This is page 23 of 29. Use http://codebase.md/wshobson/maverick-mcp?lines=false&page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.md
│ │ ├── config.yml
│ │ ├── feature_request.md
│ │ ├── question.md
│ │ └── security_report.md
│ ├── pull_request_template.md
│ └── workflows
│ ├── claude-code-review.yml
│ └── claude.yml
├── .gitignore
├── .python-version
├── .vscode
│ ├── launch.json
│ └── settings.json
├── alembic
│ ├── env.py
│ ├── script.py.mako
│ └── versions
│ ├── 001_initial_schema.py
│ ├── 003_add_performance_indexes.py
│ ├── 006_rename_metadata_columns.py
│ ├── 008_performance_optimization_indexes.py
│ ├── 009_rename_to_supply_demand.py
│ ├── 010_self_contained_schema.py
│ ├── 011_remove_proprietary_terms.py
│ ├── 013_add_backtest_persistence_models.py
│ ├── 014_add_portfolio_models.py
│ ├── 08e3945a0c93_merge_heads.py
│ ├── 9374a5c9b679_merge_heads_for_testing.py
│ ├── abf9b9afb134_merge_multiple_heads.py
│ ├── adda6d3fd84b_merge_proprietary_terms_removal_with_.py
│ ├── e0c75b0bdadb_fix_financial_data_precision_only.py
│ ├── f0696e2cac15_add_essential_performance_indexes.py
│ └── fix_database_integrity_issues.py
├── alembic.ini
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── DATABASE_SETUP.md
├── docker-compose.override.yml.example
├── docker-compose.yml
├── Dockerfile
├── docs
│ ├── api
│ │ └── backtesting.md
│ ├── BACKTESTING.md
│ ├── COST_BASIS_SPECIFICATION.md
│ ├── deep_research_agent.md
│ ├── exa_research_testing_strategy.md
│ ├── PORTFOLIO_PERSONALIZATION_PLAN.md
│ ├── PORTFOLIO.md
│ ├── SETUP_SELF_CONTAINED.md
│ └── speed_testing_framework.md
├── examples
│ ├── complete_speed_validation.py
│ ├── deep_research_integration.py
│ ├── llm_optimization_example.py
│ ├── llm_speed_demo.py
│ ├── monitoring_example.py
│ ├── parallel_research_example.py
│ ├── speed_optimization_demo.py
│ └── timeout_fix_demonstration.py
├── LICENSE
├── Makefile
├── MANIFEST.in
├── maverick_mcp
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── circuit_breaker.py
│ │ ├── deep_research.py
│ │ ├── market_analysis.py
│ │ ├── optimized_research.py
│ │ ├── supervisor.py
│ │ └── technical_analysis.py
│ ├── api
│ │ ├── __init__.py
│ │ ├── api_server.py
│ │ ├── connection_manager.py
│ │ ├── dependencies
│ │ │ ├── __init__.py
│ │ │ ├── stock_analysis.py
│ │ │ └── technical_analysis.py
│ │ ├── error_handling.py
│ │ ├── inspector_compatible_sse.py
│ │ ├── inspector_sse.py
│ │ ├── middleware
│ │ │ ├── error_handling.py
│ │ │ ├── mcp_logging.py
│ │ │ ├── rate_limiting_enhanced.py
│ │ │ └── security.py
│ │ ├── openapi_config.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── agents.py
│ │ │ ├── backtesting.py
│ │ │ ├── data_enhanced.py
│ │ │ ├── data.py
│ │ │ ├── health_enhanced.py
│ │ │ ├── health_tools.py
│ │ │ ├── health.py
│ │ │ ├── intelligent_backtesting.py
│ │ │ ├── introspection.py
│ │ │ ├── mcp_prompts.py
│ │ │ ├── monitoring.py
│ │ │ ├── news_sentiment_enhanced.py
│ │ │ ├── performance.py
│ │ │ ├── portfolio.py
│ │ │ ├── research.py
│ │ │ ├── screening_ddd.py
│ │ │ ├── screening_parallel.py
│ │ │ ├── screening.py
│ │ │ ├── technical_ddd.py
│ │ │ ├── technical_enhanced.py
│ │ │ ├── technical.py
│ │ │ └── tool_registry.py
│ │ ├── server.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── base_service.py
│ │ │ ├── market_service.py
│ │ │ ├── portfolio_service.py
│ │ │ ├── prompt_service.py
│ │ │ └── resource_service.py
│ │ ├── simple_sse.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── insomnia_export.py
│ │ └── postman_export.py
│ ├── application
│ │ ├── __init__.py
│ │ ├── commands
│ │ │ └── __init__.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_dto.py
│ │ ├── queries
│ │ │ ├── __init__.py
│ │ │ └── get_technical_analysis.py
│ │ └── screening
│ │ ├── __init__.py
│ │ ├── dtos.py
│ │ └── queries.py
│ ├── backtesting
│ │ ├── __init__.py
│ │ ├── ab_testing.py
│ │ ├── analysis.py
│ │ ├── batch_processing_stub.py
│ │ ├── batch_processing.py
│ │ ├── model_manager.py
│ │ ├── optimization.py
│ │ ├── persistence.py
│ │ ├── retraining_pipeline.py
│ │ ├── strategies
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── ml
│ │ │ │ ├── __init__.py
│ │ │ │ ├── adaptive.py
│ │ │ │ ├── ensemble.py
│ │ │ │ ├── feature_engineering.py
│ │ │ │ └── regime_aware.py
│ │ │ ├── ml_strategies.py
│ │ │ ├── parser.py
│ │ │ └── templates.py
│ │ ├── strategy_executor.py
│ │ ├── vectorbt_engine.py
│ │ └── visualization.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── database_self_contained.py
│ │ ├── database.py
│ │ ├── llm_optimization_config.py
│ │ ├── logging_settings.py
│ │ ├── plotly_config.py
│ │ ├── security_utils.py
│ │ ├── security.py
│ │ ├── settings.py
│ │ ├── technical_constants.py
│ │ ├── tool_estimation.py
│ │ └── validation.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── technical_analysis.py
│ │ └── visualization.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── cache_manager.py
│ │ ├── cache.py
│ │ ├── django_adapter.py
│ │ ├── health.py
│ │ ├── models.py
│ │ ├── performance.py
│ │ ├── session_management.py
│ │ └── validation.py
│ ├── database
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── optimization.py
│ ├── dependencies.py
│ ├── domain
│ │ ├── __init__.py
│ │ ├── entities
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis.py
│ │ ├── events
│ │ │ └── __init__.py
│ │ ├── portfolio.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ ├── entities.py
│ │ │ ├── services.py
│ │ │ └── value_objects.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_service.py
│ │ ├── stock_analysis
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis_service.py
│ │ └── value_objects
│ │ ├── __init__.py
│ │ └── technical_indicators.py
│ ├── exceptions.py
│ ├── infrastructure
│ │ ├── __init__.py
│ │ ├── cache
│ │ │ └── __init__.py
│ │ ├── caching
│ │ │ ├── __init__.py
│ │ │ └── cache_management_service.py
│ │ ├── connection_manager.py
│ │ ├── data_fetching
│ │ │ ├── __init__.py
│ │ │ └── stock_data_service.py
│ │ ├── health
│ │ │ ├── __init__.py
│ │ │ └── health_checker.py
│ │ ├── persistence
│ │ │ ├── __init__.py
│ │ │ └── stock_repository.py
│ │ ├── providers
│ │ │ └── __init__.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ └── repositories.py
│ │ └── sse_optimizer.py
│ ├── langchain_tools
│ │ ├── __init__.py
│ │ ├── adapters.py
│ │ └── registry.py
│ ├── logging_config.py
│ ├── memory
│ │ ├── __init__.py
│ │ └── stores.py
│ ├── monitoring
│ │ ├── __init__.py
│ │ ├── health_check.py
│ │ ├── health_monitor.py
│ │ ├── integration_example.py
│ │ ├── metrics.py
│ │ ├── middleware.py
│ │ └── status_dashboard.py
│ ├── providers
│ │ ├── __init__.py
│ │ ├── dependencies.py
│ │ ├── factories
│ │ │ ├── __init__.py
│ │ │ ├── config_factory.py
│ │ │ └── provider_factory.py
│ │ ├── implementations
│ │ │ ├── __init__.py
│ │ │ ├── cache_adapter.py
│ │ │ ├── macro_data_adapter.py
│ │ │ ├── market_data_adapter.py
│ │ │ ├── persistence_adapter.py
│ │ │ └── stock_data_adapter.py
│ │ ├── interfaces
│ │ │ ├── __init__.py
│ │ │ ├── cache.py
│ │ │ ├── config.py
│ │ │ ├── macro_data.py
│ │ │ ├── market_data.py
│ │ │ ├── persistence.py
│ │ │ └── stock_data.py
│ │ ├── llm_factory.py
│ │ ├── macro_data.py
│ │ ├── market_data.py
│ │ ├── mocks
│ │ │ ├── __init__.py
│ │ │ ├── mock_cache.py
│ │ │ ├── mock_config.py
│ │ │ ├── mock_macro_data.py
│ │ │ ├── mock_market_data.py
│ │ │ ├── mock_persistence.py
│ │ │ └── mock_stock_data.py
│ │ ├── openrouter_provider.py
│ │ ├── optimized_screening.py
│ │ ├── optimized_stock_data.py
│ │ └── stock_data.py
│ ├── README.md
│ ├── tests
│ │ ├── __init__.py
│ │ ├── README_INMEMORY_TESTS.md
│ │ ├── test_cache_debug.py
│ │ ├── test_fixes_validation.py
│ │ ├── test_in_memory_routers.py
│ │ ├── test_in_memory_server.py
│ │ ├── test_macro_data_provider.py
│ │ ├── test_mailgun_email.py
│ │ ├── test_market_calendar_caching.py
│ │ ├── test_mcp_tool_fixes_pytest.py
│ │ ├── test_mcp_tool_fixes.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_models_functional.py
│ │ ├── test_server.py
│ │ ├── test_stock_data_enhanced.py
│ │ ├── test_stock_data_provider.py
│ │ └── test_technical_analysis.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── performance_monitoring.py
│ │ ├── portfolio_manager.py
│ │ ├── risk_management.py
│ │ └── sentiment_analysis.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── agent_errors.py
│ │ ├── batch_processing.py
│ │ ├── cache_warmer.py
│ │ ├── circuit_breaker_decorators.py
│ │ ├── circuit_breaker_services.py
│ │ ├── circuit_breaker.py
│ │ ├── data_chunking.py
│ │ ├── database_monitoring.py
│ │ ├── debug_utils.py
│ │ ├── fallback_strategies.py
│ │ ├── llm_optimization.py
│ │ ├── logging_example.py
│ │ ├── logging_init.py
│ │ ├── logging.py
│ │ ├── mcp_logging.py
│ │ ├── memory_profiler.py
│ │ ├── monitoring_middleware.py
│ │ ├── monitoring.py
│ │ ├── orchestration_logging.py
│ │ ├── parallel_research.py
│ │ ├── parallel_screening.py
│ │ ├── quick_cache.py
│ │ ├── resource_manager.py
│ │ ├── shutdown.py
│ │ ├── stock_helpers.py
│ │ ├── structured_logger.py
│ │ ├── tool_monitoring.py
│ │ ├── tracing.py
│ │ └── yfinance_pool.py
│ ├── validation
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── data.py
│ │ ├── middleware.py
│ │ ├── portfolio.py
│ │ ├── responses.py
│ │ ├── screening.py
│ │ └── technical.py
│ └── workflows
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── market_analyzer.py
│ │ ├── optimizer_agent.py
│ │ ├── strategy_selector.py
│ │ └── validator_agent.py
│ ├── backtesting_workflow.py
│ └── state.py
├── PLANS.md
├── pyproject.toml
├── pyrightconfig.json
├── README.md
├── scripts
│ ├── dev.sh
│ ├── INSTALLATION_GUIDE.md
│ ├── load_example.py
│ ├── load_market_data.py
│ ├── load_tiingo_data.py
│ ├── migrate_db.py
│ ├── README_TIINGO_LOADER.md
│ ├── requirements_tiingo.txt
│ ├── run_stock_screening.py
│ ├── run-migrations.sh
│ ├── seed_db.py
│ ├── seed_sp500.py
│ ├── setup_database.sh
│ ├── setup_self_contained.py
│ ├── setup_sp500_database.sh
│ ├── test_seeded_data.py
│ ├── test_tiingo_loader.py
│ ├── tiingo_config.py
│ └── validate_setup.py
├── SECURITY.md
├── server.json
├── setup.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── core
│ │ └── test_technical_analysis.py
│ ├── data
│ │ └── test_portfolio_models.py
│ ├── domain
│ │ ├── conftest.py
│ │ ├── test_portfolio_entities.py
│ │ └── test_technical_analysis_service.py
│ ├── fixtures
│ │ └── orchestration_fixtures.py
│ ├── integration
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── README.md
│ │ ├── run_integration_tests.sh
│ │ ├── test_api_technical.py
│ │ ├── test_chaos_engineering.py
│ │ ├── test_config_management.py
│ │ ├── test_full_backtest_workflow_advanced.py
│ │ ├── test_full_backtest_workflow.py
│ │ ├── test_high_volume.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_orchestration_complete.py
│ │ ├── test_portfolio_persistence.py
│ │ ├── test_redis_cache.py
│ │ ├── test_security_integration.py.disabled
│ │ └── vcr_setup.py
│ ├── performance
│ │ ├── __init__.py
│ │ ├── test_benchmarks.py
│ │ ├── test_load.py
│ │ ├── test_profiling.py
│ │ └── test_stress.py
│ ├── providers
│ │ └── test_stock_data_simple.py
│ ├── README.md
│ ├── test_agents_router_mcp.py
│ ├── test_backtest_persistence.py
│ ├── test_cache_management_service.py
│ ├── test_cache_serialization.py
│ ├── test_circuit_breaker.py
│ ├── test_database_pool_config_simple.py
│ ├── test_database_pool_config.py
│ ├── test_deep_research_functional.py
│ ├── test_deep_research_integration.py
│ ├── test_deep_research_parallel_execution.py
│ ├── test_error_handling.py
│ ├── test_event_loop_integrity.py
│ ├── test_exa_research_integration.py
│ ├── test_exception_hierarchy.py
│ ├── test_financial_search.py
│ ├── test_graceful_shutdown.py
│ ├── test_integration_simple.py
│ ├── test_langgraph_workflow.py
│ ├── test_market_data_async.py
│ ├── test_market_data_simple.py
│ ├── test_mcp_orchestration_functional.py
│ ├── test_ml_strategies.py
│ ├── test_optimized_research_agent.py
│ ├── test_orchestration_integration.py
│ ├── test_orchestration_logging.py
│ ├── test_orchestration_tools_simple.py
│ ├── test_parallel_research_integration.py
│ ├── test_parallel_research_orchestrator.py
│ ├── test_parallel_research_performance.py
│ ├── test_performance_optimizations.py
│ ├── test_production_validation.py
│ ├── test_provider_architecture.py
│ ├── test_rate_limiting_enhanced.py
│ ├── test_runner_validation.py
│ ├── test_security_comprehensive.py.disabled
│ ├── test_security_cors.py
│ ├── test_security_enhancements.py.disabled
│ ├── test_security_headers.py
│ ├── test_security_penetration.py
│ ├── test_session_management.py
│ ├── test_speed_optimization_validation.py
│ ├── test_stock_analysis_dependencies.py
│ ├── test_stock_analysis_service.py
│ ├── test_stock_data_fetching_service.py
│ ├── test_supervisor_agent.py
│ ├── test_supervisor_functional.py
│ ├── test_tool_estimation_config.py
│ ├── test_visualization.py
│ └── utils
│ ├── test_agent_errors.py
│ ├── test_logging.py
│ ├── test_parallel_screening.py
│ └── test_quick_cache.py
├── tools
│ ├── check_orchestration_config.py
│ ├── experiments
│ │ ├── validation_examples.py
│ │ └── validation_fixed.py
│ ├── fast_dev.sh
│ ├── hot_reload.py
│ ├── quick_test.py
│ └── templates
│ ├── new_router_template.py
│ ├── new_tool_template.py
│ ├── screening_strategy_template.py
│ └── test_template.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/scripts/load_tiingo_data.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
"""
Tiingo Data Loader for MaverickMCP
Loads market data from Tiingo API into the self-contained MaverickMCP database.
Supports batch loading, rate limiting, progress tracking, and technical indicator calculation.
"""
import argparse
import asyncio
import json
import logging
import os
import sys
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any
import aiohttp
import numpy as np
import pandas as pd
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from tqdm import tqdm
# Add parent directory to path for imports
sys.path.append(str(Path(__file__).parent.parent))
from maverick_mcp.data.models import (
Stock,
bulk_insert_price_data,
)
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
# Configuration - following tiingo-python patterns from api.py
# Base URL without version suffix (will be added per endpoint)
DEFAULT_BASE_URL = os.getenv("TIINGO_BASE_URL", "https://api.tiingo.com")
# API token from environment
TIINGO_API_TOKEN = os.getenv("TIINGO_API_TOKEN")
# Rate limiting configuration - can be overridden by command line
DEFAULT_MAX_CONCURRENT = int(os.getenv("TIINGO_MAX_CONCURRENT", "5"))
DEFAULT_RATE_LIMIT_PER_HOUR = int(os.getenv("TIINGO_RATE_LIMIT", "2400"))
DEFAULT_CHECKPOINT_FILE = os.getenv(
"TIINGO_CHECKPOINT_FILE", "tiingo_load_progress.json"
)
# Default timeout for requests (from tiingo-python)
DEFAULT_TIMEOUT = int(os.getenv("TIINGO_TIMEOUT", "10"))
class TiingoDataLoader:
"""Handles loading data from Tiingo API into MaverickMCP database.
Following the design patterns from tiingo-python library.
"""
def __init__(
self,
api_token: str | None = None,
db_url: str | None = None,
base_url: str | None = None,
timeout: int | None = None,
rate_limit_per_hour: int | None = None,
checkpoint_file: str | None = None,
):
"""Initialize the Tiingo data loader.
Args:
api_token: Tiingo API token (defaults to env var)
db_url: Database URL (defaults to env var)
base_url: Base URL for Tiingo API (defaults to env var)
timeout: Request timeout in seconds
rate_limit_per_hour: Max requests per hour
checkpoint_file: Path to checkpoint file
"""
# API configuration (following tiingo-python patterns)
self.api_token = api_token or TIINGO_API_TOKEN
if not self.api_token:
raise ValueError(
"API token required. Set TIINGO_API_TOKEN env var or pass api_token parameter."
)
# Database configuration
self.db_url = db_url or os.getenv("DATABASE_URL")
if not self.db_url:
raise ValueError(
"Database URL required. Set DATABASE_URL env var or pass db_url parameter."
)
self.engine = create_engine(self.db_url)
self.SessionLocal = sessionmaker(bind=self.engine)
# API endpoint configuration
self.base_url = base_url or DEFAULT_BASE_URL
self.timeout = timeout or DEFAULT_TIMEOUT
# Rate limiting
self.request_count = 0
self.start_time = datetime.now()
self.rate_limit_per_hour = rate_limit_per_hour or DEFAULT_RATE_LIMIT_PER_HOUR
self.rate_limit_delay = (
3600 / self.rate_limit_per_hour
) # seconds between requests
# Checkpoint configuration
self.checkpoint_file = checkpoint_file or DEFAULT_CHECKPOINT_FILE
self.checkpoint_data = self.load_checkpoint()
# Session configuration (following tiingo-python)
self._session = None
def load_checkpoint(self) -> dict:
"""Load checkpoint data if exists."""
if Path(self.checkpoint_file).exists():
try:
with open(self.checkpoint_file) as f:
return json.load(f)
except Exception as e:
logger.warning(f"Could not load checkpoint: {e}")
return {"completed_symbols": [], "last_symbol": None}
def save_checkpoint(self, symbol: str):
"""Save checkpoint data."""
self.checkpoint_data["completed_symbols"].append(symbol)
self.checkpoint_data["last_symbol"] = symbol
self.checkpoint_data["timestamp"] = datetime.now().isoformat()
try:
with open(self.checkpoint_file, "w") as f:
json.dump(self.checkpoint_data, f, indent=2)
except Exception as e:
logger.error(f"Could not save checkpoint: {e}")
def _get_headers(self) -> dict[str, str]:
"""Get request headers following tiingo-python patterns."""
return {
"Content-Type": "application/json",
"Authorization": f"Token {self.api_token}",
"User-Agent": "tiingo-python-client/maverick-mcp",
}
async def _request(
self,
session: aiohttp.ClientSession,
endpoint: str,
params: dict[str, Any] | None = None,
max_retries: int = 3,
) -> Any | None:
"""Make HTTP request with rate limiting and error handling.
Following tiingo-python's request patterns from api.py.
Args:
session: aiohttp session
endpoint: API endpoint (will be appended to base_url)
params: Query parameters
max_retries: Maximum number of retries
Returns:
Response data or None if failed
"""
# Rate limiting
await asyncio.sleep(self.rate_limit_delay)
self.request_count += 1
# Build URL
url = f"{self.base_url}{endpoint}"
if params:
param_str = "&".join([f"{k}={v}" for k, v in params.items()])
url = f"{url}?{param_str}"
headers = self._get_headers()
for attempt in range(max_retries):
try:
timeout = aiohttp.ClientTimeout(total=self.timeout)
async with session.get(
url, headers=headers, timeout=timeout
) as response:
if response.status == 200:
return await response.json()
elif response.status == 400:
error_msg = await response.text()
logger.error(f"Bad request (400): {error_msg}")
return None
elif response.status == 404:
logger.warning(f"Resource not found (404): {endpoint}")
return None
elif response.status == 429:
# Rate limited, exponential backoff
retry_after = response.headers.get("Retry-After")
if retry_after:
wait_time = int(retry_after)
else:
wait_time = min(60 * (2**attempt), 300)
logger.warning(f"Rate limited, waiting {wait_time}s...")
await asyncio.sleep(wait_time)
continue
elif response.status >= 500:
# Server error, retry with backoff
if attempt < max_retries - 1:
wait_time = min(5 * (2**attempt), 60)
logger.warning(
f"Server error {response.status}, retry in {wait_time}s..."
)
await asyncio.sleep(wait_time)
continue
else:
logger.error(f"Server error after {max_retries} attempts")
return None
else:
error_text = await response.text()
logger.error(f"HTTP {response.status}: {error_text}")
return None
except TimeoutError:
if attempt < max_retries - 1:
wait_time = min(10 * (2**attempt), 60)
logger.warning(f"Timeout, retry in {wait_time}s...")
await asyncio.sleep(wait_time)
continue
else:
logger.error(f"Timeout after {max_retries} attempts")
return None
except Exception as e:
if attempt < max_retries - 1:
wait_time = min(10 * (2**attempt), 60)
logger.warning(f"Error: {e}, retry in {wait_time}s...")
await asyncio.sleep(wait_time)
continue
else:
logger.error(f"Failed after {max_retries} attempts: {e}")
return None
return None
def _process_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
"""Process raw DataFrame from Tiingo API.
Following tiingo-python's DataFrame processing patterns.
Args:
df: Raw DataFrame from API
Returns:
Processed DataFrame with proper column names and index
"""
if df.empty:
return df
# Handle date column following tiingo-python
if "date" in df.columns:
df["date"] = pd.to_datetime(df["date"])
# Standardize column names to match expected format
column_mapping = {
"open": "Open",
"high": "High",
"low": "Low",
"close": "Close",
"volume": "Volume",
"adjOpen": "Adj Open",
"adjHigh": "Adj High",
"adjLow": "Adj Low",
"adjClose": "Adj Close",
"adjVolume": "Adj Volume",
"divCash": "Dividend",
"splitFactor": "Split Factor",
}
# Only rename columns that exist
rename_dict = {
old: new for old, new in column_mapping.items() if old in df.columns
}
if rename_dict:
df = df.rename(columns=rename_dict)
# Set date as index
df = df.set_index("date")
# Localize to UTC following tiingo-python approach
if df.index.tz is None:
df.index = df.index.tz_localize("UTC")
# For database storage, convert to date only (no time component)
df.index = df.index.date
return df
async def get_ticker_metadata(
self, session: aiohttp.ClientSession, symbol: str
) -> dict[str, Any] | None:
"""Get metadata for a specific ticker.
Following tiingo-python's get_ticker_metadata pattern.
"""
endpoint = f"/tiingo/daily/{symbol}"
return await self._request(session, endpoint)
async def get_available_symbols(
self,
session: aiohttp.ClientSession,
asset_types: list[str] | None = None,
exchanges: list[str] | None = None,
) -> list[str]:
"""Get list of available symbols from Tiingo with optional filtering.
Following tiingo-python's list_tickers pattern.
"""
endpoint = "/tiingo/daily/supported_tickers"
data = await self._request(session, endpoint)
if data:
# Default filters if not provided
asset_types = asset_types or ["Stock"]
exchanges = exchanges or ["NYSE", "NASDAQ"]
symbols = []
for ticker_info in data:
if (
ticker_info.get("exchange") in exchanges
and ticker_info.get("assetType") in asset_types
and ticker_info.get("priceCurrency") == "USD"
):
symbols.append(ticker_info["ticker"])
return symbols
return []
async def get_daily_price_history(
self,
session: aiohttp.ClientSession,
symbol: str,
start_date: str | None = None,
end_date: str | None = None,
frequency: str = "daily",
columns: list[str] | None = None,
) -> pd.DataFrame:
"""Fetch historical price data for a symbol.
Following tiingo-python's get_dataframe pattern from api.py.
Args:
session: aiohttp session
symbol: Stock ticker symbol
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
frequency: Data frequency (daily, weekly, monthly, annually)
columns: Specific columns to return
Returns:
DataFrame with price history
"""
endpoint = f"/tiingo/daily/{symbol}/prices"
# Build params following tiingo-python
params = {
"format": "json",
"resampleFreq": frequency,
}
if start_date:
params["startDate"] = start_date
if end_date:
params["endDate"] = end_date
if columns:
params["columns"] = ",".join(columns)
data = await self._request(session, endpoint, params)
if data:
try:
df = pd.DataFrame(data)
if not df.empty:
# Process DataFrame following tiingo-python patterns
df = self._process_dataframe(df)
# Validate data integrity
if len(df) == 0:
logger.warning(f"Empty dataset returned for {symbol}")
return pd.DataFrame()
# Check for required columns
required_cols = ["Open", "High", "Low", "Close", "Volume"]
missing_cols = [
col for col in required_cols if col not in df.columns
]
if missing_cols:
logger.warning(f"Missing columns for {symbol}: {missing_cols}")
return df
except Exception as e:
logger.error(f"Error processing data for {symbol}: {e}")
return pd.DataFrame()
return pd.DataFrame()
def calculate_technical_indicators(self, df: pd.DataFrame) -> pd.DataFrame:
"""Calculate technical indicators for the data."""
if df.empty or len(df) < 200:
return df
try:
# Moving averages
df["SMA_20"] = df["Close"].rolling(window=20).mean()
df["SMA_50"] = df["Close"].rolling(window=50).mean()
df["SMA_150"] = df["Close"].rolling(window=150).mean()
df["SMA_200"] = df["Close"].rolling(window=200).mean()
df["EMA_21"] = df["Close"].ewm(span=21, adjust=False).mean()
# RSI
delta = df["Close"].diff()
gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
rs = gain / loss
df["RSI"] = 100 - (100 / (1 + rs))
# MACD
exp1 = df["Close"].ewm(span=12, adjust=False).mean()
exp2 = df["Close"].ewm(span=26, adjust=False).mean()
df["MACD"] = exp1 - exp2
df["MACD_Signal"] = df["MACD"].ewm(span=9, adjust=False).mean()
df["MACD_Histogram"] = df["MACD"] - df["MACD_Signal"]
# ATR
high_low = df["High"] - df["Low"]
high_close = np.abs(df["High"] - df["Close"].shift())
low_close = np.abs(df["Low"] - df["Close"].shift())
ranges = pd.concat([high_low, high_close, low_close], axis=1)
true_range = np.max(ranges, axis=1)
df["ATR"] = true_range.rolling(14).mean()
# ADR (Average Daily Range) as percentage
df["ADR_PCT"] = (
((df["High"] - df["Low"]) / df["Close"] * 100).rolling(20).mean()
)
# Volume indicators
df["Volume_SMA_30"] = df["Volume"].rolling(window=30).mean()
df["Volume_Ratio"] = df["Volume"] / df["Volume_SMA_30"]
# Momentum Score (simplified)
returns = df["Close"].pct_change(periods=252) # 1-year returns
df["Momentum_Score"] = returns.rank(pct=True) * 100
except Exception as e:
logger.error(f"Error calculating indicators: {e}")
return df
def run_maverick_screening(self, df: pd.DataFrame, symbol: str) -> dict | None:
"""Run Maverick momentum screening algorithm."""
if df.empty or len(df) < 200:
return None
try:
latest = df.iloc[-1]
# Maverick criteria
price_above_ema21 = latest["Close"] > latest.get("EMA_21", 0)
ema21_above_sma50 = latest.get("EMA_21", 0) > latest.get("SMA_50", 0)
sma50_above_sma200 = latest.get("SMA_50", 0) > latest.get("SMA_200", 0)
strong_momentum = latest.get("Momentum_Score", 0) > 70
# Calculate combined score
score = 0
if price_above_ema21:
score += 25
if ema21_above_sma50:
score += 25
if sma50_above_sma200:
score += 25
if strong_momentum:
score += 25
if score >= 75: # Meets criteria
return {
"stock": symbol,
"close": float(latest["Close"]),
"volume": int(latest["Volume"]),
"momentum_score": float(latest.get("Momentum_Score", 0)),
"combined_score": score,
"adr_pct": float(latest.get("ADR_PCT", 0)),
"atr": float(latest.get("ATR", 0)),
"ema_21": float(latest.get("EMA_21", 0)),
"sma_50": float(latest.get("SMA_50", 0)),
"sma_150": float(latest.get("SMA_150", 0)),
"sma_200": float(latest.get("SMA_200", 0)),
}
except Exception as e:
logger.error(f"Error in Maverick screening for {symbol}: {e}")
return None
def run_bear_screening(self, df: pd.DataFrame, symbol: str) -> dict | None:
"""Run Bear market screening algorithm."""
if df.empty or len(df) < 200:
return None
try:
latest = df.iloc[-1]
# Bear criteria
price_below_ema21 = latest["Close"] < latest.get("EMA_21", float("inf"))
ema21_below_sma50 = latest.get("EMA_21", float("inf")) < latest.get(
"SMA_50", float("inf")
)
weak_momentum = latest.get("Momentum_Score", 100) < 30
negative_macd = latest.get("MACD", 0) < 0
# Calculate bear score
score = 0
if price_below_ema21:
score += 25
if ema21_below_sma50:
score += 25
if weak_momentum:
score += 25
if negative_macd:
score += 25
if score >= 75: # Meets bear criteria
return {
"stock": symbol,
"close": float(latest["Close"]),
"volume": int(latest["Volume"]),
"momentum_score": float(latest.get("Momentum_Score", 0)),
"score": score,
"rsi_14": float(latest.get("RSI", 0)),
"macd": float(latest.get("MACD", 0)),
"macd_signal": float(latest.get("MACD_Signal", 0)),
"macd_histogram": float(latest.get("MACD_Histogram", 0)),
"adr_pct": float(latest.get("ADR_PCT", 0)),
"atr": float(latest.get("ATR", 0)),
"ema_21": float(latest.get("EMA_21", 0)),
"sma_50": float(latest.get("SMA_50", 0)),
"sma_200": float(latest.get("SMA_200", 0)),
}
except Exception as e:
logger.error(f"Error in Bear screening for {symbol}: {e}")
return None
def run_supply_demand_screening(self, df: pd.DataFrame, symbol: str) -> dict | None:
"""Run Supply/Demand breakout screening algorithm."""
if df.empty or len(df) < 200:
return None
try:
latest = df.iloc[-1]
# Supply/Demand criteria (accumulation phase)
close = latest["Close"]
sma_50 = latest.get("SMA_50", 0)
sma_150 = latest.get("SMA_150", 0)
sma_200 = latest.get("SMA_200", 0)
# Check for proper alignment
price_above_all = close > sma_50 > sma_150 > sma_200
strong_momentum = latest.get("Momentum_Score", 0) > 80
# Volume confirmation
volume_confirmation = latest.get("Volume_Ratio", 0) > 1.2
if price_above_all and strong_momentum and volume_confirmation:
return {
"stock": symbol,
"close": float(close),
"volume": int(latest["Volume"]),
"momentum_score": float(latest.get("Momentum_Score", 0)),
"adr_pct": float(latest.get("ADR_PCT", 0)),
"atr": float(latest.get("ATR", 0)),
"ema_21": float(latest.get("EMA_21", 0)),
"sma_50": float(sma_50),
"sma_150": float(sma_150),
"sma_200": float(sma_200),
"avg_volume_30d": float(latest.get("Volume_SMA_30", 0)),
}
except Exception as e:
logger.error(f"Error in Supply/Demand screening for {symbol}: {e}")
return None
async def process_symbol(
self,
session: aiohttp.ClientSession,
symbol: str,
start_date: str,
end_date: str,
calculate_indicators: bool = True,
run_screening: bool = True,
) -> tuple[bool, dict | None]:
"""Process a single symbol - fetch data, calculate indicators, run screening."""
try:
# Skip if already processed
if symbol in self.checkpoint_data.get("completed_symbols", []):
logger.info(f"Skipping {symbol} - already processed")
return True, None
# Fetch historical data using tiingo-python pattern
df = await self.get_daily_price_history(
session, symbol, start_date, end_date
)
if df.empty:
logger.warning(f"No data available for {symbol}")
return False, None
# Store in database
with self.SessionLocal() as db_session:
# Create or get stock record
Stock.get_or_create(db_session, symbol)
# Bulk insert price data
records_inserted = bulk_insert_price_data(db_session, symbol, df)
logger.info(f"Inserted {records_inserted} records for {symbol}")
screening_results = {}
if calculate_indicators:
# Calculate technical indicators
df = self.calculate_technical_indicators(df)
if run_screening:
# Run screening algorithms
maverick_result = self.run_maverick_screening(df, symbol)
if maverick_result:
screening_results["maverick"] = maverick_result
bear_result = self.run_bear_screening(df, symbol)
if bear_result:
screening_results["bear"] = bear_result
supply_demand_result = self.run_supply_demand_screening(df, symbol)
if supply_demand_result:
screening_results["supply_demand"] = supply_demand_result
# Save checkpoint
self.save_checkpoint(symbol)
return True, screening_results
except Exception as e:
logger.error(f"Error processing {symbol}: {e}")
return False, None
async def load_symbols(
self,
symbols: list[str],
start_date: str,
end_date: str,
calculate_indicators: bool = True,
run_screening: bool = True,
max_concurrent: int = None,
):
"""Load data for multiple symbols with concurrent processing."""
logger.info(
f"Loading data for {len(symbols)} symbols from {start_date} to {end_date}"
)
# Filter out already processed symbols if resuming
symbols_to_process = [
s
for s in symbols
if s not in self.checkpoint_data.get("completed_symbols", [])
]
if len(symbols_to_process) < len(symbols):
logger.info(
f"Resuming: {len(symbols) - len(symbols_to_process)} symbols already processed"
)
screening_results = {"maverick": [], "bear": [], "supply_demand": []}
# Use provided max_concurrent or default
concurrent_limit = max_concurrent or DEFAULT_MAX_CONCURRENT
async with aiohttp.ClientSession() as session:
# Process in batches with semaphore for rate limiting
semaphore = asyncio.Semaphore(concurrent_limit)
async def process_with_semaphore(symbol):
async with semaphore:
return await self.process_symbol(
session,
symbol,
start_date,
end_date,
calculate_indicators,
run_screening,
)
# Create tasks with progress bar
tasks = []
for symbol in symbols_to_process:
tasks.append(process_with_semaphore(symbol))
# Process with progress bar
with tqdm(total=len(tasks), desc="Processing symbols") as pbar:
for coro in asyncio.as_completed(tasks):
success, results = await coro
if results:
for screen_type, data in results.items():
screening_results[screen_type].append(data)
pbar.update(1)
# Store screening results in database
if run_screening:
self.store_screening_results(screening_results)
logger.info(f"Completed loading {len(symbols_to_process)} symbols")
logger.info(
f"Screening results - Maverick: {len(screening_results['maverick'])}, "
f"Bear: {len(screening_results['bear'])}, "
f"Supply/Demand: {len(screening_results['supply_demand'])}"
)
def store_screening_results(self, results: dict):
"""Store screening results in database."""
with self.SessionLocal() as db_session:
# Store Maverick results
for _data in results["maverick"]:
# Implementation would create MaverickStocks records
pass
# Store Bear results
for _data in results["bear"]:
# Implementation would create MaverickBearStocks records
pass
# Store Supply/Demand results
for _data in results["supply_demand"]:
# Implementation would create SupplyDemandBreakoutStocks records
pass
db_session.commit()
def get_test_symbols() -> list[str]:
"""Get a small test set of symbols for development.
These are just for testing - production use should load from
external sources or command line arguments.
"""
# Test symbols from different sectors for comprehensive testing
return [
"AAPL", # Apple - Tech
"MSFT", # Microsoft - Tech
"GOOGL", # Alphabet - Tech
"AMZN", # Amazon - Consumer Discretionary
"NVDA", # NVIDIA - Tech
"META", # Meta - Communication
"TSLA", # Tesla - Consumer Discretionary
"UNH", # UnitedHealth - Healthcare
"JPM", # JPMorgan Chase - Financials
"V", # Visa - Financials
"WMT", # Walmart - Consumer Staples
"JNJ", # Johnson & Johnson - Healthcare
"MA", # Mastercard - Financials
"HD", # Home Depot - Consumer Discretionary
"PG", # Procter & Gamble - Consumer Staples
"XOM", # ExxonMobil - Energy
"CVX", # Chevron - Energy
"KO", # Coca-Cola - Consumer Staples
"PEP", # PepsiCo - Consumer Staples
"ADBE", # Adobe - Tech
"NFLX", # Netflix - Communication
"CRM", # Salesforce - Tech
"DIS", # Disney - Communication
"COST", # Costco - Consumer Staples
"MRK", # Merck - Healthcare
]
def get_sp500_symbols() -> list[str]:
"""Get S&P 500 symbols list from external source or file.
This function should load S&P 500 symbols from:
1. Environment variable SP500_SYMBOLS_FILE pointing to a file
2. Download from a public data source
3. Return empty list with warning if unavailable
"""
# Try to load from file specified in environment
symbols_file = os.getenv("SP500_SYMBOLS_FILE")
if symbols_file and Path(symbols_file).exists():
try:
with open(symbols_file) as f:
symbols = [line.strip() for line in f if line.strip()]
logger.info(
f"Loaded {len(symbols)} S&P 500 symbols from {symbols_file}"
)
return symbols
except Exception as e:
logger.warning(f"Could not load S&P 500 symbols from {symbols_file}: {e}")
# Try to fetch from a public source (like Wikipedia or Yahoo Finance)
try:
# Using pandas to read S&P 500 list from Wikipedia
url = "https://en.wikipedia.org/wiki/List_of_S%26P_500_companies"
tables = pd.read_html(url)
sp500_table = tables[0] # First table contains the S&P 500 list
symbols = sp500_table["Symbol"].tolist()
logger.info(f"Fetched {len(symbols)} S&P 500 symbols from Wikipedia")
# Optional: Save to cache file for future use
cache_file = os.getenv("SP500_CACHE_FILE", "sp500_symbols_cache.txt")
try:
with open(cache_file, "w") as f:
for symbol in symbols:
f.write(f"{symbol}\n")
logger.info(f"Cached S&P 500 symbols to {cache_file}")
except Exception as e:
logger.debug(f"Could not cache symbols: {e}")
return symbols
except Exception as e:
logger.warning(f"Could not fetch S&P 500 symbols from web: {e}")
# Try to load from cache if web fetch failed
cache_file = os.getenv("SP500_CACHE_FILE", "sp500_symbols_cache.txt")
if Path(cache_file).exists():
try:
with open(cache_file) as f:
symbols = [line.strip() for line in f if line.strip()]
logger.info(f"Loaded {len(symbols)} S&P 500 symbols from cache")
return symbols
except Exception as e:
logger.warning(f"Could not load from cache: {e}")
logger.error("Unable to load S&P 500 symbols. Please specify --file or --symbols")
return []
def main():
"""Main entry point."""
parser = argparse.ArgumentParser(description="Load market data from Tiingo API")
parser.add_argument("--symbols", nargs="+", help="List of symbols to load")
parser.add_argument("--file", help="Load symbols from file (one per line)")
parser.add_argument(
"--test", action="store_true", help="Load test set of 25 symbols"
)
parser.add_argument("--sp500", action="store_true", help="Load S&P 500 symbols")
parser.add_argument(
"--years", type=int, default=2, help="Number of years of history"
)
parser.add_argument("--start-date", help="Start date (YYYY-MM-DD)")
parser.add_argument("--end-date", help="End date (YYYY-MM-DD)")
parser.add_argument(
"--calculate-indicators",
action="store_true",
help="Calculate technical indicators",
)
parser.add_argument(
"--run-screening", action="store_true", help="Run screening algorithms"
)
parser.add_argument(
"--max-concurrent", type=int, default=5, help="Maximum concurrent requests"
)
parser.add_argument("--resume", action="store_true", help="Resume from checkpoint")
parser.add_argument("--db-url", help="Database URL override")
args = parser.parse_args()
# Check for API token
if not TIINGO_API_TOKEN:
logger.error("TIINGO_API_TOKEN environment variable not set")
sys.exit(1)
# Determine database URL
db_url = args.db_url or os.getenv("MCP_DATABASE_URL") or os.getenv("DATABASE_URL")
if not db_url:
logger.error("Database URL not configured")
sys.exit(1)
# Determine symbols to load
symbols = []
if args.symbols:
symbols = args.symbols
elif args.file:
# Load symbols from file
try:
with open(args.file) as f:
symbols = [line.strip() for line in f if line.strip()]
logger.info(f"Loaded {len(symbols)} symbols from {args.file}")
except Exception as e:
logger.error(f"Could not read symbols from file: {e}")
sys.exit(1)
elif args.test:
symbols = get_test_symbols()
logger.info(f"Using test set of {len(symbols)} symbols")
elif args.sp500:
symbols = get_sp500_symbols()
logger.info(f"Using S&P 500 symbols ({len(symbols)} total)")
else:
logger.error("No symbols specified. Use --symbols, --file, --test, or --sp500")
sys.exit(1)
# Determine date range
end_date = args.end_date or datetime.now().strftime("%Y-%m-%d")
if args.start_date:
start_date = args.start_date
else:
start_date = (datetime.now() - timedelta(days=365 * args.years)).strftime(
"%Y-%m-%d"
)
# Create loader using tiingo-python style initialization
loader = TiingoDataLoader(
api_token=TIINGO_API_TOKEN,
db_url=db_url,
rate_limit_per_hour=DEFAULT_RATE_LIMIT_PER_HOUR,
)
# Run async loading
asyncio.run(
loader.load_symbols(
symbols,
start_date,
end_date,
calculate_indicators=args.calculate_indicators,
run_screening=args.run_screening,
max_concurrent=args.max_concurrent,
)
)
logger.info("Data loading complete!")
# Clean up checkpoint if completed successfully
checkpoint_file = DEFAULT_CHECKPOINT_FILE
if not args.resume and Path(checkpoint_file).exists():
os.remove(checkpoint_file)
logger.info("Removed checkpoint file")
if __name__ == "__main__":
main()
```
--------------------------------------------------------------------------------
/tests/performance/test_stress.py:
--------------------------------------------------------------------------------
```python
"""
Stress Testing for Resource Usage Under Load.
This test suite covers:
- Sustained load testing (1+ hour)
- Memory leak detection over time
- CPU utilization monitoring under stress
- Database connection pool exhaustion
- File descriptor limits testing
- Network connection limits
- Queue overflow scenarios
- System stability under extreme conditions
"""
import asyncio
import gc
import logging
import resource
import threading
import time
from dataclasses import dataclass
from typing import Any
from unittest.mock import Mock
import numpy as np
import pandas as pd
import psutil
import pytest
from maverick_mcp.backtesting import VectorBTEngine
from maverick_mcp.backtesting.persistence import BacktestPersistenceManager
from maverick_mcp.backtesting.strategies import STRATEGY_TEMPLATES
logger = logging.getLogger(__name__)
@dataclass
class ResourceSnapshot:
"""Snapshot of system resources at a point in time."""
timestamp: float
memory_rss_mb: float
memory_vms_mb: float
memory_percent: float
cpu_percent: float
threads: int
file_descriptors: int
connections: int
swap_usage_mb: float
class ResourceMonitor:
"""Monitor system resources over time."""
def __init__(self, interval: float = 1.0):
self.interval = interval
self.snapshots: list[ResourceSnapshot] = []
self.monitoring = False
self.monitor_thread = None
self.process = psutil.Process()
def start_monitoring(self):
"""Start continuous resource monitoring."""
self.monitoring = True
self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True)
self.monitor_thread.start()
logger.info("Resource monitoring started")
def stop_monitoring(self):
"""Stop resource monitoring."""
self.monitoring = False
if self.monitor_thread:
self.monitor_thread.join(timeout=2.0)
logger.info(
f"Resource monitoring stopped. Collected {len(self.snapshots)} snapshots"
)
def _monitor_loop(self):
"""Continuous monitoring loop."""
while self.monitoring:
try:
snapshot = self._take_snapshot()
self.snapshots.append(snapshot)
time.sleep(self.interval)
except Exception as e:
logger.error(f"Error in resource monitoring: {e}")
def _take_snapshot(self) -> ResourceSnapshot:
"""Take a resource snapshot."""
memory_info = self.process.memory_info()
# Get file descriptor count
try:
fd_count = self.process.num_fds()
except AttributeError:
# Windows doesn't have num_fds()
fd_count = len(self.process.open_files())
# Get connection count
try:
connections = len(self.process.connections())
except (psutil.AccessDenied, psutil.NoSuchProcess):
connections = 0
# Get swap usage
try:
swap = psutil.swap_memory()
swap_used_mb = swap.used / 1024 / 1024
except Exception:
swap_used_mb = 0
return ResourceSnapshot(
timestamp=time.time(),
memory_rss_mb=memory_info.rss / 1024 / 1024,
memory_vms_mb=memory_info.vms / 1024 / 1024,
memory_percent=self.process.memory_percent(),
cpu_percent=self.process.cpu_percent(),
threads=self.process.num_threads(),
file_descriptors=fd_count,
connections=connections,
swap_usage_mb=swap_used_mb,
)
def get_current_snapshot(self) -> ResourceSnapshot:
"""Get current resource snapshot."""
return self._take_snapshot()
def analyze_trends(self) -> dict[str, Any]:
"""Analyze resource usage trends."""
if len(self.snapshots) < 2:
return {"error": "Insufficient data for trend analysis"}
# Calculate trends
timestamps = [s.timestamp for s in self.snapshots]
memories = [s.memory_rss_mb for s in self.snapshots]
cpus = [s.cpu_percent for s in self.snapshots]
fds = [s.file_descriptors for s in self.snapshots]
# Linear regression for memory trend
n = len(timestamps)
sum_t = sum(timestamps)
sum_m = sum(memories)
sum_tm = sum(t * m for t, m in zip(timestamps, memories, strict=False))
sum_tt = sum(t * t for t in timestamps)
memory_slope = (
(n * sum_tm - sum_t * sum_m) / (n * sum_tt - sum_t * sum_t)
if n * sum_tt != sum_t * sum_t
else 0
)
return {
"duration_seconds": timestamps[-1] - timestamps[0],
"initial_memory_mb": memories[0],
"final_memory_mb": memories[-1],
"memory_growth_mb": memories[-1] - memories[0],
"memory_growth_rate_mb_per_hour": memory_slope * 3600,
"peak_memory_mb": max(memories),
"avg_cpu_percent": sum(cpus) / len(cpus),
"peak_cpu_percent": max(cpus),
"initial_file_descriptors": fds[0],
"final_file_descriptors": fds[-1],
"fd_growth": fds[-1] - fds[0],
"peak_file_descriptors": max(fds),
"snapshots_count": len(self.snapshots),
}
class StressTestRunner:
"""Run various stress tests."""
def __init__(self, data_provider):
self.data_provider = data_provider
self.resource_monitor = ResourceMonitor(interval=2.0)
async def sustained_load_test(
self, duration_minutes: int = 60, concurrent_load: int = 10
) -> dict[str, Any]:
"""Run sustained load test for extended duration."""
logger.info(
f"Starting sustained load test: {duration_minutes} minutes with {concurrent_load} concurrent operations"
)
self.resource_monitor.start_monitoring()
start_time = time.time()
end_time = start_time + (duration_minutes * 60)
total_operations = 0
total_errors = 0
operation_times = []
try:
# Create semaphore for concurrent control
semaphore = asyncio.Semaphore(concurrent_load)
async def sustained_operation(operation_id: int):
"""Single sustained operation."""
nonlocal total_operations, total_errors
engine = VectorBTEngine(data_provider=self.data_provider)
symbol = f"STRESS_{operation_id % 20}" # Cycle through 20 symbols
strategy = ["sma_cross", "rsi", "macd"][
operation_id % 3
] # Cycle through strategies
try:
async with semaphore:
op_start = time.time()
await engine.run_backtest(
symbol=symbol,
strategy_type=strategy,
parameters=STRATEGY_TEMPLATES[strategy]["parameters"],
start_date="2023-01-01",
end_date="2023-12-31",
)
op_time = time.time() - op_start
operation_times.append(op_time)
total_operations += 1
if total_operations % 100 == 0:
logger.info(f"Completed {total_operations} operations")
except Exception as e:
total_errors += 1
logger.error(f"Operation {operation_id} failed: {e}")
# Run operations continuously until duration expires
operation_id = 0
active_tasks = []
while time.time() < end_time:
# Start new operation
task = asyncio.create_task(sustained_operation(operation_id))
active_tasks.append(task)
operation_id += 1
# Clean up completed tasks
active_tasks = [t for t in active_tasks if not t.done()]
# Control task creation rate
await asyncio.sleep(0.1)
# Prevent task accumulation
if len(active_tasks) > concurrent_load * 2:
await asyncio.sleep(1.0)
# Wait for remaining tasks to complete
if active_tasks:
await asyncio.gather(*active_tasks, return_exceptions=True)
finally:
self.resource_monitor.stop_monitoring()
actual_duration = time.time() - start_time
trend_analysis = self.resource_monitor.analyze_trends()
return {
"duration_minutes": actual_duration / 60,
"total_operations": total_operations,
"total_errors": total_errors,
"error_rate": total_errors / total_operations
if total_operations > 0
else 0,
"operations_per_minute": total_operations / (actual_duration / 60),
"avg_operation_time": sum(operation_times) / len(operation_times)
if operation_times
else 0,
"resource_trends": trend_analysis,
"concurrent_load": concurrent_load,
}
async def memory_leak_detection_test(
self, iterations: int = 1000
) -> dict[str, Any]:
"""Test for memory leaks over many iterations."""
logger.info(f"Starting memory leak detection test with {iterations} iterations")
self.resource_monitor.start_monitoring()
engine = VectorBTEngine(data_provider=self.data_provider)
initial_memory = self.resource_monitor.get_current_snapshot().memory_rss_mb
memory_measurements = []
try:
for i in range(iterations):
# Run backtest operation
symbol = f"LEAK_TEST_{i % 10}"
await engine.run_backtest(
symbol=symbol,
strategy_type="sma_cross",
parameters=STRATEGY_TEMPLATES["sma_cross"]["parameters"],
start_date="2023-01-01",
end_date="2023-12-31",
)
# Force garbage collection every 50 iterations
if i % 50 == 0:
gc.collect()
snapshot = self.resource_monitor.get_current_snapshot()
memory_measurements.append(
{
"iteration": i,
"memory_mb": snapshot.memory_rss_mb,
"memory_growth": snapshot.memory_rss_mb - initial_memory,
}
)
if i % 200 == 0:
logger.info(
f"Iteration {i}: Memory = {snapshot.memory_rss_mb:.1f}MB "
f"(+{snapshot.memory_rss_mb - initial_memory:.1f}MB)"
)
finally:
self.resource_monitor.stop_monitoring()
final_memory = self.resource_monitor.get_current_snapshot().memory_rss_mb
total_growth = final_memory - initial_memory
# Analyze memory leak pattern
if len(memory_measurements) > 2:
iterations_list = [m["iteration"] for m in memory_measurements]
growth_list = [m["memory_growth"] for m in memory_measurements]
# Linear regression to detect memory leak
n = len(iterations_list)
sum_x = sum(iterations_list)
sum_y = sum(growth_list)
sum_xy = sum(
x * y for x, y in zip(iterations_list, growth_list, strict=False)
)
sum_xx = sum(x * x for x in iterations_list)
if n * sum_xx != sum_x * sum_x:
leak_rate = (n * sum_xy - sum_x * sum_y) / (n * sum_xx - sum_x * sum_x)
else:
leak_rate = 0
# Memory leak per 1000 iterations
leak_per_1000_iterations = leak_rate * 1000
else:
leak_rate = 0
leak_per_1000_iterations = 0
return {
"iterations": iterations,
"initial_memory_mb": initial_memory,
"final_memory_mb": final_memory,
"total_memory_growth_mb": total_growth,
"leak_rate_mb_per_iteration": leak_rate,
"leak_per_1000_iterations_mb": leak_per_1000_iterations,
"memory_measurements": memory_measurements,
"leak_detected": abs(leak_per_1000_iterations)
> 10.0, # More than 10MB per 1000 iterations
}
async def cpu_stress_test(
self, duration_minutes: int = 10, cpu_target: float = 0.9
) -> dict[str, Any]:
"""Test CPU utilization under stress."""
logger.info(
f"Starting CPU stress test: {duration_minutes} minutes at {cpu_target * 100}% target"
)
self.resource_monitor.start_monitoring()
# Create CPU-intensive background load
stop_event = threading.Event()
cpu_threads = []
def cpu_intensive_task():
"""CPU-intensive computation."""
while not stop_event.is_set():
# Perform CPU-intensive work
for _ in range(10000):
_ = sum(i**2 for i in range(100))
time.sleep(0.001) # Brief pause
try:
# Start CPU load threads
num_cpu_threads = max(1, int(psutil.cpu_count() * cpu_target))
for _ in range(num_cpu_threads):
thread = threading.Thread(target=cpu_intensive_task, daemon=True)
thread.start()
cpu_threads.append(thread)
# Run backtests under CPU stress
engine = VectorBTEngine(data_provider=self.data_provider)
start_time = time.time()
end_time = start_time + (duration_minutes * 60)
operations_completed = 0
cpu_stress_errors = 0
response_times = []
while time.time() < end_time:
try:
op_start = time.time()
symbol = f"CPU_STRESS_{operations_completed % 5}"
await asyncio.wait_for(
engine.run_backtest(
symbol=symbol,
strategy_type="rsi",
parameters=STRATEGY_TEMPLATES["rsi"]["parameters"],
start_date="2023-01-01",
end_date="2023-12-31",
),
timeout=30.0, # Prevent hanging under CPU stress
)
response_time = time.time() - op_start
response_times.append(response_time)
operations_completed += 1
except TimeoutError:
cpu_stress_errors += 1
logger.warning("Operation timed out under CPU stress")
except Exception as e:
cpu_stress_errors += 1
logger.error(f"Operation failed under CPU stress: {e}")
# Brief pause between operations
await asyncio.sleep(1.0)
finally:
# Stop CPU stress
stop_event.set()
for thread in cpu_threads:
thread.join(timeout=1.0)
self.resource_monitor.stop_monitoring()
trend_analysis = self.resource_monitor.analyze_trends()
return {
"duration_minutes": duration_minutes,
"cpu_target_percent": cpu_target * 100,
"operations_completed": operations_completed,
"cpu_stress_errors": cpu_stress_errors,
"error_rate": cpu_stress_errors / (operations_completed + cpu_stress_errors)
if (operations_completed + cpu_stress_errors) > 0
else 0,
"avg_response_time": sum(response_times) / len(response_times)
if response_times
else 0,
"max_response_time": max(response_times) if response_times else 0,
"avg_cpu_utilization": trend_analysis["avg_cpu_percent"],
"peak_cpu_utilization": trend_analysis["peak_cpu_percent"],
}
async def database_connection_exhaustion_test(
self, db_session, max_connections: int = 50
) -> dict[str, Any]:
"""Test database behavior under connection exhaustion."""
logger.info(
f"Starting database connection exhaustion test with {max_connections} connections"
)
# Generate test data
engine = VectorBTEngine(data_provider=self.data_provider)
test_results = []
for i in range(5):
result = await engine.run_backtest(
symbol=f"DB_EXHAUST_{i}",
strategy_type="macd",
parameters=STRATEGY_TEMPLATES["macd"]["parameters"],
start_date="2023-01-01",
end_date="2023-12-31",
)
test_results.append(result)
# Test connection exhaustion
async def database_operation(conn_id: int) -> dict[str, Any]:
"""Single database operation holding connection."""
try:
with BacktestPersistenceManager(session=db_session) as persistence:
# Hold connection and perform operations
saved_ids = []
for result in test_results:
backtest_id = persistence.save_backtest_result(
vectorbt_results=result,
execution_time=2.0,
notes=f"Connection exhaustion test {conn_id}",
)
saved_ids.append(backtest_id)
# Perform queries
for backtest_id in saved_ids:
persistence.get_backtest_by_id(backtest_id)
# Hold connection for some time
await asyncio.sleep(2.0)
return {
"connection_id": conn_id,
"operations_completed": len(saved_ids) * 2, # Save + retrieve
"success": True,
}
except Exception as e:
return {
"connection_id": conn_id,
"error": str(e),
"success": False,
}
# Create many concurrent database operations
start_time = time.time()
connection_tasks = [database_operation(i) for i in range(max_connections)]
# Execute with timeout to prevent hanging
try:
results = await asyncio.wait_for(
asyncio.gather(*connection_tasks, return_exceptions=True), timeout=60.0
)
except TimeoutError:
logger.warning("Database connection test timed out")
results = []
execution_time = time.time() - start_time
# Analyze results
successful_connections = sum(
1 for r in results if isinstance(r, dict) and r.get("success", False)
)
failed_connections = len(results) - successful_connections
total_operations = sum(
r.get("operations_completed", 0)
for r in results
if isinstance(r, dict) and r.get("success", False)
)
return {
"max_connections_attempted": max_connections,
"successful_connections": successful_connections,
"failed_connections": failed_connections,
"connection_success_rate": successful_connections / max_connections
if max_connections > 0
else 0,
"total_operations": total_operations,
"execution_time": execution_time,
"operations_per_second": total_operations / execution_time
if execution_time > 0
else 0,
}
async def file_descriptor_exhaustion_test(self) -> dict[str, Any]:
"""Test file descriptor usage patterns."""
logger.info("Starting file descriptor exhaustion test")
initial_snapshot = self.resource_monitor.get_current_snapshot()
initial_fds = initial_snapshot.file_descriptors
self.resource_monitor.start_monitoring()
# Get system file descriptor limit
try:
soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE)
except Exception:
soft_limit, hard_limit = 1024, 4096 # Default assumptions
logger.info(
f"FD limits - Soft: {soft_limit}, Hard: {hard_limit}, Initial: {initial_fds}"
)
try:
engine = VectorBTEngine(data_provider=self.data_provider)
# Run many operations to stress file descriptor usage
fd_measurements = []
max_operations = min(100, soft_limit // 10) # Conservative approach
for i in range(max_operations):
await engine.run_backtest(
symbol=f"FD_TEST_{i}",
strategy_type="sma_cross",
parameters=STRATEGY_TEMPLATES["sma_cross"]["parameters"],
start_date="2023-01-01",
end_date="2023-12-31",
)
if i % 10 == 0:
snapshot = self.resource_monitor.get_current_snapshot()
fd_measurements.append(
{
"iteration": i,
"file_descriptors": snapshot.file_descriptors,
"fd_growth": snapshot.file_descriptors - initial_fds,
}
)
if snapshot.file_descriptors > soft_limit * 0.8:
logger.warning(
f"High FD usage detected: {snapshot.file_descriptors}/{soft_limit}"
)
finally:
self.resource_monitor.stop_monitoring()
final_snapshot = self.resource_monitor.get_current_snapshot()
final_fds = final_snapshot.file_descriptors
fd_growth = final_fds - initial_fds
# Analyze FD usage pattern
peak_fds = max([m["file_descriptors"] for m in fd_measurements] + [final_fds])
fd_utilization = peak_fds / soft_limit if soft_limit > 0 else 0
return {
"initial_file_descriptors": initial_fds,
"final_file_descriptors": final_fds,
"peak_file_descriptors": peak_fds,
"fd_growth": fd_growth,
"soft_limit": soft_limit,
"hard_limit": hard_limit,
"fd_utilization_percent": fd_utilization * 100,
"fd_measurements": fd_measurements,
"operations_completed": max_operations,
}
class TestStressTesting:
"""Stress testing suite."""
@pytest.fixture
async def stress_data_provider(self):
"""Create data provider optimized for stress testing."""
provider = Mock()
# Cache to reduce computation overhead during stress tests
data_cache = {}
def get_stress_test_data(symbol: str) -> pd.DataFrame:
"""Get cached data for stress testing."""
if symbol not in data_cache:
# Generate smaller dataset for faster stress testing
dates = pd.date_range(start="2023-06-01", end="2023-12-31", freq="D")
seed = hash(symbol) % 1000
np.random.seed(seed)
returns = np.random.normal(0.001, 0.02, len(dates))
prices = 100 * np.cumprod(1 + returns)
data_cache[symbol] = pd.DataFrame(
{
"Open": prices * np.random.uniform(0.99, 1.01, len(dates)),
"High": prices * np.random.uniform(1.01, 1.03, len(dates)),
"Low": prices * np.random.uniform(0.97, 0.99, len(dates)),
"Close": prices,
"Volume": np.random.randint(1000000, 5000000, len(dates)),
"Adj Close": prices,
},
index=dates,
)
return data_cache[symbol].copy()
provider.get_stock_data.side_effect = get_stress_test_data
return provider
@pytest.mark.slow
async def test_sustained_load_15_minutes(self, stress_data_provider):
"""Test sustained load for 15 minutes (abbreviated from 1 hour for CI)."""
stress_runner = StressTestRunner(stress_data_provider)
result = await stress_runner.sustained_load_test(
duration_minutes=15, # Reduced for CI/testing
concurrent_load=8,
)
# Assertions for sustained load
assert result["error_rate"] <= 0.05, (
f"Error rate too high: {result['error_rate']:.3f}"
)
assert result["operations_per_minute"] >= 10, (
f"Throughput too low: {result['operations_per_minute']:.1f} ops/min"
)
# Resource growth should be reasonable
trends = result["resource_trends"]
assert trends["memory_growth_rate_mb_per_hour"] <= 100, (
f"Memory growth rate too high: {trends['memory_growth_rate_mb_per_hour']:.1f} MB/hour"
)
logger.info(
f"✓ Sustained load test completed: {result['total_operations']} operations in {result['duration_minutes']:.1f} minutes"
)
return result
async def test_memory_leak_detection(self, stress_data_provider):
"""Test for memory leaks over many iterations."""
stress_runner = StressTestRunner(stress_data_provider)
result = await stress_runner.memory_leak_detection_test(iterations=200)
# Memory leak assertions
assert not result["leak_detected"], (
f"Memory leak detected: {result['leak_per_1000_iterations_mb']:.2f} MB per 1000 iterations"
)
assert result["total_memory_growth_mb"] <= 300, (
f"Total memory growth too high: {result['total_memory_growth_mb']:.1f} MB"
)
logger.info(
f"✓ Memory leak test completed: {result['total_memory_growth_mb']:.1f}MB growth over {result['iterations']} iterations"
)
return result
async def test_cpu_stress_resilience(self, stress_data_provider):
"""Test system resilience under CPU stress."""
stress_runner = StressTestRunner(stress_data_provider)
result = await stress_runner.cpu_stress_test(
duration_minutes=5, # Reduced for testing
cpu_target=0.7, # 70% CPU utilization
)
# CPU stress assertions
assert result["error_rate"] <= 0.2, (
f"Error rate too high under CPU stress: {result['error_rate']:.3f}"
)
assert result["avg_response_time"] <= 10.0, (
f"Response time too slow under CPU stress: {result['avg_response_time']:.2f}s"
)
assert result["operations_completed"] >= 10, (
f"Too few operations completed: {result['operations_completed']}"
)
logger.info(
f"✓ CPU stress test completed: {result['operations_completed']} operations with {result['avg_cpu_utilization']:.1f}% avg CPU"
)
return result
async def test_database_connection_stress(self, stress_data_provider, db_session):
"""Test database performance under connection stress."""
stress_runner = StressTestRunner(stress_data_provider)
result = await stress_runner.database_connection_exhaustion_test(
db_session=db_session,
max_connections=20, # Reduced for testing
)
# Database stress assertions
assert result["connection_success_rate"] >= 0.8, (
f"Connection success rate too low: {result['connection_success_rate']:.3f}"
)
assert result["operations_per_second"] >= 5.0, (
f"Database throughput too low: {result['operations_per_second']:.2f} ops/s"
)
logger.info(
f"✓ Database stress test completed: {result['successful_connections']}/{result['max_connections_attempted']} connections succeeded"
)
return result
async def test_file_descriptor_management(self, stress_data_provider):
"""Test file descriptor usage under stress."""
stress_runner = StressTestRunner(stress_data_provider)
result = await stress_runner.file_descriptor_exhaustion_test()
# File descriptor assertions
assert result["fd_utilization_percent"] <= 50.0, (
f"FD utilization too high: {result['fd_utilization_percent']:.1f}%"
)
assert result["fd_growth"] <= 100, f"FD growth too high: {result['fd_growth']}"
logger.info(
f"✓ File descriptor test completed: {result['peak_file_descriptors']} peak FDs ({result['fd_utilization_percent']:.1f}% utilization)"
)
return result
async def test_queue_overflow_scenarios(self, stress_data_provider):
"""Test queue management under overflow conditions."""
# Simulate queue overflow by creating more tasks than can be processed
max_queue_size = 50
overflow_tasks = 100
semaphore = asyncio.Semaphore(5) # Limit concurrent processing
processed_tasks = 0
overflow_errors = 0
async def queue_task(task_id: int):
nonlocal processed_tasks, overflow_errors
try:
async with semaphore:
engine = VectorBTEngine(data_provider=stress_data_provider)
await engine.run_backtest(
symbol=f"QUEUE_{task_id % 10}",
strategy_type="sma_cross",
parameters=STRATEGY_TEMPLATES["sma_cross"]["parameters"],
start_date="2023-06-01",
end_date="2023-12-31",
)
processed_tasks += 1
except Exception as e:
overflow_errors += 1
logger.error(f"Queue task {task_id} failed: {e}")
# Create tasks faster than they can be processed
start_time = time.time()
tasks = []
for i in range(overflow_tasks):
task = asyncio.create_task(queue_task(i))
tasks.append(task)
# Create tasks rapidly to test queue management
if i < max_queue_size:
await asyncio.sleep(0.01) # Rapid creation
else:
await asyncio.sleep(0.1) # Slower creation after queue fills
# Wait for all tasks to complete
await asyncio.gather(*tasks, return_exceptions=True)
execution_time = time.time() - start_time
# Queue overflow assertions
processing_success_rate = processed_tasks / overflow_tasks
assert processing_success_rate >= 0.8, (
f"Queue processing success rate too low: {processing_success_rate:.3f}"
)
assert execution_time < 120.0, (
f"Queue processing took too long: {execution_time:.1f}s"
)
logger.info(
f"✓ Queue overflow test completed: {processed_tasks}/{overflow_tasks} tasks processed in {execution_time:.1f}s"
)
return {
"overflow_tasks": overflow_tasks,
"processed_tasks": processed_tasks,
"overflow_errors": overflow_errors,
"processing_success_rate": processing_success_rate,
"execution_time": execution_time,
}
async def test_comprehensive_stress_suite(self, stress_data_provider, db_session):
"""Run comprehensive stress testing suite."""
logger.info("Starting Comprehensive Stress Testing Suite...")
stress_results = {}
# Run individual stress tests
stress_results["sustained_load"] = await self.test_sustained_load_15_minutes(
stress_data_provider
)
stress_results["memory_leak"] = await self.test_memory_leak_detection(
stress_data_provider
)
stress_results["cpu_stress"] = await self.test_cpu_stress_resilience(
stress_data_provider
)
stress_results["database_stress"] = await self.test_database_connection_stress(
stress_data_provider, db_session
)
stress_results["file_descriptors"] = await self.test_file_descriptor_management(
stress_data_provider
)
stress_results["queue_overflow"] = await self.test_queue_overflow_scenarios(
stress_data_provider
)
# Aggregate stress test analysis
total_tests = len(stress_results)
passed_tests = 0
critical_failures = []
for test_name, result in stress_results.items():
# Simple pass/fail based on whether test completed without major issues
test_passed = True
if test_name == "sustained_load" and result["error_rate"] > 0.1:
test_passed = False
critical_failures.append(
f"Sustained load error rate: {result['error_rate']:.3f}"
)
elif test_name == "memory_leak" and result["leak_detected"]:
test_passed = False
critical_failures.append(
f"Memory leak detected: {result['leak_per_1000_iterations_mb']:.2f} MB/1k iterations"
)
elif test_name == "cpu_stress" and result["error_rate"] > 0.3:
test_passed = False
critical_failures.append(
f"CPU stress error rate: {result['error_rate']:.3f}"
)
if test_passed:
passed_tests += 1
overall_pass_rate = passed_tests / total_tests
logger.info(
f"\n{'=' * 60}\n"
f"COMPREHENSIVE STRESS TEST REPORT\n"
f"{'=' * 60}\n"
f"Total Tests: {total_tests}\n"
f"Passed: {passed_tests}\n"
f"Overall Pass Rate: {overall_pass_rate:.1%}\n"
f"Critical Failures: {len(critical_failures)}\n"
f"{'=' * 60}\n"
)
# Assert overall stress test success
assert overall_pass_rate >= 0.8, (
f"Overall stress test pass rate too low: {overall_pass_rate:.1%}"
)
assert len(critical_failures) <= 1, (
f"Too many critical failures: {critical_failures}"
)
return {
"overall_pass_rate": overall_pass_rate,
"critical_failures": critical_failures,
"stress_results": stress_results,
}
if __name__ == "__main__":
# Run stress testing suite
pytest.main(
[
__file__,
"-v",
"--tb=short",
"--asyncio-mode=auto",
"--timeout=1800", # 30 minute timeout for stress tests
"-m",
"not slow", # Skip slow tests by default
]
)
```
--------------------------------------------------------------------------------
/tests/test_database_pool_config.py:
--------------------------------------------------------------------------------
```python
"""
Comprehensive tests for DatabasePoolConfig.
This module tests the enhanced database pool configuration that provides
validation, monitoring, and optimization capabilities. Tests cover:
- Pool validation logic against database limits
- Warning conditions for insufficient pool sizing
- Environment variable overrides
- Factory methods (development, production, high-concurrency)
- Monitoring thresholds and SQLAlchemy event listeners
- Integration with existing DatabaseConfig
- Production validation checks
"""
import os
import warnings
from unittest.mock import Mock, patch
import pytest
from sqlalchemy.pool import QueuePool
from maverick_mcp.config.database import (
DatabasePoolConfig,
create_engine_with_enhanced_config,
create_monitored_engine_kwargs,
get_default_pool_config,
get_development_pool_config,
get_high_concurrency_pool_config,
get_pool_config_from_settings,
validate_production_config,
)
from maverick_mcp.providers.interfaces.persistence import DatabaseConfig
class TestDatabasePoolConfig:
"""Test the main DatabasePoolConfig class."""
def test_default_configuration(self):
"""Test default configuration values."""
config = DatabasePoolConfig()
# Test environment variable defaults
assert config.pool_size == int(os.getenv("DB_POOL_SIZE", "20"))
assert config.max_overflow == int(os.getenv("DB_MAX_OVERFLOW", "10"))
assert config.pool_timeout == int(os.getenv("DB_POOL_TIMEOUT", "30"))
assert config.pool_recycle == int(os.getenv("DB_POOL_RECYCLE", "3600"))
assert config.max_database_connections == int(
os.getenv("DB_MAX_CONNECTIONS", "100")
)
assert config.reserved_superuser_connections == int(
os.getenv("DB_RESERVED_SUPERUSER_CONNECTIONS", "3")
)
assert config.expected_concurrent_users == int(
os.getenv("DB_EXPECTED_CONCURRENT_USERS", "20")
)
assert config.connections_per_user == float(
os.getenv("DB_CONNECTIONS_PER_USER", "1.2")
)
assert config.pool_pre_ping == (
os.getenv("DB_POOL_PRE_PING", "true").lower() == "true"
)
assert config.echo_pool == (
os.getenv("DB_ECHO_POOL", "false").lower() == "true"
)
@patch.dict(
os.environ,
{
"DB_POOL_SIZE": "25",
"DB_MAX_OVERFLOW": "10",
"DB_POOL_TIMEOUT": "45",
"DB_POOL_RECYCLE": "1800",
"DB_MAX_CONNECTIONS": "80",
"DB_RESERVED_SUPERUSER_CONNECTIONS": "2",
"DB_EXPECTED_CONCURRENT_USERS": "25",
"DB_CONNECTIONS_PER_USER": "1.2",
"DB_POOL_PRE_PING": "false",
"DB_ECHO_POOL": "true",
},
)
def test_environment_variable_overrides(self):
"""Test that environment variables override defaults."""
config = DatabasePoolConfig()
assert config.pool_size == 25
assert config.max_overflow == 10
assert config.pool_timeout == 45
assert config.pool_recycle == 1800
assert config.max_database_connections == 80
assert config.reserved_superuser_connections == 2
assert config.expected_concurrent_users == 25
assert config.connections_per_user == 1.2
assert config.pool_pre_ping is False
assert config.echo_pool is True
def test_valid_configuration(self):
"""Test a valid configuration passes validation."""
config = DatabasePoolConfig(
pool_size=10,
max_overflow=5,
max_database_connections=50,
reserved_superuser_connections=3,
expected_concurrent_users=10,
connections_per_user=1.2,
)
# Should not raise any exceptions
assert config.pool_size == 10
assert config.max_overflow == 5
# Calculated values
total_app_connections = config.pool_size + config.max_overflow
available_connections = (
config.max_database_connections - config.reserved_superuser_connections
)
assert total_app_connections <= available_connections
def test_validation_exceeds_database_capacity(self):
"""Test validation failure when pool exceeds database capacity."""
with pytest.raises(
ValueError, match="Pool configuration exceeds database capacity"
):
DatabasePoolConfig(
pool_size=50,
max_overflow=30, # Total = 80
max_database_connections=70, # Available = 67 (70-3)
reserved_superuser_connections=3,
)
def test_validation_insufficient_for_expected_load(self):
"""Test validation failure when pool is insufficient for expected load."""
with pytest.raises(
ValueError, match="Total connection capacity .* is insufficient"
):
DatabasePoolConfig(
pool_size=5,
max_overflow=0, # Total capacity = 5
expected_concurrent_users=10,
connections_per_user=1.0, # Expected demand = 10
max_database_connections=50,
)
def test_validation_warning_for_small_pool(self):
"""Test warning when pool size may be insufficient."""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
DatabasePoolConfig(
pool_size=5, # Small pool
max_overflow=15, # But enough overflow to meet demand
expected_concurrent_users=10,
connections_per_user=1.5, # Expected demand = 15
max_database_connections=50,
)
# Should generate a warning
assert len(w) > 0
assert "Pool size (5) may be insufficient" in str(w[0].message)
def test_field_validation_ranges(self):
"""Test field validation for valid ranges."""
from pydantic import ValidationError
# Test valid ranges with proper expected demand
config = DatabasePoolConfig(
pool_size=5, # Minimum safe size
max_overflow=0, # Minimum
pool_timeout=1, # Minimum
pool_recycle=300, # Minimum
expected_concurrent_users=3, # Lower expected demand
connections_per_user=1.0,
)
assert config.pool_size == 5
config = DatabasePoolConfig(
pool_size=80, # Large but fits in database capacity
max_overflow=15, # Reasonable overflow
pool_timeout=300, # Maximum
pool_recycle=7200, # Maximum
expected_concurrent_users=85, # Fit within total capacity of 95
connections_per_user=1.0,
max_database_connections=120, # Higher limit to accommodate pool
)
assert config.pool_size == 80
# Test invalid ranges
with pytest.raises(ValidationError):
DatabasePoolConfig(pool_size=0) # Below minimum
with pytest.raises(ValidationError):
DatabasePoolConfig(pool_size=101) # Above maximum
with pytest.raises(ValidationError):
DatabasePoolConfig(max_overflow=-1) # Below minimum
with pytest.raises(ValidationError):
DatabasePoolConfig(max_overflow=51) # Above maximum
def test_get_pool_kwargs(self):
"""Test SQLAlchemy pool configuration generation."""
config = DatabasePoolConfig(
pool_size=15,
max_overflow=8,
pool_timeout=45,
pool_recycle=1800,
pool_pre_ping=True,
echo_pool=True,
expected_concurrent_users=18, # Match capacity
connections_per_user=1.0,
)
kwargs = config.get_pool_kwargs()
expected = {
"poolclass": QueuePool,
"pool_size": 15,
"max_overflow": 8,
"pool_timeout": 45,
"pool_recycle": 1800,
"pool_pre_ping": True,
"echo_pool": True,
}
assert kwargs == expected
def test_get_monitoring_thresholds(self):
"""Test monitoring threshold calculation."""
config = DatabasePoolConfig(pool_size=20, max_overflow=10)
thresholds = config.get_monitoring_thresholds()
expected = {
"warning_threshold": int(20 * 0.8), # 16
"critical_threshold": int(20 * 0.95), # 19
"pool_size": 20,
"max_overflow": 10,
"total_capacity": 30,
}
assert thresholds == expected
def test_validate_against_database_limits_matching(self):
"""Test validation when actual limits match configuration."""
config = DatabasePoolConfig(max_database_connections=100)
# Should not raise any exceptions when limits match
config.validate_against_database_limits(100)
assert config.max_database_connections == 100
def test_validate_against_database_limits_higher_actual(self):
"""Test validation when actual limits are higher."""
config = DatabasePoolConfig(max_database_connections=100)
with patch("maverick_mcp.config.database.logger") as mock_logger:
config.validate_against_database_limits(150)
# Should update configuration and log info
assert config.max_database_connections == 150
mock_logger.info.assert_called_once()
def test_validate_against_database_limits_lower_actual_safe(self):
"""Test validation when actual limits are lower but pool still fits."""
config = DatabasePoolConfig(
pool_size=10,
max_overflow=5, # Total = 15
max_database_connections=100,
reserved_superuser_connections=3,
expected_concurrent_users=12, # Fit within total capacity of 15
connections_per_user=1.0,
)
with patch("maverick_mcp.config.database.logger") as mock_logger:
# Actual limit is 80, available is 77, pool needs 15 - should be fine
config.validate_against_database_limits(80)
mock_logger.warning.assert_called_once()
warning_call = mock_logger.warning.call_args[0][0]
assert "lower than configured" in warning_call
def test_validate_against_database_limits_lower_actual_unsafe(self):
"""Test validation failure when actual limits are too low."""
config = DatabasePoolConfig(
pool_size=30,
max_overflow=20, # Total = 50
max_database_connections=100,
reserved_superuser_connections=3,
)
with pytest.raises(
ValueError, match="Configuration invalid for actual database limits"
):
# Actual limit is 40, available is 37, pool needs 50 - should fail
config.validate_against_database_limits(40)
def test_to_legacy_config(self):
"""Test conversion to legacy DatabaseConfig."""
config = DatabasePoolConfig(
pool_size=15,
max_overflow=8,
pool_timeout=45,
pool_recycle=1800,
echo_pool=True,
expected_concurrent_users=18, # Fit within total capacity of 23
connections_per_user=1.0,
)
database_url = "postgresql://user:pass@localhost/test"
legacy_config = config.to_legacy_config(database_url)
assert isinstance(legacy_config, DatabaseConfig)
assert legacy_config.database_url == database_url
assert legacy_config.pool_size == 15
assert legacy_config.max_overflow == 8
assert legacy_config.pool_timeout == 45
assert legacy_config.pool_recycle == 1800
assert legacy_config.echo is True
assert legacy_config.autocommit is False
assert legacy_config.autoflush is True
assert legacy_config.expire_on_commit is True
def test_from_legacy_config(self):
"""Test creation from legacy DatabaseConfig."""
legacy_config = DatabaseConfig(
database_url="postgresql://user:pass@localhost/test",
pool_size=12,
max_overflow=6,
pool_timeout=60,
pool_recycle=2400,
echo=False,
)
enhanced_config = DatabasePoolConfig.from_legacy_config(
legacy_config,
expected_concurrent_users=15, # Override
max_database_connections=80, # Override
)
assert enhanced_config.pool_size == 12
assert enhanced_config.max_overflow == 6
assert enhanced_config.pool_timeout == 60
assert enhanced_config.pool_recycle == 2400
assert enhanced_config.echo_pool is False
assert enhanced_config.expected_concurrent_users == 15 # Override applied
assert enhanced_config.max_database_connections == 80 # Override applied
def test_setup_pool_monitoring(self):
"""Test SQLAlchemy event listener setup."""
config = DatabasePoolConfig(
pool_size=10,
echo_pool=True,
expected_concurrent_users=15, # Fit within capacity
connections_per_user=1.0,
)
# Create a mock engine with pool
mock_engine = Mock()
mock_pool = Mock()
mock_pool.checkedout.return_value = 5
mock_pool.checkedin.return_value = 3
mock_engine.pool = mock_pool
# Mock the event listener registration
with patch("maverick_mcp.config.database.event") as mock_event:
config.setup_pool_monitoring(mock_engine)
# Verify event listeners were registered
assert (
mock_event.listens_for.call_count == 5
) # connect, checkout, checkin, invalidate, soft_invalidate
# Test the event listener functions were called correctly
expected_events = [
"connect",
"checkout",
"checkin",
"invalidate",
"soft_invalidate",
]
for call_args in mock_event.listens_for.call_args_list:
assert call_args[0][0] is mock_engine
assert call_args[0][1] in expected_events
class TestFactoryFunctions:
"""Test factory functions for different configuration types."""
def test_get_default_pool_config(self):
"""Test default pool configuration factory."""
config = get_default_pool_config()
assert isinstance(config, DatabasePoolConfig)
# Should use environment variable defaults
assert config.pool_size == int(os.getenv("DB_POOL_SIZE", "20"))
def test_get_development_pool_config(self):
"""Test development pool configuration factory."""
config = get_development_pool_config()
assert isinstance(config, DatabasePoolConfig)
assert config.pool_size == 5
assert config.max_overflow == 2
assert config.pool_timeout == 30
assert config.pool_recycle == 3600
assert config.expected_concurrent_users == 5
assert config.connections_per_user == 1.0
assert config.max_database_connections == 20
assert config.reserved_superuser_connections == 2
assert config.echo_pool is True # Debug enabled in development
def test_get_high_concurrency_pool_config(self):
"""Test high concurrency pool configuration factory."""
config = get_high_concurrency_pool_config()
assert isinstance(config, DatabasePoolConfig)
assert config.pool_size == 50
assert config.max_overflow == 30
assert config.pool_timeout == 60
assert config.pool_recycle == 1800 # 30 minutes
assert config.expected_concurrent_users == 60
assert config.connections_per_user == 1.3
assert config.max_database_connections == 200
assert config.reserved_superuser_connections == 5
def test_get_pool_config_from_settings_development(self):
"""Test configuration from settings in development."""
# Create a mock settings module and settings object
mock_settings_module = Mock()
mock_settings_obj = Mock()
mock_settings_obj.environment = "development"
# Configure hasattr to return False for 'db' to avoid overrides path
mock_settings_obj.configure_mock(**{"db": None})
mock_settings_module.settings = mock_settings_obj
# Patch the import to return our mock
with patch.dict(
"sys.modules", {"maverick_mcp.config.settings": mock_settings_module}
):
# Also patch hasattr to return False for the db attribute
with patch("builtins.hasattr", side_effect=lambda obj, attr: attr != "db"):
config = get_pool_config_from_settings()
# Should return development configuration
assert config.pool_size == 5 # Development default
assert config.echo_pool is True
def test_get_pool_config_from_settings_production(self):
"""Test configuration from settings in production."""
# Create a mock settings module and settings object
mock_settings_module = Mock()
mock_settings_obj = Mock()
mock_settings_obj.environment = "production"
mock_settings_module.settings = mock_settings_obj
# Patch the import to return our mock
with patch.dict(
"sys.modules", {"maverick_mcp.config.settings": mock_settings_module}
):
# Also patch hasattr to return False for the db attribute
with patch("builtins.hasattr", side_effect=lambda obj, attr: attr != "db"):
config = get_pool_config_from_settings()
# Should return high concurrency configuration
assert config.pool_size == 50 # Production default
assert config.max_overflow == 30
def test_get_pool_config_from_settings_with_overrides(self):
"""Test configuration from settings with database-specific overrides."""
# Create a mock settings module and settings object
mock_settings_module = Mock()
mock_settings_obj = Mock()
mock_settings_obj.environment = "development"
# Create proper mock for db settings with real values, not Mock objects
class MockDbSettings:
pool_size = 8
pool_max_overflow = 3
pool_timeout = 60
mock_settings_obj.db = MockDbSettings()
mock_settings_module.settings = mock_settings_obj
# Patch the import to return our mock
with patch.dict(
"sys.modules", {"maverick_mcp.config.settings": mock_settings_module}
):
config = get_pool_config_from_settings()
# Should use overrides
assert config.pool_size == 8
assert config.max_overflow == 3
assert config.pool_timeout == 60
# Other development defaults should remain
assert config.echo_pool is True
def test_get_pool_config_from_settings_import_error(self):
"""Test fallback when settings import fails."""
# Create a mock import function that raises ImportError for settings module
def mock_import(name, *args, **kwargs):
if name == "maverick_mcp.config.settings":
raise ImportError("No module named 'maverick_mcp.config.settings'")
return __import__(name, *args, **kwargs)
with patch("builtins.__import__", side_effect=mock_import):
with patch("maverick_mcp.config.database.logger") as mock_logger:
config = get_pool_config_from_settings()
# Should fall back to default
assert isinstance(config, DatabasePoolConfig)
# Should call warning twice: import error + pool size warning
assert mock_logger.warning.call_count == 2
import_warning_call = mock_logger.warning.call_args_list[0]
assert (
"Could not import settings, using default pool configuration"
in str(import_warning_call)
)
class TestUtilityFunctions:
"""Test utility functions."""
def test_create_monitored_engine_kwargs(self):
"""Test monitored engine kwargs creation."""
config = DatabasePoolConfig(
pool_size=15,
max_overflow=8,
pool_timeout=45,
pool_recycle=1800,
pool_pre_ping=True,
echo_pool=False,
expected_concurrent_users=18, # Reduce to fit total capacity of 23
connections_per_user=1.0,
)
database_url = "postgresql://user:pass@localhost/test"
kwargs = create_monitored_engine_kwargs(database_url, config)
expected = {
"url": database_url,
"poolclass": QueuePool,
"pool_size": 15,
"max_overflow": 8,
"pool_timeout": 45,
"pool_recycle": 1800,
"pool_pre_ping": True,
"echo_pool": False,
"connect_args": {
"application_name": "maverick_mcp",
},
}
assert kwargs == expected
@patch("sqlalchemy.create_engine")
@patch("maverick_mcp.config.database.get_pool_config_from_settings")
def test_create_engine_with_enhanced_config(
self, mock_get_config, mock_create_engine
):
"""Test complete engine creation with monitoring."""
mock_config = Mock(spec=DatabasePoolConfig)
mock_config.pool_size = 20
mock_config.max_overflow = 10
mock_config.get_pool_kwargs.return_value = {"pool_size": 20}
mock_config.setup_pool_monitoring = Mock()
mock_get_config.return_value = mock_config
mock_engine = Mock()
mock_create_engine.return_value = mock_engine
database_url = "postgresql://user:pass@localhost/test"
result = create_engine_with_enhanced_config(database_url)
# Verify engine creation and monitoring setup
assert result is mock_engine
mock_create_engine.assert_called_once()
mock_config.setup_pool_monitoring.assert_called_once_with(mock_engine)
def test_validate_production_config_valid(self):
"""Test production validation for valid configuration."""
config = DatabasePoolConfig(
pool_size=25,
max_overflow=15,
pool_timeout=30,
pool_recycle=3600,
)
with patch("maverick_mcp.config.database.logger") as mock_logger:
result = validate_production_config(config)
assert result is True
mock_logger.info.assert_called_with(
"Production configuration validation passed"
)
def test_validate_production_config_warnings(self):
"""Test production validation with warnings."""
config = DatabasePoolConfig(
pool_size=5, # Too small
max_overflow=0, # No overflow
pool_timeout=30,
pool_recycle=7200, # Maximum allowed (was 8000, too high)
expected_concurrent_users=4, # Reduce to fit capacity of 5
connections_per_user=1.0,
)
with patch("maverick_mcp.config.database.logger") as mock_logger:
result = validate_production_config(config)
assert result is True # Warnings don't fail validation
# Should log multiple warnings
warning_calls = list(mock_logger.warning.call_args_list)
assert (
len(warning_calls) == 2
) # Small pool, no overflow (recycle=7200 is max allowed, not "too long")
# Check final info message mentions warnings
info_call = mock_logger.info.call_args[0][0]
assert "warnings" in info_call
def test_validate_production_config_errors(self):
"""Test production validation with errors."""
config = DatabasePoolConfig(
pool_size=15,
max_overflow=5,
pool_timeout=5, # Too aggressive
pool_recycle=3600,
expected_concurrent_users=18, # Reduce to fit capacity of 20
connections_per_user=1.0,
)
with pytest.raises(
ValueError, match="Production configuration validation failed"
):
validate_production_config(config)
class TestEventListenerBehavior:
"""Test SQLAlchemy event listener behavior with real scenarios."""
def test_connect_event_logging(self):
"""Test connect event logging behavior."""
config = DatabasePoolConfig(
pool_size=10,
echo_pool=True,
expected_concurrent_users=8, # Reduce expected demand to fit capacity
connections_per_user=1.0,
)
# Mock engine and pool
mock_engine = Mock()
mock_pool = Mock()
mock_pool.checkedout.return_value = 7 # 70% usage
mock_pool.checkedin.return_value = 3
mock_engine.pool = mock_pool
# Mock the event registration and capture listener functions
captured_listeners = {}
def mock_listens_for(target, event_name):
def decorator(func):
captured_listeners[event_name] = func
return func
return decorator
with patch(
"maverick_mcp.config.database.event.listens_for",
side_effect=mock_listens_for,
):
config.setup_pool_monitoring(mock_engine)
# Verify we captured the connect listener
assert "connect" in captured_listeners
connect_listener = captured_listeners["connect"]
# Test the listener function
with patch("maverick_mcp.config.database.logger") as mock_logger:
connect_listener(None, None) # dbapi_connection, connection_record
# Should log warning at 70% usage (above 80% threshold would be warning)
# At 70%, should not trigger warning (threshold is 80%)
mock_logger.warning.assert_not_called()
def test_connect_event_warning_threshold(self):
"""Test connect event warning threshold."""
config = DatabasePoolConfig(
pool_size=10,
echo_pool=True,
expected_concurrent_users=8, # Reduce expected demand
connections_per_user=1.0,
)
mock_engine = Mock()
mock_pool = Mock()
mock_pool.checkedout.return_value = 9 # 90% usage (above 80% warning threshold)
mock_pool.checkedin.return_value = 1
mock_engine.pool = mock_pool
# Mock the event registration and capture listener functions
captured_listeners = {}
def mock_listens_for(target, event_name):
def decorator(func):
captured_listeners[event_name] = func
return func
return decorator
with patch(
"maverick_mcp.config.database.event.listens_for",
side_effect=mock_listens_for,
):
config.setup_pool_monitoring(mock_engine)
# Verify we captured the connect listener
assert "connect" in captured_listeners
connect_listener = captured_listeners["connect"]
# Test warning threshold
with patch("maverick_mcp.config.database.logger") as mock_logger:
connect_listener(None, None)
# Should log warning
mock_logger.warning.assert_called_once()
warning_message = mock_logger.warning.call_args[0][0]
assert "Pool usage approaching capacity" in warning_message
def test_connect_event_critical_threshold(self):
"""Test connect event critical threshold."""
config = DatabasePoolConfig(
pool_size=10,
echo_pool=True,
expected_concurrent_users=8, # Reduce expected demand
connections_per_user=1.0,
)
mock_engine = Mock()
mock_pool = Mock()
mock_pool.checkedout.return_value = (
10 # 100% usage (above 95% critical threshold)
)
mock_pool.checkedin.return_value = 0
mock_engine.pool = mock_pool
# Mock the event registration and capture listener functions
captured_listeners = {}
def mock_listens_for(target, event_name):
def decorator(func):
captured_listeners[event_name] = func
return func
return decorator
with patch(
"maverick_mcp.config.database.event.listens_for",
side_effect=mock_listens_for,
):
config.setup_pool_monitoring(mock_engine)
# Verify we captured the connect listener
assert "connect" in captured_listeners
connect_listener = captured_listeners["connect"]
# Test critical threshold
with patch("maverick_mcp.config.database.logger") as mock_logger:
connect_listener(None, None)
# Should log both warning and error
mock_logger.warning.assert_called_once()
mock_logger.error.assert_called_once()
error_message = mock_logger.error.call_args[0][0]
assert "Pool usage critical" in error_message
def test_invalidate_event_logging(self):
"""Test connection invalidation event logging."""
config = DatabasePoolConfig(
pool_size=10,
echo_pool=True,
expected_concurrent_users=8, # Reduce expected demand
connections_per_user=1.0,
)
mock_engine = Mock()
# Mock the event registration and capture listener functions
captured_listeners = {}
def mock_listens_for(target, event_name):
def decorator(func):
captured_listeners[event_name] = func
return func
return decorator
with patch(
"maverick_mcp.config.database.event.listens_for",
side_effect=mock_listens_for,
):
config.setup_pool_monitoring(mock_engine)
# Verify we captured the invalidate listener
assert "invalidate" in captured_listeners
invalidate_listener = captured_listeners["invalidate"]
# Test the listener function
with patch("maverick_mcp.config.database.logger") as mock_logger:
test_exception = Exception("Connection lost")
invalidate_listener(None, None, test_exception)
mock_logger.warning.assert_called_once()
warning_message = mock_logger.warning.call_args[0][0]
assert "Connection invalidated due to error" in warning_message
assert "Connection lost" in warning_message
class TestRealWorldScenarios:
"""Test realistic usage scenarios."""
def test_microservice_configuration(self):
"""Test configuration suitable for microservice deployment."""
config = DatabasePoolConfig(
pool_size=8,
max_overflow=4,
pool_timeout=30,
pool_recycle=1800,
expected_concurrent_users=10,
connections_per_user=1.0,
max_database_connections=50,
reserved_superuser_connections=2,
)
# Should be valid
assert config.pool_size == 8
# Test monitoring setup
thresholds = config.get_monitoring_thresholds()
assert thresholds["warning_threshold"] == 6 # 80% of 8
assert thresholds["critical_threshold"] == 7 # 95% of 8
def test_high_traffic_web_app_configuration(self):
"""Test configuration for high-traffic web application."""
config = get_high_concurrency_pool_config()
# Validate it's production-ready
assert validate_production_config(config) is True
# Should handle expected load
total_capacity = config.pool_size + config.max_overflow
expected_demand = config.expected_concurrent_users * config.connections_per_user
assert total_capacity >= expected_demand
def test_development_to_production_migration(self):
"""Test migrating from development to production configuration."""
# Start with development config
dev_config = get_development_pool_config()
assert dev_config.echo_pool is True # Debug enabled
assert dev_config.pool_size == 5 # Small pool
# Convert to legacy for compatibility testing
legacy_config = dev_config.to_legacy_config("postgresql://localhost/test")
assert isinstance(legacy_config, DatabaseConfig)
# Upgrade to production config
prod_config = DatabasePoolConfig.from_legacy_config(
legacy_config,
pool_size=30, # Production sizing
max_overflow=20,
expected_concurrent_users=40,
max_database_connections=150,
echo_pool=False, # Disable debug
)
# Should be production-ready
assert validate_production_config(prod_config) is True
assert prod_config.echo_pool is False
assert prod_config.pool_size == 30
def test_database_upgrade_scenario(self):
"""Test handling database capacity upgrades."""
# Original configuration for smaller database
config = DatabasePoolConfig(
pool_size=20,
max_overflow=10,
max_database_connections=100,
)
# Database upgraded to higher capacity
config.validate_against_database_limits(200)
# Configuration should be updated
assert config.max_database_connections == 200
# Can now safely increase pool size
larger_config = DatabasePoolConfig(
pool_size=40,
max_overflow=20,
max_database_connections=200,
expected_concurrent_users=50,
connections_per_user=1.2,
)
# Should validate successfully
assert larger_config.pool_size == 40
def test_connection_exhaustion_prevention(self):
"""Test that configuration prevents connection exhaustion."""
# Configuration that would exhaust connections
with pytest.raises(ValueError, match="exceeds database capacity"):
DatabasePoolConfig(
pool_size=45,
max_overflow=35, # Total = 80
max_database_connections=75, # Available = 72 (75-3)
reserved_superuser_connections=3,
)
# Safe configuration
safe_config = DatabasePoolConfig(
pool_size=30,
max_overflow=20, # Total = 50
max_database_connections=75, # Available = 72 (75-3)
reserved_superuser_connections=3,
)
# Should leave room for other applications and admin access
total_used = safe_config.pool_size + safe_config.max_overflow
available = (
safe_config.max_database_connections
- safe_config.reserved_superuser_connections
)
assert total_used < available # Should not use ALL available connections
```
--------------------------------------------------------------------------------
/tests/test_speed_optimization_validation.py:
--------------------------------------------------------------------------------
```python
"""
Speed Optimization Validation Test Suite for MaverickMCP Research Agents
This comprehensive test suite validates the speed optimizations implemented in the research system:
- Validates 2-3x speed improvement claims
- Tests emergency mode completion under 30s
- Verifies fast model selection (Gemini 2.5 Flash, GPT-4o Mini)
- Resolves previous timeout issues (138s, 129s failures)
- Compares before/after performance
Speed Optimization Features Being Tested:
1. Adaptive Model Selection (emergency, fast, balanced modes)
2. Progressive Token Budgeting with time awareness
3. Parallel LLM Processing with intelligent batching
4. Optimized Prompt Engineering for speed
5. Early Termination based on confidence thresholds
6. Content Filtering to reduce processing overhead
"""
import asyncio
import logging
import statistics
import time
from datetime import datetime
from enum import Enum
from typing import Any
from unittest.mock import AsyncMock, MagicMock
try:
import pytest
except ImportError:
# For standalone use without pytest
pytest = None
from maverick_mcp.agents.deep_research import DeepResearchAgent
from maverick_mcp.agents.optimized_research import OptimizedDeepResearchAgent
from maverick_mcp.providers.openrouter_provider import TaskType
from maverick_mcp.utils.llm_optimization import (
AdaptiveModelSelector,
ConfidenceTracker,
IntelligentContentFilter,
ParallelLLMProcessor,
ProgressiveTokenBudgeter,
)
logger = logging.getLogger(__name__)
# Speed optimization validation thresholds
SPEED_THRESHOLDS = {
"simple_query_max_time": 15.0, # Simple queries should complete in <15s
"moderate_query_max_time": 25.0, # Moderate queries should complete in <25s
"complex_query_max_time": 35.0, # Complex queries should complete in <35s
"emergency_mode_max_time": 30.0, # Emergency mode should complete in <30s
"minimum_speedup_factor": 2.0, # Minimum 2x speedup over baseline
"target_speedup_factor": 3.0, # Target 3x speedup over baseline
"timeout_failure_threshold": 0.05, # Max 5% timeout failures allowed
}
# Test query complexity definitions
class QueryComplexity(Enum):
SIMPLE = "simple"
MODERATE = "moderate"
COMPLEX = "complex"
EMERGENCY = "emergency"
# Test query templates by complexity
SPEED_TEST_QUERIES = {
QueryComplexity.SIMPLE: [
"Apple Inc current stock price and basic sentiment",
"Tesla recent news and market overview",
"Microsoft quarterly earnings summary",
"NVIDIA stock performance this month",
],
QueryComplexity.MODERATE: [
"Apple Inc comprehensive financial analysis and competitive position in smartphone market",
"Tesla Inc market outlook considering EV competition and regulatory changes",
"Microsoft Corp cloud business growth prospects and AI integration strategy",
"NVIDIA competitive analysis in semiconductor and AI acceleration markets",
],
QueryComplexity.COMPLEX: [
"Apple Inc deep fundamental analysis including supply chain risks, product lifecycle assessment, regulatory challenges across global markets, competitive positioning against Samsung and Google, and 5-year growth trajectory considering AR/VR investments and services expansion",
"Tesla Inc comprehensive investment thesis covering production scaling challenges, battery technology competitive advantages, autonomous driving timeline and regulatory risks, energy business growth potential, and Elon Musk leadership impact on stock volatility",
"Microsoft Corp strategic analysis of cloud infrastructure competition with AWS and Google, AI monetization through Copilot integration, gaming division performance post-Activision acquisition, and enterprise software market share defense against Salesforce and Oracle",
"NVIDIA Corp detailed semiconductor industry analysis covering data center growth drivers, gaming market maturity, automotive AI partnerships, geopolitical chip manufacturing risks, and competitive threats from AMD, Intel, and custom silicon development by major cloud providers",
],
QueryComplexity.EMERGENCY: [
"Quick Apple sentiment - bullish or bearish right now?",
"Tesla stock - buy, hold, or sell this week?",
"Microsoft earnings - beat or miss expectations?",
"NVIDIA - momentum trade opportunity today?",
],
}
# Expected model selections for each scenario
EXPECTED_MODEL_SELECTIONS = {
QueryComplexity.SIMPLE: ["google/gemini-2.5-flash", "openai/gpt-4o-mini"],
QueryComplexity.MODERATE: ["openai/gpt-4o-mini", "google/gemini-2.5-flash"],
QueryComplexity.COMPLEX: [
"anthropic/claude-sonnet-4",
"google/gemini-2.5-pro",
],
QueryComplexity.EMERGENCY: ["google/gemini-2.5-flash", "openai/gpt-4o-mini"],
}
# Token generation speeds (tokens/second) for validation
MODEL_SPEED_BENCHMARKS = {
"google/gemini-2.5-flash": 199,
"openai/gpt-4o-mini": 126,
"anthropic/claude-sonnet-4": 45,
"google/gemini-2.5-pro": 25,
"anthropic/claude-haiku": 89,
}
class SpeedTestMonitor:
"""Monitors speed optimization performance during test execution."""
def __init__(self, test_name: str, complexity: QueryComplexity):
self.test_name = test_name
self.complexity = complexity
self.start_time: float = 0
self.end_time: float = 0
self.phase_timings: dict[str, float] = {}
self.model_selections: list[str] = []
self.optimization_metrics: dict[str, Any] = {}
def __enter__(self):
"""Start speed monitoring."""
self.start_time = time.time()
logger.info(f"Starting speed test: {self.test_name} ({self.complexity.value})")
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Complete speed monitoring and log results."""
self.end_time = time.time()
total_time = self.end_time - self.start_time
logger.info(
f"Speed test completed: {self.test_name} - "
f"Time: {total_time:.2f}s, "
f"Complexity: {self.complexity.value}, "
f"Models: {self.model_selections}"
)
def record_phase(self, phase_name: str, duration: float):
"""Record timing for a specific phase."""
self.phase_timings[phase_name] = duration
def record_model_selection(self, model_id: str):
"""Record which model was selected."""
self.model_selections.append(model_id)
def record_optimization_metric(self, metric_name: str, value: Any):
"""Record optimization-specific metrics."""
self.optimization_metrics[metric_name] = value
@property
def total_execution_time(self) -> float:
"""Get total execution time."""
return self.end_time - self.start_time if self.end_time > 0 else 0
class MockOpenRouterProvider:
"""Mock OpenRouter provider that simulates realistic API response times."""
def __init__(self, simulate_model_speeds: bool = True):
self.simulate_model_speeds = simulate_model_speeds
self.call_history: list[dict[str, Any]] = []
def get_llm(self, model_override: str = None, **kwargs):
"""Get mock LLM with realistic speed simulation."""
model_id = model_override or "openai/gpt-4o-mini"
mock_llm = AsyncMock()
mock_llm.model_id = model_id
# Simulate realistic response times based on model speed
if self.simulate_model_speeds:
speed = MODEL_SPEED_BENCHMARKS.get(model_id, 50)
max_tokens = kwargs.get("max_tokens", 1000)
# Calculate response time: (tokens / speed) + API overhead
response_time = (max_tokens / speed) + 0.5 # 0.5s API overhead
else:
response_time = 0.1 # Fast mock response
async def mock_ainvoke(messages):
await asyncio.sleep(response_time)
# Record the call
self.call_history.append(
{
"model_id": model_id,
"response_time": response_time,
"max_tokens": kwargs.get("max_tokens", 1000),
"timestamp": time.time(),
"messages": len(messages),
}
)
# Return mock response
mock_response = MagicMock()
mock_response.content = (
f"Mock response from {model_id} (simulated {response_time:.2f}s)"
)
return mock_response
mock_llm.ainvoke = mock_ainvoke
return mock_llm
class SpeedOptimizationValidator:
"""Validates speed optimization claims and performance improvements."""
@staticmethod
async def test_adaptive_model_selection(
time_budget: float, complexity: float, expected_models: list[str]
) -> dict[str, Any]:
"""Test that adaptive model selection chooses appropriate fast models."""
provider = MockOpenRouterProvider(simulate_model_speeds=True)
selector = AdaptiveModelSelector(provider)
# Test model selection for time budget
model_config = selector.select_model_for_time_budget(
task_type=TaskType.MARKET_ANALYSIS,
time_remaining_seconds=time_budget,
complexity_score=complexity,
content_size_tokens=1000,
)
return {
"selected_model": model_config.model_id,
"max_tokens": model_config.max_tokens,
"timeout_seconds": model_config.timeout_seconds,
"expected_models": expected_models,
"model_appropriate": model_config.model_id in expected_models,
"speed_optimized": model_config.model_id
in ["google/gemini-2.5-flash", "openai/gpt-4o-mini"],
}
@staticmethod
async def test_emergency_mode_performance(query: str) -> dict[str, Any]:
"""Test emergency mode performance (< 30s completion)."""
provider = MockOpenRouterProvider(simulate_model_speeds=True)
# Create optimized research agent
agent = OptimizedDeepResearchAgent(
openrouter_provider=provider,
persona="moderate",
optimization_enabled=True,
)
# Mock the search providers to avoid actual API calls
agent.search_providers = [MockSearchProvider()]
start_time = time.time()
try:
# Test with strict emergency time budget
result = await agent.research_comprehensive(
topic=query,
session_id="emergency_test",
depth="basic",
time_budget_seconds=25.0, # Strict emergency budget
target_confidence=0.6, # Lower confidence for speed
)
execution_time = time.time() - start_time
return {
"success": True,
"execution_time": execution_time,
"within_budget": execution_time
< SPEED_THRESHOLDS["emergency_mode_max_time"],
"result_status": result.get("status", "unknown"),
"emergency_mode_used": result.get("emergency_mode", False),
"optimization_metrics": result.get("optimization_metrics", {}),
}
except Exception as e:
execution_time = time.time() - start_time
return {
"success": False,
"execution_time": execution_time,
"error": str(e),
"within_budget": execution_time
< SPEED_THRESHOLDS["emergency_mode_max_time"],
}
@staticmethod
async def test_baseline_vs_optimized_performance(
query: str, complexity: QueryComplexity
) -> dict[str, Any]:
"""Compare baseline vs optimized agent performance."""
provider = MockOpenRouterProvider(simulate_model_speeds=True)
# Test baseline agent (non-optimized)
baseline_agent = DeepResearchAgent(
llm=provider.get_llm(),
persona="moderate",
enable_parallel_execution=False,
)
baseline_agent.search_providers = [MockSearchProvider()]
# Test optimized agent
optimized_agent = OptimizedDeepResearchAgent(
openrouter_provider=provider,
persona="moderate",
optimization_enabled=True,
)
optimized_agent.search_providers = [MockSearchProvider()]
# Run baseline test
baseline_start = time.time()
try:
baseline_result = await baseline_agent.research_comprehensive(
topic=query,
session_id="baseline_test",
depth="standard",
)
baseline_time = time.time() - baseline_start
baseline_success = True
except Exception as e:
baseline_time = time.time() - baseline_start
baseline_success = False
baseline_result = {"error": str(e)}
# Run optimized test
optimized_start = time.time()
try:
optimized_result = await optimized_agent.research_comprehensive(
topic=query,
session_id="optimized_test",
depth="standard",
time_budget_seconds=60.0,
)
optimized_time = time.time() - optimized_start
optimized_success = True
except Exception as e:
optimized_time = time.time() - optimized_start
optimized_success = False
optimized_result = {"error": str(e)}
# Calculate performance metrics
speedup_factor = (
baseline_time / max(optimized_time, 0.001) if optimized_time > 0 else 0
)
return {
"baseline_time": baseline_time,
"optimized_time": optimized_time,
"speedup_factor": speedup_factor,
"baseline_success": baseline_success,
"optimized_success": optimized_success,
"meets_2x_target": speedup_factor
>= SPEED_THRESHOLDS["minimum_speedup_factor"],
"meets_3x_target": speedup_factor
>= SPEED_THRESHOLDS["target_speedup_factor"],
"baseline_result": baseline_result,
"optimized_result": optimized_result,
}
class MockSearchProvider:
"""Mock search provider for testing without external API calls."""
async def search(self, query: str, num_results: int = 5) -> list[dict[str, Any]]:
"""Return mock search results."""
await asyncio.sleep(0.1) # Simulate API delay
return [
{
"title": f"Mock search result {i + 1} for: {query[:30]}",
"url": f"https://example.com/result{i + 1}",
"content": f"Mock content for result {i + 1}. " * 50, # ~50 words
"published_date": datetime.now().isoformat(),
"credibility_score": 0.8,
"relevance_score": 0.9 - (i * 0.1),
}
for i in range(num_results)
]
# Test fixtures (conditional on pytest availability)
if pytest:
@pytest.fixture
def mock_openrouter_provider():
"""Provide mock OpenRouter provider."""
return MockOpenRouterProvider(simulate_model_speeds=True)
@pytest.fixture
def speed_validator():
"""Provide speed optimization validator."""
return SpeedOptimizationValidator()
@pytest.fixture
def speed_monitor_factory():
"""Factory for creating speed test monitors."""
def _create_monitor(test_name: str, complexity: QueryComplexity):
return SpeedTestMonitor(test_name, complexity)
return _create_monitor
# Core Speed Optimization Tests
if pytest:
@pytest.mark.unit
class TestSpeedOptimizations:
"""Core tests for speed optimization functionality."""
async def test_adaptive_model_selector_emergency_mode(
self, mock_openrouter_provider
):
"""Test that emergency mode selects fastest models."""
selector = AdaptiveModelSelector(mock_openrouter_provider)
# Test ultra-emergency mode (< 10s)
config = selector.select_model_for_time_budget(
task_type=TaskType.QUICK_ANSWER,
time_remaining_seconds=8.0,
complexity_score=0.5,
content_size_tokens=500,
)
# Should select fastest model
assert config.model_id in ["google/gemini-2.5-flash", "openai/gpt-4o-mini"]
assert config.timeout_seconds < 10
assert config.max_tokens < 1000
# Test moderate emergency (< 25s)
config = selector.select_model_for_time_budget(
task_type=TaskType.MARKET_ANALYSIS,
time_remaining_seconds=20.0,
complexity_score=0.7,
content_size_tokens=1000,
)
# Should still prefer fast models
assert config.model_id in ["google/gemini-2.5-flash", "openai/gpt-4o-mini"]
assert config.timeout_seconds < 25
async def test_progressive_token_budgeter_time_constraints(self):
"""Test progressive token budgeter adapts to time pressure."""
# Test emergency budget
emergency_budgeter = ProgressiveTokenBudgeter(
total_time_budget_seconds=20.0, confidence_target=0.6
)
allocation = emergency_budgeter.allocate_tokens_for_phase(
phase=emergency_budgeter.phase_budgets.__class__.CONTENT_ANALYSIS,
sources_count=3,
current_confidence=0.3,
complexity_score=0.5,
)
# Emergency mode should have reduced tokens and shorter timeout
assert allocation.output_tokens < 1000
assert allocation.timeout_seconds < 15
# Test standard budget
standard_budgeter = ProgressiveTokenBudgeter(
total_time_budget_seconds=120.0, confidence_target=0.75
)
allocation = standard_budgeter.allocate_tokens_for_phase(
phase=standard_budgeter.phase_budgets.__class__.CONTENT_ANALYSIS,
sources_count=3,
current_confidence=0.3,
complexity_score=0.5,
)
# Standard mode should allow more tokens and time
assert allocation.output_tokens >= 1000
assert allocation.timeout_seconds >= 15
async def test_parallel_llm_processor_speed_optimization(
self, mock_openrouter_provider
):
"""Test parallel LLM processor speed optimizations."""
processor = ParallelLLMProcessor(mock_openrouter_provider, max_concurrent=4)
# Create mock sources
sources = [
{
"title": f"Source {i}",
"content": f"Mock content {i} " * 100, # ~100 words
"url": f"https://example.com/{i}",
}
for i in range(6)
]
start_time = time.time()
results = await processor.parallel_content_analysis(
sources=sources,
analysis_type="sentiment",
persona="moderate",
time_budget_seconds=15.0, # Tight budget
current_confidence=0.0,
)
execution_time = time.time() - start_time
# Should complete within time budget
assert execution_time < 20.0 # Some buffer for test environment
assert len(results) > 0 # Should produce results
# Verify all results have required analysis structure
for result in results:
assert "analysis" in result
analysis = result["analysis"]
assert "sentiment" in analysis
assert "batch_processed" in analysis
async def test_confidence_tracker_early_termination(self):
"""Test confidence tracker enables early termination."""
tracker = ConfidenceTracker(
target_confidence=0.8,
min_sources=2,
max_sources=10,
)
# Simulate high-confidence evidence
high_confidence_evidence = {
"sentiment": {"direction": "bullish", "confidence": 0.9},
"insights": ["Strong positive insight", "Another strong insight"],
"risk_factors": ["Minor risk"],
"opportunities": ["Major opportunity", "Growth catalyst"],
"relevance_score": 0.95,
}
# Process minimum sources first
for _i in range(2):
result = tracker.update_confidence(high_confidence_evidence, 0.9)
if not result["should_continue"]:
break
# After high-confidence sources, should suggest early termination
final_result = tracker.update_confidence(high_confidence_evidence, 0.9)
assert final_result["current_confidence"] > 0.7
# Early termination logic should trigger with high confidence
async def test_intelligent_content_filter_speed_optimization(self):
"""Test intelligent content filtering reduces processing overhead."""
filter = IntelligentContentFilter()
# Create sources with varying relevance
sources = [
{
"title": "Apple Inc Q4 Earnings Beat Expectations",
"content": "Apple Inc reported strong Q4 earnings with revenue growth of 15%. "
+ "The company's iPhone sales exceeded analysts' expectations. "
* 20,
"url": "https://reuters.com/apple-earnings",
"published_date": datetime.now().isoformat(),
},
{
"title": "Random Tech News Not About Apple",
"content": "Some unrelated tech news content. " * 50,
"url": "https://example.com/random",
"published_date": "2023-01-01T00:00:00",
},
{
"title": "Apple Supply Chain Analysis",
"content": "Apple's supply chain faces challenges but shows resilience. "
+ "Manufacturing partnerships in Asia remain strong. " * 15,
"url": "https://wsj.com/apple-supply-chain",
"published_date": datetime.now().isoformat(),
},
]
filtered_sources = await filter.filter_and_prioritize_sources(
sources=sources,
research_focus="fundamental",
time_budget=20.0, # Tight budget
current_confidence=0.0,
)
# Should prioritize relevant, high-quality sources
assert len(filtered_sources) <= len(sources)
if filtered_sources:
# First source should be most relevant
assert "apple" in filtered_sources[0]["title"].lower()
# Should have preprocessing applied
assert "original_length" in filtered_sources[0]
# Speed Validation Tests by Query Complexity
if pytest:
@pytest.mark.integration
class TestQueryComplexitySpeedValidation:
"""Test speed validation across different query complexities."""
@pytest.mark.parametrize("complexity", list(QueryComplexity))
async def test_query_completion_time_thresholds(
self, complexity: QueryComplexity, speed_monitor_factory, speed_validator
):
"""Test queries complete within time thresholds by complexity."""
queries = SPEED_TEST_QUERIES[complexity]
results = []
for query in queries[:2]: # Test 2 queries per complexity
with speed_monitor_factory(
f"complexity_test_{complexity.value}", complexity
) as monitor:
if complexity == QueryComplexity.EMERGENCY:
result = await speed_validator.test_emergency_mode_performance(
query
)
else:
# Use baseline vs optimized for other complexities
result = await speed_validator.test_baseline_vs_optimized_performance(
query, complexity
)
monitor.record_optimization_metric(
"completion_time", monitor.total_execution_time
)
results.append(
{
"query": query,
"execution_time": monitor.total_execution_time,
"result": result,
}
)
# Validate time thresholds based on complexity
threshold_map = {
QueryComplexity.SIMPLE: SPEED_THRESHOLDS["simple_query_max_time"],
QueryComplexity.MODERATE: SPEED_THRESHOLDS["moderate_query_max_time"],
QueryComplexity.COMPLEX: SPEED_THRESHOLDS["complex_query_max_time"],
QueryComplexity.EMERGENCY: SPEED_THRESHOLDS["emergency_mode_max_time"],
}
max_allowed_time = threshold_map[complexity]
for result in results:
execution_time = result["execution_time"]
assert execution_time < max_allowed_time, (
f"{complexity.value} query exceeded time threshold: "
f"{execution_time:.2f}s > {max_allowed_time}s"
)
# Log performance summary
avg_time = statistics.mean([r["execution_time"] for r in results])
logger.info(
f"{complexity.value} queries - Avg time: {avg_time:.2f}s "
f"(threshold: {max_allowed_time}s)"
)
async def test_emergency_mode_model_selection(self, mock_openrouter_provider):
"""Test emergency mode selects fastest models."""
selector = AdaptiveModelSelector(mock_openrouter_provider)
# Test various emergency time budgets
emergency_scenarios = [5, 10, 15, 20, 25]
for time_budget in emergency_scenarios:
config = selector.select_model_for_time_budget(
task_type=TaskType.QUICK_ANSWER,
time_remaining_seconds=time_budget,
complexity_score=0.3, # Low complexity for emergency
content_size_tokens=200,
)
# Should always select fastest models in emergency scenarios
expected_models = EXPECTED_MODEL_SELECTIONS[QueryComplexity.EMERGENCY]
assert config.model_id in expected_models, (
f"Emergency mode with {time_budget}s budget should select fast model, "
f"got {config.model_id}"
)
# Timeout should be appropriate for time budget
assert config.timeout_seconds < time_budget * 0.8, (
f"Timeout too long for emergency budget: "
f"{config.timeout_seconds}s for {time_budget}s budget"
)
# Performance Comparison Tests
if pytest:
@pytest.mark.integration
class TestSpeedImprovementValidation:
"""Validate claimed speed improvements (2-3x faster)."""
async def test_2x_minimum_speedup_validation(self, speed_validator):
"""Validate minimum 2x speedup is achieved."""
moderate_queries = SPEED_TEST_QUERIES[QueryComplexity.MODERATE]
speedup_results = []
for query in moderate_queries[:2]: # Test subset for CI speed
result = await speed_validator.test_baseline_vs_optimized_performance(
query, QueryComplexity.MODERATE
)
if result["baseline_success"] and result["optimized_success"]:
speedup_results.append(result["speedup_factor"])
logger.info(
f"Speedup test: {result['speedup_factor']:.2f}x "
f"({result['baseline_time']:.2f}s -> {result['optimized_time']:.2f}s)"
)
# Validate minimum 2x speedup achieved
if speedup_results:
avg_speedup = statistics.mean(speedup_results)
min(speedup_results)
assert avg_speedup >= SPEED_THRESHOLDS["minimum_speedup_factor"], (
f"Average speedup {avg_speedup:.2f}x below 2x minimum threshold"
)
# At least 80% of tests should meet minimum speedup
meeting_threshold = sum(
1
for s in speedup_results
if s >= SPEED_THRESHOLDS["minimum_speedup_factor"]
)
threshold_rate = meeting_threshold / len(speedup_results)
assert threshold_rate >= 0.8, (
f"Only {threshold_rate:.1%} of tests met 2x speedup threshold "
f"(should be >= 80%)"
)
else:
pytest.skip("No successful speedup comparisons completed")
async def test_3x_target_speedup_aspiration(self, speed_validator):
"""Test aspirational 3x speedup target for simple queries."""
simple_queries = SPEED_TEST_QUERIES[QueryComplexity.SIMPLE]
speedup_results = []
for query in simple_queries:
result = await speed_validator.test_baseline_vs_optimized_performance(
query, QueryComplexity.SIMPLE
)
if result["baseline_success"] and result["optimized_success"]:
speedup_results.append(result["speedup_factor"])
if speedup_results:
avg_speedup = statistics.mean(speedup_results)
max_speedup = max(speedup_results)
logger.info(
f"3x target test - Avg: {avg_speedup:.2f}x, Max: {max_speedup:.2f}x"
)
# This is aspirational - log results but don't fail
target_met = avg_speedup >= SPEED_THRESHOLDS["target_speedup_factor"]
if target_met:
logger.info("🎉 3x speedup target achieved!")
else:
logger.info(f"3x target not yet achieved (current: {avg_speedup:.2f}x)")
# Still assert we're making good progress toward 3x
assert avg_speedup >= 1.5, (
f"Should show significant speedup progress, got {avg_speedup:.2f}x"
)
# Timeout Resolution Tests
if pytest:
@pytest.mark.integration
class TestTimeoutResolution:
"""Test resolution of previous timeout issues (138s, 129s failures)."""
async def test_no_timeout_failures_in_emergency_mode(self, speed_validator):
"""Test emergency mode prevents timeout failures."""
emergency_queries = SPEED_TEST_QUERIES[QueryComplexity.EMERGENCY]
timeout_failures = 0
total_tests = 0
for query in emergency_queries:
total_tests += 1
result = await speed_validator.test_emergency_mode_performance(query)
# Check if execution exceeded emergency time budget
if result["execution_time"] >= SPEED_THRESHOLDS["emergency_mode_max_time"]:
timeout_failures += 1
logger.warning(
f"Emergency mode timeout: {result['execution_time']:.2f}s "
f"for query: {query[:50]}..."
)
# Calculate failure rate
timeout_failure_rate = timeout_failures / max(total_tests, 1)
# Should have very low timeout failure rate
assert timeout_failure_rate <= SPEED_THRESHOLDS["timeout_failure_threshold"], (
f"Timeout failure rate too high: {timeout_failure_rate:.1%} > "
f"{SPEED_THRESHOLDS['timeout_failure_threshold']:.1%}"
)
logger.info(
f"Timeout resolution test: {timeout_failure_rate:.1%} failure rate "
f"({timeout_failures}/{total_tests} timeouts)"
)
async def test_graceful_degradation_under_time_pressure(self, speed_validator):
"""Test system degrades gracefully under extreme time pressure."""
# Simulate very tight time budgets that previously caused 138s/129s failures
tight_budgets = [10, 15, 20, 25] # Various emergency scenarios
degradation_results = []
for budget in tight_budgets:
provider = MockOpenRouterProvider(simulate_model_speeds=True)
agent = OptimizedDeepResearchAgent(
openrouter_provider=provider,
persona="moderate",
optimization_enabled=True,
)
agent.search_providers = [MockSearchProvider()]
start_time = time.time()
try:
result = await agent.research_comprehensive(
topic="Apple Inc urgent analysis needed",
session_id=f"degradation_test_{budget}s",
depth="basic",
time_budget_seconds=budget,
target_confidence=0.5, # Lower expectations
)
execution_time = time.time() - start_time
degradation_results.append(
{
"budget": budget,
"execution_time": execution_time,
"success": True,
"within_budget": execution_time <= budget + 5, # 5s buffer
"emergency_mode": result.get("emergency_mode", False),
}
)
except Exception as e:
execution_time = time.time() - start_time
degradation_results.append(
{
"budget": budget,
"execution_time": execution_time,
"success": False,
"error": str(e),
"within_budget": execution_time <= budget + 5,
}
)
# Validate graceful degradation
successful_tests = [r for r in degradation_results if r["success"]]
within_budget_tests = [r for r in degradation_results if r["within_budget"]]
success_rate = len(successful_tests) / len(degradation_results)
budget_compliance_rate = len(within_budget_tests) / len(degradation_results)
# Should succeed most of the time and stay within budget
assert success_rate >= 0.75, (
f"Success rate too low under time pressure: {success_rate:.1%}"
)
assert budget_compliance_rate >= 0.80, (
f"Budget compliance too low: {budget_compliance_rate:.1%}"
)
logger.info(
f"Graceful degradation test: {success_rate:.1%} success rate, "
f"{budget_compliance_rate:.1%} budget compliance"
)
if __name__ == "__main__":
# Allow running specific test categories
import sys
if len(sys.argv) > 1:
pytest.main([sys.argv[1], "-v", "-s", "--tb=short"])
else:
# Run all speed validation tests by default
pytest.main([__file__, "-v", "-s", "--tb=short"])
```
--------------------------------------------------------------------------------
/maverick_mcp/config/settings.py:
--------------------------------------------------------------------------------
```python
"""
Configuration settings for Maverick-MCP.
This module provides configuration settings that can be customized
through environment variables or a settings file.
"""
import logging
import os
from decimal import Decimal
from pydantic import BaseModel, Field
from maverick_mcp.config.constants import CONFIG
# Set up logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("maverick_mcp.config")
class DatabaseSettings(BaseModel):
"""Database configuration settings."""
host: str = Field(default="localhost", description="Database host")
port: int = Field(default=5432, description="Database port")
username: str = Field(default="postgres", description="Database username")
password: str = Field(default="", description="Database password")
database: str = Field(default="maverick_mcp", description="Database name")
max_connections: int = Field(
default=10, description="Maximum number of connections"
)
@property
def url(self) -> str:
"""Get database URL string."""
# Check for environment variable first
env_url = os.getenv("DATABASE_URL") or os.getenv("POSTGRES_URL")
if env_url:
return env_url
# Default to SQLite for development
return "sqlite:///maverick_mcp.db"
class APISettings(BaseModel):
"""API configuration settings."""
host: str = Field(default="0.0.0.0", description="API host")
port: int = Field(default=8000, description="API port")
debug: bool = Field(default=False, description="Debug mode")
log_level: str = Field(default="info", description="Log level")
cache_timeout: int = Field(default=300, description="Cache timeout in seconds")
cors_origins: list[str] = Field(
default=["http://localhost:3000", "http://localhost:3001"],
description="CORS allowed origins",
)
# Web search API key
exa_api_key: str | None = Field(
default_factory=lambda: os.getenv("EXA_API_KEY"),
description="Exa API key",
)
class DataProviderSettings(BaseModel):
"""Data provider configuration settings."""
api_key: str | None = Field(default=None, description="API key for data provider")
use_cache: bool = Field(default=True, description="Use cache for data")
cache_dir: str = Field(
default="/tmp/maverick_mcp/cache", description="Cache directory"
)
cache_expiry: int = Field(default=86400, description="Cache expiry in seconds")
rate_limit: int = Field(default=5, description="Rate limit per minute")
# Research-specific settings
max_search_results: int = Field(
default=100, description="Max search results per query"
)
research_cache_ttl: int = Field(
default=3600, description="Research cache TTL in seconds"
)
content_max_length: int = Field(
default=2000, description="Max content length per source"
)
class RedisSettings(BaseModel):
"""Redis configuration settings."""
host: str = Field(
default_factory=lambda: CONFIG["redis"]["host"], description="Redis host"
)
port: int = Field(
default_factory=lambda: CONFIG["redis"]["port"], description="Redis port"
)
db: int = Field(
default_factory=lambda: CONFIG["redis"]["db"],
description="Redis database number",
)
username: str | None = Field(
default_factory=lambda: CONFIG["redis"]["username"],
description="Redis username",
)
password: str | None = Field(
default_factory=lambda: CONFIG["redis"]["password"],
description="Redis password",
)
ssl: bool = Field(
default_factory=lambda: CONFIG["redis"]["ssl"],
description="Use SSL for Redis connection",
)
@property
def url(self) -> str:
"""Get Redis URL string."""
scheme = "rediss" if self.ssl else "redis"
auth = ""
if self.username and self.password:
auth = f"{self.username}:{self.password}@"
elif self.password:
auth = f":{self.password}@"
return f"{scheme}://{auth}{self.host}:{self.port}/{self.db}"
class ResearchSettings(BaseModel):
"""Research and web search configuration settings."""
# API key for web search provider
exa_api_key: str | None = Field(
default_factory=lambda: os.getenv("EXA_API_KEY"),
description="Exa API key for web search",
)
# Research parameters
default_max_sources: int = Field(
default=50, description="Default max sources per research"
)
default_research_depth: str = Field(
default="comprehensive", description="Default research depth"
)
cache_ttl_hours: int = Field(default=4, description="Research cache TTL in hours")
# Content analysis settings
max_content_length: int = Field(
default=2000, description="Max content length per source"
)
sentiment_confidence_threshold: float = Field(
default=0.7, description="Sentiment confidence threshold"
)
credibility_score_threshold: float = Field(
default=0.6, description="Source credibility threshold"
)
# Rate limiting
search_rate_limit: int = Field(default=10, description="Search requests per minute")
content_analysis_batch_size: int = Field(
default=5, description="Content analysis batch size"
)
# Domain filtering
trusted_domains: list[str] = Field(
default=[
"reuters.com",
"bloomberg.com",
"wsj.com",
"ft.com",
"marketwatch.com",
"cnbc.com",
"yahoo.com",
"seekingalpha.com",
],
description="Trusted news domains for research",
)
blocked_domains: list[str] = Field(
default=[], description="Blocked domains for research"
)
@property
def api_keys(self) -> dict[str, str | None]:
"""Get API keys as dictionary."""
return {"exa_api_key": self.exa_api_key}
class DataLimitsConfig(BaseModel):
"""Data limits and constraints configuration settings."""
# API Rate limits
max_api_requests_per_minute: int = Field(
default_factory=lambda: int(os.getenv("MAX_API_REQUESTS_PER_MINUTE", "60")),
description="Maximum API requests per minute",
)
max_api_requests_per_hour: int = Field(
default_factory=lambda: int(os.getenv("MAX_API_REQUESTS_PER_HOUR", "1000")),
description="Maximum API requests per hour",
)
# Data size limits
max_data_rows_per_request: int = Field(
default_factory=lambda: int(os.getenv("MAX_DATA_ROWS_PER_REQUEST", "10000")),
description="Maximum data rows per request",
)
max_symbols_per_batch: int = Field(
default_factory=lambda: int(os.getenv("MAX_SYMBOLS_PER_BATCH", "100")),
description="Maximum symbols per batch request",
)
max_response_size_mb: int = Field(
default_factory=lambda: int(os.getenv("MAX_RESPONSE_SIZE_MB", "50")),
description="Maximum response size in MB",
)
# Research limits
max_research_sources: int = Field(
default_factory=lambda: int(os.getenv("MAX_RESEARCH_SOURCES", "100")),
description="Maximum research sources per query",
)
max_research_depth_level: int = Field(
default_factory=lambda: int(os.getenv("MAX_RESEARCH_DEPTH_LEVEL", "5")),
description="Maximum research depth level",
)
max_content_analysis_items: int = Field(
default_factory=lambda: int(os.getenv("MAX_CONTENT_ANALYSIS_ITEMS", "50")),
description="Maximum content items for analysis",
)
# Agent limits
max_agent_iterations: int = Field(
default_factory=lambda: int(os.getenv("MAX_AGENT_ITERATIONS", "10")),
description="Maximum agent workflow iterations",
)
max_parallel_agents: int = Field(
default_factory=lambda: int(os.getenv("MAX_PARALLEL_AGENTS", "5")),
description="Maximum parallel agents in orchestration",
)
max_agent_execution_time_seconds: int = Field(
default_factory=lambda: int(
os.getenv("MAX_AGENT_EXECUTION_TIME_SECONDS", "720")
),
description="Maximum agent execution time in seconds",
)
# Cache limits
max_cache_size_mb: int = Field(
default_factory=lambda: int(os.getenv("MAX_CACHE_SIZE_MB", "500")),
description="Maximum cache size in MB",
)
max_cached_items: int = Field(
default_factory=lambda: int(os.getenv("MAX_CACHED_ITEMS", "10000")),
description="Maximum number of cached items",
)
# Database limits
max_db_connections: int = Field(
default_factory=lambda: int(os.getenv("MAX_DB_CONNECTIONS", "100")),
description="Maximum database connections",
)
max_query_results: int = Field(
default_factory=lambda: int(os.getenv("MAX_QUERY_RESULTS", "50000")),
description="Maximum query results",
)
class ExternalDataSettings(BaseModel):
"""External data API configuration settings."""
api_key: str | None = Field(
default_factory=lambda: os.getenv("EXTERNAL_DATA_API_KEY"),
description="API key for external data API",
)
base_url: str = Field(
default="https://external-data-api.com",
description="Base URL for external data API",
)
class EmailSettings(BaseModel):
"""Email service configuration settings."""
enabled: bool = Field(
default_factory=lambda: os.getenv("EMAIL_ENABLED", "true").lower() == "true",
description="Enable email sending",
)
mailgun_api_key: str = Field(
default_factory=lambda: os.getenv("MAILGUN_API_KEY", ""),
description="Mailgun API key",
)
mailgun_domain: str = Field(
default_factory=lambda: os.getenv("MAILGUN_DOMAIN", ""),
description="Mailgun sending domain",
)
from_address: str = Field(
default_factory=lambda: os.getenv("EMAIL_FROM_ADDRESS", "noreply@localhost"),
description="Default from email address",
)
from_name: str = Field(
default_factory=lambda: os.getenv("EMAIL_FROM_NAME", "MaverickMCP"),
description="Default from name",
)
support_email: str = Field(
default_factory=lambda: os.getenv("EMAIL_SUPPORT", "support@localhost"),
description="Support email address",
)
class FinancialConfig(BaseModel):
"""Financial calculations and portfolio management settings."""
# Portfolio defaults
default_account_size: Decimal = Field(
default_factory=lambda: Decimal(os.getenv("DEFAULT_ACCOUNT_SIZE", "100000")),
description="Default account size for calculations (USD)",
)
@property
def api_keys(self) -> dict[str, str | None]:
"""Get API keys as dictionary (placeholder for financial data APIs)."""
return {}
# Risk management
max_position_size_conservative: float = Field(
default_factory=lambda: float(
os.getenv("MAX_POSITION_SIZE_CONSERVATIVE", "0.05")
),
description="Maximum position size for conservative investors (5%)",
)
max_position_size_moderate: float = Field(
default_factory=lambda: float(os.getenv("MAX_POSITION_SIZE_MODERATE", "0.10")),
description="Maximum position size for moderate investors (10%)",
)
max_position_size_aggressive: float = Field(
default_factory=lambda: float(
os.getenv("MAX_POSITION_SIZE_AGGRESSIVE", "0.20")
),
description="Maximum position size for aggressive investors (20%)",
)
max_position_size_day_trader: float = Field(
default_factory=lambda: float(
os.getenv("MAX_POSITION_SIZE_DAY_TRADER", "0.25")
),
description="Maximum position size for day traders (25%)",
)
# Stop loss multipliers
stop_loss_multiplier_conservative: float = Field(
default_factory=lambda: float(
os.getenv("STOP_LOSS_MULTIPLIER_CONSERVATIVE", "1.5")
),
description="Stop loss multiplier for conservative investors",
)
stop_loss_multiplier_moderate: float = Field(
default_factory=lambda: float(
os.getenv("STOP_LOSS_MULTIPLIER_MODERATE", "1.2")
),
description="Stop loss multiplier for moderate investors",
)
stop_loss_multiplier_aggressive: float = Field(
default_factory=lambda: float(
os.getenv("STOP_LOSS_MULTIPLIER_AGGRESSIVE", "1.0")
),
description="Stop loss multiplier for aggressive investors",
)
stop_loss_multiplier_day_trader: float = Field(
default_factory=lambda: float(
os.getenv("STOP_LOSS_MULTIPLIER_DAY_TRADER", "0.8")
),
description="Stop loss multiplier for day traders",
)
# Risk tolerance ranges (0-100 scale)
risk_tolerance_conservative_min: int = Field(
default_factory=lambda: int(os.getenv("RISK_TOLERANCE_CONSERVATIVE_MIN", "10")),
description="Minimum risk tolerance for conservative investors",
)
risk_tolerance_conservative_max: int = Field(
default_factory=lambda: int(os.getenv("RISK_TOLERANCE_CONSERVATIVE_MAX", "30")),
description="Maximum risk tolerance for conservative investors",
)
risk_tolerance_moderate_min: int = Field(
default_factory=lambda: int(os.getenv("RISK_TOLERANCE_MODERATE_MIN", "30")),
description="Minimum risk tolerance for moderate investors",
)
risk_tolerance_moderate_max: int = Field(
default_factory=lambda: int(os.getenv("RISK_TOLERANCE_MODERATE_MAX", "60")),
description="Maximum risk tolerance for moderate investors",
)
risk_tolerance_aggressive_min: int = Field(
default_factory=lambda: int(os.getenv("RISK_TOLERANCE_AGGRESSIVE_MIN", "60")),
description="Minimum risk tolerance for aggressive investors",
)
risk_tolerance_aggressive_max: int = Field(
default_factory=lambda: int(os.getenv("RISK_TOLERANCE_AGGRESSIVE_MAX", "90")),
description="Maximum risk tolerance for aggressive investors",
)
risk_tolerance_day_trader_min: int = Field(
default_factory=lambda: int(os.getenv("RISK_TOLERANCE_DAY_TRADER_MIN", "70")),
description="Minimum risk tolerance for day traders",
)
risk_tolerance_day_trader_max: int = Field(
default_factory=lambda: int(os.getenv("RISK_TOLERANCE_DAY_TRADER_MAX", "95")),
description="Maximum risk tolerance for day traders",
)
# Technical analysis weights
rsi_weight: float = Field(
default_factory=lambda: float(os.getenv("TECHNICAL_RSI_WEIGHT", "2.0")),
description="Weight for RSI in technical analysis scoring",
)
macd_weight: float = Field(
default_factory=lambda: float(os.getenv("TECHNICAL_MACD_WEIGHT", "1.5")),
description="Weight for MACD in technical analysis scoring",
)
momentum_weight: float = Field(
default_factory=lambda: float(os.getenv("TECHNICAL_MOMENTUM_WEIGHT", "1.0")),
description="Weight for momentum indicators in technical analysis scoring",
)
volume_weight: float = Field(
default_factory=lambda: float(os.getenv("TECHNICAL_VOLUME_WEIGHT", "1.0")),
description="Weight for volume indicators in technical analysis scoring",
)
# Trend identification thresholds
uptrend_threshold: float = Field(
default_factory=lambda: float(os.getenv("UPTREND_THRESHOLD", "1.2")),
description="Threshold multiplier for identifying uptrends",
)
downtrend_threshold: float = Field(
default_factory=lambda: float(os.getenv("DOWNTREND_THRESHOLD", "0.8")),
description="Threshold multiplier for identifying downtrends",
)
class PerformanceConfig(BaseModel):
"""Performance settings for timeouts, retries, batch sizes, and cache TTLs."""
# Timeout settings
api_request_timeout: int = Field(
default_factory=lambda: int(os.getenv("API_REQUEST_TIMEOUT", "120")),
description="Default API request timeout in seconds",
)
yfinance_timeout: int = Field(
default_factory=lambda: int(os.getenv("YFINANCE_TIMEOUT_SECONDS", "60")),
description="yfinance API timeout in seconds",
)
database_timeout: int = Field(
default_factory=lambda: int(os.getenv("DATABASE_TIMEOUT", "60")),
description="Database operation timeout in seconds",
)
# Search provider timeouts
search_timeout_base: int = Field(
default_factory=lambda: int(os.getenv("SEARCH_TIMEOUT_BASE", "60")),
description="Base search timeout in seconds for simple queries",
)
search_timeout_complex: int = Field(
default_factory=lambda: int(os.getenv("SEARCH_TIMEOUT_COMPLEX", "120")),
description="Search timeout in seconds for complex queries",
)
search_timeout_max: int = Field(
default_factory=lambda: int(os.getenv("SEARCH_TIMEOUT_MAX", "180")),
description="Maximum search timeout in seconds",
)
# Retry settings
max_retry_attempts: int = Field(
default_factory=lambda: int(os.getenv("MAX_RETRY_ATTEMPTS", "3")),
description="Maximum number of retry attempts for failed operations",
)
retry_backoff_factor: float = Field(
default_factory=lambda: float(os.getenv("RETRY_BACKOFF_FACTOR", "2.0")),
description="Exponential backoff factor for retries",
)
# Batch processing
default_batch_size: int = Field(
default_factory=lambda: int(os.getenv("DEFAULT_BATCH_SIZE", "50")),
description="Default batch size for processing operations",
)
max_batch_size: int = Field(
default_factory=lambda: int(os.getenv("MAX_BATCH_SIZE", "1000")),
description="Maximum batch size allowed",
)
parallel_screening_workers: int = Field(
default_factory=lambda: int(os.getenv("PARALLEL_SCREENING_WORKERS", "4")),
description="Number of worker processes for parallel screening",
)
# Cache settings
cache_ttl_seconds: int = Field(
default_factory=lambda: int(os.getenv("CACHE_TTL_SECONDS", "604800")), # 7 days
description="Default cache TTL in seconds",
)
quick_cache_ttl: int = Field(
default_factory=lambda: int(os.getenv("QUICK_CACHE_TTL", "300")), # 5 minutes
description="Quick cache TTL for frequently accessed data",
)
agent_cache_ttl: int = Field(
default_factory=lambda: int(os.getenv("AGENT_CACHE_TTL", "3600")), # 1 hour
description="Agent state cache TTL in seconds",
)
# Rate limiting
api_rate_limit_per_minute: int = Field(
default_factory=lambda: int(os.getenv("API_RATE_LIMIT_PER_MINUTE", "60")),
description="API rate limit requests per minute",
)
data_provider_rate_limit: int = Field(
default_factory=lambda: int(os.getenv("DATA_PROVIDER_RATE_LIMIT", "5")),
description="Data provider rate limit per minute",
)
class UIConfig(BaseModel):
"""UI and user experience configuration settings."""
# Pagination defaults
default_page_size: int = Field(
default_factory=lambda: int(os.getenv("DEFAULT_PAGE_SIZE", "20")),
description="Default number of items per page",
)
max_page_size: int = Field(
default_factory=lambda: int(os.getenv("MAX_PAGE_SIZE", "100")),
description="Maximum number of items per page",
)
# Data display limits
max_stocks_per_screening: int = Field(
default_factory=lambda: int(os.getenv("MAX_STOCKS_PER_SCREENING", "100")),
description="Maximum number of stocks returned in screening results",
)
default_screening_limit: int = Field(
default_factory=lambda: int(os.getenv("DEFAULT_SCREENING_LIMIT", "20")),
description="Default number of stocks in screening results",
)
max_portfolio_stocks: int = Field(
default_factory=lambda: int(os.getenv("MAX_PORTFOLIO_STOCKS", "30")),
description="Maximum number of stocks in portfolio analysis",
)
default_portfolio_stocks: int = Field(
default_factory=lambda: int(os.getenv("DEFAULT_PORTFOLIO_STOCKS", "10")),
description="Default number of stocks in portfolio analysis",
)
# Historical data defaults
default_history_days: int = Field(
default_factory=lambda: int(os.getenv("DEFAULT_HISTORY_DAYS", "365")),
description="Default number of days of historical data",
)
min_history_days: int = Field(
default_factory=lambda: int(os.getenv("MIN_HISTORY_DAYS", "30")),
description="Minimum number of days of historical data",
)
max_history_days: int = Field(
default_factory=lambda: int(os.getenv("MAX_HISTORY_DAYS", "1825")), # 5 years
description="Maximum number of days of historical data",
)
# Technical analysis periods
default_rsi_period: int = Field(
default_factory=lambda: int(os.getenv("DEFAULT_RSI_PERIOD", "14")),
description="Default RSI calculation period",
)
default_sma_period: int = Field(
default_factory=lambda: int(os.getenv("DEFAULT_SMA_PERIOD", "20")),
description="Default SMA calculation period",
)
default_trend_period: int = Field(
default_factory=lambda: int(os.getenv("DEFAULT_TREND_PERIOD", "50")),
description="Default trend identification period",
)
# Symbol validation
min_symbol_length: int = Field(
default_factory=lambda: int(os.getenv("MIN_SYMBOL_LENGTH", "1")),
description="Minimum stock symbol length",
)
max_symbol_length: int = Field(
default_factory=lambda: int(os.getenv("MAX_SYMBOL_LENGTH", "10")),
description="Maximum stock symbol length",
)
class ProviderConfig(BaseModel):
"""Data provider API limits and configuration settings."""
# External data API limits
external_data_requests_per_minute: int = Field(
default_factory=lambda: int(
os.getenv("EXTERNAL_DATA_REQUESTS_PER_MINUTE", "60")
),
description="External data API requests per minute",
)
external_data_timeout: int = Field(
default_factory=lambda: int(os.getenv("EXTERNAL_DATA_TIMEOUT", "120")),
description="External data API timeout in seconds",
)
# Yahoo Finance limits
yfinance_requests_per_minute: int = Field(
default_factory=lambda: int(os.getenv("YFINANCE_REQUESTS_PER_MINUTE", "120")),
description="Yahoo Finance requests per minute",
)
yfinance_max_symbols_per_request: int = Field(
default_factory=lambda: int(
os.getenv("YFINANCE_MAX_SYMBOLS_PER_REQUEST", "50")
),
description="Maximum symbols per Yahoo Finance request",
)
# Finviz limits
finviz_requests_per_minute: int = Field(
default_factory=lambda: int(os.getenv("FINVIZ_REQUESTS_PER_MINUTE", "30")),
description="Finviz requests per minute",
)
finviz_timeout: int = Field(
default_factory=lambda: int(os.getenv("FINVIZ_TIMEOUT", "60")),
description="Finviz timeout in seconds",
)
# News API limits
news_api_requests_per_day: int = Field(
default_factory=lambda: int(os.getenv("NEWS_API_REQUESTS_PER_DAY", "1000")),
description="News API requests per day",
)
max_news_articles: int = Field(
default_factory=lambda: int(os.getenv("MAX_NEWS_ARTICLES", "50")),
description="Maximum news articles to fetch",
)
default_news_limit: int = Field(
default_factory=lambda: int(os.getenv("DEFAULT_NEWS_LIMIT", "5")),
description="Default number of news articles to return",
)
# Cache configuration per provider
stock_data_cache_hours: int = Field(
default_factory=lambda: int(os.getenv("STOCK_DATA_CACHE_HOURS", "4")),
description="Stock data cache duration in hours",
)
market_data_cache_minutes: int = Field(
default_factory=lambda: int(os.getenv("MARKET_DATA_CACHE_MINUTES", "15")),
description="Market data cache duration in minutes",
)
news_cache_hours: int = Field(
default_factory=lambda: int(os.getenv("NEWS_CACHE_HOURS", "2")),
description="News data cache duration in hours",
)
class AgentConfig(BaseModel):
"""Agent and AI workflow configuration settings."""
# Cache settings
agent_cache_ttl_seconds: int = Field(
default_factory=lambda: int(os.getenv("AGENT_CACHE_TTL_SECONDS", "300")),
description="Agent cache TTL in seconds (5 minutes default)",
)
conversation_cache_ttl_hours: int = Field(
default_factory=lambda: int(os.getenv("CONVERSATION_CACHE_TTL_HOURS", "1")),
description="Conversation cache TTL in hours",
)
# Circuit breaker settings
circuit_breaker_failure_threshold: int = Field(
default_factory=lambda: int(
os.getenv("CIRCUIT_BREAKER_FAILURE_THRESHOLD", "5")
),
description="Number of failures before opening circuit",
)
circuit_breaker_recovery_timeout: int = Field(
default_factory=lambda: int(
os.getenv("CIRCUIT_BREAKER_RECOVERY_TIMEOUT", "60")
),
description="Seconds to wait before testing recovery",
)
# Search-specific circuit breaker settings (more tolerant)
search_circuit_breaker_failure_threshold: int = Field(
default_factory=lambda: int(
os.getenv("SEARCH_CIRCUIT_BREAKER_FAILURE_THRESHOLD", "8")
),
description="Number of failures before opening search circuit (more tolerant)",
)
search_circuit_breaker_recovery_timeout: int = Field(
default_factory=lambda: int(
os.getenv("SEARCH_CIRCUIT_BREAKER_RECOVERY_TIMEOUT", "30")
),
description="Seconds to wait before testing search recovery (faster recovery)",
)
search_timeout_failure_threshold: int = Field(
default_factory=lambda: int(
os.getenv("SEARCH_TIMEOUT_FAILURE_THRESHOLD", "12")
),
description="Number of timeout failures before disabling search provider",
)
# Market data limits for sentiment analysis
sentiment_news_limit: int = Field(
default_factory=lambda: int(os.getenv("SENTIMENT_NEWS_LIMIT", "50")),
description="Maximum news articles for sentiment analysis",
)
market_movers_gainers_limit: int = Field(
default_factory=lambda: int(os.getenv("MARKET_MOVERS_GAINERS_LIMIT", "50")),
description="Maximum gainers to fetch for market analysis",
)
market_movers_losers_limit: int = Field(
default_factory=lambda: int(os.getenv("MARKET_MOVERS_LOSERS_LIMIT", "50")),
description="Maximum losers to fetch for market analysis",
)
market_movers_active_limit: int = Field(
default_factory=lambda: int(os.getenv("MARKET_MOVERS_ACTIVE_LIMIT", "20")),
description="Maximum most active stocks to fetch",
)
# Screening limits
screening_limit_default: int = Field(
default_factory=lambda: int(os.getenv("SCREENING_LIMIT_DEFAULT", "20")),
description="Default limit for screening results",
)
screening_limit_max: int = Field(
default_factory=lambda: int(os.getenv("SCREENING_LIMIT_MAX", "100")),
description="Maximum limit for screening results",
)
screening_min_volume_default: int = Field(
default_factory=lambda: int(
os.getenv("SCREENING_MIN_VOLUME_DEFAULT", "1000000")
),
description="Default minimum volume filter for screening",
)
class DatabaseConfig(BaseModel):
"""Database connection and pooling configuration settings."""
# Connection pool settings
pool_size: int = Field(
default_factory=lambda: int(os.getenv("DB_POOL_SIZE", "20")),
description="Database connection pool size",
)
pool_max_overflow: int = Field(
default_factory=lambda: int(os.getenv("DB_POOL_MAX_OVERFLOW", "10")),
description="Maximum overflow connections above pool size",
)
pool_timeout: int = Field(
default_factory=lambda: int(os.getenv("DB_POOL_TIMEOUT", "30")),
description="Pool connection timeout in seconds",
)
statement_timeout: int = Field(
default_factory=lambda: int(os.getenv("DB_STATEMENT_TIMEOUT", "30000")),
description="Database statement timeout in milliseconds",
)
# Redis connection settings
redis_max_connections: int = Field(
default_factory=lambda: int(os.getenv("REDIS_MAX_CONNECTIONS", "50")),
description="Maximum Redis connections in pool",
)
redis_socket_timeout: int = Field(
default_factory=lambda: int(os.getenv("REDIS_SOCKET_TIMEOUT", "5")),
description="Redis socket timeout in seconds",
)
redis_socket_connect_timeout: int = Field(
default_factory=lambda: int(os.getenv("REDIS_SOCKET_CONNECT_TIMEOUT", "5")),
description="Redis socket connection timeout in seconds",
)
redis_retry_on_timeout: bool = Field(
default_factory=lambda: os.getenv("REDIS_RETRY_ON_TIMEOUT", "true").lower()
== "true",
description="Retry Redis operations on timeout",
)
class MiddlewareConfig(BaseModel):
"""Middleware and request handling configuration settings."""
# Rate limiting
api_rate_limit_per_minute: int = Field(
default_factory=lambda: int(os.getenv("API_RATE_LIMIT_PER_MINUTE", "60")),
description="API rate limit per minute",
)
# Security headers
security_header_max_age: int = Field(
default_factory=lambda: int(os.getenv("SECURITY_HEADER_MAX_AGE", "86400")),
description="Security header max age in seconds (24 hours default)",
)
# Request handling
sse_queue_timeout: int = Field(
default_factory=lambda: int(os.getenv("SSE_QUEUE_TIMEOUT", "30")),
description="SSE message queue timeout in seconds",
)
api_request_timeout_default: int = Field(
default_factory=lambda: int(os.getenv("API_REQUEST_TIMEOUT_DEFAULT", "10")),
description="Default API request timeout in seconds",
)
# Thread pool settings
thread_pool_max_workers: int = Field(
default_factory=lambda: int(os.getenv("THREAD_POOL_MAX_WORKERS", "10")),
description="Maximum workers in thread pool executor",
)
class ValidationConfig(BaseModel):
"""Input validation configuration settings."""
# String length constraints
min_symbol_length: int = Field(
default_factory=lambda: int(os.getenv("MIN_SYMBOL_LENGTH", "1")),
description="Minimum stock symbol length",
)
max_symbol_length: int = Field(
default_factory=lambda: int(os.getenv("MAX_SYMBOL_LENGTH", "10")),
description="Maximum stock symbol length",
)
min_portfolio_name_length: int = Field(
default_factory=lambda: int(os.getenv("MIN_PORTFOLIO_NAME_LENGTH", "2")),
description="Minimum portfolio name length",
)
max_portfolio_name_length: int = Field(
default_factory=lambda: int(os.getenv("MAX_PORTFOLIO_NAME_LENGTH", "20")),
description="Maximum portfolio name length",
)
min_screening_name_length: int = Field(
default_factory=lambda: int(os.getenv("MIN_SCREENING_NAME_LENGTH", "2")),
description="Minimum screening strategy name length",
)
max_screening_name_length: int = Field(
default_factory=lambda: int(os.getenv("MAX_SCREENING_NAME_LENGTH", "30")),
description="Maximum screening strategy name length",
)
# General text validation
min_text_field_length: int = Field(
default_factory=lambda: int(os.getenv("MIN_TEXT_FIELD_LENGTH", "1")),
description="Minimum length for general text fields",
)
max_text_field_length: int = Field(
default_factory=lambda: int(os.getenv("MAX_TEXT_FIELD_LENGTH", "100")),
description="Maximum length for general text fields",
)
max_description_length: int = Field(
default_factory=lambda: int(os.getenv("MAX_DESCRIPTION_LENGTH", "500")),
description="Maximum length for description fields",
)
class Settings(BaseModel):
"""Main application settings."""
app_name: str = Field(default="MaverickMCP", description="Application name")
environment: str = Field(
default_factory=lambda: os.getenv("ENVIRONMENT", "development"),
description="Environment (development, production)",
)
api: APISettings = Field(default_factory=APISettings, description="API settings")
database: DatabaseSettings = Field(
default_factory=DatabaseSettings, description="Database settings"
)
data_provider: DataProviderSettings = Field(
default_factory=DataProviderSettings, description="Data provider settings"
)
redis: RedisSettings = Field(
default_factory=RedisSettings, description="Redis settings"
)
external_data: ExternalDataSettings = Field(
default_factory=ExternalDataSettings,
description="External data API settings",
)
email: EmailSettings = Field(
default_factory=EmailSettings, description="Email service configuration"
)
financial: FinancialConfig = Field(
default_factory=FinancialConfig, description="Financial settings"
)
research: ResearchSettings = Field(
default_factory=ResearchSettings, description="Research settings"
)
data_limits: DataLimitsConfig = Field(
default_factory=DataLimitsConfig, description="Data limits settings"
)
agent: AgentConfig = Field(
default_factory=AgentConfig, description="Agent settings"
)
validation: ValidationConfig = Field(
default_factory=FinancialConfig, description="Financial calculation settings"
)
performance: PerformanceConfig = Field(
default_factory=PerformanceConfig, description="Performance settings"
)
ui: UIConfig = Field(default_factory=UIConfig, description="UI configuration")
provider: ProviderConfig = Field(
default_factory=ProviderConfig, description="Provider configuration"
)
agent: AgentConfig = Field(
default_factory=AgentConfig, description="Agent configuration"
)
db: DatabaseConfig = Field(
default_factory=DatabaseConfig, description="Database connection settings"
)
middleware: MiddlewareConfig = Field(
default_factory=MiddlewareConfig, description="Middleware settings"
)
validation: ValidationConfig = Field(
default_factory=ValidationConfig, description="Validation settings"
)
def load_settings_from_environment() -> Settings:
"""
Load settings from environment variables.
Environment variables should be prefixed with MAVERICK_MCP_,
e.g., MAVERICK_MCP_API__PORT=8000
Returns:
Settings object with values loaded from environment
"""
return Settings()
def get_settings() -> Settings:
"""
Get application settings.
This function loads settings from environment variables and
any custom overrides specified in the constants.
Returns:
Settings object with all configured values
"""
settings = load_settings_from_environment()
# Apply any overrides from constants
if hasattr(CONFIG, "SETTINGS"):
# This would update settings with values from CONFIG.SETTINGS
pass
# Override with environment-specific settings if needed
if settings.environment == "production":
# Apply production-specific settings
# e.g., disable debug mode, set higher rate limits, etc.
settings.api.debug = False
settings.api.log_level = "warning"
settings.data_provider.rate_limit = 20
return settings
# Create a singleton instance of settings
settings = get_settings()
```
--------------------------------------------------------------------------------
/tests/integration/test_chaos_engineering.py:
--------------------------------------------------------------------------------
```python
"""
Chaos Engineering Tests for Resilience Testing.
This test suite covers:
- API failures and recovery mechanisms
- Database connection drops and reconnection
- Cache failures and fallback behavior
- Circuit breaker behavior under load
- Network timeouts and retries
- Memory pressure scenarios
- CPU overload situations
- External service outages
"""
import asyncio
import logging
import random
import threading
import time
from contextlib import ExitStack, contextmanager
from unittest.mock import MagicMock, Mock, patch
import numpy as np
import pandas as pd
import pytest
from maverick_mcp.backtesting import VectorBTEngine
from maverick_mcp.backtesting.persistence import BacktestPersistenceManager
from maverick_mcp.backtesting.strategies import STRATEGY_TEMPLATES
logger = logging.getLogger(__name__)
class ChaosInjector:
"""Utility class for injecting various types of failures."""
@staticmethod
@contextmanager
def api_failure_injection(failure_rate: float = 0.3):
"""Inject API failures at specified rate."""
original_get_stock_data = None
def failing_get_stock_data(*args, **kwargs):
if random.random() < failure_rate:
if random.random() < 0.5:
raise ConnectionError("Simulated API connection failure")
else:
raise TimeoutError("Simulated API timeout")
return (
original_get_stock_data(*args, **kwargs)
if original_get_stock_data
else Mock()
)
try:
# Store original method and replace with failing version
with patch.object(
VectorBTEngine,
"get_historical_data",
side_effect=failing_get_stock_data,
):
yield
finally:
pass
@staticmethod
@contextmanager
def database_failure_injection(failure_rate: float = 0.2):
"""Inject database failures at specified rate."""
def failing_db_operation(*args, **kwargs):
if random.random() < failure_rate:
if random.random() < 0.33:
raise ConnectionError("Database connection lost")
elif random.random() < 0.66:
raise Exception("Database query timeout")
else:
raise Exception("Database lock timeout")
return MagicMock() # Return mock successful result
try:
with patch.object(
BacktestPersistenceManager,
"save_backtest_result",
side_effect=failing_db_operation,
):
yield
finally:
pass
@staticmethod
@contextmanager
def memory_pressure_injection(pressure_mb: int = 500):
"""Inject memory pressure by allocating large arrays."""
pressure_arrays = []
try:
# Create memory pressure
for _ in range(pressure_mb // 10):
arr = np.random.random((1280, 1000)) # ~10MB each
pressure_arrays.append(arr)
yield
finally:
# Clean up memory pressure
del pressure_arrays
@staticmethod
@contextmanager
def cpu_load_injection(load_intensity: float = 0.8, duration: float = 5.0):
"""Inject CPU load using background threads."""
stop_event = threading.Event()
load_threads = []
def cpu_intensive_task():
"""CPU-intensive task for load injection."""
while not stop_event.is_set():
# Perform CPU-intensive computation
for _ in range(int(100000 * load_intensity)):
_ = sum(i**2 for i in range(100))
time.sleep(0.01) # Brief pause
try:
# Start CPU load threads
num_threads = max(1, int(4 * load_intensity)) # Scale with intensity
for _ in range(num_threads):
thread = threading.Thread(target=cpu_intensive_task)
thread.daemon = True
thread.start()
load_threads.append(thread)
yield
finally:
# Stop CPU load
stop_event.set()
for thread in load_threads:
thread.join(timeout=1.0)
@staticmethod
@contextmanager
def network_instability_injection(
delay_range: tuple = (0.1, 2.0), timeout_rate: float = 0.1
):
"""Inject network instability with delays and timeouts."""
async def unstable_network_call(original_func, *args, **kwargs):
# Random delay
delay = random.uniform(*delay_range)
await asyncio.sleep(delay)
# Random timeout
if random.random() < timeout_rate:
raise TimeoutError("Simulated network timeout")
return await original_func(*args, **kwargs)
# This is a simplified version - real implementation would patch actual network calls
yield
class TestChaosEngineering:
"""Chaos engineering tests for system resilience."""
@pytest.fixture
async def resilient_data_provider(self):
"""Create data provider with built-in resilience patterns."""
provider = Mock()
async def resilient_get_data(symbol: str, *args, **kwargs):
"""Data provider with retry logic and fallback."""
max_retries = 3
retry_delay = 0.1
for attempt in range(max_retries):
try:
# Simulate data generation (can fail randomly)
if random.random() < 0.1: # 10% failure rate
raise ConnectionError(f"API failure for {symbol}")
# Generate mock data
dates = pd.date_range(
start="2023-01-01", end="2023-12-31", freq="D"
)
returns = np.random.normal(0.0008, 0.02, len(dates))
prices = 100 * np.cumprod(1 + returns)
return pd.DataFrame(
{
"Open": prices * np.random.uniform(0.99, 1.01, len(dates)),
"High": prices * np.random.uniform(1.00, 1.03, len(dates)),
"Low": prices * np.random.uniform(0.97, 1.00, len(dates)),
"Close": prices,
"Volume": np.random.randint(100000, 5000000, len(dates)),
"Adj Close": prices,
},
index=dates,
)
except Exception:
if attempt == max_retries - 1:
# Final attempt failed, return minimal fallback data
logger.warning(
f"All retries failed for {symbol}, using fallback data"
)
dates = pd.date_range(start="2023-01-01", periods=10, freq="D")
prices = np.full(len(dates), 100.0)
return pd.DataFrame(
{
"Open": prices,
"High": prices * 1.01,
"Low": prices * 0.99,
"Close": prices,
"Volume": np.full(len(dates), 1000000),
"Adj Close": prices,
},
index=dates,
)
await asyncio.sleep(retry_delay)
retry_delay *= 2 # Exponential backoff
provider.get_stock_data.side_effect = resilient_get_data
return provider
async def test_api_failures_and_recovery(
self, resilient_data_provider, benchmark_timer
):
"""Test API failure scenarios and recovery mechanisms."""
symbols = ["AAPL", "GOOGL", "MSFT", "AMZN", "TSLA"]
strategy = "sma_cross"
parameters = STRATEGY_TEMPLATES[strategy]["parameters"]
# Test with different failure rates
failure_scenarios = [
{"name": "low_failure", "rate": 0.1},
{"name": "moderate_failure", "rate": 0.3},
{"name": "high_failure", "rate": 0.6},
]
scenario_results = {}
for scenario in failure_scenarios:
with ChaosInjector.api_failure_injection(failure_rate=scenario["rate"]):
with benchmark_timer() as timer:
results = []
failures = []
engine = VectorBTEngine(data_provider=resilient_data_provider)
for symbol in symbols:
try:
result = await engine.run_backtest(
symbol=symbol,
strategy_type=strategy,
parameters=parameters,
start_date="2023-01-01",
end_date="2023-12-31",
)
results.append(result)
logger.info(
f"✓ {symbol} succeeded under {scenario['name']} conditions"
)
except Exception as e:
failures.append({"symbol": symbol, "error": str(e)})
logger.error(
f"✗ {symbol} failed under {scenario['name']} conditions: {e}"
)
execution_time = timer.elapsed
success_rate = len(results) / len(symbols)
recovery_rate = 1 - (
scenario["rate"] * (1 - success_rate)
) # Account for injected failures
scenario_results[scenario["name"]] = {
"failure_rate_injected": scenario["rate"],
"success_rate_achieved": success_rate,
"recovery_effectiveness": recovery_rate,
"execution_time": execution_time,
"successful_backtests": len(results),
"failed_backtests": len(failures),
}
logger.info(
f"{scenario['name'].upper()} Failure Scenario:\n"
f" • Injected Failure Rate: {scenario['rate']:.1%}\n"
f" • Achieved Success Rate: {success_rate:.1%}\n"
f" • Recovery Effectiveness: {recovery_rate:.1%}\n"
f" • Execution Time: {execution_time:.2f}s"
)
# Assert minimum recovery effectiveness
assert success_rate >= 0.5, (
f"Success rate too low for {scenario['name']}: {success_rate:.1%}"
)
return scenario_results
async def test_database_connection_drops(
self, resilient_data_provider, db_session, benchmark_timer
):
"""Test database connection drops and reconnection logic."""
symbols = ["AAPL", "GOOGL", "MSFT"]
strategy = "rsi"
parameters = STRATEGY_TEMPLATES[strategy]["parameters"]
engine = VectorBTEngine(data_provider=resilient_data_provider)
# Generate backtest results first
backtest_results = []
for symbol in symbols:
result = await engine.run_backtest(
symbol=symbol,
strategy_type=strategy,
parameters=parameters,
start_date="2023-01-01",
end_date="2023-12-31",
)
backtest_results.append(result)
# Test database operations under chaos
with ChaosInjector.database_failure_injection(failure_rate=0.3):
with benchmark_timer() as timer:
persistence_results = []
persistence_failures = []
# Attempt to save results with intermittent database failures
for result in backtest_results:
retry_count = 0
max_retries = 3
while retry_count < max_retries:
try:
with BacktestPersistenceManager(
session=db_session
) as persistence:
backtest_id = persistence.save_backtest_result(
vectorbt_results=result,
execution_time=2.0,
notes=f"Chaos test - {result['symbol']}",
)
persistence_results.append(
{
"symbol": result["symbol"],
"backtest_id": backtest_id,
"retry_count": retry_count,
}
)
break # Success, break retry loop
except Exception as e:
retry_count += 1
if retry_count >= max_retries:
persistence_failures.append(
{
"symbol": result["symbol"],
"error": str(e),
"retry_count": retry_count,
}
)
else:
await asyncio.sleep(
0.1 * retry_count
) # Exponential backoff
persistence_time = timer.elapsed
# Analyze results
persistence_success_rate = len(persistence_results) / len(backtest_results)
avg_retries = (
np.mean([r["retry_count"] for r in persistence_results])
if persistence_results
else 0
)
# Test recovery by attempting to retrieve saved data
retrieval_successes = 0
if persistence_results:
for saved_result in persistence_results:
try:
with BacktestPersistenceManager(session=db_session) as persistence:
retrieved = persistence.get_backtest_by_id(
saved_result["backtest_id"]
)
if retrieved:
retrieval_successes += 1
except Exception as e:
logger.error(f"Retrieval failed for {saved_result['symbol']}: {e}")
retrieval_success_rate = (
retrieval_successes / len(persistence_results) if persistence_results else 0
)
logger.info(
f"Database Connection Drops Test Results:\n"
f" • Backtest Results: {len(backtest_results)}\n"
f" • Persistence Successes: {len(persistence_results)}\n"
f" • Persistence Failures: {len(persistence_failures)}\n"
f" • Persistence Success Rate: {persistence_success_rate:.1%}\n"
f" • Average Retries: {avg_retries:.1f}\n"
f" • Retrieval Success Rate: {retrieval_success_rate:.1%}\n"
f" • Total Time: {persistence_time:.2f}s"
)
# Assert resilience requirements
assert persistence_success_rate >= 0.7, (
f"Persistence success rate too low: {persistence_success_rate:.1%}"
)
assert retrieval_success_rate >= 0.9, (
f"Retrieval success rate too low: {retrieval_success_rate:.1%}"
)
return {
"persistence_success_rate": persistence_success_rate,
"retrieval_success_rate": retrieval_success_rate,
"avg_retries": avg_retries,
}
async def test_cache_failures_and_fallback(
self, resilient_data_provider, benchmark_timer
):
"""Test cache failures and fallback behavior."""
symbols = ["CACHE_TEST_1", "CACHE_TEST_2", "CACHE_TEST_3"]
strategy = "macd"
parameters = STRATEGY_TEMPLATES[strategy]["parameters"]
engine = VectorBTEngine(data_provider=resilient_data_provider)
# Test cache behavior under failures
cache_scenarios = [
{"name": "normal_cache", "inject_failure": False},
{"name": "cache_failures", "inject_failure": True},
]
scenario_results = {}
for scenario in cache_scenarios:
if scenario["inject_failure"]:
# Mock cache to randomly fail
def failing_cache_get(key):
if random.random() < 0.4: # 40% cache failure rate
raise ConnectionError("Cache connection failed")
return None # Cache miss
def failing_cache_set(key, value, ttl=None):
if random.random() < 0.3: # 30% cache set failure rate
raise ConnectionError("Cache set operation failed")
return True
cache_patches = [
patch(
"maverick_mcp.core.cache.CacheManager.get",
side_effect=failing_cache_get,
),
patch(
"maverick_mcp.core.cache.CacheManager.set",
side_effect=failing_cache_set,
),
]
else:
cache_patches = []
with benchmark_timer() as timer:
results = []
cache_errors = []
# Apply cache patches if needed
with ExitStack() as stack:
for patch_context in cache_patches:
stack.enter_context(patch_context)
# Run backtests - should fallback gracefully on cache failures
for symbol in symbols:
try:
result = await engine.run_backtest(
symbol=symbol,
strategy_type=strategy,
parameters=parameters,
start_date="2023-01-01",
end_date="2023-12-31",
)
results.append(result)
except Exception as e:
cache_errors.append({"symbol": symbol, "error": str(e)})
logger.error(
f"Backtest failed for {symbol} under {scenario['name']}: {e}"
)
execution_time = timer.elapsed
success_rate = len(results) / len(symbols)
scenario_results[scenario["name"]] = {
"execution_time": execution_time,
"success_rate": success_rate,
"cache_errors": len(cache_errors),
}
logger.info(
f"{scenario['name'].upper()} Cache Scenario:\n"
f" • Execution Time: {execution_time:.2f}s\n"
f" • Success Rate: {success_rate:.1%}\n"
f" • Cache Errors: {len(cache_errors)}"
)
# Cache failures should not prevent backtests from completing
assert success_rate >= 0.8, (
f"Success rate too low with cache issues: {success_rate:.1%}"
)
# Cache failures might slightly increase execution time but shouldn't break functionality
time_ratio = (
scenario_results["cache_failures"]["execution_time"]
/ scenario_results["normal_cache"]["execution_time"]
)
assert time_ratio < 3.0, (
f"Cache failure time penalty too high: {time_ratio:.1f}x"
)
return scenario_results
async def test_circuit_breaker_behavior(
self, resilient_data_provider, benchmark_timer
):
"""Test circuit breaker behavior under load and failures."""
symbols = ["CB_TEST_1", "CB_TEST_2", "CB_TEST_3", "CB_TEST_4", "CB_TEST_5"]
strategy = "sma_cross"
parameters = STRATEGY_TEMPLATES[strategy]["parameters"]
# Mock circuit breaker states
circuit_breaker_state = {"failures": 0, "state": "CLOSED", "last_failure": 0}
failure_threshold = 3
recovery_timeout = 2.0
def circuit_breaker_wrapper(func):
"""Simple circuit breaker implementation."""
async def wrapper(*args, **kwargs):
current_time = time.time()
# Check if circuit should reset
if (
circuit_breaker_state["state"] == "OPEN"
and current_time - circuit_breaker_state["last_failure"]
> recovery_timeout
):
circuit_breaker_state["state"] = "HALF_OPEN"
logger.info("Circuit breaker moved to HALF_OPEN state")
# Circuit is open, reject immediately
if circuit_breaker_state["state"] == "OPEN":
raise Exception("Circuit breaker is OPEN")
try:
# Inject failures for testing
if random.random() < 0.4: # 40% failure rate
raise ConnectionError("Simulated service failure")
result = await func(*args, **kwargs)
# Success - reset failure count if in HALF_OPEN state
if circuit_breaker_state["state"] == "HALF_OPEN":
circuit_breaker_state["state"] = "CLOSED"
circuit_breaker_state["failures"] = 0
logger.info("Circuit breaker CLOSED after successful recovery")
return result
except Exception as e:
circuit_breaker_state["failures"] += 1
circuit_breaker_state["last_failure"] = current_time
if circuit_breaker_state["failures"] >= failure_threshold:
circuit_breaker_state["state"] = "OPEN"
logger.warning(
f"Circuit breaker OPENED after {circuit_breaker_state['failures']} failures"
)
raise e
return wrapper
# Apply circuit breaker to engine operations
engine = VectorBTEngine(data_provider=resilient_data_provider)
with benchmark_timer() as timer:
results = []
circuit_breaker_trips = 0
recovery_attempts = 0
for _i, symbol in enumerate(symbols):
try:
# Simulate circuit breaker behavior
current_symbol = symbol
@circuit_breaker_wrapper
async def protected_backtest(symbol_to_use=current_symbol):
return await engine.run_backtest(
symbol=symbol_to_use,
strategy_type=strategy,
parameters=parameters,
start_date="2023-01-01",
end_date="2023-12-31",
)
result = await protected_backtest()
results.append(result)
logger.info(
f"✓ {symbol} succeeded (CB state: {circuit_breaker_state['state']})"
)
except Exception as e:
if "Circuit breaker is OPEN" in str(e):
circuit_breaker_trips += 1
logger.warning(f"⚡ {symbol} blocked by circuit breaker")
# Wait for potential recovery
await asyncio.sleep(recovery_timeout + 0.1)
recovery_attempts += 1
# Try once more after recovery timeout
try:
recovery_symbol = symbol
@circuit_breaker_wrapper
async def recovery_backtest(symbol_to_use=recovery_symbol):
return await engine.run_backtest(
symbol=symbol_to_use,
strategy_type=strategy,
parameters=parameters,
start_date="2023-01-01",
end_date="2023-12-31",
)
result = await recovery_backtest()
results.append(result)
logger.info(
f"✓ {symbol} succeeded after circuit breaker recovery"
)
except Exception as recovery_error:
logger.error(
f"✗ {symbol} failed even after recovery: {recovery_error}"
)
else:
logger.error(f"✗ {symbol} failed: {e}")
execution_time = timer.elapsed
success_rate = len(results) / len(symbols)
circuit_breaker_effectiveness = (
circuit_breaker_trips > 0
) # Circuit breaker activated
logger.info(
f"Circuit Breaker Behavior Test Results:\n"
f" • Symbols Tested: {len(symbols)}\n"
f" • Successful Results: {len(results)}\n"
f" • Success Rate: {success_rate:.1%}\n"
f" • Circuit Breaker Trips: {circuit_breaker_trips}\n"
f" • Recovery Attempts: {recovery_attempts}\n"
f" • Circuit Breaker Effectiveness: {circuit_breaker_effectiveness}\n"
f" • Final CB State: {circuit_breaker_state['state']}\n"
f" • Execution Time: {execution_time:.2f}s"
)
# Circuit breaker should provide some protection
assert circuit_breaker_effectiveness, "Circuit breaker should have activated"
assert success_rate >= 0.4, (
f"Success rate too low even with circuit breaker: {success_rate:.1%}"
)
return {
"success_rate": success_rate,
"circuit_breaker_trips": circuit_breaker_trips,
"recovery_attempts": recovery_attempts,
"final_state": circuit_breaker_state["state"],
}
async def test_memory_pressure_resilience(
self, resilient_data_provider, benchmark_timer
):
"""Test system resilience under memory pressure."""
symbols = ["MEM_TEST_1", "MEM_TEST_2", "MEM_TEST_3"]
strategy = "bollinger"
parameters = STRATEGY_TEMPLATES[strategy]["parameters"]
# Test under different memory pressure levels
pressure_levels = [0, 500, 1000] # MB of memory pressure
pressure_results = {}
for pressure_mb in pressure_levels:
with ChaosInjector.memory_pressure_injection(pressure_mb=pressure_mb):
with benchmark_timer() as timer:
results = []
memory_errors = []
engine = VectorBTEngine(data_provider=resilient_data_provider)
for symbol in symbols:
try:
result = await engine.run_backtest(
symbol=symbol,
strategy_type=strategy,
parameters=parameters,
start_date="2023-01-01",
end_date="2023-12-31",
)
results.append(result)
except (MemoryError, Exception) as e:
memory_errors.append({"symbol": symbol, "error": str(e)})
logger.error(
f"Memory pressure caused failure for {symbol}: {e}"
)
execution_time = timer.elapsed
success_rate = len(results) / len(symbols)
pressure_results[f"{pressure_mb}mb"] = {
"pressure_mb": pressure_mb,
"execution_time": execution_time,
"success_rate": success_rate,
"memory_errors": len(memory_errors),
}
logger.info(
f"Memory Pressure {pressure_mb}MB Results:\n"
f" • Execution Time: {execution_time:.2f}s\n"
f" • Success Rate: {success_rate:.1%}\n"
f" • Memory Errors: {len(memory_errors)}"
)
# System should be resilient to moderate memory pressure
moderate_pressure_result = pressure_results["500mb"]
high_pressure_result = pressure_results["1000mb"]
assert moderate_pressure_result["success_rate"] >= 0.8, (
"Should handle moderate memory pressure"
)
assert high_pressure_result["success_rate"] >= 0.5, (
"Should partially handle high memory pressure"
)
return pressure_results
async def test_cpu_overload_resilience(
self, resilient_data_provider, benchmark_timer
):
"""Test system resilience under CPU overload."""
symbols = ["CPU_TEST_1", "CPU_TEST_2"]
strategy = "momentum"
parameters = STRATEGY_TEMPLATES[strategy]["parameters"]
# Test under CPU load
with ChaosInjector.cpu_load_injection(load_intensity=0.8, duration=10.0):
with benchmark_timer() as timer:
results = []
timeout_errors = []
engine = VectorBTEngine(data_provider=resilient_data_provider)
for symbol in symbols:
try:
# Add timeout to prevent hanging under CPU load
result = await asyncio.wait_for(
engine.run_backtest(
symbol=symbol,
strategy_type=strategy,
parameters=parameters,
start_date="2023-01-01",
end_date="2023-12-31",
),
timeout=30.0, # 30 second timeout
)
results.append(result)
except TimeoutError:
timeout_errors.append(
{"symbol": symbol, "error": "CPU overload timeout"}
)
logger.error(f"CPU overload caused timeout for {symbol}")
except Exception as e:
timeout_errors.append({"symbol": symbol, "error": str(e)})
logger.error(f"CPU overload caused failure for {symbol}: {e}")
execution_time = timer.elapsed
success_rate = len(results) / len(symbols)
timeout_rate = len(
[e for e in timeout_errors if "timeout" in e["error"]]
) / len(symbols)
logger.info(
f"CPU Overload Resilience Results:\n"
f" • Symbols Tested: {len(symbols)}\n"
f" • Successful Results: {len(results)}\n"
f" • Success Rate: {success_rate:.1%}\n"
f" • Timeout Rate: {timeout_rate:.1%}\n"
f" • Execution Time: {execution_time:.2f}s"
)
# System should handle some CPU pressure, though performance may degrade
assert success_rate >= 0.5, (
f"Success rate too low under CPU load: {success_rate:.1%}"
)
assert execution_time < 60.0, (
f"Execution time too long under CPU load: {execution_time:.1f}s"
)
return {
"success_rate": success_rate,
"timeout_rate": timeout_rate,
"execution_time": execution_time,
}
async def test_cascading_failure_recovery(
self, resilient_data_provider, benchmark_timer
):
"""Test recovery from cascading failures across multiple components."""
symbols = ["CASCADE_1", "CASCADE_2", "CASCADE_3"]
strategy = "rsi"
parameters = STRATEGY_TEMPLATES[strategy]["parameters"]
# Simulate cascading failures: API -> Cache -> Database
with ChaosInjector.api_failure_injection(failure_rate=0.5):
with ChaosInjector.memory_pressure_injection(pressure_mb=300):
with benchmark_timer() as timer:
results = []
cascading_failures = []
engine = VectorBTEngine(data_provider=resilient_data_provider)
for symbol in symbols:
failure_chain = []
final_result = None
# Multiple recovery attempts with different strategies
for attempt in range(3):
try:
result = await engine.run_backtest(
symbol=symbol,
strategy_type=strategy,
parameters=parameters,
start_date="2023-01-01",
end_date="2023-12-31",
)
final_result = result
break # Success, exit retry loop
except Exception as e:
failure_chain.append(
f"Attempt {attempt + 1}: {str(e)[:50]}"
)
if attempt < 2:
# Progressive backoff and different strategies
await asyncio.sleep(0.5 * (attempt + 1))
if final_result:
results.append(final_result)
logger.info(
f"✓ {symbol} recovered after {len(failure_chain)} failures"
)
else:
cascading_failures.append(
{"symbol": symbol, "failure_chain": failure_chain}
)
logger.error(
f"✗ {symbol} failed completely: {failure_chain}"
)
execution_time = timer.elapsed
recovery_rate = len(results) / len(symbols)
avg_failures_before_recovery = (
np.mean([len(cf["failure_chain"]) for cf in cascading_failures])
if cascading_failures
else 0
)
logger.info(
f"Cascading Failure Recovery Results:\n"
f" • Symbols Tested: {len(symbols)}\n"
f" • Successfully Recovered: {len(results)}\n"
f" • Complete Failures: {len(cascading_failures)}\n"
f" • Recovery Rate: {recovery_rate:.1%}\n"
f" • Avg Failures Before Recovery: {avg_failures_before_recovery:.1f}\n"
f" • Execution Time: {execution_time:.2f}s"
)
# System should show some recovery capability even under cascading failures
assert recovery_rate >= 0.3, (
f"Recovery rate too low for cascading failures: {recovery_rate:.1%}"
)
return {
"recovery_rate": recovery_rate,
"cascading_failures": len(cascading_failures),
"avg_failures_before_recovery": avg_failures_before_recovery,
}
if __name__ == "__main__":
# Run chaos engineering tests
pytest.main(
[
__file__,
"-v",
"--tb=short",
"--asyncio-mode=auto",
"--timeout=900", # 15 minute timeout for chaos tests
"--durations=15", # Show 15 slowest tests
]
)
```