This is page 18 of 29. Use http://codebase.md/wshobson/maverick-mcp?lines=false&page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.md
│ │ ├── config.yml
│ │ ├── feature_request.md
│ │ ├── question.md
│ │ └── security_report.md
│ ├── pull_request_template.md
│ └── workflows
│ ├── claude-code-review.yml
│ └── claude.yml
├── .gitignore
├── .python-version
├── .vscode
│ ├── launch.json
│ └── settings.json
├── alembic
│ ├── env.py
│ ├── script.py.mako
│ └── versions
│ ├── 001_initial_schema.py
│ ├── 003_add_performance_indexes.py
│ ├── 006_rename_metadata_columns.py
│ ├── 008_performance_optimization_indexes.py
│ ├── 009_rename_to_supply_demand.py
│ ├── 010_self_contained_schema.py
│ ├── 011_remove_proprietary_terms.py
│ ├── 013_add_backtest_persistence_models.py
│ ├── 014_add_portfolio_models.py
│ ├── 08e3945a0c93_merge_heads.py
│ ├── 9374a5c9b679_merge_heads_for_testing.py
│ ├── abf9b9afb134_merge_multiple_heads.py
│ ├── adda6d3fd84b_merge_proprietary_terms_removal_with_.py
│ ├── e0c75b0bdadb_fix_financial_data_precision_only.py
│ ├── f0696e2cac15_add_essential_performance_indexes.py
│ └── fix_database_integrity_issues.py
├── alembic.ini
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── DATABASE_SETUP.md
├── docker-compose.override.yml.example
├── docker-compose.yml
├── Dockerfile
├── docs
│ ├── api
│ │ └── backtesting.md
│ ├── BACKTESTING.md
│ ├── COST_BASIS_SPECIFICATION.md
│ ├── deep_research_agent.md
│ ├── exa_research_testing_strategy.md
│ ├── PORTFOLIO_PERSONALIZATION_PLAN.md
│ ├── PORTFOLIO.md
│ ├── SETUP_SELF_CONTAINED.md
│ └── speed_testing_framework.md
├── examples
│ ├── complete_speed_validation.py
│ ├── deep_research_integration.py
│ ├── llm_optimization_example.py
│ ├── llm_speed_demo.py
│ ├── monitoring_example.py
│ ├── parallel_research_example.py
│ ├── speed_optimization_demo.py
│ └── timeout_fix_demonstration.py
├── LICENSE
├── Makefile
├── MANIFEST.in
├── maverick_mcp
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── circuit_breaker.py
│ │ ├── deep_research.py
│ │ ├── market_analysis.py
│ │ ├── optimized_research.py
│ │ ├── supervisor.py
│ │ └── technical_analysis.py
│ ├── api
│ │ ├── __init__.py
│ │ ├── api_server.py
│ │ ├── connection_manager.py
│ │ ├── dependencies
│ │ │ ├── __init__.py
│ │ │ ├── stock_analysis.py
│ │ │ └── technical_analysis.py
│ │ ├── error_handling.py
│ │ ├── inspector_compatible_sse.py
│ │ ├── inspector_sse.py
│ │ ├── middleware
│ │ │ ├── error_handling.py
│ │ │ ├── mcp_logging.py
│ │ │ ├── rate_limiting_enhanced.py
│ │ │ └── security.py
│ │ ├── openapi_config.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── agents.py
│ │ │ ├── backtesting.py
│ │ │ ├── data_enhanced.py
│ │ │ ├── data.py
│ │ │ ├── health_enhanced.py
│ │ │ ├── health_tools.py
│ │ │ ├── health.py
│ │ │ ├── intelligent_backtesting.py
│ │ │ ├── introspection.py
│ │ │ ├── mcp_prompts.py
│ │ │ ├── monitoring.py
│ │ │ ├── news_sentiment_enhanced.py
│ │ │ ├── performance.py
│ │ │ ├── portfolio.py
│ │ │ ├── research.py
│ │ │ ├── screening_ddd.py
│ │ │ ├── screening_parallel.py
│ │ │ ├── screening.py
│ │ │ ├── technical_ddd.py
│ │ │ ├── technical_enhanced.py
│ │ │ ├── technical.py
│ │ │ └── tool_registry.py
│ │ ├── server.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── base_service.py
│ │ │ ├── market_service.py
│ │ │ ├── portfolio_service.py
│ │ │ ├── prompt_service.py
│ │ │ └── resource_service.py
│ │ ├── simple_sse.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── insomnia_export.py
│ │ └── postman_export.py
│ ├── application
│ │ ├── __init__.py
│ │ ├── commands
│ │ │ └── __init__.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_dto.py
│ │ ├── queries
│ │ │ ├── __init__.py
│ │ │ └── get_technical_analysis.py
│ │ └── screening
│ │ ├── __init__.py
│ │ ├── dtos.py
│ │ └── queries.py
│ ├── backtesting
│ │ ├── __init__.py
│ │ ├── ab_testing.py
│ │ ├── analysis.py
│ │ ├── batch_processing_stub.py
│ │ ├── batch_processing.py
│ │ ├── model_manager.py
│ │ ├── optimization.py
│ │ ├── persistence.py
│ │ ├── retraining_pipeline.py
│ │ ├── strategies
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── ml
│ │ │ │ ├── __init__.py
│ │ │ │ ├── adaptive.py
│ │ │ │ ├── ensemble.py
│ │ │ │ ├── feature_engineering.py
│ │ │ │ └── regime_aware.py
│ │ │ ├── ml_strategies.py
│ │ │ ├── parser.py
│ │ │ └── templates.py
│ │ ├── strategy_executor.py
│ │ ├── vectorbt_engine.py
│ │ └── visualization.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── database_self_contained.py
│ │ ├── database.py
│ │ ├── llm_optimization_config.py
│ │ ├── logging_settings.py
│ │ ├── plotly_config.py
│ │ ├── security_utils.py
│ │ ├── security.py
│ │ ├── settings.py
│ │ ├── technical_constants.py
│ │ ├── tool_estimation.py
│ │ └── validation.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── technical_analysis.py
│ │ └── visualization.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── cache_manager.py
│ │ ├── cache.py
│ │ ├── django_adapter.py
│ │ ├── health.py
│ │ ├── models.py
│ │ ├── performance.py
│ │ ├── session_management.py
│ │ └── validation.py
│ ├── database
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── optimization.py
│ ├── dependencies.py
│ ├── domain
│ │ ├── __init__.py
│ │ ├── entities
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis.py
│ │ ├── events
│ │ │ └── __init__.py
│ │ ├── portfolio.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ ├── entities.py
│ │ │ ├── services.py
│ │ │ └── value_objects.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_service.py
│ │ ├── stock_analysis
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis_service.py
│ │ └── value_objects
│ │ ├── __init__.py
│ │ └── technical_indicators.py
│ ├── exceptions.py
│ ├── infrastructure
│ │ ├── __init__.py
│ │ ├── cache
│ │ │ └── __init__.py
│ │ ├── caching
│ │ │ ├── __init__.py
│ │ │ └── cache_management_service.py
│ │ ├── connection_manager.py
│ │ ├── data_fetching
│ │ │ ├── __init__.py
│ │ │ └── stock_data_service.py
│ │ ├── health
│ │ │ ├── __init__.py
│ │ │ └── health_checker.py
│ │ ├── persistence
│ │ │ ├── __init__.py
│ │ │ └── stock_repository.py
│ │ ├── providers
│ │ │ └── __init__.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ └── repositories.py
│ │ └── sse_optimizer.py
│ ├── langchain_tools
│ │ ├── __init__.py
│ │ ├── adapters.py
│ │ └── registry.py
│ ├── logging_config.py
│ ├── memory
│ │ ├── __init__.py
│ │ └── stores.py
│ ├── monitoring
│ │ ├── __init__.py
│ │ ├── health_check.py
│ │ ├── health_monitor.py
│ │ ├── integration_example.py
│ │ ├── metrics.py
│ │ ├── middleware.py
│ │ └── status_dashboard.py
│ ├── providers
│ │ ├── __init__.py
│ │ ├── dependencies.py
│ │ ├── factories
│ │ │ ├── __init__.py
│ │ │ ├── config_factory.py
│ │ │ └── provider_factory.py
│ │ ├── implementations
│ │ │ ├── __init__.py
│ │ │ ├── cache_adapter.py
│ │ │ ├── macro_data_adapter.py
│ │ │ ├── market_data_adapter.py
│ │ │ ├── persistence_adapter.py
│ │ │ └── stock_data_adapter.py
│ │ ├── interfaces
│ │ │ ├── __init__.py
│ │ │ ├── cache.py
│ │ │ ├── config.py
│ │ │ ├── macro_data.py
│ │ │ ├── market_data.py
│ │ │ ├── persistence.py
│ │ │ └── stock_data.py
│ │ ├── llm_factory.py
│ │ ├── macro_data.py
│ │ ├── market_data.py
│ │ ├── mocks
│ │ │ ├── __init__.py
│ │ │ ├── mock_cache.py
│ │ │ ├── mock_config.py
│ │ │ ├── mock_macro_data.py
│ │ │ ├── mock_market_data.py
│ │ │ ├── mock_persistence.py
│ │ │ └── mock_stock_data.py
│ │ ├── openrouter_provider.py
│ │ ├── optimized_screening.py
│ │ ├── optimized_stock_data.py
│ │ └── stock_data.py
│ ├── README.md
│ ├── tests
│ │ ├── __init__.py
│ │ ├── README_INMEMORY_TESTS.md
│ │ ├── test_cache_debug.py
│ │ ├── test_fixes_validation.py
│ │ ├── test_in_memory_routers.py
│ │ ├── test_in_memory_server.py
│ │ ├── test_macro_data_provider.py
│ │ ├── test_mailgun_email.py
│ │ ├── test_market_calendar_caching.py
│ │ ├── test_mcp_tool_fixes_pytest.py
│ │ ├── test_mcp_tool_fixes.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_models_functional.py
│ │ ├── test_server.py
│ │ ├── test_stock_data_enhanced.py
│ │ ├── test_stock_data_provider.py
│ │ └── test_technical_analysis.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── performance_monitoring.py
│ │ ├── portfolio_manager.py
│ │ ├── risk_management.py
│ │ └── sentiment_analysis.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── agent_errors.py
│ │ ├── batch_processing.py
│ │ ├── cache_warmer.py
│ │ ├── circuit_breaker_decorators.py
│ │ ├── circuit_breaker_services.py
│ │ ├── circuit_breaker.py
│ │ ├── data_chunking.py
│ │ ├── database_monitoring.py
│ │ ├── debug_utils.py
│ │ ├── fallback_strategies.py
│ │ ├── llm_optimization.py
│ │ ├── logging_example.py
│ │ ├── logging_init.py
│ │ ├── logging.py
│ │ ├── mcp_logging.py
│ │ ├── memory_profiler.py
│ │ ├── monitoring_middleware.py
│ │ ├── monitoring.py
│ │ ├── orchestration_logging.py
│ │ ├── parallel_research.py
│ │ ├── parallel_screening.py
│ │ ├── quick_cache.py
│ │ ├── resource_manager.py
│ │ ├── shutdown.py
│ │ ├── stock_helpers.py
│ │ ├── structured_logger.py
│ │ ├── tool_monitoring.py
│ │ ├── tracing.py
│ │ └── yfinance_pool.py
│ ├── validation
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── data.py
│ │ ├── middleware.py
│ │ ├── portfolio.py
│ │ ├── responses.py
│ │ ├── screening.py
│ │ └── technical.py
│ └── workflows
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── market_analyzer.py
│ │ ├── optimizer_agent.py
│ │ ├── strategy_selector.py
│ │ └── validator_agent.py
│ ├── backtesting_workflow.py
│ └── state.py
├── PLANS.md
├── pyproject.toml
├── pyrightconfig.json
├── README.md
├── scripts
│ ├── dev.sh
│ ├── INSTALLATION_GUIDE.md
│ ├── load_example.py
│ ├── load_market_data.py
│ ├── load_tiingo_data.py
│ ├── migrate_db.py
│ ├── README_TIINGO_LOADER.md
│ ├── requirements_tiingo.txt
│ ├── run_stock_screening.py
│ ├── run-migrations.sh
│ ├── seed_db.py
│ ├── seed_sp500.py
│ ├── setup_database.sh
│ ├── setup_self_contained.py
│ ├── setup_sp500_database.sh
│ ├── test_seeded_data.py
│ ├── test_tiingo_loader.py
│ ├── tiingo_config.py
│ └── validate_setup.py
├── SECURITY.md
├── server.json
├── setup.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── core
│ │ └── test_technical_analysis.py
│ ├── data
│ │ └── test_portfolio_models.py
│ ├── domain
│ │ ├── conftest.py
│ │ ├── test_portfolio_entities.py
│ │ └── test_technical_analysis_service.py
│ ├── fixtures
│ │ └── orchestration_fixtures.py
│ ├── integration
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── README.md
│ │ ├── run_integration_tests.sh
│ │ ├── test_api_technical.py
│ │ ├── test_chaos_engineering.py
│ │ ├── test_config_management.py
│ │ ├── test_full_backtest_workflow_advanced.py
│ │ ├── test_full_backtest_workflow.py
│ │ ├── test_high_volume.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_orchestration_complete.py
│ │ ├── test_portfolio_persistence.py
│ │ ├── test_redis_cache.py
│ │ ├── test_security_integration.py.disabled
│ │ └── vcr_setup.py
│ ├── performance
│ │ ├── __init__.py
│ │ ├── test_benchmarks.py
│ │ ├── test_load.py
│ │ ├── test_profiling.py
│ │ └── test_stress.py
│ ├── providers
│ │ └── test_stock_data_simple.py
│ ├── README.md
│ ├── test_agents_router_mcp.py
│ ├── test_backtest_persistence.py
│ ├── test_cache_management_service.py
│ ├── test_cache_serialization.py
│ ├── test_circuit_breaker.py
│ ├── test_database_pool_config_simple.py
│ ├── test_database_pool_config.py
│ ├── test_deep_research_functional.py
│ ├── test_deep_research_integration.py
│ ├── test_deep_research_parallel_execution.py
│ ├── test_error_handling.py
│ ├── test_event_loop_integrity.py
│ ├── test_exa_research_integration.py
│ ├── test_exception_hierarchy.py
│ ├── test_financial_search.py
│ ├── test_graceful_shutdown.py
│ ├── test_integration_simple.py
│ ├── test_langgraph_workflow.py
│ ├── test_market_data_async.py
│ ├── test_market_data_simple.py
│ ├── test_mcp_orchestration_functional.py
│ ├── test_ml_strategies.py
│ ├── test_optimized_research_agent.py
│ ├── test_orchestration_integration.py
│ ├── test_orchestration_logging.py
│ ├── test_orchestration_tools_simple.py
│ ├── test_parallel_research_integration.py
│ ├── test_parallel_research_orchestrator.py
│ ├── test_parallel_research_performance.py
│ ├── test_performance_optimizations.py
│ ├── test_production_validation.py
│ ├── test_provider_architecture.py
│ ├── test_rate_limiting_enhanced.py
│ ├── test_runner_validation.py
│ ├── test_security_comprehensive.py.disabled
│ ├── test_security_cors.py
│ ├── test_security_enhancements.py.disabled
│ ├── test_security_headers.py
│ ├── test_security_penetration.py
│ ├── test_session_management.py
│ ├── test_speed_optimization_validation.py
│ ├── test_stock_analysis_dependencies.py
│ ├── test_stock_analysis_service.py
│ ├── test_stock_data_fetching_service.py
│ ├── test_supervisor_agent.py
│ ├── test_supervisor_functional.py
│ ├── test_tool_estimation_config.py
│ ├── test_visualization.py
│ └── utils
│ ├── test_agent_errors.py
│ ├── test_logging.py
│ ├── test_parallel_screening.py
│ └── test_quick_cache.py
├── tools
│ ├── check_orchestration_config.py
│ ├── experiments
│ │ ├── validation_examples.py
│ │ └── validation_fixed.py
│ ├── fast_dev.sh
│ ├── hot_reload.py
│ ├── quick_test.py
│ └── templates
│ ├── new_router_template.py
│ ├── new_tool_template.py
│ ├── screening_strategy_template.py
│ └── test_template.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/maverick_mcp/backtesting/retraining_pipeline.py:
--------------------------------------------------------------------------------
```python
"""Automated retraining pipeline for ML models with data drift detection."""
import logging
from collections.abc import Callable
from datetime import datetime
from typing import Any
import numpy as np
import pandas as pd
from scipy import stats
from sklearn.base import BaseEstimator
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from .model_manager import ModelManager
logger = logging.getLogger(__name__)
class DataDriftDetector:
"""Detects data drift in features and targets."""
def __init__(self, significance_level: float = 0.05):
"""Initialize drift detector.
Args:
significance_level: Statistical significance level for drift detection
"""
self.significance_level = significance_level
self.reference_data: pd.DataFrame | None = None
self.reference_target: pd.Series | None = None
self.feature_stats: dict[str, dict[str, float]] = {}
def set_reference_data(
self, features: pd.DataFrame, target: pd.Series | None = None
):
"""Set reference data for drift detection.
Args:
features: Reference feature data
target: Reference target data (optional)
"""
self.reference_data = features.copy()
self.reference_target = target.copy() if target is not None else None
# Calculate reference statistics
self.feature_stats = {}
for col in features.columns:
if features[col].dtype in ["float64", "float32", "int64", "int32"]:
self.feature_stats[col] = {
"mean": features[col].mean(),
"std": features[col].std(),
"min": features[col].min(),
"max": features[col].max(),
"median": features[col].median(),
}
logger.info(
f"Set reference data with {len(features)} samples and {len(features.columns)} features"
)
def detect_feature_drift(
self, new_features: pd.DataFrame
) -> dict[str, dict[str, Any]]:
"""Detect drift in features using statistical tests.
Args:
new_features: New feature data to compare
Returns:
Dictionary with drift detection results per feature
"""
if self.reference_data is None:
raise ValueError("Reference data not set")
drift_results = {}
for col in new_features.columns:
if col not in self.reference_data.columns:
continue
if new_features[col].dtype not in ["float64", "float32", "int64", "int32"]:
continue
ref_data = self.reference_data[col].dropna()
new_data = new_features[col].dropna()
if len(ref_data) == 0 or len(new_data) == 0:
continue
# Perform statistical tests
drift_detected = False
test_results = {}
try:
# Kolmogorov-Smirnov test for distribution change
ks_statistic, ks_p_value = stats.ks_2samp(ref_data, new_data)
test_results["ks_statistic"] = ks_statistic
test_results["ks_p_value"] = ks_p_value
ks_drift = ks_p_value < self.significance_level
# Mann-Whitney U test for location shift
mw_statistic, mw_p_value = stats.mannwhitneyu(
ref_data, new_data, alternative="two-sided"
)
test_results["mw_statistic"] = mw_statistic
test_results["mw_p_value"] = mw_p_value
mw_drift = mw_p_value < self.significance_level
# Levene test for variance change
levene_statistic, levene_p_value = stats.levene(ref_data, new_data)
test_results["levene_statistic"] = levene_statistic
test_results["levene_p_value"] = levene_p_value
levene_drift = levene_p_value < self.significance_level
# Overall drift detection
drift_detected = ks_drift or mw_drift or levene_drift
# Calculate effect sizes
test_results["mean_diff"] = new_data.mean() - ref_data.mean()
test_results["std_ratio"] = new_data.std() / (ref_data.std() + 1e-8)
except Exception as e:
logger.warning(f"Error in drift detection for {col}: {e}")
test_results["error"] = str(e)
drift_results[col] = {
"drift_detected": drift_detected,
"test_results": test_results,
"reference_stats": self.feature_stats.get(col, {}),
"new_stats": {
"mean": new_data.mean(),
"std": new_data.std(),
"min": new_data.min(),
"max": new_data.max(),
"median": new_data.median(),
},
}
return drift_results
def detect_target_drift(self, new_target: pd.Series) -> dict[str, Any]:
"""Detect drift in target variable.
Args:
new_target: New target data to compare
Returns:
Dictionary with target drift results
"""
if self.reference_target is None:
logger.warning("No reference target data set")
return {"drift_detected": False, "reason": "no_reference_target"}
ref_target = self.reference_target.dropna()
new_target = new_target.dropna()
if len(ref_target) == 0 or len(new_target) == 0:
return {"drift_detected": False, "reason": "insufficient_data"}
drift_results = {"drift_detected": False}
try:
# For categorical targets, use chi-square test
if ref_target.dtype == "object" or ref_target.nunique() < 10:
ref_counts = ref_target.value_counts()
new_counts = new_target.value_counts()
# Align the categories
all_categories = set(ref_counts.index) | set(new_counts.index)
ref_aligned = [ref_counts.get(cat, 0) for cat in all_categories]
new_aligned = [new_counts.get(cat, 0) for cat in all_categories]
if sum(ref_aligned) > 0 and sum(new_aligned) > 0:
chi2_stat, chi2_p_value = stats.chisquare(new_aligned, ref_aligned)
drift_results.update(
{
"test_type": "chi_square",
"chi2_statistic": chi2_stat,
"chi2_p_value": chi2_p_value,
"drift_detected": chi2_p_value < self.significance_level,
}
)
# For continuous targets
else:
ks_statistic, ks_p_value = stats.ks_2samp(ref_target, new_target)
drift_results.update(
{
"test_type": "kolmogorov_smirnov",
"ks_statistic": ks_statistic,
"ks_p_value": ks_p_value,
"drift_detected": ks_p_value < self.significance_level,
}
)
except Exception as e:
logger.warning(f"Error in target drift detection: {e}")
drift_results["error"] = str(e)
return drift_results
def get_drift_summary(
self, feature_drift: dict[str, dict], target_drift: dict[str, Any]
) -> dict[str, Any]:
"""Get summary of drift detection results.
Args:
feature_drift: Feature drift results
target_drift: Target drift results
Returns:
Summary dictionary
"""
total_features = len(feature_drift)
drifted_features = sum(
1 for result in feature_drift.values() if result["drift_detected"]
)
target_drift_detected = target_drift.get("drift_detected", False)
drift_severity = "none"
if target_drift_detected or drifted_features > total_features * 0.5:
drift_severity = "high"
elif drifted_features > total_features * 0.2:
drift_severity = "medium"
elif drifted_features > 0:
drift_severity = "low"
return {
"total_features": total_features,
"drifted_features": drifted_features,
"drift_percentage": drifted_features / max(total_features, 1) * 100,
"target_drift_detected": target_drift_detected,
"drift_severity": drift_severity,
"recommendation": self._get_retraining_recommendation(
drift_severity, target_drift_detected
),
}
def _get_retraining_recommendation(
self, drift_severity: str, target_drift: bool
) -> str:
"""Get retraining recommendation based on drift severity."""
if target_drift:
return "immediate_retraining"
elif drift_severity == "high":
return "urgent_retraining"
elif drift_severity == "medium":
return "scheduled_retraining"
elif drift_severity == "low":
return "monitor_closely"
else:
return "no_action_needed"
class ModelPerformanceMonitor:
"""Monitors model performance and detects degradation."""
def __init__(self, performance_threshold: float = 0.05):
"""Initialize performance monitor.
Args:
performance_threshold: Threshold for performance degradation detection
"""
self.performance_threshold = performance_threshold
self.baseline_metrics: dict[str, float] = {}
self.performance_history: list[dict[str, Any]] = []
def set_baseline_performance(self, metrics: dict[str, float]):
"""Set baseline performance metrics.
Args:
metrics: Baseline performance metrics
"""
self.baseline_metrics = metrics.copy()
logger.info(f"Set baseline performance: {metrics}")
def evaluate_performance(
self,
model: BaseEstimator,
X_test: pd.DataFrame,
y_test: pd.Series,
additional_metrics: dict[str, float] | None = None,
) -> dict[str, Any]:
"""Evaluate current model performance.
Args:
model: Trained model
X_test: Test features
y_test: Test targets
additional_metrics: Additional metrics to include
Returns:
Performance evaluation results
"""
try:
# Make predictions
y_pred = model.predict(X_test)
# Calculate metrics
metrics = {
"accuracy": accuracy_score(y_test, y_pred),
"timestamp": datetime.now().isoformat(),
}
# Add additional metrics if provided
if additional_metrics:
metrics.update(additional_metrics)
# Detect performance degradation
degradation_detected = False
degradation_details = {}
for metric_name, current_value in metrics.items():
if metric_name in self.baseline_metrics and metric_name != "timestamp":
baseline_value = self.baseline_metrics[metric_name]
degradation = (baseline_value - current_value) / abs(baseline_value)
if degradation > self.performance_threshold:
degradation_detected = True
degradation_details[metric_name] = {
"baseline": baseline_value,
"current": current_value,
"degradation": degradation,
}
evaluation_result = {
"metrics": metrics,
"degradation_detected": degradation_detected,
"degradation_details": degradation_details,
"classification_report": classification_report(
y_test, y_pred, output_dict=True
),
}
# Store in history
self.performance_history.append(evaluation_result)
# Keep only recent history
if len(self.performance_history) > 100:
self.performance_history = self.performance_history[-100:]
return evaluation_result
except Exception as e:
logger.error(f"Error evaluating model performance: {e}")
return {"error": str(e)}
def get_performance_trend(self, metric_name: str = "accuracy") -> dict[str, Any]:
"""Analyze performance trend over time.
Args:
metric_name: Metric to analyze
Returns:
Trend analysis results
"""
if not self.performance_history:
return {"trend": "no_data"}
values = []
timestamps = []
for record in self.performance_history:
if metric_name in record["metrics"]:
values.append(record["metrics"][metric_name])
timestamps.append(record["metrics"]["timestamp"])
if len(values) < 3:
return {"trend": "insufficient_data"}
# Calculate trend
x = np.arange(len(values))
slope, _, r_value, p_value, _ = stats.linregress(x, values)
trend_direction = "stable"
if p_value < 0.05: # Statistically significant trend
if slope > 0:
trend_direction = "improving"
else:
trend_direction = "degrading"
return {
"trend": trend_direction,
"slope": slope,
"r_squared": r_value**2,
"p_value": p_value,
"recent_values": values[-5:],
"timestamps": timestamps[-5:],
}
class AutoRetrainingPipeline:
"""Automated pipeline for model retraining with drift detection and performance monitoring."""
def __init__(
self,
model_manager: ModelManager,
model_factory: Callable[[], BaseEstimator],
feature_extractor: Callable[[pd.DataFrame], pd.DataFrame],
target_extractor: Callable[[pd.DataFrame], pd.Series],
retraining_schedule_hours: int = 24,
min_samples_for_retraining: int = 100,
):
"""Initialize auto-retraining pipeline.
Args:
model_manager: Model manager instance
model_factory: Function that creates new model instances
feature_extractor: Function to extract features from data
target_extractor: Function to extract targets from data
retraining_schedule_hours: Hours between scheduled retraining checks
min_samples_for_retraining: Minimum samples required for retraining
"""
self.model_manager = model_manager
self.model_factory = model_factory
self.feature_extractor = feature_extractor
self.target_extractor = target_extractor
self.retraining_schedule_hours = retraining_schedule_hours
self.min_samples_for_retraining = min_samples_for_retraining
self.drift_detector = DataDriftDetector()
self.performance_monitor = ModelPerformanceMonitor()
self.last_retraining: dict[str, datetime] = {}
self.retraining_history: list[dict[str, Any]] = []
def should_retrain(
self,
model_id: str,
new_data: pd.DataFrame,
force_check: bool = False,
) -> tuple[bool, str]:
"""Determine if a model should be retrained.
Args:
model_id: Model identifier
new_data: New data for evaluation
force_check: Force retraining check regardless of schedule
Returns:
Tuple of (should_retrain, reason)
"""
# Check schedule
last_retrain = self.last_retraining.get(model_id)
if not force_check and last_retrain is not None:
time_since_retrain = datetime.now() - last_retrain
if (
time_since_retrain.total_seconds()
< self.retraining_schedule_hours * 3600
):
return False, "schedule_not_due"
# Check minimum samples
if len(new_data) < self.min_samples_for_retraining:
return False, "insufficient_samples"
# Extract features and targets
try:
features = self.feature_extractor(new_data)
targets = self.target_extractor(new_data)
except Exception as e:
logger.error(f"Error extracting features/targets: {e}")
return False, f"extraction_error: {e}"
# Check for data drift
if self.drift_detector.reference_data is not None:
feature_drift = self.drift_detector.detect_feature_drift(features)
target_drift = self.drift_detector.detect_target_drift(targets)
drift_summary = self.drift_detector.get_drift_summary(
feature_drift, target_drift
)
if drift_summary["recommendation"] in [
"immediate_retraining",
"urgent_retraining",
]:
return True, f"data_drift_{drift_summary['drift_severity']}"
# Check performance degradation
current_model = self.model_manager.load_model(model_id)
if current_model is not None and current_model.model is not None:
try:
# Split data for evaluation
X_train, X_test, y_train, y_test = train_test_split(
features, targets, test_size=0.3, random_state=42, stratify=targets
)
# Scale features if scaler is available
if current_model.scaler is not None:
X_test_scaled = current_model.scaler.transform(X_test)
else:
X_test_scaled = X_test
# Evaluate performance
performance_result = self.performance_monitor.evaluate_performance(
current_model.model, X_test_scaled, y_test
)
if performance_result.get("degradation_detected", False):
return True, "performance_degradation"
except Exception as e:
logger.warning(f"Error evaluating model performance: {e}")
return False, "no_triggers"
def retrain_model(
self,
model_id: str,
training_data: pd.DataFrame,
validation_split: float = 0.2,
**model_params,
) -> str | None:
"""Retrain a model with new data.
Args:
model_id: Model identifier
training_data: Training data
validation_split: Fraction of data to use for validation
**model_params: Additional parameters for model training
Returns:
New model version string if successful, None otherwise
"""
try:
# Extract features and targets
features = self.feature_extractor(training_data)
targets = self.target_extractor(training_data)
# Split data
X_train, X_val, y_train, y_val = train_test_split(
features,
targets,
test_size=validation_split,
random_state=42,
stratify=targets,
)
# Scale features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_val_scaled = scaler.transform(X_val)
# Create and train new model
model = self.model_factory()
model.set_params(**model_params)
model.fit(X_train_scaled, y_train)
# Evaluate model
train_score = model.score(X_train_scaled, y_train)
val_score = model.score(X_val_scaled, y_val)
# Create version string
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
new_version = f"v_{timestamp}"
# Prepare metadata
metadata = {
"training_samples": len(X_train),
"validation_samples": len(X_val),
"features_count": X_train.shape[1],
"model_params": model_params,
"retraining_trigger": "automated",
}
# Prepare performance metrics
performance_metrics = {
"train_accuracy": train_score,
"validation_accuracy": val_score,
"overfitting_gap": train_score - val_score,
}
# Save model
success = self.model_manager.save_model(
model_id=model_id,
version=new_version,
model=model,
scaler=scaler,
metadata=metadata,
performance_metrics=performance_metrics,
set_as_active=True, # Set as active if validation performance is good
)
if success:
# Update retraining history
self.last_retraining[model_id] = datetime.now()
self.retraining_history.append(
{
"model_id": model_id,
"version": new_version,
"timestamp": datetime.now().isoformat(),
"training_samples": len(X_train),
"validation_accuracy": val_score,
}
)
# Update drift detector reference data
self.drift_detector.set_reference_data(features, targets)
# Update performance monitor baseline
self.performance_monitor.set_baseline_performance(performance_metrics)
logger.info(
f"Successfully retrained model {model_id} -> {new_version} (val_acc: {val_score:.4f})"
)
return new_version
else:
logger.error(f"Failed to save retrained model {model_id}")
return None
except Exception as e:
logger.error(f"Error retraining model {model_id}: {e}")
return None
def run_retraining_check(
self, model_id: str, new_data: pd.DataFrame
) -> dict[str, Any]:
"""Run complete retraining check and execute if needed.
Args:
model_id: Model identifier
new_data: New data for evaluation
Returns:
Dictionary with check results and actions taken
"""
start_time = datetime.now()
try:
# Check if retraining is needed
should_retrain, reason = self.should_retrain(model_id, new_data)
result = {
"model_id": model_id,
"timestamp": start_time.isoformat(),
"should_retrain": should_retrain,
"reason": reason,
"data_samples": len(new_data),
"new_version": None,
"success": False,
}
if should_retrain:
logger.info(f"Retraining {model_id} due to: {reason}")
new_version = self.retrain_model(model_id, new_data)
if new_version:
result.update(
{
"new_version": new_version,
"success": True,
"action": "retrained",
}
)
else:
result.update(
{
"action": "retrain_failed",
"error": "Model retraining failed",
}
)
else:
result.update(
{
"action": "no_retrain",
"success": True,
}
)
# Calculate execution time
execution_time = (datetime.now() - start_time).total_seconds()
result["execution_time_seconds"] = execution_time
return result
except Exception as e:
logger.error(f"Error in retraining check for {model_id}: {e}")
return {
"model_id": model_id,
"timestamp": start_time.isoformat(),
"should_retrain": False,
"reason": "check_error",
"error": str(e),
"success": False,
"execution_time_seconds": (datetime.now() - start_time).total_seconds(),
}
def get_retraining_summary(self) -> dict[str, Any]:
"""Get summary of retraining pipeline status.
Returns:
Summary dictionary
"""
return {
"total_models_managed": len(self.last_retraining),
"total_retrainings": len(self.retraining_history),
"recent_retrainings": self.retraining_history[-10:],
"last_retraining_times": {
model_id: timestamp.isoformat()
for model_id, timestamp in self.last_retraining.items()
},
"retraining_schedule_hours": self.retraining_schedule_hours,
"min_samples_for_retraining": self.min_samples_for_retraining,
}
# Alias for backward compatibility
RetrainingPipeline = AutoRetrainingPipeline
# Ensure all expected names are available
__all__ = [
"DataDriftDetector",
"ModelPerformanceMonitor",
"AutoRetrainingPipeline",
"RetrainingPipeline", # Alias for backward compatibility
]
```
--------------------------------------------------------------------------------
/maverick_mcp/data/performance.py:
--------------------------------------------------------------------------------
```python
"""
Performance optimization utilities for Maverick-MCP.
This module provides Redis connection pooling, request caching,
and query optimization features to improve application performance.
"""
import hashlib
import json
import logging
import time
from collections.abc import Callable
from contextlib import asynccontextmanager
from functools import wraps
from typing import Any, TypeVar, cast
import redis.asyncio as redis
from redis.asyncio.client import Pipeline
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from maverick_mcp.config.settings import get_settings
from maverick_mcp.data.session_management import get_async_db_session
settings = get_settings()
logger = logging.getLogger(__name__)
# Type variables for generic typing
F = TypeVar("F", bound=Callable[..., Any])
class RedisConnectionManager:
"""
Centralized Redis connection manager with connection pooling.
This manager provides:
- Connection pooling with configurable limits
- Automatic failover and retry logic
- Health monitoring and metrics
- Graceful degradation when Redis is unavailable
"""
def __init__(self):
self._pool: redis.ConnectionPool | None = None
self._client: redis.Redis | None = None
self._initialized = False
self._healthy = False
self._last_health_check = 0
self._health_check_interval = 30 # seconds
# Connection pool configuration
self._max_connections = settings.db.redis_max_connections
self._retry_on_timeout = settings.db.redis_retry_on_timeout
self._socket_timeout = settings.db.redis_socket_timeout
self._socket_connect_timeout = settings.db.redis_socket_connect_timeout
self._health_check_interval_sec = 30
# Metrics
self._metrics = {
"connections_created": 0,
"connections_closed": 0,
"commands_executed": 0,
"errors": 0,
"health_checks": 0,
"last_error": None,
}
async def initialize(self) -> bool:
"""
Initialize Redis connection pool.
Returns:
bool: True if initialization successful, False otherwise
"""
if self._initialized:
return self._healthy
try:
# Create connection pool
self._pool = redis.ConnectionPool.from_url(
settings.redis.url,
max_connections=self._max_connections,
retry_on_timeout=self._retry_on_timeout,
socket_timeout=self._socket_timeout,
socket_connect_timeout=self._socket_connect_timeout,
decode_responses=True,
health_check_interval=self._health_check_interval_sec,
)
# Create Redis client
self._client = redis.Redis(connection_pool=self._pool)
client = self._client
if client is None: # Defensive guard for static type checking
msg = "Redis client initialization failed"
raise RuntimeError(msg)
# Test connection
await client.ping()
self._healthy = True
self._initialized = True
self._metrics["connections_created"] += 1
logger.info(
f"Redis connection pool initialized: "
f"max_connections={self._max_connections}, "
f"url={settings.redis.url}"
)
return True
except Exception as e:
logger.error(f"Failed to initialize Redis connection pool: {e}")
self._metrics["errors"] += 1
self._metrics["last_error"] = str(e)
self._healthy = False
return False
async def get_client(self) -> redis.Redis | None:
"""
Get Redis client from the connection pool.
Returns:
Redis client or None if unavailable
"""
if not self._initialized:
await self.initialize()
if not self._healthy:
await self._health_check()
return self._client if self._healthy else None
async def _health_check(self) -> bool:
"""
Perform health check on Redis connection.
Returns:
bool: True if healthy, False otherwise
"""
current_time = time.time()
# Skip health check if recently performed
if (current_time - self._last_health_check) < self._health_check_interval:
return self._healthy
self._last_health_check = current_time
self._metrics["health_checks"] += 1
try:
if self._client:
await self._client.ping()
self._healthy = True
logger.debug("Redis health check passed")
else:
self._healthy = False
except Exception as e:
logger.warning(f"Redis health check failed: {e}")
self._healthy = False
self._metrics["errors"] += 1
self._metrics["last_error"] = str(e)
# Try to reinitialize
await self.initialize()
return self._healthy
async def execute_command(self, command: str, *args, **kwargs) -> Any:
"""
Execute Redis command with error handling and metrics.
Args:
command: Redis command name
*args: Command arguments
**kwargs: Command keyword arguments
Returns:
Command result or None if failed
"""
client = await self.get_client()
if not client:
return None
try:
self._metrics["commands_executed"] += 1
result = await getattr(client, command)(*args, **kwargs)
return result
except Exception as e:
logger.error(f"Redis command '{command}' failed: {e}")
self._metrics["errors"] += 1
self._metrics["last_error"] = str(e)
return None
async def pipeline(self) -> Pipeline | None:
"""
Create Redis pipeline for batch operations.
Returns:
Redis pipeline or None if unavailable
"""
client = await self.get_client()
if not client:
return None
return client.pipeline()
def get_metrics(self) -> dict[str, Any]:
"""Get connection pool metrics."""
metrics = self._metrics.copy()
metrics.update(
{
"healthy": self._healthy,
"initialized": self._initialized,
"pool_size": self._max_connections,
"pool_created": bool(self._pool),
}
)
if self._pool:
# Safely get pool metrics with fallbacks for missing attributes
try:
metrics["pool_created_connections"] = getattr(
self._pool, "created_connections", 0
)
except AttributeError:
metrics["pool_created_connections"] = 0
try:
metrics["pool_available_connections"] = len(
getattr(self._pool, "_available_connections", [])
)
except (AttributeError, TypeError):
metrics["pool_available_connections"] = 0
try:
metrics["pool_in_use_connections"] = len(
getattr(self._pool, "_in_use_connections", [])
)
except (AttributeError, TypeError):
metrics["pool_in_use_connections"] = 0
return metrics
async def close(self):
"""Close connection pool gracefully."""
if self._client:
# Use aclose() instead of close() to avoid deprecation warning
# aclose() is the new async close method in redis-py 5.0+
if hasattr(self._client, "aclose"):
await self._client.aclose()
else:
# Fallback for older versions
await self._client.close()
self._metrics["connections_closed"] += 1
if self._pool:
await self._pool.disconnect()
self._initialized = False
self._healthy = False
logger.info("Redis connection pool closed")
# Global Redis connection manager instance
redis_manager = RedisConnectionManager()
class RequestCache:
"""
Smart request-level caching system.
This system provides:
- Automatic cache key generation based on function signature
- TTL strategies for different data types
- Cache invalidation mechanisms
- Hit/miss metrics and monitoring
"""
def __init__(self):
self._hit_count = 0
self._miss_count = 0
self._error_count = 0
# Default TTL values for different data types (in seconds)
self._default_ttls = {
"stock_data": 3600, # 1 hour for stock data
"technical_analysis": 1800, # 30 minutes for technical indicators
"market_data": 300, # 5 minutes for market data
"screening": 7200, # 2 hours for screening results
"portfolio": 1800, # 30 minutes for portfolio analysis
"macro_data": 3600, # 1 hour for macro data
"default": 900, # 15 minutes default
}
def _generate_cache_key(self, prefix: str, *args, **kwargs) -> str:
"""
Generate cache key from function arguments.
Args:
prefix: Cache key prefix
*args: Function arguments
**kwargs: Function keyword arguments
Returns:
Generated cache key
"""
# Create a hash of the arguments
key_data = {
"args": args,
"kwargs": sorted(kwargs.items()),
}
key_hash = hashlib.sha256(
json.dumps(key_data, sort_keys=True, default=str).encode()
).hexdigest()[:16] # Use first 16 chars for brevity
return f"cache:{prefix}:{key_hash}"
def _get_ttl(self, data_type: str) -> int:
"""Get TTL for data type."""
return self._default_ttls.get(data_type, self._default_ttls["default"])
async def get(self, key: str) -> Any | None:
"""
Get value from cache.
Args:
key: Cache key
Returns:
Cached value or None if not found
"""
try:
client = await redis_manager.get_client()
if not client:
return None
data = await client.get(key)
if data:
self._hit_count += 1
logger.debug(f"Cache hit for key: {key}")
return json.loads(data)
else:
self._miss_count += 1
logger.debug(f"Cache miss for key: {key}")
return None
except Exception as e:
logger.error(f"Error getting from cache: {e}")
self._error_count += 1
return None
async def set(
self, key: str, value: Any, ttl: int | None = None, data_type: str = "default"
) -> bool:
"""
Set value in cache.
Args:
key: Cache key
value: Value to cache
ttl: Time to live in seconds
data_type: Data type for TTL determination
Returns:
True if successful, False otherwise
"""
try:
client = await redis_manager.get_client()
if not client:
return False
if ttl is None:
ttl = self._get_ttl(data_type)
serialized_value = json.dumps(value, default=str)
success = await client.setex(key, ttl, serialized_value)
if success:
logger.debug(f"Cached value for key: {key} (TTL: {ttl}s)")
return bool(success)
except Exception as e:
logger.error(f"Error setting cache: {e}")
self._error_count += 1
return False
async def delete(self, key: str) -> bool:
"""Delete key from cache."""
try:
client = await redis_manager.get_client()
if not client:
return False
result = await client.delete(key)
return bool(result)
except Exception as e:
logger.error(f"Error deleting from cache: {e}")
self._error_count += 1
return False
async def delete_pattern(self, pattern: str) -> int:
"""Delete all keys matching pattern."""
try:
client = await redis_manager.get_client()
if not client:
return 0
keys = await client.keys(pattern)
if keys:
result = await client.delete(*keys)
logger.info(f"Deleted {result} keys matching pattern: {pattern}")
return result
return 0
except Exception as e:
logger.error(f"Error deleting pattern: {e}")
self._error_count += 1
return 0
def get_metrics(self) -> dict[str, Any]:
"""Get cache metrics."""
total_requests = self._hit_count + self._miss_count
hit_rate = (self._hit_count / total_requests) if total_requests > 0 else 0
return {
"hit_count": self._hit_count,
"miss_count": self._miss_count,
"error_count": self._error_count,
"total_requests": total_requests,
"hit_rate": hit_rate,
"ttl_config": self._default_ttls,
}
# Global request cache instance
request_cache = RequestCache()
def cached(
data_type: str = "default",
ttl: int | None = None,
key_prefix: str | None = None,
invalidate_patterns: list[str] | None = None,
):
"""
Decorator for automatic function result caching.
Args:
data_type: Data type for TTL determination
ttl: Custom TTL in seconds
key_prefix: Custom cache key prefix
invalidate_patterns: Patterns to invalidate on update
Example:
@cached(data_type="stock_data", ttl=3600)
async def get_stock_price(symbol: str) -> float:
# Expensive operation
return price
"""
def decorator(func: F) -> F:
@wraps(func)
async def wrapper(*args, **kwargs):
# Generate cache key
prefix = key_prefix or f"{func.__module__}.{func.__name__}"
cache_key = request_cache._generate_cache_key(prefix, *args, **kwargs)
# Try to get from cache
cached_result = await request_cache.get(cache_key)
if cached_result is not None:
return cached_result
# Execute function
result = await func(*args, **kwargs)
# Cache result
if result is not None:
await request_cache.set(cache_key, result, ttl, data_type)
return result
# Add cache invalidation method
async def invalidate_cache(*args, **kwargs):
"""Invalidate cache for this function."""
prefix = key_prefix or f"{func.__module__}.{func.__name__}"
cache_key = request_cache._generate_cache_key(prefix, *args, **kwargs)
await request_cache.delete(cache_key)
# Invalidate patterns if specified
if invalidate_patterns:
for pattern in invalidate_patterns:
await request_cache.delete_pattern(pattern)
typed_wrapper = cast(F, wrapper)
cast(Any, typed_wrapper).invalidate_cache = invalidate_cache
return typed_wrapper
return decorator
class QueryOptimizer:
"""
Database query optimization utilities.
This class provides:
- Query performance monitoring
- Index recommendations
- N+1 query detection
- Connection pool monitoring
"""
def __init__(self):
self._query_stats = {}
self._slow_query_threshold = 1.0 # seconds
self._slow_queries = []
def monitor_query(self, query_name: str):
"""
Decorator for monitoring query performance.
Args:
query_name: Name for the query (for metrics)
"""
def decorator(func: F) -> F:
@wraps(func)
async def wrapper(*args, **kwargs):
start_time = time.time()
try:
result = await func(*args, **kwargs)
execution_time = time.time() - start_time
# Update statistics
if query_name not in self._query_stats:
self._query_stats[query_name] = {
"count": 0,
"total_time": 0,
"avg_time": 0,
"max_time": 0,
"min_time": float("inf"),
}
stats = self._query_stats[query_name]
stats["count"] += 1
stats["total_time"] += execution_time
stats["avg_time"] = stats["total_time"] / stats["count"]
stats["max_time"] = max(stats["max_time"], execution_time)
stats["min_time"] = min(stats["min_time"], execution_time)
# Track slow queries
if execution_time > self._slow_query_threshold:
self._slow_queries.append(
{
"query_name": query_name,
"execution_time": execution_time,
"timestamp": time.time(),
"args": str(args)[:200], # Truncate long args
}
)
# Keep only last 100 slow queries
if len(self._slow_queries) > 100:
self._slow_queries = self._slow_queries[-100:]
logger.warning(
f"Slow query detected: {query_name} took {execution_time:.2f}s"
)
return result
except Exception as e:
execution_time = time.time() - start_time
logger.error(
f"Query {query_name} failed after {execution_time:.2f}s: {e}"
)
raise
return cast(F, wrapper)
return decorator
def get_query_stats(self) -> dict[str, Any]:
"""Get query performance statistics."""
return {
"query_stats": self._query_stats,
"slow_queries": self._slow_queries[-10:], # Last 10 slow queries
"slow_query_threshold": self._slow_query_threshold,
}
async def analyze_missing_indexes(
self, session: AsyncSession
) -> list[dict[str, Any]]:
"""
Analyze database for missing indexes.
Args:
session: Database session
Returns:
List of recommended indexes
"""
recommendations = []
try:
# Check for common missing indexes
queries = [
# PriceCache table analysis
{
"name": "PriceCache date range queries",
"query": """
SELECT schemaname, tablename, attname, n_distinct, correlation
FROM pg_stats
WHERE tablename = 'stocks_pricecache'
AND attname IN ('date', 'stock_id', 'volume')
""",
"recommendation": "Consider composite index on (stock_id, date) if not exists",
},
# Stock lookup performance
{
"name": "Stock ticker lookups",
"query": """
SELECT schemaname, tablename, attname, n_distinct, correlation
FROM pg_stats
WHERE tablename = 'stocks_stock'
AND attname = 'ticker_symbol'
""",
"recommendation": "Ensure unique index on ticker_symbol exists",
},
# Screening tables
{
"name": "Maverick screening queries",
"query": """
SELECT schemaname, tablename, attname, n_distinct
FROM pg_stats
WHERE tablename IN ('stocks_maverickstocks', 'stocks_maverickbearstocks', 'stocks_supply_demand_breakouts')
AND attname IN ('score', 'rank', 'date_analyzed')
""",
"recommendation": "Consider indexes on score, rank, and date_analyzed columns",
},
]
for query_info in queries:
try:
result = await session.execute(text(query_info["query"]))
rows = result.fetchall()
if rows:
recommendations.append(
{
"analysis": query_info["name"],
"recommendation": query_info["recommendation"],
"stats": [dict(row._mapping) for row in rows],
}
)
except Exception as e:
logger.error(f"Failed to analyze {query_info['name']}: {e}")
# Check for tables without proper indexes
missing_indexes_query = """
SELECT
schemaname,
tablename,
seq_scan,
seq_tup_read,
idx_scan,
idx_tup_fetch,
CASE
WHEN seq_scan = 0 THEN 0
ELSE seq_tup_read / seq_scan
END as avg_seq_read
FROM pg_stat_user_tables
WHERE schemaname = 'public'
AND tablename LIKE 'stocks_%'
ORDER BY seq_tup_read DESC
"""
result = await session.execute(text(missing_indexes_query))
scan_stats = result.fetchall()
for row in scan_stats:
if row.seq_scan > 100 and row.avg_seq_read > 1000:
recommendations.append(
{
"analysis": f"High sequential scans on {row.tablename}",
"recommendation": f"Consider adding indexes to reduce {row.seq_tup_read} sequential reads",
"stats": dict(row._mapping),
}
)
except Exception as e:
logger.error(f"Error analyzing missing indexes: {e}")
return recommendations
# Global query optimizer instance
query_optimizer = QueryOptimizer()
async def initialize_performance_systems():
"""Initialize all performance optimization systems."""
logger.info("Initializing performance optimization systems...")
# Initialize Redis connection manager
redis_success = await redis_manager.initialize()
logger.info(
f"Performance systems initialized: Redis={'✓' if redis_success else '✗'}"
)
return {
"redis_manager": redis_success,
"request_cache": True,
"query_optimizer": True,
}
async def get_performance_metrics() -> dict[str, Any]:
"""Get comprehensive performance metrics."""
return {
"redis_manager": redis_manager.get_metrics(),
"request_cache": request_cache.get_metrics(),
"query_optimizer": query_optimizer.get_query_stats(),
"timestamp": time.time(),
}
async def cleanup_performance_systems():
"""Cleanup performance systems gracefully."""
logger.info("Cleaning up performance optimization systems...")
await redis_manager.close()
logger.info("Performance systems cleanup completed")
# Context manager for database session with query monitoring
@asynccontextmanager
async def monitored_db_session(query_name: str = "unknown"):
"""
Context manager for database sessions with automatic query monitoring.
Args:
query_name: Name for the query (for metrics)
Example:
async with monitored_db_session("get_stock_data") as session:
result = await session.execute(
text("SELECT * FROM stocks_stock WHERE ticker_symbol = :symbol"),
{"symbol": "AAPL"},
)
stock = result.first()
"""
async with get_async_db_session() as session:
start_time = time.time()
try:
yield session
# Record successful query
execution_time = time.time() - start_time
if query_name not in query_optimizer._query_stats:
query_optimizer._query_stats[query_name] = {
"count": 0,
"total_time": 0,
"avg_time": 0,
"max_time": 0,
"min_time": float("inf"),
}
stats = query_optimizer._query_stats[query_name]
stats["count"] += 1
stats["total_time"] += execution_time
stats["avg_time"] = stats["total_time"] / stats["count"]
stats["max_time"] = max(stats["max_time"], execution_time)
stats["min_time"] = min(stats["min_time"], execution_time)
except Exception as e:
execution_time = time.time() - start_time
logger.error(
f"Database query '{query_name}' failed after {execution_time:.2f}s: {e}"
)
raise
```
--------------------------------------------------------------------------------
/maverick_mcp/api/routers/monitoring.py:
--------------------------------------------------------------------------------
```python
"""
Monitoring and health check endpoints for MaverickMCP.
This module provides endpoints for:
- Prometheus metrics exposure
- Health checks (basic, detailed, readiness)
- System status and diagnostics
- Monitoring dashboard data
"""
import time
from typing import Any
from fastapi import APIRouter, HTTPException, Response
from pydantic import BaseModel
from maverick_mcp.config.settings import settings
from maverick_mcp.monitoring.metrics import (
get_backtesting_metrics,
get_metrics_for_prometheus,
)
from maverick_mcp.utils.database_monitoring import (
get_cache_monitor,
get_database_monitor,
)
from maverick_mcp.utils.logging import get_logger
from maverick_mcp.utils.monitoring import get_metrics, get_monitoring_service
logger = get_logger(__name__)
router = APIRouter()
class HealthStatus(BaseModel):
"""Health check response model."""
status: str
timestamp: float
version: str
environment: str
uptime_seconds: float
class DetailedHealthStatus(BaseModel):
"""Detailed health check response model."""
status: str
timestamp: float
version: str
environment: str
uptime_seconds: float
services: dict[str, dict[str, Any]]
metrics: dict[str, Any]
class SystemMetrics(BaseModel):
"""System metrics response model."""
cpu_usage_percent: float
memory_usage_mb: float
open_file_descriptors: int
active_connections: int
database_pool_status: dict[str, Any]
redis_info: dict[str, Any]
class ServiceStatus(BaseModel):
"""Individual service status."""
name: str
status: str
last_check: float
details: dict[str, Any] = {}
# Track server start time for uptime calculation
_server_start_time = time.time()
@router.get("/health", response_model=HealthStatus)
async def health_check():
"""
Basic health check endpoint.
Returns basic service status and uptime information.
Used by load balancers and orchestration systems.
"""
return HealthStatus(
status="healthy",
timestamp=time.time(),
version="1.0.0", # You might want to get this from a version file
environment=settings.environment,
uptime_seconds=time.time() - _server_start_time,
)
@router.get("/health/detailed", response_model=DetailedHealthStatus)
async def detailed_health_check():
"""
Detailed health check endpoint.
Returns comprehensive health information including:
- Service dependencies status
- Database connectivity
- Redis connectivity
- Performance metrics
"""
services = {}
# Check database health
try:
db_monitor = get_database_monitor()
pool_status = db_monitor.get_pool_status()
services["database"] = {
"status": "healthy" if pool_status else "unknown",
"details": pool_status,
"last_check": time.time(),
}
except Exception as e:
services["database"] = {
"status": "unhealthy",
"details": {"error": str(e)},
"last_check": time.time(),
}
# Check Redis health
try:
cache_monitor = get_cache_monitor()
redis_info = (
await cache_monitor.redis_monitor.get_redis_info()
if cache_monitor.redis_monitor
else {}
)
services["redis"] = {
"status": "healthy" if redis_info else "unknown",
"details": redis_info,
"last_check": time.time(),
}
except Exception as e:
services["redis"] = {
"status": "unhealthy",
"details": {"error": str(e)},
"last_check": time.time(),
}
# Check monitoring services
try:
monitoring = get_monitoring_service()
services["monitoring"] = {
"status": "healthy",
"details": {
"sentry_enabled": monitoring.sentry_enabled,
},
"last_check": time.time(),
}
except Exception as e:
services["monitoring"] = {
"status": "unhealthy",
"details": {"error": str(e)},
"last_check": time.time(),
}
# Overall status
overall_status = "healthy"
for service in services.values():
if service["status"] == "unhealthy":
overall_status = "unhealthy"
break
elif service["status"] == "unknown" and overall_status == "healthy":
overall_status = "degraded"
return DetailedHealthStatus(
status=overall_status,
timestamp=time.time(),
version="1.0.0",
environment=settings.environment,
uptime_seconds=time.time() - _server_start_time,
services=services,
metrics=await _get_basic_metrics(),
)
@router.get("/health/readiness")
async def readiness_check():
"""
Readiness check endpoint.
Indicates whether the service is ready to handle requests.
Used by Kubernetes and other orchestration systems.
"""
try:
# Check critical dependencies
checks = []
# Database readiness
try:
db_monitor = get_database_monitor()
pool_status = db_monitor.get_pool_status()
if pool_status and pool_status.get("pool_size", 0) > 0:
checks.append(True)
else:
checks.append(False)
except Exception:
checks.append(False)
# Redis readiness (if configured)
try:
cache_monitor = get_cache_monitor()
if cache_monitor.redis_monitor:
redis_info = await cache_monitor.redis_monitor.get_redis_info()
checks.append(bool(redis_info))
else:
checks.append(True) # Redis not required
except Exception:
checks.append(False)
if all(checks):
return {"status": "ready", "timestamp": time.time()}
else:
raise HTTPException(status_code=503, detail="Service not ready")
except HTTPException:
raise
except Exception as e:
logger.error(f"Readiness check failed: {e}")
raise HTTPException(status_code=503, detail="Readiness check failed")
@router.get("/health/liveness")
async def liveness_check():
"""
Liveness check endpoint.
Indicates whether the service is alive and should not be restarted.
Used by Kubernetes and other orchestration systems.
"""
# Simple check - if we can respond, we're alive
return {"status": "alive", "timestamp": time.time()}
@router.get("/metrics")
async def prometheus_metrics():
"""
Prometheus metrics endpoint.
Returns comprehensive metrics in Prometheus text format for scraping.
Includes both system metrics and backtesting-specific metrics.
"""
try:
# Get standard system metrics
system_metrics = get_metrics()
# Get backtesting-specific metrics
backtesting_metrics = get_metrics_for_prometheus()
# Combine all metrics
combined_metrics = system_metrics + "\n" + backtesting_metrics
return Response(
content=combined_metrics,
media_type="text/plain; version=0.0.4; charset=utf-8",
)
except Exception as e:
logger.error(f"Failed to generate metrics: {e}")
raise HTTPException(status_code=500, detail="Failed to generate metrics")
@router.get("/metrics/backtesting")
async def backtesting_metrics():
"""
Specialized backtesting metrics endpoint.
Returns backtesting-specific metrics in Prometheus text format.
Useful for dedicated backtesting monitoring and alerting.
"""
try:
backtesting_metrics_text = get_metrics_for_prometheus()
return Response(
content=backtesting_metrics_text,
media_type="text/plain; version=0.0.4; charset=utf-8",
)
except Exception as e:
logger.error(f"Failed to generate backtesting metrics: {e}")
raise HTTPException(
status_code=500, detail="Failed to generate backtesting metrics"
)
@router.get("/metrics/json")
async def metrics_json():
"""
Get metrics in JSON format for dashboards and monitoring.
Returns structured metrics data suitable for consumption by
monitoring dashboards and alerting systems.
"""
try:
return {
"timestamp": time.time(),
"system": await _get_system_metrics(),
"application": await _get_application_metrics(),
"business": await _get_business_metrics(),
"backtesting": await _get_backtesting_metrics(),
}
except Exception as e:
logger.error(f"Failed to generate JSON metrics: {e}")
raise HTTPException(status_code=500, detail="Failed to generate JSON metrics")
@router.get("/status", response_model=SystemMetrics)
async def system_status():
"""
Get current system status and performance metrics.
Returns real-time system performance data including:
- CPU and memory usage
- Database connection pool status
- Redis connection status
- File descriptor usage
"""
try:
import psutil
process = psutil.Process()
memory_info = process.memory_info()
# Get database pool status
db_monitor = get_database_monitor()
pool_status = db_monitor.get_pool_status()
# Get Redis info
cache_monitor = get_cache_monitor()
redis_info = {}
if cache_monitor.redis_monitor:
redis_info = await cache_monitor.redis_monitor.get_redis_info()
return SystemMetrics(
cpu_usage_percent=process.cpu_percent(),
memory_usage_mb=memory_info.rss / 1024 / 1024,
open_file_descriptors=process.num_fds()
if hasattr(process, "num_fds")
else 0,
active_connections=0, # This would come from your connection tracking
database_pool_status=pool_status,
redis_info=redis_info,
)
except Exception as e:
logger.error(f"Failed to get system status: {e}")
raise HTTPException(status_code=500, detail="Failed to get system status")
@router.get("/diagnostics")
async def system_diagnostics():
"""
Get comprehensive system diagnostics.
Returns detailed diagnostic information for troubleshooting:
- Environment configuration
- Feature flags
- Service dependencies
- Performance metrics
- Recent errors
"""
try:
diagnostics = {
"timestamp": time.time(),
"environment": {
"name": settings.environment,
"auth_enabled": False, # Disabled for personal use
"debug_mode": settings.api.debug,
},
"uptime_seconds": time.time() - _server_start_time,
"services": await _get_service_diagnostics(),
"performance": await _get_performance_diagnostics(),
"configuration": _get_configuration_diagnostics(),
}
return diagnostics
except Exception as e:
logger.error(f"Failed to generate diagnostics: {e}")
raise HTTPException(status_code=500, detail="Failed to generate diagnostics")
async def _get_basic_metrics() -> dict[str, Any]:
"""Get basic performance metrics."""
try:
import psutil
process = psutil.Process()
memory_info = process.memory_info()
return {
"cpu_usage_percent": process.cpu_percent(),
"memory_usage_mb": memory_info.rss / 1024 / 1024,
"uptime_seconds": time.time() - _server_start_time,
}
except Exception as e:
logger.error(f"Failed to get basic metrics: {e}")
return {}
async def _get_system_metrics() -> dict[str, Any]:
"""Get detailed system metrics."""
try:
import psutil
process = psutil.Process()
memory_info = process.memory_info()
return {
"cpu": {
"usage_percent": process.cpu_percent(),
"times": process.cpu_times()._asdict(),
},
"memory": {
"rss_mb": memory_info.rss / 1024 / 1024,
"vms_mb": memory_info.vms / 1024 / 1024,
"percent": process.memory_percent(),
},
"file_descriptors": {
"open": process.num_fds() if hasattr(process, "num_fds") else 0,
},
"threads": process.num_threads(),
}
except Exception as e:
logger.error(f"Failed to get system metrics: {e}")
return {}
async def _get_application_metrics() -> dict[str, Any]:
"""Get application-specific metrics."""
try:
# Get database metrics
db_monitor = get_database_monitor()
pool_status = db_monitor.get_pool_status()
# Get cache metrics
cache_monitor = get_cache_monitor()
redis_info = {}
if cache_monitor.redis_monitor:
redis_info = await cache_monitor.redis_monitor.get_redis_info()
return {
"database": {
"pool_status": pool_status,
},
"cache": {
"redis_info": redis_info,
},
"monitoring": {
"sentry_enabled": get_monitoring_service().sentry_enabled,
},
}
except Exception as e:
logger.error(f"Failed to get application metrics: {e}")
return {}
async def _get_business_metrics() -> dict[str, Any]:
"""Get business-related metrics."""
# This would typically query your database for business metrics
# For now, return placeholder data
return {
"users": {
"total_active": 0,
"daily_active": 0,
"monthly_active": 0,
},
"tools": {
"total_executions": 0,
"average_execution_time": 0,
},
"engagement": {
"portfolio_reviews": 0,
"watchlists_managed": 0,
},
}
async def _get_backtesting_metrics() -> dict[str, Any]:
"""Get backtesting-specific metrics summary."""
try:
# Get the backtesting metrics collector
get_backtesting_metrics()
# Return a summary of key backtesting metrics
# In a real implementation, you might query the metrics registry
# or maintain counters in the collector class
return {
"strategy_performance": {
"total_backtests_run": 0, # Would be populated from metrics
"average_execution_time": 0.0,
"successful_backtests": 0,
"failed_backtests": 0,
},
"api_usage": {
"total_api_calls": 0,
"rate_limit_hits": 0,
"average_response_time": 0.0,
"error_rate": 0.0,
},
"resource_usage": {
"peak_memory_usage_mb": 0.0,
"average_computation_time": 0.0,
"database_query_count": 0,
},
"anomalies": {
"total_anomalies_detected": 0,
"critical_anomalies": 0,
"warning_anomalies": 0,
},
}
except Exception as e:
logger.error(f"Failed to get backtesting metrics: {e}")
return {}
async def _get_service_diagnostics() -> dict[str, Any]:
"""Get service dependency diagnostics."""
services = {}
# Database diagnostics
try:
db_monitor = get_database_monitor()
pool_status = db_monitor.get_pool_status()
services["database"] = {
"status": "healthy" if pool_status else "unknown",
"pool_status": pool_status,
"url_configured": bool(settings.database.url),
}
except Exception as e:
services["database"] = {
"status": "error",
"error": str(e),
}
# Redis diagnostics
try:
cache_monitor = get_cache_monitor()
if cache_monitor.redis_monitor:
redis_info = await cache_monitor.redis_monitor.get_redis_info()
services["redis"] = {
"status": "healthy" if redis_info else "unknown",
"info": redis_info,
}
else:
services["redis"] = {
"status": "not_configured",
}
except Exception as e:
services["redis"] = {
"status": "error",
"error": str(e),
}
return services
async def _get_performance_diagnostics() -> dict[str, Any]:
"""Get performance diagnostics."""
try:
import gc
import psutil
process = psutil.Process()
return {
"garbage_collection": {
"stats": gc.get_stats(),
"counts": gc.get_count(),
},
"process": {
"create_time": process.create_time(),
"num_threads": process.num_threads(),
"connections": len(process.connections())
if hasattr(process, "connections")
else 0,
},
}
except Exception as e:
logger.error(f"Failed to get performance diagnostics: {e}")
return {}
def _get_configuration_diagnostics() -> dict[str, Any]:
"""Get configuration diagnostics."""
return {
"environment": settings.environment,
"features": {
"auth_enabled": False, # Disabled for personal use
"debug_mode": settings.api.debug,
},
"database": {
"url_configured": bool(settings.database.url),
},
}
# Health check dependencies for other endpoints
async def require_healthy_database():
"""Dependency that ensures database is healthy."""
try:
db_monitor = get_database_monitor()
pool_status = db_monitor.get_pool_status()
if not pool_status or pool_status.get("pool_size", 0) == 0:
raise HTTPException(status_code=503, detail="Database not available")
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=503, detail=f"Database health check failed: {e}"
)
async def require_healthy_redis():
"""Dependency that ensures Redis is healthy."""
try:
cache_monitor = get_cache_monitor()
if cache_monitor.redis_monitor:
redis_info = await cache_monitor.redis_monitor.get_redis_info()
if not redis_info:
raise HTTPException(status_code=503, detail="Redis not available")
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=503, detail=f"Redis health check failed: {e}")
@router.get("/alerts")
async def get_active_alerts():
"""
Get active alerts and anomalies detected by the monitoring system.
Returns current alerts for:
- Strategy performance anomalies
- API rate limiting issues
- Resource usage threshold violations
- Data quality problems
"""
try:
alerts = []
timestamp = time.time()
# Get the backtesting metrics collector to check for anomalies
get_backtesting_metrics()
# In a real implementation, you would query stored alert data
# For now, we'll check current thresholds and return sample data
# Example: Check current system metrics against thresholds
import psutil
process = psutil.Process()
memory_mb = process.memory_info().rss / 1024 / 1024
# Check memory usage threshold
if memory_mb > 1000: # 1GB threshold
alerts.append(
{
"id": "memory_high_001",
"type": "resource_usage",
"severity": "warning" if memory_mb < 2000 else "critical",
"title": "High Memory Usage",
"description": f"Process memory usage is {memory_mb:.1f}MB",
"timestamp": timestamp,
"metric_value": memory_mb,
"threshold_value": 1000,
"status": "active",
"tags": ["memory", "system", "performance"],
}
)
# Check database connection pool
try:
db_monitor = get_database_monitor()
pool_status = db_monitor.get_pool_status()
if (
pool_status
and pool_status.get("active_connections", 0)
> pool_status.get("pool_size", 10) * 0.8
):
alerts.append(
{
"id": "db_pool_high_001",
"type": "database_performance",
"severity": "warning",
"title": "High Database Connection Usage",
"description": "Database connection pool usage is above 80%",
"timestamp": timestamp,
"metric_value": pool_status.get("active_connections", 0),
"threshold_value": pool_status.get("pool_size", 10) * 0.8,
"status": "active",
"tags": ["database", "connections", "performance"],
}
)
except Exception:
pass
return {
"alerts": alerts,
"total_count": len(alerts),
"severity_counts": {
"critical": len([a for a in alerts if a["severity"] == "critical"]),
"warning": len([a for a in alerts if a["severity"] == "warning"]),
"info": len([a for a in alerts if a["severity"] == "info"]),
},
"timestamp": timestamp,
}
except Exception as e:
logger.error(f"Failed to get alerts: {e}")
raise HTTPException(status_code=500, detail="Failed to get alerts")
@router.get("/alerts/rules")
async def get_alert_rules():
"""
Get configured alert rules and thresholds.
Returns the current alert rule configuration including:
- Performance thresholds
- Anomaly detection settings
- Alert severity levels
- Notification settings
"""
try:
# Get the backtesting metrics collector
get_backtesting_metrics()
# Return the configured alert rules
rules = {
"performance_thresholds": {
"sharpe_ratio": {
"warning_threshold": 0.5,
"critical_threshold": 0.0,
"comparison": "less_than",
"enabled": True,
},
"max_drawdown": {
"warning_threshold": 20.0,
"critical_threshold": 30.0,
"comparison": "greater_than",
"enabled": True,
},
"win_rate": {
"warning_threshold": 40.0,
"critical_threshold": 30.0,
"comparison": "less_than",
"enabled": True,
},
"execution_time": {
"warning_threshold": 60.0,
"critical_threshold": 120.0,
"comparison": "greater_than",
"enabled": True,
},
},
"resource_thresholds": {
"memory_usage": {
"warning_threshold": 1000, # MB
"critical_threshold": 2000, # MB
"comparison": "greater_than",
"enabled": True,
},
"cpu_usage": {
"warning_threshold": 80.0, # %
"critical_threshold": 95.0, # %
"comparison": "greater_than",
"enabled": True,
},
"disk_usage": {
"warning_threshold": 80.0, # %
"critical_threshold": 95.0, # %
"comparison": "greater_than",
"enabled": True,
},
},
"api_thresholds": {
"response_time": {
"warning_threshold": 30.0, # seconds
"critical_threshold": 60.0, # seconds
"comparison": "greater_than",
"enabled": True,
},
"error_rate": {
"warning_threshold": 5.0, # %
"critical_threshold": 10.0, # %
"comparison": "greater_than",
"enabled": True,
},
"rate_limit_usage": {
"warning_threshold": 80.0, # %
"critical_threshold": 95.0, # %
"comparison": "greater_than",
"enabled": True,
},
},
"anomaly_detection": {
"enabled": True,
"sensitivity": "medium",
"lookback_period_hours": 24,
"minimum_data_points": 10,
},
"notification_settings": {
"webhook_enabled": False,
"email_enabled": False,
"slack_enabled": False,
"webhook_url": None,
},
}
return {
"rules": rules,
"total_rules": sum(
len(category)
for category in rules.values()
if isinstance(category, dict)
),
"enabled_rules": sum(
len(
[
rule
for rule in category.values()
if isinstance(rule, dict) and rule.get("enabled", False)
]
)
for category in rules.values()
if isinstance(category, dict)
),
"timestamp": time.time(),
}
except Exception as e:
logger.error(f"Failed to get alert rules: {e}")
raise HTTPException(status_code=500, detail="Failed to get alert rules")
```
--------------------------------------------------------------------------------
/tests/integration/test_full_backtest_workflow.py:
--------------------------------------------------------------------------------
```python
"""
Comprehensive end-to-end integration tests for VectorBT backtesting workflow.
Tests cover:
- Full workflow integration from data fetching to result visualization
- LangGraph workflow orchestration with real agents
- Database persistence with real PostgreSQL operations
- Chart generation and visualization pipeline
- ML strategy integration with adaptive learning
- Performance benchmarks for complete workflow
- Error recovery and resilience testing
- Concurrent workflow execution
- Resource cleanup and memory management
- Cache integration and optimization
"""
import asyncio
import base64
import logging
from datetime import datetime
from unittest.mock import Mock, patch
from uuid import UUID
import numpy as np
import pandas as pd
import pytest
from maverick_mcp.backtesting.persistence import (
BacktestPersistenceManager,
)
from maverick_mcp.backtesting.vectorbt_engine import VectorBTEngine
from maverick_mcp.backtesting.visualization import (
generate_equity_curve,
generate_performance_dashboard,
)
from maverick_mcp.providers.stock_data import EnhancedStockDataProvider
from maverick_mcp.workflows.backtesting_workflow import BacktestingWorkflow
logger = logging.getLogger(__name__)
class TestFullBacktestWorkflowIntegration:
"""Integration tests for complete backtesting workflow."""
@pytest.fixture
async def mock_stock_data_provider(self):
"""Create a mock stock data provider with realistic data."""
provider = Mock(spec=EnhancedStockDataProvider)
# Generate realistic stock data
dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
returns = np.random.normal(0.0008, 0.02, len(dates)) # ~20% annual volatility
prices = 150 * np.cumprod(1 + returns) # Start at $150
volumes = np.random.randint(1000000, 10000000, len(dates))
stock_data = pd.DataFrame(
{
"Open": prices * np.random.uniform(0.99, 1.01, len(dates)),
"High": prices * np.random.uniform(1.00, 1.03, len(dates)),
"Low": prices * np.random.uniform(0.97, 1.00, len(dates)),
"Close": prices,
"Volume": volumes,
"Adj Close": prices,
},
index=dates,
)
# Ensure OHLC constraints
stock_data["High"] = np.maximum(
stock_data["High"], np.maximum(stock_data["Open"], stock_data["Close"])
)
stock_data["Low"] = np.minimum(
stock_data["Low"], np.minimum(stock_data["Open"], stock_data["Close"])
)
provider.get_stock_data.return_value = stock_data
return provider
@pytest.fixture
async def vectorbt_engine(self, mock_stock_data_provider):
"""Create VectorBT engine with mocked data provider."""
engine = VectorBTEngine(data_provider=mock_stock_data_provider)
return engine
@pytest.fixture
def workflow_with_real_agents(self):
"""Create workflow with real agents (not mocked)."""
return BacktestingWorkflow()
async def test_complete_workflow_execution(
self, workflow_with_real_agents, db_session, benchmark_timer
):
"""Test complete workflow from start to finish with database persistence."""
with benchmark_timer() as timer:
# Execute intelligent backtest
result = await workflow_with_real_agents.run_intelligent_backtest(
symbol="AAPL",
start_date="2023-01-01",
end_date="2023-12-31",
initial_capital=10000.0,
)
# Test basic result structure
assert "symbol" in result
assert result["symbol"] == "AAPL"
assert "execution_metadata" in result
assert "market_analysis" in result
assert "strategy_selection" in result
assert "recommendation" in result
# Test execution metadata
metadata = result["execution_metadata"]
assert "total_execution_time_ms" in metadata
assert "workflow_completed" in metadata
assert "steps_completed" in metadata
# Test performance requirements
assert timer.elapsed < 60.0 # Should complete within 1 minute
assert metadata["total_execution_time_ms"] > 0
# Test that meaningful analysis occurred
market_analysis = result["market_analysis"]
assert "regime" in market_analysis
assert "regime_confidence" in market_analysis
strategy_selection = result["strategy_selection"]
assert "selected_strategies" in strategy_selection
assert "selection_reasoning" in strategy_selection
recommendation = result["recommendation"]
assert "recommended_strategy" in recommendation
assert "recommendation_confidence" in recommendation
logger.info(f"Complete workflow executed in {timer.elapsed:.2f}s")
async def test_workflow_with_persistence_integration(
self, workflow_with_real_agents, db_session, sample_vectorbt_results
):
"""Test workflow integration with database persistence."""
# First run the workflow
result = await workflow_with_real_agents.run_intelligent_backtest(
symbol="TSLA", start_date="2023-01-01", end_date="2023-12-31"
)
# Simulate saving backtest results to database
with BacktestPersistenceManager(session=db_session) as persistence:
# Modify sample results to match workflow output
sample_vectorbt_results["symbol"] = "TSLA"
sample_vectorbt_results["strategy"] = result["recommendation"][
"recommended_strategy"
]
backtest_id = persistence.save_backtest_result(
vectorbt_results=sample_vectorbt_results,
execution_time=result["execution_metadata"]["total_execution_time_ms"]
/ 1000,
notes=f"Intelligent backtest - {result['recommendation']['recommendation_confidence']:.2%} confidence",
)
# Verify persistence
assert backtest_id is not None
assert UUID(backtest_id)
# Retrieve and verify
saved_result = persistence.get_backtest_by_id(backtest_id)
assert saved_result is not None
assert saved_result.symbol == "TSLA"
assert (
saved_result.strategy_type
== result["recommendation"]["recommended_strategy"]
)
async def test_workflow_with_visualization_integration(
self, workflow_with_real_agents, sample_vectorbt_results
):
"""Test workflow integration with visualization generation."""
# Run workflow
result = await workflow_with_real_agents.run_intelligent_backtest(
symbol="NVDA", start_date="2023-01-01", end_date="2023-12-31"
)
# Generate visualizations based on workflow results
equity_curve_data = pd.Series(sample_vectorbt_results["equity_curve"])
drawdown_data = pd.Series(sample_vectorbt_results["drawdown_series"])
# Test equity curve generation
equity_chart = generate_equity_curve(
equity_curve_data,
drawdown=drawdown_data,
title=f"NVDA - {result['recommendation']['recommended_strategy']} Strategy",
)
assert isinstance(equity_chart, str)
assert len(equity_chart) > 100
# Verify base64 image
try:
decoded_bytes = base64.b64decode(equity_chart)
assert decoded_bytes.startswith(b"\x89PNG")
except Exception as e:
pytest.fail(f"Invalid chart generation: {e}")
# Test performance dashboard
dashboard_metrics = {
"Strategy": result["recommendation"]["recommended_strategy"],
"Confidence": f"{result['recommendation']['recommendation_confidence']:.1%}",
"Market Regime": result["market_analysis"]["regime"],
"Regime Confidence": f"{result['market_analysis']['regime_confidence']:.1%}",
"Total Return": sample_vectorbt_results["metrics"]["total_return"],
"Sharpe Ratio": sample_vectorbt_results["metrics"]["sharpe_ratio"],
"Max Drawdown": sample_vectorbt_results["metrics"]["max_drawdown"],
}
dashboard_chart = generate_performance_dashboard(
dashboard_metrics, title="Intelligent Backtest Results"
)
assert isinstance(dashboard_chart, str)
assert len(dashboard_chart) > 100
async def test_workflow_with_ml_strategy_integration(
self, workflow_with_real_agents, mock_stock_data_provider
):
"""Test workflow integration with ML-enhanced strategies."""
# Mock the workflow to use ML strategies
with patch.object(
workflow_with_real_agents.strategy_selector, "select_strategies"
) as mock_selector:
async def mock_select_with_ml(state):
state.selected_strategies = ["adaptive_momentum", "online_learning"]
state.strategy_selection_confidence = 0.85
state.strategy_selection_reasoning = (
"ML strategies selected for volatile market conditions"
)
return state
mock_selector.side_effect = mock_select_with_ml
result = await workflow_with_real_agents.run_intelligent_backtest(
symbol="AMZN", start_date="2023-01-01", end_date="2023-12-31"
)
# Verify ML strategy integration
assert (
"adaptive_momentum"
in result["strategy_selection"]["selected_strategies"]
or "online_learning"
in result["strategy_selection"]["selected_strategies"]
)
assert (
"ML" in result["strategy_selection"]["selection_reasoning"]
or "adaptive" in result["strategy_selection"]["selection_reasoning"]
)
async def test_vectorbt_engine_integration(self, vectorbt_engine):
"""Test VectorBT engine integration with workflow."""
# Test data fetching
data = await vectorbt_engine.get_historical_data(
symbol="MSFT", start_date="2023-01-01", end_date="2023-12-31"
)
assert isinstance(data, pd.DataFrame)
assert len(data) > 0
# Check if required columns exist (data should already have lowercase columns)
required_cols = ["open", "high", "low", "close"]
actual_cols = list(data.columns)
missing_cols = [col for col in required_cols if col not in actual_cols]
assert all(col in actual_cols for col in required_cols), (
f"Missing columns: {missing_cols}"
)
# Test backtest execution
backtest_result = await vectorbt_engine.run_backtest(
symbol="MSFT",
strategy_type="sma_crossover",
parameters={"fast_window": 10, "slow_window": 20},
start_date="2023-01-01",
end_date="2023-12-31",
)
assert isinstance(backtest_result, dict)
assert "symbol" in backtest_result
assert "metrics" in backtest_result
assert "equity_curve" in backtest_result
async def test_error_recovery_integration(self, workflow_with_real_agents):
"""Test error recovery in integrated workflow."""
# Test with invalid symbol
result = await workflow_with_real_agents.run_intelligent_backtest(
symbol="INVALID_SYMBOL", start_date="2023-01-01", end_date="2023-12-31"
)
# Should handle gracefully
assert "error" in result or "execution_metadata" in result
if "execution_metadata" in result:
assert result["execution_metadata"]["workflow_completed"] is False
# Test with invalid date range
result = await workflow_with_real_agents.run_intelligent_backtest(
symbol="AAPL",
start_date="2025-01-01", # Future date
end_date="2025-12-31",
)
# Should handle gracefully
assert isinstance(result, dict)
async def test_concurrent_workflow_execution(
self, workflow_with_real_agents, benchmark_timer
):
"""Test concurrent execution of multiple complete workflows."""
symbols = ["AAPL", "GOOGL", "MSFT", "TSLA"]
async def run_workflow(symbol):
return await workflow_with_real_agents.run_intelligent_backtest(
symbol=symbol, start_date="2023-01-01", end_date="2023-12-31"
)
with benchmark_timer() as timer:
# Run workflows concurrently
results = await asyncio.gather(
*[run_workflow(symbol) for symbol in symbols], return_exceptions=True
)
# Test all completed
assert len(results) == len(symbols)
# Test no exceptions
successful_results = []
for i, result in enumerate(results):
if not isinstance(result, Exception):
successful_results.append(result)
assert result["symbol"] == symbols[i]
else:
logger.warning(f"Workflow failed for {symbols[i]}: {result}")
# At least half should succeed in concurrent execution
assert len(successful_results) >= len(symbols) // 2
# Test reasonable execution time for concurrent runs
assert timer.elapsed < 120.0 # Should complete within 2 minutes
logger.info(
f"Concurrent workflows completed: {len(successful_results)}/{len(symbols)} in {timer.elapsed:.2f}s"
)
async def test_performance_benchmarks_integration(
self, workflow_with_real_agents, benchmark_timer
):
"""Test performance benchmarks for integrated workflow."""
performance_results = {}
# Test quick analysis performance
with benchmark_timer() as timer:
quick_result = await workflow_with_real_agents.run_quick_analysis(
symbol="AAPL", start_date="2023-01-01", end_date="2023-12-31"
)
assert isinstance(quick_result, dict)
assert quick_result.get("symbol") == "AAPL"
quick_time = timer.elapsed
# Test full workflow performance
with benchmark_timer() as timer:
full_result = await workflow_with_real_agents.run_intelligent_backtest(
symbol="AAPL", start_date="2023-01-01", end_date="2023-12-31"
)
full_time = timer.elapsed
# Performance requirements
assert quick_time < 10.0 # Quick analysis < 10 seconds
assert full_time < 60.0 # Full workflow < 1 minute
assert quick_time < full_time # Quick should be faster than full
performance_results["quick_analysis"] = quick_time
performance_results["full_workflow"] = full_time
# Test workflow status tracking performance
if "workflow_completed" in full_result.get("execution_metadata", {}):
workflow_status = workflow_with_real_agents.get_workflow_status(
full_result.get(
"_internal_state",
Mock(
workflow_status="completed",
current_step="finalized",
steps_completed=[
"initialization",
"market_analysis",
"strategy_selection",
],
errors_encountered=[],
validation_warnings=[],
total_execution_time_ms=full_time * 1000,
recommended_strategy=full_result.get("recommendation", {}).get(
"recommended_strategy", "unknown"
),
recommendation_confidence=full_result.get(
"recommendation", {}
).get("recommendation_confidence", 0.0),
),
)
)
assert workflow_status["progress_percentage"] >= 0
assert workflow_status["progress_percentage"] <= 100
logger.info(f"Performance benchmarks: {performance_results}")
async def test_resource_cleanup_integration(self, workflow_with_real_agents):
"""Test resource cleanup after workflow completion."""
import os
import psutil
process = psutil.Process(os.getpid())
initial_memory = process.memory_info().rss
initial_threads = process.num_threads()
# Run multiple workflows
for i in range(3):
result = await workflow_with_real_agents.run_intelligent_backtest(
symbol=f"TEST_{i}", # Use different symbols
start_date="2023-01-01",
end_date="2023-12-31",
)
assert isinstance(result, dict)
# Check resource usage after completion
final_memory = process.memory_info().rss
final_threads = process.num_threads()
memory_growth = (final_memory - initial_memory) / 1024 / 1024 # MB
thread_growth = final_threads - initial_threads
# Memory growth should be reasonable
assert memory_growth < 200 # < 200MB growth
# Thread count should not grow excessively
assert thread_growth <= 5 # Allow some thread growth
logger.info(
f"Resource usage: Memory +{memory_growth:.1f}MB, Threads +{thread_growth}"
)
async def test_cache_optimization_integration(self, workflow_with_real_agents):
"""Test cache optimization in integrated workflow."""
# First run - should populate cache
start_time1 = datetime.now()
result1 = await workflow_with_real_agents.run_intelligent_backtest(
symbol="CACHE_TEST", start_date="2023-01-01", end_date="2023-12-31"
)
time1 = (datetime.now() - start_time1).total_seconds()
# Second run - should use cache
start_time2 = datetime.now()
result2 = await workflow_with_real_agents.run_intelligent_backtest(
symbol="CACHE_TEST", start_date="2023-01-01", end_date="2023-12-31"
)
time2 = (datetime.now() - start_time2).total_seconds()
# Both should complete successfully
assert isinstance(result1, dict)
assert isinstance(result2, dict)
# Second run might be faster due to caching (though not guaranteed)
# We mainly test that caching doesn't break functionality
assert result1["symbol"] == result2["symbol"] == "CACHE_TEST"
logger.info(f"Cache test: First run {time1:.2f}s, Second run {time2:.2f}s")
class TestWorkflowErrorResilience:
"""Test workflow resilience under various error conditions."""
async def test_database_failure_resilience(self, workflow_with_real_agents):
"""Test workflow resilience when database operations fail."""
with patch(
"maverick_mcp.backtesting.persistence.SessionLocal",
side_effect=Exception("Database unavailable"),
):
# Workflow should still complete even if persistence fails
result = await workflow_with_real_agents.run_intelligent_backtest(
symbol="DB_FAIL_TEST", start_date="2023-01-01", end_date="2023-12-31"
)
# Should get a result even if database persistence failed
assert isinstance(result, dict)
assert "symbol" in result
async def test_external_api_failure_resilience(self, workflow_with_real_agents):
"""Test workflow resilience when external APIs fail."""
# Mock external API failures
with patch(
"maverick_mcp.providers.stock_data.EnhancedStockDataProvider.get_stock_data",
side_effect=Exception("API rate limit exceeded"),
):
result = await workflow_with_real_agents.run_intelligent_backtest(
symbol="API_FAIL_TEST", start_date="2023-01-01", end_date="2023-12-31"
)
# Should handle API failure gracefully
assert isinstance(result, dict)
# Should either have an error field or fallback behavior
assert "error" in result or "execution_metadata" in result
async def test_memory_pressure_resilience(self, workflow_with_real_agents):
"""Test workflow resilience under memory pressure."""
# Simulate memory pressure by creating large objects
memory_pressure = []
try:
# Create memory pressure (but not too much to crash the test)
for _ in range(10):
large_array = np.random.random((1000, 1000)) # ~8MB each
memory_pressure.append(large_array)
# Run workflow under memory pressure
result = await workflow_with_real_agents.run_intelligent_backtest(
symbol="MEMORY_TEST", start_date="2023-01-01", end_date="2023-12-31"
)
assert isinstance(result, dict)
assert "symbol" in result
finally:
# Clean up memory pressure
del memory_pressure
async def test_timeout_handling(self, workflow_with_real_agents):
"""Test workflow timeout handling."""
# Create a workflow with very short timeout
with patch.object(asyncio, "wait_for") as mock_wait_for:
mock_wait_for.side_effect = TimeoutError("Workflow timed out")
try:
result = await workflow_with_real_agents.run_intelligent_backtest(
symbol="TIMEOUT_TEST",
start_date="2023-01-01",
end_date="2023-12-31",
)
# If we get here, timeout was handled
assert isinstance(result, dict)
except TimeoutError:
# Timeout occurred - this is also acceptable behavior
pass
class TestWorkflowValidation:
"""Test workflow validation and data integrity."""
async def test_input_validation(self, workflow_with_real_agents):
"""Test input parameter validation."""
# Test invalid symbol
result = await workflow_with_real_agents.run_intelligent_backtest(
symbol="", # Empty symbol
start_date="2023-01-01",
end_date="2023-12-31",
)
assert "error" in result or (
"execution_metadata" in result
and not result["execution_metadata"]["workflow_completed"]
)
# Test invalid date range
result = await workflow_with_real_agents.run_intelligent_backtest(
symbol="AAPL",
start_date="2023-12-31", # Start after end
end_date="2023-01-01",
)
assert isinstance(result, dict) # Should handle gracefully
async def test_output_validation(self, workflow_with_real_agents):
"""Test output structure validation."""
result = await workflow_with_real_agents.run_intelligent_backtest(
symbol="VALIDATE_TEST", start_date="2023-01-01", end_date="2023-12-31"
)
# Validate required fields
required_fields = ["symbol", "execution_metadata"]
for field in required_fields:
assert field in result, f"Missing required field: {field}"
# Validate execution metadata structure
metadata = result["execution_metadata"]
required_metadata = ["total_execution_time_ms", "workflow_completed"]
for field in required_metadata:
assert field in metadata, f"Missing metadata field: {field}"
# Validate data types
assert isinstance(metadata["total_execution_time_ms"], (int, float))
assert isinstance(metadata["workflow_completed"], bool)
if "recommendation" in result:
recommendation = result["recommendation"]
assert "recommended_strategy" in recommendation
assert "recommendation_confidence" in recommendation
assert isinstance(recommendation["recommendation_confidence"], (int, float))
assert 0.0 <= recommendation["recommendation_confidence"] <= 1.0
async def test_data_consistency(self, workflow_with_real_agents, db_session):
"""Test data consistency across workflow components."""
symbol = "CONSISTENCY_TEST"
result = await workflow_with_real_agents.run_intelligent_backtest(
symbol=symbol, start_date="2023-01-01", end_date="2023-12-31"
)
# Test symbol consistency
assert result["symbol"] == symbol
# If workflow completed successfully, all components should be consistent
if result["execution_metadata"]["workflow_completed"]:
# Market analysis should be consistent
if "market_analysis" in result:
market_analysis = result["market_analysis"]
assert "regime" in market_analysis
assert isinstance(
market_analysis.get("regime_confidence", 0), (int, float)
)
# Strategy selection should be consistent
if "strategy_selection" in result:
strategy_selection = result["strategy_selection"]
selected_strategies = strategy_selection.get("selected_strategies", [])
assert isinstance(selected_strategies, list)
# Recommendation should be consistent with selection
if "recommendation" in result and "strategy_selection" in result:
recommended = result["recommendation"]["recommended_strategy"]
if recommended and selected_strategies:
# Recommended strategy should be from selected strategies
# (though fallback behavior might select others)
pass # Allow flexibility for fallback scenarios
if __name__ == "__main__":
# Run integration tests with extended timeout
pytest.main(
[
__file__,
"-v",
"--tb=short",
"--asyncio-mode=auto",
"--timeout=300", # 5 minute timeout for integration tests
"-x", # Stop on first failure
]
)
```
--------------------------------------------------------------------------------
/maverick_mcp/core/technical_analysis.py:
--------------------------------------------------------------------------------
```python
"""
Technical analysis functions for Maverick-MCP.
This module contains functions for performing technical analysis on financial data,
including calculating indicators, analyzing trends, and generating trading signals.
DISCLAIMER: All technical analysis functions in this module are for educational
purposes only. Technical indicators are mathematical calculations based on historical
data and do not predict future price movements. Past performance does not guarantee
future results. Always conduct thorough research and consult with qualified financial
professionals before making investment decisions.
"""
import logging
from collections.abc import Sequence
from typing import Any
import numpy as np
import pandas as pd
import pandas_ta as ta
from maverick_mcp.config.technical_constants import TECHNICAL_CONFIG
# Set up logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("maverick_mcp.technical_analysis")
def add_technical_indicators(df: pd.DataFrame) -> pd.DataFrame:
"""
Add technical indicators to the dataframe
Args:
df: DataFrame with OHLCV price data
Returns:
DataFrame with added technical indicators
"""
# Ensure column names are lowercase
df = df.copy()
df.columns = [col.lower() for col in df.columns]
# Use pandas_ta for all indicators with configurable parameters
# EMA
df["ema_21"] = ta.ema(df["close"], length=TECHNICAL_CONFIG.EMA_PERIOD)
# SMA
df["sma_50"] = ta.sma(df["close"], length=TECHNICAL_CONFIG.SMA_SHORT_PERIOD)
df["sma_200"] = ta.sma(df["close"], length=TECHNICAL_CONFIG.SMA_LONG_PERIOD)
# RSI
df["rsi"] = ta.rsi(df["close"], length=TECHNICAL_CONFIG.RSI_PERIOD)
# MACD
macd = ta.macd(
df["close"],
fast=TECHNICAL_CONFIG.MACD_FAST_PERIOD,
slow=TECHNICAL_CONFIG.MACD_SLOW_PERIOD,
signal=TECHNICAL_CONFIG.MACD_SIGNAL_PERIOD,
)
if macd is not None and not macd.empty:
df["macd_12_26_9"] = macd["MACD_12_26_9"]
df["macds_12_26_9"] = macd["MACDs_12_26_9"]
df["macdh_12_26_9"] = macd["MACDh_12_26_9"]
else:
df["macd_12_26_9"] = np.nan
df["macds_12_26_9"] = np.nan
df["macdh_12_26_9"] = np.nan
# Bollinger Bands
bbands = ta.bbands(df["close"], length=20, std=2.0)
if bbands is not None and not bbands.empty:
resolved_columns = _resolve_bollinger_columns(bbands.columns)
if resolved_columns:
mid_col, upper_col, lower_col = resolved_columns
df["sma_20"] = bbands[mid_col]
df["bbu_20_2.0"] = bbands[upper_col]
df["bbl_20_2.0"] = bbands[lower_col]
else:
logger.warning(
"Bollinger Bands columns missing expected names: %s",
list(bbands.columns),
)
df["sma_20"] = np.nan
df["bbu_20_2.0"] = np.nan
df["bbl_20_2.0"] = np.nan
else:
df["sma_20"] = np.nan
df["bbu_20_2.0"] = np.nan
df["bbl_20_2.0"] = np.nan
df["stdev"] = df["close"].rolling(window=20).std()
# ATR
df["atr"] = ta.atr(df["high"], df["low"], df["close"], length=14)
# Stochastic Oscillator
stoch = ta.stoch(df["high"], df["low"], df["close"], k=14, d=3, smooth_k=3)
if stoch is not None and not stoch.empty:
df["stochk_14_3_3"] = stoch["STOCHk_14_3_3"]
df["stochd_14_3_3"] = stoch["STOCHd_14_3_3"]
else:
df["stochk_14_3_3"] = np.nan
df["stochd_14_3_3"] = np.nan
# ADX
adx = ta.adx(df["high"], df["low"], df["close"], length=14)
if adx is not None and not adx.empty:
df["adx_14"] = adx["ADX_14"]
else:
df["adx_14"] = np.nan
return df
def _resolve_bollinger_columns(columns: Sequence[str]) -> tuple[str, str, str] | None:
"""Resolve Bollinger Band column names across pandas-ta variants."""
candidate_sets = [
("BBM_20_2.0", "BBU_20_2.0", "BBL_20_2.0"),
("BBM_20_2", "BBU_20_2", "BBL_20_2"),
]
for candidate in candidate_sets:
if set(candidate).issubset(columns):
return candidate
mid_candidates = [column for column in columns if column.startswith("BBM_")]
upper_candidates = [column for column in columns if column.startswith("BBU_")]
lower_candidates = [column for column in columns if column.startswith("BBL_")]
if mid_candidates and upper_candidates and lower_candidates:
return mid_candidates[0], upper_candidates[0], lower_candidates[0]
return None
def identify_support_levels(df: pd.DataFrame) -> list[float]:
"""
Identify support levels using recent lows
Args:
df: DataFrame with price data
Returns:
List of support price levels
"""
# Use the lowest points in recent periods
last_month = df.iloc[-30:] if len(df) >= 30 else df
min_price = last_month["low"].min()
# Additional support levels
support_levels = [
round(min_price, 2),
round(df["close"].iloc[-1] * 0.95, 2), # 5% below current price
round(df["close"].iloc[-1] * 0.90, 2), # 10% below current price
]
return sorted(set(support_levels))
def identify_resistance_levels(df: pd.DataFrame) -> list[float]:
"""
Identify resistance levels using recent highs
Args:
df: DataFrame with price data
Returns:
List of resistance price levels
"""
# Use the highest points in recent periods
last_month = df.iloc[-30:] if len(df) >= 30 else df
max_price = last_month["high"].max()
# Additional resistance levels
resistance_levels = [
round(max_price, 2),
round(df["close"].iloc[-1] * 1.05, 2), # 5% above current price
round(df["close"].iloc[-1] * 1.10, 2), # 10% above current price
]
return sorted(set(resistance_levels))
def analyze_trend(df: pd.DataFrame) -> int:
"""
Calculate the trend strength of a stock based on various technical indicators.
Args:
df: DataFrame with price and indicator data
Returns:
Integer trend strength score (0-7)
"""
try:
trend_strength = 0
close_price = df["close"].iloc[-1]
# Check SMA 50
sma_50 = df["sma_50"].iloc[-1]
if pd.notna(sma_50) and close_price > sma_50:
trend_strength += 1
# Check EMA 21
ema_21 = df["ema_21"].iloc[-1]
if pd.notna(ema_21) and close_price > ema_21:
trend_strength += 1
# Check EMA 21 vs SMA 50
if pd.notna(ema_21) and pd.notna(sma_50) and ema_21 > sma_50:
trend_strength += 1
# Check SMA 50 vs SMA 200
sma_200 = df["sma_200"].iloc[-1]
if pd.notna(sma_50) and pd.notna(sma_200) and sma_50 > sma_200:
trend_strength += 1
# Check RSI
rsi = df["rsi"].iloc[-1]
if pd.notna(rsi) and rsi > 50:
trend_strength += 1
# Check MACD
macd = df["macd_12_26_9"].iloc[-1]
if pd.notna(macd) and macd > 0:
trend_strength += 1
# Check ADX
adx = df["adx_14"].iloc[-1]
if pd.notna(adx) and adx > 25:
trend_strength += 1
return trend_strength
except Exception as e:
logger.error(f"Error calculating trend strength: {e}")
return 0
def analyze_rsi(df: pd.DataFrame) -> dict[str, Any]:
"""
Analyze RSI indicator
Args:
df: DataFrame with price and indicator data
Returns:
Dictionary with RSI analysis
"""
try:
# Check if dataframe is valid and has RSI column
if df.empty:
return {
"current": None,
"signal": "unavailable",
"description": "No data available for RSI calculation",
}
if "rsi" not in df.columns:
return {
"current": None,
"signal": "unavailable",
"description": "RSI indicator not calculated",
}
if len(df) == 0:
return {
"current": None,
"signal": "unavailable",
"description": "Insufficient data for RSI calculation",
}
rsi = df["rsi"].iloc[-1]
# Check if RSI is NaN
if pd.isna(rsi):
return {
"current": None,
"signal": "unavailable",
"description": "RSI data not available (insufficient data points)",
}
if rsi > 70:
signal = "overbought"
elif rsi < 30:
signal = "oversold"
elif rsi > 50:
signal = "bullish"
else:
signal = "bearish"
return {
"current": round(rsi, 2),
"signal": signal,
"description": f"RSI is currently at {round(rsi, 2)}, indicating {signal} conditions.",
}
except Exception as e:
logger.error(f"Error analyzing RSI: {e}")
return {
"current": None,
"signal": "error",
"description": f"Error calculating RSI: {str(e)}",
}
def analyze_macd(df: pd.DataFrame) -> dict[str, Any]:
"""
Analyze MACD indicator
Args:
df: DataFrame with price and indicator data
Returns:
Dictionary with MACD analysis
"""
try:
macd = df["macd_12_26_9"].iloc[-1]
signal = df["macds_12_26_9"].iloc[-1]
histogram = df["macdh_12_26_9"].iloc[-1]
# Check if any values are NaN
if pd.isna(macd) or pd.isna(signal) or pd.isna(histogram):
return {
"macd": None,
"signal": None,
"histogram": None,
"indicator": "unavailable",
"crossover": "unavailable",
"description": "MACD data not available (insufficient data points)",
}
if macd > signal and histogram > 0:
signal_type = "bullish"
elif macd < signal and histogram < 0:
signal_type = "bearish"
elif macd > signal and macd < 0:
signal_type = "improving"
elif macd < signal and macd > 0:
signal_type = "weakening"
else:
signal_type = "neutral"
# Check for crossover (ensure we have enough data)
crossover = "no recent crossover"
if len(df) >= 2:
prev_macd = df["macd_12_26_9"].iloc[-2]
prev_signal = df["macds_12_26_9"].iloc[-2]
if pd.notna(prev_macd) and pd.notna(prev_signal):
if prev_macd <= prev_signal and macd > signal:
crossover = "bullish crossover detected"
elif prev_macd >= prev_signal and macd < signal:
crossover = "bearish crossover detected"
return {
"macd": round(macd, 2),
"signal": round(signal, 2),
"histogram": round(histogram, 2),
"indicator": signal_type,
"crossover": crossover,
"description": f"MACD is {signal_type} with {crossover}.",
}
except Exception as e:
logger.error(f"Error analyzing MACD: {e}")
return {
"macd": None,
"signal": None,
"histogram": None,
"indicator": "error",
"crossover": "error",
"description": "Error calculating MACD",
}
def analyze_stochastic(df: pd.DataFrame) -> dict[str, Any]:
"""
Analyze Stochastic Oscillator
Args:
df: DataFrame with price and indicator data
Returns:
Dictionary with stochastic oscillator analysis
"""
try:
k = df["stochk_14_3_3"].iloc[-1]
d = df["stochd_14_3_3"].iloc[-1]
# Check if values are NaN
if pd.isna(k) or pd.isna(d):
return {
"k": None,
"d": None,
"signal": "unavailable",
"crossover": "unavailable",
"description": "Stochastic data not available (insufficient data points)",
}
if k > 80 and d > 80:
signal = "overbought"
elif k < 20 and d < 20:
signal = "oversold"
elif k > d:
signal = "bullish"
else:
signal = "bearish"
# Check for crossover (ensure we have enough data)
crossover = "no recent crossover"
if len(df) >= 2:
prev_k = df["stochk_14_3_3"].iloc[-2]
prev_d = df["stochd_14_3_3"].iloc[-2]
if pd.notna(prev_k) and pd.notna(prev_d):
if prev_k <= prev_d and k > d:
crossover = "bullish crossover detected"
elif prev_k >= prev_d and k < d:
crossover = "bearish crossover detected"
return {
"k": round(k, 2),
"d": round(d, 2),
"signal": signal,
"crossover": crossover,
"description": f"Stochastic Oscillator is {signal} with {crossover}.",
}
except Exception as e:
logger.error(f"Error analyzing Stochastic: {e}")
return {
"k": None,
"d": None,
"signal": "error",
"crossover": "error",
"description": "Error calculating Stochastic",
}
def analyze_bollinger_bands(df: pd.DataFrame) -> dict[str, Any]:
"""
Analyze Bollinger Bands
Args:
df: DataFrame with price and indicator data
Returns:
Dictionary with Bollinger Bands analysis
"""
try:
current_price = df["close"].iloc[-1]
upper_band = df["bbu_20_2.0"].iloc[-1]
lower_band = df["bbl_20_2.0"].iloc[-1]
middle_band = df["sma_20"].iloc[-1]
# Check if any values are NaN
if pd.isna(upper_band) or pd.isna(lower_band) or pd.isna(middle_band):
return {
"upper_band": None,
"middle_band": None,
"lower_band": None,
"position": "unavailable",
"signal": "unavailable",
"volatility": "unavailable",
"description": "Bollinger Bands data not available (insufficient data points)",
}
if current_price > upper_band:
position = "above upper band"
signal = "overbought"
elif current_price < lower_band:
position = "below lower band"
signal = "oversold"
elif current_price > middle_band:
position = "above middle band"
signal = "bullish"
else:
position = "below middle band"
signal = "bearish"
# Check for BB squeeze (volatility contraction)
volatility = "stable"
if len(df) >= 5:
try:
bb_widths = []
for i in range(-5, 0):
upper = df["bbu_20_2.0"].iloc[i]
lower = df["bbl_20_2.0"].iloc[i]
middle = df["sma_20"].iloc[i]
if (
pd.notna(upper)
and pd.notna(lower)
and pd.notna(middle)
and middle != 0
):
bb_widths.append((upper - lower) / middle)
if len(bb_widths) == 5:
if all(bb_widths[i] < bb_widths[i - 1] for i in range(1, 5)):
volatility = "contracting (potential breakout ahead)"
elif all(bb_widths[i] > bb_widths[i - 1] for i in range(1, 5)):
volatility = "expanding (increased volatility)"
except Exception:
# If volatility calculation fails, keep it as stable
pass
return {
"upper_band": round(upper_band, 2),
"middle_band": round(middle_band, 2),
"lower_band": round(lower_band, 2),
"position": position,
"signal": signal,
"volatility": volatility,
"description": f"Price is {position}, indicating {signal} conditions. Volatility is {volatility}.",
}
except Exception as e:
logger.error(f"Error analyzing Bollinger Bands: {e}")
return {
"upper_band": None,
"middle_band": None,
"lower_band": None,
"position": "error",
"signal": "error",
"volatility": "error",
"description": "Error calculating Bollinger Bands",
}
def analyze_volume(df: pd.DataFrame) -> dict[str, Any]:
"""
Analyze volume patterns
Args:
df: DataFrame with price and volume data
Returns:
Dictionary with volume analysis
"""
try:
current_volume = df["volume"].iloc[-1]
# Check if we have enough data for average
if len(df) < 10:
avg_volume = df["volume"].mean()
else:
avg_volume = df["volume"].iloc[-10:].mean()
# Check for invalid values
if pd.isna(current_volume) or pd.isna(avg_volume) or avg_volume == 0:
return {
"current": None,
"average": None,
"ratio": None,
"description": "unavailable",
"signal": "unavailable",
}
volume_ratio = current_volume / avg_volume
if volume_ratio > 1.5:
volume_desc = "above average"
if len(df) >= 2 and df["close"].iloc[-1] > df["close"].iloc[-2]:
signal = "bullish (high volume on up move)"
else:
signal = "bearish (high volume on down move)"
elif volume_ratio < 0.7:
volume_desc = "below average"
signal = "weak conviction"
else:
volume_desc = "average"
signal = "neutral"
return {
"current": int(current_volume),
"average": int(avg_volume),
"ratio": round(volume_ratio, 2),
"description": volume_desc,
"signal": signal,
}
except Exception as e:
logger.error(f"Error analyzing volume: {e}")
return {
"current": None,
"average": None,
"ratio": None,
"description": "error",
"signal": "error",
}
def identify_chart_patterns(df: pd.DataFrame) -> list[str]:
"""
Identify common chart patterns
Args:
df: DataFrame with price data
Returns:
List of identified chart patterns
"""
patterns = []
# Check for potential double bottom (W formation)
if len(df) >= 40:
recent_lows = df["low"].iloc[-40:].values
potential_bottoms = []
for i in range(1, len(recent_lows) - 1):
if (
recent_lows[i] < recent_lows[i - 1]
and recent_lows[i] < recent_lows[i + 1]
):
potential_bottoms.append(i)
if (
len(potential_bottoms) >= 2
and potential_bottoms[-1] - potential_bottoms[-2] >= 5
):
if (
abs(
recent_lows[potential_bottoms[-1]]
- recent_lows[potential_bottoms[-2]]
)
/ recent_lows[potential_bottoms[-2]]
< 0.05
):
patterns.append("Double Bottom (W)")
# Check for potential double top (M formation)
if len(df) >= 40:
recent_highs = df["high"].iloc[-40:].values
potential_tops = []
for i in range(1, len(recent_highs) - 1):
if (
recent_highs[i] > recent_highs[i - 1]
and recent_highs[i] > recent_highs[i + 1]
):
potential_tops.append(i)
if len(potential_tops) >= 2 and potential_tops[-1] - potential_tops[-2] >= 5:
if (
abs(recent_highs[potential_tops[-1]] - recent_highs[potential_tops[-2]])
/ recent_highs[potential_tops[-2]]
< 0.05
):
patterns.append("Double Top (M)")
# Check for bullish flag/pennant
if len(df) >= 20:
recent_prices = df["close"].iloc[-20:].values
if (
recent_prices[0] < recent_prices[10]
and all(
recent_prices[i] >= recent_prices[i - 1] * 0.99 for i in range(1, 10)
)
and all(
abs(recent_prices[i] - recent_prices[i - 1]) / recent_prices[i - 1]
< 0.02
for i in range(11, 20)
)
):
patterns.append("Bullish Flag/Pennant")
# Check for bearish flag/pennant
if len(df) >= 20:
recent_prices = df["close"].iloc[-20:].values
if (
recent_prices[0] > recent_prices[10]
and all(
recent_prices[i] <= recent_prices[i - 1] * 1.01 for i in range(1, 10)
)
and all(
abs(recent_prices[i] - recent_prices[i - 1]) / recent_prices[i - 1]
< 0.02
for i in range(11, 20)
)
):
patterns.append("Bearish Flag/Pennant")
return patterns
def calculate_atr(df: pd.DataFrame, period: int = 14) -> pd.Series:
"""
Calculate Average True Range (ATR) for the given dataframe.
Args:
df: DataFrame with high, low, and close price data
period: Period for ATR calculation (default: 14)
Returns:
Series with ATR values
"""
# Ensure column names are lowercase
df_copy = df.copy()
df_copy.columns = [col.lower() for col in df_copy.columns]
# Use pandas_ta to calculate ATR
atr = ta.atr(df_copy["high"], df_copy["low"], df_copy["close"], length=period)
# Ensure we return a Series
if isinstance(atr, pd.Series):
return atr
elif isinstance(atr, pd.DataFrame):
# If it's a DataFrame, take the first column
return pd.Series(atr.iloc[:, 0])
elif atr is not None:
# If it's a numpy array or other iterable
return pd.Series(atr)
else:
# Return empty series if calculation failed
return pd.Series(dtype=float)
def generate_outlook(
df: pd.DataFrame,
trend: str,
rsi_analysis: dict[str, Any],
macd_analysis: dict[str, Any],
stoch_analysis: dict[str, Any],
) -> str:
"""
Generate an overall outlook based on technical analysis
Args:
df: DataFrame with price and indicator data
trend: Trend direction from analyze_trend
rsi_analysis: RSI analysis from analyze_rsi
macd_analysis: MACD analysis from analyze_macd
stoch_analysis: Stochastic analysis from analyze_stochastic
Returns:
String with overall market outlook
"""
bullish_signals = 0
bearish_signals = 0
# Count signals from different indicators
if trend == "uptrend":
bullish_signals += 2
elif trend == "downtrend":
bearish_signals += 2
if rsi_analysis["signal"] == "bullish" or rsi_analysis["signal"] == "oversold":
bullish_signals += 1
elif rsi_analysis["signal"] == "bearish" or rsi_analysis["signal"] == "overbought":
bearish_signals += 1
if (
macd_analysis["indicator"] == "bullish"
or macd_analysis["crossover"] == "bullish crossover detected"
):
bullish_signals += 1
elif (
macd_analysis["indicator"] == "bearish"
or macd_analysis["crossover"] == "bearish crossover detected"
):
bearish_signals += 1
if stoch_analysis["signal"] == "bullish" or stoch_analysis["signal"] == "oversold":
bullish_signals += 1
elif (
stoch_analysis["signal"] == "bearish"
or stoch_analysis["signal"] == "overbought"
):
bearish_signals += 1
# Generate outlook based on signals
if bullish_signals >= 4:
return "strongly bullish"
elif bullish_signals > bearish_signals:
return "moderately bullish"
elif bearish_signals >= 4:
return "strongly bearish"
elif bearish_signals > bullish_signals:
return "moderately bearish"
else:
return "neutral"
def calculate_rsi(df: pd.DataFrame, period: int = 14) -> pd.Series:
"""
Calculate RSI (Relative Strength Index) for the given dataframe.
Args:
df: DataFrame with price data
period: Period for RSI calculation (default: 14)
Returns:
Series with RSI values
"""
# Handle both uppercase and lowercase column names
df_copy = df.copy()
df_copy.columns = [col.lower() for col in df_copy.columns]
# Ensure we have the required 'close' column
if "close" not in df_copy.columns:
raise ValueError("DataFrame must contain a 'close' or 'Close' column")
# Use pandas_ta to calculate RSI
rsi = ta.rsi(df_copy["close"], length=period)
# Ensure we return a Series
if isinstance(rsi, pd.Series):
return rsi
elif rsi is not None:
# If it's a numpy array or other iterable
return pd.Series(rsi, index=df.index)
else:
# Return empty series if calculation failed
return pd.Series(dtype=float, index=df.index)
def calculate_sma(df: pd.DataFrame, period: int) -> pd.Series:
"""
Calculate Simple Moving Average (SMA) for the given dataframe.
Args:
df: DataFrame with price data
period: Period for SMA calculation
Returns:
Series with SMA values
"""
# Handle both uppercase and lowercase column names
df_copy = df.copy()
df_copy.columns = [col.lower() for col in df_copy.columns]
# Ensure we have the required 'close' column
if "close" not in df_copy.columns:
raise ValueError("DataFrame must contain a 'close' or 'Close' column")
# Use pandas_ta to calculate SMA
sma = ta.sma(df_copy["close"], length=period)
# Ensure we return a Series
if isinstance(sma, pd.Series):
return sma
elif sma is not None:
# If it's a numpy array or other iterable
return pd.Series(sma, index=df.index)
else:
# Return empty series if calculation failed
return pd.Series(dtype=float, index=df.index)
```
--------------------------------------------------------------------------------
/tests/utils/test_parallel_screening.py:
--------------------------------------------------------------------------------
```python
"""
Tests for parallel_screening.py - 4x faster multi-stock screening.
This test suite achieves high coverage by testing:
1. Parallel execution logic without actual multiprocessing
2. Error handling and partial failures
3. Process pool management and cleanup
4. Function serialization safety
5. Progress tracking functionality
"""
import asyncio
from concurrent.futures import Future
from unittest.mock import Mock, patch
import numpy as np
import pandas as pd
import pytest
from maverick_mcp.utils.parallel_screening import (
BatchScreener,
ParallelScreener,
example_momentum_screen,
make_parallel_safe,
parallel_screen_async,
)
class TestParallelScreener:
"""Test ParallelScreener context manager and core functionality."""
@patch("maverick_mcp.utils.parallel_screening.ProcessPoolExecutor")
def test_context_manager_creates_executor(self, mock_executor_class):
"""Test that context manager creates and cleans up executor."""
mock_executor = Mock()
mock_executor_class.return_value = mock_executor
with ParallelScreener(max_workers=2) as screener:
assert screener._executor is not None
assert screener._executor == mock_executor
# Verify executor was created with correct parameters
mock_executor_class.assert_called_once_with(max_workers=2)
# Verify shutdown was called
mock_executor.shutdown.assert_called_once_with(wait=True)
@patch("maverick_mcp.utils.parallel_screening.ProcessPoolExecutor")
def test_context_manager_cleanup_on_exception(self, mock_executor_class):
"""Test that executor is cleaned up even on exception."""
mock_executor = Mock()
mock_executor_class.return_value = mock_executor
try:
with ParallelScreener(max_workers=2) as screener:
assert screener._executor is not None
raise ValueError("Test exception")
except ValueError:
pass
# Executor should still be shut down
mock_executor.shutdown.assert_called_once_with(wait=True)
@patch("maverick_mcp.utils.parallel_screening.ProcessPoolExecutor")
@patch("maverick_mcp.utils.parallel_screening.as_completed")
def test_screen_batch_basic(self, mock_as_completed, mock_executor_class):
"""Test basic batch screening functionality."""
# Mock the executor
mock_executor = Mock()
mock_executor_class.return_value = mock_executor
# Mock futures that return batch results
future1 = Mock(spec=Future)
future1.result.return_value = [
{"symbol": "STOCK0", "score": 0.1, "passed": True},
{"symbol": "STOCK1", "score": 0.2, "passed": True},
]
future2 = Mock(spec=Future)
future2.result.return_value = [
{"symbol": "STOCK2", "score": 0.3, "passed": True}
]
# Mock as_completed to return futures in order
mock_as_completed.return_value = [future1, future2]
# Mock submit to return futures
mock_executor.submit.side_effect = [future1, future2]
# Test screening
def test_screen_func(symbol):
return {"symbol": symbol, "score": 0.5, "passed": True}
with ParallelScreener(max_workers=2) as screener:
results = screener.screen_batch(
["STOCK0", "STOCK1", "STOCK2"], test_screen_func, batch_size=2
)
assert len(results) == 3
assert all("symbol" in r for r in results)
assert all("score" in r for r in results)
# Verify the executor was called correctly
assert mock_executor.submit.call_count == 2
@patch("maverick_mcp.utils.parallel_screening.ProcessPoolExecutor")
@patch("maverick_mcp.utils.parallel_screening.as_completed")
def test_screen_batch_with_timeout(self, mock_as_completed, mock_executor_class):
"""Test batch screening with timeout handling."""
mock_executor = Mock()
mock_executor_class.return_value = mock_executor
# Mock submit to return a future
mock_future = Mock(spec=Future)
mock_executor.submit.return_value = mock_future
# Mock as_completed to raise TimeoutError when called
from concurrent.futures import TimeoutError
mock_as_completed.side_effect = TimeoutError("Timeout occurred")
def slow_screen_func(symbol):
return {"symbol": symbol, "score": 0.5, "passed": True}
with ParallelScreener(max_workers=2) as screener:
# This should handle the timeout gracefully by catching the exception
try:
results = screener.screen_batch(
["FAST1", "SLOW", "FAST2"],
slow_screen_func,
timeout=0.5, # 500ms timeout
)
# If no exception, results should be empty since timeout occurred
assert isinstance(results, list)
except TimeoutError:
# If TimeoutError propagates, that's also acceptable behavior
pass
# Verify as_completed was called
mock_as_completed.assert_called()
def test_screen_batch_error_handling(self):
"""Test error handling in batch screening."""
def failing_screen_func(symbol):
if symbol == "FAIL":
raise ValueError(f"Failed to process {symbol}")
return {"symbol": symbol, "score": 0.5, "passed": True}
# Mock screen_batch to simulate error handling
with patch.object(ParallelScreener, "screen_batch") as mock_screen_batch:
# Simulate that only the good symbol passes through after error handling
mock_screen_batch.return_value = [
{"symbol": "GOOD1", "score": 0.5, "passed": True}
]
with ParallelScreener(max_workers=2) as screener:
results = screener.screen_batch(
["GOOD1", "FAIL", "GOOD2"], failing_screen_func
)
# Should get results for successful batch only
assert len(results) == 1
assert results[0]["symbol"] == "GOOD1"
def test_screen_batch_progress_callback(self):
"""Test that screen_batch completes without progress callback."""
# Mock the screen_batch method directly to avoid complex internal mocking
with patch.object(ParallelScreener, "screen_batch") as mock_screen_batch:
mock_screen_batch.return_value = [
{"symbol": "A", "score": 0.5, "passed": True},
{"symbol": "B", "score": 0.5, "passed": True},
{"symbol": "C", "score": 0.5, "passed": True},
{"symbol": "D", "score": 0.5, "passed": True},
]
def quick_screen_func(symbol):
return {"symbol": symbol, "score": 0.5, "passed": True}
with ParallelScreener(max_workers=2) as screener:
results = screener.screen_batch(["A", "B", "C", "D"], quick_screen_func)
# Should get all results
assert len(results) == 4
assert all("symbol" in r for r in results)
def test_screen_batch_custom_batch_size(self):
"""Test custom batch size handling."""
# Mock screen_batch to test that the correct batching logic is applied
with patch.object(ParallelScreener, "screen_batch") as mock_screen_batch:
mock_screen_batch.return_value = [
{"symbol": "A", "score": 0.5, "passed": True},
{"symbol": "B", "score": 0.5, "passed": True},
{"symbol": "C", "score": 0.5, "passed": True},
{"symbol": "D", "score": 0.5, "passed": True},
{"symbol": "E", "score": 0.5, "passed": True},
]
with ParallelScreener(max_workers=2) as screener:
results = screener.screen_batch(
["A", "B", "C", "D", "E"],
lambda x: {"symbol": x, "score": 0.5, "passed": True},
batch_size=2,
)
# Should get all 5 results
assert len(results) == 5
symbols = [r["symbol"] for r in results]
assert symbols == ["A", "B", "C", "D", "E"]
def test_screen_batch_without_context_manager(self):
"""Test that screen_batch raises error when not used as context manager."""
screener = ParallelScreener(max_workers=2)
with pytest.raises(
RuntimeError, match="ParallelScreener must be used as context manager"
):
screener.screen_batch(["TEST"], lambda x: {"symbol": x, "passed": True})
class TestBatchScreener:
"""Test BatchScreener with enhanced progress tracking."""
def test_batch_screener_initialization(self):
"""Test BatchScreener initialization and configuration."""
def dummy_func(symbol):
return {"symbol": symbol, "passed": True}
screener = BatchScreener(dummy_func, max_workers=4)
assert screener.screening_func == dummy_func
assert screener.max_workers == 4
assert screener.results == []
assert screener.progress == 0
assert screener.total == 0
@patch("maverick_mcp.utils.parallel_screening.ParallelScreener")
def test_screen_with_progress(self, mock_parallel_screener_class):
"""Test screening with progress tracking."""
# Mock the ParallelScreener context manager
mock_screener = Mock()
mock_parallel_screener_class.return_value.__enter__.return_value = mock_screener
mock_parallel_screener_class.return_value.__exit__.return_value = None
# Mock screen_batch to return one result per symbol for a single call
# Since BatchScreener may call screen_batch multiple times, we need to handle this
call_count = 0
def mock_screen_batch_side_effect(*args, **kwargs):
nonlocal call_count
call_count += 1
# Return results based on the batch being processed
if call_count == 1:
return [{"symbol": "A", "score": 0.8, "passed": True}]
elif call_count == 2:
return [{"symbol": "B", "score": 0.6, "passed": True}]
else:
return []
mock_screener.screen_batch.side_effect = mock_screen_batch_side_effect
def dummy_func(symbol):
return {"symbol": symbol, "score": 0.5, "passed": True}
batch_screener = BatchScreener(dummy_func)
results = batch_screener.screen_with_progress(["A", "B"])
assert len(results) == 2
assert batch_screener.progress == 2
assert batch_screener.total == 2
def test_get_summary(self):
"""Test summary statistics generation."""
def dummy_func(symbol):
return {"symbol": symbol, "passed": True}
batch_screener = BatchScreener(dummy_func)
batch_screener.results = [
{"symbol": "A", "score": 0.8, "passed": True},
{"symbol": "B", "score": 0.6, "passed": True},
]
batch_screener.progress = 2
batch_screener.total = 4
# Test the actual BatchScreener attributes
assert len(batch_screener.results) == 2
assert batch_screener.progress == 2
assert batch_screener.total == 4
class TestParallelScreenAsync:
"""Test async wrapper for parallel screening."""
@pytest.mark.asyncio
@patch("maverick_mcp.utils.parallel_screening.ParallelScreener")
async def test_parallel_screen_async_basic(self, mock_screener_class):
"""Test basic async parallel screening."""
# Mock the context manager
mock_screener = Mock()
mock_screener_class.return_value.__enter__.return_value = mock_screener
mock_screener_class.return_value.__exit__.return_value = None
# Mock the screen_batch method
mock_screener.screen_batch.return_value = [
{"symbol": "AA", "score": 0.2, "passed": True},
{"symbol": "BBB", "score": 0.3, "passed": True},
{"symbol": "CCCC", "score": 0.4, "passed": True},
]
def simple_screen(symbol):
return {"symbol": symbol, "score": len(symbol) * 0.1, "passed": True}
results = await parallel_screen_async(
["AA", "BBB", "CCCC"], simple_screen, max_workers=2
)
assert len(results) == 3
symbols = [r["symbol"] for r in results]
assert "AA" in symbols
assert "BBB" in symbols
assert "CCCC" in symbols
@pytest.mark.asyncio
@patch("maverick_mcp.utils.parallel_screening.ParallelScreener")
async def test_parallel_screen_async_error_handling(self, mock_screener_class):
"""Test async error handling."""
# Mock the context manager
mock_screener = Mock()
mock_screener_class.return_value.__enter__.return_value = mock_screener
mock_screener_class.return_value.__exit__.return_value = None
# Mock screen_batch to return only successful results
mock_screener.screen_batch.return_value = [
{"symbol": "OK1", "score": 0.5, "passed": True},
{"symbol": "OK2", "score": 0.5, "passed": True},
]
def failing_screen(symbol):
if symbol == "FAIL":
raise ValueError("Screen failed")
return {"symbol": symbol, "score": 0.5, "passed": True}
results = await parallel_screen_async(["OK1", "FAIL", "OK2"], failing_screen)
# Should only get results for successful symbols
assert len(results) == 2
assert all(r["symbol"] in ["OK1", "OK2"] for r in results)
class TestMakeParallelSafe:
"""Test make_parallel_safe decorator."""
def test_make_parallel_safe_basic(self):
"""Test basic function wrapping."""
@make_parallel_safe
def test_func(x):
return x * 2
result = test_func(5)
assert result == 10
def test_make_parallel_safe_with_exception(self):
"""Test exception handling in wrapped function."""
@make_parallel_safe
def failing_func(x):
raise ValueError(f"Failed with {x}")
result = failing_func(5)
assert isinstance(result, dict)
assert result["error"] == "Failed with 5"
assert result["passed"] is False
def test_make_parallel_safe_serialization(self):
"""Test that wrapped function results are JSON serializable."""
@make_parallel_safe
def complex_func(symbol):
# Return something that might not be JSON serializable
return {
"symbol": symbol,
"data": pd.DataFrame(
{"A": [1, 2, 3]}
), # DataFrame not JSON serializable
"array": np.array([1, 2, 3]), # numpy array not JSON serializable
}
result = complex_func("TEST")
# Should handle non-serializable data
assert result["passed"] is False
assert "error" in result
assert "not JSON serializable" in str(result["error"])
def test_make_parallel_safe_preserves_metadata(self):
"""Test that decorator preserves function metadata."""
@make_parallel_safe
def documented_func(x):
"""This is a documented function."""
return x
assert documented_func.__name__ == "documented_func"
assert documented_func.__doc__ == "This is a documented function."
class TestExampleMomentumScreen:
"""Test the example momentum screening function."""
@patch("maverick_mcp.core.technical_analysis.calculate_rsi")
@patch("maverick_mcp.core.technical_analysis.calculate_sma")
@patch("maverick_mcp.providers.stock_data.StockDataProvider")
def test_example_momentum_screen_success(
self, mock_provider_class, mock_sma, mock_rsi
):
"""Test successful momentum screening."""
# Mock stock data provider
mock_provider = Mock()
mock_provider_class.return_value = mock_provider
# Mock stock data with enough length
dates = pd.date_range(end="2024-01-01", periods=100, freq="D")
mock_df = pd.DataFrame(
{
"Close": np.random.uniform(100, 105, 100),
"Volume": np.random.randint(1000, 1300, 100),
},
index=dates,
)
mock_provider.get_stock_data.return_value = mock_df
# Mock technical indicators
mock_rsi.return_value = pd.Series([62] * 100, index=dates)
mock_sma.return_value = pd.Series([102] * 100, index=dates)
result = example_momentum_screen("TEST")
assert result["symbol"] == "TEST"
assert result["passed"] in [True, False]
assert "price" in result
assert "sma_50" in result
assert "rsi" in result
assert "above_sma" in result
assert result.get("error", False) is False
@patch("maverick_mcp.providers.stock_data.StockDataProvider")
def test_example_momentum_screen_error(self, mock_provider_class):
"""Test error handling in momentum screening."""
# Mock provider to raise exception
mock_provider = Mock()
mock_provider_class.return_value = mock_provider
mock_provider.get_stock_data.side_effect = Exception("Data fetch failed")
result = example_momentum_screen("FAIL")
assert result["symbol"] == "FAIL"
assert result["passed"] is False
assert result.get("error") == "Data fetch failed"
class TestPerformanceValidation:
"""Test performance improvements and speedup validation."""
def test_parallel_vs_sequential_speedup(self):
"""Test that parallel processing logic is called correctly."""
def mock_screen_func(symbol):
return {"symbol": symbol, "score": 0.5, "passed": True}
symbols = [f"STOCK{i}" for i in range(8)]
# Sequential results (for comparison)
sequential_results = []
for symbol in symbols:
result = mock_screen_func(symbol)
if result.get("passed", False):
sequential_results.append(result)
# Mock screen_batch method to return all results without actual multiprocessing
with patch.object(ParallelScreener, "screen_batch") as mock_screen_batch:
mock_screen_batch.return_value = [
{"symbol": f"STOCK{i}", "score": 0.5, "passed": True} for i in range(8)
]
# Parallel results using mocked screener
with ParallelScreener(max_workers=4) as screener:
parallel_results = screener.screen_batch(symbols, mock_screen_func)
# Verify both approaches produce the same number of results
assert len(parallel_results) == len(sequential_results)
assert len(parallel_results) == 8
# Verify ParallelScreener was used correctly
mock_screen_batch.assert_called_once()
def test_optimal_batch_size_calculation(self):
"""Test that batch size is calculated optimally."""
# Mock screen_batch to verify the batching logic works
with patch.object(ParallelScreener, "screen_batch") as mock_screen_batch:
mock_screen_batch.return_value = [
{"symbol": f"S{i}", "score": 0.5, "passed": True} for i in range(10)
]
# Small dataset - should use smaller batches
with ParallelScreener(max_workers=4) as screener:
results = screener.screen_batch(
[f"S{i}" for i in range(10)],
lambda x: {"symbol": x, "score": 0.5, "passed": True},
)
# Check that results are as expected
assert len(results) == 10
symbols = [r["symbol"] for r in results]
expected_symbols = [f"S{i}" for i in range(10)]
assert symbols == expected_symbols
class TestEdgeCases:
"""Test edge cases and error conditions."""
@patch("maverick_mcp.utils.parallel_screening.ProcessPoolExecutor")
@patch("maverick_mcp.utils.parallel_screening.as_completed")
def test_empty_symbol_list(self, mock_as_completed, mock_executor_class):
"""Test handling of empty symbol list."""
mock_executor = Mock()
mock_executor_class.return_value = mock_executor
# Empty list should result in no futures
mock_as_completed.return_value = []
with ParallelScreener() as screener:
results = screener.screen_batch([], lambda x: {"symbol": x})
assert results == []
# Should not submit any jobs for empty list
mock_executor.submit.assert_not_called()
@patch("maverick_mcp.utils.parallel_screening.ProcessPoolExecutor")
@patch("maverick_mcp.utils.parallel_screening.as_completed")
def test_single_symbol(self, mock_as_completed, mock_executor_class):
"""Test handling of single symbol."""
mock_executor = Mock()
mock_executor_class.return_value = mock_executor
# Mock single future
future = Mock(spec=Future)
future.result.return_value = [
{"symbol": "SINGLE", "score": 1.0, "passed": True}
]
mock_as_completed.return_value = [future]
mock_executor.submit.return_value = future
with ParallelScreener() as screener:
results = screener.screen_batch(
["SINGLE"], lambda x: {"symbol": x, "score": 1.0, "passed": True}
)
assert len(results) == 1
assert results[0]["symbol"] == "SINGLE"
def test_non_picklable_function(self):
"""Test handling of non-picklable screening function."""
# Lambda functions are not picklable in some Python versions
def non_picklable(x):
return {"symbol": x}
with ParallelScreener() as screener:
# Should handle gracefully
try:
results = screener.screen_batch(["TEST"], non_picklable)
# If it works, that's fine
assert len(results) <= 1
except Exception as e:
# If it fails, should be a pickling error
assert "pickle" in str(e).lower() or "serializ" in str(e).lower()
def test_keyboard_interrupt_handling(self):
"""Test handling of keyboard interrupts."""
def interruptible_func(symbol):
if symbol == "INTERRUPT":
raise KeyboardInterrupt()
return {"symbol": symbol, "passed": True}
# Mock screen_batch to simulate partial results due to interrupt
with patch.object(ParallelScreener, "screen_batch") as mock_screen_batch:
mock_screen_batch.return_value = [{"symbol": "OK", "passed": True}]
with ParallelScreener() as screener:
# The screen_batch should handle the exception gracefully
results = screener.screen_batch(
["OK", "INTERRUPT", "NEVER_REACHED"], interruptible_func
)
# Should get results for OK only since INTERRUPT will fail
assert len(results) == 1
assert results[0]["symbol"] == "OK"
@patch("maverick_mcp.utils.parallel_screening.ProcessPoolExecutor")
@patch("maverick_mcp.utils.parallel_screening.as_completed")
def test_very_large_batch(self, mock_as_completed, mock_executor_class):
"""Test handling of very large symbol batches."""
mock_executor = Mock()
mock_executor_class.return_value = mock_executor
# Create a large list of symbols
large_symbol_list = [f"SYM{i:05d}" for i in range(100)]
# Mock futures for 10 batches (100 symbols / 10 per batch)
futures = []
for i in range(10):
future = Mock(spec=Future)
batch_start = i * 10
batch_end = min((i + 1) * 10, 100)
batch_results = [
{"symbol": f"SYM{j:05d}", "id": j, "passed": True}
for j in range(batch_start, batch_end)
]
future.result.return_value = batch_results
futures.append(future)
mock_as_completed.return_value = futures
mock_executor.submit.side_effect = futures
def quick_func(symbol):
return {"symbol": symbol, "id": int(symbol[3:]), "passed": True}
with ParallelScreener(max_workers=4) as screener:
results = screener.screen_batch(
large_symbol_list, quick_func, batch_size=10
)
# Should process all symbols that passed
assert len(results) == 100
# Extract IDs and verify we got all symbols
result_ids = sorted([r["id"] for r in results])
assert result_ids == list(range(100))
class TestIntegration:
"""Integration tests with real technical analysis."""
@pytest.mark.integration
@patch("maverick_mcp.utils.parallel_screening.ParallelScreener")
def test_full_screening_workflow(self, mock_screener_class):
"""Test complete screening workflow."""
# Mock the context manager
mock_screener = Mock()
mock_screener_class.return_value.__enter__.return_value = mock_screener
mock_screener_class.return_value.__exit__.return_value = None
# This would test with real data if available
symbols = ["AAPL", "GOOGL", "MSFT"]
# Mock screen_batch to return realistic results
mock_screener.screen_batch.return_value = [
{"symbol": "AAPL", "passed": True, "price": 150.0, "error": False},
{"symbol": "GOOGL", "passed": False, "error": "Insufficient data"},
{"symbol": "MSFT", "passed": True, "price": 300.0, "error": False},
]
async def run_screening():
results = await parallel_screen_async(
symbols, example_momentum_screen, max_workers=2
)
return results
# Run the async screening
results = asyncio.run(run_screening())
# Should get results for all symbols (or errors)
assert len(results) == len(symbols)
for result in results:
assert "symbol" in result
assert "error" in result
```
--------------------------------------------------------------------------------
/tests/test_langgraph_workflow.py:
--------------------------------------------------------------------------------
```python
"""
Comprehensive tests for LangGraph backtesting workflow.
Tests cover:
- LangGraph workflow state transitions and agent orchestration
- Market regime analysis workflow steps
- Strategy selection and parameter optimization
- Results validation and recommendation generation
- Error handling and fallback strategies
- Performance benchmarks and timing
"""
import asyncio
import logging
from datetime import datetime
from unittest.mock import AsyncMock, Mock
import pytest
from maverick_mcp.workflows.agents import (
MarketAnalyzerAgent,
OptimizerAgent,
StrategySelectorAgent,
ValidatorAgent,
)
from maverick_mcp.workflows.backtesting_workflow import BacktestingWorkflow
from maverick_mcp.workflows.state import BacktestingWorkflowState
logger = logging.getLogger(__name__)
class TestBacktestingWorkflow:
"""Test suite for BacktestingWorkflow class."""
@pytest.fixture
def sample_workflow_state(self) -> BacktestingWorkflowState:
"""Create a sample workflow state for testing."""
from langchain_core.messages import HumanMessage
return BacktestingWorkflowState(
# Base agent state
messages=[HumanMessage(content="Analyze AAPL for backtesting")],
session_id="test_session_123",
persona="intelligent_backtesting_agent",
timestamp=datetime.now(),
token_count=0,
error=None,
analyzed_stocks={},
key_price_levels={},
last_analysis_time={},
conversation_context={},
execution_time_ms=None,
api_calls_made=0,
cache_hits=0,
cache_misses=0,
# Input parameters
symbol="AAPL",
start_date="2023-01-01",
end_date="2023-12-31",
initial_capital=10000.0,
requested_strategy=None,
# Market regime analysis (initialized)
market_regime="bullish",
regime_confidence=0.85,
regime_indicators={
"trend_strength": 0.75,
"volatility": 0.25,
"momentum": 0.80,
},
regime_analysis_time_ms=150.0,
volatility_percentile=35.0,
trend_strength=0.75,
market_conditions={
"trend": "upward",
"volatility": "low",
"volume": "normal",
},
sector_performance={"technology": 0.15},
correlation_to_market=0.75,
volume_profile={"average": 50000000, "relative": 1.2},
support_resistance_levels=[150.0, 160.0, 170.0],
# Strategy selection (initialized)
candidate_strategies=["momentum", "mean_reversion", "breakout"],
strategy_rankings={"momentum": 0.9, "breakout": 0.7, "mean_reversion": 0.6},
selected_strategies=["momentum", "breakout"],
strategy_selection_reasoning="High momentum and trend strength favor momentum strategies",
strategy_selection_confidence=0.85,
# Parameter optimization (initialized)
optimization_config={"method": "grid_search", "cv_folds": 5},
parameter_grids={
"momentum": {"window": [10, 20, 30], "threshold": [0.01, 0.02]}
},
optimization_results={
"momentum": {
"best_sharpe": 1.5,
"best_params": {"window": 20, "threshold": 0.02},
}
},
best_parameters={"momentum": {"window": 20, "threshold": 0.02}},
optimization_time_ms=2500.0,
optimization_iterations=45,
# Validation (initialized)
walk_forward_results={"out_of_sample_sharpe": 1.2, "degradation": 0.2},
monte_carlo_results={"confidence_95": 0.8, "max_drawdown_95": 0.15},
out_of_sample_performance={"sharpe": 1.2, "return": 0.18},
robustness_score={"overall": 0.75, "parameter_sensitivity": 0.8},
validation_warnings=["High parameter sensitivity detected"],
# Final recommendations (initialized)
final_strategy_ranking=[
{"strategy": "momentum", "score": 0.9, "confidence": 0.85}
],
recommended_strategy="momentum",
recommended_parameters={"window": 20, "threshold": 0.02},
recommendation_confidence=0.85,
risk_assessment={"max_drawdown": 0.15, "volatility": 0.25},
# Performance metrics (initialized)
comparative_metrics={"sharpe_vs_benchmark": 1.5, "alpha": 0.05},
benchmark_comparison={"excess_return": 0.08, "information_ratio": 0.6},
risk_adjusted_performance={"calmar": 1.0, "sortino": 1.8},
drawdown_analysis={"max_dd": 0.15, "avg_dd": 0.05, "recovery_days": 30},
# Workflow control (initialized)
workflow_status="analyzing_regime",
current_step="market_analysis",
steps_completed=["initialization"],
total_execution_time_ms=0.0,
# Error handling (initialized)
errors_encountered=[],
fallback_strategies_used=[],
data_quality_issues=[],
# Caching (initialized)
cached_results={},
cache_hit_rate=0.0,
# Advanced analysis (initialized)
regime_transition_analysis={},
multi_timeframe_analysis={},
correlation_analysis={},
macroeconomic_context={},
)
@pytest.fixture
def mock_agents(self):
"""Create mock agents for testing."""
market_analyzer = Mock(spec=MarketAnalyzerAgent)
strategy_selector = Mock(spec=StrategySelectorAgent)
optimizer = Mock(spec=OptimizerAgent)
validator = Mock(spec=ValidatorAgent)
# Set up successful mock responses
async def mock_analyze_market_regime(state):
state.market_regime = "bullish"
state.regime_confidence = 0.85
state.workflow_status = "selecting_strategies"
state.steps_completed.append("market_analysis")
return state
async def mock_select_strategies(state):
state.selected_strategies = ["momentum", "breakout"]
state.strategy_selection_confidence = 0.85
state.workflow_status = "optimizing_parameters"
state.steps_completed.append("strategy_selection")
return state
async def mock_optimize_parameters(state):
state.best_parameters = {"momentum": {"window": 20, "threshold": 0.02}}
state.optimization_iterations = 45
state.workflow_status = "validating_results"
state.steps_completed.append("parameter_optimization")
return state
async def mock_validate_strategies(state):
state.recommended_strategy = "momentum"
state.recommendation_confidence = 0.85
state.workflow_status = "completed"
state.steps_completed.append("validation")
return state
market_analyzer.analyze_market_regime = AsyncMock(
side_effect=mock_analyze_market_regime
)
strategy_selector.select_strategies = AsyncMock(
side_effect=mock_select_strategies
)
optimizer.optimize_parameters = AsyncMock(side_effect=mock_optimize_parameters)
validator.validate_strategies = AsyncMock(side_effect=mock_validate_strategies)
return {
"market_analyzer": market_analyzer,
"strategy_selector": strategy_selector,
"optimizer": optimizer,
"validator": validator,
}
@pytest.fixture
def workflow_with_mocks(self, mock_agents):
"""Create a workflow with mocked agents."""
return BacktestingWorkflow(
market_analyzer=mock_agents["market_analyzer"],
strategy_selector=mock_agents["strategy_selector"],
optimizer=mock_agents["optimizer"],
validator=mock_agents["validator"],
)
async def test_workflow_initialization(self):
"""Test workflow initialization creates proper graph structure."""
workflow = BacktestingWorkflow()
# Test workflow has been compiled
assert workflow.workflow is not None
# Test agent initialization
assert workflow.market_analyzer is not None
assert workflow.strategy_selector is not None
assert workflow.optimizer is not None
assert workflow.validator is not None
# Test workflow nodes exist
nodes = workflow.workflow.get_graph().nodes()
expected_nodes = [
"initialize",
"analyze_market_regime",
"select_strategies",
"optimize_parameters",
"validate_results",
"finalize_workflow",
]
for node in expected_nodes:
assert node in nodes
async def test_successful_workflow_execution(self, workflow_with_mocks):
"""Test successful end-to-end workflow execution."""
start_time = datetime.now()
result = await workflow_with_mocks.run_intelligent_backtest(
symbol="AAPL",
start_date="2023-01-01",
end_date="2023-12-31",
initial_capital=10000.0,
)
execution_time = datetime.now() - start_time
# Test basic structure
assert "symbol" in result
assert result["symbol"] == "AAPL"
assert "execution_metadata" in result
# Test workflow completion
exec_metadata = result["execution_metadata"]
assert exec_metadata["workflow_completed"] is True
assert "initialization" in exec_metadata["steps_completed"]
assert "market_analysis" in exec_metadata["steps_completed"]
assert "strategy_selection" in exec_metadata["steps_completed"]
# Test recommendation structure
assert "recommendation" in result
recommendation = result["recommendation"]
assert recommendation["recommended_strategy"] == "momentum"
assert recommendation["recommendation_confidence"] == 0.85
# Test performance
assert exec_metadata["total_execution_time_ms"] > 0
assert (
execution_time.total_seconds() < 5.0
) # Should complete quickly with mocks
async def test_market_analysis_conditional_routing(
self, workflow_with_mocks, sample_workflow_state
):
"""Test conditional routing after market analysis step."""
workflow = workflow_with_mocks
# Test successful routing
result = workflow._should_proceed_after_market_analysis(sample_workflow_state)
assert result == "continue"
# Test failure routing - unknown regime with low confidence
failure_state = sample_workflow_state.copy()
failure_state.market_regime = "unknown"
failure_state.regime_confidence = 0.05
result = workflow._should_proceed_after_market_analysis(failure_state)
assert result == "fallback"
# Test error routing
error_state = sample_workflow_state.copy()
error_state.errors_encountered = [
{"step": "market_regime_analysis", "error": "Data unavailable"}
]
result = workflow._should_proceed_after_market_analysis(error_state)
assert result == "fallback"
async def test_strategy_selection_conditional_routing(
self, workflow_with_mocks, sample_workflow_state
):
"""Test conditional routing after strategy selection step."""
workflow = workflow_with_mocks
# Test successful routing
result = workflow._should_proceed_after_strategy_selection(
sample_workflow_state
)
assert result == "continue"
# Test failure routing - no strategies selected
failure_state = sample_workflow_state.copy()
failure_state.selected_strategies = []
result = workflow._should_proceed_after_strategy_selection(failure_state)
assert result == "fallback"
# Test low confidence routing
low_conf_state = sample_workflow_state.copy()
low_conf_state.strategy_selection_confidence = 0.1
result = workflow._should_proceed_after_strategy_selection(low_conf_state)
assert result == "fallback"
async def test_optimization_conditional_routing(
self, workflow_with_mocks, sample_workflow_state
):
"""Test conditional routing after parameter optimization step."""
workflow = workflow_with_mocks
# Test successful routing
result = workflow._should_proceed_after_optimization(sample_workflow_state)
assert result == "continue"
# Test failure routing - no best parameters
failure_state = sample_workflow_state.copy()
failure_state.best_parameters = {}
result = workflow._should_proceed_after_optimization(failure_state)
assert result == "fallback"
async def test_workflow_state_transitions(self, workflow_with_mocks):
"""Test that workflow state transitions occur correctly."""
workflow = workflow_with_mocks
# Create initial state
initial_state = workflow._create_initial_state(
symbol="AAPL",
start_date="2023-01-01",
end_date="2023-12-31",
initial_capital=10000.0,
requested_strategy=None,
)
# Test initialization step
state = await workflow._initialize_workflow(initial_state)
assert "initialization" in state.steps_completed
assert state.workflow_status == "analyzing_regime"
assert state.current_step == "initialization_completed"
async def test_workflow_error_handling(self, workflow_with_mocks):
"""Test workflow error handling and recovery."""
# Create workflow with failing market analyzer
workflow = workflow_with_mocks
async def failing_market_analyzer(state):
state.errors_encountered.append(
{
"step": "market_regime_analysis",
"error": "API unavailable",
"timestamp": datetime.now().isoformat(),
}
)
return state
workflow.market_analyzer.analyze_market_regime = AsyncMock(
side_effect=failing_market_analyzer
)
result = await workflow.run_intelligent_backtest(
symbol="AAPL", start_date="2023-01-01", end_date="2023-12-31"
)
# Test that workflow handles error gracefully
assert "execution_metadata" in result
exec_metadata = result["execution_metadata"]
assert len(exec_metadata["errors_encountered"]) > 0
# Test fallback behavior
assert len(exec_metadata["fallback_strategies_used"]) > 0
async def test_workflow_performance_benchmarks(
self, workflow_with_mocks, benchmark_timer
):
"""Test workflow performance meets benchmarks."""
workflow = workflow_with_mocks
with benchmark_timer() as timer:
result = await workflow.run_intelligent_backtest(
symbol="AAPL", start_date="2023-01-01", end_date="2023-12-31"
)
# Test performance benchmarks
execution_time = result["execution_metadata"]["total_execution_time_ms"]
actual_time = timer.elapsed * 1000
# Should complete within reasonable time with mocks
assert execution_time < 1000 # < 1 second
assert actual_time < 5000 # < 5 seconds actual
# Test execution metadata accuracy
assert abs(execution_time - actual_time) < 100 # Within 100ms tolerance
async def test_quick_analysis_workflow(self, workflow_with_mocks):
"""Test quick analysis workflow bypass."""
workflow = workflow_with_mocks
result = await workflow.run_quick_analysis(
symbol="AAPL", start_date="2023-01-01", end_date="2023-12-31"
)
# Test quick analysis structure
assert result["analysis_type"] == "quick_analysis"
assert "market_regime" in result
assert "recommended_strategies" in result
assert "execution_time_ms" in result
# Test performance - quick analysis should be faster
assert result["execution_time_ms"] < 500 # < 500ms
# Test that it skips optimization and validation
assert "optimization" not in result
assert "validation" not in result
async def test_workflow_status_tracking(
self, workflow_with_mocks, sample_workflow_state
):
"""Test workflow status tracking and progress reporting."""
workflow = workflow_with_mocks
# Test initial status
status = workflow.get_workflow_status(sample_workflow_state)
assert status["workflow_status"] == sample_workflow_state.workflow_status
assert status["current_step"] == sample_workflow_state.current_step
assert status["progress_percentage"] >= 0
assert status["progress_percentage"] <= 100
assert (
status["recommended_strategy"] == sample_workflow_state.recommended_strategy
)
# Test progress calculation
expected_progress = (len(sample_workflow_state.steps_completed) / 5) * 100
assert status["progress_percentage"] == expected_progress
async def test_workflow_with_requested_strategy(self, workflow_with_mocks):
"""Test workflow behavior with user-requested strategy."""
workflow = workflow_with_mocks
result = await workflow.run_intelligent_backtest(
symbol="AAPL",
requested_strategy="momentum",
start_date="2023-01-01",
end_date="2023-12-31",
)
# Test that requested strategy is considered
assert "strategy_selection" in result
strategy_info = result["strategy_selection"]
# Should influence selection (mock will still return its default, but in real implementation would consider)
assert len(strategy_info["selected_strategies"]) > 0
async def test_workflow_fallback_handling(
self, workflow_with_mocks, sample_workflow_state
):
"""Test workflow fallback strategy handling."""
workflow = workflow_with_mocks
# Create incomplete state that triggers fallback
incomplete_state = sample_workflow_state.copy()
incomplete_state.workflow_status = "incomplete"
incomplete_state.recommended_strategy = ""
incomplete_state.best_parameters = {"momentum": {"window": 20}}
final_state = await workflow._finalize_workflow(incomplete_state)
# Test fallback behavior
assert (
final_state.recommended_strategy == "momentum"
) # Should use first available
assert final_state.recommendation_confidence == 0.3 # Low confidence fallback
assert "incomplete_workflow_fallback" in final_state.fallback_strategies_used
async def test_workflow_results_formatting(
self, workflow_with_mocks, sample_workflow_state
):
"""Test comprehensive results formatting."""
workflow = workflow_with_mocks
# Set completed status for full results
complete_state = sample_workflow_state.copy()
complete_state.workflow_status = "completed"
results = workflow._format_results(complete_state)
# Test all major sections are present
expected_sections = [
"symbol",
"period",
"market_analysis",
"strategy_selection",
"optimization",
"validation",
"recommendation",
"performance_analysis",
]
for section in expected_sections:
assert section in results
# Test detailed content
assert results["market_analysis"]["regime"] == "bullish"
assert results["strategy_selection"]["selection_confidence"] == 0.85
assert results["optimization"]["optimization_iterations"] == 45
assert results["recommendation"]["recommended_strategy"] == "momentum"
class TestLangGraphIntegration:
"""Test suite for LangGraph-specific integration aspects."""
async def test_langgraph_state_serialization(self, sample_workflow_state):
"""Test that workflow state can be properly serialized/deserialized for LangGraph."""
# Test JSON serialization compatibility
import json
# Extract serializable data
serializable_data = {
"symbol": sample_workflow_state.symbol,
"workflow_status": sample_workflow_state.workflow_status,
"market_regime": sample_workflow_state.market_regime,
"regime_confidence": sample_workflow_state.regime_confidence,
"selected_strategies": sample_workflow_state.selected_strategies,
"recommendation_confidence": sample_workflow_state.recommendation_confidence,
}
# Test serialization
serialized = json.dumps(serializable_data)
deserialized = json.loads(serialized)
assert deserialized["symbol"] == "AAPL"
assert deserialized["market_regime"] == "bullish"
assert deserialized["regime_confidence"] == 0.85
async def test_langgraph_message_flow(self, workflow_with_mocks):
"""Test message flow through LangGraph nodes."""
workflow = workflow_with_mocks
# Test that messages are properly handled
result = await workflow.run_intelligent_backtest(
symbol="AAPL", start_date="2023-01-01", end_date="2023-12-31"
)
assert isinstance(result, dict)
assert result.get("symbol") == "AAPL"
# Verify mock agents were called in sequence
workflow.market_analyzer.analyze_market_regime.assert_called_once()
workflow.strategy_selector.select_strategies.assert_called_once()
workflow.optimizer.optimize_parameters.assert_called_once()
workflow.validator.validate_strategies.assert_called_once()
async def test_langgraph_conditional_edges(self, workflow_with_mocks):
"""Test LangGraph conditional edge routing logic."""
workflow = workflow_with_mocks
# Create states that should trigger different routing
good_state = Mock()
good_state.market_regime = "bullish"
good_state.regime_confidence = 0.8
good_state.errors_encountered = []
good_state.selected_strategies = ["momentum"]
good_state.strategy_selection_confidence = 0.7
good_state.best_parameters = {"momentum": {}}
bad_state = Mock()
bad_state.market_regime = "unknown"
bad_state.regime_confidence = 0.1
bad_state.errors_encountered = [{"step": "test", "error": "test"}]
bad_state.selected_strategies = []
bad_state.strategy_selection_confidence = 0.1
bad_state.best_parameters = {}
# Test routing decisions
assert workflow._should_proceed_after_market_analysis(good_state) == "continue"
assert workflow._should_proceed_after_market_analysis(bad_state) == "fallback"
assert (
workflow._should_proceed_after_strategy_selection(good_state) == "continue"
)
assert (
workflow._should_proceed_after_strategy_selection(bad_state) == "fallback"
)
assert workflow._should_proceed_after_optimization(good_state) == "continue"
assert workflow._should_proceed_after_optimization(bad_state) == "fallback"
class TestWorkflowStressTests:
"""Stress tests for workflow performance and reliability."""
async def test_concurrent_workflow_execution(self, workflow_with_mocks):
"""Test concurrent execution of multiple workflows."""
workflow = workflow_with_mocks
symbols = ["AAPL", "GOOGL", "MSFT", "TSLA", "AMZN"]
# Run multiple workflows concurrently
tasks = []
for symbol in symbols:
task = workflow.run_intelligent_backtest(
symbol=symbol, start_date="2023-01-01", end_date="2023-12-31"
)
tasks.append(task)
results = await asyncio.gather(*tasks, return_exceptions=True)
# Test all succeeded
assert len(results) == len(symbols)
for i, result in enumerate(results):
assert not isinstance(result, Exception)
assert result["symbol"] == symbols[i]
async def test_workflow_memory_usage(self, workflow_with_mocks):
"""Test workflow memory usage doesn't grow excessively."""
import os
import psutil
process = psutil.Process(os.getpid())
initial_memory = process.memory_info().rss
workflow = workflow_with_mocks
# Run multiple workflows
for i in range(10):
await workflow.run_intelligent_backtest(
symbol=f"TEST{i}", start_date="2023-01-01", end_date="2023-12-31"
)
final_memory = process.memory_info().rss
memory_growth = (final_memory - initial_memory) / 1024 / 1024 # MB
# Memory growth should be reasonable (< 50MB for 10 workflows)
assert memory_growth < 50
async def test_workflow_error_recovery(self, mock_agents):
"""Test workflow recovery from various error conditions."""
# Create workflow with intermittently failing agents
failure_count = 0
async def intermittent_failure(state):
nonlocal failure_count
failure_count += 1
if failure_count <= 2:
raise Exception("Simulated failure")
# Eventually succeed
state.market_regime = "bullish"
state.regime_confidence = 0.8
state.workflow_status = "selecting_strategies"
state.steps_completed.append("market_analysis")
return state
mock_agents["market_analyzer"].analyze_market_regime = AsyncMock(
side_effect=intermittent_failure
)
workflow = BacktestingWorkflow(
market_analyzer=mock_agents["market_analyzer"],
strategy_selector=mock_agents["strategy_selector"],
optimizer=mock_agents["optimizer"],
validator=mock_agents["validator"],
)
# This should eventually succeed despite initial failures
try:
result = await workflow.run_intelligent_backtest(
symbol="AAPL", start_date="2023-01-01", end_date="2023-12-31"
)
# If we reach here, the workflow had some form of error handling
assert "error" in result or "execution_metadata" in result
except Exception:
# Expected for this test - workflow should handle gracefully
pass
if __name__ == "__main__":
# Run tests with detailed output
pytest.main([__file__, "-v", "--tb=short", "--asyncio-mode=auto"])
```