This is page 10 of 29. Use http://codebase.md/wshobson/maverick-mcp?lines=false&page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.md
│ │ ├── config.yml
│ │ ├── feature_request.md
│ │ ├── question.md
│ │ └── security_report.md
│ ├── pull_request_template.md
│ └── workflows
│ ├── claude-code-review.yml
│ └── claude.yml
├── .gitignore
├── .python-version
├── .vscode
│ ├── launch.json
│ └── settings.json
├── alembic
│ ├── env.py
│ ├── script.py.mako
│ └── versions
│ ├── 001_initial_schema.py
│ ├── 003_add_performance_indexes.py
│ ├── 006_rename_metadata_columns.py
│ ├── 008_performance_optimization_indexes.py
│ ├── 009_rename_to_supply_demand.py
│ ├── 010_self_contained_schema.py
│ ├── 011_remove_proprietary_terms.py
│ ├── 013_add_backtest_persistence_models.py
│ ├── 014_add_portfolio_models.py
│ ├── 08e3945a0c93_merge_heads.py
│ ├── 9374a5c9b679_merge_heads_for_testing.py
│ ├── abf9b9afb134_merge_multiple_heads.py
│ ├── adda6d3fd84b_merge_proprietary_terms_removal_with_.py
│ ├── e0c75b0bdadb_fix_financial_data_precision_only.py
│ ├── f0696e2cac15_add_essential_performance_indexes.py
│ └── fix_database_integrity_issues.py
├── alembic.ini
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── DATABASE_SETUP.md
├── docker-compose.override.yml.example
├── docker-compose.yml
├── Dockerfile
├── docs
│ ├── api
│ │ └── backtesting.md
│ ├── BACKTESTING.md
│ ├── COST_BASIS_SPECIFICATION.md
│ ├── deep_research_agent.md
│ ├── exa_research_testing_strategy.md
│ ├── PORTFOLIO_PERSONALIZATION_PLAN.md
│ ├── PORTFOLIO.md
│ ├── SETUP_SELF_CONTAINED.md
│ └── speed_testing_framework.md
├── examples
│ ├── complete_speed_validation.py
│ ├── deep_research_integration.py
│ ├── llm_optimization_example.py
│ ├── llm_speed_demo.py
│ ├── monitoring_example.py
│ ├── parallel_research_example.py
│ ├── speed_optimization_demo.py
│ └── timeout_fix_demonstration.py
├── LICENSE
├── Makefile
├── MANIFEST.in
├── maverick_mcp
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── circuit_breaker.py
│ │ ├── deep_research.py
│ │ ├── market_analysis.py
│ │ ├── optimized_research.py
│ │ ├── supervisor.py
│ │ └── technical_analysis.py
│ ├── api
│ │ ├── __init__.py
│ │ ├── api_server.py
│ │ ├── connection_manager.py
│ │ ├── dependencies
│ │ │ ├── __init__.py
│ │ │ ├── stock_analysis.py
│ │ │ └── technical_analysis.py
│ │ ├── error_handling.py
│ │ ├── inspector_compatible_sse.py
│ │ ├── inspector_sse.py
│ │ ├── middleware
│ │ │ ├── error_handling.py
│ │ │ ├── mcp_logging.py
│ │ │ ├── rate_limiting_enhanced.py
│ │ │ └── security.py
│ │ ├── openapi_config.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── agents.py
│ │ │ ├── backtesting.py
│ │ │ ├── data_enhanced.py
│ │ │ ├── data.py
│ │ │ ├── health_enhanced.py
│ │ │ ├── health_tools.py
│ │ │ ├── health.py
│ │ │ ├── intelligent_backtesting.py
│ │ │ ├── introspection.py
│ │ │ ├── mcp_prompts.py
│ │ │ ├── monitoring.py
│ │ │ ├── news_sentiment_enhanced.py
│ │ │ ├── performance.py
│ │ │ ├── portfolio.py
│ │ │ ├── research.py
│ │ │ ├── screening_ddd.py
│ │ │ ├── screening_parallel.py
│ │ │ ├── screening.py
│ │ │ ├── technical_ddd.py
│ │ │ ├── technical_enhanced.py
│ │ │ ├── technical.py
│ │ │ └── tool_registry.py
│ │ ├── server.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── base_service.py
│ │ │ ├── market_service.py
│ │ │ ├── portfolio_service.py
│ │ │ ├── prompt_service.py
│ │ │ └── resource_service.py
│ │ ├── simple_sse.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── insomnia_export.py
│ │ └── postman_export.py
│ ├── application
│ │ ├── __init__.py
│ │ ├── commands
│ │ │ └── __init__.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_dto.py
│ │ ├── queries
│ │ │ ├── __init__.py
│ │ │ └── get_technical_analysis.py
│ │ └── screening
│ │ ├── __init__.py
│ │ ├── dtos.py
│ │ └── queries.py
│ ├── backtesting
│ │ ├── __init__.py
│ │ ├── ab_testing.py
│ │ ├── analysis.py
│ │ ├── batch_processing_stub.py
│ │ ├── batch_processing.py
│ │ ├── model_manager.py
│ │ ├── optimization.py
│ │ ├── persistence.py
│ │ ├── retraining_pipeline.py
│ │ ├── strategies
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── ml
│ │ │ │ ├── __init__.py
│ │ │ │ ├── adaptive.py
│ │ │ │ ├── ensemble.py
│ │ │ │ ├── feature_engineering.py
│ │ │ │ └── regime_aware.py
│ │ │ ├── ml_strategies.py
│ │ │ ├── parser.py
│ │ │ └── templates.py
│ │ ├── strategy_executor.py
│ │ ├── vectorbt_engine.py
│ │ └── visualization.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── database_self_contained.py
│ │ ├── database.py
│ │ ├── llm_optimization_config.py
│ │ ├── logging_settings.py
│ │ ├── plotly_config.py
│ │ ├── security_utils.py
│ │ ├── security.py
│ │ ├── settings.py
│ │ ├── technical_constants.py
│ │ ├── tool_estimation.py
│ │ └── validation.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── technical_analysis.py
│ │ └── visualization.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── cache_manager.py
│ │ ├── cache.py
│ │ ├── django_adapter.py
│ │ ├── health.py
│ │ ├── models.py
│ │ ├── performance.py
│ │ ├── session_management.py
│ │ └── validation.py
│ ├── database
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── optimization.py
│ ├── dependencies.py
│ ├── domain
│ │ ├── __init__.py
│ │ ├── entities
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis.py
│ │ ├── events
│ │ │ └── __init__.py
│ │ ├── portfolio.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ ├── entities.py
│ │ │ ├── services.py
│ │ │ └── value_objects.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_service.py
│ │ ├── stock_analysis
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis_service.py
│ │ └── value_objects
│ │ ├── __init__.py
│ │ └── technical_indicators.py
│ ├── exceptions.py
│ ├── infrastructure
│ │ ├── __init__.py
│ │ ├── cache
│ │ │ └── __init__.py
│ │ ├── caching
│ │ │ ├── __init__.py
│ │ │ └── cache_management_service.py
│ │ ├── connection_manager.py
│ │ ├── data_fetching
│ │ │ ├── __init__.py
│ │ │ └── stock_data_service.py
│ │ ├── health
│ │ │ ├── __init__.py
│ │ │ └── health_checker.py
│ │ ├── persistence
│ │ │ ├── __init__.py
│ │ │ └── stock_repository.py
│ │ ├── providers
│ │ │ └── __init__.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ └── repositories.py
│ │ └── sse_optimizer.py
│ ├── langchain_tools
│ │ ├── __init__.py
│ │ ├── adapters.py
│ │ └── registry.py
│ ├── logging_config.py
│ ├── memory
│ │ ├── __init__.py
│ │ └── stores.py
│ ├── monitoring
│ │ ├── __init__.py
│ │ ├── health_check.py
│ │ ├── health_monitor.py
│ │ ├── integration_example.py
│ │ ├── metrics.py
│ │ ├── middleware.py
│ │ └── status_dashboard.py
│ ├── providers
│ │ ├── __init__.py
│ │ ├── dependencies.py
│ │ ├── factories
│ │ │ ├── __init__.py
│ │ │ ├── config_factory.py
│ │ │ └── provider_factory.py
│ │ ├── implementations
│ │ │ ├── __init__.py
│ │ │ ├── cache_adapter.py
│ │ │ ├── macro_data_adapter.py
│ │ │ ├── market_data_adapter.py
│ │ │ ├── persistence_adapter.py
│ │ │ └── stock_data_adapter.py
│ │ ├── interfaces
│ │ │ ├── __init__.py
│ │ │ ├── cache.py
│ │ │ ├── config.py
│ │ │ ├── macro_data.py
│ │ │ ├── market_data.py
│ │ │ ├── persistence.py
│ │ │ └── stock_data.py
│ │ ├── llm_factory.py
│ │ ├── macro_data.py
│ │ ├── market_data.py
│ │ ├── mocks
│ │ │ ├── __init__.py
│ │ │ ├── mock_cache.py
│ │ │ ├── mock_config.py
│ │ │ ├── mock_macro_data.py
│ │ │ ├── mock_market_data.py
│ │ │ ├── mock_persistence.py
│ │ │ └── mock_stock_data.py
│ │ ├── openrouter_provider.py
│ │ ├── optimized_screening.py
│ │ ├── optimized_stock_data.py
│ │ └── stock_data.py
│ ├── README.md
│ ├── tests
│ │ ├── __init__.py
│ │ ├── README_INMEMORY_TESTS.md
│ │ ├── test_cache_debug.py
│ │ ├── test_fixes_validation.py
│ │ ├── test_in_memory_routers.py
│ │ ├── test_in_memory_server.py
│ │ ├── test_macro_data_provider.py
│ │ ├── test_mailgun_email.py
│ │ ├── test_market_calendar_caching.py
│ │ ├── test_mcp_tool_fixes_pytest.py
│ │ ├── test_mcp_tool_fixes.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_models_functional.py
│ │ ├── test_server.py
│ │ ├── test_stock_data_enhanced.py
│ │ ├── test_stock_data_provider.py
│ │ └── test_technical_analysis.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── performance_monitoring.py
│ │ ├── portfolio_manager.py
│ │ ├── risk_management.py
│ │ └── sentiment_analysis.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── agent_errors.py
│ │ ├── batch_processing.py
│ │ ├── cache_warmer.py
│ │ ├── circuit_breaker_decorators.py
│ │ ├── circuit_breaker_services.py
│ │ ├── circuit_breaker.py
│ │ ├── data_chunking.py
│ │ ├── database_monitoring.py
│ │ ├── debug_utils.py
│ │ ├── fallback_strategies.py
│ │ ├── llm_optimization.py
│ │ ├── logging_example.py
│ │ ├── logging_init.py
│ │ ├── logging.py
│ │ ├── mcp_logging.py
│ │ ├── memory_profiler.py
│ │ ├── monitoring_middleware.py
│ │ ├── monitoring.py
│ │ ├── orchestration_logging.py
│ │ ├── parallel_research.py
│ │ ├── parallel_screening.py
│ │ ├── quick_cache.py
│ │ ├── resource_manager.py
│ │ ├── shutdown.py
│ │ ├── stock_helpers.py
│ │ ├── structured_logger.py
│ │ ├── tool_monitoring.py
│ │ ├── tracing.py
│ │ └── yfinance_pool.py
│ ├── validation
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── data.py
│ │ ├── middleware.py
│ │ ├── portfolio.py
│ │ ├── responses.py
│ │ ├── screening.py
│ │ └── technical.py
│ └── workflows
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── market_analyzer.py
│ │ ├── optimizer_agent.py
│ │ ├── strategy_selector.py
│ │ └── validator_agent.py
│ ├── backtesting_workflow.py
│ └── state.py
├── PLANS.md
├── pyproject.toml
├── pyrightconfig.json
├── README.md
├── scripts
│ ├── dev.sh
│ ├── INSTALLATION_GUIDE.md
│ ├── load_example.py
│ ├── load_market_data.py
│ ├── load_tiingo_data.py
│ ├── migrate_db.py
│ ├── README_TIINGO_LOADER.md
│ ├── requirements_tiingo.txt
│ ├── run_stock_screening.py
│ ├── run-migrations.sh
│ ├── seed_db.py
│ ├── seed_sp500.py
│ ├── setup_database.sh
│ ├── setup_self_contained.py
│ ├── setup_sp500_database.sh
│ ├── test_seeded_data.py
│ ├── test_tiingo_loader.py
│ ├── tiingo_config.py
│ └── validate_setup.py
├── SECURITY.md
├── server.json
├── setup.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── core
│ │ └── test_technical_analysis.py
│ ├── data
│ │ └── test_portfolio_models.py
│ ├── domain
│ │ ├── conftest.py
│ │ ├── test_portfolio_entities.py
│ │ └── test_technical_analysis_service.py
│ ├── fixtures
│ │ └── orchestration_fixtures.py
│ ├── integration
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── README.md
│ │ ├── run_integration_tests.sh
│ │ ├── test_api_technical.py
│ │ ├── test_chaos_engineering.py
│ │ ├── test_config_management.py
│ │ ├── test_full_backtest_workflow_advanced.py
│ │ ├── test_full_backtest_workflow.py
│ │ ├── test_high_volume.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_orchestration_complete.py
│ │ ├── test_portfolio_persistence.py
│ │ ├── test_redis_cache.py
│ │ ├── test_security_integration.py.disabled
│ │ └── vcr_setup.py
│ ├── performance
│ │ ├── __init__.py
│ │ ├── test_benchmarks.py
│ │ ├── test_load.py
│ │ ├── test_profiling.py
│ │ └── test_stress.py
│ ├── providers
│ │ └── test_stock_data_simple.py
│ ├── README.md
│ ├── test_agents_router_mcp.py
│ ├── test_backtest_persistence.py
│ ├── test_cache_management_service.py
│ ├── test_cache_serialization.py
│ ├── test_circuit_breaker.py
│ ├── test_database_pool_config_simple.py
│ ├── test_database_pool_config.py
│ ├── test_deep_research_functional.py
│ ├── test_deep_research_integration.py
│ ├── test_deep_research_parallel_execution.py
│ ├── test_error_handling.py
│ ├── test_event_loop_integrity.py
│ ├── test_exa_research_integration.py
│ ├── test_exception_hierarchy.py
│ ├── test_financial_search.py
│ ├── test_graceful_shutdown.py
│ ├── test_integration_simple.py
│ ├── test_langgraph_workflow.py
│ ├── test_market_data_async.py
│ ├── test_market_data_simple.py
│ ├── test_mcp_orchestration_functional.py
│ ├── test_ml_strategies.py
│ ├── test_optimized_research_agent.py
│ ├── test_orchestration_integration.py
│ ├── test_orchestration_logging.py
│ ├── test_orchestration_tools_simple.py
│ ├── test_parallel_research_integration.py
│ ├── test_parallel_research_orchestrator.py
│ ├── test_parallel_research_performance.py
│ ├── test_performance_optimizations.py
│ ├── test_production_validation.py
│ ├── test_provider_architecture.py
│ ├── test_rate_limiting_enhanced.py
│ ├── test_runner_validation.py
│ ├── test_security_comprehensive.py.disabled
│ ├── test_security_cors.py
│ ├── test_security_enhancements.py.disabled
│ ├── test_security_headers.py
│ ├── test_security_penetration.py
│ ├── test_session_management.py
│ ├── test_speed_optimization_validation.py
│ ├── test_stock_analysis_dependencies.py
│ ├── test_stock_analysis_service.py
│ ├── test_stock_data_fetching_service.py
│ ├── test_supervisor_agent.py
│ ├── test_supervisor_functional.py
│ ├── test_tool_estimation_config.py
│ ├── test_visualization.py
│ └── utils
│ ├── test_agent_errors.py
│ ├── test_logging.py
│ ├── test_parallel_screening.py
│ └── test_quick_cache.py
├── tools
│ ├── check_orchestration_config.py
│ ├── experiments
│ │ ├── validation_examples.py
│ │ └── validation_fixed.py
│ ├── fast_dev.sh
│ ├── hot_reload.py
│ ├── quick_test.py
│ └── templates
│ ├── new_router_template.py
│ ├── new_tool_template.py
│ ├── screening_strategy_template.py
│ └── test_template.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/maverick_mcp/api/middleware/mcp_logging.py:
--------------------------------------------------------------------------------
```python
"""
Comprehensive MCP Logging Middleware for debugging tool calls and protocol communication.
This middleware provides:
- Tool call lifecycle logging
- MCP protocol message logging
- Request/response payload logging
- Error tracking with full context
- Performance metrics collection
- Timeout detection and logging
"""
import asyncio
import functools
import json
import logging
import time
import traceback
import uuid
from typing import Any
from fastmcp import FastMCP
try:
from fastmcp.server.middleware import Middleware, MiddlewareContext
MIDDLEWARE_AVAILABLE = True
except ImportError:
# Fallback for older FastMCP versions
MIDDLEWARE_AVAILABLE = False
class Middleware: # type: ignore
"""Fallback Middleware class for older FastMCP versions."""
pass
MiddlewareContext = Any
from maverick_mcp.utils.logging import (
get_logger,
request_id_var,
request_start_var,
tool_name_var,
)
logger = get_logger("maverick_mcp.middleware.mcp_logging")
class MCPLoggingMiddleware(Middleware if MIDDLEWARE_AVAILABLE else object):
"""
Comprehensive MCP protocol and tool call logging middleware for FastMCP 2.0+.
Logs:
- Tool call lifecycle with execution details
- Resource access and prompt retrievals
- Error conditions with full context
- Performance metrics (execution time, memory usage)
- Timeout detection and warnings
"""
def __init__(
self,
include_payloads: bool = True,
max_payload_length: int = 2000,
log_level: int = logging.INFO,
):
if MIDDLEWARE_AVAILABLE:
super().__init__()
self.include_payloads = include_payloads
self.max_payload_length = max_payload_length
self.log_level = log_level
self.logger = get_logger("maverick_mcp.mcp_protocol")
async def on_call_tool(self, context: MiddlewareContext, call_next) -> Any:
"""Log tool call lifecycle with comprehensive details."""
if not MIDDLEWARE_AVAILABLE:
return await call_next(context)
request_id = str(uuid.uuid4())
request_start_var.set(time.time())
request_id_var.set(request_id)
start_time = time.time()
tool_name = getattr(context.message, "name", "unknown_tool")
tool_name_var.set(tool_name)
# Extract arguments if available
arguments = getattr(context.message, "arguments", {})
# Log tool call start
self._log_tool_call_start(request_id, tool_name, arguments)
try:
# Execute with timeout detection
result = await asyncio.wait_for(call_next(context), timeout=25.0)
# Log successful completion
execution_time = time.time() - start_time
self._log_tool_call_success(request_id, tool_name, result, execution_time)
return result
except TimeoutError:
execution_time = time.time() - start_time
self._log_tool_call_timeout(request_id, tool_name, execution_time)
raise
except Exception as e:
# Log error with full context
execution_time = time.time() - start_time
self._log_tool_call_error(
request_id, tool_name, e, execution_time, arguments
)
raise
async def on_read_resource(self, context: MiddlewareContext, call_next) -> Any:
"""Log resource access."""
if not MIDDLEWARE_AVAILABLE:
return await call_next(context)
resource_uri = getattr(context.message, "uri", "unknown_resource")
start_time = time.time()
print(f"🔗 RESOURCE ACCESS: {resource_uri}")
try:
result = await call_next(context)
execution_time = time.time() - start_time
print(f"✅ RESOURCE SUCCESS: {resource_uri} ({execution_time:.2f}s)")
return result
except Exception as e:
execution_time = time.time() - start_time
print(
f"❌ RESOURCE ERROR: {resource_uri} ({execution_time:.2f}s) - {type(e).__name__}: {str(e)}"
)
raise
def _log_tool_call_start(self, request_id: str, tool_name: str, arguments: dict):
"""Log tool call initiation."""
log_data = {
"request_id": request_id,
"direction": "incoming",
"tool_name": tool_name,
"timestamp": time.time(),
}
# Add arguments if requested (debug mode)
if self.include_payloads and arguments:
try:
args_str = json.dumps(arguments)[: self.max_payload_length]
log_data["arguments"] = args_str
except Exception as e:
log_data["args_error"] = str(e)
self.logger.info("TOOL_CALL_START", extra=log_data)
# Console output for immediate visibility
args_preview = ""
if arguments:
args_str = str(arguments)
args_preview = f" with {args_str[:50]}{'...' if len(args_str) > 50 else ''}"
print(f"🔧 TOOL CALL: {tool_name}{args_preview} [{request_id[:8]}]")
def _log_tool_call_success(
self, request_id: str, tool_name: str, result: Any, execution_time: float
):
"""Log successful tool completion."""
log_data = {
"request_id": request_id,
"direction": "outgoing",
"tool_name": tool_name,
"execution_time": execution_time,
"status": "success",
"timestamp": time.time(),
}
# Add result preview if requested (debug mode)
if self.include_payloads and result is not None:
try:
result_str = (
json.dumps(result)[: self.max_payload_length]
if not isinstance(result, str)
else result[: self.max_payload_length]
)
log_data["result_preview"] = result_str
log_data["result_type"] = type(result).__name__
except Exception as e:
log_data["result_error"] = str(e)
self.logger.info("TOOL_CALL_SUCCESS", extra=log_data)
# Console output with color coding based on execution time
status_icon = (
"🟢" if execution_time < 5.0 else "🟡" if execution_time < 15.0 else "🟠"
)
print(
f"{status_icon} TOOL SUCCESS: {tool_name} [{request_id[:8]}] {execution_time:.2f}s"
)
def _log_tool_call_timeout(
self, request_id: str, tool_name: str, execution_time: float
):
"""Log tool timeout."""
log_data = {
"request_id": request_id,
"direction": "outgoing",
"tool_name": tool_name,
"execution_time": execution_time,
"status": "timeout",
"timeout_seconds": 25.0,
"error_type": "timeout",
"timestamp": time.time(),
}
self.logger.error("TOOL_CALL_TIMEOUT", extra=log_data)
print(
f"⏰ TOOL TIMEOUT: {tool_name} [{request_id[:8]}] {execution_time:.2f}s (exceeded 25s limit)"
)
def _log_tool_call_error(
self,
request_id: str,
tool_name: str,
error: Exception,
execution_time: float,
arguments: dict,
):
"""Log tool error with full context."""
log_data = {
"request_id": request_id,
"direction": "outgoing",
"tool_name": tool_name,
"execution_time": execution_time,
"status": "error",
"error_type": type(error).__name__,
"error_message": str(error),
"traceback": traceback.format_exc(),
"timestamp": time.time(),
}
# Add arguments for debugging
if self.include_payloads and arguments:
try:
log_data["arguments"] = json.dumps(arguments)[: self.max_payload_length]
except Exception as e:
log_data["args_error"] = str(e)
self.logger.error("TOOL_CALL_ERROR", extra=log_data)
# Console output with error details
print(
f"❌ TOOL ERROR: {tool_name} [{request_id[:8]}] {execution_time:.2f}s - {type(error).__name__}: {str(error)}"
)
class ToolExecutionLogger:
"""
Specific logger for individual tool execution steps.
Use this within tools to log execution progress and debug issues.
"""
def __init__(self, tool_name: str, request_id: str | None = None):
self.tool_name = tool_name
self.request_id = request_id or request_id_var.get() or str(uuid.uuid4())
self.logger = get_logger(f"maverick_mcp.tools.{tool_name}")
self.start_time = time.time()
self.step_times = {}
def step(self, step_name: str, message: str | None = None):
"""Log a step in tool execution."""
current_time = time.time()
step_duration = current_time - self.start_time
self.step_times[step_name] = step_duration
log_message = message or f"Executing step: {step_name}"
self.logger.info(
log_message,
extra={
"request_id": self.request_id,
"tool_name": self.tool_name,
"step": step_name,
"step_duration": step_duration,
"total_duration": current_time - self.start_time,
},
)
# Console progress indicator
print(f" 📊 {self.tool_name} -> {step_name} ({step_duration:.2f}s)")
def error(self, step_name: str, error: Exception, message: str | None = None):
"""Log an error in tool execution."""
current_time = time.time()
step_duration = current_time - self.start_time
log_message = message or f"Error in step: {step_name}"
self.logger.error(
log_message,
extra={
"request_id": self.request_id,
"tool_name": self.tool_name,
"step": step_name,
"step_duration": step_duration,
"total_duration": current_time - self.start_time,
"error_type": type(error).__name__,
"error_message": str(error),
"traceback": traceback.format_exc(),
},
)
# Console error indicator
print(
f" ❌ {self.tool_name} -> {step_name} ERROR: {type(error).__name__}: {str(error)}"
)
def complete(self, result_summary: str | None = None):
"""Log completion of tool execution."""
total_duration = time.time() - self.start_time
log_message = result_summary or "Tool execution completed"
self.logger.info(
log_message,
extra={
"request_id": self.request_id,
"tool_name": self.tool_name,
"total_duration": total_duration,
"step_times": self.step_times,
"status": "completed",
},
)
# Console completion
print(f" ✅ {self.tool_name} completed ({total_duration:.2f}s)")
def add_mcp_logging_middleware(
server: FastMCP,
include_payloads: bool = True,
max_payload_length: int = 2000,
log_level: int = logging.INFO,
):
"""
Add comprehensive MCP logging middleware to a FastMCP server.
Args:
server: FastMCP server instance
include_payloads: Whether to log request/response payloads (debug mode)
max_payload_length: Maximum length of logged payloads
log_level: Minimum logging level
"""
if not MIDDLEWARE_AVAILABLE:
logger.warning("FastMCP middleware not available - requires FastMCP 2.9+")
print("⚠️ FastMCP middleware not available - tool logging will be limited")
return
middleware = MCPLoggingMiddleware(
include_payloads=include_payloads,
max_payload_length=max_payload_length,
log_level=log_level,
)
# Use the correct FastMCP 2.0 middleware registration method
try:
if hasattr(server, "add_middleware"):
server.add_middleware(middleware)
logger.info("✅ FastMCP 2.0 middleware registered successfully")
elif hasattr(server, "middleware"):
# Fallback for different API structure
if isinstance(server.middleware, list):
server.middleware.append(middleware)
else:
server.middleware = [middleware]
logger.info("✅ FastMCP middleware registered via fallback method")
else:
# Manual middleware application as decorator
logger.warning("Using decorator-style middleware registration")
_apply_middleware_as_decorators(server, middleware)
except Exception as e:
logger.error(f"Failed to register FastMCP middleware: {e}")
print(f"⚠️ Middleware registration failed: {e}")
logger.info(
"MCP logging middleware setup completed",
extra={
"include_payloads": include_payloads,
"max_payload_length": max_payload_length,
"log_level": logging.getLevelName(log_level),
},
)
def _apply_middleware_as_decorators(server: FastMCP, middleware: MCPLoggingMiddleware):
"""Apply middleware functionality via decorators if direct middleware isn't available."""
# This is a fallback approach - wrap tool execution with logging
original_tool_method = server.tool
def logging_tool_decorator(*args, **kwargs):
def decorator(func):
# Wrap the original tool function with logging
@functools.wraps(func)
async def wrapper(*func_args, **func_kwargs):
# Simple console logging as fallback
func_name = getattr(func, "__name__", "unknown_tool")
print(f"🔧 TOOL CALL: {func_name}")
start_time = time.time()
try:
result = await func(*func_args, **func_kwargs)
execution_time = time.time() - start_time
print(f"🟢 TOOL SUCCESS: {func_name} ({execution_time:.2f}s)")
return result
except Exception as e:
execution_time = time.time() - start_time
print(
f"❌ TOOL ERROR: {func_name} ({execution_time:.2f}s) - {type(e).__name__}: {str(e)}"
)
raise
# Register the wrapped function
return original_tool_method(*args, **kwargs)(wrapper)
return decorator
# Replace the server's tool decorator
server.tool = logging_tool_decorator
logger.info("Applied middleware as tool decorators (fallback mode)")
# Convenience function for tool developers
def get_tool_logger(tool_name: str) -> ToolExecutionLogger:
"""Get a tool execution logger for the current request."""
return ToolExecutionLogger(tool_name)
```
--------------------------------------------------------------------------------
/maverick_mcp/domain/portfolio.py:
--------------------------------------------------------------------------------
```python
"""
Portfolio domain entities for MaverickMCP.
This module implements pure business logic for portfolio management following
Domain-Driven Design (DDD) principles. These entities are framework-independent
and contain the core portfolio logic including cost basis averaging and P&L calculations.
Cost Basis Method: Average Cost
- Simplest for educational purposes
- Total cost / total shares
- Does not change on partial sales
"""
from dataclasses import dataclass, field
from datetime import UTC, datetime
from decimal import ROUND_HALF_UP, Decimal
from typing import Optional
@dataclass
class Position:
"""
Value object representing a single portfolio position.
A position tracks shares held in a specific ticker with cost basis information.
Uses immutable operations - modifications return new Position instances.
Attributes:
ticker: Stock ticker symbol (e.g., "AAPL")
shares: Number of shares owned (supports fractional shares)
average_cost_basis: Average cost per share
total_cost: Total capital invested (shares × average_cost_basis)
purchase_date: Earliest purchase date for this position
notes: Optional user notes about the position
"""
ticker: str
shares: Decimal
average_cost_basis: Decimal
total_cost: Decimal
purchase_date: datetime
notes: str | None = None
def __post_init__(self) -> None:
"""Validate position invariants after initialization."""
if self.shares <= 0:
raise ValueError(f"Shares must be positive, got {self.shares}")
if self.average_cost_basis <= 0:
raise ValueError(
f"Average cost basis must be positive, got {self.average_cost_basis}"
)
if self.total_cost <= 0:
raise ValueError(f"Total cost must be positive, got {self.total_cost}")
# Normalize ticker to uppercase
object.__setattr__(self, "ticker", self.ticker.upper())
def add_shares(self, shares: Decimal, price: Decimal, date: datetime) -> "Position":
"""
Add shares to position with automatic cost basis averaging.
This creates a new Position instance with updated shares and averaged cost basis.
The average cost method is used: (total_cost + new_cost) / total_shares
Args:
shares: Number of shares to add (must be > 0)
price: Purchase price per share (must be > 0)
date: Purchase date
Returns:
New Position instance with averaged cost basis
Raises:
ValueError: If shares or price is not positive
Example:
>>> pos = Position("AAPL", Decimal("10"), Decimal("150"), Decimal("1500"), datetime.now())
>>> pos = pos.add_shares(Decimal("10"), Decimal("170"), datetime.now())
>>> pos.shares
Decimal('20')
>>> pos.average_cost_basis
Decimal('160.00')
"""
if shares <= 0:
raise ValueError(f"Shares to add must be positive, got {shares}")
if price <= 0:
raise ValueError(f"Price must be positive, got {price}")
new_total_shares = self.shares + shares
new_total_cost = self.total_cost + (shares * price)
new_avg_cost = (new_total_cost / new_total_shares).quantize(
Decimal("0.0001"), rounding=ROUND_HALF_UP
)
return Position(
ticker=self.ticker,
shares=new_total_shares,
average_cost_basis=new_avg_cost,
total_cost=new_total_cost,
purchase_date=min(self.purchase_date, date),
notes=self.notes,
)
def remove_shares(self, shares: Decimal) -> Optional["Position"]:
"""
Remove shares from position.
Returns None if the removal would close the position entirely (sold_shares >= held_shares).
For partial sales, average cost basis remains unchanged.
Args:
shares: Number of shares to remove (must be > 0)
Returns:
New Position instance with reduced shares, or None if position closed
Raises:
ValueError: If shares is not positive
Example:
>>> pos = Position("AAPL", Decimal("20"), Decimal("160"), Decimal("3200"), datetime.now())
>>> pos = pos.remove_shares(Decimal("10"))
>>> pos.shares
Decimal('10')
>>> pos.average_cost_basis # Unchanged
Decimal('160.00')
"""
if shares <= 0:
raise ValueError(f"Shares to remove must be positive, got {shares}")
if shares >= self.shares:
# Full position close
return None
new_shares = self.shares - shares
new_total_cost = new_shares * self.average_cost_basis
return Position(
ticker=self.ticker,
shares=new_shares,
average_cost_basis=self.average_cost_basis,
total_cost=new_total_cost,
purchase_date=self.purchase_date,
notes=self.notes,
)
def calculate_current_value(self, current_price: Decimal) -> dict[str, Decimal]:
"""
Calculate live position value and P&L metrics.
Args:
current_price: Current market price per share
Returns:
Dictionary containing:
- current_value: Current market value (shares × price)
- unrealized_pnl: Unrealized profit/loss (current_value - total_cost)
- pnl_percentage: P&L as percentage of total cost
Example:
>>> pos = Position("AAPL", Decimal("20"), Decimal("160"), Decimal("3200"), datetime.now())
>>> metrics = pos.calculate_current_value(Decimal("175.50"))
>>> metrics["current_value"]
Decimal('3510.00')
>>> metrics["unrealized_pnl"]
Decimal('310.00')
>>> metrics["pnl_percentage"]
Decimal('9.6875')
"""
current_value = (self.shares * current_price).quantize(
Decimal("0.01"), rounding=ROUND_HALF_UP
)
unrealized_pnl = (current_value - self.total_cost).quantize(
Decimal("0.01"), rounding=ROUND_HALF_UP
)
if self.total_cost > 0:
pnl_percentage = (unrealized_pnl / self.total_cost * 100).quantize(
Decimal("0.01"), rounding=ROUND_HALF_UP
)
else:
pnl_percentage = Decimal("0.00")
return {
"current_value": current_value,
"unrealized_pnl": unrealized_pnl,
"pnl_percentage": pnl_percentage,
}
def to_dict(self) -> dict:
"""
Convert position to dictionary for serialization.
Returns:
Dictionary representation with float values for JSON compatibility
"""
return {
"ticker": self.ticker,
"shares": float(self.shares),
"average_cost_basis": float(self.average_cost_basis),
"total_cost": float(self.total_cost),
"purchase_date": self.purchase_date.isoformat(),
"notes": self.notes,
}
@dataclass
class Portfolio:
"""
Aggregate root for user portfolio.
Manages a collection of positions with operations for adding, removing, and analyzing
holdings. Enforces business rules and maintains consistency.
Attributes:
portfolio_id: Unique identifier (UUID as string)
user_id: User identifier (default: "default" for single-user system)
name: Portfolio display name
positions: List of Position value objects
created_at: Portfolio creation timestamp
updated_at: Last modification timestamp
"""
portfolio_id: str
user_id: str
name: str
positions: list[Position] = field(default_factory=list)
created_at: datetime = field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = field(default_factory=lambda: datetime.now(UTC))
def add_position(
self,
ticker: str,
shares: Decimal,
price: Decimal,
date: datetime,
notes: str | None = None,
) -> None:
"""
Add or update position with automatic cost basis averaging.
If the ticker already exists, shares are added and cost basis is averaged.
Otherwise, a new position is created.
Args:
ticker: Stock ticker symbol
shares: Number of shares to add
price: Purchase price per share
date: Purchase date
notes: Optional notes (only used for new positions)
Example:
>>> portfolio = Portfolio("id", "default", "My Portfolio")
>>> portfolio.add_position("AAPL", Decimal("10"), Decimal("150"), datetime.now())
>>> portfolio.add_position("AAPL", Decimal("10"), Decimal("170"), datetime.now())
>>> portfolio.get_position("AAPL").shares
Decimal('20')
"""
ticker = ticker.upper()
# Find existing position
for i, pos in enumerate(self.positions):
if pos.ticker == ticker:
self.positions[i] = pos.add_shares(shares, price, date)
self.updated_at = datetime.now(UTC)
return
# Create new position
new_position = Position(
ticker=ticker,
shares=shares,
average_cost_basis=price,
total_cost=shares * price,
purchase_date=date,
notes=notes,
)
self.positions.append(new_position)
self.updated_at = datetime.now(UTC)
def remove_position(self, ticker: str, shares: Decimal | None = None) -> bool:
"""
Remove position or reduce shares.
Args:
ticker: Stock ticker symbol
shares: Number of shares to remove (None = remove entire position)
Returns:
True if position was found and removed/reduced, False otherwise
Example:
>>> portfolio.remove_position("AAPL", Decimal("10")) # Partial
True
>>> portfolio.remove_position("AAPL") # Full removal
True
"""
ticker = ticker.upper()
for i, pos in enumerate(self.positions):
if pos.ticker == ticker:
if shares is None or shares >= pos.shares:
# Full position removal
self.positions.pop(i)
else:
# Partial removal
updated_pos = pos.remove_shares(shares)
if updated_pos:
self.positions[i] = updated_pos
else:
self.positions.pop(i)
self.updated_at = datetime.now(UTC)
return True
return False
def get_position(self, ticker: str) -> Position | None:
"""
Get position by ticker symbol.
Args:
ticker: Stock ticker symbol (case-insensitive)
Returns:
Position if found, None otherwise
"""
ticker = ticker.upper()
return next((pos for pos in self.positions if pos.ticker == ticker), None)
def get_total_invested(self) -> Decimal:
"""
Calculate total capital invested across all positions.
Returns:
Sum of all position total costs
"""
return sum((pos.total_cost for pos in self.positions), Decimal("0"))
def calculate_portfolio_metrics(self, current_prices: dict[str, Decimal]) -> dict:
"""
Calculate comprehensive portfolio metrics with live prices.
Args:
current_prices: Dictionary mapping ticker symbols to current prices
Returns:
Dictionary containing:
- total_value: Current market value of all positions
- total_invested: Total capital invested
- total_pnl: Total unrealized profit/loss
- total_pnl_percentage: Total P&L as percentage
- position_count: Number of positions
- positions: List of position details with current metrics
Example:
>>> prices = {"AAPL": Decimal("175.50"), "MSFT": Decimal("380.00")}
>>> metrics = portfolio.calculate_portfolio_metrics(prices)
>>> metrics["total_value"]
15250.50
"""
total_value = Decimal("0")
total_cost = Decimal("0")
position_details = []
for pos in self.positions:
# Use current price if available, otherwise fall back to cost basis
current_price = current_prices.get(pos.ticker, pos.average_cost_basis)
metrics = pos.calculate_current_value(current_price)
total_value += metrics["current_value"]
total_cost += pos.total_cost
position_details.append(
{
"ticker": pos.ticker,
"shares": float(pos.shares),
"cost_basis": float(pos.average_cost_basis),
"current_price": float(current_price),
"current_value": float(metrics["current_value"]),
"unrealized_pnl": float(metrics["unrealized_pnl"]),
"pnl_percentage": float(metrics["pnl_percentage"]),
"purchase_date": pos.purchase_date.isoformat(),
"notes": pos.notes,
}
)
total_pnl = total_value - total_cost
total_pnl_pct = (
(total_pnl / total_cost * 100).quantize(
Decimal("0.01"), rounding=ROUND_HALF_UP
)
if total_cost > 0
else Decimal("0.00")
)
return {
"total_value": float(total_value),
"total_invested": float(total_cost),
"total_pnl": float(total_pnl),
"total_pnl_percentage": float(total_pnl_pct),
"position_count": len(self.positions),
"positions": position_details,
}
def clear_all_positions(self) -> None:
"""
Remove all positions from the portfolio.
⚠️ WARNING: This operation cannot be undone.
"""
self.positions.clear()
self.updated_at = datetime.now(UTC)
def to_dict(self) -> dict:
"""
Convert portfolio to dictionary for serialization.
Returns:
Dictionary representation suitable for JSON serialization
"""
return {
"portfolio_id": self.portfolio_id,
"user_id": self.user_id,
"name": self.name,
"positions": [pos.to_dict() for pos in self.positions],
"position_count": len(self.positions),
"total_invested": float(self.get_total_invested()),
"created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(),
}
```
--------------------------------------------------------------------------------
/maverick_mcp/api/error_handling.py:
--------------------------------------------------------------------------------
```python
"""
Enhanced error handling framework for MaverickMCP API.
This module provides centralized error handling with structured responses,
proper logging, monitoring integration, and client-friendly error messages.
"""
import asyncio
import uuid
from collections.abc import Callable
from typing import Any
from fastapi import HTTPException, Request, status
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from sqlalchemy.exc import IntegrityError, OperationalError
from maverick_mcp.exceptions import (
APIRateLimitError,
AuthenticationError,
AuthorizationError,
CacheConnectionError,
CircuitBreakerError,
ConflictError,
DatabaseConnectionError,
DataIntegrityError,
DataNotFoundError,
ExternalServiceError,
MaverickException,
NotFoundError,
RateLimitError,
ValidationError,
WebhookError,
)
from maverick_mcp.utils.logging import get_logger
from maverick_mcp.utils.monitoring import get_monitoring_service
from maverick_mcp.validation.responses import error_response, validation_error_response
logger = get_logger(__name__)
monitoring = get_monitoring_service()
class ErrorHandler:
"""Centralized error handler with monitoring integration."""
def __init__(self):
self.error_mappings = self._build_error_mappings()
def _build_error_mappings(self) -> dict[type[Exception], dict[str, Any]]:
"""Build mapping of exception types to response details."""
return {
# MaverickMCP exceptions
ValidationError: {
"status_code": status.HTTP_422_UNPROCESSABLE_ENTITY,
"code": "VALIDATION_ERROR",
"log_level": "warning",
},
AuthenticationError: {
"status_code": status.HTTP_401_UNAUTHORIZED,
"code": "AUTHENTICATION_ERROR",
"log_level": "warning",
},
AuthorizationError: {
"status_code": status.HTTP_403_FORBIDDEN,
"code": "AUTHORIZATION_ERROR",
"log_level": "warning",
},
DataNotFoundError: {
"status_code": status.HTTP_404_NOT_FOUND,
"code": "DATA_NOT_FOUND",
"log_level": "info",
},
APIRateLimitError: {
"status_code": status.HTTP_429_TOO_MANY_REQUESTS,
"code": "RATE_LIMIT_EXCEEDED",
"log_level": "warning",
},
CircuitBreakerError: {
"status_code": status.HTTP_503_SERVICE_UNAVAILABLE,
"code": "SERVICE_UNAVAILABLE",
"log_level": "error",
},
DatabaseConnectionError: {
"status_code": status.HTTP_503_SERVICE_UNAVAILABLE,
"code": "DATABASE_CONNECTION_ERROR",
"log_level": "error",
},
CacheConnectionError: {
"status_code": status.HTTP_503_SERVICE_UNAVAILABLE,
"code": "CACHE_CONNECTION_ERROR",
"log_level": "error",
},
DataIntegrityError: {
"status_code": status.HTTP_409_CONFLICT,
"code": "DATA_INTEGRITY_ERROR",
"log_level": "error",
},
# API errors from validation module
NotFoundError: {
"status_code": status.HTTP_404_NOT_FOUND,
"code": "NOT_FOUND",
"log_level": "info",
},
ConflictError: {
"status_code": status.HTTP_409_CONFLICT,
"code": "CONFLICT",
"log_level": "warning",
},
RateLimitError: {
"status_code": status.HTTP_429_TOO_MANY_REQUESTS,
"code": "RATE_LIMIT_EXCEEDED",
"log_level": "warning",
},
ExternalServiceError: {
"status_code": status.HTTP_503_SERVICE_UNAVAILABLE,
"code": "EXTERNAL_SERVICE_ERROR",
"log_level": "error",
},
WebhookError: {
"status_code": status.HTTP_400_BAD_REQUEST,
"code": "WEBHOOK_ERROR",
"log_level": "warning",
},
# SQLAlchemy exceptions
IntegrityError: {
"status_code": status.HTTP_409_CONFLICT,
"code": "DATABASE_INTEGRITY_ERROR",
"log_level": "error",
},
OperationalError: {
"status_code": status.HTTP_503_SERVICE_UNAVAILABLE,
"code": "DATABASE_OPERATIONAL_ERROR",
"log_level": "error",
},
# Third-party API exceptions
ValueError: {
"status_code": status.HTTP_400_BAD_REQUEST,
"code": "INVALID_REQUEST",
"log_level": "warning",
},
KeyError: {
"status_code": status.HTTP_400_BAD_REQUEST,
"code": "MISSING_REQUIRED_FIELD",
"log_level": "warning",
},
TypeError: {
"status_code": status.HTTP_400_BAD_REQUEST,
"code": "TYPE_ERROR",
"log_level": "warning",
},
}
def handle_exception(
self,
request: Request,
exception: Exception,
context: dict[str, Any] | None = None,
) -> JSONResponse:
"""
Handle exception and return structured error response.
Args:
request: FastAPI request object
exception: The exception to handle
context: Additional context for logging
Returns:
JSONResponse with structured error
"""
# Generate trace ID for this error
trace_id = str(uuid.uuid4())
# Get error details from mapping
error_info = self._get_error_info(exception)
# Log the error with full context
self._log_error(
exception=exception,
trace_id=trace_id,
request=request,
error_info=error_info,
context=context,
)
# Send to monitoring service
self._send_to_monitoring(
exception=exception,
trace_id=trace_id,
request=request,
context=context,
)
# Build client-friendly response
response_data = self._build_error_response(
exception=exception,
error_info=error_info,
trace_id=trace_id,
)
return JSONResponse(
status_code=error_info["status_code"],
content=response_data,
)
def _get_error_info(self, exception: Exception) -> dict[str, Any]:
"""Get error information for the exception type."""
# Check for exact type match first
exc_type = type(exception)
if exc_type in self.error_mappings:
return self.error_mappings[exc_type]
# Check for inheritance
for error_type, info in self.error_mappings.items():
if isinstance(exception, error_type):
return info
# Default for unknown exceptions
return {
"status_code": status.HTTP_500_INTERNAL_SERVER_ERROR,
"code": "INTERNAL_ERROR",
"log_level": "error",
}
def _log_error(
self,
exception: Exception,
trace_id: str,
request: Request,
error_info: dict[str, Any],
context: dict[str, Any] | None = None,
) -> None:
"""Log error with full context."""
log_data = {
"trace_id": trace_id,
"error_type": type(exception).__name__,
"error_code": error_info["code"],
"status_code": error_info["status_code"],
"method": request.method,
"path": request.url.path,
"client_host": request.client.host if request.client else None,
"user_agent": request.headers.get("user-agent"),
}
# Add exception details if available
if isinstance(exception, MaverickException):
log_data["error_details"] = exception.to_dict()
# Add custom context
if context:
log_data["context"] = context
# Log at appropriate level
log_level = error_info["log_level"]
if log_level == "error":
logger.error(
f"Error handling request: {str(exception)}",
exc_info=True,
extra=log_data,
)
elif log_level == "warning":
logger.warning(
f"Request failed: {str(exception)}",
extra=log_data,
)
else:
logger.info(
f"Request rejected: {str(exception)}",
extra=log_data,
)
def _send_to_monitoring(
self,
exception: Exception,
trace_id: str,
request: Request,
context: dict[str, Any] | None = None,
) -> None:
"""Send error to monitoring service (Sentry)."""
monitoring_context = {
"trace_id": trace_id,
"request": {
"method": request.method,
"path": request.url.path,
"query": str(request.url.query),
},
}
if context:
monitoring_context["custom_context"] = context
# Only send certain errors to Sentry
error_info = self._get_error_info(exception)
if error_info["log_level"] in ["error", "warning"]:
monitoring.capture_exception(exception, **monitoring_context)
def _build_error_response(
self,
exception: Exception,
error_info: dict[str, Any],
trace_id: str,
) -> dict[str, Any]:
"""Build client-friendly error response."""
# Extract error details
if isinstance(exception, MaverickException):
message = exception.message
context = exception.context
elif isinstance(exception, HTTPException):
message = exception.detail
context = None
else:
# Generic message for unknown errors
message = self._get_safe_error_message(exception, error_info["code"])
context = None
return error_response(
code=error_info["code"],
message=message,
status_code=error_info["status_code"],
context=context,
trace_id=trace_id,
)
def _get_safe_error_message(self, exception: Exception, code: str) -> str:
"""Get safe error message for client."""
safe_messages = {
"INTERNAL_ERROR": "An unexpected error occurred. Please try again later.",
"DATABASE_INTEGRITY_ERROR": "Data conflict detected. Please check your input.",
"DATABASE_OPERATIONAL_ERROR": "Database temporarily unavailable.",
"INVALID_REQUEST": "Invalid request format.",
"MISSING_REQUIRED_FIELD": "Required field missing from request.",
"TYPE_ERROR": "Invalid data type in request.",
}
return safe_messages.get(code, str(exception))
# Global error handler instance
error_handler = ErrorHandler()
def handle_api_error(
request: Request,
exception: Exception,
context: dict[str, Any] | None = None,
) -> JSONResponse:
"""
Main entry point for API error handling.
Args:
request: FastAPI request
exception: Exception to handle
context: Additional context
Returns:
Structured error response
"""
return error_handler.handle_exception(request, exception, context)
async def validation_exception_handler(
request: Request, exc: RequestValidationError
) -> JSONResponse:
"""Handle FastAPI validation errors."""
errors = []
for error in exc.errors():
errors.append(
{
"code": "VALIDATION_ERROR",
"field": ".".join(str(loc) for loc in error["loc"]),
"message": error["msg"],
"context": {"input": error.get("input")},
}
)
trace_id = str(uuid.uuid4())
# Log validation errors
logger.warning(
"Request validation failed",
extra={
"trace_id": trace_id,
"path": request.url.path,
"errors": errors,
},
)
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content=validation_error_response(errors, trace_id),
)
def create_error_handlers() -> dict[Any, Callable]:
"""Create error handlers for FastAPI app."""
return {
RequestValidationError: validation_exception_handler,
Exception: lambda request, exc: handle_api_error(request, exc),
}
# Decorator for wrapping functions with error handling
def with_error_handling(context_fn: Callable[[Any], dict[str, Any]] | None = None):
"""
Decorator to wrap functions with proper error handling.
Args:
context_fn: Optional function to extract context from arguments
"""
def decorator(func: Callable) -> Callable:
async def async_wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except Exception as e:
# Extract context if function provided
context = context_fn(*args, **kwargs) if context_fn else {}
# Get request from args/kwargs
request = None
for arg in args:
if isinstance(arg, Request):
request = arg
break
if not request and "request" in kwargs:
request = kwargs["request"]
if request:
return handle_api_error(request, e, context)
else:
# Re-raise if no request object
raise
def sync_wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
# Extract context if function provided
context = context_fn(*args, **kwargs) if context_fn else {}
# Get request from args/kwargs
request = None
for arg in args:
if isinstance(arg, Request):
request = arg
break
if not request and "request" in kwargs:
request = kwargs["request"]
if request:
return handle_api_error(request, e, context)
else:
# Re-raise if no request object
raise
# Return appropriate wrapper based on function type
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
return decorator
```
--------------------------------------------------------------------------------
/maverick_mcp/utils/logging_init.py:
--------------------------------------------------------------------------------
```python
"""
Logging initialization module for the backtesting system.
This module provides a centralized initialization point for all logging
components including structured logging, performance monitoring, debug
utilities, and log aggregation.
"""
import logging
import os
from typing import Any
from maverick_mcp.config.logging_settings import (
LoggingSettings,
configure_logging_for_environment,
get_logging_settings,
validate_logging_settings,
)
from maverick_mcp.utils.debug_utils import (
disable_debug_mode,
enable_debug_mode,
)
from maverick_mcp.utils.debug_utils import (
print_debug_summary as debug_print_summary,
)
from maverick_mcp.utils.structured_logger import (
StructuredLoggerManager,
get_logger_manager,
)
class LoggingInitializer:
"""Comprehensive logging system initializer."""
def __init__(self):
self._initialized = False
self._settings: LoggingSettings | None = None
self._manager: StructuredLoggerManager | None = None
def initialize_logging_system(
self,
environment: str | None = None,
custom_settings: dict[str, Any] | None = None,
force_reinit: bool = False,
) -> LoggingSettings:
"""
Initialize the complete logging system.
Args:
environment: Environment name (development, testing, production)
custom_settings: Custom settings to override defaults
force_reinit: Force reinitialization even if already initialized
Returns:
LoggingSettings: The final logging configuration
"""
if self._initialized and not force_reinit:
return self._settings
# Determine environment
if not environment:
environment = os.getenv("MAVERICK_ENVIRONMENT", "development")
# Get base settings for environment
if environment in ["development", "testing", "production"]:
self._settings = configure_logging_for_environment(environment)
else:
self._settings = get_logging_settings()
# Apply custom settings if provided
if custom_settings:
for key, value in custom_settings.items():
if hasattr(self._settings, key):
setattr(self._settings, key, value)
# Validate settings
warnings = validate_logging_settings(self._settings)
if warnings:
print("⚠️ Logging configuration warnings:")
for warning in warnings:
print(f" - {warning}")
# Initialize structured logging system
self._initialize_structured_logging()
# Initialize debug mode if enabled
if self._settings.debug_enabled:
enable_debug_mode()
self._setup_debug_logging()
# Initialize performance monitoring
self._initialize_performance_monitoring()
# Setup log rotation and cleanup
self._setup_log_management()
# Print initialization summary
self._print_initialization_summary(environment)
self._initialized = True
return self._settings
def _initialize_structured_logging(self):
"""Initialize structured logging infrastructure."""
self._manager = get_logger_manager()
# Setup structured logging with current settings
self._manager.setup_structured_logging(
log_level=self._settings.log_level,
log_format=self._settings.log_format,
log_file=self._settings.log_file_path
if self._settings.enable_file_logging
else None,
enable_async=self._settings.enable_async_logging,
enable_rotation=self._settings.enable_log_rotation,
max_log_size=self._settings.max_log_size_mb * 1024 * 1024,
backup_count=self._settings.backup_count,
console_output=self._settings.console_output,
)
# Configure debug filters if debug mode is enabled
if self._settings.debug_enabled:
for module in self._settings.get_debug_modules():
self._manager.debug_manager.enable_verbose_logging(module)
if self._settings.log_request_response:
self._manager.debug_manager.add_debug_filter(
"backtesting_requests",
{
"log_request_response": True,
"operations": [
"run_backtest",
"optimize_parameters",
"get_historical_data",
"calculate_technical_indicators",
],
},
)
def _setup_debug_logging(self):
"""Setup debug-specific logging configuration."""
# Create debug loggers
debug_logger = logging.getLogger("maverick_mcp.debug")
debug_logger.setLevel(logging.DEBUG)
request_logger = logging.getLogger("maverick_mcp.requests")
request_logger.setLevel(logging.DEBUG)
error_logger = logging.getLogger("maverick_mcp.errors")
error_logger.setLevel(logging.DEBUG)
# Add debug file handler if file logging is enabled
if self._settings.enable_file_logging:
debug_log_path = self._settings.log_file_path.replace(".log", "_debug.log")
debug_handler = logging.FileHandler(debug_log_path)
debug_handler.setLevel(logging.DEBUG)
# Use structured formatter for debug logs
from maverick_mcp.utils.structured_logger import EnhancedStructuredFormatter
debug_formatter = EnhancedStructuredFormatter(
include_performance=True, include_resources=True
)
debug_handler.setFormatter(debug_formatter)
debug_logger.addHandler(debug_handler)
request_logger.addHandler(debug_handler)
error_logger.addHandler(debug_handler)
def _initialize_performance_monitoring(self):
"""Initialize performance monitoring system."""
if not self._settings.enable_performance_logging:
return
# Create performance loggers for key components
components = [
"vectorbt_engine",
"data_provider",
"cache_manager",
"technical_analysis",
"portfolio_optimization",
"strategy_execution",
]
for component in components:
perf_logger = self._manager.get_performance_logger(
f"performance.{component}"
)
perf_logger.logger.info(
f"Performance monitoring initialized for {component}"
)
def _setup_log_management(self):
"""Setup log rotation and cleanup mechanisms."""
if (
not self._settings.enable_file_logging
or not self._settings.enable_log_rotation
):
return
# Log rotation is handled by RotatingFileHandler
# Additional cleanup could be implemented here for old log files
# Create logs directory if it doesn't exist
self._settings.ensure_log_directory()
def _print_initialization_summary(self, environment: str):
"""Print logging initialization summary."""
print("\n" + "=" * 80)
print("MAVERICK MCP LOGGING SYSTEM INITIALIZED")
print("=" * 80)
print(f"Environment: {environment}")
print(f"Log Level: {self._settings.log_level}")
print(f"Log Format: {self._settings.log_format}")
print(
f"Debug Mode: {'✅ Enabled' if self._settings.debug_enabled else '❌ Disabled'}"
)
print(
f"Performance Monitoring: {'✅ Enabled' if self._settings.enable_performance_logging else '❌ Disabled'}"
)
print(
f"File Logging: {'✅ Enabled' if self._settings.enable_file_logging else '❌ Disabled'}"
)
if self._settings.enable_file_logging:
print(f"Log File: {self._settings.log_file_path}")
print(
f"Log Rotation: {'✅ Enabled' if self._settings.enable_log_rotation else '❌ Disabled'}"
)
print(
f"Async Logging: {'✅ Enabled' if self._settings.enable_async_logging else '❌ Disabled'}"
)
print(
f"Resource Tracking: {'✅ Enabled' if self._settings.enable_resource_tracking else '❌ Disabled'}"
)
if self._settings.debug_enabled:
print("\n🐛 DEBUG MODE FEATURES:")
print(
f" - Request/Response Logging: {'✅' if self._settings.log_request_response else '❌'}"
)
print(f" - Verbose Modules: {len(self._settings.get_debug_modules())}")
print(f" - Max Payload Size: {self._settings.max_payload_length} chars")
if self._settings.enable_performance_logging:
print("\n📊 PERFORMANCE MONITORING:")
print(f" - Threshold: {self._settings.performance_log_threshold_ms}ms")
print(
f" - Business Metrics: {'✅' if self._settings.enable_business_metrics else '❌'}"
)
print("\n" + "=" * 80 + "\n")
def get_settings(self) -> LoggingSettings | None:
"""Get current logging settings."""
return self._settings
def get_manager(self) -> StructuredLoggerManager | None:
"""Get logging manager instance."""
return self._manager
def enable_debug_mode_runtime(self):
"""Enable debug mode at runtime."""
if self._settings:
self._settings.debug_enabled = True
enable_debug_mode()
self._setup_debug_logging()
print("🐛 Debug mode enabled at runtime")
def disable_debug_mode_runtime(self):
"""Disable debug mode at runtime."""
if self._settings:
self._settings.debug_enabled = False
disable_debug_mode()
print("🐛 Debug mode disabled at runtime")
def print_debug_summary_if_enabled(self):
"""Print debug summary if debug mode is enabled."""
if self._settings and self._settings.debug_enabled:
debug_print_summary()
def reconfigure_log_level(self, new_level: str):
"""Reconfigure log level at runtime."""
if not self._settings:
raise RuntimeError("Logging system not initialized")
valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
if new_level.upper() not in valid_levels:
raise ValueError(f"Invalid log level: {new_level}")
self._settings.log_level = new_level.upper()
# Update all loggers
logging.getLogger().setLevel(getattr(logging, new_level.upper()))
print(f"📊 Log level changed to: {new_level.upper()}")
def get_performance_summary(self) -> dict[str, Any]:
"""Get comprehensive performance summary."""
if not self._manager:
return {"error": "Logging system not initialized"}
return self._manager.create_dashboard_metrics()
def cleanup_logging_system(self):
"""Cleanup logging system resources."""
if self._manager:
# Close any open handlers
for handler in logging.getLogger().handlers:
if hasattr(handler, "close"):
handler.close()
self._initialized = False
print("🧹 Logging system cleaned up")
# Global initializer instance
_logging_initializer: LoggingInitializer | None = None
def get_logging_initializer() -> LoggingInitializer:
"""Get global logging initializer instance."""
global _logging_initializer
if _logging_initializer is None:
_logging_initializer = LoggingInitializer()
return _logging_initializer
def initialize_for_environment(environment: str, **custom_settings) -> LoggingSettings:
"""Initialize logging for specific environment."""
initializer = get_logging_initializer()
return initializer.initialize_logging_system(environment, custom_settings)
def initialize_for_development(**custom_settings) -> LoggingSettings:
"""Initialize logging for development environment."""
return initialize_for_environment("development", **custom_settings)
def initialize_for_testing(**custom_settings) -> LoggingSettings:
"""Initialize logging for testing environment."""
return initialize_for_environment("testing", **custom_settings)
def initialize_for_production(**custom_settings) -> LoggingSettings:
"""Initialize logging for production environment."""
return initialize_for_environment("production", **custom_settings)
def initialize_backtesting_logging(
environment: str | None = None, debug_mode: bool = False, **custom_settings
) -> LoggingSettings:
"""
Convenient function to initialize logging specifically for backtesting.
Args:
environment: Target environment (auto-detected if None)
debug_mode: Enable debug mode
**custom_settings: Additional custom settings
Returns:
LoggingSettings: Final logging configuration
"""
if debug_mode:
custom_settings["debug_enabled"] = True
custom_settings["log_request_response"] = True
custom_settings["performance_log_threshold_ms"] = 100.0
return initialize_for_environment(environment, **custom_settings)
# Convenience functions for runtime control
def enable_debug_mode_runtime():
"""Enable debug mode at runtime."""
get_logging_initializer().enable_debug_mode_runtime()
def disable_debug_mode_runtime():
"""Disable debug mode at runtime."""
get_logging_initializer().disable_debug_mode_runtime()
def change_log_level(new_level: str):
"""Change log level at runtime."""
get_logging_initializer().reconfigure_log_level(new_level)
def get_performance_summary() -> dict[str, Any]:
"""Get comprehensive performance summary."""
return get_logging_initializer().get_performance_summary()
def print_debug_summary():
"""Print debug summary if enabled."""
get_logging_initializer().print_debug_summary_if_enabled()
def cleanup_logging():
"""Cleanup logging system."""
get_logging_initializer().cleanup_logging_system()
# Environment detection and auto-initialization
def auto_initialize_logging() -> LoggingSettings:
"""
Automatically initialize logging based on environment variables.
This function is called automatically when the module is imported
in most cases, but can be called manually for custom initialization.
"""
environment = os.getenv("MAVERICK_ENVIRONMENT", "development")
debug_mode = os.getenv("MAVERICK_DEBUG", "false").lower() == "true"
return initialize_backtesting_logging(
environment=environment, debug_mode=debug_mode
)
# Auto-initialize if running as main module or in certain conditions
if __name__ == "__main__":
settings = auto_initialize_logging()
print("Logging system initialized from command line")
print_debug_summary()
elif os.getenv("MAVERICK_AUTO_INIT_LOGGING", "false").lower() == "true":
auto_initialize_logging()
```
--------------------------------------------------------------------------------
/tests/test_database_pool_config_simple.py:
--------------------------------------------------------------------------------
```python
"""
Simplified tests for DatabasePoolConfig focusing on core functionality.
This module tests the essential features of the enhanced database pool configuration:
- Basic configuration and validation
- Pool validation logic
- Factory methods
- Monitoring thresholds
- Environment variable integration
"""
import os
import warnings
from unittest.mock import patch
import pytest
from sqlalchemy.pool import QueuePool
from maverick_mcp.config.database import (
DatabasePoolConfig,
get_default_pool_config,
get_development_pool_config,
get_high_concurrency_pool_config,
validate_production_config,
)
from maverick_mcp.providers.interfaces.persistence import DatabaseConfig
class TestDatabasePoolConfigBasics:
"""Test basic DatabasePoolConfig functionality."""
def test_default_configuration(self):
"""Test default configuration values."""
config = DatabasePoolConfig()
# Should have reasonable defaults
assert config.pool_size >= 5
assert config.max_overflow >= 0
assert config.pool_timeout > 0
assert config.pool_recycle > 0
assert config.max_database_connections > 0
def test_valid_configuration(self):
"""Test a valid configuration passes validation."""
config = DatabasePoolConfig(
pool_size=10,
max_overflow=5,
max_database_connections=50,
reserved_superuser_connections=3,
expected_concurrent_users=10,
connections_per_user=1.2,
)
assert config.pool_size == 10
assert config.max_overflow == 5
# Should calculate totals correctly
total_app_connections = config.pool_size + config.max_overflow
available_connections = (
config.max_database_connections - config.reserved_superuser_connections
)
assert total_app_connections <= available_connections
def test_validation_exceeds_database_capacity(self):
"""Test validation failure when pool exceeds database capacity."""
with pytest.raises(
ValueError, match="Pool configuration exceeds database capacity"
):
DatabasePoolConfig(
pool_size=50,
max_overflow=30, # Total = 80
max_database_connections=70, # Available = 67 (70-3)
reserved_superuser_connections=3,
expected_concurrent_users=60, # Adjust to avoid other validation errors
connections_per_user=1.0,
)
def test_validation_insufficient_for_expected_load(self):
"""Test validation failure when pool is insufficient for expected load."""
with pytest.raises(
ValueError, match="Total connection capacity .* is insufficient"
):
DatabasePoolConfig(
pool_size=5,
max_overflow=0, # Total capacity = 5
expected_concurrent_users=10,
connections_per_user=1.0, # Expected demand = 10
max_database_connections=50,
)
def test_validation_warning_for_small_pool(self):
"""Test warning when pool size may be insufficient."""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
DatabasePoolConfig(
pool_size=5, # Small pool
max_overflow=15, # But enough overflow to meet demand
expected_concurrent_users=10,
connections_per_user=1.5, # Expected demand = 15
max_database_connections=50,
)
# Should generate a warning
assert len(w) > 0
assert "Pool size (5) may be insufficient" in str(w[0].message)
def test_get_pool_kwargs(self):
"""Test SQLAlchemy pool configuration generation."""
config = DatabasePoolConfig(
pool_size=15,
max_overflow=8,
pool_timeout=45,
pool_recycle=1800,
pool_pre_ping=True,
echo_pool=True,
expected_concurrent_users=18,
connections_per_user=1.0,
)
kwargs = config.get_pool_kwargs()
expected = {
"poolclass": QueuePool,
"pool_size": 15,
"max_overflow": 8,
"pool_timeout": 45,
"pool_recycle": 1800,
"pool_pre_ping": True,
"echo_pool": True,
}
assert kwargs == expected
def test_get_monitoring_thresholds(self):
"""Test monitoring threshold calculation."""
config = DatabasePoolConfig(
pool_size=20,
max_overflow=10,
expected_concurrent_users=25,
connections_per_user=1.0,
)
thresholds = config.get_monitoring_thresholds()
expected = {
"warning_threshold": int(20 * 0.8), # 16
"critical_threshold": int(20 * 0.95), # 19
"pool_size": 20,
"max_overflow": 10,
"total_capacity": 30,
}
assert thresholds == expected
def test_to_legacy_config(self):
"""Test conversion to legacy DatabaseConfig."""
config = DatabasePoolConfig(
pool_size=15,
max_overflow=8,
pool_timeout=45,
pool_recycle=1800,
echo_pool=True,
expected_concurrent_users=20,
connections_per_user=1.0,
)
database_url = "postgresql://user:pass@localhost/test"
legacy_config = config.to_legacy_config(database_url)
assert isinstance(legacy_config, DatabaseConfig)
assert legacy_config.database_url == database_url
assert legacy_config.pool_size == 15
assert legacy_config.max_overflow == 8
assert legacy_config.pool_timeout == 45
assert legacy_config.pool_recycle == 1800
assert legacy_config.echo is True
def test_from_legacy_config(self):
"""Test creation from legacy DatabaseConfig."""
legacy_config = DatabaseConfig(
database_url="postgresql://user:pass@localhost/test",
pool_size=12,
max_overflow=6,
pool_timeout=60,
pool_recycle=2400,
echo=False,
)
enhanced_config = DatabasePoolConfig.from_legacy_config(
legacy_config,
expected_concurrent_users=15,
max_database_connections=80,
)
assert enhanced_config.pool_size == 12
assert enhanced_config.max_overflow == 6
assert enhanced_config.pool_timeout == 60
assert enhanced_config.pool_recycle == 2400
assert enhanced_config.echo_pool is False
assert enhanced_config.expected_concurrent_users == 15
assert enhanced_config.max_database_connections == 80
class TestFactoryMethods:
"""Test factory methods for different configuration types."""
def test_get_default_pool_config(self):
"""Test default pool configuration factory."""
config = get_default_pool_config()
assert isinstance(config, DatabasePoolConfig)
assert config.pool_size > 0
def test_get_development_pool_config(self):
"""Test development pool configuration factory."""
config = get_development_pool_config()
assert isinstance(config, DatabasePoolConfig)
assert config.pool_size == 5
assert config.max_overflow == 2
assert config.echo_pool is True # Debug enabled in development
def test_get_high_concurrency_pool_config(self):
"""Test high concurrency pool configuration factory."""
config = get_high_concurrency_pool_config()
assert isinstance(config, DatabasePoolConfig)
assert config.pool_size == 50
assert config.max_overflow == 30
assert config.expected_concurrent_users == 60
def test_validate_production_config_valid(self):
"""Test production validation for valid configuration."""
config = DatabasePoolConfig(
pool_size=25,
max_overflow=15,
pool_timeout=30,
pool_recycle=3600,
expected_concurrent_users=35,
connections_per_user=1.0,
)
with patch("maverick_mcp.config.database.logger") as mock_logger:
result = validate_production_config(config)
assert result is True
mock_logger.info.assert_called()
def test_validate_production_config_warnings(self):
"""Test production validation with warnings."""
config = DatabasePoolConfig(
pool_size=5, # Too small for production
max_overflow=10, # Enough to meet demand but will warn
pool_timeout=30,
pool_recycle=3600,
expected_concurrent_users=10,
connections_per_user=1.0,
)
with patch("maverick_mcp.config.database.logger") as mock_logger:
result = validate_production_config(config)
assert result is True # Warnings don't fail validation
# Should log warnings
assert mock_logger.warning.called
def test_validate_production_config_errors(self):
"""Test production validation with errors."""
# Create a valid config first
config = DatabasePoolConfig(
pool_size=15,
max_overflow=5,
pool_timeout=5, # This is actually at the minimum, so will work
pool_recycle=3600,
expected_concurrent_users=18,
connections_per_user=1.0,
)
# Now test the production validation function which has stricter requirements
with pytest.raises(
ValueError, match="Production configuration validation failed"
):
validate_production_config(config)
class TestEnvironmentVariables:
"""Test environment variable integration."""
@patch.dict(
os.environ,
{
"DB_POOL_SIZE": "25",
"DB_MAX_OVERFLOW": "10",
"DB_EXPECTED_CONCURRENT_USERS": "25",
"DB_CONNECTIONS_PER_USER": "1.2",
},
)
def test_environment_variable_overrides(self):
"""Test that environment variables override defaults."""
config = DatabasePoolConfig()
# Should use environment values
assert config.pool_size == 25
assert config.max_overflow == 10
assert config.expected_concurrent_users == 25
assert config.connections_per_user == 1.2
@patch.dict(
os.environ,
{
"DB_ECHO_POOL": "true",
"DB_POOL_PRE_PING": "false",
},
)
def test_boolean_environment_variables(self):
"""Test boolean environment variable parsing."""
config = DatabasePoolConfig()
assert config.echo_pool is True
assert config.pool_pre_ping is False
class TestValidationScenarios:
"""Test various validation scenarios."""
def test_database_limits_validation(self):
"""Test validation against database connection limits."""
config = DatabasePoolConfig(
pool_size=10,
max_overflow=5,
max_database_connections=100,
expected_concurrent_users=12,
connections_per_user=1.0,
)
# Should pass validation when limits match
config.validate_against_database_limits(100)
assert config.max_database_connections == 100
def test_database_limits_higher_actual(self):
"""Test when actual database limits are higher."""
config = DatabasePoolConfig(
pool_size=10,
max_overflow=5,
max_database_connections=50,
expected_concurrent_users=12,
connections_per_user=1.0,
)
with patch("maverick_mcp.config.database.logger") as mock_logger:
config.validate_against_database_limits(100)
# Should update configuration
assert config.max_database_connections == 100
mock_logger.info.assert_called()
def test_database_limits_too_low(self):
"""Test when actual database limits are dangerously low."""
config = DatabasePoolConfig(
pool_size=30,
max_overflow=20, # Total = 50
max_database_connections=100,
expected_concurrent_users=40,
connections_per_user=1.0,
)
with pytest.raises(
ValueError, match="Configuration invalid for actual database limits"
):
# Actual limit is 40, available is 37, pool needs 50 - should fail
config.validate_against_database_limits(40)
class TestRealWorldScenarios:
"""Test realistic usage scenarios."""
def test_microservice_configuration(self):
"""Test configuration suitable for microservice deployment."""
config = DatabasePoolConfig(
pool_size=8,
max_overflow=4,
expected_concurrent_users=10,
connections_per_user=1.0,
max_database_connections=50,
)
# Should be valid and suitable for microservice
assert config.pool_size == 8
thresholds = config.get_monitoring_thresholds()
assert thresholds["total_capacity"] == 12
def test_development_to_production_migration(self):
"""Test migrating from development to production configuration."""
# Start with development config
dev_config = get_development_pool_config()
assert dev_config.echo_pool is True
assert dev_config.pool_size == 5
# Convert to legacy for compatibility
legacy_config = dev_config.to_legacy_config("postgresql://localhost/test")
# Upgrade to production config
prod_config = DatabasePoolConfig.from_legacy_config(
legacy_config,
pool_size=30,
max_overflow=20,
expected_concurrent_users=40,
echo_pool=False,
)
# Should be production-ready
assert validate_production_config(prod_config) is True
assert prod_config.echo_pool is False
assert prod_config.pool_size == 30
def test_connection_exhaustion_prevention(self):
"""Test that configuration prevents connection exhaustion."""
# Configuration that would exhaust connections should fail
with pytest.raises(ValueError, match="exceeds database capacity"):
DatabasePoolConfig(
pool_size=45,
max_overflow=35, # Total = 80
max_database_connections=75, # Available = 72
expected_concurrent_users=60,
connections_per_user=1.0,
)
# Safe configuration should work
safe_config = DatabasePoolConfig(
pool_size=30,
max_overflow=20, # Total = 50
max_database_connections=75, # Available = 72
expected_concurrent_users=45,
connections_per_user=1.0,
)
# Should work and leave room for other applications
total_used = safe_config.pool_size + safe_config.max_overflow
available = (
safe_config.max_database_connections
- safe_config.reserved_superuser_connections
)
assert total_used < available
```
--------------------------------------------------------------------------------
/alembic/versions/013_add_backtest_persistence_models.py:
--------------------------------------------------------------------------------
```python
"""Add backtest persistence models
Revision ID: 013_add_backtest_persistence_models
Revises: fix_database_integrity_issues
Create Date: 2025-01-16 12:00:00.000000
This migration adds comprehensive backtesting persistence models:
1. BacktestResult - Main backtest results with comprehensive metrics
2. BacktestTrade - Individual trade records from backtests
3. OptimizationResult - Parameter optimization results
4. WalkForwardTest - Walk-forward validation test results
5. BacktestPortfolio - Portfolio-level backtests with multiple symbols
All tables include proper indexes for common query patterns and foreign key
relationships for data integrity.
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "013_add_backtest_persistence_models"
down_revision = "fix_database_integrity_issues"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create BacktestResult table
op.create_table(
"mcp_backtest_results",
sa.Column("backtest_id", sa.Uuid(), nullable=False, primary_key=True),
# Basic metadata
sa.Column("symbol", sa.String(length=10), nullable=False),
sa.Column("strategy_type", sa.String(length=50), nullable=False),
sa.Column("backtest_date", sa.DateTime(timezone=True), nullable=False),
# Date range and setup
sa.Column("start_date", sa.Date(), nullable=False),
sa.Column("end_date", sa.Date(), nullable=False),
sa.Column(
"initial_capital",
sa.Numeric(precision=15, scale=2),
server_default="10000.0",
),
# Trading costs
sa.Column("fees", sa.Numeric(precision=6, scale=4), server_default="0.001"),
sa.Column("slippage", sa.Numeric(precision=6, scale=4), server_default="0.001"),
# Strategy parameters
sa.Column("parameters", sa.JSON()),
# Performance metrics
sa.Column("total_return", sa.Numeric(precision=10, scale=4)),
sa.Column("annualized_return", sa.Numeric(precision=10, scale=4)),
sa.Column("sharpe_ratio", sa.Numeric(precision=8, scale=4)),
sa.Column("sortino_ratio", sa.Numeric(precision=8, scale=4)),
sa.Column("calmar_ratio", sa.Numeric(precision=8, scale=4)),
# Risk metrics
sa.Column("max_drawdown", sa.Numeric(precision=8, scale=4)),
sa.Column("max_drawdown_duration", sa.Integer()),
sa.Column("volatility", sa.Numeric(precision=8, scale=4)),
sa.Column("downside_volatility", sa.Numeric(precision=8, scale=4)),
# Trade statistics
sa.Column("total_trades", sa.Integer(), server_default="0"),
sa.Column("winning_trades", sa.Integer(), server_default="0"),
sa.Column("losing_trades", sa.Integer(), server_default="0"),
sa.Column("win_rate", sa.Numeric(precision=5, scale=4)),
# P&L statistics
sa.Column("profit_factor", sa.Numeric(precision=8, scale=4)),
sa.Column("average_win", sa.Numeric(precision=12, scale=4)),
sa.Column("average_loss", sa.Numeric(precision=12, scale=4)),
sa.Column("largest_win", sa.Numeric(precision=12, scale=4)),
sa.Column("largest_loss", sa.Numeric(precision=12, scale=4)),
# Portfolio values
sa.Column("final_portfolio_value", sa.Numeric(precision=15, scale=2)),
sa.Column("peak_portfolio_value", sa.Numeric(precision=15, scale=2)),
# Market analysis
sa.Column("beta", sa.Numeric(precision=8, scale=4)),
sa.Column("alpha", sa.Numeric(precision=8, scale=4)),
# Time series data
sa.Column("equity_curve", sa.JSON()),
sa.Column("drawdown_series", sa.JSON()),
# Execution metadata
sa.Column("execution_time_seconds", sa.Numeric(precision=8, scale=3)),
sa.Column("data_points", sa.Integer()),
# Status and notes
sa.Column("status", sa.String(length=20), server_default="completed"),
sa.Column("error_message", sa.Text()),
sa.Column("notes", sa.Text()),
# Timestamps
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
)
# Create indexes for BacktestResult
op.create_index(
"mcp_backtest_results_symbol_idx", "mcp_backtest_results", ["symbol"]
)
op.create_index(
"mcp_backtest_results_strategy_idx", "mcp_backtest_results", ["strategy_type"]
)
op.create_index(
"mcp_backtest_results_date_idx", "mcp_backtest_results", ["backtest_date"]
)
op.create_index(
"mcp_backtest_results_sharpe_idx", "mcp_backtest_results", ["sharpe_ratio"]
)
op.create_index(
"mcp_backtest_results_total_return_idx",
"mcp_backtest_results",
["total_return"],
)
op.create_index(
"mcp_backtest_results_symbol_strategy_idx",
"mcp_backtest_results",
["symbol", "strategy_type"],
)
# Create BacktestTrade table
op.create_table(
"mcp_backtest_trades",
sa.Column("trade_id", sa.Uuid(), nullable=False, primary_key=True),
sa.Column("backtest_id", sa.Uuid(), nullable=False),
# Trade identification
sa.Column("trade_number", sa.Integer(), nullable=False),
# Entry details
sa.Column("entry_date", sa.Date(), nullable=False),
sa.Column("entry_price", sa.Numeric(precision=12, scale=4), nullable=False),
sa.Column("entry_time", sa.DateTime(timezone=True)),
# Exit details
sa.Column("exit_date", sa.Date()),
sa.Column("exit_price", sa.Numeric(precision=12, scale=4)),
sa.Column("exit_time", sa.DateTime(timezone=True)),
# Position details
sa.Column("position_size", sa.Numeric(precision=15, scale=2)),
sa.Column("direction", sa.String(length=5), nullable=False),
# P&L
sa.Column("pnl", sa.Numeric(precision=12, scale=4)),
sa.Column("pnl_percent", sa.Numeric(precision=8, scale=4)),
# Risk metrics
sa.Column("mae", sa.Numeric(precision=8, scale=4)), # Maximum Adverse Excursion
sa.Column(
"mfe", sa.Numeric(precision=8, scale=4)
), # Maximum Favorable Excursion
# Duration
sa.Column("duration_days", sa.Integer()),
sa.Column("duration_hours", sa.Numeric(precision=8, scale=2)),
# Exit details
sa.Column("exit_reason", sa.String(length=50)),
sa.Column("fees_paid", sa.Numeric(precision=10, scale=4), server_default="0"),
sa.Column(
"slippage_cost", sa.Numeric(precision=10, scale=4), server_default="0"
),
# Timestamps
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
# Foreign key constraint
sa.ForeignKeyConstraint(
["backtest_id"], ["mcp_backtest_results.backtest_id"], ondelete="CASCADE"
),
)
# Create indexes for BacktestTrade
op.create_index(
"mcp_backtest_trades_backtest_idx", "mcp_backtest_trades", ["backtest_id"]
)
op.create_index(
"mcp_backtest_trades_entry_date_idx", "mcp_backtest_trades", ["entry_date"]
)
op.create_index(
"mcp_backtest_trades_exit_date_idx", "mcp_backtest_trades", ["exit_date"]
)
op.create_index("mcp_backtest_trades_pnl_idx", "mcp_backtest_trades", ["pnl"])
op.create_index(
"mcp_backtest_trades_backtest_entry_idx",
"mcp_backtest_trades",
["backtest_id", "entry_date"],
)
# Create OptimizationResult table
op.create_table(
"mcp_optimization_results",
sa.Column("optimization_id", sa.Uuid(), nullable=False, primary_key=True),
sa.Column("backtest_id", sa.Uuid(), nullable=False),
# Optimization metadata
sa.Column("optimization_date", sa.DateTime(timezone=True), nullable=False),
sa.Column("parameter_set", sa.Integer(), nullable=False),
# Parameters and results
sa.Column("parameters", sa.JSON(), nullable=False),
sa.Column("objective_function", sa.String(length=50)),
sa.Column("objective_value", sa.Numeric(precision=12, scale=6)),
# Key metrics
sa.Column("total_return", sa.Numeric(precision=10, scale=4)),
sa.Column("sharpe_ratio", sa.Numeric(precision=8, scale=4)),
sa.Column("max_drawdown", sa.Numeric(precision=8, scale=4)),
sa.Column("win_rate", sa.Numeric(precision=5, scale=4)),
sa.Column("profit_factor", sa.Numeric(precision=8, scale=4)),
sa.Column("total_trades", sa.Integer()),
# Ranking
sa.Column("rank", sa.Integer()),
# Statistical significance
sa.Column("is_statistically_significant", sa.Boolean(), server_default="false"),
sa.Column("p_value", sa.Numeric(precision=8, scale=6)),
# Timestamps
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
# Foreign key constraint
sa.ForeignKeyConstraint(
["backtest_id"], ["mcp_backtest_results.backtest_id"], ondelete="CASCADE"
),
)
# Create indexes for OptimizationResult
op.create_index(
"mcp_optimization_results_backtest_idx",
"mcp_optimization_results",
["backtest_id"],
)
op.create_index(
"mcp_optimization_results_param_set_idx",
"mcp_optimization_results",
["parameter_set"],
)
op.create_index(
"mcp_optimization_results_objective_idx",
"mcp_optimization_results",
["objective_value"],
)
# Create WalkForwardTest table
op.create_table(
"mcp_walk_forward_tests",
sa.Column("walk_forward_id", sa.Uuid(), nullable=False, primary_key=True),
sa.Column("parent_backtest_id", sa.Uuid(), nullable=False),
# Test configuration
sa.Column("test_date", sa.DateTime(timezone=True), nullable=False),
sa.Column("window_size_months", sa.Integer(), nullable=False),
sa.Column("step_size_months", sa.Integer(), nullable=False),
# Time periods
sa.Column("training_start", sa.Date(), nullable=False),
sa.Column("training_end", sa.Date(), nullable=False),
sa.Column("test_period_start", sa.Date(), nullable=False),
sa.Column("test_period_end", sa.Date(), nullable=False),
# Training results
sa.Column("optimal_parameters", sa.JSON()),
sa.Column("training_performance", sa.Numeric(precision=10, scale=4)),
# Out-of-sample results
sa.Column("out_of_sample_return", sa.Numeric(precision=10, scale=4)),
sa.Column("out_of_sample_sharpe", sa.Numeric(precision=8, scale=4)),
sa.Column("out_of_sample_drawdown", sa.Numeric(precision=8, scale=4)),
sa.Column("out_of_sample_trades", sa.Integer()),
# Performance analysis
sa.Column("performance_ratio", sa.Numeric(precision=8, scale=4)),
sa.Column("degradation_factor", sa.Numeric(precision=8, scale=4)),
# Validation
sa.Column("is_profitable", sa.Boolean()),
sa.Column("is_statistically_significant", sa.Boolean(), server_default="false"),
# Timestamps
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
# Foreign key constraint
sa.ForeignKeyConstraint(
["parent_backtest_id"],
["mcp_backtest_results.backtest_id"],
ondelete="CASCADE",
),
)
# Create indexes for WalkForwardTest
op.create_index(
"mcp_walk_forward_tests_parent_idx",
"mcp_walk_forward_tests",
["parent_backtest_id"],
)
op.create_index(
"mcp_walk_forward_tests_period_idx",
"mcp_walk_forward_tests",
["test_period_start"],
)
op.create_index(
"mcp_walk_forward_tests_performance_idx",
"mcp_walk_forward_tests",
["out_of_sample_return"],
)
# Create BacktestPortfolio table
op.create_table(
"mcp_backtest_portfolios",
sa.Column("portfolio_backtest_id", sa.Uuid(), nullable=False, primary_key=True),
# Portfolio identification
sa.Column("portfolio_name", sa.String(length=100), nullable=False),
sa.Column("description", sa.Text()),
# Test metadata
sa.Column("backtest_date", sa.DateTime(timezone=True), nullable=False),
sa.Column("start_date", sa.Date(), nullable=False),
sa.Column("end_date", sa.Date(), nullable=False),
# Portfolio composition
sa.Column("symbols", sa.JSON(), nullable=False),
sa.Column("weights", sa.JSON()),
sa.Column("rebalance_frequency", sa.String(length=20)),
# Portfolio parameters
sa.Column(
"initial_capital",
sa.Numeric(precision=15, scale=2),
server_default="100000.0",
),
sa.Column("max_positions", sa.Integer()),
sa.Column("position_sizing_method", sa.String(length=50)),
# Risk management
sa.Column("portfolio_stop_loss", sa.Numeric(precision=6, scale=4)),
sa.Column("max_sector_allocation", sa.Numeric(precision=5, scale=4)),
sa.Column("correlation_threshold", sa.Numeric(precision=5, scale=4)),
# Performance metrics
sa.Column("total_return", sa.Numeric(precision=10, scale=4)),
sa.Column("annualized_return", sa.Numeric(precision=10, scale=4)),
sa.Column("sharpe_ratio", sa.Numeric(precision=8, scale=4)),
sa.Column("sortino_ratio", sa.Numeric(precision=8, scale=4)),
sa.Column("max_drawdown", sa.Numeric(precision=8, scale=4)),
sa.Column("volatility", sa.Numeric(precision=8, scale=4)),
# Portfolio-specific metrics
sa.Column("diversification_ratio", sa.Numeric(precision=8, scale=4)),
sa.Column("concentration_index", sa.Numeric(precision=8, scale=4)),
sa.Column("turnover_rate", sa.Numeric(precision=8, scale=4)),
# References and time series
sa.Column("component_backtest_ids", sa.JSON()),
sa.Column("portfolio_equity_curve", sa.JSON()),
sa.Column("portfolio_weights_history", sa.JSON()),
# Status
sa.Column("status", sa.String(length=20), server_default="completed"),
sa.Column("notes", sa.Text()),
# Timestamps
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
)
# Create indexes for BacktestPortfolio
op.create_index(
"mcp_backtest_portfolios_name_idx",
"mcp_backtest_portfolios",
["portfolio_name"],
)
op.create_index(
"mcp_backtest_portfolios_date_idx", "mcp_backtest_portfolios", ["backtest_date"]
)
op.create_index(
"mcp_backtest_portfolios_return_idx",
"mcp_backtest_portfolios",
["total_return"],
)
def downgrade() -> None:
# Drop tables in reverse order (due to foreign key constraints)
op.drop_table("mcp_backtest_portfolios")
op.drop_table("mcp_walk_forward_tests")
op.drop_table("mcp_optimization_results")
op.drop_table("mcp_backtest_trades")
op.drop_table("mcp_backtest_results")
```
--------------------------------------------------------------------------------
/maverick_mcp/tools/risk_management.py:
--------------------------------------------------------------------------------
```python
"""
Risk management tools for position sizing, stop loss calculation, and portfolio risk analysis.
"""
import logging
from datetime import datetime, timedelta
from typing import Any
import numpy as np
import pandas as pd
from pydantic import BaseModel, Field
from maverick_mcp.agents.base import PersonaAwareTool
from maverick_mcp.core.technical_analysis import calculate_atr
from maverick_mcp.providers.stock_data import StockDataProvider
logger = logging.getLogger(__name__)
class PositionSizeInput(BaseModel):
"""Input for position sizing calculations."""
account_size: float = Field(description="Total account size in dollars")
entry_price: float = Field(description="Planned entry price")
stop_loss_price: float = Field(description="Stop loss price")
risk_percentage: float = Field(
default=2.0, description="Percentage of account to risk (default 2%)"
)
class TechnicalStopsInput(BaseModel):
"""Input for technical stop calculations."""
symbol: str = Field(description="Stock symbol")
lookback_days: int = Field(default=20, description="Days to look back for analysis")
atr_multiplier: float = Field(
default=2.0, description="ATR multiplier for stop distance"
)
class RiskMetricsInput(BaseModel):
"""Input for portfolio risk metrics."""
symbols: list[str] = Field(description="List of symbols in portfolio")
weights: list[float] | None = Field(
default=None, description="Portfolio weights (equal weight if not provided)"
)
lookback_days: int = Field(
default=252, description="Days for correlation calculation"
)
class PositionSizeTool(PersonaAwareTool):
"""Calculate position size based on risk management rules."""
name: str = "calculate_position_size"
description: str = (
"Calculate position size based on account risk, with Kelly Criterion "
"and persona adjustments"
)
args_schema: type[BaseModel] = PositionSizeInput
def _run(
self,
account_size: float,
entry_price: float,
stop_loss_price: float,
risk_percentage: float = 2.0,
) -> str:
"""Calculate position size synchronously."""
try:
# Basic risk calculation
risk_amount = account_size * (risk_percentage / 100)
price_risk = abs(entry_price - stop_loss_price)
if price_risk == 0:
return "Error: Entry and stop loss prices cannot be the same"
# Calculate base position size
base_shares = risk_amount / price_risk
base_position_value = base_shares * entry_price
# Apply persona adjustments
adjusted_shares = self.adjust_for_risk(base_shares, "position_size")
adjusted_value = adjusted_shares * entry_price
# Calculate Kelly fraction if persona is set
kelly_fraction = 0.25 # Default conservative Kelly
if self.persona:
risk_factor = sum(self.persona.risk_tolerance) / 100
kelly_fraction = self._calculate_kelly_fraction(risk_factor)
kelly_shares = base_shares * kelly_fraction
kelly_value = kelly_shares * entry_price
# Ensure position doesn't exceed max allocation
max_position_pct = self.persona.position_size_max if self.persona else 0.10
max_position_value = account_size * max_position_pct
final_shares = min(adjusted_shares, kelly_shares)
final_value = final_shares * entry_price
if final_value > max_position_value:
final_shares = max_position_value / entry_price
final_value = max_position_value
result = {
"status": "success",
"position_sizing": {
"recommended_shares": int(final_shares),
"position_value": round(final_value, 2),
"position_percentage": round((final_value / account_size) * 100, 2),
"risk_amount": round(risk_amount, 2),
"price_risk_per_share": round(price_risk, 2),
"r_multiple_target": round(
2.0 * price_risk / entry_price * 100, 2
), # 2R target
},
"calculations": {
"base_shares": int(base_shares),
"base_position_value": round(base_position_value, 2),
"kelly_shares": int(kelly_shares),
"kelly_value": round(kelly_value, 2),
"persona_adjusted_shares": int(adjusted_shares),
"persona_adjusted_value": round(adjusted_value, 2),
"kelly_fraction": round(kelly_fraction, 3),
"max_allowed_value": round(max_position_value, 2),
},
}
# Add persona insights if available
if self.persona:
result["persona_insights"] = {
"investor_type": self.persona.name,
"risk_tolerance": self.persona.risk_tolerance,
"max_position_size": f"{self.persona.position_size_max * 100:.1f}%",
"suitable_for_profile": final_value <= max_position_value,
}
# Format for return
formatted = self.format_for_persona(result)
return str(formatted)
except Exception as e:
logger.error(f"Error calculating position size: {e}")
return f"Error calculating position size: {str(e)}"
class TechnicalStopsTool(PersonaAwareTool):
"""Calculate stop loss levels based on technical analysis."""
name: str = "calculate_technical_stops"
description: str = (
"Calculate stop loss levels using ATR, support levels, and moving averages"
)
args_schema: type[BaseModel] = TechnicalStopsInput
def _run(
self, symbol: str, lookback_days: int = 20, atr_multiplier: float = 2.0
) -> str:
"""Calculate technical stops synchronously."""
try:
provider = StockDataProvider()
# Get price data
end_date = datetime.now()
start_date = end_date - timedelta(days=max(lookback_days * 2, 100))
df = provider.get_stock_data(
symbol,
start_date.strftime("%Y-%m-%d"),
end_date.strftime("%Y-%m-%d"),
use_cache=True,
)
if df.empty:
return f"Error: No price data available for {symbol}"
# Calculate technical levels
current_price = df["Close"].iloc[-1]
# ATR-based stop
atr = calculate_atr(df, period=14)
atr_value = atr.iloc[-1]
atr_stop = current_price - (atr_value * atr_multiplier)
# Support-based stops
recent_lows = df["Low"].rolling(window=lookback_days).min()
support_level = recent_lows.iloc[-1]
# Moving average stops
ma_20 = float(df["Close"].rolling(window=20).mean().iloc[-1])
ma_50 = (
float(df["Close"].rolling(window=50).mean().iloc[-1])
if len(df) >= 50
else None
)
# Swing low stop (lowest low in recent period)
swing_low = df["Low"].iloc[-lookback_days:].min()
# Apply persona adjustments
if self.persona:
atr_multiplier = self.adjust_for_risk(atr_multiplier, "stop_loss")
atr_stop = current_price - (atr_value * atr_multiplier)
stops = {
"current_price": round(current_price, 2),
"atr_stop": round(atr_stop, 2),
"support_stop": round(support_level, 2),
"swing_low_stop": round(swing_low, 2),
"ma_20_stop": round(ma_20, 2),
"ma_50_stop": round(ma_50, 2) if ma_50 else None,
"atr_value": round(atr_value, 2),
"stop_distances": {
"atr_stop_pct": round(
((current_price - atr_stop) / current_price) * 100, 2
),
"support_stop_pct": round(
((current_price - support_level) / current_price) * 100, 2
),
"swing_low_pct": round(
((current_price - swing_low) / current_price) * 100, 2
),
},
}
# Recommend stop based on persona
if self.persona:
if self.persona.name == "Conservative":
recommended = max(atr_stop, ma_20) # Tighter stop
elif self.persona.name == "Day Trader":
recommended = atr_stop # ATR-based for volatility
else:
recommended = min(support_level, atr_stop) # Balance
else:
recommended = atr_stop
stops["recommended_stop"] = round(recommended, 2)
stops["recommended_stop_pct"] = round(
((current_price - recommended) / current_price) * 100, 2
)
result = {
"status": "success",
"symbol": symbol,
"technical_stops": stops,
"analysis_period": lookback_days,
"atr_multiplier": atr_multiplier,
}
# Format for persona
formatted = self.format_for_persona(result)
return str(formatted)
except Exception as e:
logger.error(f"Error calculating technical stops for {symbol}: {e}")
return f"Error calculating technical stops: {str(e)}"
class RiskMetricsTool(PersonaAwareTool):
"""Calculate portfolio risk metrics including correlations and VaR."""
name: str = "calculate_risk_metrics"
description: str = (
"Calculate portfolio risk metrics including correlation, beta, and VaR"
)
args_schema: type[BaseModel] = RiskMetricsInput # type: ignore[assignment]
def _run(
self,
symbols: list[str],
weights: list[float] | None = None,
lookback_days: int = 252,
) -> str:
"""Calculate risk metrics synchronously."""
try:
if not symbols:
return "Error: No symbols provided"
provider = StockDataProvider()
# If no weights provided, use equal weight
if weights is None:
weights = [1.0 / len(symbols)] * len(symbols)
elif len(weights) != len(symbols):
return "Error: Number of weights must match number of symbols"
# Normalize weights
weights_array = np.array(weights)
weights = list(weights_array / weights_array.sum())
# Get price data for all symbols
end_date = datetime.now()
start_date = end_date - timedelta(days=lookback_days + 30)
price_data = {}
returns_data = {}
for symbol in symbols:
df = provider.get_stock_data(
symbol,
start_date.strftime("%Y-%m-%d"),
end_date.strftime("%Y-%m-%d"),
use_cache=True,
)
if not df.empty:
price_data[symbol] = df["Close"]
returns_data[symbol] = df["Close"].pct_change().dropna()
if not returns_data:
return "Error: No price data available for any symbols"
# Create returns DataFrame
returns_df = pd.DataFrame(returns_data).dropna()
# Calculate correlation matrix
correlation_matrix = returns_df.corr()
# Calculate portfolio metrics
portfolio_returns = (returns_df * weights[: len(returns_df.columns)]).sum(
axis=1
)
portfolio_std = portfolio_returns.std() * np.sqrt(252) # Annualized
# Calculate VaR (95% confidence)
var_95 = np.percentile(portfolio_returns, 5) * np.sqrt(252)
# Calculate portfolio beta (vs SPY)
spy_df = provider.get_stock_data(
"SPY",
start_date.strftime("%Y-%m-%d"),
end_date.strftime("%Y-%m-%d"),
use_cache=True,
)
if not spy_df.empty:
spy_returns = spy_df["Close"].pct_change().dropna()
# Align dates
common_dates = portfolio_returns.index.intersection(spy_returns.index)
if len(common_dates) > 0:
portfolio_beta = (
portfolio_returns[common_dates].cov(spy_returns[common_dates])
/ spy_returns[common_dates].var()
)
else:
portfolio_beta = None
else:
portfolio_beta = None
# Build result
result = {
"status": "success",
"portfolio_metrics": {
"annualized_volatility": round(portfolio_std * 100, 2),
"value_at_risk_95": round(var_95 * 100, 2),
"portfolio_beta": round(portfolio_beta, 2)
if portfolio_beta
else None,
"avg_correlation": round(
correlation_matrix.values[
np.triu_indices_from(correlation_matrix.values, k=1)
].mean(),
3,
),
},
"correlations": correlation_matrix.to_dict(),
"weights": {
symbol: round(weight, 3)
for symbol, weight in zip(
symbols[: len(weights)], weights, strict=False
)
},
"risk_assessment": self._assess_portfolio_risk(
portfolio_std, var_95, correlation_matrix
),
}
# Format for persona
formatted = self.format_for_persona(result)
return str(formatted)
except Exception as e:
logger.error(f"Error calculating risk metrics: {e}")
return f"Error calculating risk metrics: {str(e)}"
def _assess_portfolio_risk(
self, volatility: float, var: float, correlation_matrix: pd.DataFrame
) -> dict[str, Any]:
"""Assess portfolio risk level."""
risk_level = "Low"
warnings = []
# Check volatility
if volatility > 0.25: # 25% annual vol
risk_level = "High"
warnings.append("High portfolio volatility")
elif volatility > 0.15:
risk_level = "Moderate"
# Check VaR
if abs(var) > 0.10: # 10% VaR
warnings.append("High Value at Risk")
# Check correlation
avg_corr = correlation_matrix.values[
np.triu_indices_from(correlation_matrix.values, k=1)
].mean()
if avg_corr > 0.7:
warnings.append("High correlation between holdings")
return {
"risk_level": risk_level,
"warnings": warnings,
"diversification_score": round(1 - avg_corr, 2),
}
```
--------------------------------------------------------------------------------
/maverick_mcp/tests/test_mcp_tools.py:
--------------------------------------------------------------------------------
```python
"""
Comprehensive tests for all MCP tool functions in Maverick-MCP.
This module tests all public MCP tools exposed by the server including:
- Stock data fetching
- Technical analysis
- Risk analysis
- Chart generation
- News sentiment
- Multi-ticker comparison
"""
from datetime import datetime
from unittest.mock import MagicMock, patch
import numpy as np
import pandas as pd
import pytest
from fastmcp import Client
from maverick_mcp.api.server import mcp
class TestMCPTools:
"""Test suite for all MCP tool functions using the new router structure."""
@pytest.fixture
def mock_stock_data(self):
"""Create sample stock data for testing."""
dates = pd.date_range(end=datetime.now(), periods=250, freq="D")
return pd.DataFrame(
{
"Open": np.random.uniform(90, 110, 250),
"High": np.random.uniform(95, 115, 250),
"Low": np.random.uniform(85, 105, 250),
"Close": np.random.uniform(90, 110, 250),
"Volume": np.random.randint(1000000, 10000000, 250),
},
index=dates,
)
@pytest.mark.asyncio
async def test_fetch_stock_data(self, mock_stock_data):
"""Test basic stock data fetching."""
with patch(
"maverick_mcp.providers.stock_data.StockDataProvider.get_stock_data"
) as mock_get:
mock_get.return_value = mock_stock_data
async with Client(mcp) as client:
result = await client.call_tool(
"/data_fetch_stock_data",
{
"request": {
"ticker": "AAPL",
"start_date": "2024-01-01",
"end_date": "2024-01-31",
}
},
)
assert result[0].text is not None
data = eval(result[0].text)
assert "ticker" in data
assert data["ticker"] == "AAPL"
assert "record_count" in data
assert data["record_count"] == 250
@pytest.mark.asyncio
async def test_rsi_analysis(self, mock_stock_data):
"""Test RSI technical analysis."""
with patch(
"maverick_mcp.providers.stock_data.StockDataProvider.get_stock_data"
) as mock_get:
mock_get.return_value = mock_stock_data
async with Client(mcp) as client:
result = await client.call_tool(
"/technical_get_rsi_analysis", {"ticker": "AAPL", "period": 14}
)
assert result[0].text is not None
data = eval(result[0].text)
assert "analysis" in data
assert "ticker" in data
assert data["ticker"] == "AAPL"
assert "current" in data["analysis"]
assert "signal" in data["analysis"]
assert data["analysis"]["signal"] in [
"oversold",
"neutral",
"overbought",
"bullish",
"bearish",
]
@pytest.mark.asyncio
async def test_macd_analysis(self, mock_stock_data):
"""Test MACD technical analysis."""
with patch(
"maverick_mcp.providers.stock_data.StockDataProvider.get_stock_data"
) as mock_get:
mock_get.return_value = mock_stock_data
async with Client(mcp) as client:
result = await client.call_tool(
"/technical_get_macd_analysis", {"ticker": "MSFT"}
)
assert result[0].text is not None
data = eval(result[0].text)
assert "analysis" in data
assert "ticker" in data
assert data["ticker"] == "MSFT"
assert "macd" in data["analysis"]
assert "signal" in data["analysis"]
assert "histogram" in data["analysis"]
assert "indicator" in data["analysis"]
@pytest.mark.asyncio
async def test_support_resistance(self, mock_stock_data):
"""Test support and resistance level detection."""
with patch(
"maverick_mcp.providers.stock_data.StockDataProvider.get_stock_data"
) as mock_get:
# Create data with clear support/resistance levels
mock_data = mock_stock_data.copy()
mock_data["High"] = [105 if i % 20 < 10 else 110 for i in range(250)]
mock_data["Low"] = [95 if i % 20 < 10 else 100 for i in range(250)]
mock_data["Close"] = [100 if i % 20 < 10 else 105 for i in range(250)]
mock_get.return_value = mock_data
async with Client(mcp) as client:
result = await client.call_tool(
"/technical_get_support_resistance", {"ticker": "GOOGL"}
)
assert result[0].text is not None
data = eval(result[0].text)
assert "support_levels" in data
assert "resistance_levels" in data
assert len(data["support_levels"]) > 0
assert len(data["resistance_levels"]) > 0
@pytest.mark.asyncio
async def test_batch_stock_data(self, mock_stock_data):
"""Test batch stock data fetching."""
with patch(
"maverick_mcp.providers.stock_data.StockDataProvider.get_stock_data"
) as mock_get:
mock_get.return_value = mock_stock_data
async with Client(mcp) as client:
result = await client.call_tool(
"/data_fetch_stock_data_batch",
{
"request": {
"tickers": ["AAPL", "MSFT", "GOOGL"],
"start_date": "2024-01-01",
"end_date": "2024-01-31",
}
},
)
assert result[0].text is not None
data = eval(result[0].text)
assert "results" in data
assert "success_count" in data
assert "error_count" in data
assert len(data["results"]) == 3
assert data["success_count"] == 3
assert data["error_count"] == 0
@pytest.mark.asyncio
async def test_portfolio_risk_analysis(self, mock_stock_data):
"""Test portfolio risk analysis."""
with patch(
"maverick_mcp.providers.stock_data.StockDataProvider.get_stock_data"
) as mock_get:
# Create correlated stock data
base_returns = np.random.normal(0.001, 0.02, 250)
mock_data1 = mock_stock_data.copy()
mock_data2 = mock_stock_data.copy()
mock_data3 = mock_stock_data.copy()
# Apply correlated returns and ensure lowercase column names
mock_data1.columns = mock_data1.columns.str.lower()
mock_data2.columns = mock_data2.columns.str.lower()
mock_data3.columns = mock_data3.columns.str.lower()
mock_data1["close"] = 100 * np.exp(np.cumsum(base_returns))
mock_data2["close"] = 100 * np.exp(
np.cumsum(base_returns * 0.8 + np.random.normal(0, 0.01, 250))
)
mock_data3["close"] = 100 * np.exp(
np.cumsum(base_returns * 0.6 + np.random.normal(0, 0.015, 250))
)
mock_get.return_value = mock_data1
async with Client(mcp) as client:
result = await client.call_tool(
"/portfolio_risk_adjusted_analysis",
{"ticker": "AAPL", "risk_level": 50.0},
)
assert result[0].text is not None
data = eval(result[0].text)
assert "ticker" in data
assert "risk_level" in data
assert "position_sizing" in data
assert "risk_management" in data
@pytest.mark.asyncio
async def test_maverick_screening(self):
"""Test Maverick stock screening."""
with (
patch("maverick_mcp.data.models.SessionLocal") as mock_session_cls,
patch(
"maverick_mcp.data.models.MaverickStocks.get_top_stocks"
) as mock_get_stocks,
):
# Mock database session (not used but needed for session lifecycle)
_ = mock_session_cls.return_value.__enter__.return_value
# Mock return data
class MockStock1:
def to_dict(self):
return {
"stock": "AAPL",
"close": 150.0,
"combined_score": 92,
"momentum_score": 88,
"adr_pct": 2.5,
}
class MockStock2:
def to_dict(self):
return {
"stock": "MSFT",
"close": 300.0,
"combined_score": 89,
"momentum_score": 85,
"adr_pct": 2.1,
}
mock_get_stocks.return_value = [MockStock1(), MockStock2()]
async with Client(mcp) as client:
result = await client.call_tool(
"/screening_get_maverick_stocks", {"limit": 10}
)
assert result[0].text is not None
data = eval(result[0].text)
assert "stocks" in data
assert len(data["stocks"]) == 2
assert data["stocks"][0]["stock"] == "AAPL"
@pytest.mark.asyncio
async def test_news_sentiment(self):
"""Test news sentiment analysis."""
with (
patch("requests.get") as mock_get,
patch(
"maverick_mcp.config.settings.settings.external_data.api_key",
"test_api_key",
),
patch(
"maverick_mcp.config.settings.settings.external_data.base_url",
"https://test-api.com",
),
):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"articles": [
{
"title": "Apple hits new highs",
"url": "https://example.com/1",
"summary": "Positive news about Apple",
"banner_image": "https://example.com/image1.jpg",
"time_published": "20240115T100000",
"overall_sentiment_score": 0.8,
"overall_sentiment_label": "Bullish",
}
]
}
mock_get.return_value = mock_response
async with Client(mcp) as client:
result = await client.call_tool(
"/data_get_news_sentiment", {"request": {"ticker": "AAPL"}}
)
assert result[0].text is not None
data = eval(result[0].text)
assert "articles" in data
assert len(data["articles"]) > 0
assert data["articles"][0]["overall_sentiment_label"] == "Bullish"
@pytest.mark.asyncio
async def test_full_technical_analysis(self, mock_stock_data):
"""Test comprehensive technical analysis."""
with patch(
"maverick_mcp.providers.stock_data.StockDataProvider.get_stock_data"
) as mock_get:
# Ensure lowercase column names for technical analysis
mock_data_lowercase = mock_stock_data.copy()
mock_data_lowercase.columns = mock_data_lowercase.columns.str.lower()
mock_get.return_value = mock_data_lowercase
async with Client(mcp) as client:
result = await client.call_tool(
"/technical_get_full_technical_analysis", {"ticker": "AAPL"}
)
assert result[0].text is not None
data = eval(result[0].text)
assert "indicators" in data
assert "rsi" in data["indicators"]
assert "macd" in data["indicators"]
assert "bollinger_bands" in data["indicators"]
assert "levels" in data
assert "current_price" in data
assert "last_updated" in data
@pytest.mark.asyncio
async def test_error_handling(self):
"""Test error handling for invalid requests."""
async with Client(mcp) as client:
# Test invalid ticker format
with pytest.raises(Exception) as exc_info:
await client.call_tool(
"/data_fetch_stock_data",
{
"request": {
"ticker": "INVALIDTICKER", # Too long (max 10 chars)
"start_date": "2024-01-01",
"end_date": "2024-01-31",
}
},
)
assert "validation error" in str(exc_info.value).lower()
# Test invalid date range
with pytest.raises(Exception) as exc_info:
await client.call_tool(
"/data_fetch_stock_data",
{
"request": {
"ticker": "AAPL",
"start_date": "2024-12-31",
"end_date": "2024-01-01", # End before start
}
},
)
assert (
"end date" in str(exc_info.value).lower()
and "start date" in str(exc_info.value).lower()
)
@pytest.mark.asyncio
async def test_caching_behavior(self, mock_stock_data):
"""Test that caching reduces API calls."""
call_count = 0
def mock_get_data(*args, **kwargs):
nonlocal call_count
call_count += 1
return mock_stock_data
with patch(
"maverick_mcp.providers.stock_data.StockDataProvider.get_stock_data",
side_effect=mock_get_data,
):
async with Client(mcp) as client:
# First call
await client.call_tool(
"/data_fetch_stock_data",
{
"request": {
"ticker": "AAPL",
"start_date": "2024-01-01",
"end_date": "2024-01-31",
}
},
)
assert call_count == 1
# Second call with same parameters should hit cache
await client.call_tool(
"/data_fetch_stock_data",
{
"request": {
"ticker": "AAPL",
"start_date": "2024-01-01",
"end_date": "2024-01-31",
}
},
)
# Note: In test environment without actual caching infrastructure,
# the call count may be 2. This is expected behavior.
assert call_count <= 2
if __name__ == "__main__":
pytest.main([__file__, "-v"])
```
--------------------------------------------------------------------------------
/maverick_mcp/database/optimization.py:
--------------------------------------------------------------------------------
```python
"""
Database optimization module for parallel backtesting performance.
Implements query optimization, bulk operations, and performance monitoring.
"""
import logging
import time
from contextlib import contextmanager
from typing import Any
import pandas as pd
from sqlalchemy import Index, event, text
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session
from maverick_mcp.data.models import (
PriceCache,
SessionLocal,
Stock,
engine,
)
logger = logging.getLogger(__name__)
class QueryOptimizer:
"""Database query optimization for backtesting performance."""
def __init__(self, session_factory=None):
"""Initialize query optimizer."""
self.session_factory = session_factory or SessionLocal
self._query_stats = {}
self._connection_pool_stats = {
"active_connections": 0,
"checked_out": 0,
"total_queries": 0,
"slow_queries": 0,
}
def create_backtesting_indexes(self, engine: Engine):
"""
Create optimized indexes for backtesting queries.
These indexes are specifically designed for the parallel backtesting
workload patterns.
"""
logger.info("Creating backtesting optimization indexes...")
# Define additional indexes for common backtesting query patterns
additional_indexes = [
# Composite index for date range queries with symbol lookup
Index(
"mcp_price_cache_symbol_date_range_idx",
Stock.__table__.c.ticker_symbol,
PriceCache.__table__.c.date,
PriceCache.__table__.c.close_price,
),
# Index for volume-based queries (common in strategy analysis)
Index(
"mcp_price_cache_volume_date_idx",
PriceCache.__table__.c.volume,
PriceCache.__table__.c.date,
),
# Covering index for OHLCV queries (includes all price data)
Index(
"mcp_price_cache_ohlcv_covering_idx",
PriceCache.__table__.c.stock_id,
PriceCache.__table__.c.date,
# Include all price columns as covering columns
PriceCache.__table__.c.open_price,
PriceCache.__table__.c.high_price,
PriceCache.__table__.c.low_price,
PriceCache.__table__.c.close_price,
PriceCache.__table__.c.volume,
),
# Index for latest price queries
Index(
"mcp_price_cache_latest_price_idx",
PriceCache.__table__.c.stock_id,
PriceCache.__table__.c.date.desc(),
),
# Partial index for recent data (last 2 years) - most commonly queried
# Note: This is PostgreSQL-specific, will be skipped for SQLite
]
try:
with engine.connect() as conn:
# Check if we're using PostgreSQL for partial indexes
is_postgresql = engine.dialect.name == "postgresql"
for index in additional_indexes:
try:
# Skip PostgreSQL-specific features on SQLite
if not is_postgresql and "partial" in str(index).lower():
continue
# Create index if it doesn't exist
index.create(conn, checkfirst=True)
logger.info(f"Created index: {index.name}")
except Exception as e:
logger.warning(f"Failed to create index {index.name}: {e}")
# Add PostgreSQL-specific optimizations
if is_postgresql:
try:
# Create partial index for recent data (last 2 years)
conn.execute(
text("""
CREATE INDEX CONCURRENTLY IF NOT EXISTS mcp_price_cache_recent_data_idx
ON mcp_price_cache (stock_id, date DESC, close_price)
WHERE date >= CURRENT_DATE - INTERVAL '2 years'
""")
)
logger.info("Created partial index for recent data")
# Update table statistics for better query planning
conn.execute(text("ANALYZE mcp_price_cache"))
conn.execute(text("ANALYZE mcp_stocks"))
logger.info("Updated table statistics")
except Exception as e:
logger.warning(
f"Failed to create PostgreSQL optimizations: {e}"
)
conn.commit()
except Exception as e:
logger.error(f"Failed to create backtesting indexes: {e}")
def optimize_connection_pool(self, engine: Engine):
"""Optimize connection pool settings for parallel operations."""
logger.info("Optimizing connection pool for parallel backtesting...")
# Add connection pool event listeners for monitoring
@event.listens_for(engine, "connect")
def receive_connect(dbapi_connection, connection_record):
self._connection_pool_stats["active_connections"] += 1
@event.listens_for(engine, "checkout")
def receive_checkout(dbapi_connection, connection_record, connection_proxy):
self._connection_pool_stats["checked_out"] += 1
@event.listens_for(engine, "checkin")
def receive_checkin(dbapi_connection, connection_record):
self._connection_pool_stats["checked_out"] -= 1
def create_bulk_insert_method(self):
"""Create optimized bulk insert method for price data."""
def bulk_insert_price_data_optimized(
session: Session,
price_data_list: list[dict[str, Any]],
batch_size: int = 1000,
):
"""
Optimized bulk insert for price data with batching.
Args:
session: Database session
price_data_list: List of price data dictionaries
batch_size: Number of records per batch
"""
if not price_data_list:
return
logger.info(f"Bulk inserting {len(price_data_list)} price records")
start_time = time.time()
try:
# Process in batches to avoid memory issues
for i in range(0, len(price_data_list), batch_size):
batch = price_data_list[i : i + batch_size]
# Use bulk_insert_mappings for better performance
session.bulk_insert_mappings(PriceCache, batch)
# Commit each batch to free up memory
if i + batch_size < len(price_data_list):
session.flush()
session.commit()
elapsed = time.time() - start_time
logger.info(
f"Bulk insert completed in {elapsed:.2f}s "
f"({len(price_data_list) / elapsed:.0f} records/sec)"
)
except Exception as e:
logger.error(f"Bulk insert failed: {e}")
session.rollback()
raise
return bulk_insert_price_data_optimized
@contextmanager
def query_performance_monitor(self, query_name: str):
"""Context manager for monitoring query performance."""
start_time = time.time()
try:
yield
finally:
elapsed = time.time() - start_time
# Track query statistics
if query_name not in self._query_stats:
self._query_stats[query_name] = {
"count": 0,
"total_time": 0.0,
"avg_time": 0.0,
"max_time": 0.0,
"slow_queries": 0,
}
stats = self._query_stats[query_name]
stats["count"] += 1
stats["total_time"] += elapsed
stats["avg_time"] = stats["total_time"] / stats["count"]
stats["max_time"] = max(stats["max_time"], elapsed)
# Mark slow queries (> 1 second)
if elapsed > 1.0:
stats["slow_queries"] += 1
self._connection_pool_stats["slow_queries"] += 1
logger.warning(f"Slow query detected: {query_name} took {elapsed:.2f}s")
self._connection_pool_stats["total_queries"] += 1
def get_optimized_price_query(self) -> str:
"""Get optimized SQL query for price data retrieval."""
return """
SELECT
pc.date,
pc.open_price as "open",
pc.high_price as "high",
pc.low_price as "low",
pc.close_price as "close",
pc.volume
FROM mcp_price_cache pc
JOIN mcp_stocks s ON pc.stock_id = s.stock_id
WHERE s.ticker_symbol = :symbol
AND pc.date >= :start_date
AND pc.date <= :end_date
ORDER BY pc.date
"""
def get_batch_price_query(self) -> str:
"""Get optimized SQL query for batch price data retrieval."""
return """
SELECT
s.ticker_symbol,
pc.date,
pc.open_price as "open",
pc.high_price as "high",
pc.low_price as "low",
pc.close_price as "close",
pc.volume
FROM mcp_price_cache pc
JOIN mcp_stocks s ON pc.stock_id = s.stock_id
WHERE s.ticker_symbol = ANY(:symbols)
AND pc.date >= :start_date
AND pc.date <= :end_date
ORDER BY s.ticker_symbol, pc.date
"""
def execute_optimized_query(
self,
session: Session,
query: str,
params: dict[str, Any],
query_name: str = "unnamed",
) -> pd.DataFrame:
"""Execute optimized query with performance monitoring."""
with self.query_performance_monitor(query_name):
try:
result = pd.read_sql(
text(query),
session.bind,
params=params,
index_col="date" if "date" in query.lower() else None,
parse_dates=["date"] if "date" in query.lower() else None,
)
logger.debug(f"Query {query_name} returned {len(result)} rows")
return result
except Exception as e:
logger.error(f"Query {query_name} failed: {e}")
raise
def get_statistics(self) -> dict[str, Any]:
"""Get query and connection pool statistics."""
return {
"query_stats": self._query_stats.copy(),
"connection_pool_stats": self._connection_pool_stats.copy(),
"top_slow_queries": sorted(
[
(name, stats["avg_time"])
for name, stats in self._query_stats.items()
],
key=lambda x: x[1],
reverse=True,
)[:5],
}
def reset_statistics(self):
"""Reset performance statistics."""
self._query_stats.clear()
self._connection_pool_stats = {
"active_connections": 0,
"checked_out": 0,
"total_queries": 0,
"slow_queries": 0,
}
class BatchQueryExecutor:
"""Efficient batch query execution for parallel backtesting."""
def __init__(self, optimizer: QueryOptimizer = None):
"""Initialize batch query executor."""
self.optimizer = optimizer or QueryOptimizer()
async def fetch_multiple_symbols_data(
self,
symbols: list[str],
start_date: str,
end_date: str,
session: Session = None,
) -> dict[str, pd.DataFrame]:
"""
Efficiently fetch data for multiple symbols in a single query.
Args:
symbols: List of stock symbols
start_date: Start date (YYYY-MM-DD)
end_date: End date (YYYY-MM-DD)
session: Optional database session
Returns:
Dictionary mapping symbols to DataFrames
"""
if not symbols:
return {}
should_close = session is None
if session is None:
session = self.optimizer.session_factory()
try:
# Use batch query to fetch all symbols at once
batch_query = self.optimizer.get_batch_price_query()
result_df = self.optimizer.execute_optimized_query(
session=session,
query=batch_query,
params={
"symbols": symbols,
"start_date": start_date,
"end_date": end_date,
},
query_name="batch_symbol_fetch",
)
# Group by symbol and create separate DataFrames
symbol_data = {}
if not result_df.empty:
for symbol in symbols:
symbol_df = result_df[result_df["ticker_symbol"] == symbol].copy()
symbol_df.drop("ticker_symbol", axis=1, inplace=True)
symbol_data[symbol] = symbol_df
else:
# Return empty DataFrames for all symbols
symbol_data = {symbol: pd.DataFrame() for symbol in symbols}
logger.info(
f"Batch fetched data for {len(symbols)} symbols: "
f"{sum(len(df) for df in symbol_data.values())} total records"
)
return symbol_data
finally:
if should_close:
session.close()
# Global instances for easy access
_query_optimizer = QueryOptimizer()
_batch_executor = BatchQueryExecutor(_query_optimizer)
def get_query_optimizer() -> QueryOptimizer:
"""Get the global query optimizer instance."""
return _query_optimizer
def get_batch_executor() -> BatchQueryExecutor:
"""Get the global batch executor instance."""
return _batch_executor
def initialize_database_optimizations():
"""Initialize all database optimizations for backtesting."""
logger.info("Initializing database optimizations for parallel backtesting...")
try:
optimizer = get_query_optimizer()
# Create performance indexes
optimizer.create_backtesting_indexes(engine)
# Optimize connection pool
optimizer.optimize_connection_pool(engine)
logger.info("Database optimizations initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize database optimizations: {e}")
@contextmanager
def optimized_db_session():
"""Context manager for optimized database session."""
session = SessionLocal()
try:
# Configure session for optimal performance
session.execute(text("PRAGMA synchronous = NORMAL")) # SQLite optimization
session.execute(text("PRAGMA journal_mode = WAL")) # SQLite optimization
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
# Performance monitoring decorator
def monitor_query_performance(query_name: str):
"""Decorator for monitoring query performance."""
def decorator(func):
def wrapper(*args, **kwargs):
optimizer = get_query_optimizer()
with optimizer.query_performance_monitor(query_name):
return func(*args, **kwargs)
return wrapper
return decorator
```
--------------------------------------------------------------------------------
/maverick_mcp/utils/tracing.py:
--------------------------------------------------------------------------------
```python
"""
OpenTelemetry distributed tracing integration for MaverickMCP.
This module provides comprehensive distributed tracing capabilities including:
- Automatic span creation for database queries, external API calls, and tool executions
- Integration with FastMCP and FastAPI
- Support for multiple tracing backends (Jaeger, Zipkin, OTLP)
- Correlation with structured logging
"""
import functools
import os
import time
from collections.abc import Callable
from contextlib import contextmanager
from typing import Any
from maverick_mcp.config.settings import settings
from maverick_mcp.utils.logging import get_logger
# OpenTelemetry imports with graceful fallback
try:
from opentelemetry import trace # type: ignore[import-untyped]
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import (
OTLPSpanExporter, # type: ignore[import-untyped]
)
from opentelemetry.exporter.zipkin.json import (
ZipkinExporter, # type: ignore[import-untyped]
)
from opentelemetry.instrumentation.asyncio import (
AsyncioInstrumentor, # type: ignore[import-untyped]
)
from opentelemetry.instrumentation.asyncpg import (
AsyncPGInstrumentor, # type: ignore[import-untyped]
)
from opentelemetry.instrumentation.fastapi import (
FastAPIInstrumentor, # type: ignore[import-untyped]
)
from opentelemetry.instrumentation.httpx import (
HTTPXInstrumentor, # type: ignore[import-untyped]
)
from opentelemetry.instrumentation.redis import (
RedisInstrumentor, # type: ignore[import-untyped]
)
from opentelemetry.instrumentation.requests import (
RequestsInstrumentor, # type: ignore[import-untyped]
)
from opentelemetry.instrumentation.sqlalchemy import (
SQLAlchemyInstrumentor, # type: ignore[import-untyped]
)
from opentelemetry.propagate import (
set_global_textmap, # type: ignore[import-untyped]
)
from opentelemetry.propagators.b3 import (
B3MultiFormat, # type: ignore[import-untyped]
)
from opentelemetry.sdk.resources import Resource # type: ignore[import-untyped]
from opentelemetry.sdk.trace import TracerProvider # type: ignore[import-untyped]
from opentelemetry.sdk.trace.export import ( # type: ignore[import-untyped]
BatchSpanProcessor,
ConsoleSpanExporter,
)
from opentelemetry.semconv.resource import (
ResourceAttributes, # type: ignore[import-untyped]
)
from opentelemetry.trace import Status, StatusCode # type: ignore[import-untyped]
OTEL_AVAILABLE = True
except ImportError:
# Create stub classes for when OpenTelemetry is not available
class _TracerStub:
def start_span(self, name: str, **kwargs):
return _SpanStub()
def start_as_current_span(self, name: str, **kwargs):
return _SpanStub()
class _SpanStub:
def __enter__(self):
return self
def __exit__(self, *args):
pass
def set_attribute(self, key: str, value: Any):
pass
def set_status(self, status):
pass
def record_exception(self, exception: Exception):
pass
def add_event(self, name: str, attributes: dict[str, Any] | None = None):
pass
# Create stub types for type annotations
class TracerProvider:
pass
trace = type("trace", (), {"get_tracer": lambda name: _TracerStub()})()
OTEL_AVAILABLE = False
logger = get_logger(__name__)
class TracingService:
"""Service for distributed tracing configuration and management."""
def __init__(self):
self.tracer = None
self.enabled = False
self._initialize_tracing()
def _initialize_tracing(self):
"""Initialize OpenTelemetry tracing."""
if not OTEL_AVAILABLE:
return
# Check if tracing is enabled
tracing_enabled = os.getenv("OTEL_TRACING_ENABLED", "false").lower() == "true"
if not tracing_enabled and settings.environment != "development":
logger.info("OpenTelemetry tracing disabled")
return
try:
# Create resource
resource = Resource.create(
{
ResourceAttributes.SERVICE_NAME: settings.app_name,
ResourceAttributes.SERVICE_VERSION: os.getenv(
"RELEASE_VERSION", "unknown"
),
ResourceAttributes.SERVICE_NAMESPACE: "maverick-mcp",
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: settings.environment,
}
)
# Configure tracer provider
tracer_provider = TracerProvider(resource=resource)
trace.set_tracer_provider(tracer_provider)
# Configure exporters
self._configure_exporters(tracer_provider)
# Configure propagators
self._configure_propagators()
# Instrument libraries
self._instrument_libraries()
# Create tracer
self.tracer = trace.get_tracer(__name__)
self.enabled = True
logger.info("OpenTelemetry tracing initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize OpenTelemetry tracing: {e}")
def _configure_exporters(self, tracer_provider: TracerProvider):
"""Configure trace exporters based on environment variables."""
# Console exporter (for development)
if settings.environment == "development":
console_exporter = ConsoleSpanExporter()
tracer_provider.add_span_processor(BatchSpanProcessor(console_exporter)) # type: ignore[attr-defined]
# Jaeger exporter via OTLP (modern approach)
jaeger_endpoint = os.getenv("JAEGER_ENDPOINT")
if jaeger_endpoint:
# Modern Jaeger deployments accept OTLP on port 4317 (gRPC) or 4318 (HTTP)
# Convert legacy Jaeger collector endpoint to OTLP format if needed
if "14268" in jaeger_endpoint: # Legacy Jaeger HTTP port
otlp_endpoint = jaeger_endpoint.replace(":14268", ":4318").replace(
"/api/traces", ""
)
logger.info(
f"Converting legacy Jaeger endpoint {jaeger_endpoint} to OTLP: {otlp_endpoint}"
)
else:
otlp_endpoint = jaeger_endpoint
jaeger_otlp_exporter = OTLPSpanExporter(
endpoint=otlp_endpoint,
# Add Jaeger-specific headers if needed
headers={},
)
tracer_provider.add_span_processor(BatchSpanProcessor(jaeger_otlp_exporter)) # type: ignore[attr-defined]
logger.info(f"Jaeger OTLP exporter configured: {otlp_endpoint}")
# Zipkin exporter
zipkin_endpoint = os.getenv("ZIPKIN_ENDPOINT")
if zipkin_endpoint:
zipkin_exporter = ZipkinExporter(endpoint=zipkin_endpoint)
tracer_provider.add_span_processor(BatchSpanProcessor(zipkin_exporter)) # type: ignore[attr-defined]
logger.info(f"Zipkin exporter configured: {zipkin_endpoint}")
# OTLP exporter (for services like Honeycomb, New Relic, etc.)
otlp_endpoint = os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT")
if otlp_endpoint:
otlp_exporter = OTLPSpanExporter(
endpoint=otlp_endpoint,
headers={"x-honeycomb-team": os.getenv("HONEYCOMB_API_KEY", "")},
)
tracer_provider.add_span_processor(BatchSpanProcessor(otlp_exporter)) # type: ignore[attr-defined]
logger.info(f"OTLP exporter configured: {otlp_endpoint}")
def _configure_propagators(self):
"""Configure trace propagators for cross-service communication."""
# Use B3 propagator for maximum compatibility
set_global_textmap(B3MultiFormat())
logger.info("B3 trace propagator configured")
def _instrument_libraries(self):
"""Automatically instrument common libraries."""
try:
# FastAPI instrumentation
FastAPIInstrumentor().instrument()
# Database instrumentation
SQLAlchemyInstrumentor().instrument()
AsyncPGInstrumentor().instrument()
# HTTP client instrumentation
RequestsInstrumentor().instrument()
HTTPXInstrumentor().instrument()
# Redis instrumentation
RedisInstrumentor().instrument()
# Asyncio instrumentation
AsyncioInstrumentor().instrument()
logger.info("Auto-instrumentation completed successfully")
except Exception as e:
logger.warning(f"Some auto-instrumentation failed: {e}")
@contextmanager
def trace_operation(
self,
operation_name: str,
attributes: dict[str, Any] | None = None,
record_exception: bool = True,
):
"""
Context manager for tracing operations.
Args:
operation_name: Name of the operation being traced
attributes: Additional attributes to add to the span
record_exception: Whether to record exceptions in the span
"""
if not self.enabled:
yield None
return
with self.tracer.start_as_current_span(operation_name) as span:
# Add attributes
if attributes:
for key, value in attributes.items():
span.set_attribute(key, str(value))
try:
yield span
span.set_status(Status(StatusCode.OK))
except Exception as e:
span.set_status(Status(StatusCode.ERROR, str(e)))
if record_exception:
span.record_exception(e)
raise
def trace_tool_execution(self, func: Callable) -> Callable:
"""
Decorator to trace tool execution.
Args:
func: The tool function to trace
Returns:
Decorated function with tracing
"""
@functools.wraps(func)
async def wrapper(*args, **kwargs):
if not self.enabled:
return await func(*args, **kwargs)
tool_name = getattr(func, "__name__", "unknown_tool")
with self.trace_operation(
f"tool.{tool_name}",
attributes={
"tool.name": tool_name,
"tool.args_count": len(args),
"tool.kwargs_count": len(kwargs),
},
) as span:
# Add user context if available
for arg in args:
if hasattr(arg, "user_id"):
span.set_attribute("user.id", str(arg.user_id))
break
start_time = time.time()
result = await func(*args, **kwargs)
duration = time.time() - start_time
span.set_attribute("tool.duration_seconds", duration)
span.set_attribute("tool.success", True)
return result
return wrapper
def trace_database_query(
self, query_type: str, table: str | None = None, query: str | None = None
):
"""
Context manager for tracing database queries.
Args:
query_type: Type of query (SELECT, INSERT, UPDATE, DELETE)
table: Table name being queried
query: The actual SQL query (will be truncated for security)
"""
attributes = {
"db.operation": query_type,
"db.system": "postgresql",
}
if table:
attributes["db.table"] = table
if query:
# Truncate query for security and performance
attributes["db.statement"] = (
query[:200] + "..." if len(query) > 200 else query
)
return self.trace_operation(f"db.{query_type.lower()}", attributes)
def trace_external_api_call(self, service: str, endpoint: str, method: str = "GET"):
"""
Context manager for tracing external API calls.
Args:
service: Name of the external service
endpoint: API endpoint being called
method: HTTP method
"""
attributes = {
"http.method": method,
"http.url": endpoint,
"service.name": service,
}
return self.trace_operation(f"http.{method.lower()}", attributes)
def trace_cache_operation(self, operation: str, cache_type: str = "redis"):
"""
Context manager for tracing cache operations.
Args:
operation: Cache operation (get, set, delete, etc.)
cache_type: Type of cache (redis, memory, etc.)
"""
attributes = {
"cache.operation": operation,
"cache.type": cache_type,
}
return self.trace_operation(f"cache.{operation}", attributes)
def add_event(self, name: str, attributes: dict[str, Any] | None = None):
"""Add an event to the current span."""
if not self.enabled:
return
current_span = trace.get_current_span()
if current_span:
current_span.add_event(name, attributes or {})
def set_user_context(self, user_id: str, email: str | None = None):
"""Set user context on the current span."""
if not self.enabled:
return
current_span = trace.get_current_span()
if current_span:
current_span.set_attribute("user.id", user_id)
if email:
current_span.set_attribute("user.email", email)
# Global tracing service instance
_tracing_service: TracingService | None = None
def get_tracing_service() -> TracingService:
"""Get or create the global tracing service."""
global _tracing_service
if _tracing_service is None:
_tracing_service = TracingService()
return _tracing_service
def trace_tool(func: Callable) -> Callable:
"""Decorator for tracing tool execution."""
tracing = get_tracing_service()
return tracing.trace_tool_execution(func)
@contextmanager
def trace_operation(
operation_name: str,
attributes: dict[str, Any] | None = None,
record_exception: bool = True,
):
"""Context manager for tracing operations."""
tracing = get_tracing_service()
with tracing.trace_operation(operation_name, attributes, record_exception) as span:
yield span
@contextmanager
def trace_database_query(
query_type: str, table: str | None = None, query: str | None = None
):
"""Context manager for tracing database queries."""
tracing = get_tracing_service()
with tracing.trace_database_query(query_type, table, query) as span:
yield span
@contextmanager
def trace_external_api_call(service: str, endpoint: str, method: str = "GET"):
"""Context manager for tracing external API calls."""
tracing = get_tracing_service()
with tracing.trace_external_api_call(service, endpoint, method) as span:
yield span
@contextmanager
def trace_cache_operation(operation: str, cache_type: str = "redis"):
"""Context manager for tracing cache operations."""
tracing = get_tracing_service()
with tracing.trace_cache_operation(operation, cache_type) as span:
yield span
def initialize_tracing():
"""Initialize the global tracing service."""
logger.info("Initializing distributed tracing...")
tracing = get_tracing_service()
if tracing.enabled:
logger.info("Distributed tracing initialized successfully")
else:
logger.info("Distributed tracing disabled or unavailable")
```
--------------------------------------------------------------------------------
/maverick_mcp/api/routers/technical.py:
--------------------------------------------------------------------------------
```python
"""
Technical analysis router for MaverickMCP.
This module contains all technical analysis related tools including
indicators, chart patterns, and analysis functions.
DISCLAIMER: All technical analysis tools are for educational purposes only.
Technical indicators are mathematical calculations based on historical data and
do not predict future price movements. Results should not be considered as
investment advice. Always consult qualified financial professionals.
"""
import asyncio
from concurrent.futures import ThreadPoolExecutor
from datetime import UTC, datetime
from typing import Any
from fastmcp import FastMCP
from fastmcp.server.dependencies import get_access_token
from maverick_mcp.core.technical_analysis import (
analyze_bollinger_bands,
analyze_macd,
analyze_rsi,
analyze_stochastic,
analyze_trend,
analyze_volume,
generate_outlook,
identify_chart_patterns,
identify_resistance_levels,
identify_support_levels,
)
from maverick_mcp.core.visualization import (
create_plotly_technical_chart,
plotly_fig_to_base64,
)
from maverick_mcp.providers.stock_data import StockDataProvider
from maverick_mcp.utils.logging import PerformanceMonitor, get_logger
from maverick_mcp.utils.mcp_logging import with_logging
from maverick_mcp.utils.stock_helpers import (
get_stock_dataframe_async,
)
logger = get_logger("maverick_mcp.routers.technical")
# Create the technical analysis router
technical_router: FastMCP = FastMCP("Technical_Analysis")
# Initialize data provider
stock_provider = StockDataProvider()
# Thread pool for blocking operations
executor = ThreadPoolExecutor(max_workers=10)
@with_logging("rsi_analysis")
async def get_rsi_analysis(
ticker: str, period: int = 14, days: int = 365
) -> dict[str, Any]:
"""
Get RSI analysis for a given ticker.
Args:
ticker: Stock ticker symbol
period: RSI period (default: 14)
days: Number of days of historical data to analyze (default: 365)
Returns:
Dictionary containing RSI analysis
"""
try:
# Log analysis parameters
logger.info(
"Starting RSI analysis",
extra={"ticker": ticker, "period": period, "days": days},
)
# Fetch stock data with performance monitoring
with PerformanceMonitor(f"fetch_data_{ticker}"):
df = await get_stock_dataframe_async(ticker, days)
# Perform RSI analysis with monitoring
with PerformanceMonitor(f"rsi_calculation_{ticker}"):
loop = asyncio.get_event_loop()
analysis = await loop.run_in_executor(executor, analyze_rsi, df)
# Log successful completion
logger.info(
"RSI analysis completed successfully",
extra={
"ticker": ticker,
"rsi_current": analysis.get("current_rsi"),
"signal": analysis.get("signal"),
},
)
return {"ticker": ticker, "period": period, "analysis": analysis}
except Exception as e:
logger.error(
"Error in RSI analysis",
exc_info=True,
extra={"ticker": ticker, "period": period, "error_type": type(e).__name__},
)
return {"error": str(e), "status": "error"}
async def get_macd_analysis(
ticker: str,
fast_period: int = 12,
slow_period: int = 26,
signal_period: int = 9,
days: int = 365,
) -> dict[str, Any]:
"""
Get MACD analysis for a given ticker.
Args:
ticker: Stock ticker symbol
fast_period: Fast EMA period (default: 12)
slow_period: Slow EMA period (default: 26)
signal_period: Signal line period (default: 9)
days: Number of days of historical data to analyze (default: 365)
Returns:
Dictionary containing MACD analysis
"""
try:
df = await get_stock_dataframe_async(ticker, days)
analysis = analyze_macd(df)
return {
"ticker": ticker,
"parameters": {
"fast_period": fast_period,
"slow_period": slow_period,
"signal_period": signal_period,
},
"analysis": analysis,
}
except Exception as e:
logger.error(f"Error in MACD analysis for {ticker}: {str(e)}")
return {"error": str(e), "status": "error"}
async def get_support_resistance(ticker: str, days: int = 365) -> dict[str, Any]:
"""
Get support and resistance levels for a given ticker.
Args:
ticker: Stock ticker symbol
days: Number of days of historical data to analyze (default: 365)
Returns:
Dictionary containing support and resistance levels
"""
try:
df = await get_stock_dataframe_async(ticker, days)
support = identify_support_levels(df)
resistance = identify_resistance_levels(df)
current_price = df["close"].iloc[-1]
return {
"ticker": ticker,
"current_price": float(current_price),
"support_levels": sorted(support),
"resistance_levels": sorted(resistance),
}
except Exception as e:
logger.error(f"Error in support/resistance analysis for {ticker}: {str(e)}")
return {"error": str(e), "status": "error"}
async def get_full_technical_analysis(ticker: str, days: int = 365) -> dict[str, Any]:
"""
Get comprehensive technical analysis for a given ticker.
This tool provides a complete technical analysis including:
- Trend analysis
- All major indicators (RSI, MACD, Stochastic, Bollinger Bands)
- Support and resistance levels
- Volume analysis
- Chart patterns
- Overall outlook
Args:
ticker: Stock ticker symbol
days: Number of days of historical data to analyze (default: 365)
Returns:
Dictionary containing complete technical analysis
"""
try:
# Access authentication context if available (optional for this tool)
# This demonstrates optional authentication - tool works without auth
# but provides enhanced features for authenticated users
has_premium = False
try:
access_token = get_access_token()
if access_token is None:
raise ValueError("No access token available")
# Log authenticated user
logger.info(
f"Technical analysis requested by authenticated user: {access_token.client_id}",
extra={"scopes": access_token.scopes},
)
# Check for premium features based on scopes
has_premium = "premium:access" in access_token.scopes
logger.info(f"Has premium: {has_premium}")
except Exception:
# Authentication is optional for this tool
logger.debug("Technical analysis requested by unauthenticated user")
df = await get_stock_dataframe_async(ticker, days)
# Perform all analyses
trend = analyze_trend(df)
rsi_analysis = analyze_rsi(df)
macd_analysis = analyze_macd(df)
stoch_analysis = analyze_stochastic(df)
bb_analysis = analyze_bollinger_bands(df)
volume_analysis = analyze_volume(df)
patterns = identify_chart_patterns(df)
support = identify_support_levels(df)
resistance = identify_resistance_levels(df)
outlook = generate_outlook(
df, str(trend), rsi_analysis, macd_analysis, stoch_analysis
)
# Get current price and indicators
current_price = df["close"].iloc[-1]
# Compile results
return {
"ticker": ticker,
"current_price": float(current_price),
"trend": trend,
"outlook": outlook,
"indicators": {
"rsi": rsi_analysis,
"macd": macd_analysis,
"stochastic": stoch_analysis,
"bollinger_bands": bb_analysis,
"volume": volume_analysis,
},
"levels": {"support": sorted(support), "resistance": sorted(resistance)},
"patterns": patterns,
"last_updated": datetime.now(UTC).isoformat(),
}
except Exception as e:
logger.error(f"Error in technical analysis for {ticker}: {str(e)}")
return {"error": str(e), "status": "error"}
async def get_stock_chart_analysis(ticker: str) -> dict[str, Any]:
"""
Generate a comprehensive technical analysis chart.
This tool creates a visual technical analysis including:
- Price action with candlesticks
- Moving averages
- Volume analysis
- Technical indicators
- Support and resistance levels
Args:
ticker: The ticker symbol of the stock to analyze
Returns:
Dictionary containing the chart as properly formatted MCP image content for Claude Desktop display
"""
try:
# Use async data fetching
df = await get_stock_dataframe_async(ticker, 365)
# Run the chart generation in the executor for performance
loop = asyncio.get_event_loop()
chart_content = await loop.run_in_executor(
executor, _generate_chart_mcp_format, df, ticker
)
return chart_content
except Exception as e:
logger.error(f"Error generating chart analysis for {ticker}: {e}")
return {"error": str(e)}
def _generate_chart_mcp_format(df, ticker: str) -> dict[str, Any]:
"""Generate chart in proper MCP content format for Claude Desktop with aggressive size optimization"""
from maverick_mcp.core.technical_analysis import add_technical_indicators
df = add_technical_indicators(df)
# Claude Desktop has a ~100k character limit for responses
# Base64 images need to be MUCH smaller - aim for ~50k chars max
chart_configs = [
{"height": 300, "width": 500, "format": "jpeg"}, # Small primary
{"height": 250, "width": 400, "format": "jpeg"}, # Smaller fallback
{"height": 200, "width": 350, "format": "jpeg"}, # Tiny fallback
{"height": 150, "width": 300, "format": "jpeg"}, # Last resort
]
for config in chart_configs:
try:
# Generate chart with current config
analysis = create_plotly_technical_chart(
df, ticker, height=config["height"], width=config["width"]
)
# Generate base64 data URI
data_uri = plotly_fig_to_base64(analysis, format=config["format"])
# Extract base64 data without the data URI prefix
if data_uri.startswith(f"data:image/{config['format']};base64,"):
base64_data = data_uri.split(",", 1)[1]
mime_type = f"image/{config['format']}"
else:
# Fallback - assume it's already base64 data
base64_data = data_uri
mime_type = f"image/{config['format']}"
# Very conservative size limit for Claude Desktop
# Response gets truncated at 100k chars, so aim for 50k max for base64
max_chars = 50000
logger.info(
f"Generated chart for {ticker}: {config['width']}x{config['height']} "
f"({len(base64_data):,} chars base64)"
)
if len(base64_data) <= max_chars:
# Try multiple formats to work around Claude Desktop bugs
description = (
f"Technical analysis chart for {ticker.upper()} "
f"({config['width']}x{config['height']}) showing price action, "
f"moving averages, volume, RSI, and MACD indicators."
)
return _return_image_with_claude_desktop_workaround(
base64_data, mime_type, description, ticker
)
else:
logger.warning(
f"Chart for {ticker} too large at {config['width']}x{config['height']} "
f"({len(base64_data):,} chars > {max_chars}), trying smaller size..."
)
continue
except Exception as e:
logger.warning(f"Failed to generate chart with config {config}: {e}")
continue
# If all configs failed, return error
return {
"content": [
{
"type": "text",
"text": (
f"Unable to generate suitable chart size for {ticker.upper()}. "
f"The chart image is too large for Claude Desktop display limits. "
f"Please use the text-based technical analysis tool instead: "
f"technical_get_full_technical_analysis"
),
}
]
}
def _return_image_with_claude_desktop_workaround(
base64_data: str, mime_type: str, description: str, ticker: str
) -> dict[str, Any]:
"""
Return image using multiple formats to work around Claude Desktop bugs.
Tries alternative MCP format first, fallback to file saving.
"""
import base64 as b64
import tempfile
from pathlib import Path
# Format 1: Alternative "source" structure (some reports of this working)
try:
return {
"content": [
{"type": "text", "text": description},
{
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": base64_data,
},
},
]
}
except Exception as e:
logger.warning(f"Alternative image format failed: {e}")
# Format 2: Try original format one more time with different structure
try:
return {
"content": [
{"type": "text", "text": description},
{"type": "image", "data": base64_data, "mimeType": mime_type},
]
}
except Exception as e:
logger.warning(f"Standard image format failed: {e}")
# Format 3: File-based fallback (most reliable for Claude Desktop)
try:
ext = mime_type.split("/")[-1] # jpeg, png, etc.
# Create temp file in a standard location
temp_dir = Path(tempfile.gettempdir()) / "maverick_mcp_charts"
temp_dir.mkdir(exist_ok=True)
chart_file = (
temp_dir
/ f"{ticker.lower()}_chart_{datetime.now(UTC).strftime('%Y%m%d_%H%M%S')}.{ext}"
)
# Decode and save base64 to file
image_data = b64.b64decode(base64_data)
chart_file.write_bytes(image_data)
logger.info(f"Saved chart to file: {chart_file}")
return {
"content": [
{
"type": "text",
"text": (
f"{description}\n\n"
f"📁 **Chart saved to file**: `{chart_file}`\n\n"
f"**To view this image:**\n"
f"1. Use the filesystem MCP server if configured, or\n"
f"2. Ask me to open the file location, or\n"
f"3. Navigate to the file manually\n\n"
f"*Note: Claude Desktop has a known issue with embedded images. "
f"File-based display is the current workaround.*"
),
}
]
}
except Exception as e:
logger.error(f"File fallback also failed: {e}")
return {
"content": [
{
"type": "text",
"text": (
f"Unable to display chart for {ticker.upper()} due to "
f"Claude Desktop image rendering limitations. "
f"Please use the text-based technical analysis instead: "
f"`technical_get_full_technical_analysis`"
),
}
]
}
```
--------------------------------------------------------------------------------
/maverick_mcp/domain/stock_analysis/stock_analysis_service.py:
--------------------------------------------------------------------------------
```python
"""
Stock Analysis Service - Domain service that orchestrates data fetching and caching.
"""
import logging
from datetime import UTC, datetime, timedelta
import pandas as pd
import pandas_market_calendars as mcal
import pytz
from sqlalchemy.orm import Session
from maverick_mcp.infrastructure.caching import CacheManagementService
from maverick_mcp.infrastructure.data_fetching import StockDataFetchingService
logger = logging.getLogger("maverick_mcp.stock_analysis")
class StockAnalysisService:
"""
Domain service that orchestrates stock data retrieval with intelligent caching.
This service:
- Contains business logic for stock data retrieval
- Orchestrates data fetching and caching services
- Implements smart caching strategies
- Uses dependency injection for service composition
"""
def __init__(
self,
data_fetching_service: StockDataFetchingService,
cache_service: CacheManagementService,
db_session: Session | None = None,
):
"""
Initialize the stock analysis service.
Args:
data_fetching_service: Service for fetching data from external sources
cache_service: Service for cache management
db_session: Optional database session for dependency injection
"""
self.data_fetching_service = data_fetching_service
self.cache_service = cache_service
self.db_session = db_session
# Initialize NYSE calendar for US stock market
self.market_calendar = mcal.get_calendar("NYSE")
def get_stock_data(
self,
symbol: str,
start_date: str | None = None,
end_date: str | None = None,
period: str | None = None,
interval: str = "1d",
use_cache: bool = True,
) -> pd.DataFrame:
"""
Get stock data with intelligent caching strategy.
This method:
1. Gets all available data from cache
2. Identifies missing date ranges
3. Fetches only missing data from external sources
4. Combines and returns the complete dataset
Args:
symbol: Stock ticker symbol
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
period: Alternative to start/end dates (e.g., '1d', '5d', '1mo', etc.)
interval: Data interval ('1d', '1wk', '1mo', '1m', '5m', etc.)
use_cache: Whether to use cached data if available
Returns:
DataFrame with stock data
"""
symbol = symbol.upper()
# For non-daily intervals or periods, always fetch fresh data
if interval != "1d" or period:
logger.info(
f"Non-daily interval or period specified, fetching fresh data for {symbol}"
)
return self.data_fetching_service.fetch_stock_data(
symbol, start_date, end_date, period, interval
)
# Set default dates if not provided
if start_date is None:
start_date = (datetime.now(UTC) - timedelta(days=365)).strftime("%Y-%m-%d")
if end_date is None:
end_date = datetime.now(UTC).strftime("%Y-%m-%d")
# For daily data, adjust end date to last trading day if it's not a trading day
if interval == "1d" and use_cache:
end_dt = pd.to_datetime(end_date)
if not self._is_trading_day(end_dt):
last_trading = self._get_last_trading_day(end_dt)
logger.debug(
f"Adjusting end date from {end_date} to last trading day {last_trading.strftime('%Y-%m-%d')}"
)
end_date = last_trading.strftime("%Y-%m-%d")
# If cache is disabled, fetch directly
if not use_cache:
logger.info(f"Cache disabled, fetching fresh data for {symbol}")
return self.data_fetching_service.fetch_stock_data(
symbol, start_date, end_date, period, interval
)
# Use smart caching strategy
try:
return self._get_data_with_smart_cache(
symbol, start_date, end_date, interval
)
except Exception as e:
logger.warning(
f"Smart cache failed for {symbol}, falling back to fresh data: {e}"
)
return self.data_fetching_service.fetch_stock_data(
symbol, start_date, end_date, period, interval
)
def _get_data_with_smart_cache(
self, symbol: str, start_date: str, end_date: str, interval: str
) -> pd.DataFrame:
"""
Implement smart caching strategy for stock data retrieval.
Args:
symbol: Stock ticker symbol
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
interval: Data interval
Returns:
DataFrame with complete stock data
"""
logger.info(
f"Using smart cache strategy for {symbol} from {start_date} to {end_date}"
)
# Step 1: Get available cached data
cached_df = self.cache_service.get_cached_data(symbol, start_date, end_date)
# Convert dates for comparison
start_dt = pd.to_datetime(start_date)
end_dt = pd.to_datetime(end_date)
# Step 2: Determine what data we need
if cached_df is not None and not cached_df.empty:
logger.info(f"Found {len(cached_df)} cached records for {symbol}")
# Check if we have all the data we need
cached_start = pd.to_datetime(cached_df.index.min())
cached_end = pd.to_datetime(cached_df.index.max())
# Identify missing ranges
missing_ranges = []
# Missing data at the beginning?
if start_dt < cached_start:
missing_start_trading = self._get_trading_days(
start_dt, cached_start - timedelta(days=1)
)
if len(missing_start_trading) > 0:
missing_ranges.append(
(
missing_start_trading[0].strftime("%Y-%m-%d"),
missing_start_trading[-1].strftime("%Y-%m-%d"),
)
)
# Missing recent data?
if end_dt > cached_end:
if self._is_trading_day_between(cached_end, end_dt):
missing_end_trading = self._get_trading_days(
cached_end + timedelta(days=1), end_dt
)
if len(missing_end_trading) > 0:
missing_ranges.append(
(
missing_end_trading[0].strftime("%Y-%m-%d"),
missing_end_trading[-1].strftime("%Y-%m-%d"),
)
)
# If no missing data, return cached data
if not missing_ranges:
logger.info(
f"Cache hit! Returning {len(cached_df)} cached records for {symbol}"
)
# Filter to requested range
mask = (cached_df.index >= start_dt) & (cached_df.index <= end_dt)
return cached_df.loc[mask]
# Step 3: Fetch only missing data
logger.info(f"Cache partial hit. Missing ranges: {missing_ranges}")
all_dfs = [cached_df]
for miss_start, miss_end in missing_ranges:
logger.info(
f"Fetching missing data for {symbol} from {miss_start} to {miss_end}"
)
missing_df = self.data_fetching_service.fetch_stock_data(
symbol, miss_start, miss_end, None, interval
)
if not missing_df.empty:
all_dfs.append(missing_df)
# Cache the new data
self.cache_service.cache_data(symbol, missing_df)
# Combine all data
combined_df = pd.concat(all_dfs).sort_index()
# Remove any duplicates (keep first)
combined_df = combined_df[~combined_df.index.duplicated(keep="first")]
# Filter to requested range
mask = (combined_df.index >= start_dt) & (combined_df.index <= end_dt)
return combined_df.loc[mask]
else:
# No cached data, fetch everything
logger.info(f"No cached data found for {symbol}, fetching fresh data")
# Adjust dates to trading days
trading_days = self._get_trading_days(start_date, end_date)
if len(trading_days) == 0:
logger.warning(
f"No trading days found between {start_date} and {end_date}"
)
return pd.DataFrame(
columns=[
"Open",
"High",
"Low",
"Close",
"Volume",
"Dividends",
"Stock Splits",
]
)
# Fetch data only for the trading day range
fetch_start = trading_days[0].strftime("%Y-%m-%d")
fetch_end = trading_days[-1].strftime("%Y-%m-%d")
logger.info(f"Fetching data for trading days: {fetch_start} to {fetch_end}")
df = self.data_fetching_service.fetch_stock_data(
symbol, fetch_start, fetch_end, None, interval
)
if not df.empty:
# Cache the fetched data
self.cache_service.cache_data(symbol, df)
return df
def get_stock_info(self, symbol: str) -> dict:
"""
Get detailed stock information.
Args:
symbol: Stock ticker symbol
Returns:
Dictionary with stock information
"""
return self.data_fetching_service.fetch_stock_info(symbol)
def get_realtime_data(self, symbol: str) -> dict | None:
"""
Get real-time data for a symbol.
Args:
symbol: Stock ticker symbol
Returns:
Dictionary with real-time data or None
"""
return self.data_fetching_service.fetch_realtime_data(symbol)
def get_multiple_realtime_data(self, symbols: list[str]) -> dict[str, dict]:
"""
Get real-time data for multiple symbols.
Args:
symbols: List of stock ticker symbols
Returns:
Dictionary mapping symbols to their real-time data
"""
return self.data_fetching_service.fetch_multiple_realtime_data(symbols)
def is_market_open(self) -> bool:
"""
Check if the US stock market is currently open.
Returns:
True if market is open
"""
now = datetime.now(pytz.timezone("US/Eastern"))
# Check if it's a weekday
if now.weekday() >= 5: # 5 and 6 are Saturday and Sunday
return False
# Check if it's between 9:30 AM and 4:00 PM Eastern Time
market_open = now.replace(hour=9, minute=30, second=0, microsecond=0)
market_close = now.replace(hour=16, minute=0, second=0, microsecond=0)
return market_open <= now <= market_close
def get_news(self, symbol: str, limit: int = 10) -> pd.DataFrame:
"""
Get news for a stock.
Args:
symbol: Stock ticker symbol
limit: Maximum number of news items
Returns:
DataFrame with news data
"""
return self.data_fetching_service.fetch_news(symbol, limit)
def get_earnings(self, symbol: str) -> dict:
"""
Get earnings information for a stock.
Args:
symbol: Stock ticker symbol
Returns:
Dictionary with earnings data
"""
return self.data_fetching_service.fetch_earnings(symbol)
def get_recommendations(self, symbol: str) -> pd.DataFrame:
"""
Get analyst recommendations for a stock.
Args:
symbol: Stock ticker symbol
Returns:
DataFrame with recommendations
"""
return self.data_fetching_service.fetch_recommendations(symbol)
def is_etf(self, symbol: str) -> bool:
"""
Check if a given symbol is an ETF.
Args:
symbol: Stock ticker symbol
Returns:
True if symbol is an ETF
"""
return self.data_fetching_service.check_if_etf(symbol)
def _get_trading_days(self, start_date, end_date) -> pd.DatetimeIndex:
"""
Get all trading days between start and end dates.
Args:
start_date: Start date (can be string or datetime)
end_date: End date (can be string or datetime)
Returns:
DatetimeIndex of trading days
"""
# Ensure dates are datetime objects
if isinstance(start_date, str):
start_date = pd.to_datetime(start_date)
if isinstance(end_date, str):
end_date = pd.to_datetime(end_date)
# Get valid trading days from market calendar
schedule = self.market_calendar.schedule(
start_date=start_date, end_date=end_date
)
return schedule.index
def _get_last_trading_day(self, date) -> pd.Timestamp:
"""
Get the last trading day on or before the given date.
Args:
date: Date to check (can be string or datetime)
Returns:
Last trading day as pd.Timestamp
"""
if isinstance(date, str):
date = pd.to_datetime(date)
# Check if the date itself is a trading day
if self._is_trading_day(date):
return date
# Otherwise, find the previous trading day
for i in range(1, 10): # Look back up to 10 days
check_date = date - timedelta(days=i)
if self._is_trading_day(check_date):
return check_date
# Fallback to the date itself if no trading day found
return date
def _is_trading_day(self, date) -> bool:
"""
Check if a specific date is a trading day.
Args:
date: Date to check
Returns:
True if it's a trading day
"""
if isinstance(date, str):
date = pd.to_datetime(date)
schedule = self.market_calendar.schedule(start_date=date, end_date=date)
return len(schedule) > 0
def _is_trading_day_between(
self, start_date: pd.Timestamp, end_date: pd.Timestamp
) -> bool:
"""
Check if there's a trading day between two dates.
Args:
start_date: Start date
end_date: End date
Returns:
True if there's a trading day between the dates
"""
# Add one day to start since we're checking "between"
check_start = start_date + timedelta(days=1)
if check_start > end_date:
return False
# Get trading days in the range
trading_days = self._get_trading_days(check_start, end_date)
return len(trading_days) > 0
def invalidate_cache(self, symbol: str, start_date: str, end_date: str) -> bool:
"""
Invalidate cached data for a symbol within a date range.
Args:
symbol: Stock ticker symbol
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
Returns:
True if invalidation was successful
"""
return self.cache_service.invalidate_cache(symbol, start_date, end_date)
def get_cache_stats(self, symbol: str) -> dict:
"""
Get cache statistics for a symbol.
Args:
symbol: Stock ticker symbol
Returns:
Dictionary with cache statistics
"""
return self.cache_service.get_cache_stats(symbol)
```
--------------------------------------------------------------------------------
/maverick_mcp/api/routers/mcp_prompts.py:
--------------------------------------------------------------------------------
```python
"""MCP Prompts for better tool discovery and usage guidance."""
from fastmcp import FastMCP
def register_mcp_prompts(mcp: FastMCP):
"""Register MCP prompts to help clients understand how to use the tools."""
# Backtesting prompts
@mcp.prompt()
async def backtest_strategy_guide():
"""Guide for running backtesting strategies."""
return """
# Backtesting Strategy Guide
## Available Strategies (15 total)
### Traditional Strategies (9):
- `sma_cross`: Simple Moving Average Crossover
- `rsi`: RSI Mean Reversion (oversold/overbought)
- `macd`: MACD Signal Line Crossover
- `bollinger`: Bollinger Bands (buy low, sell high)
- `momentum`: Momentum-based trading
- `ema_cross`: Exponential Moving Average Crossover
- `mean_reversion`: Mean Reversion Strategy
- `breakout`: Channel Breakout Strategy
- `volume_momentum`: Volume-Weighted Momentum
### ML Strategies (6):
- `online_learning`: Adaptive learning with dynamic thresholds
- `regime_aware`: Market regime detection (trending vs ranging)
- `ensemble`: Multiple strategy voting system
## Example Usage:
### Traditional Strategy:
"Run a backtest on AAPL using the sma_cross strategy from 2024-01-01 to 2024-12-31"
### ML Strategy:
"Test the online_learning strategy on TSLA for the past year with a learning rate of 0.01"
### Parameters:
- Most strategies have default parameters that work well
- You can customize: fast_period, slow_period, threshold, etc.
"""
@mcp.prompt()
async def ml_strategy_examples():
"""Examples of ML strategy usage."""
return """
# ML Strategy Examples
## 1. Online Learning Strategy
"Run online_learning strategy on NVDA with parameters:
- lookback: 20 days
- learning_rate: 0.01
- start_date: 2024-01-01
- end_date: 2024-12-31"
## 2. Regime-Aware Strategy
"Test regime_aware strategy on SPY to detect market regimes:
- regime_window: 50 days
- threshold: 0.02
- Adapts between trending and ranging markets"
## 3. Ensemble Strategy
"Use ensemble strategy on AAPL combining multiple signals:
- Combines SMA, RSI, and Momentum
- Uses voting to generate signals
- More robust than single strategies"
## Important Notes:
- ML strategies work through the standard run_backtest tool
- Use strategy_type parameter: "online_learning", "regime_aware", or "ensemble"
- These are simplified ML strategies that don't require training
"""
@mcp.prompt()
async def optimization_guide():
"""Guide for parameter optimization."""
return """
# Parameter Optimization Guide
## How to Optimize Strategy Parameters
### Basic Optimization:
"Optimize sma_cross parameters for MSFT over the past 6 months"
This will test combinations like:
- fast_period: [5, 10, 15, 20]
- slow_period: [20, 30, 50, 100]
### Custom Parameter Ranges:
"Optimize RSI strategy for TSLA with:
- period: [7, 14, 21]
- oversold: [20, 25, 30]
- overbought: [70, 75, 80]"
### Optimization Metrics:
- sharpe_ratio (default): Risk-adjusted returns
- total_return: Raw returns
- win_rate: Percentage of winning trades
## Results Include:
- Best parameter combination
- Performance metrics for top combinations
- Comparison across all tested parameters
"""
@mcp.prompt()
async def available_tools_summary():
"""Summary of all available MCP tools."""
return """
# MaverickMCP Tools Summary
## 1. Backtesting Tools
- `run_backtest`: Run any strategy (traditional or ML)
- `optimize_parameters`: Find best parameters
- `compare_strategies`: Compare multiple strategies
- `get_strategy_info`: Get strategy details
## 2. Data Tools
- `get_stock_data`: Historical price data
- `get_stock_info`: Company information
- `get_multiple_stocks_data`: Batch data fetching
## 3. Technical Analysis
- `calculate_sma`, `calculate_ema`: Moving averages
- `calculate_rsi`: Relative Strength Index
- `calculate_macd`: MACD indicator
- `calculate_bollinger_bands`: Bollinger Bands
- `get_full_technical_analysis`: All indicators
## 4. Screening Tools
- `get_maverick_recommendations`: Bullish stocks
- `get_maverick_bear_recommendations`: Bearish setups
- `get_trending_breakout_recommendations`: Breakout candidates
## 5. Portfolio Tools
- `optimize_portfolio`: Portfolio optimization
- `analyze_portfolio_risk`: Risk assessment
- `calculate_correlation_matrix`: Asset correlations
## Usage Tips:
- Start with simple strategies before trying ML
- Use default parameters initially
- Optimize parameters after initial testing
- Compare multiple strategies on same data
"""
@mcp.prompt()
async def troubleshooting_guide():
"""Troubleshooting common issues."""
return """
# Troubleshooting Guide
## Common Issues and Solutions
### 1. "Unknown strategy type"
**Solution**: Use one of these exact strategy names:
- Traditional: sma_cross, rsi, macd, bollinger, momentum, ema_cross, mean_reversion, breakout, volume_momentum
- ML: online_learning, regime_aware, ensemble
### 2. "No data available"
**Solution**:
- Check date range (use past dates, not future)
- Verify stock symbol (use standard tickers like AAPL, MSFT)
- Try shorter date ranges (1 year or less)
### 3. ML Strategy Issues
**Solution**: Use the standard run_backtest tool with:
```
strategy_type: "online_learning" # or "regime_aware", "ensemble"
```
Don't use the run_ml_backtest tool for these strategies.
### 4. Parameter Errors
**Solution**: Start with no parameters (uses defaults):
"Run backtest on AAPL using sma_cross strategy"
Then customize if needed:
"Run backtest on AAPL using sma_cross with fast_period=10 and slow_period=30"
### 5. Connection Issues
**Solution**:
- Restart Claude Desktop
- Check server is running: The white circle should be blue
- Try a simple test: "Get AAPL stock data"
"""
@mcp.prompt()
async def quick_start():
"""Quick start guide for new users."""
return """
# Quick Start Guide
## Test These Commands First:
### 1. Simple Backtest
"Run a backtest on AAPL using the sma_cross strategy for 2024"
### 2. Get Stock Data
"Get AAPL stock data for the last 3 months"
### 3. Technical Analysis
"Show me technical analysis for MSFT"
### 4. Stock Screening
"Show me bullish stock recommendations"
### 5. ML Strategy Test
"Test the online_learning strategy on TSLA for the past 6 months"
## Next Steps:
1. Try different strategies on your favorite stocks
2. Optimize parameters for better performance
3. Compare multiple strategies
4. Build a portfolio with top performers
## Pro Tips:
- Use 2024 dates for reliable data
- Start with liquid stocks (AAPL, MSFT, GOOGL)
- Default parameters usually work well
- ML strategies are experimental but fun to try
"""
# Register a resources endpoint for better discovery
@mcp.prompt()
async def strategy_reference():
"""Complete strategy reference with all parameters."""
strategies = {
"sma_cross": {
"description": "Buy when fast SMA crosses above slow SMA",
"parameters": {
"fast_period": "Fast moving average period (default: 10)",
"slow_period": "Slow moving average period (default: 20)",
},
"example": "run_backtest(symbol='AAPL', strategy_type='sma_cross', fast_period=10, slow_period=20)",
},
"rsi": {
"description": "Buy oversold (RSI < 30), sell overbought (RSI > 70)",
"parameters": {
"period": "RSI calculation period (default: 14)",
"oversold": "Oversold threshold (default: 30)",
"overbought": "Overbought threshold (default: 70)",
},
"example": "run_backtest(symbol='MSFT', strategy_type='rsi', period=14, oversold=30)",
},
"online_learning": {
"description": "ML strategy with adaptive thresholds",
"parameters": {
"lookback": "Historical window (default: 20)",
"learning_rate": "Adaptation rate (default: 0.01)",
},
"example": "run_backtest(symbol='TSLA', strategy_type='online_learning', lookback=20)",
},
"regime_aware": {
"description": "Detects and adapts to market regimes",
"parameters": {
"regime_window": "Regime detection window (default: 50)",
"threshold": "Regime change threshold (default: 0.02)",
},
"example": "run_backtest(symbol='SPY', strategy_type='regime_aware', regime_window=50)",
},
"ensemble": {
"description": "Combines multiple strategies with voting",
"parameters": {
"fast_period": "Fast MA period (default: 10)",
"slow_period": "Slow MA period (default: 20)",
"rsi_period": "RSI period (default: 14)",
},
"example": "run_backtest(symbol='NVDA', strategy_type='ensemble')",
},
}
import json
return f"""
# Complete Strategy Reference
## All Available Strategies with Parameters
```json
{json.dumps(strategies, indent=2)}
```
## Usage Pattern:
All strategies use the same tool: `run_backtest`
Parameters:
- symbol: Stock ticker (required)
- strategy_type: Strategy name (required)
- start_date: YYYY-MM-DD format
- end_date: YYYY-MM-DD format
- initial_capital: Starting amount (default: 10000)
- Additional strategy-specific parameters
## Testing Order:
1. Start with sma_cross (simplest)
2. Try rsi or macd (intermediate)
3. Test online_learning (ML strategy)
4. Compare all with compare_strategies tool
"""
# Register resources for better discovery
@mcp.resource("strategies://list")
def list_strategies_resource():
"""List of all available backtesting strategies with parameters."""
return {
"traditional_strategies": {
"sma_cross": {
"name": "Simple Moving Average Crossover",
"parameters": ["fast_period", "slow_period"],
"default_values": {"fast_period": 10, "slow_period": 20},
},
"rsi": {
"name": "RSI Mean Reversion",
"parameters": ["period", "oversold", "overbought"],
"default_values": {"period": 14, "oversold": 30, "overbought": 70},
},
"macd": {
"name": "MACD Signal Line Crossover",
"parameters": ["fast_period", "slow_period", "signal_period"],
"default_values": {
"fast_period": 12,
"slow_period": 26,
"signal_period": 9,
},
},
"bollinger": {
"name": "Bollinger Bands",
"parameters": ["period", "std_dev"],
"default_values": {"period": 20, "std_dev": 2},
},
"momentum": {
"name": "Momentum Trading",
"parameters": ["period", "threshold"],
"default_values": {"period": 10, "threshold": 0.02},
},
"ema_cross": {
"name": "EMA Crossover",
"parameters": ["fast_period", "slow_period"],
"default_values": {"fast_period": 12, "slow_period": 26},
},
"mean_reversion": {
"name": "Mean Reversion",
"parameters": ["lookback", "entry_z", "exit_z"],
"default_values": {"lookback": 20, "entry_z": -2, "exit_z": 0},
},
"breakout": {
"name": "Channel Breakout",
"parameters": ["lookback", "breakout_factor"],
"default_values": {"lookback": 20, "breakout_factor": 1.5},
},
"volume_momentum": {
"name": "Volume-Weighted Momentum",
"parameters": ["period", "volume_factor"],
"default_values": {"period": 10, "volume_factor": 1.5},
},
},
"ml_strategies": {
"online_learning": {
"name": "Online Learning Adaptive Strategy",
"parameters": ["lookback", "learning_rate"],
"default_values": {"lookback": 20, "learning_rate": 0.01},
},
"regime_aware": {
"name": "Market Regime Detection",
"parameters": ["regime_window", "threshold"],
"default_values": {"regime_window": 50, "threshold": 0.02},
},
"ensemble": {
"name": "Ensemble Voting Strategy",
"parameters": ["fast_period", "slow_period", "rsi_period"],
"default_values": {
"fast_period": 10,
"slow_period": 20,
"rsi_period": 14,
},
},
},
"total_strategies": 15,
}
@mcp.resource("tools://categories")
def tool_categories_resource():
"""Categorized list of all available MCP tools."""
return {
"backtesting": [
"run_backtest",
"optimize_parameters",
"compare_strategies",
"get_strategy_info",
],
"data": ["get_stock_data", "get_stock_info", "get_multiple_stocks_data"],
"technical_analysis": [
"calculate_sma",
"calculate_ema",
"calculate_rsi",
"calculate_macd",
"calculate_bollinger_bands",
"get_full_technical_analysis",
],
"screening": [
"get_maverick_recommendations",
"get_maverick_bear_recommendations",
"get_trending_breakout_recommendations",
],
"portfolio": [
"optimize_portfolio",
"analyze_portfolio_risk",
"calculate_correlation_matrix",
],
"research": [
"research_comprehensive",
"research_company",
"analyze_market_sentiment",
"coordinate_agents",
],
}
@mcp.resource("examples://backtesting")
def backtesting_examples_resource():
"""Practical examples of using backtesting tools."""
return {
"simple_backtest": {
"description": "Basic backtest with default parameters",
"example": "run_backtest(symbol='AAPL', strategy_type='sma_cross')",
"expected_output": "Performance metrics including total return, sharpe ratio, win rate",
},
"custom_parameters": {
"description": "Backtest with custom strategy parameters",
"example": "run_backtest(symbol='TSLA', strategy_type='rsi', period=21, oversold=25)",
"expected_output": "Performance with adjusted RSI parameters",
},
"ml_strategy": {
"description": "Running ML-based strategy",
"example": "run_backtest(symbol='NVDA', strategy_type='online_learning', lookback=30)",
"expected_output": "Adaptive strategy performance with online learning",
},
"optimization": {
"description": "Optimize strategy parameters",
"example": "optimize_parameters(symbol='MSFT', strategy_type='sma_cross')",
"expected_output": "Best parameter combination and performance metrics",
},
"comparison": {
"description": "Compare multiple strategies",
"example": "compare_strategies(symbol='SPY', strategies=['sma_cross', 'rsi', 'online_learning'])",
"expected_output": "Side-by-side comparison of strategy performance",
},
}
return True
```