This is page 16 of 29. Use http://codebase.md/wshobson/maverick-mcp?page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.md
│ │ ├── config.yml
│ │ ├── feature_request.md
│ │ ├── question.md
│ │ └── security_report.md
│ ├── pull_request_template.md
│ └── workflows
│ ├── claude-code-review.yml
│ └── claude.yml
├── .gitignore
├── .python-version
├── .vscode
│ ├── launch.json
│ └── settings.json
├── alembic
│ ├── env.py
│ ├── script.py.mako
│ └── versions
│ ├── 001_initial_schema.py
│ ├── 003_add_performance_indexes.py
│ ├── 006_rename_metadata_columns.py
│ ├── 008_performance_optimization_indexes.py
│ ├── 009_rename_to_supply_demand.py
│ ├── 010_self_contained_schema.py
│ ├── 011_remove_proprietary_terms.py
│ ├── 013_add_backtest_persistence_models.py
│ ├── 014_add_portfolio_models.py
│ ├── 08e3945a0c93_merge_heads.py
│ ├── 9374a5c9b679_merge_heads_for_testing.py
│ ├── abf9b9afb134_merge_multiple_heads.py
│ ├── adda6d3fd84b_merge_proprietary_terms_removal_with_.py
│ ├── e0c75b0bdadb_fix_financial_data_precision_only.py
│ ├── f0696e2cac15_add_essential_performance_indexes.py
│ └── fix_database_integrity_issues.py
├── alembic.ini
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── DATABASE_SETUP.md
├── docker-compose.override.yml.example
├── docker-compose.yml
├── Dockerfile
├── docs
│ ├── api
│ │ └── backtesting.md
│ ├── BACKTESTING.md
│ ├── COST_BASIS_SPECIFICATION.md
│ ├── deep_research_agent.md
│ ├── exa_research_testing_strategy.md
│ ├── PORTFOLIO_PERSONALIZATION_PLAN.md
│ ├── PORTFOLIO.md
│ ├── SETUP_SELF_CONTAINED.md
│ └── speed_testing_framework.md
├── examples
│ ├── complete_speed_validation.py
│ ├── deep_research_integration.py
│ ├── llm_optimization_example.py
│ ├── llm_speed_demo.py
│ ├── monitoring_example.py
│ ├── parallel_research_example.py
│ ├── speed_optimization_demo.py
│ └── timeout_fix_demonstration.py
├── LICENSE
├── Makefile
├── MANIFEST.in
├── maverick_mcp
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── circuit_breaker.py
│ │ ├── deep_research.py
│ │ ├── market_analysis.py
│ │ ├── optimized_research.py
│ │ ├── supervisor.py
│ │ └── technical_analysis.py
│ ├── api
│ │ ├── __init__.py
│ │ ├── api_server.py
│ │ ├── connection_manager.py
│ │ ├── dependencies
│ │ │ ├── __init__.py
│ │ │ ├── stock_analysis.py
│ │ │ └── technical_analysis.py
│ │ ├── error_handling.py
│ │ ├── inspector_compatible_sse.py
│ │ ├── inspector_sse.py
│ │ ├── middleware
│ │ │ ├── error_handling.py
│ │ │ ├── mcp_logging.py
│ │ │ ├── rate_limiting_enhanced.py
│ │ │ └── security.py
│ │ ├── openapi_config.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── agents.py
│ │ │ ├── backtesting.py
│ │ │ ├── data_enhanced.py
│ │ │ ├── data.py
│ │ │ ├── health_enhanced.py
│ │ │ ├── health_tools.py
│ │ │ ├── health.py
│ │ │ ├── intelligent_backtesting.py
│ │ │ ├── introspection.py
│ │ │ ├── mcp_prompts.py
│ │ │ ├── monitoring.py
│ │ │ ├── news_sentiment_enhanced.py
│ │ │ ├── performance.py
│ │ │ ├── portfolio.py
│ │ │ ├── research.py
│ │ │ ├── screening_ddd.py
│ │ │ ├── screening_parallel.py
│ │ │ ├── screening.py
│ │ │ ├── technical_ddd.py
│ │ │ ├── technical_enhanced.py
│ │ │ ├── technical.py
│ │ │ └── tool_registry.py
│ │ ├── server.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── base_service.py
│ │ │ ├── market_service.py
│ │ │ ├── portfolio_service.py
│ │ │ ├── prompt_service.py
│ │ │ └── resource_service.py
│ │ ├── simple_sse.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── insomnia_export.py
│ │ └── postman_export.py
│ ├── application
│ │ ├── __init__.py
│ │ ├── commands
│ │ │ └── __init__.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_dto.py
│ │ ├── queries
│ │ │ ├── __init__.py
│ │ │ └── get_technical_analysis.py
│ │ └── screening
│ │ ├── __init__.py
│ │ ├── dtos.py
│ │ └── queries.py
│ ├── backtesting
│ │ ├── __init__.py
│ │ ├── ab_testing.py
│ │ ├── analysis.py
│ │ ├── batch_processing_stub.py
│ │ ├── batch_processing.py
│ │ ├── model_manager.py
│ │ ├── optimization.py
│ │ ├── persistence.py
│ │ ├── retraining_pipeline.py
│ │ ├── strategies
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── ml
│ │ │ │ ├── __init__.py
│ │ │ │ ├── adaptive.py
│ │ │ │ ├── ensemble.py
│ │ │ │ ├── feature_engineering.py
│ │ │ │ └── regime_aware.py
│ │ │ ├── ml_strategies.py
│ │ │ ├── parser.py
│ │ │ └── templates.py
│ │ ├── strategy_executor.py
│ │ ├── vectorbt_engine.py
│ │ └── visualization.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── database_self_contained.py
│ │ ├── database.py
│ │ ├── llm_optimization_config.py
│ │ ├── logging_settings.py
│ │ ├── plotly_config.py
│ │ ├── security_utils.py
│ │ ├── security.py
│ │ ├── settings.py
│ │ ├── technical_constants.py
│ │ ├── tool_estimation.py
│ │ └── validation.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── technical_analysis.py
│ │ └── visualization.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── cache_manager.py
│ │ ├── cache.py
│ │ ├── django_adapter.py
│ │ ├── health.py
│ │ ├── models.py
│ │ ├── performance.py
│ │ ├── session_management.py
│ │ └── validation.py
│ ├── database
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── optimization.py
│ ├── dependencies.py
│ ├── domain
│ │ ├── __init__.py
│ │ ├── entities
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis.py
│ │ ├── events
│ │ │ └── __init__.py
│ │ ├── portfolio.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ ├── entities.py
│ │ │ ├── services.py
│ │ │ └── value_objects.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_service.py
│ │ ├── stock_analysis
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis_service.py
│ │ └── value_objects
│ │ ├── __init__.py
│ │ └── technical_indicators.py
│ ├── exceptions.py
│ ├── infrastructure
│ │ ├── __init__.py
│ │ ├── cache
│ │ │ └── __init__.py
│ │ ├── caching
│ │ │ ├── __init__.py
│ │ │ └── cache_management_service.py
│ │ ├── connection_manager.py
│ │ ├── data_fetching
│ │ │ ├── __init__.py
│ │ │ └── stock_data_service.py
│ │ ├── health
│ │ │ ├── __init__.py
│ │ │ └── health_checker.py
│ │ ├── persistence
│ │ │ ├── __init__.py
│ │ │ └── stock_repository.py
│ │ ├── providers
│ │ │ └── __init__.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ └── repositories.py
│ │ └── sse_optimizer.py
│ ├── langchain_tools
│ │ ├── __init__.py
│ │ ├── adapters.py
│ │ └── registry.py
│ ├── logging_config.py
│ ├── memory
│ │ ├── __init__.py
│ │ └── stores.py
│ ├── monitoring
│ │ ├── __init__.py
│ │ ├── health_check.py
│ │ ├── health_monitor.py
│ │ ├── integration_example.py
│ │ ├── metrics.py
│ │ ├── middleware.py
│ │ └── status_dashboard.py
│ ├── providers
│ │ ├── __init__.py
│ │ ├── dependencies.py
│ │ ├── factories
│ │ │ ├── __init__.py
│ │ │ ├── config_factory.py
│ │ │ └── provider_factory.py
│ │ ├── implementations
│ │ │ ├── __init__.py
│ │ │ ├── cache_adapter.py
│ │ │ ├── macro_data_adapter.py
│ │ │ ├── market_data_adapter.py
│ │ │ ├── persistence_adapter.py
│ │ │ └── stock_data_adapter.py
│ │ ├── interfaces
│ │ │ ├── __init__.py
│ │ │ ├── cache.py
│ │ │ ├── config.py
│ │ │ ├── macro_data.py
│ │ │ ├── market_data.py
│ │ │ ├── persistence.py
│ │ │ └── stock_data.py
│ │ ├── llm_factory.py
│ │ ├── macro_data.py
│ │ ├── market_data.py
│ │ ├── mocks
│ │ │ ├── __init__.py
│ │ │ ├── mock_cache.py
│ │ │ ├── mock_config.py
│ │ │ ├── mock_macro_data.py
│ │ │ ├── mock_market_data.py
│ │ │ ├── mock_persistence.py
│ │ │ └── mock_stock_data.py
│ │ ├── openrouter_provider.py
│ │ ├── optimized_screening.py
│ │ ├── optimized_stock_data.py
│ │ └── stock_data.py
│ ├── README.md
│ ├── tests
│ │ ├── __init__.py
│ │ ├── README_INMEMORY_TESTS.md
│ │ ├── test_cache_debug.py
│ │ ├── test_fixes_validation.py
│ │ ├── test_in_memory_routers.py
│ │ ├── test_in_memory_server.py
│ │ ├── test_macro_data_provider.py
│ │ ├── test_mailgun_email.py
│ │ ├── test_market_calendar_caching.py
│ │ ├── test_mcp_tool_fixes_pytest.py
│ │ ├── test_mcp_tool_fixes.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_models_functional.py
│ │ ├── test_server.py
│ │ ├── test_stock_data_enhanced.py
│ │ ├── test_stock_data_provider.py
│ │ └── test_technical_analysis.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── performance_monitoring.py
│ │ ├── portfolio_manager.py
│ │ ├── risk_management.py
│ │ └── sentiment_analysis.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── agent_errors.py
│ │ ├── batch_processing.py
│ │ ├── cache_warmer.py
│ │ ├── circuit_breaker_decorators.py
│ │ ├── circuit_breaker_services.py
│ │ ├── circuit_breaker.py
│ │ ├── data_chunking.py
│ │ ├── database_monitoring.py
│ │ ├── debug_utils.py
│ │ ├── fallback_strategies.py
│ │ ├── llm_optimization.py
│ │ ├── logging_example.py
│ │ ├── logging_init.py
│ │ ├── logging.py
│ │ ├── mcp_logging.py
│ │ ├── memory_profiler.py
│ │ ├── monitoring_middleware.py
│ │ ├── monitoring.py
│ │ ├── orchestration_logging.py
│ │ ├── parallel_research.py
│ │ ├── parallel_screening.py
│ │ ├── quick_cache.py
│ │ ├── resource_manager.py
│ │ ├── shutdown.py
│ │ ├── stock_helpers.py
│ │ ├── structured_logger.py
│ │ ├── tool_monitoring.py
│ │ ├── tracing.py
│ │ └── yfinance_pool.py
│ ├── validation
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── data.py
│ │ ├── middleware.py
│ │ ├── portfolio.py
│ │ ├── responses.py
│ │ ├── screening.py
│ │ └── technical.py
│ └── workflows
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── market_analyzer.py
│ │ ├── optimizer_agent.py
│ │ ├── strategy_selector.py
│ │ └── validator_agent.py
│ ├── backtesting_workflow.py
│ └── state.py
├── PLANS.md
├── pyproject.toml
├── pyrightconfig.json
├── README.md
├── scripts
│ ├── dev.sh
│ ├── INSTALLATION_GUIDE.md
│ ├── load_example.py
│ ├── load_market_data.py
│ ├── load_tiingo_data.py
│ ├── migrate_db.py
│ ├── README_TIINGO_LOADER.md
│ ├── requirements_tiingo.txt
│ ├── run_stock_screening.py
│ ├── run-migrations.sh
│ ├── seed_db.py
│ ├── seed_sp500.py
│ ├── setup_database.sh
│ ├── setup_self_contained.py
│ ├── setup_sp500_database.sh
│ ├── test_seeded_data.py
│ ├── test_tiingo_loader.py
│ ├── tiingo_config.py
│ └── validate_setup.py
├── SECURITY.md
├── server.json
├── setup.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── core
│ │ └── test_technical_analysis.py
│ ├── data
│ │ └── test_portfolio_models.py
│ ├── domain
│ │ ├── conftest.py
│ │ ├── test_portfolio_entities.py
│ │ └── test_technical_analysis_service.py
│ ├── fixtures
│ │ └── orchestration_fixtures.py
│ ├── integration
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── README.md
│ │ ├── run_integration_tests.sh
│ │ ├── test_api_technical.py
│ │ ├── test_chaos_engineering.py
│ │ ├── test_config_management.py
│ │ ├── test_full_backtest_workflow_advanced.py
│ │ ├── test_full_backtest_workflow.py
│ │ ├── test_high_volume.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_orchestration_complete.py
│ │ ├── test_portfolio_persistence.py
│ │ ├── test_redis_cache.py
│ │ ├── test_security_integration.py.disabled
│ │ └── vcr_setup.py
│ ├── performance
│ │ ├── __init__.py
│ │ ├── test_benchmarks.py
│ │ ├── test_load.py
│ │ ├── test_profiling.py
│ │ └── test_stress.py
│ ├── providers
│ │ └── test_stock_data_simple.py
│ ├── README.md
│ ├── test_agents_router_mcp.py
│ ├── test_backtest_persistence.py
│ ├── test_cache_management_service.py
│ ├── test_cache_serialization.py
│ ├── test_circuit_breaker.py
│ ├── test_database_pool_config_simple.py
│ ├── test_database_pool_config.py
│ ├── test_deep_research_functional.py
│ ├── test_deep_research_integration.py
│ ├── test_deep_research_parallel_execution.py
│ ├── test_error_handling.py
│ ├── test_event_loop_integrity.py
│ ├── test_exa_research_integration.py
│ ├── test_exception_hierarchy.py
│ ├── test_financial_search.py
│ ├── test_graceful_shutdown.py
│ ├── test_integration_simple.py
│ ├── test_langgraph_workflow.py
│ ├── test_market_data_async.py
│ ├── test_market_data_simple.py
│ ├── test_mcp_orchestration_functional.py
│ ├── test_ml_strategies.py
│ ├── test_optimized_research_agent.py
│ ├── test_orchestration_integration.py
│ ├── test_orchestration_logging.py
│ ├── test_orchestration_tools_simple.py
│ ├── test_parallel_research_integration.py
│ ├── test_parallel_research_orchestrator.py
│ ├── test_parallel_research_performance.py
│ ├── test_performance_optimizations.py
│ ├── test_production_validation.py
│ ├── test_provider_architecture.py
│ ├── test_rate_limiting_enhanced.py
│ ├── test_runner_validation.py
│ ├── test_security_comprehensive.py.disabled
│ ├── test_security_cors.py
│ ├── test_security_enhancements.py.disabled
│ ├── test_security_headers.py
│ ├── test_security_penetration.py
│ ├── test_session_management.py
│ ├── test_speed_optimization_validation.py
│ ├── test_stock_analysis_dependencies.py
│ ├── test_stock_analysis_service.py
│ ├── test_stock_data_fetching_service.py
│ ├── test_supervisor_agent.py
│ ├── test_supervisor_functional.py
│ ├── test_tool_estimation_config.py
│ ├── test_visualization.py
│ └── utils
│ ├── test_agent_errors.py
│ ├── test_logging.py
│ ├── test_parallel_screening.py
│ └── test_quick_cache.py
├── tools
│ ├── check_orchestration_config.py
│ ├── experiments
│ │ ├── validation_examples.py
│ │ └── validation_fixed.py
│ ├── fast_dev.sh
│ ├── hot_reload.py
│ ├── quick_test.py
│ └── templates
│ ├── new_router_template.py
│ ├── new_tool_template.py
│ ├── screening_strategy_template.py
│ └── test_template.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/maverick_mcp/agents/base.py:
--------------------------------------------------------------------------------
```python
"""
Base classes for persona-aware agents using LangGraph best practices.
"""
import logging
from abc import ABC, abstractmethod
from collections.abc import Sequence
from datetime import datetime
from typing import Annotated, Any
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_core.tools import BaseTool
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph, add_messages
from langgraph.prebuilt import ToolNode
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
from maverick_mcp.config.settings import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
class InvestorPersona(BaseModel):
"""Defines an investor persona with risk parameters."""
name: str
risk_tolerance: tuple[int, int] = Field(
description="Risk tolerance range (min, max) on 0-100 scale"
)
position_size_max: float = Field(
description="Maximum position size as percentage of portfolio"
)
stop_loss_multiplier: float = Field(
description="Multiplier for stop loss calculation"
)
preferred_timeframe: str = Field(
default="swing", description="Preferred trading timeframe: day, swing, position"
)
characteristics: list[str] = Field(
default_factory=list, description="Key behavioral characteristics"
)
# Predefined investor personas
INVESTOR_PERSONAS = {
"conservative": InvestorPersona(
name="Conservative",
risk_tolerance=(
settings.financial.risk_tolerance_conservative_min,
settings.financial.risk_tolerance_conservative_max,
),
position_size_max=settings.financial.max_position_size_conservative,
stop_loss_multiplier=settings.financial.stop_loss_multiplier_conservative,
preferred_timeframe="position",
characteristics=[
"Prioritizes capital preservation",
"Focuses on dividend stocks",
"Prefers established companies",
"Long-term oriented",
],
),
"moderate": InvestorPersona(
name="Moderate",
risk_tolerance=(
settings.financial.risk_tolerance_moderate_min,
settings.financial.risk_tolerance_moderate_max,
),
position_size_max=settings.financial.max_position_size_moderate,
stop_loss_multiplier=settings.financial.stop_loss_multiplier_moderate,
preferred_timeframe="swing",
characteristics=[
"Balanced risk/reward approach",
"Mix of growth and value",
"Diversified portfolio",
"Medium-term focus",
],
),
"aggressive": InvestorPersona(
name="Aggressive",
risk_tolerance=(
settings.financial.risk_tolerance_aggressive_min,
settings.financial.risk_tolerance_aggressive_max,
),
position_size_max=settings.financial.max_position_size_aggressive,
stop_loss_multiplier=settings.financial.stop_loss_multiplier_aggressive,
preferred_timeframe="day",
characteristics=[
"High risk tolerance",
"Growth-focused",
"Momentum trading",
"Short-term opportunities",
],
),
"day_trader": InvestorPersona(
name="Day Trader",
risk_tolerance=(
settings.financial.risk_tolerance_day_trader_min,
settings.financial.risk_tolerance_day_trader_max,
),
position_size_max=settings.financial.max_position_size_day_trader,
stop_loss_multiplier=settings.financial.stop_loss_multiplier_day_trader,
preferred_timeframe="day",
characteristics=[
"Intraday positions only",
"High-frequency trading",
"Technical analysis focused",
"Tight risk controls",
],
),
}
class BaseAgentState(TypedDict):
"""Base state for all persona-aware agents."""
messages: Annotated[Sequence[BaseMessage], add_messages]
persona: str
session_id: str
class PersonaAwareTool(BaseTool):
"""Base class for tools that adapt to investor personas."""
persona: InvestorPersona | None = None
# State tracking
last_analysis_time: dict[str, datetime] = {}
analyzed_stocks: dict[str, dict] = {}
key_price_levels: dict[str, dict] = {}
# Cache settings
cache_ttl: int = settings.agent.agent_cache_ttl_seconds
def set_persona(self, persona: InvestorPersona) -> None:
"""Set the active investor persona."""
self.persona = persona
def adjust_for_risk(self, value: float, parameter_type: str) -> float:
"""Adjust a value based on the persona's risk profile."""
if not self.persona:
return value
# Get average risk tolerance
risk_avg = sum(self.persona.risk_tolerance) / 2
risk_factor = risk_avg / 50 # Normalize to 1.0 at moderate risk
# Adjust based on parameter type
if parameter_type == "position_size":
# Kelly Criterion-inspired sizing
kelly_fraction = self._calculate_kelly_fraction(risk_factor)
adjusted = value * kelly_fraction
return min(adjusted, self.persona.position_size_max)
elif parameter_type == "stop_loss":
# ATR-based dynamic stops
return value * self.persona.stop_loss_multiplier
elif parameter_type == "profit_target":
# Risk-adjusted targets
return value * (2 - risk_factor) # Conservative = lower targets
elif parameter_type == "volatility_filter":
# Volatility tolerance
return value * (2 - risk_factor) # Conservative = lower vol tolerance
elif parameter_type == "time_horizon":
# Holding period in days
if self.persona.preferred_timeframe == "day":
return 1
elif self.persona.preferred_timeframe == "swing":
return 5 * risk_factor # 2.5-7.5 days
else: # position
return 20 * risk_factor # 10-30 days
else:
return value
def _calculate_kelly_fraction(self, risk_factor: float) -> float:
"""Calculate position size using Kelly Criterion."""
# Simplified Kelly: f = (p*b - q) / b
# where p = win probability, b = win/loss ratio, q = loss probability
# Using risk factor to adjust expected win rate
win_probability = 0.45 + (0.1 * risk_factor) # 45-55% base win rate
win_loss_ratio = 2.0 # 2:1 reward/risk
loss_probability = 1 - win_probability
kelly = (win_probability * win_loss_ratio - loss_probability) / win_loss_ratio
# Apply safety factor (never use full Kelly)
safety_factor = 0.25 # Use 25% of Kelly
return max(0, kelly * safety_factor)
def update_analysis_data(self, symbol: str, analysis_data: dict[str, Any]):
"""Update stored analysis data for a symbol."""
symbol = symbol.upper()
self.analyzed_stocks[symbol] = analysis_data
self.last_analysis_time[symbol] = datetime.now()
if "price_levels" in analysis_data:
self.key_price_levels[symbol] = analysis_data["price_levels"]
def get_stock_context(self, symbol: str) -> dict[str, Any]:
"""Get stored context for a symbol."""
symbol = symbol.upper()
return {
"analysis": self.analyzed_stocks.get(symbol, {}),
"last_analysis": self.last_analysis_time.get(symbol),
"price_levels": self.key_price_levels.get(symbol, {}),
"cache_expired": self._is_cache_expired(symbol),
}
def _is_cache_expired(self, symbol: str) -> bool:
"""Check if cached data has expired."""
last_time = self.last_analysis_time.get(symbol.upper())
if not last_time:
return True
age_seconds = (datetime.now() - last_time).total_seconds()
return age_seconds > self.cache_ttl
def _adjust_risk_parameters(self, params: dict) -> dict:
"""Adjust parameters based on risk profile."""
if not self.persona:
return params
risk_factor = sum(self.persona.risk_tolerance) / 100 # 0.1-0.9 scale
# Apply risk adjustments based on parameter names
adjusted = {}
for key, value in params.items():
if isinstance(value, int | float):
key_lower = key.lower()
if any(term in key_lower for term in ["stop", "support", "risk"]):
# Wider stops/support for conservative, tighter for aggressive
adjusted[key] = value * (2 - risk_factor)
elif any(
term in key_lower for term in ["resistance", "target", "profit"]
):
# Lower targets for conservative, higher for aggressive
adjusted[key] = value * risk_factor
elif any(term in key_lower for term in ["size", "amount", "shares"]):
# Smaller positions for conservative, larger for aggressive
adjusted[key] = self.adjust_for_risk(value, "position_size")
elif any(term in key_lower for term in ["volume", "liquidity"]):
# Higher liquidity requirements for conservative
adjusted[key] = value * (2 - risk_factor)
elif any(term in key_lower for term in ["volatility", "atr", "std"]):
# Lower volatility tolerance for conservative
adjusted[key] = self.adjust_for_risk(value, "volatility_filter")
else:
adjusted[key] = value
else:
adjusted[key] = value
return adjusted
def _validate_risk_levels(self, data: dict) -> bool:
"""Validate if the data meets the persona's risk criteria."""
if not self.persona:
return True
min_risk, max_risk = self.persona.risk_tolerance
# Extract risk metrics
volatility = data.get("volatility", 0)
beta = data.get("beta", 1.0)
# Convert to risk score (0-100)
volatility_score = min(100, volatility * 2) # Assume 50% vol = 100 risk
beta_score = abs(beta - 1) * 100 # Distance from market
# Combined risk score
risk_score = (volatility_score + beta_score) / 2
if risk_score < min_risk or risk_score > max_risk:
return False
# Persona-specific validations
if self.persona.name == "Conservative":
# Additional checks for conservative investors
if data.get("debt_to_equity", 0) > 1.5:
return False
if data.get("current_ratio", 0) < 1.5:
return False
if data.get("dividend_yield", 0) < 0.02: # Prefer dividend stocks
return False
elif self.persona.name == "Day Trader":
# Day traders need high liquidity
if data.get("average_volume", 0) < 1_000_000:
return False
if data.get("spread_percentage", 0) > 0.1: # Tight spreads only
return False
return True
def format_for_persona(self, data: dict) -> dict:
"""Format output data based on persona preferences."""
if not self.persona:
return data
formatted = data.copy()
# Add persona-specific insights
formatted["persona_insights"] = {
"suitable_for_profile": self._validate_risk_levels(data),
"risk_adjusted_parameters": self._adjust_risk_parameters(
data.get("parameters", {})
),
"recommended_timeframe": self.persona.preferred_timeframe,
"max_position_size": self.persona.position_size_max,
}
# Add risk warnings if needed
warnings = []
if not self._validate_risk_levels(data):
warnings.append(f"Risk profile outside {self.persona.name} parameters")
if data.get("volatility", 0) > 50:
warnings.append("High volatility - consider smaller position size")
if warnings:
formatted["risk_warnings"] = warnings
return formatted
class PersonaAwareAgent(ABC):
"""
Base class for agents that adapt behavior based on investor personas.
This follows LangGraph best practices:
- Uses StateGraph for workflow definition
- Implements proper node/edge patterns
- Supports native streaming modes
- Uses TypedDict for state management
"""
def __init__(
self,
llm,
tools: list[BaseTool],
persona: str = "moderate",
checkpointer: MemorySaver | None = None,
ttl_hours: int = 1,
):
"""
Initialize a persona-aware agent.
Args:
llm: Language model to use
tools: List of tools available to the agent
persona: Investor persona name
checkpointer: Optional checkpointer (defaults to MemorySaver)
ttl_hours: Time-to-live for memory in hours
"""
self.llm = llm
self.tools = tools
self.persona = INVESTOR_PERSONAS.get(persona, INVESTOR_PERSONAS["moderate"])
self.ttl_hours = ttl_hours
# Set up checkpointing
if checkpointer is None:
self.checkpointer = MemorySaver()
else:
self.checkpointer = checkpointer
# Configure tools with persona
for tool in self.tools:
if isinstance(tool, PersonaAwareTool):
tool.set_persona(self.persona)
# Build the graph
self.graph = self._build_graph()
# Track usage
self.total_tokens = 0
self.conversation_start = datetime.now()
def _build_graph(self):
"""Build the LangGraph workflow."""
# Create the graph builder
workflow = StateGraph(self.get_state_schema())
# Add the agent node
workflow.add_node("agent", self._agent_node)
# Create tool node if tools are available
if self.tools:
tool_node = ToolNode(self.tools)
workflow.add_node("tools", tool_node)
# Add conditional edge from agent
workflow.add_conditional_edges(
"agent",
self._should_continue,
{
# If agent returns tool calls, route to tools
"continue": "tools",
# Otherwise end
"end": END,
},
)
# Add edge from tools back to agent
workflow.add_edge("tools", "agent")
else:
# No tools, just end after agent
workflow.add_edge("agent", END)
# Set entry point
workflow.add_edge(START, "agent")
# Compile with checkpointer
return workflow.compile(checkpointer=self.checkpointer)
def _agent_node(self, state: dict[str, Any]) -> dict[str, Any]:
"""The main agent node that processes messages."""
messages = state["messages"]
# Add system message if it's the first message
if len(messages) == 1 and isinstance(messages[0], HumanMessage):
system_prompt = self._build_system_prompt()
messages = [SystemMessage(content=system_prompt)] + messages
# Call the LLM
if self.tools:
response = self.llm.bind_tools(self.tools).invoke(messages)
else:
response = self.llm.invoke(messages)
# Track tokens (simplified)
if hasattr(response, "content"):
self.total_tokens += len(response.content) // 4
# Return the response
return {"messages": [response]}
def _should_continue(self, state: dict[str, Any]) -> str:
"""Determine whether to continue to tools or end."""
last_message = state["messages"][-1]
# If the LLM makes a tool call, continue to tools
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
return "continue"
# Otherwise we're done
return "end"
def _build_system_prompt(self) -> str:
"""Build system prompt based on persona."""
base_prompt = f"""You are a financial advisor configured for a {self.persona.name} investor profile.
Risk Parameters:
- Risk Tolerance: {self.persona.risk_tolerance[0]}-{self.persona.risk_tolerance[1]}/100
- Max Position Size: {self.persona.position_size_max * 100:.1f}% of portfolio
- Stop Loss Multiplier: {self.persona.stop_loss_multiplier}x
- Preferred Timeframe: {self.persona.preferred_timeframe}
Key Characteristics:
{chr(10).join(f"- {char}" for char in self.persona.characteristics)}
Always adjust your recommendations to match this risk profile. Be explicit about risk management."""
return base_prompt
@abstractmethod
def get_state_schema(self) -> type:
"""
Get the state schema for this agent.
Subclasses should return their specific state schema.
"""
return BaseAgentState
async def ainvoke(self, query: str, session_id: str, **kwargs) -> dict[str, Any]:
"""
Invoke the agent asynchronously.
Args:
query: User query
session_id: Session identifier
**kwargs: Additional parameters
Returns:
Agent response
"""
config = {
"configurable": {"thread_id": session_id, "persona": self.persona.name}
}
# Merge additional config
if "config" in kwargs:
config.update(kwargs["config"])
# Run the graph
result = await self.graph.ainvoke(
{
"messages": [HumanMessage(content=query)],
"persona": self.persona.name,
"session_id": session_id,
},
config=config,
)
return self._extract_response(result)
def invoke(self, query: str, session_id: str, **kwargs) -> dict[str, Any]:
"""
Invoke the agent synchronously.
Args:
query: User query
session_id: Session identifier
**kwargs: Additional parameters
Returns:
Agent response
"""
config = {
"configurable": {"thread_id": session_id, "persona": self.persona.name}
}
# Merge additional config
if "config" in kwargs:
config.update(kwargs["config"])
# Run the graph
result = self.graph.invoke(
{
"messages": [HumanMessage(content=query)],
"persona": self.persona.name,
"session_id": session_id,
},
config=config,
)
return self._extract_response(result)
async def astream(
self, query: str, session_id: str, stream_mode: str = "values", **kwargs
):
"""
Stream agent responses asynchronously.
Args:
query: User query
session_id: Session identifier
stream_mode: Streaming mode (values, updates, messages, custom, debug)
**kwargs: Additional parameters
Yields:
Streamed chunks based on mode
"""
config = {
"configurable": {"thread_id": session_id, "persona": self.persona.name}
}
# Merge additional config
if "config" in kwargs:
config.update(kwargs["config"])
# Stream the graph
async for chunk in self.graph.astream(
{
"messages": [HumanMessage(content=query)],
"persona": self.persona.name,
"session_id": session_id,
},
config=config,
stream_mode=stream_mode,
):
yield chunk
def stream(
self, query: str, session_id: str, stream_mode: str = "values", **kwargs
):
"""
Stream agent responses synchronously.
Args:
query: User query
session_id: Session identifier
stream_mode: Streaming mode (values, updates, messages, custom, debug)
**kwargs: Additional parameters
Yields:
Streamed chunks based on mode
"""
config = {
"configurable": {"thread_id": session_id, "persona": self.persona.name}
}
# Merge additional config
if "config" in kwargs:
config.update(kwargs["config"])
# Stream the graph
yield from self.graph.stream(
{
"messages": [HumanMessage(content=query)],
"persona": self.persona.name,
"session_id": session_id,
},
config=config,
stream_mode=stream_mode,
)
def _extract_response(self, result: dict[str, Any]) -> dict[str, Any]:
"""Extract the final response from graph execution."""
messages = result.get("messages", [])
if not messages:
return {"content": "No response generated", "status": "error"}
# Get the last AI message
last_message = messages[-1]
return {
"content": last_message.content
if hasattr(last_message, "content")
else str(last_message),
"status": "success",
"persona": self.persona.name,
"message_count": len(messages),
"session_id": result.get("session_id", ""),
}
def get_risk_adjusted_params(
self, base_params: dict[str, float]
) -> dict[str, float]:
"""Adjust parameters based on persona risk profile."""
adjusted = {}
for key, value in base_params.items():
if "size" in key.lower() or "position" in key.lower():
adjusted[key] = self.adjust_for_risk(value, "position_size")
elif "stop" in key.lower():
adjusted[key] = self.adjust_for_risk(value, "stop_loss")
elif "target" in key.lower() or "profit" in key.lower():
adjusted[key] = self.adjust_for_risk(value, "profit_target")
else:
adjusted[key] = value
return adjusted
def adjust_for_risk(self, value: float, parameter_type: str) -> float:
"""Adjust a value based on the persona's risk profile."""
# Get average risk tolerance
risk_avg = sum(self.persona.risk_tolerance) / 2
risk_factor = risk_avg / 50 # Normalize to 1.0 at moderate risk
# Adjust based on parameter type
if parameter_type == "position_size":
return min(value * risk_factor, self.persona.position_size_max)
elif parameter_type == "stop_loss":
return value * self.persona.stop_loss_multiplier
elif parameter_type == "profit_target":
return value * (2 - risk_factor) # Conservative = lower targets
else:
return value
```
--------------------------------------------------------------------------------
/maverick_mcp/api/routers/health_enhanced.py:
--------------------------------------------------------------------------------
```python
"""
Comprehensive health check router for backtesting system.
Provides detailed health monitoring including:
- Component status (database, cache, external APIs)
- Circuit breaker monitoring
- Resource utilization
- Readiness and liveness probes
- Performance metrics
"""
import asyncio
import logging
import time
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
import psutil
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from maverick_mcp.config.settings import get_settings
from maverick_mcp.utils.circuit_breaker import get_circuit_breaker_status
logger = logging.getLogger(__name__)
settings = get_settings()
router = APIRouter(prefix="/health", tags=["Health"])
# Service start time for uptime calculation
_start_time = time.time()
class ComponentStatus(BaseModel):
"""Individual component health status."""
name: str = Field(description="Component name")
status: str = Field(description="Status (healthy/degraded/unhealthy)")
response_time_ms: float | None = Field(description="Response time in milliseconds")
last_check: str = Field(description="Timestamp of last health check")
details: dict = Field(default_factory=dict, description="Additional status details")
error: str | None = Field(default=None, description="Error message if unhealthy")
class ResourceUsage(BaseModel):
"""System resource usage information."""
cpu_percent: float = Field(description="CPU usage percentage")
memory_percent: float = Field(description="Memory usage percentage")
disk_percent: float = Field(description="Disk usage percentage")
memory_used_mb: float = Field(description="Memory used in MB")
memory_total_mb: float = Field(description="Total memory in MB")
disk_used_gb: float = Field(description="Disk used in GB")
disk_total_gb: float = Field(description="Total disk in GB")
load_average: list[float] | None = Field(
default=None, description="System load averages"
)
class CircuitBreakerStatus(BaseModel):
"""Circuit breaker status information."""
name: str = Field(description="Circuit breaker name")
state: str = Field(description="Current state (closed/open/half_open)")
failure_count: int = Field(description="Current consecutive failure count")
time_until_retry: float | None = Field(description="Seconds until retry allowed")
metrics: dict = Field(description="Performance metrics")
class DetailedHealthStatus(BaseModel):
"""Comprehensive health status with all components."""
status: str = Field(
description="Overall health status (healthy/degraded/unhealthy)"
)
timestamp: str = Field(description="Current timestamp")
version: str = Field(description="Application version")
uptime_seconds: float = Field(description="Service uptime in seconds")
components: dict[str, ComponentStatus] = Field(
description="Individual component statuses"
)
circuit_breakers: dict[str, CircuitBreakerStatus] = Field(
description="Circuit breaker statuses"
)
resource_usage: ResourceUsage = Field(description="System resource usage")
services: dict[str, str] = Field(description="External service statuses")
checks_summary: dict[str, int] = Field(description="Summary of check results")
class BasicHealthStatus(BaseModel):
"""Basic health status for simple health checks."""
status: str = Field(
description="Overall health status (healthy/degraded/unhealthy)"
)
timestamp: str = Field(description="Current timestamp")
version: str = Field(description="Application version")
uptime_seconds: float = Field(description="Service uptime in seconds")
class ReadinessStatus(BaseModel):
"""Readiness probe status."""
ready: bool = Field(description="Whether service is ready to accept traffic")
timestamp: str = Field(description="Current timestamp")
dependencies: dict[str, bool] = Field(description="Dependency readiness statuses")
details: dict = Field(
default_factory=dict, description="Additional readiness details"
)
class LivenessStatus(BaseModel):
"""Liveness probe status."""
alive: bool = Field(description="Whether service is alive and functioning")
timestamp: str = Field(description="Current timestamp")
last_heartbeat: str = Field(description="Last heartbeat timestamp")
details: dict = Field(
default_factory=dict, description="Additional liveness details"
)
def _get_uptime_seconds() -> float:
"""Get service uptime in seconds."""
return time.time() - _start_time
def _get_resource_usage() -> ResourceUsage:
"""Get current system resource usage."""
try:
# CPU usage
cpu_percent = psutil.cpu_percent(interval=1)
# Memory usage
memory = psutil.virtual_memory()
memory_used_mb = (memory.total - memory.available) / (1024 * 1024)
memory_total_mb = memory.total / (1024 * 1024)
# Disk usage for current directory
disk = psutil.disk_usage(Path.cwd())
disk_used_gb = (disk.total - disk.free) / (1024 * 1024 * 1024)
disk_total_gb = disk.total / (1024 * 1024 * 1024)
# Load average (Unix systems only)
load_average = None
try:
load_average = list(psutil.getloadavg())
except (AttributeError, OSError):
# Windows doesn't have load average
pass
return ResourceUsage(
cpu_percent=round(cpu_percent, 2),
memory_percent=round(memory.percent, 2),
disk_percent=round(disk.percent, 2),
memory_used_mb=round(memory_used_mb, 2),
memory_total_mb=round(memory_total_mb, 2),
disk_used_gb=round(disk_used_gb, 2),
disk_total_gb=round(disk_total_gb, 2),
load_average=load_average,
)
except Exception as e:
logger.error(f"Failed to get resource usage: {e}")
return ResourceUsage(
cpu_percent=0.0,
memory_percent=0.0,
disk_percent=0.0,
memory_used_mb=0.0,
memory_total_mb=0.0,
disk_used_gb=0.0,
disk_total_gb=0.0,
)
async def _check_database_health() -> ComponentStatus:
"""Check database connectivity and health."""
start_time = time.time()
timestamp = datetime.now(UTC).isoformat()
try:
from maverick_mcp.data.models import get_db
# Test database connection
db_session = next(get_db())
try:
# Simple query to test connection
result = db_session.execute("SELECT 1 as test")
test_value = result.scalar()
response_time_ms = (time.time() - start_time) * 1000
if test_value == 1:
return ComponentStatus(
name="database",
status="healthy",
response_time_ms=round(response_time_ms, 2),
last_check=timestamp,
details={"connection": "active", "query_test": "passed"},
)
else:
return ComponentStatus(
name="database",
status="unhealthy",
response_time_ms=round(response_time_ms, 2),
last_check=timestamp,
error="Database query returned unexpected result",
)
finally:
db_session.close()
except Exception as e:
response_time_ms = (time.time() - start_time) * 1000
return ComponentStatus(
name="database",
status="unhealthy",
response_time_ms=round(response_time_ms, 2),
last_check=timestamp,
error=str(e),
)
async def _check_cache_health() -> ComponentStatus:
"""Check Redis cache connectivity and health."""
start_time = time.time()
timestamp = datetime.now(UTC).isoformat()
try:
from maverick_mcp.data.cache import get_redis_client
redis_client = get_redis_client()
if redis_client is None:
return ComponentStatus(
name="cache",
status="degraded",
response_time_ms=0,
last_check=timestamp,
details={"type": "in_memory", "redis": "not_configured"},
)
# Test Redis connection
await asyncio.to_thread(redis_client.ping)
response_time_ms = (time.time() - start_time) * 1000
# Get Redis info
info = await asyncio.to_thread(redis_client.info)
return ComponentStatus(
name="cache",
status="healthy",
response_time_ms=round(response_time_ms, 2),
last_check=timestamp,
details={
"type": "redis",
"version": info.get("redis_version", "unknown"),
"memory_usage": info.get("used_memory_human", "unknown"),
"connected_clients": info.get("connected_clients", 0),
},
)
except Exception as e:
response_time_ms = (time.time() - start_time) * 1000
return ComponentStatus(
name="cache",
status="degraded",
response_time_ms=round(response_time_ms, 2),
last_check=timestamp,
details={"type": "fallback", "redis_error": str(e)},
)
async def _check_external_apis_health() -> dict[str, ComponentStatus]:
"""Check external API health using circuit breaker status."""
timestamp = datetime.now(UTC).isoformat()
# Map circuit breaker names to API names
api_mapping = {
"yfinance": "Yahoo Finance API",
"finviz": "Finviz API",
"fred_api": "FRED Economic Data API",
"tiingo": "Tiingo Market Data API",
"openrouter": "OpenRouter AI API",
"exa": "Exa Search API",
"news_api": "News API",
"external_api": "External Services",
}
api_statuses = {}
cb_status = get_circuit_breaker_status()
for cb_name, display_name in api_mapping.items():
cb_info = cb_status.get(cb_name)
if cb_info:
# Determine status based on circuit breaker state
if cb_info["state"] == "closed":
status = "healthy"
error = None
elif cb_info["state"] == "half_open":
status = "degraded"
error = "Circuit breaker testing recovery"
else: # open
status = "unhealthy"
error = "Circuit breaker open due to failures"
response_time = cb_info["metrics"].get("avg_response_time", 0)
api_statuses[cb_name] = ComponentStatus(
name=display_name,
status=status,
response_time_ms=round(response_time, 2) if response_time else None,
last_check=timestamp,
details={
"circuit_breaker_state": cb_info["state"],
"failure_count": cb_info["consecutive_failures"],
"success_rate": cb_info["metrics"].get("success_rate", 0),
},
error=error,
)
else:
# API not monitored by circuit breaker
api_statuses[cb_name] = ComponentStatus(
name=display_name,
status="unknown",
response_time_ms=None,
last_check=timestamp,
details={"monitoring": "not_configured"},
)
return api_statuses
async def _check_ml_models_health() -> ComponentStatus:
"""Check ML model availability and health."""
timestamp = datetime.now(UTC).isoformat()
try:
# Check if TA-Lib is available
# Basic test of technical analysis libraries
import numpy as np
# Check if pandas-ta is available
import pandas_ta as ta
import talib
test_data = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=float)
sma_result = talib.SMA(test_data, timeperiod=5)
sma_last_value = float(sma_result[-1])
return ComponentStatus(
name="ML Models & Libraries",
status="healthy",
response_time_ms=None,
last_check=timestamp,
details={
"talib": f"available (v{getattr(talib, '__version__', 'unknown')})",
"pandas_ta": f"available (v{getattr(ta, '__version__', 'unknown')})",
"numpy": "available",
"test_computation": "passed",
"test_computation_sma_last": sma_last_value,
},
)
except ImportError as e:
return ComponentStatus(
name="ML Models & Libraries",
status="degraded",
response_time_ms=None,
last_check=timestamp,
details={"missing_library": str(e)},
error=f"Missing required library: {e}",
)
except Exception as e:
return ComponentStatus(
name="ML Models & Libraries",
status="unhealthy",
response_time_ms=None,
last_check=timestamp,
error=str(e),
)
async def _get_detailed_health_status() -> dict[str, Any]:
"""Get comprehensive health status for all components."""
timestamp = datetime.now(UTC).isoformat()
# Run all health checks concurrently
db_task = _check_database_health()
cache_task = _check_cache_health()
apis_task = _check_external_apis_health()
ml_task = _check_ml_models_health()
try:
db_status, cache_status, api_statuses, ml_status = await asyncio.gather(
db_task, cache_task, apis_task, ml_task
)
except Exception as e:
logger.error(f"Error running health checks: {e}")
# Return minimal status on error
return {
"status": "unhealthy",
"timestamp": timestamp,
"version": getattr(settings, "version", "1.0.0"),
"uptime_seconds": _get_uptime_seconds(),
"components": {},
"circuit_breakers": {},
"resource_usage": _get_resource_usage(),
"services": {},
"checks_summary": {"healthy": 0, "degraded": 0, "unhealthy": 1},
}
# Combine all component statuses
components = {
"database": db_status,
"cache": cache_status,
"ml_models": ml_status,
}
components.update(api_statuses)
# Get circuit breaker status
cb_status = get_circuit_breaker_status()
circuit_breakers = {}
for name, status in cb_status.items():
circuit_breakers[name] = CircuitBreakerStatus(
name=status["name"],
state=status["state"],
failure_count=status["consecutive_failures"],
time_until_retry=status["time_until_retry"],
metrics=status["metrics"],
)
# Calculate overall health status
healthy_count = sum(1 for c in components.values() if c.status == "healthy")
degraded_count = sum(1 for c in components.values() if c.status == "degraded")
unhealthy_count = sum(1 for c in components.values() if c.status == "unhealthy")
if unhealthy_count > 0:
overall_status = "unhealthy"
elif degraded_count > 0:
overall_status = "degraded"
else:
overall_status = "healthy"
# Check service statuses based on circuit breakers
services = {}
for name, cb_info in cb_status.items():
if cb_info["state"] == "open":
services[name] = "down"
elif cb_info["state"] == "half_open":
services[name] = "degraded"
else:
services[name] = "up"
return {
"status": overall_status,
"timestamp": timestamp,
"version": getattr(settings, "version", "1.0.0"),
"uptime_seconds": _get_uptime_seconds(),
"components": components,
"circuit_breakers": circuit_breakers,
"resource_usage": _get_resource_usage(),
"services": services,
"checks_summary": {
"healthy": healthy_count,
"degraded": degraded_count,
"unhealthy": unhealthy_count,
},
}
@router.get("/", response_model=BasicHealthStatus)
async def basic_health_check() -> BasicHealthStatus:
"""Basic health check endpoint.
Returns simple health status without detailed component information.
Suitable for basic monitoring and load balancer health checks.
"""
try:
# Get basic status from comprehensive health check
detailed_status = await _get_detailed_health_status()
return BasicHealthStatus(
status=detailed_status["status"],
timestamp=datetime.now(UTC).isoformat(),
version=getattr(settings, "version", "1.0.0"),
uptime_seconds=_get_uptime_seconds(),
)
except Exception as e:
logger.error(f"Health check failed: {e}")
return BasicHealthStatus(
status="unhealthy",
timestamp=datetime.now(UTC).isoformat(),
version=getattr(settings, "version", "1.0.0"),
uptime_seconds=_get_uptime_seconds(),
)
@router.get("/detailed", response_model=DetailedHealthStatus)
async def detailed_health_check() -> DetailedHealthStatus:
"""Comprehensive health check with detailed component status.
Returns detailed information about all system components including:
- Database connectivity
- Cache availability
- External API status
- Circuit breaker states
- Resource utilization
- ML model availability
Returns:
DetailedHealthStatus: Comprehensive health information
"""
try:
health_data = await _get_detailed_health_status()
return DetailedHealthStatus(**health_data)
except Exception as e:
logger.error(f"Detailed health check failed: {e}")
# Return minimal unhealthy status
return DetailedHealthStatus(
status="unhealthy",
timestamp=datetime.now(UTC).isoformat(),
version=getattr(settings, "version", "1.0.0"),
uptime_seconds=_get_uptime_seconds(),
components={},
circuit_breakers={},
resource_usage=ResourceUsage(
cpu_percent=0.0,
memory_percent=0.0,
disk_percent=0.0,
memory_used_mb=0.0,
memory_total_mb=0.0,
disk_used_gb=0.0,
disk_total_gb=0.0,
),
services={},
checks_summary={"healthy": 0, "degraded": 0, "unhealthy": 1},
)
@router.get("/ready", response_model=ReadinessStatus)
async def readiness_probe() -> ReadinessStatus:
"""Kubernetes-style readiness probe.
Checks if the service is ready to accept traffic.
Returns ready=true only if all critical dependencies are available.
"""
try:
health_data = await _get_detailed_health_status()
# Critical dependencies for readiness
critical_components = ["database"]
dependencies = {}
all_critical_ready = True
for comp_name, comp_status in health_data["components"].items():
if comp_name in critical_components:
is_ready = comp_status.status in ["healthy", "degraded"]
dependencies[comp_name] = is_ready
if not is_ready:
all_critical_ready = False
else:
# Non-critical components
dependencies[comp_name] = comp_status.status != "unhealthy"
return ReadinessStatus(
ready=all_critical_ready,
timestamp=datetime.now(UTC).isoformat(),
dependencies=dependencies,
details={
"critical_components": critical_components,
"overall_health": health_data["status"],
},
)
except Exception as e:
logger.error(f"Readiness probe failed: {e}")
return ReadinessStatus(
ready=False,
timestamp=datetime.now(UTC).isoformat(),
dependencies={},
details={"error": str(e)},
)
@router.get("/live", response_model=LivenessStatus)
async def liveness_probe() -> LivenessStatus:
"""Kubernetes-style liveness probe.
Checks if the service is alive and functioning.
Returns alive=true if the service can process basic requests.
"""
try:
# Simple check - if we can respond, we're alive
current_time = datetime.now(UTC).isoformat()
# Basic service functionality test
uptime = _get_uptime_seconds()
return LivenessStatus(
alive=True,
timestamp=current_time,
last_heartbeat=current_time,
details={
"uptime_seconds": uptime,
"service_name": settings.app_name,
"process_id": psutil.Process().pid,
},
)
except Exception as e:
logger.error(f"Liveness probe failed: {e}")
return LivenessStatus(
alive=False,
timestamp=datetime.now(UTC).isoformat(),
last_heartbeat=datetime.now(UTC).isoformat(),
details={"error": str(e)},
)
@router.get("/circuit-breakers", response_model=dict[str, CircuitBreakerStatus])
async def get_circuit_breakers() -> dict[str, CircuitBreakerStatus]:
"""Get detailed circuit breaker status.
Returns:
Dictionary of circuit breaker statuses
"""
cb_status = get_circuit_breaker_status()
result = {}
for name, status in cb_status.items():
result[name] = CircuitBreakerStatus(
name=status["name"],
state=status["state"],
failure_count=status["consecutive_failures"],
time_until_retry=status["time_until_retry"],
metrics=status["metrics"],
)
return result
@router.post("/circuit-breakers/{name}/reset")
async def reset_circuit_breaker(name: str) -> dict:
"""Reset a specific circuit breaker.
Args:
name: Circuit breaker name
Returns:
Success response
"""
from maverick_mcp.utils.circuit_breaker import get_circuit_breaker
breaker = get_circuit_breaker(name)
if not breaker:
raise HTTPException(
status_code=404, detail=f"Circuit breaker '{name}' not found"
)
breaker.reset()
logger.info(f"Circuit breaker '{name}' reset via API")
return {"status": "success", "message": f"Circuit breaker '{name}' reset"}
@router.post("/circuit-breakers/reset-all")
async def reset_all_circuit_breakers() -> dict:
"""Reset all circuit breakers.
Returns:
Success response
"""
from maverick_mcp.utils.circuit_breaker import reset_all_circuit_breakers
reset_all_circuit_breakers()
logger.info("All circuit breakers reset via API")
return {"status": "success", "message": "All circuit breakers reset"}
```
--------------------------------------------------------------------------------
/maverick_mcp/backtesting/analysis.py:
--------------------------------------------------------------------------------
```python
"""Backtest result analysis utilities."""
import logging
from typing import Any
import numpy as np
import pandas as pd
import vectorbt as vbt
logger = logging.getLogger(__name__)
def convert_to_native(value):
"""Convert numpy types to native Python types for JSON serialization."""
if isinstance(value, np.int64 | np.int32 | np.int16 | np.int8):
return int(value)
elif isinstance(value, np.float64 | np.float32 | np.float16):
return float(value)
elif isinstance(value, np.ndarray):
return value.tolist()
elif hasattr(value, "item"): # For numpy scalars
return value.item()
elif pd.isna(value):
return None
return value
class BacktestAnalyzer:
"""Analyzer for backtest results."""
async def run_vectorbt_backtest(
self,
data: pd.DataFrame,
entry_signals: pd.Series,
exit_signals: pd.Series,
initial_capital: float = 10000.0,
fees: float = 0.001,
slippage: float = 0.001,
) -> dict[str, Any]:
"""Run a backtest using VectorBT with given signals.
Args:
data: Price data with OHLCV columns
entry_signals: Boolean series for entry signals
exit_signals: Boolean series for exit signals
initial_capital: Initial capital amount
fees: Trading fees as percentage
slippage: Slippage as percentage
Returns:
Backtest results dictionary
"""
# Validate inputs to prevent empty array errors
if data is None or len(data) == 0:
logger.warning("Empty or invalid data provided to run_vectorbt_backtest")
return self._create_empty_backtest_results(initial_capital)
if entry_signals is None or exit_signals is None:
logger.warning("Invalid signals provided to run_vectorbt_backtest")
return self._create_empty_backtest_results(initial_capital)
# Check for empty signals or all-False signals
if (
len(entry_signals) == 0
or len(exit_signals) == 0
or entry_signals.size == 0
or exit_signals.size == 0
):
logger.warning("Empty signal arrays provided to run_vectorbt_backtest")
return self._create_empty_backtest_results(initial_capital)
# Check if signals have any True values
if not entry_signals.any() and not exit_signals.any():
logger.info("No trading signals generated - returning buy-and-hold results")
return self._create_buyhold_backtest_results(data, initial_capital)
# Ensure we have close prices
close = data["close"] if "close" in data.columns else data["Close"]
try:
# Run VectorBT portfolio simulation
portfolio = vbt.Portfolio.from_signals(
close=close,
entries=entry_signals,
exits=exit_signals,
init_cash=initial_capital,
fees=fees,
slippage=slippage,
freq="D",
)
except Exception as e:
logger.error(f"VectorBT Portfolio.from_signals failed: {e}")
return self._create_empty_backtest_results(initial_capital, error=str(e))
# Extract metrics
metrics = {
"total_return": float(portfolio.total_return()),
"annual_return": float(portfolio.annualized_return())
if hasattr(portfolio, "annualized_return")
else 0,
"sharpe_ratio": float(portfolio.sharpe_ratio())
if not np.isnan(portfolio.sharpe_ratio())
else 0,
"max_drawdown": float(portfolio.max_drawdown()),
"win_rate": float(portfolio.trades.win_rate())
if portfolio.trades.count() > 0
else 0,
"total_trades": int(portfolio.trades.count()),
"profit_factor": float(portfolio.trades.profit_factor())
if portfolio.trades.count() > 0
else 0,
}
# Extract trades
trades = []
if portfolio.trades.count() > 0:
try:
# VectorBT trades are in a records array
trade_records = portfolio.trades.records
for i in range(len(trade_records)):
trade = trade_records[i]
trades.append(
{
"entry_time": convert_to_native(trade["entry_idx"])
if "entry_idx" in trade.dtype.names
else i,
"exit_time": convert_to_native(trade["exit_idx"])
if "exit_idx" in trade.dtype.names
else i + 1,
"pnl": convert_to_native(trade["pnl"])
if "pnl" in trade.dtype.names
else 0.0,
"return": convert_to_native(trade["return"])
if "return" in trade.dtype.names
else 0.0,
}
)
except (AttributeError, TypeError, KeyError) as e:
# Fallback for different trade formats
logger.debug(f"Could not extract detailed trades: {e}")
trades = [
{
"total_trades": int(portfolio.trades.count()),
"message": "Detailed trade data not available",
}
]
# Convert equity curve to ensure all values are Python native types
equity_curve_raw = portfolio.value().to_dict()
equity_curve = {
str(k): convert_to_native(v) for k, v in equity_curve_raw.items()
}
# Also get drawdown series with proper conversion
drawdown_raw = (
portfolio.drawdown().to_dict() if hasattr(portfolio, "drawdown") else {}
)
drawdown_series = {
str(k): convert_to_native(v) for k, v in drawdown_raw.items()
}
return {
"metrics": metrics,
"trades": trades,
"equity_curve": equity_curve,
"drawdown_series": drawdown_series,
}
def analyze(self, results: dict[str, Any]) -> dict[str, Any]:
"""Analyze backtest results and provide insights.
Args:
results: Backtest results from VectorBTEngine
Returns:
Analysis with performance grade, risk assessment, and recommendations
"""
metrics = results.get("metrics", {})
trades = results.get("trades", [])
analysis = {
"performance_grade": self._grade_performance(metrics),
"risk_assessment": self._assess_risk(metrics),
"trade_quality": self._analyze_trades(trades, metrics),
"strengths": self._identify_strengths(metrics),
"weaknesses": self._identify_weaknesses(metrics),
"recommendations": self._generate_recommendations(metrics),
"summary": self._generate_summary(metrics),
}
return analysis
def _grade_performance(self, metrics: dict[str, float]) -> str:
"""Grade overall performance (A-F)."""
score = 0
max_score = 100
# Sharpe ratio (30 points)
sharpe = metrics.get("sharpe_ratio", 0)
if sharpe >= 2.0:
score += 30
elif sharpe >= 1.5:
score += 25
elif sharpe >= 1.0:
score += 20
elif sharpe >= 0.5:
score += 10
else:
score += 5
# Total return (25 points)
total_return = metrics.get("total_return", 0)
if total_return >= 0.50: # 50%+
score += 25
elif total_return >= 0.30:
score += 20
elif total_return >= 0.15:
score += 15
elif total_return >= 0.05:
score += 10
elif total_return > 0:
score += 5
# Win rate (20 points)
win_rate = metrics.get("win_rate", 0)
if win_rate >= 0.60:
score += 20
elif win_rate >= 0.50:
score += 15
elif win_rate >= 0.40:
score += 10
else:
score += 5
# Max drawdown (15 points)
max_dd = abs(metrics.get("max_drawdown", 0))
if max_dd <= 0.10: # Less than 10%
score += 15
elif max_dd <= 0.20:
score += 12
elif max_dd <= 0.30:
score += 8
elif max_dd <= 0.40:
score += 4
# Profit factor (10 points)
profit_factor = metrics.get("profit_factor", 0)
if profit_factor >= 2.0:
score += 10
elif profit_factor >= 1.5:
score += 8
elif profit_factor >= 1.2:
score += 5
elif profit_factor > 1.0:
score += 3
# Convert score to grade
percentage = (score / max_score) * 100
if percentage >= 90:
return "A"
elif percentage >= 80:
return "B"
elif percentage >= 70:
return "C"
elif percentage >= 60:
return "D"
else:
return "F"
def _assess_risk(self, metrics: dict[str, float]) -> dict[str, Any]:
"""Assess risk characteristics."""
max_dd = abs(metrics.get("max_drawdown", 0))
sortino = metrics.get("sortino_ratio", 0)
sharpe = metrics.get("sharpe_ratio", 0)
calmar = metrics.get("calmar_ratio", 0)
recovery = metrics.get("recovery_factor", 0)
risk_level = "Low"
if max_dd > 0.40:
risk_level = "Very High"
elif max_dd > 0.30:
risk_level = "High"
elif max_dd > 0.20:
risk_level = "Medium"
elif max_dd > 0.10:
risk_level = "Low-Medium"
return {
"risk_level": risk_level,
"max_drawdown": max_dd,
"sortino_ratio": sortino,
"calmar_ratio": calmar,
"recovery_factor": recovery,
"risk_adjusted_return": sortino if sortino > 0 else sharpe,
"downside_protection": "Good"
if sortino > 1.5
else "Moderate"
if sortino > 0.5
else "Poor",
}
def _analyze_trades(
self, trades: list[dict], metrics: dict[str, float]
) -> dict[str, Any]:
"""Analyze trade quality and patterns."""
if not trades:
return {
"quality": "No trades",
"total_trades": 0,
"frequency": "None",
}
total_trades = metrics.get("total_trades", 0)
win_rate = metrics.get("win_rate", 0)
avg_duration = metrics.get("avg_duration", 0)
# Determine trade frequency
if total_trades < 10:
frequency = "Very Low"
elif total_trades < 50:
frequency = "Low"
elif total_trades < 100:
frequency = "Moderate"
elif total_trades < 200:
frequency = "High"
else:
frequency = "Very High"
# Determine trade quality
if win_rate >= 0.60 and metrics.get("profit_factor", 0) >= 1.5:
quality = "Excellent"
elif win_rate >= 0.50 and metrics.get("profit_factor", 0) >= 1.2:
quality = "Good"
elif win_rate >= 0.40:
quality = "Average"
else:
quality = "Poor"
return {
"quality": quality,
"total_trades": total_trades,
"frequency": frequency,
"win_rate": win_rate,
"avg_win": metrics.get("avg_win", 0),
"avg_loss": metrics.get("avg_loss", 0),
"best_trade": metrics.get("best_trade", 0),
"worst_trade": metrics.get("worst_trade", 0),
"avg_duration_days": avg_duration,
"risk_reward_ratio": metrics.get("risk_reward_ratio", 0),
}
def _identify_strengths(self, metrics: dict[str, float]) -> list[str]:
"""Identify strategy strengths."""
strengths = []
if metrics.get("sharpe_ratio", 0) >= 1.5:
strengths.append("Excellent risk-adjusted returns")
if metrics.get("win_rate", 0) >= 0.60:
strengths.append("High win rate")
if abs(metrics.get("max_drawdown", 0)) <= 0.15:
strengths.append("Low maximum drawdown")
if metrics.get("profit_factor", 0) >= 1.5:
strengths.append("Strong profit factor")
if metrics.get("sortino_ratio", 0) >= 2.0:
strengths.append("Excellent downside protection")
if metrics.get("calmar_ratio", 0) >= 1.0:
strengths.append("Good return vs drawdown ratio")
if metrics.get("recovery_factor", 0) >= 3.0:
strengths.append("Quick drawdown recovery")
if metrics.get("total_return", 0) >= 0.30:
strengths.append("High total returns")
return strengths if strengths else ["Consistent performance"]
def _identify_weaknesses(self, metrics: dict[str, float]) -> list[str]:
"""Identify strategy weaknesses."""
weaknesses = []
if metrics.get("sharpe_ratio", 0) < 0.5:
weaknesses.append("Poor risk-adjusted returns")
if metrics.get("win_rate", 0) < 0.40:
weaknesses.append("Low win rate")
if abs(metrics.get("max_drawdown", 0)) > 0.30:
weaknesses.append("High maximum drawdown")
if metrics.get("profit_factor", 0) < 1.0:
weaknesses.append("Unprofitable trades overall")
if metrics.get("total_trades", 0) < 10:
weaknesses.append("Insufficient trade signals")
if metrics.get("sortino_ratio", 0) < 0:
weaknesses.append("Poor downside protection")
if metrics.get("total_return", 0) < 0:
weaknesses.append("Negative returns")
return weaknesses if weaknesses else ["Room for optimization"]
def _generate_recommendations(self, metrics: dict[str, float]) -> list[str]:
"""Generate improvement recommendations."""
recommendations = []
# Risk management recommendations
if abs(metrics.get("max_drawdown", 0)) > 0.25:
recommendations.append(
"Implement tighter stop-loss rules to reduce drawdowns"
)
# Win rate improvements
if metrics.get("win_rate", 0) < 0.45:
recommendations.append("Refine entry signals to improve win rate")
# Trade frequency
if metrics.get("total_trades", 0) < 20:
recommendations.append(
"Consider more sensitive parameters for increased signals"
)
elif metrics.get("total_trades", 0) > 200:
recommendations.append("Filter signals to reduce overtrading")
# Risk-reward optimization
if metrics.get("risk_reward_ratio", 0) < 1.5:
recommendations.append("Adjust exit strategy for better risk-reward ratio")
# Profit factor improvements
if metrics.get("profit_factor", 0) < 1.2:
recommendations.append(
"Focus on cutting losses quicker and letting winners run"
)
# Sharpe ratio improvements
if metrics.get("sharpe_ratio", 0) < 1.0:
recommendations.append("Consider position sizing based on volatility")
# Kelly criterion
kelly = metrics.get("kelly_criterion", 0)
if kelly > 0 and kelly < 0.25:
recommendations.append(
f"Consider position size of {kelly * 100:.1f}% based on Kelly Criterion"
)
return (
recommendations
if recommendations
else ["Strategy performing well, consider live testing"]
)
def _generate_summary(self, metrics: dict[str, float]) -> str:
"""Generate a text summary of the backtest."""
total_return = metrics.get("total_return", 0) * 100
sharpe = metrics.get("sharpe_ratio", 0)
max_dd = abs(metrics.get("max_drawdown", 0)) * 100
win_rate = metrics.get("win_rate", 0) * 100
total_trades = metrics.get("total_trades", 0)
summary = f"The strategy generated a {total_return:.1f}% return with a Sharpe ratio of {sharpe:.2f}. "
summary += f"Maximum drawdown was {max_dd:.1f}% with a {win_rate:.1f}% win rate across {total_trades} trades. "
if sharpe >= 1.5 and max_dd <= 20:
summary += (
"Overall performance is excellent with strong risk-adjusted returns."
)
elif sharpe >= 1.0 and max_dd <= 30:
summary += "Performance is good with acceptable risk levels."
elif sharpe >= 0.5:
summary += "Performance is moderate and could benefit from optimization."
else:
summary += "Performance needs significant improvement before live trading."
return summary
def compare_strategies(self, results_list: list[dict[str, Any]]) -> dict[str, Any]:
"""Compare multiple strategy results.
Args:
results_list: List of backtest results to compare
Returns:
Comparison analysis with rankings
"""
if not results_list:
return {"error": "No results to compare"}
comparisons = []
for result in results_list:
metrics = result.get("metrics", {})
comparisons.append(
{
"strategy": result.get("strategy", "Unknown"),
"parameters": result.get("parameters", {}),
"total_return": metrics.get("total_return", 0),
"sharpe_ratio": metrics.get("sharpe_ratio", 0),
"max_drawdown": abs(metrics.get("max_drawdown", 0)),
"win_rate": metrics.get("win_rate", 0),
"profit_factor": metrics.get("profit_factor", 0),
"total_trades": metrics.get("total_trades", 0),
"grade": self._grade_performance(metrics),
}
)
# Sort by Sharpe ratio as default ranking
comparisons.sort(key=lambda x: x["sharpe_ratio"], reverse=True)
# Add rankings
for i, comp in enumerate(comparisons, 1):
comp["rank"] = i
# Find best in each category
best_return = max(comparisons, key=lambda x: x["total_return"])
best_sharpe = max(comparisons, key=lambda x: x["sharpe_ratio"])
best_drawdown = min(comparisons, key=lambda x: x["max_drawdown"])
best_win_rate = max(comparisons, key=lambda x: x["win_rate"])
return {
"rankings": comparisons,
"best_overall": comparisons[0] if comparisons else None,
"best_return": best_return,
"best_sharpe": best_sharpe,
"best_drawdown": best_drawdown,
"best_win_rate": best_win_rate,
"summary": self._generate_comparison_summary(comparisons),
}
def _generate_comparison_summary(self, comparisons: list[dict]) -> str:
"""Generate summary of strategy comparison."""
if not comparisons:
return "No strategies to compare"
best = comparisons[0]
summary = f"The best performing strategy is {best['strategy']} "
summary += f"with a Sharpe ratio of {best['sharpe_ratio']:.2f} "
summary += f"and total return of {best['total_return'] * 100:.1f}%. "
if len(comparisons) > 1:
summary += (
f"It outperformed {len(comparisons) - 1} other strategies tested."
)
return summary
def _create_empty_backtest_results(
self, initial_capital: float, error: str = None
) -> dict[str, Any]:
"""Create empty backtest results when no valid signals are available.
Args:
initial_capital: Initial capital amount
error: Optional error message to include
Returns:
Empty backtest results dictionary
"""
return {
"metrics": {
"total_return": 0.0,
"annual_return": 0.0,
"sharpe_ratio": 0.0,
"max_drawdown": 0.0,
"win_rate": 0.0,
"total_trades": 0,
"profit_factor": 0.0,
},
"trades": [],
"equity_curve": {str(0): initial_capital},
"drawdown_series": {str(0): 0.0},
"error": error,
"message": "No trading signals generated - empty backtest results returned",
}
def _create_buyhold_backtest_results(
self, data: pd.DataFrame, initial_capital: float
) -> dict[str, Any]:
"""Create buy-and-hold backtest results when no trading signals are available.
Args:
data: Price data
initial_capital: Initial capital amount
Returns:
Buy-and-hold backtest results dictionary
"""
try:
# Calculate buy and hold performance
close = data["close"] if "close" in data.columns else data["Close"]
if len(close) == 0:
return self._create_empty_backtest_results(initial_capital)
start_price = close.iloc[0]
end_price = close.iloc[-1]
total_return = (end_price - start_price) / start_price
# Simple buy and hold equity curve
normalized_prices = close / start_price * initial_capital
equity_curve = {
str(idx): convert_to_native(val)
for idx, val in normalized_prices.to_dict().items()
}
# Calculate drawdown for buy and hold
cummax = normalized_prices.expanding().max()
drawdown = (normalized_prices - cummax) / cummax
drawdown_series = {
str(idx): convert_to_native(val)
for idx, val in drawdown.to_dict().items()
}
return {
"metrics": {
"total_return": float(total_return),
"annual_return": float(total_return * 252 / len(data))
if len(data) > 0
else 0.0,
"sharpe_ratio": 0.0, # Cannot calculate without trading
"max_drawdown": float(drawdown.min()) if len(drawdown) > 0 else 0.0,
"win_rate": 0.0, # No trades
"total_trades": 0,
"profit_factor": 0.0, # No trades
},
"trades": [],
"equity_curve": equity_curve,
"drawdown_series": drawdown_series,
"message": "No trading signals generated - returning buy-and-hold performance",
}
except Exception as e:
logger.error(f"Error creating buy-and-hold results: {e}")
return self._create_empty_backtest_results(initial_capital, error=str(e))
```
--------------------------------------------------------------------------------
/maverick_mcp/data/validation.py:
--------------------------------------------------------------------------------
```python
"""
Data Quality Validation Module for MaverickMCP.
This module provides comprehensive data validation functionality for
stock price data, backtesting data, and general data quality checks.
Ensures data integrity before processing and backtesting operations.
"""
import logging
from datetime import date, datetime
from typing import Any
import numpy as np
import pandas as pd
from pandas import DataFrame
from maverick_mcp.exceptions import ValidationError
logger = logging.getLogger(__name__)
class DataValidator:
"""Comprehensive data validation for stock market and backtesting data."""
@staticmethod
def validate_date_range(
start_date: str | datetime | date,
end_date: str | datetime | date,
allow_future: bool = False,
max_range_days: int | None = None,
) -> tuple[datetime, datetime]:
"""
Validate date range for data queries.
Args:
start_date: Start date for the range
end_date: End date for the range
allow_future: Whether to allow future dates
max_range_days: Maximum allowed days in range
Returns:
Tuple of validated (start_date, end_date) as datetime objects
Raises:
ValidationError: If dates are invalid
"""
# Convert to datetime objects
if isinstance(start_date, str):
try:
start_dt = pd.to_datetime(start_date).to_pydatetime()
except Exception as e:
raise ValidationError(f"Invalid start_date format: {start_date}") from e
elif isinstance(start_date, date):
start_dt = datetime.combine(start_date, datetime.min.time())
else:
start_dt = start_date
if isinstance(end_date, str):
try:
end_dt = pd.to_datetime(end_date).to_pydatetime()
except Exception as e:
raise ValidationError(f"Invalid end_date format: {end_date}") from e
elif isinstance(end_date, date):
end_dt = datetime.combine(end_date, datetime.min.time())
else:
end_dt = end_date
# Validate chronological order
if start_dt > end_dt:
raise ValidationError(
f"Start date {start_dt.date()} must be before end date {end_dt.date()}"
)
# Check future dates if not allowed
if not allow_future:
today = datetime.now().date()
if start_dt.date() > today:
raise ValidationError(
f"Start date {start_dt.date()} cannot be in the future"
)
if end_dt.date() > today:
logger.warning(
f"End date {end_dt.date()} is in the future, using today instead"
)
end_dt = datetime.combine(today, datetime.min.time())
# Check maximum range
if max_range_days:
range_days = (end_dt - start_dt).days
if range_days > max_range_days:
raise ValidationError(
f"Date range too large: {range_days} days (max: {max_range_days} days)"
)
return start_dt, end_dt
@staticmethod
def validate_data_quality(
data: DataFrame,
required_columns: list[str] | None = None,
min_rows: int = 1,
max_missing_ratio: float = 0.1,
check_duplicates: bool = True,
) -> dict[str, Any]:
"""
Validate general data quality of a DataFrame.
Args:
data: DataFrame to validate
required_columns: List of required columns
min_rows: Minimum number of rows required
max_missing_ratio: Maximum ratio of missing values allowed
check_duplicates: Whether to check for duplicate rows
Returns:
Dictionary with validation results and quality metrics
Raises:
ValidationError: If validation fails
"""
if data is None or data.empty:
raise ValidationError("Data is None or empty")
validation_results = {
"passed": True,
"warnings": [],
"errors": [],
"metrics": {
"total_rows": len(data),
"total_columns": len(data.columns),
"missing_values": data.isnull().sum().sum(),
"duplicate_rows": 0,
},
}
# Check minimum rows
if len(data) < min_rows:
error_msg = f"Insufficient data: {len(data)} rows (minimum: {min_rows})"
validation_results["errors"].append(error_msg)
validation_results["passed"] = False
# Check required columns
if required_columns:
missing_cols = set(required_columns) - set(data.columns)
if missing_cols:
error_msg = f"Missing required columns: {list(missing_cols)}"
validation_results["errors"].append(error_msg)
validation_results["passed"] = False
# Check missing values ratio
total_cells = len(data) * len(data.columns)
if total_cells > 0:
missing_ratio = (
validation_results["metrics"]["missing_values"] / total_cells
)
validation_results["metrics"]["missing_ratio"] = missing_ratio
if missing_ratio > max_missing_ratio:
error_msg = f"Too many missing values: {missing_ratio:.2%} (max: {max_missing_ratio:.2%})"
validation_results["errors"].append(error_msg)
validation_results["passed"] = False
# Check for duplicate rows
if check_duplicates:
duplicate_count = data.duplicated().sum()
validation_results["metrics"]["duplicate_rows"] = duplicate_count
if duplicate_count > 0:
warning_msg = f"Found {duplicate_count} duplicate rows"
validation_results["warnings"].append(warning_msg)
# Check for completely empty columns
empty_columns = data.columns[data.isnull().all()].tolist()
if empty_columns:
warning_msg = f"Completely empty columns: {empty_columns}"
validation_results["warnings"].append(warning_msg)
return validation_results
@staticmethod
def validate_price_data(
data: DataFrame, symbol: str = "Unknown", strict_mode: bool = True
) -> dict[str, Any]:
"""
Validate OHLCV stock price data integrity.
Args:
data: DataFrame with OHLCV data
symbol: Stock symbol for error messages
strict_mode: Whether to apply strict validation rules
Returns:
Dictionary with validation results and metrics
Raises:
ValidationError: If validation fails in strict mode
"""
expected_columns = ["open", "high", "low", "close"]
# Basic data quality check
quality_results = DataValidator.validate_data_quality(
data,
required_columns=expected_columns,
min_rows=1,
max_missing_ratio=0.05, # Allow 5% missing values for price data
)
validation_results = {
"passed": quality_results["passed"],
"warnings": quality_results["warnings"].copy(),
"errors": quality_results["errors"].copy(),
"metrics": quality_results["metrics"].copy(),
"symbol": symbol,
"price_validation": {
"negative_prices": 0,
"zero_prices": 0,
"invalid_ohlc_relationships": 0,
"extreme_price_changes": 0,
"volume_anomalies": 0,
},
}
if data.empty:
return validation_results
# Check for negative prices
price_cols = [col for col in expected_columns if col in data.columns]
for col in price_cols:
if col in data.columns:
negative_count = (data[col] < 0).sum()
if negative_count > 0:
error_msg = (
f"Found {negative_count} negative {col} prices for {symbol}"
)
validation_results["errors"].append(error_msg)
validation_results["price_validation"]["negative_prices"] += (
negative_count
)
validation_results["passed"] = False
# Check for zero prices
for col in price_cols:
if col in data.columns:
zero_count = (data[col] == 0).sum()
if zero_count > 0:
warning_msg = f"Found {zero_count} zero {col} prices for {symbol}"
validation_results["warnings"].append(warning_msg)
validation_results["price_validation"]["zero_prices"] += zero_count
# Validate OHLC relationships (High >= Open, Close, Low; Low <= Open, Close)
if all(col in data.columns for col in ["open", "high", "low", "close"]):
# High should be >= Open, Low, Close
high_violations = (
(data["high"] < data["open"])
| (data["high"] < data["low"])
| (data["high"] < data["close"])
).sum()
# Low should be <= Open, High, Close
low_violations = (
(data["low"] > data["open"])
| (data["low"] > data["high"])
| (data["low"] > data["close"])
).sum()
total_ohlc_violations = high_violations + low_violations
if total_ohlc_violations > 0:
error_msg = f"OHLC relationship violations for {symbol}: {total_ohlc_violations} bars"
validation_results["errors"].append(error_msg)
validation_results["price_validation"]["invalid_ohlc_relationships"] = (
total_ohlc_violations
)
validation_results["passed"] = False
# Check for extreme price changes (>50% daily moves)
if "close" in data.columns and len(data) > 1:
daily_returns = data["close"].pct_change().dropna()
extreme_changes = (daily_returns.abs() > 0.5).sum()
if extreme_changes > 0:
warning_msg = (
f"Found {extreme_changes} extreme price changes (>50%) for {symbol}"
)
validation_results["warnings"].append(warning_msg)
validation_results["price_validation"]["extreme_price_changes"] = (
extreme_changes
)
# Validate volume data if present
if "volume" in data.columns:
negative_volume = (data["volume"] < 0).sum()
if negative_volume > 0:
error_msg = (
f"Found {negative_volume} negative volume values for {symbol}"
)
validation_results["errors"].append(error_msg)
validation_results["price_validation"]["volume_anomalies"] += (
negative_volume
)
validation_results["passed"] = False
# Check for suspiciously high volume (>10x median)
if len(data) > 10:
median_volume = data["volume"].median()
if median_volume > 0:
high_volume_count = (data["volume"] > median_volume * 10).sum()
if high_volume_count > 0:
validation_results["price_validation"]["volume_anomalies"] += (
high_volume_count
)
# Check data continuity (gaps in date index)
if hasattr(data.index, "to_series"):
date_diffs = data.index.to_series().diff()[1:]
if len(date_diffs) > 0:
# Check for gaps larger than 7 days (weekend + holiday)
large_gaps = (date_diffs > pd.Timedelta(days=7)).sum()
if large_gaps > 0:
warning_msg = f"Found {large_gaps} large time gaps (>7 days) in data for {symbol}"
validation_results["warnings"].append(warning_msg)
# Raise error in strict mode if validation failed
if strict_mode and not validation_results["passed"]:
error_summary = "; ".join(validation_results["errors"])
raise ValidationError(
f"Price data validation failed for {symbol}: {error_summary}"
)
return validation_results
@staticmethod
def validate_batch_data(
batch_data: dict[str, DataFrame],
min_symbols: int = 1,
max_symbols: int = 100,
validate_individual: bool = True,
) -> dict[str, Any]:
"""
Validate batch data containing multiple symbol DataFrames.
Args:
batch_data: Dictionary mapping symbols to DataFrames
min_symbols: Minimum number of symbols required
max_symbols: Maximum number of symbols allowed
validate_individual: Whether to validate each symbol's data
Returns:
Dictionary with batch validation results
Raises:
ValidationError: If batch validation fails
"""
if not isinstance(batch_data, dict):
raise ValidationError("Batch data must be a dictionary")
validation_results = {
"passed": True,
"warnings": [],
"errors": [],
"metrics": {
"total_symbols": len(batch_data),
"valid_symbols": 0,
"invalid_symbols": 0,
"empty_symbols": 0,
"total_rows": 0,
},
"symbol_results": {},
}
# Check symbol count
symbol_count = len(batch_data)
if symbol_count < min_symbols:
error_msg = f"Insufficient symbols: {symbol_count} (minimum: {min_symbols})"
validation_results["errors"].append(error_msg)
validation_results["passed"] = False
if symbol_count > max_symbols:
error_msg = f"Too many symbols: {symbol_count} (maximum: {max_symbols})"
validation_results["errors"].append(error_msg)
validation_results["passed"] = False
# Validate each symbol's data
for symbol, data in batch_data.items():
try:
if data is None or data.empty:
validation_results["metrics"]["empty_symbols"] += 1
validation_results["symbol_results"][symbol] = {
"passed": False,
"error": "Empty or None data",
}
continue
if validate_individual:
# Validate price data for each symbol
symbol_validation = DataValidator.validate_price_data(
data, symbol, strict_mode=False
)
validation_results["symbol_results"][symbol] = symbol_validation
if symbol_validation["passed"]:
validation_results["metrics"]["valid_symbols"] += 1
else:
validation_results["metrics"]["invalid_symbols"] += 1
# Aggregate errors
for error in symbol_validation["errors"]:
validation_results["errors"].append(f"{symbol}: {error}")
# Don't fail entire batch for individual symbol issues
# validation_results["passed"] = False
else:
validation_results["metrics"]["valid_symbols"] += 1
validation_results["symbol_results"][symbol] = {
"passed": True,
"rows": len(data),
}
validation_results["metrics"]["total_rows"] += len(data)
except Exception as e:
validation_results["metrics"]["invalid_symbols"] += 1
validation_results["symbol_results"][symbol] = {
"passed": False,
"error": str(e),
}
validation_results["errors"].append(f"{symbol}: Validation error - {e}")
# Summary metrics
validation_results["metrics"]["success_rate"] = (
validation_results["metrics"]["valid_symbols"] / symbol_count
if symbol_count > 0
else 0.0
)
# Add warnings for low success rate
if validation_results["metrics"]["success_rate"] < 0.8:
warning_msg = (
f"Low success rate: {validation_results['metrics']['success_rate']:.1%}"
)
validation_results["warnings"].append(warning_msg)
return validation_results
@staticmethod
def validate_technical_indicators(
data: DataFrame, indicators: dict[str, Any], symbol: str = "Unknown"
) -> dict[str, Any]:
"""
Validate technical indicator data.
Args:
data: DataFrame with technical indicator data
indicators: Dictionary of indicator configurations
symbol: Symbol name for error messages
Returns:
Dictionary with validation results
"""
validation_results = {
"passed": True,
"warnings": [],
"errors": [],
"metrics": {
"total_indicators": len(indicators),
"valid_indicators": 0,
"nan_counts": {},
},
}
for indicator_name, _config in indicators.items():
if indicator_name not in data.columns:
error_msg = f"Missing indicator '{indicator_name}' for {symbol}"
validation_results["errors"].append(error_msg)
validation_results["passed"] = False
continue
indicator_data = data[indicator_name]
# Count NaN values
nan_count = indicator_data.isnull().sum()
validation_results["metrics"]["nan_counts"][indicator_name] = nan_count
# Check for excessive NaN values
if len(data) > 0:
nan_ratio = nan_count / len(data)
if nan_ratio > 0.5: # More than 50% NaN
warning_msg = (
f"High NaN ratio for '{indicator_name}': {nan_ratio:.1%}"
)
validation_results["warnings"].append(warning_msg)
elif nan_ratio == 0:
validation_results["metrics"]["valid_indicators"] += 1
# Check for infinite values
if np.any(np.isinf(indicator_data.fillna(0))):
error_msg = f"Infinite values found in '{indicator_name}' for {symbol}"
validation_results["errors"].append(error_msg)
validation_results["passed"] = False
return validation_results
@classmethod
def create_validation_report(
cls, validation_results: dict[str, Any], include_warnings: bool = True
) -> str:
"""
Create a human-readable validation report.
Args:
validation_results: Results from validation methods
include_warnings: Whether to include warnings in report
Returns:
Formatted validation report string
"""
lines = []
# Header
status = "✅ PASSED" if validation_results.get("passed", False) else "❌ FAILED"
lines.append(f"=== Data Validation Report - {status} ===")
lines.append("")
# Metrics
if "metrics" in validation_results:
lines.append("📊 Metrics:")
for key, value in validation_results["metrics"].items():
if isinstance(value, float) and 0 < value < 1:
lines.append(f" • {key}: {value:.2%}")
else:
lines.append(f" • {key}: {value}")
lines.append("")
# Errors
if validation_results.get("errors"):
lines.append("❌ Errors:")
for error in validation_results["errors"]:
lines.append(f" • {error}")
lines.append("")
# Warnings
if include_warnings and validation_results.get("warnings"):
lines.append("⚠️ Warnings:")
for warning in validation_results["warnings"]:
lines.append(f" • {warning}")
lines.append("")
# Symbol-specific results (for batch validation)
if "symbol_results" in validation_results:
failed_symbols = [
symbol
for symbol, result in validation_results["symbol_results"].items()
if not result.get("passed", True)
]
if failed_symbols:
lines.append(f"🔍 Failed Symbols ({len(failed_symbols)}):")
for symbol in failed_symbols:
result = validation_results["symbol_results"][symbol]
error = result.get("error", "Unknown error")
lines.append(f" • {symbol}: {error}")
lines.append("")
return "\n".join(lines)
# Convenience functions for common validation scenarios
def validate_stock_data(
data: DataFrame,
symbol: str,
start_date: str | None = None,
end_date: str | None = None,
strict: bool = True,
) -> dict[str, Any]:
"""
Convenience function to validate stock data with date range.
Args:
data: Stock price DataFrame
symbol: Stock symbol
start_date: Expected start date (optional)
end_date: Expected end date (optional)
strict: Whether to use strict validation
Returns:
Combined validation results
"""
validator = DataValidator()
# Validate price data
price_results = validator.validate_price_data(data, symbol, strict_mode=strict)
# Validate date range if provided
if start_date and end_date:
try:
validator.validate_date_range(start_date, end_date)
price_results["date_range_valid"] = True
except ValidationError as e:
price_results["date_range_valid"] = False
price_results["errors"].append(f"Date range validation failed: {e}")
price_results["passed"] = False
return price_results
def validate_backtest_data(
data: dict[str, DataFrame], min_history_days: int = 30
) -> dict[str, Any]:
"""
Convenience function to validate backtesting data requirements.
Args:
data: Dictionary of symbol -> DataFrame mappings
min_history_days: Minimum days of history required
Returns:
Validation results for backtesting
"""
validator = DataValidator()
# Validate batch data
batch_results = validator.validate_batch_data(data, validate_individual=True)
# Additional backtesting-specific checks
for symbol, df in data.items():
if not df.empty and len(df) < min_history_days:
warning_msg = (
f"{symbol}: Only {len(df)} days of data (minimum: {min_history_days})"
)
batch_results["warnings"].append(warning_msg)
return batch_results
```
--------------------------------------------------------------------------------
/maverick_mcp/workflows/agents/validator_agent.py:
--------------------------------------------------------------------------------
```python
"""
Validator Agent for backtesting results validation and robustness testing.
This agent performs walk-forward analysis, Monte Carlo simulation, and robustness
testing to validate optimization results and provide confidence-scored recommendations.
"""
import logging
import statistics
from datetime import datetime, timedelta
from typing import Any
from maverick_mcp.backtesting import StrategyOptimizer, VectorBTEngine
from maverick_mcp.workflows.state import BacktestingWorkflowState
logger = logging.getLogger(__name__)
class ValidatorAgent:
"""Intelligent validator for backtesting results and strategy robustness."""
def __init__(
self,
vectorbt_engine: VectorBTEngine | None = None,
strategy_optimizer: StrategyOptimizer | None = None,
):
"""Initialize validator agent.
Args:
vectorbt_engine: VectorBT backtesting engine
strategy_optimizer: Strategy optimization engine
"""
self.engine = vectorbt_engine or VectorBTEngine()
self.optimizer = strategy_optimizer or StrategyOptimizer(self.engine)
# Validation criteria for different regimes
self.REGIME_VALIDATION_CRITERIA = {
"trending": {
"min_sharpe_ratio": 0.8,
"max_drawdown_threshold": 0.25,
"min_total_return": 0.10,
"min_win_rate": 0.35,
"stability_threshold": 0.7,
},
"ranging": {
"min_sharpe_ratio": 1.0, # Higher standard for ranging markets
"max_drawdown_threshold": 0.15,
"min_total_return": 0.05,
"min_win_rate": 0.45,
"stability_threshold": 0.8,
},
"volatile": {
"min_sharpe_ratio": 0.6, # Lower expectation in volatile markets
"max_drawdown_threshold": 0.35,
"min_total_return": 0.08,
"min_win_rate": 0.30,
"stability_threshold": 0.6,
},
"volatile_trending": {
"min_sharpe_ratio": 0.7,
"max_drawdown_threshold": 0.30,
"min_total_return": 0.12,
"min_win_rate": 0.35,
"stability_threshold": 0.65,
},
"low_volume": {
"min_sharpe_ratio": 0.9,
"max_drawdown_threshold": 0.20,
"min_total_return": 0.06,
"min_win_rate": 0.40,
"stability_threshold": 0.75,
},
"unknown": {
"min_sharpe_ratio": 0.8,
"max_drawdown_threshold": 0.20,
"min_total_return": 0.08,
"min_win_rate": 0.40,
"stability_threshold": 0.7,
},
}
# Robustness scoring weights
self.ROBUSTNESS_WEIGHTS = {
"walk_forward_consistency": 0.3,
"parameter_sensitivity": 0.2,
"monte_carlo_stability": 0.2,
"out_of_sample_performance": 0.3,
}
logger.info("ValidatorAgent initialized")
async def validate_strategies(
self, state: BacktestingWorkflowState
) -> BacktestingWorkflowState:
"""Validate optimized strategies through comprehensive testing.
Args:
state: Current workflow state with optimization results
Returns:
Updated state with validation results and final recommendations
"""
start_time = datetime.now()
try:
logger.info(
f"Validating {len(state.best_parameters)} strategies for {state.symbol}"
)
# Get validation criteria for current regime
validation_criteria = self._get_validation_criteria(state.market_regime)
# Perform validation for each strategy
walk_forward_results = {}
monte_carlo_results = {}
out_of_sample_performance = {}
robustness_scores = {}
validation_warnings = []
for strategy, parameters in state.best_parameters.items():
try:
logger.info(f"Validating {strategy} strategy...")
# Walk-forward analysis
wf_result = await self._run_walk_forward_analysis(
state, strategy, parameters
)
walk_forward_results[strategy] = wf_result
# Monte Carlo simulation
mc_result = await self._run_monte_carlo_simulation(
state, strategy, parameters
)
monte_carlo_results[strategy] = mc_result
# Out-of-sample testing
oos_result = await self._run_out_of_sample_test(
state, strategy, parameters
)
out_of_sample_performance[strategy] = oos_result
# Calculate robustness score
robustness_score = self._calculate_robustness_score(
wf_result, mc_result, oos_result, validation_criteria
)
robustness_scores[strategy] = robustness_score
# Check for validation warnings
warnings = self._check_validation_warnings(
strategy, wf_result, mc_result, oos_result, validation_criteria
)
validation_warnings.extend(warnings)
logger.info(
f"Validated {strategy}: robustness score {robustness_score:.2f}"
)
except Exception as e:
logger.error(f"Failed to validate {strategy}: {e}")
robustness_scores[strategy] = 0.0
validation_warnings.append(
f"{strategy}: Validation failed - {str(e)}"
)
# Generate final recommendations
final_ranking = self._generate_final_ranking(
state.best_parameters, robustness_scores, state.strategy_rankings
)
# Select recommended strategy
recommended_strategy, recommendation_confidence = (
self._select_recommended_strategy(
final_ranking, robustness_scores, state.regime_confidence
)
)
# Perform risk assessment
risk_assessment = self._perform_risk_assessment(
recommended_strategy,
walk_forward_results,
monte_carlo_results,
validation_criteria,
)
# Update state
state.walk_forward_results = walk_forward_results
state.monte_carlo_results = monte_carlo_results
state.out_of_sample_performance = out_of_sample_performance
state.robustness_score = robustness_scores
state.validation_warnings = validation_warnings
state.final_strategy_ranking = final_ranking
state.recommended_strategy = recommended_strategy
state.recommended_parameters = state.best_parameters.get(
recommended_strategy, {}
)
state.recommendation_confidence = recommendation_confidence
state.risk_assessment = risk_assessment
# Update workflow status
state.workflow_status = "completed"
state.current_step = "validation_completed"
state.steps_completed.append("strategy_validation")
# Record total execution time
total_execution_time = (datetime.now() - start_time).total_seconds() * 1000
state.total_execution_time_ms = (
state.regime_analysis_time_ms
+ state.optimization_time_ms
+ total_execution_time
)
logger.info(
f"Strategy validation completed for {state.symbol}: "
f"Recommended {recommended_strategy} with confidence {recommendation_confidence:.2f}"
)
return state
except Exception as e:
error_info = {
"step": "strategy_validation",
"error": str(e),
"timestamp": datetime.now().isoformat(),
"symbol": state.symbol,
}
state.errors_encountered.append(error_info)
# Fallback recommendation
if state.best_parameters:
fallback_strategy = list(state.best_parameters.keys())[0]
state.recommended_strategy = fallback_strategy
state.recommended_parameters = state.best_parameters[fallback_strategy]
state.recommendation_confidence = 0.3
state.fallback_strategies_used.append("validation_fallback")
logger.error(f"Strategy validation failed for {state.symbol}: {e}")
return state
def _get_validation_criteria(self, regime: str) -> dict[str, Any]:
"""Get validation criteria based on market regime."""
return self.REGIME_VALIDATION_CRITERIA.get(
regime, self.REGIME_VALIDATION_CRITERIA["unknown"]
)
async def _run_walk_forward_analysis(
self, state: BacktestingWorkflowState, strategy: str, parameters: dict[str, Any]
) -> dict[str, Any]:
"""Run walk-forward analysis for strategy validation."""
try:
# Calculate walk-forward windows
start_dt = datetime.strptime(state.start_date, "%Y-%m-%d")
end_dt = datetime.strptime(state.end_date, "%Y-%m-%d")
total_days = (end_dt - start_dt).days
# Use appropriate window sizes based on data length
if total_days > 500: # ~2 years
window_size = 252 # 1 year
step_size = 63 # 3 months
elif total_days > 250: # ~1 year
window_size = 126 # 6 months
step_size = 42 # 6 weeks
else:
window_size = 63 # 3 months
step_size = 21 # 3 weeks
# Run walk-forward analysis using the optimizer
wf_result = await self.optimizer.walk_forward_analysis(
symbol=state.symbol,
strategy_type=strategy,
parameters=parameters,
start_date=state.start_date,
end_date=state.end_date,
window_size=window_size,
step_size=step_size,
)
return wf_result
except Exception as e:
logger.error(f"Walk-forward analysis failed for {strategy}: {e}")
return {"error": str(e), "consistency_score": 0.0}
async def _run_monte_carlo_simulation(
self, state: BacktestingWorkflowState, strategy: str, parameters: dict[str, Any]
) -> dict[str, Any]:
"""Run Monte Carlo simulation for strategy validation."""
try:
# First run a backtest to get base results
backtest_result = await self.engine.run_backtest(
symbol=state.symbol,
strategy_type=strategy,
parameters=parameters,
start_date=state.start_date,
end_date=state.end_date,
initial_capital=state.initial_capital,
)
# Run Monte Carlo simulation
mc_result = await self.optimizer.monte_carlo_simulation(
backtest_results=backtest_result,
num_simulations=500, # Reduced for performance
)
return mc_result
except Exception as e:
logger.error(f"Monte Carlo simulation failed for {strategy}: {e}")
return {"error": str(e), "stability_score": 0.0}
async def _run_out_of_sample_test(
self, state: BacktestingWorkflowState, strategy: str, parameters: dict[str, Any]
) -> dict[str, float]:
"""Run out-of-sample testing on holdout data."""
try:
# Use last 30% of data as out-of-sample
start_dt = datetime.strptime(state.start_date, "%Y-%m-%d")
end_dt = datetime.strptime(state.end_date, "%Y-%m-%d")
total_days = (end_dt - start_dt).days
oos_days = int(total_days * 0.3)
oos_start = end_dt - timedelta(days=oos_days)
# Run backtest on out-of-sample period
oos_result = await self.engine.run_backtest(
symbol=state.symbol,
strategy_type=strategy,
parameters=parameters,
start_date=oos_start.strftime("%Y-%m-%d"),
end_date=state.end_date,
initial_capital=state.initial_capital,
)
return {
"total_return": oos_result["metrics"]["total_return"],
"sharpe_ratio": oos_result["metrics"]["sharpe_ratio"],
"max_drawdown": oos_result["metrics"]["max_drawdown"],
"win_rate": oos_result["metrics"]["win_rate"],
"total_trades": oos_result["metrics"]["total_trades"],
}
except Exception as e:
logger.error(f"Out-of-sample test failed for {strategy}: {e}")
return {
"total_return": 0.0,
"sharpe_ratio": 0.0,
"max_drawdown": 0.0,
"win_rate": 0.0,
"total_trades": 0,
}
def _calculate_robustness_score(
self,
wf_result: dict[str, Any],
mc_result: dict[str, Any],
oos_result: dict[str, float],
validation_criteria: dict[str, Any],
) -> float:
"""Calculate overall robustness score for a strategy."""
scores = {}
# Walk-forward consistency score
if "consistency_score" in wf_result:
scores["walk_forward_consistency"] = wf_result["consistency_score"]
elif "error" not in wf_result and "periods" in wf_result:
# Calculate consistency from period results
period_returns = [
p.get("total_return", 0) for p in wf_result.get("periods", [])
]
if period_returns:
# Lower std deviation relative to mean = higher consistency
mean_return = statistics.mean(period_returns)
std_return = (
statistics.stdev(period_returns) if len(period_returns) > 1 else 0
)
consistency = max(0, 1 - (std_return / max(abs(mean_return), 0.01)))
scores["walk_forward_consistency"] = min(1.0, consistency)
else:
scores["walk_forward_consistency"] = 0.0
else:
scores["walk_forward_consistency"] = 0.0
# Parameter sensitivity (inverse of standard error)
scores["parameter_sensitivity"] = 0.7 # Default moderate sensitivity
# Monte Carlo stability
if "stability_score" in mc_result:
scores["monte_carlo_stability"] = mc_result["stability_score"]
elif "error" not in mc_result and "percentiles" in mc_result:
# Calculate stability from percentile spread
percentiles = mc_result["percentiles"]
p10 = percentiles.get("10", 0)
p90 = percentiles.get("90", 0)
median = percentiles.get("50", 0)
if median != 0:
stability = 1 - abs(p90 - p10) / abs(median)
scores["monte_carlo_stability"] = max(0, min(1, stability))
else:
scores["monte_carlo_stability"] = 0.0
else:
scores["monte_carlo_stability"] = 0.0
# Out-of-sample performance score
oos_score = 0.0
if oos_result["sharpe_ratio"] >= validation_criteria["min_sharpe_ratio"]:
oos_score += 0.3
if (
abs(oos_result["max_drawdown"])
<= validation_criteria["max_drawdown_threshold"]
):
oos_score += 0.3
if oos_result["total_return"] >= validation_criteria["min_total_return"]:
oos_score += 0.2
if oos_result["win_rate"] >= validation_criteria["min_win_rate"]:
oos_score += 0.2
scores["out_of_sample_performance"] = oos_score
# Calculate weighted robustness score
robustness_score = sum(
scores[component] * self.ROBUSTNESS_WEIGHTS[component]
for component in self.ROBUSTNESS_WEIGHTS
)
return max(0.0, min(1.0, robustness_score))
def _check_validation_warnings(
self,
strategy: str,
wf_result: dict[str, Any],
mc_result: dict[str, Any],
oos_result: dict[str, float],
validation_criteria: dict[str, Any],
) -> list[str]:
"""Check for validation warnings and concerns."""
warnings = []
# Walk-forward analysis warnings
if "error" in wf_result:
warnings.append(f"{strategy}: Walk-forward analysis failed")
elif (
wf_result.get("consistency_score", 0)
< validation_criteria["stability_threshold"]
):
warnings.append(
f"{strategy}: Low walk-forward consistency ({wf_result.get('consistency_score', 0):.2f})"
)
# Monte Carlo warnings
if "error" in mc_result:
warnings.append(f"{strategy}: Monte Carlo simulation failed")
elif mc_result.get("stability_score", 0) < 0.6:
warnings.append(f"{strategy}: High Monte Carlo variability")
# Out-of-sample warnings
if oos_result["total_trades"] < 5:
warnings.append(
f"{strategy}: Very few out-of-sample trades ({oos_result['total_trades']})"
)
if oos_result["sharpe_ratio"] < validation_criteria["min_sharpe_ratio"]:
warnings.append(
f"{strategy}: Low out-of-sample Sharpe ratio ({oos_result['sharpe_ratio']:.2f})"
)
if (
abs(oos_result["max_drawdown"])
> validation_criteria["max_drawdown_threshold"]
):
warnings.append(
f"{strategy}: High out-of-sample drawdown ({oos_result['max_drawdown']:.2f})"
)
return warnings
def _generate_final_ranking(
self,
best_parameters: dict[str, dict[str, Any]],
robustness_scores: dict[str, float],
strategy_rankings: dict[str, float],
) -> list[dict[str, Any]]:
"""Generate final ranked recommendations."""
rankings = []
for strategy in best_parameters.keys():
robustness = robustness_scores.get(strategy, 0.0)
fitness = strategy_rankings.get(strategy, 0.5)
# Combined score: 60% robustness, 40% initial fitness
combined_score = robustness * 0.6 + fitness * 0.4
rankings.append(
{
"strategy": strategy,
"robustness_score": robustness,
"fitness_score": fitness,
"combined_score": combined_score,
"parameters": best_parameters[strategy],
"recommendation": self._get_recommendation_level(combined_score),
}
)
# Sort by combined score
rankings.sort(key=lambda x: x["combined_score"], reverse=True)
return rankings
def _get_recommendation_level(self, combined_score: float) -> str:
"""Get recommendation level based on combined score."""
if combined_score >= 0.8:
return "Highly Recommended"
elif combined_score >= 0.6:
return "Recommended"
elif combined_score >= 0.4:
return "Acceptable"
else:
return "Not Recommended"
def _select_recommended_strategy(
self,
final_ranking: list[dict[str, Any]],
robustness_scores: dict[str, float],
regime_confidence: float,
) -> tuple[str, float]:
"""Select the final recommended strategy and calculate confidence."""
if not final_ranking:
return "sma_cross", 0.1 # Fallback
# Select top strategy
top_strategy = final_ranking[0]["strategy"]
top_score = final_ranking[0]["combined_score"]
# Calculate recommendation confidence
confidence_factors = []
# Score-based confidence
confidence_factors.append(top_score)
# Robustness-based confidence
robustness = robustness_scores.get(top_strategy, 0.0)
confidence_factors.append(robustness)
# Regime confidence factor
confidence_factors.append(regime_confidence)
# Score separation from second-best
if len(final_ranking) > 1:
score_gap = top_score - final_ranking[1]["combined_score"]
separation_confidence = min(score_gap * 2, 1.0) # Scale to 0-1
confidence_factors.append(separation_confidence)
else:
confidence_factors.append(0.5) # Moderate confidence for single option
# Calculate overall confidence
recommendation_confidence = sum(confidence_factors) / len(confidence_factors)
recommendation_confidence = max(0.1, min(0.95, recommendation_confidence))
return top_strategy, recommendation_confidence
def _perform_risk_assessment(
self,
recommended_strategy: str,
walk_forward_results: dict[str, dict[str, Any]],
monte_carlo_results: dict[str, dict[str, Any]],
validation_criteria: dict[str, Any],
) -> dict[str, Any]:
"""Perform comprehensive risk assessment of recommended strategy."""
wf_result = walk_forward_results.get(recommended_strategy, {})
mc_result = monte_carlo_results.get(recommended_strategy, {})
risk_assessment = {
"overall_risk_level": "Medium",
"key_risks": [],
"risk_mitigation": [],
"confidence_intervals": {},
"worst_case_scenario": {},
}
# Analyze walk-forward results for risk patterns
if "periods" in wf_result:
periods = wf_result["periods"]
negative_periods = [p for p in periods if p.get("total_return", 0) < 0]
if len(negative_periods) / len(periods) > 0.4:
risk_assessment["key_risks"].append("High frequency of losing periods")
risk_assessment["overall_risk_level"] = "High"
max_period_loss = min([p.get("total_return", 0) for p in periods])
if max_period_loss < -0.15:
risk_assessment["key_risks"].append(
f"Severe single-period loss: {max_period_loss:.1%}"
)
# Analyze Monte Carlo results
if "percentiles" in mc_result:
percentiles = mc_result["percentiles"]
worst_case = percentiles.get("5", 0) # 5th percentile
risk_assessment["worst_case_scenario"] = {
"return_5th_percentile": worst_case,
"probability": 0.05,
"description": f"5% chance of returns below {worst_case:.1%}",
}
risk_assessment["confidence_intervals"] = {
"90_percent_range": f"{percentiles.get('5', 0):.1%} to {percentiles.get('95', 0):.1%}",
"median_return": f"{percentiles.get('50', 0):.1%}",
}
# Risk mitigation recommendations
risk_assessment["risk_mitigation"] = [
"Use position sizing based on volatility",
"Implement stop-loss orders",
"Monitor strategy performance regularly",
"Consider diversification across multiple strategies",
]
return risk_assessment
```
--------------------------------------------------------------------------------
/tests/test_production_validation.py:
--------------------------------------------------------------------------------
```python
"""
Production Validation Test Suite for MaverickMCP.
This suite validates that the system is ready for production deployment
by testing configuration, environment setup, monitoring, backup procedures,
and production-like load scenarios.
Validates:
- Environment configuration correctness
- SSL/TLS configuration (when available)
- Monitoring and alerting systems
- Backup and recovery procedures
- Load testing with production-like scenarios
- Security configuration in production mode
- Database migration status and integrity
- Performance optimization effectiveness
"""
import asyncio
import os
import ssl
import time
from pathlib import Path
from unittest.mock import patch
import pytest
from fastapi.testclient import TestClient
from maverick_mcp.api.api_server import create_api_app
from maverick_mcp.config.settings import get_settings
from maverick_mcp.config.validation import get_validation_status
from maverick_mcp.data.models import SessionLocal
from maverick_mcp.data.performance import (
cleanup_performance_systems,
get_performance_metrics,
initialize_performance_systems,
)
from maverick_mcp.utils.monitoring import get_metrics, initialize_monitoring
@pytest.fixture(scope="session")
def production_settings():
"""Get production-like settings."""
with patch.dict(
os.environ,
{
"ENVIRONMENT": "production",
"AUTH_ENABLED": "true",
"SECURITY_ENABLED": "true",
"JWT_SECRET": "test-jwt-secret-for-production-validation-tests-minimum-32-chars",
"DATABASE_URL": "postgresql://test:test@localhost/test_prod_db",
},
):
return get_settings()
@pytest.fixture
def production_app(production_settings):
"""Create production-configured app."""
return create_api_app()
@pytest.fixture
def production_client(production_app):
"""Create client for production testing."""
return TestClient(production_app)
class TestEnvironmentConfiguration:
"""Test production environment configuration."""
@pytest.mark.skip(reason="Incompatible with global test environment configuration")
def test_environment_variables_set(self, production_settings):
"""Test that all required environment variables are set."""
# Critical environment variables for production
critical_vars = [
"DATABASE_URL",
"JWT_SECRET",
"ENVIRONMENT",
]
# Check that critical vars are set (not default values)
for var in critical_vars:
env_value = os.getenv(var)
if var == "DATABASE_URL":
# Should not be default SQLite in production
if env_value is None:
pytest.skip(f"{var} not set in test environment")
if env_value:
assert (
"sqlite" not in env_value.lower()
or "memory" not in env_value.lower()
)
elif var == "JWT_SECRET":
# Should not be default/weak secret
if env_value is None:
pytest.skip(f"{var} not set in test environment")
if env_value:
assert len(env_value) >= 32
assert env_value != "your-secret-key-here"
assert env_value != "development-key"
elif var == "ENVIRONMENT":
if env_value is None:
pytest.skip(f"{var} not set in test environment")
assert env_value in ["production", "staging"]
def test_security_configuration(self, production_settings):
"""Test security configuration for production."""
# Authentication should be enabled
assert production_settings.auth.enabled is True
# Secure cookies in production
if production_settings.environment == "production":
# Cookie security should be enabled (skip if not implemented)
if not hasattr(production_settings, "cookie_secure"):
pytest.skip("Cookie secure setting not implemented yet")
# JWT configuration
assert production_settings.auth.jwt_algorithm in ["RS256", "HS256"]
assert (
production_settings.auth.jwt_access_token_expire_minutes <= 60
) # Not too long
# Redis configuration (should not use default)
if hasattr(production_settings.auth, "redis_url"):
redis_url = production_settings.auth.redis_url
assert "localhost" not in redis_url or os.getenv("REDIS_HOST") is not None
def test_database_configuration(self, production_settings):
"""Test database configuration for production."""
# Get database URL from environment or settings
database_url = os.getenv("DATABASE_URL", "")
if not database_url:
pytest.skip("DATABASE_URL not set in environment")
# Should use production database (not SQLite)
assert (
"postgresql" in database_url.lower() or "mysql" in database_url.lower()
) and "sqlite" not in database_url.lower()
# Should not use default credentials
if "postgresql://" in database_url:
assert "password" not in database_url or "your-password" not in database_url
assert (
"localhost" not in database_url
or os.getenv("DATABASE_HOST") is not None
)
# Test database connection
try:
with SessionLocal() as session:
result = session.execute("SELECT 1")
assert result.scalar() == 1
except Exception as e:
pytest.skip(f"Database connection test skipped: {e}")
def test_logging_configuration(self, production_settings):
"""Test logging configuration for production."""
# Log level should be appropriate for production
assert production_settings.api.log_level.upper() in ["INFO", "WARNING", "ERROR"]
# Should not be DEBUG in production
if production_settings.environment == "production":
assert production_settings.api.log_level.upper() != "DEBUG"
def test_api_configuration(self, production_settings):
"""Test API configuration for production."""
# Debug features should be disabled
if production_settings.environment == "production":
assert production_settings.api.debug is False
# CORS should be properly configured
cors_origins = production_settings.api.cors_origins
assert cors_origins is not None
# Should not allow all origins in production
if production_settings.environment == "production":
assert "*" not in cors_origins
class TestSystemValidation:
"""Test system validation and health checks."""
def test_configuration_validation(self):
"""Test configuration validation system."""
validation_status = get_validation_status()
# Should have validation status
assert "valid" in validation_status
assert "warnings" in validation_status
assert "errors" in validation_status
# In production, should have minimal warnings/errors
if os.getenv("ENVIRONMENT") == "production":
assert len(validation_status["errors"]) == 0
assert len(validation_status["warnings"]) <= 2 # Allow some minor warnings
def test_health_check_endpoint(self, production_client):
"""Test health check endpoint functionality."""
response = production_client.get("/health")
assert response.status_code == 200
health_data = response.json()
assert "status" in health_data
assert health_data["status"] in ["healthy", "degraded"]
# Should include service information
assert "services" in health_data
assert "version" in health_data
# Should include circuit breakers
assert "circuit_breakers" in health_data
@pytest.mark.integration
def test_database_health(self):
"""Test database health and connectivity."""
try:
with SessionLocal() as session:
# Test basic connectivity
from sqlalchemy import text
result = session.execute(text("SELECT 1 as health_check"))
assert result.scalar() == 1
# Test transaction capability
# Session already has a transaction, so just test query
# Use SQLite-compatible query for testing
result = session.execute(
text("SELECT COUNT(*) FROM sqlite_master WHERE type='table'")
if "sqlite" in str(session.bind.url)
else text("SELECT COUNT(*) FROM information_schema.tables")
)
assert result.scalar() >= 0 # Should return some count
except Exception as e:
pytest.fail(f"Database health check failed: {e}")
@pytest.mark.asyncio
@pytest.mark.integration
async def test_performance_systems_health(self):
"""Test performance systems health."""
# Initialize performance systems
performance_status = await initialize_performance_systems()
# Should initialize successfully
assert isinstance(performance_status, dict)
assert "redis_manager" in performance_status
# Get performance metrics
metrics = await get_performance_metrics()
assert "redis_manager" in metrics
assert "request_cache" in metrics
assert "query_optimizer" in metrics
assert "timestamp" in metrics
# Cleanup
await cleanup_performance_systems()
def test_monitoring_systems(self):
"""Test monitoring systems are functional."""
try:
# Initialize monitoring
initialize_monitoring()
# Get metrics
metrics_data = get_metrics()
assert isinstance(metrics_data, str)
# Should be Prometheus format
assert (
"# HELP" in metrics_data
or "# TYPE" in metrics_data
or len(metrics_data) > 0
)
except Exception as e:
pytest.skip(f"Monitoring test skipped: {e}")
class TestSSLTLSConfiguration:
"""Test SSL/TLS configuration (when available)."""
def test_ssl_certificate_validity(self):
"""Test SSL certificate validity."""
# This would test actual SSL certificate in production
# For testing, we check if SSL context can be created
try:
context = ssl.create_default_context()
assert context.check_hostname is True
assert context.verify_mode == ssl.CERT_REQUIRED
except Exception as e:
pytest.skip(f"SSL test skipped: {e}")
def test_tls_configuration(self, production_client):
"""Test TLS configuration."""
# Test security headers are present
production_client.get("/health")
# Should have security headers in production
security_headers = [
"X-Content-Type-Options",
"X-Frame-Options",
"X-XSS-Protection",
]
# Note: These would be set by security middleware
# Check if security middleware is active
for _header in security_headers:
# In test environment, headers might not be set
# In production, they should be present
if os.getenv("ENVIRONMENT") == "production":
# assert header in response.headers
pass # Skip for test environment
def test_secure_cookie_configuration(self, production_client, production_settings):
"""Test secure cookie configuration."""
if production_settings.environment != "production":
pytest.skip("Secure cookie test only for production")
# Test that cookies are set with secure flags
test_user = {
"email": "[email protected]",
"password": "TestPass123!",
"name": "SSL Test User",
}
# Register and login
production_client.post("/auth/register", json=test_user)
login_response = production_client.post(
"/auth/login",
json={"email": test_user["email"], "password": test_user["password"]},
)
# Check cookie headers for security flags
cookie_header = login_response.headers.get("set-cookie", "")
if cookie_header:
# Should have Secure flag in production
assert "Secure" in cookie_header
assert "HttpOnly" in cookie_header
assert "SameSite" in cookie_header
class TestBackupAndRecovery:
"""Test backup and recovery procedures."""
def test_database_backup_capability(self):
"""Test database backup capability."""
try:
with SessionLocal() as session:
# Test that we can read critical tables
critical_tables = [
"mcp_users",
"mcp_api_keys",
"auth_audit_log",
]
for table in critical_tables:
try:
result = session.execute(f"SELECT COUNT(*) FROM {table}")
count = result.scalar()
assert count >= 0 # Should be able to count rows
except Exception as e:
# Table might not exist in test environment
pytest.skip(f"Table {table} not found: {e}")
except Exception as e:
pytest.skip(f"Database backup test skipped: {e}")
def test_configuration_backup(self):
"""Test configuration backup capability."""
# Test that critical configuration can be backed up
critical_config_files = [
"alembic.ini",
".env", # Note: should not backup .env with secrets
"pyproject.toml",
]
project_root = Path(__file__).parent.parent
for config_file in critical_config_files:
config_path = project_root / config_file
if config_path.exists():
# Should be readable
assert config_path.is_file()
assert os.access(config_path, os.R_OK)
else:
# Some files might not exist in test environment
pass
def test_graceful_shutdown_capability(self, production_app):
"""Test graceful shutdown capability."""
# Test that app can handle shutdown signals
# This is more of a conceptual test since we can't actually shut down
# Check that shutdown handlers are registered
# This would be tested in actual deployment
assert hasattr(production_app, "router")
assert production_app.router is not None
class TestLoadTesting:
"""Test system under production-like load."""
@pytest.mark.skip(
reason="Long-running load test - disabled to conserve CI resources"
)
@pytest.mark.asyncio
@pytest.mark.integration
async def test_concurrent_user_load(self, production_client):
"""Test system under concurrent user load."""
# Create multiple test users
test_users = []
for i in range(5):
user = {
"email": f"loadtest{i}@example.com",
"password": "LoadTest123!",
"name": f"Load Test User {i}",
}
test_users.append(user)
# Register user
response = production_client.post("/auth/register", json=user)
if response.status_code not in [200, 201]:
pytest.skip("User registration failed in load test")
# Simulate concurrent operations
async def user_session(user_data):
"""Simulate a complete user session."""
results = []
# Login
login_response = production_client.post(
"/auth/login",
json={"email": user_data["email"], "password": user_data["password"]},
)
results.append(("login", login_response.status_code))
if login_response.status_code == 200:
csrf_token = login_response.json().get("csrf_token")
# Multiple API calls
for _ in range(3):
profile_response = production_client.get(
"/user/profile", headers={"X-CSRF-Token": csrf_token}
)
results.append(("profile", profile_response.status_code))
return results
# Run concurrent sessions
tasks = [user_session(user) for user in test_users]
session_results = await asyncio.gather(*tasks, return_exceptions=True)
# Analyze results
all_results = []
for result in session_results:
if isinstance(result, list):
all_results.extend(result)
# Should have mostly successful responses
success_rate = sum(
1 for op, status in all_results if status in [200, 201]
) / len(all_results)
assert success_rate >= 0.8 # At least 80% success rate
@pytest.mark.skip(
reason="Long-running performance test - disabled to conserve CI resources"
)
def test_api_endpoint_performance(self, production_client):
"""Test API endpoint performance."""
# Test key endpoints for performance
endpoints_to_test = [
"/health",
"/",
]
performance_results = {}
for endpoint in endpoints_to_test:
times = []
for _ in range(5):
start_time = time.time()
response = production_client.get(endpoint)
end_time = time.time()
if response.status_code == 200:
times.append(end_time - start_time)
if times:
avg_time = sum(times) / len(times)
max_time = max(times)
performance_results[endpoint] = {
"avg_time": avg_time,
"max_time": max_time,
}
# Performance assertions
assert avg_time < 1.0 # Average response under 1 second
assert max_time < 2.0 # Max response under 2 seconds
@pytest.mark.skip(
reason="Long-running memory test - disabled to conserve CI resources"
)
def test_memory_usage_stability(self, production_client):
"""Test memory usage stability under load."""
# Make multiple requests to test for memory leaks
initial_response_time = None
final_response_time = None
for i in range(20):
start_time = time.time()
response = production_client.get("/health")
end_time = time.time()
if response.status_code == 200:
response_time = end_time - start_time
if i == 0:
initial_response_time = response_time
elif i == 19:
final_response_time = response_time
# Response time should not degrade significantly (indicating memory leaks)
if initial_response_time and final_response_time:
degradation_ratio = final_response_time / initial_response_time
assert degradation_ratio < 3.0 # Should not be 3x slower
class TestProductionReadinessChecklist:
"""Final production readiness checklist."""
def test_database_migrations_applied(self):
"""Test that all database migrations are applied."""
try:
with SessionLocal() as session:
# Check that migration tables exist
result = session.execute("""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = 'alembic_version'
""")
migration_table_exists = result.scalar() is not None
if migration_table_exists:
# Check current migration version
version_result = session.execute(
"SELECT version_num FROM alembic_version"
)
current_version = version_result.scalar()
assert current_version is not None
assert len(current_version) > 0
except Exception as e:
pytest.skip(f"Database migration check skipped: {e}")
def test_security_features_enabled(self, production_settings):
"""Test that all security features are enabled."""
# Authentication enabled
assert production_settings.auth.enabled is True
# Proper environment
assert production_settings.environment in ["production", "staging"]
def test_performance_optimizations_active(self):
"""Test that performance optimizations are active."""
# This would test actual performance optimizations
# For now, test that performance modules can be imported
try:
from maverick_mcp.data.performance import (
query_optimizer,
redis_manager,
request_cache,
)
assert redis_manager is not None
assert request_cache is not None
assert query_optimizer is not None
except ImportError as e:
pytest.fail(f"Performance optimization modules not available: {e}")
def test_monitoring_and_logging_ready(self):
"""Test that monitoring and logging are ready."""
try:
# Test logging configuration
from maverick_mcp.utils.logging import get_logger
logger = get_logger("production_test")
assert logger is not None
# Test monitoring availability
from maverick_mcp.utils.monitoring import get_metrics
metrics = get_metrics()
assert isinstance(metrics, str)
except Exception as e:
pytest.skip(f"Monitoring test skipped: {e}")
@pytest.mark.integration
def test_final_system_integration(self, production_client):
"""Final system integration test."""
# Test complete workflow with unique email
import uuid
unique_id = str(uuid.uuid4())[:8]
test_user = {
"email": f"final_test_{unique_id}@example.com",
"password": "FinalTest123!",
"name": "Final Test User",
}
# 1. Health check
health_response = production_client.get("/health")
assert health_response.status_code == 200
# 2. User registration
register_response = production_client.post("/auth/signup", json=test_user)
assert register_response.status_code in [200, 201]
# 3. User login
login_response = production_client.post(
"/auth/login",
json={"email": test_user["email"], "password": test_user["password"]},
)
assert login_response.status_code == 200
# Get tokens from response
login_data = login_response.json()
access_token = login_data.get("access_token")
# If no access token in response body, it might be in cookies
if not access_token:
# For cookie-based auth, we just need to make sure login succeeded
assert "user" in login_data or "message" in login_data
# 4. Authenticated API access (with cookies)
profile_response = production_client.get("/user/profile")
assert profile_response.status_code == 200
else:
# Bearer token auth
headers = {"Authorization": f"Bearer {access_token}"}
# 4. Authenticated API access
profile_response = production_client.get("/user/profile", headers=headers)
assert profile_response.status_code == 200
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])
```
--------------------------------------------------------------------------------
/tests/test_security_cors.py:
--------------------------------------------------------------------------------
```python
"""
Comprehensive CORS Security Tests for Maverick MCP.
Tests CORS configuration, validation, origin blocking, wildcard security,
and environment-specific behaviors.
"""
import os
from unittest.mock import MagicMock, patch
import pytest
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.testclient import TestClient
from maverick_mcp.config.security import (
CORSConfig,
SecurityConfig,
validate_security_config,
)
from maverick_mcp.config.security_utils import (
apply_cors_to_fastapi,
check_security_config,
get_safe_cors_config,
)
class TestCORSConfiguration:
"""Test CORS configuration validation and creation."""
def test_cors_config_valid_origins(self):
"""Test CORS config creation with valid origins."""
config = CORSConfig(
allowed_origins=["https://example.com", "https://app.example.com"],
allow_credentials=True,
)
assert config.allowed_origins == [
"https://example.com",
"https://app.example.com",
]
assert config.allow_credentials is True
def test_cors_config_wildcard_with_credentials_raises_error(self):
"""Test that wildcard origins with credentials raises validation error."""
with pytest.raises(
ValueError,
match="CORS Security Error.*wildcard origin.*serious security vulnerability",
):
CORSConfig(allowed_origins=["*"], allow_credentials=True)
def test_cors_config_wildcard_without_credentials_warns(self):
"""Test that wildcard origins without credentials logs warning."""
with patch("logging.getLogger") as mock_logger:
mock_logger_instance = MagicMock()
mock_logger.return_value = mock_logger_instance
config = CORSConfig(allowed_origins=["*"], allow_credentials=False)
assert config.allowed_origins == ["*"]
assert config.allow_credentials is False
mock_logger_instance.warning.assert_called_once()
def test_cors_config_multiple_origins_with_wildcard_fails(self):
"""Test that mixed origins including wildcard with credentials fails."""
with pytest.raises(ValueError, match="CORS Security Error"):
CORSConfig(
allowed_origins=["https://example.com", "*"], allow_credentials=True
)
def test_cors_config_default_values(self):
"""Test CORS config default values are secure."""
with patch.dict(os.environ, {"ENVIRONMENT": "development"}, clear=False):
with patch(
"maverick_mcp.config.security._get_cors_origins"
) as mock_origins:
mock_origins.return_value = ["http://localhost:3000"]
config = CORSConfig()
assert config.allow_credentials is True
assert "GET" in config.allowed_methods
assert "POST" in config.allowed_methods
assert "Authorization" in config.allowed_headers
assert "Content-Type" in config.allowed_headers
assert config.max_age == 86400
def test_cors_config_expose_headers(self):
"""Test that proper headers are exposed to clients."""
config = CORSConfig()
expected_exposed = [
"X-Process-Time",
"X-RateLimit-Limit",
"X-RateLimit-Remaining",
"X-RateLimit-Reset",
"X-Request-ID",
]
for header in expected_exposed:
assert header in config.exposed_headers
class TestCORSEnvironmentConfiguration:
"""Test environment-specific CORS configuration."""
def test_production_cors_origins(self):
"""Test production CORS origins are restrictive."""
with patch.dict(os.environ, {"ENVIRONMENT": "production"}, clear=True):
with patch(
"maverick_mcp.config.security._get_cors_origins"
) as mock_origins:
mock_origins.return_value = [
"https://app.maverick-mcp.com",
"https://maverick-mcp.com",
]
config = SecurityConfig()
assert "localhost" not in str(config.cors.allowed_origins).lower()
assert "127.0.0.1" not in str(config.cors.allowed_origins).lower()
assert all(
origin.startswith("https://")
for origin in config.cors.allowed_origins
)
def test_development_cors_origins(self):
"""Test development CORS origins include localhost."""
with patch.dict(os.environ, {"ENVIRONMENT": "development"}, clear=True):
with patch(
"maverick_mcp.config.security._get_cors_origins"
) as mock_origins:
mock_origins.return_value = [
"http://localhost:3000",
"http://127.0.0.1:3000",
]
config = SecurityConfig()
localhost_found = any(
"localhost" in origin for origin in config.cors.allowed_origins
)
assert localhost_found
def test_staging_cors_origins(self):
"""Test staging CORS origins are appropriate."""
with patch.dict(os.environ, {"ENVIRONMENT": "staging"}, clear=True):
with patch(
"maverick_mcp.config.security._get_cors_origins"
) as mock_origins:
mock_origins.return_value = [
"https://staging.maverick-mcp.com",
"http://localhost:3000",
]
config = SecurityConfig()
staging_found = any(
"staging" in origin for origin in config.cors.allowed_origins
)
assert staging_found
def test_custom_cors_origins_from_env(self):
"""Test custom CORS origins from environment variable."""
custom_origins = "https://custom1.com,https://custom2.com"
with patch.dict(os.environ, {"CORS_ORIGINS": custom_origins}, clear=False):
with patch(
"maverick_mcp.config.security._get_cors_origins"
) as mock_origins:
mock_origins.return_value = [
"https://custom1.com",
"https://custom2.com",
]
config = SecurityConfig()
assert "https://custom1.com" in config.cors.allowed_origins
assert "https://custom2.com" in config.cors.allowed_origins
class TestCORSValidation:
"""Test CORS security validation."""
def test_validate_security_config_valid_cors(self):
"""Test security validation passes with valid CORS config."""
with patch("maverick_mcp.config.security.get_security_config") as mock_config:
mock_security_config = MagicMock()
mock_security_config.cors.allowed_origins = ["https://example.com"]
mock_security_config.cors.allow_credentials = True
mock_security_config.is_production.return_value = False
mock_security_config.force_https = True
mock_security_config.headers.x_frame_options = "DENY"
mock_config.return_value = mock_security_config
result = validate_security_config()
assert result["valid"] is True
assert len(result["issues"]) == 0
def test_validate_security_config_wildcard_with_credentials(self):
"""Test security validation fails with wildcard + credentials."""
with patch("maverick_mcp.config.security.get_security_config") as mock_config:
mock_security_config = MagicMock()
mock_security_config.cors.allowed_origins = ["*"]
mock_security_config.cors.allow_credentials = True
mock_security_config.is_production.return_value = False
mock_security_config.force_https = True
mock_security_config.headers.x_frame_options = "DENY"
mock_config.return_value = mock_security_config
result = validate_security_config()
assert result["valid"] is False
assert any(
"Wildcard CORS origins with credentials enabled" in issue
for issue in result["issues"]
)
def test_validate_security_config_production_wildcards(self):
"""Test security validation fails with wildcards in production."""
with patch("maverick_mcp.config.security.get_security_config") as mock_config:
mock_security_config = MagicMock()
mock_security_config.cors.allowed_origins = ["*"]
mock_security_config.cors.allow_credentials = False
mock_security_config.is_production.return_value = True
mock_security_config.force_https = True
mock_security_config.headers.x_frame_options = "DENY"
mock_config.return_value = mock_security_config
result = validate_security_config()
assert result["valid"] is False
assert any(
"Wildcard CORS origins in production" in issue
for issue in result["issues"]
)
def test_validate_security_config_production_localhost_warning(self):
"""Test security validation warns about localhost in production."""
with patch("maverick_mcp.config.security.get_security_config") as mock_config:
mock_security_config = MagicMock()
mock_security_config.cors.allowed_origins = [
"https://app.com",
"http://localhost:3000",
]
mock_security_config.cors.allow_credentials = True
mock_security_config.is_production.return_value = True
mock_security_config.force_https = True
mock_security_config.headers.x_frame_options = "DENY"
mock_config.return_value = mock_security_config
result = validate_security_config()
assert result["valid"] is True # Warning, not error
assert any("localhost" in warning.lower() for warning in result["warnings"])
class TestCORSMiddlewareIntegration:
"""Test CORS middleware integration with FastAPI."""
def create_test_app(self, security_config=None):
"""Create a test FastAPI app with CORS applied."""
app = FastAPI()
if security_config:
with patch(
"maverick_mcp.config.security_utils.get_security_config",
return_value=security_config,
):
apply_cors_to_fastapi(app)
else:
apply_cors_to_fastapi(app)
@app.get("/test")
async def test_endpoint():
return {"message": "test"}
@app.post("/test")
async def test_post_endpoint():
return {"message": "post test"}
return app
def test_cors_middleware_allows_configured_origins(self):
"""Test that CORS middleware allows configured origins."""
# Create mock security config
mock_config = MagicMock()
mock_config.get_cors_middleware_config.return_value = {
"allow_origins": ["https://allowed.com"],
"allow_credentials": True,
"allow_methods": ["GET", "POST"],
"allow_headers": ["Content-Type", "Authorization"],
"expose_headers": [],
"max_age": 86400,
}
# Mock validation to pass
with patch(
"maverick_mcp.config.security_utils.validate_security_config"
) as mock_validate:
mock_validate.return_value = {"valid": True, "issues": [], "warnings": []}
app = self.create_test_app(mock_config)
client = TestClient(app)
# Test preflight request
response = client.options(
"/test",
headers={
"Origin": "https://allowed.com",
"Access-Control-Request-Method": "POST",
"Access-Control-Request-Headers": "Content-Type",
},
)
assert response.status_code == 200
assert (
response.headers.get("Access-Control-Allow-Origin")
== "https://allowed.com"
)
assert "POST" in response.headers.get("Access-Control-Allow-Methods", "")
def test_cors_middleware_blocks_unauthorized_origins(self):
"""Test that CORS middleware blocks unauthorized origins."""
mock_config = MagicMock()
mock_config.get_cors_middleware_config.return_value = {
"allow_origins": ["https://allowed.com"],
"allow_credentials": True,
"allow_methods": ["GET", "POST"],
"allow_headers": ["Content-Type"],
"expose_headers": [],
"max_age": 86400,
}
with patch(
"maverick_mcp.config.security_utils.validate_security_config"
) as mock_validate:
mock_validate.return_value = {"valid": True, "issues": [], "warnings": []}
app = self.create_test_app(mock_config)
client = TestClient(app)
# Test request from unauthorized origin
response = client.get(
"/test", headers={"Origin": "https://unauthorized.com"}
)
# The request should succeed (CORS is browser-enforced)
# but the CORS headers should not allow the unauthorized origin
assert response.status_code == 200
cors_origin = response.headers.get("Access-Control-Allow-Origin")
assert cors_origin != "https://unauthorized.com"
def test_cors_middleware_credentials_handling(self):
"""Test CORS middleware credentials handling."""
mock_config = MagicMock()
mock_config.get_cors_middleware_config.return_value = {
"allow_origins": ["https://allowed.com"],
"allow_credentials": True,
"allow_methods": ["GET", "POST"],
"allow_headers": ["Content-Type"],
"expose_headers": [],
"max_age": 86400,
}
with patch(
"maverick_mcp.config.security_utils.validate_security_config"
) as mock_validate:
mock_validate.return_value = {"valid": True, "issues": [], "warnings": []}
app = self.create_test_app(mock_config)
client = TestClient(app)
response = client.options(
"/test",
headers={
"Origin": "https://allowed.com",
"Access-Control-Request-Method": "POST",
},
)
assert response.headers.get("Access-Control-Allow-Credentials") == "true"
def test_cors_middleware_exposed_headers(self):
"""Test that CORS middleware exposes configured headers."""
mock_config = MagicMock()
mock_config.get_cors_middleware_config.return_value = {
"allow_origins": ["https://allowed.com"],
"allow_credentials": True,
"allow_methods": ["GET"],
"allow_headers": ["Content-Type"],
"expose_headers": ["X-Custom-Header", "X-Rate-Limit"],
"max_age": 86400,
}
with patch(
"maverick_mcp.config.security_utils.validate_security_config"
) as mock_validate:
mock_validate.return_value = {"valid": True, "issues": [], "warnings": []}
app = self.create_test_app(mock_config)
client = TestClient(app)
response = client.get("/test", headers={"Origin": "https://allowed.com"})
exposed_headers = response.headers.get("Access-Control-Expose-Headers", "")
assert "X-Custom-Header" in exposed_headers
assert "X-Rate-Limit" in exposed_headers
class TestCORSSecurityValidation:
"""Test CORS security validation and safety measures."""
def test_apply_cors_fails_with_invalid_config(self):
"""Test that applying CORS fails with invalid configuration."""
app = FastAPI()
# Mock invalid configuration
with patch(
"maverick_mcp.config.security_utils.validate_security_config"
) as mock_validate:
mock_validate.return_value = {
"valid": False,
"issues": ["Wildcard CORS origins with credentials"],
"warnings": [],
}
with pytest.raises(ValueError, match="Security configuration is invalid"):
apply_cors_to_fastapi(app)
def test_get_safe_cors_config_production_fallback(self):
"""Test safe CORS config fallback for production."""
with patch(
"maverick_mcp.config.security_utils.validate_security_config"
) as mock_validate:
mock_validate.return_value = {
"valid": False,
"issues": ["Invalid config"],
"warnings": [],
}
with patch(
"maverick_mcp.config.security_utils.get_security_config"
) as mock_config:
mock_security_config = MagicMock()
mock_security_config.is_production.return_value = True
mock_config.return_value = mock_security_config
safe_config = get_safe_cors_config()
assert safe_config["allow_origins"] == ["https://maverick-mcp.com"]
assert safe_config["allow_credentials"] is True
assert "localhost" not in str(safe_config["allow_origins"])
def test_get_safe_cors_config_development_fallback(self):
"""Test safe CORS config fallback for development."""
with patch(
"maverick_mcp.config.security_utils.validate_security_config"
) as mock_validate:
mock_validate.return_value = {
"valid": False,
"issues": ["Invalid config"],
"warnings": [],
}
with patch(
"maverick_mcp.config.security_utils.get_security_config"
) as mock_config:
mock_security_config = MagicMock()
mock_security_config.is_production.return_value = False
mock_config.return_value = mock_security_config
safe_config = get_safe_cors_config()
assert safe_config["allow_origins"] == ["http://localhost:3000"]
assert safe_config["allow_credentials"] is True
def test_check_security_config_function(self):
"""Test security config check function."""
with patch(
"maverick_mcp.config.security_utils.validate_security_config"
) as mock_validate:
# Test valid config
mock_validate.return_value = {"valid": True, "issues": [], "warnings": []}
assert check_security_config() is True
# Test invalid config
mock_validate.return_value = {
"valid": False,
"issues": ["Error"],
"warnings": [],
}
assert check_security_config() is False
class TestCORSPreflightRequests:
"""Test CORS preflight request handling."""
def test_preflight_request_max_age(self):
"""Test CORS preflight max-age header."""
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["https://example.com"],
allow_methods=["GET", "POST"],
allow_headers=["Content-Type"],
max_age=3600,
)
@app.get("/test")
async def test_endpoint():
return {"message": "test"}
client = TestClient(app)
response = client.options(
"/test",
headers={
"Origin": "https://example.com",
"Access-Control-Request-Method": "GET",
},
)
assert response.headers.get("Access-Control-Max-Age") == "3600"
def test_preflight_request_methods(self):
"""Test CORS preflight allowed methods."""
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["https://example.com"],
allow_methods=["GET", "POST", "PUT"],
allow_headers=["Content-Type"],
)
@app.get("/test")
async def test_endpoint():
return {"message": "test"}
client = TestClient(app)
response = client.options(
"/test",
headers={
"Origin": "https://example.com",
"Access-Control-Request-Method": "PUT",
},
)
allowed_methods = response.headers.get("Access-Control-Allow-Methods", "")
assert "PUT" in allowed_methods
assert "GET" in allowed_methods
assert "POST" in allowed_methods
def test_preflight_request_headers(self):
"""Test CORS preflight allowed headers."""
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["https://example.com"],
allow_methods=["POST"],
allow_headers=["Content-Type", "Authorization", "X-Custom"],
)
@app.post("/test")
async def test_endpoint():
return {"message": "test"}
client = TestClient(app)
response = client.options(
"/test",
headers={
"Origin": "https://example.com",
"Access-Control-Request-Method": "POST",
"Access-Control-Request-Headers": "Content-Type, Authorization",
},
)
allowed_headers = response.headers.get("Access-Control-Allow-Headers", "")
assert "Content-Type" in allowed_headers
assert "Authorization" in allowed_headers
class TestCORSEdgeCases:
"""Test CORS edge cases and security scenarios."""
def test_cors_with_vary_header(self):
"""Test that CORS responses include Vary header."""
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["https://example.com"],
allow_methods=["GET"],
allow_headers=["Content-Type"],
)
@app.get("/test")
async def test_endpoint():
return {"message": "test"}
client = TestClient(app)
response = client.get("/test", headers={"Origin": "https://example.com"})
vary_header = response.headers.get("Vary", "")
assert "Origin" in vary_header
def test_cors_null_origin_handling(self):
"""Test CORS handling of null origin (file:// protocol)."""
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["null"], # Sometimes needed for file:// protocol
allow_methods=["GET"],
allow_headers=["Content-Type"],
)
@app.get("/test")
async def test_endpoint():
return {"message": "test"}
client = TestClient(app)
response = client.get("/test", headers={"Origin": "null"})
# Should handle null origin appropriately
assert response.status_code == 200
def test_cors_case_insensitive_origin(self):
"""Test CORS origin matching is case-sensitive (as it should be)."""
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["https://Example.com"], # Capital E
allow_methods=["GET"],
allow_headers=["Content-Type"],
)
@app.get("/test")
async def test_endpoint():
return {"message": "test"}
client = TestClient(app)
# Test with different case
response = client.get(
"/test",
headers={"Origin": "https://example.com"}, # lowercase e
)
# Should not match due to case sensitivity
cors_origin = response.headers.get("Access-Control-Allow-Origin")
assert cors_origin != "https://example.com"
if __name__ == "__main__":
pytest.main([__file__, "-v"])
```
--------------------------------------------------------------------------------
/tests/core/test_technical_analysis.py:
--------------------------------------------------------------------------------
```python
"""
Unit tests for maverick_mcp.core.technical_analysis module.
This module contains comprehensive tests for all technical analysis functions
to ensure accurate financial calculations and proper error handling.
"""
import numpy as np
import pandas as pd
import pytest
from maverick_mcp.core.technical_analysis import (
add_technical_indicators,
analyze_bollinger_bands,
analyze_macd,
analyze_rsi,
analyze_stochastic,
analyze_trend,
analyze_volume,
calculate_atr,
generate_outlook,
identify_chart_patterns,
identify_resistance_levels,
identify_support_levels,
)
class TestTechnicalIndicators:
"""Test the add_technical_indicators function."""
def test_add_technical_indicators_basic(self):
"""Test basic technical indicators calculation."""
# Create sample data with enough data points for all indicators
dates = pd.date_range("2024-01-01", periods=100, freq="D")
data = {
"Date": dates,
"Open": np.random.uniform(100, 200, 100),
"High": np.random.uniform(150, 250, 100),
"Low": np.random.uniform(50, 150, 100),
"Close": np.random.uniform(100, 200, 100),
"Volume": np.random.randint(1000000, 10000000, 100),
}
df = pd.DataFrame(data)
df = df.set_index("Date")
# Add some realistic price movement
for i in range(1, len(df)):
df.loc[df.index[i], "Close"] = df.iloc[i - 1]["Close"] * np.random.uniform(
0.98, 1.02
)
df.loc[df.index[i], "High"] = max(
df.iloc[i]["Open"], df.iloc[i]["Close"]
) * np.random.uniform(1.0, 1.02)
df.loc[df.index[i], "Low"] = min(
df.iloc[i]["Open"], df.iloc[i]["Close"]
) * np.random.uniform(0.98, 1.0)
result = add_technical_indicators(df)
# Check that all expected indicators are added
expected_indicators = [
"ema_21",
"sma_50",
"sma_200",
"rsi",
"macd_12_26_9",
"macds_12_26_9",
"macdh_12_26_9",
"sma_20",
"bbu_20_2.0",
"bbl_20_2.0",
"stdev",
"atr",
"stochk_14_3_3",
"stochd_14_3_3",
"adx_14",
]
for indicator in expected_indicators:
assert indicator in result.columns
# Check that indicators have reasonable values (not all NaN)
assert not result["rsi"].iloc[-10:].isna().all()
assert not result["ema_21"].iloc[-10:].isna().all()
assert not result["sma_50"].iloc[-10:].isna().all()
def test_add_technical_indicators_column_case_insensitive(self):
"""Test that the function handles different column case properly."""
data = {
"OPEN": [100, 101, 102],
"HIGH": [105, 106, 107],
"LOW": [95, 96, 97],
"CLOSE": [103, 104, 105],
"VOLUME": [1000000, 1100000, 1200000],
}
df = pd.DataFrame(data)
result = add_technical_indicators(df)
# Check that columns are normalized to lowercase
assert "close" in result.columns
assert "high" in result.columns
assert "low" in result.columns
def test_add_technical_indicators_insufficient_data(self):
"""Test behavior with insufficient data."""
data = {
"Open": [100],
"High": [105],
"Low": [95],
"Close": [103],
"Volume": [1000000],
}
df = pd.DataFrame(data)
result = add_technical_indicators(df)
# Should handle insufficient data gracefully
assert "rsi" in result.columns
assert pd.isna(result["rsi"].iloc[0]) # Should be NaN for insufficient data
def test_add_technical_indicators_empty_dataframe(self):
"""Test behavior with empty dataframe."""
df = pd.DataFrame()
with pytest.raises(KeyError):
add_technical_indicators(df)
@pytest.mark.parametrize(
"bb_columns",
[
("BBM_20_2.0", "BBU_20_2.0", "BBL_20_2.0"),
("BBM_20_2", "BBU_20_2", "BBL_20_2"),
],
)
def test_add_technical_indicators_supports_bbands_column_aliases(
self, monkeypatch, bb_columns
):
"""Ensure Bollinger Band column name variations are handled."""
index = pd.date_range("2024-01-01", periods=40, freq="D")
base_series = np.linspace(100, 140, len(index))
data = {
"open": base_series,
"high": base_series + 1,
"low": base_series - 1,
"close": base_series,
"volume": np.full(len(index), 1_000_000),
}
df = pd.DataFrame(data, index=index)
mid_column, upper_column, lower_column = bb_columns
def fake_bbands(close, *args, **kwargs):
band_values = pd.Series(base_series, index=close.index)
return pd.DataFrame(
{
mid_column: band_values,
upper_column: band_values + 2,
lower_column: band_values - 2,
}
)
monkeypatch.setattr(
"maverick_mcp.core.technical_analysis.ta.bbands",
fake_bbands,
)
result = add_technical_indicators(df)
np.testing.assert_allclose(result["sma_20"], base_series)
np.testing.assert_allclose(result["bbu_20_2.0"], base_series + 2)
np.testing.assert_allclose(result["bbl_20_2.0"], base_series - 2)
class TestSupportResistanceLevels:
"""Test support and resistance level identification."""
@pytest.fixture
def sample_data(self):
"""Create sample price data for testing."""
data = {
"high": [105, 110, 108, 115, 112, 120, 118, 125, 122, 130] * 5,
"low": [95, 100, 98, 105, 102, 110, 108, 115, 112, 120] * 5,
"close": [100, 105, 103, 110, 107, 115, 113, 120, 117, 125] * 5,
}
return pd.DataFrame(data)
def test_identify_support_levels(self, sample_data):
"""Test support level identification."""
support_levels = identify_support_levels(sample_data)
assert isinstance(support_levels, list)
assert len(support_levels) > 0
assert all(
isinstance(level, float | int | np.number) for level in support_levels
)
assert support_levels == sorted(support_levels) # Should be sorted
def test_identify_resistance_levels(self, sample_data):
"""Test resistance level identification."""
resistance_levels = identify_resistance_levels(sample_data)
assert isinstance(resistance_levels, list)
assert len(resistance_levels) > 0
assert all(
isinstance(level, float | int | np.number) for level in resistance_levels
)
assert resistance_levels == sorted(resistance_levels) # Should be sorted
def test_support_resistance_with_small_dataset(self):
"""Test with dataset smaller than 30 days."""
data = {
"high": [105, 110, 108],
"low": [95, 100, 98],
"close": [100, 105, 103],
}
df = pd.DataFrame(data)
support_levels = identify_support_levels(df)
resistance_levels = identify_resistance_levels(df)
assert len(support_levels) > 0
assert len(resistance_levels) > 0
class TestTrendAnalysis:
"""Test trend analysis functionality."""
@pytest.fixture
def trending_data(self):
"""Create data with clear upward trend."""
dates = pd.date_range("2024-01-01", periods=60, freq="D")
close_prices = np.linspace(100, 150, 60) # Clear upward trend
data = {
"close": close_prices,
"high": close_prices * 1.02,
"low": close_prices * 0.98,
"volume": np.random.randint(1000000, 2000000, 60),
}
df = pd.DataFrame(data, index=dates)
return add_technical_indicators(df)
def test_analyze_trend_uptrend(self, trending_data):
"""Test trend analysis with upward trending data."""
trend_strength = analyze_trend(trending_data)
assert isinstance(trend_strength, int)
assert 0 <= trend_strength <= 7
assert trend_strength > 3 # Should detect strong uptrend
def test_analyze_trend_empty_dataframe(self):
"""Test trend analysis with empty dataframe."""
df = pd.DataFrame({"close": []})
trend_strength = analyze_trend(df)
assert trend_strength == 0
def test_analyze_trend_missing_indicators(self):
"""Test trend analysis with missing indicators."""
data = {
"close": [100, 101, 102, 103, 104],
}
df = pd.DataFrame(data)
trend_strength = analyze_trend(df)
assert trend_strength == 0 # Should handle missing indicators gracefully
class TestRSIAnalysis:
"""Test RSI analysis functionality."""
@pytest.fixture
def rsi_data(self):
"""Create data with RSI indicator."""
data = {
"close": [100, 105, 103, 110, 107, 115, 113, 120, 117, 125],
"rsi": [50, 55, 52, 65, 60, 70, 68, 75, 72, 80],
}
return pd.DataFrame(data)
def test_analyze_rsi_overbought(self, rsi_data):
"""Test RSI analysis with overbought conditions."""
result = analyze_rsi(rsi_data)
assert result["current"] == 80.0
assert result["signal"] == "overbought"
assert "overbought" in result["description"]
def test_analyze_rsi_oversold(self):
"""Test RSI analysis with oversold conditions."""
data = {
"close": [100, 95, 90, 85, 80],
"rsi": [50, 40, 30, 25, 20],
}
df = pd.DataFrame(data)
result = analyze_rsi(df)
assert result["current"] == 20.0
assert result["signal"] == "oversold"
def test_analyze_rsi_bullish(self):
"""Test RSI analysis with bullish conditions."""
data = {
"close": [100, 105, 110],
"rsi": [50, 55, 60],
}
df = pd.DataFrame(data)
result = analyze_rsi(df)
assert result["current"] == 60.0
assert result["signal"] == "bullish"
def test_analyze_rsi_bearish(self):
"""Test RSI analysis with bearish conditions."""
data = {
"close": [100, 95, 90],
"rsi": [50, 45, 40],
}
df = pd.DataFrame(data)
result = analyze_rsi(df)
assert result["current"] == 40.0
assert result["signal"] == "bearish"
def test_analyze_rsi_empty_dataframe(self):
"""Test RSI analysis with empty dataframe."""
df = pd.DataFrame()
result = analyze_rsi(df)
assert result["current"] is None
assert result["signal"] == "unavailable"
def test_analyze_rsi_missing_column(self):
"""Test RSI analysis without RSI column."""
data = {"close": [100, 105, 110]}
df = pd.DataFrame(data)
result = analyze_rsi(df)
assert result["current"] is None
assert result["signal"] == "unavailable"
def test_analyze_rsi_nan_values(self):
"""Test RSI analysis with NaN values."""
data = {
"close": [100, 105, 110],
"rsi": [50, 55, np.nan],
}
df = pd.DataFrame(data)
result = analyze_rsi(df)
assert result["current"] is None
assert result["signal"] == "unavailable"
class TestMACDAnalysis:
"""Test MACD analysis functionality."""
@pytest.fixture
def macd_data(self):
"""Create data with MACD indicators."""
data = {
"macd_12_26_9": [1.5, 2.0, 2.5, 3.0, 2.8],
"macds_12_26_9": [1.0, 1.8, 2.2, 2.7, 3.2],
"macdh_12_26_9": [0.5, 0.2, 0.3, 0.3, -0.4],
}
return pd.DataFrame(data)
def test_analyze_macd_bullish(self, macd_data):
"""Test MACD analysis with bullish signals."""
result = analyze_macd(macd_data)
assert result["macd"] == 2.8
assert result["signal"] == 3.2
assert result["histogram"] == -0.4
assert result["indicator"] == "bearish" # macd < signal and histogram < 0
def test_analyze_macd_crossover_detection(self):
"""Test MACD crossover detection."""
data = {
"macd_12_26_9": [1.0, 2.0, 3.0],
"macds_12_26_9": [2.0, 1.8, 2.5],
"macdh_12_26_9": [-1.0, 0.2, 0.5],
}
df = pd.DataFrame(data)
result = analyze_macd(df)
# Check that crossover detection works (test the logic rather than specific result)
assert "crossover" in result
assert result["crossover"] in [
"bullish crossover detected",
"bearish crossover detected",
"no recent crossover",
]
def test_analyze_macd_missing_data(self):
"""Test MACD analysis with missing data."""
data = {
"macd_12_26_9": [np.nan],
"macds_12_26_9": [np.nan],
"macdh_12_26_9": [np.nan],
}
df = pd.DataFrame(data)
result = analyze_macd(df)
assert result["macd"] is None
assert result["indicator"] == "unavailable"
class TestStochasticAnalysis:
"""Test Stochastic Oscillator analysis."""
@pytest.fixture
def stoch_data(self):
"""Create data with Stochastic indicators."""
data = {
"stochk_14_3_3": [20, 30, 40, 50, 60],
"stochd_14_3_3": [25, 35, 45, 55, 65],
}
return pd.DataFrame(data)
def test_analyze_stochastic_bearish(self, stoch_data):
"""Test Stochastic analysis with bearish signal."""
result = analyze_stochastic(stoch_data)
assert result["k"] == 60.0
assert result["d"] == 65.0
assert result["signal"] == "bearish" # k < d
def test_analyze_stochastic_overbought(self):
"""Test Stochastic analysis with overbought conditions."""
data = {
"stochk_14_3_3": [85],
"stochd_14_3_3": [83],
}
df = pd.DataFrame(data)
result = analyze_stochastic(df)
assert result["signal"] == "overbought"
def test_analyze_stochastic_oversold(self):
"""Test Stochastic analysis with oversold conditions."""
data = {
"stochk_14_3_3": [15],
"stochd_14_3_3": [18],
}
df = pd.DataFrame(data)
result = analyze_stochastic(df)
assert result["signal"] == "oversold"
def test_analyze_stochastic_crossover(self):
"""Test Stochastic crossover detection."""
data = {
"stochk_14_3_3": [30, 45],
"stochd_14_3_3": [40, 35],
}
df = pd.DataFrame(data)
result = analyze_stochastic(df)
assert result["crossover"] == "bullish crossover detected"
class TestBollingerBands:
"""Test Bollinger Bands analysis."""
@pytest.fixture
def bb_data(self):
"""Create data with Bollinger Bands."""
data = {
"close": [100, 105, 110, 108, 112],
"bbu_20_2.0": [115, 116, 117, 116, 118],
"bbl_20_2.0": [85, 86, 87, 86, 88],
"sma_20": [100, 101, 102, 101, 103],
}
return pd.DataFrame(data)
def test_analyze_bollinger_bands_above_middle(self, bb_data):
"""Test Bollinger Bands with price above middle band."""
result = analyze_bollinger_bands(bb_data)
assert result["upper_band"] == 118.0
assert result["middle_band"] == 103.0
assert result["lower_band"] == 88.0
assert result["position"] == "above middle band"
assert result["signal"] == "bullish"
def test_analyze_bollinger_bands_above_upper(self):
"""Test Bollinger Bands with price above upper band."""
data = {
"close": [120],
"bbu_20_2.0": [115],
"bbl_20_2.0": [85],
"sma_20": [100],
}
df = pd.DataFrame(data)
result = analyze_bollinger_bands(df)
assert result["position"] == "above upper band"
assert result["signal"] == "overbought"
def test_analyze_bollinger_bands_below_lower(self):
"""Test Bollinger Bands with price below lower band."""
data = {
"close": [80],
"bbu_20_2.0": [115],
"bbl_20_2.0": [85],
"sma_20": [100],
}
df = pd.DataFrame(data)
result = analyze_bollinger_bands(df)
assert result["position"] == "below lower band"
assert result["signal"] == "oversold"
def test_analyze_bollinger_bands_volatility_calculation(self):
"""Test Bollinger Bands volatility calculation."""
# Create data with contracting bands
data = {
"close": [100, 100, 100, 100, 100],
"bbu_20_2.0": [110, 108, 106, 104, 102],
"bbl_20_2.0": [90, 92, 94, 96, 98],
"sma_20": [100, 100, 100, 100, 100],
}
df = pd.DataFrame(data)
result = analyze_bollinger_bands(df)
assert "contracting" in result["volatility"]
class TestVolumeAnalysis:
"""Test volume analysis functionality."""
@pytest.fixture
def volume_data(self):
"""Create data with volume information."""
data = {
"volume": [1000000, 1100000, 1200000, 1500000, 2000000],
"close": [100, 101, 102, 105, 108],
}
return pd.DataFrame(data)
def test_analyze_volume_high_volume_up_move(self, volume_data):
"""Test volume analysis with high volume on up move."""
result = analyze_volume(volume_data)
assert result["current"] == 2000000
assert result["ratio"] >= 1.4 # More lenient threshold
# Check that volume analysis is working, signal may vary based on exact ratio
assert result["description"] in ["above average", "average"]
assert result["signal"] in ["bullish (high volume on up move)", "neutral"]
def test_analyze_volume_low_volume(self):
"""Test volume analysis with low volume."""
data = {
"volume": [1000000, 1100000, 1200000, 1300000, 600000],
"close": [100, 101, 102, 103, 104],
}
df = pd.DataFrame(data)
result = analyze_volume(df)
assert result["ratio"] < 0.7
assert result["description"] == "below average"
assert result["signal"] == "weak conviction"
def test_analyze_volume_insufficient_data(self):
"""Test volume analysis with insufficient data."""
data = {
"volume": [1000000],
"close": [100],
}
df = pd.DataFrame(data)
result = analyze_volume(df)
# Should still work with single data point
assert result["current"] == 1000000
assert result["average"] == 1000000
assert result["ratio"] == 1.0
def test_analyze_volume_invalid_data(self):
"""Test volume analysis with invalid data."""
data = {
"volume": [np.nan],
"close": [100],
}
df = pd.DataFrame(data)
result = analyze_volume(df)
assert result["current"] is None
assert result["signal"] == "unavailable"
class TestChartPatterns:
"""Test chart pattern identification."""
def test_identify_chart_patterns_double_bottom(self):
"""Test double bottom pattern identification."""
# Create price data with double bottom pattern
prices = [100] * 10 + [90] * 5 + [100] * 10 + [90] * 5 + [100] * 10
data = {
"low": prices,
"high": [p + 10 for p in prices],
"close": [p + 5 for p in prices],
}
df = pd.DataFrame(data)
patterns = identify_chart_patterns(df)
# Note: The pattern detection is quite strict, so we just test it runs
assert isinstance(patterns, list)
def test_identify_chart_patterns_insufficient_data(self):
"""Test chart pattern identification with insufficient data."""
data = {
"low": [90, 95, 92],
"high": [100, 105, 102],
"close": [95, 100, 97],
}
df = pd.DataFrame(data)
patterns = identify_chart_patterns(df)
assert isinstance(patterns, list)
assert len(patterns) == 0 # Not enough data for patterns
class TestATRCalculation:
"""Test Average True Range calculation."""
@pytest.fixture
def atr_data(self):
"""Create data for ATR calculation."""
data = {
"High": [105, 110, 108, 115, 112],
"Low": [95, 100, 98, 105, 102],
"Close": [100, 105, 103, 110, 107],
}
return pd.DataFrame(data)
def test_calculate_atr_basic(self, atr_data):
"""Test basic ATR calculation."""
result = calculate_atr(atr_data, period=3)
assert isinstance(result, pd.Series)
assert len(result) == len(atr_data)
# ATR values should be positive where calculated
assert (result.dropna() >= 0).all()
def test_calculate_atr_custom_period(self, atr_data):
"""Test ATR calculation with custom period."""
result = calculate_atr(atr_data, period=2)
assert isinstance(result, pd.Series)
assert len(result) == len(atr_data)
def test_calculate_atr_insufficient_data(self):
"""Test ATR calculation with insufficient data."""
data = {
"High": [105],
"Low": [95],
"Close": [100],
}
df = pd.DataFrame(data)
result = calculate_atr(df)
assert isinstance(result, pd.Series)
# Should handle insufficient data gracefully
class TestOutlookGeneration:
"""Test overall outlook generation."""
def test_generate_outlook_bullish(self):
"""Test outlook generation with bullish signals."""
df = pd.DataFrame({"close": [100, 105, 110]})
trend = "uptrend"
rsi_analysis = {"signal": "bullish"}
macd_analysis = {
"indicator": "bullish",
"crossover": "bullish crossover detected",
}
stoch_analysis = {"signal": "bullish"}
outlook = generate_outlook(
df, trend, rsi_analysis, macd_analysis, stoch_analysis
)
assert "bullish" in outlook
def test_generate_outlook_bearish(self):
"""Test outlook generation with bearish signals."""
df = pd.DataFrame({"close": [100, 95, 90]})
trend = "downtrend"
rsi_analysis = {"signal": "bearish"}
macd_analysis = {
"indicator": "bearish",
"crossover": "bearish crossover detected",
}
stoch_analysis = {"signal": "bearish"}
outlook = generate_outlook(
df, trend, rsi_analysis, macd_analysis, stoch_analysis
)
assert "bearish" in outlook
def test_generate_outlook_neutral(self):
"""Test outlook generation with mixed signals."""
df = pd.DataFrame({"close": [100, 100, 100]})
trend = "sideways"
rsi_analysis = {"signal": "neutral"}
macd_analysis = {"indicator": "neutral", "crossover": "no recent crossover"}
stoch_analysis = {"signal": "neutral"}
outlook = generate_outlook(
df, trend, rsi_analysis, macd_analysis, stoch_analysis
)
assert outlook == "neutral"
def test_generate_outlook_strongly_bullish(self):
"""Test outlook generation with very bullish signals."""
df = pd.DataFrame({"close": [100, 105, 110]})
trend = "uptrend"
rsi_analysis = {"signal": "oversold"} # Bullish signal
macd_analysis = {
"indicator": "bullish",
"crossover": "bullish crossover detected",
}
stoch_analysis = {"signal": "oversold"} # Bullish signal
outlook = generate_outlook(
df, trend, rsi_analysis, macd_analysis, stoch_analysis
)
assert "strongly bullish" in outlook
if __name__ == "__main__":
pytest.main([__file__])
```