This is page 13 of 39. Use http://codebase.md/wshobson/maverick-mcp?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.md
│ │ ├── config.yml
│ │ ├── feature_request.md
│ │ ├── question.md
│ │ └── security_report.md
│ ├── pull_request_template.md
│ └── workflows
│ ├── claude-code-review.yml
│ └── claude.yml
├── .gitignore
├── .python-version
├── .vscode
│ ├── launch.json
│ └── settings.json
├── alembic
│ ├── env.py
│ ├── script.py.mako
│ └── versions
│ ├── 001_initial_schema.py
│ ├── 003_add_performance_indexes.py
│ ├── 006_rename_metadata_columns.py
│ ├── 008_performance_optimization_indexes.py
│ ├── 009_rename_to_supply_demand.py
│ ├── 010_self_contained_schema.py
│ ├── 011_remove_proprietary_terms.py
│ ├── 013_add_backtest_persistence_models.py
│ ├── 014_add_portfolio_models.py
│ ├── 08e3945a0c93_merge_heads.py
│ ├── 9374a5c9b679_merge_heads_for_testing.py
│ ├── abf9b9afb134_merge_multiple_heads.py
│ ├── adda6d3fd84b_merge_proprietary_terms_removal_with_.py
│ ├── e0c75b0bdadb_fix_financial_data_precision_only.py
│ ├── f0696e2cac15_add_essential_performance_indexes.py
│ └── fix_database_integrity_issues.py
├── alembic.ini
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── DATABASE_SETUP.md
├── docker-compose.override.yml.example
├── docker-compose.yml
├── Dockerfile
├── docs
│ ├── api
│ │ └── backtesting.md
│ ├── BACKTESTING.md
│ ├── COST_BASIS_SPECIFICATION.md
│ ├── deep_research_agent.md
│ ├── exa_research_testing_strategy.md
│ ├── PORTFOLIO_PERSONALIZATION_PLAN.md
│ ├── PORTFOLIO.md
│ ├── SETUP_SELF_CONTAINED.md
│ └── speed_testing_framework.md
├── examples
│ ├── complete_speed_validation.py
│ ├── deep_research_integration.py
│ ├── llm_optimization_example.py
│ ├── llm_speed_demo.py
│ ├── monitoring_example.py
│ ├── parallel_research_example.py
│ ├── speed_optimization_demo.py
│ └── timeout_fix_demonstration.py
├── LICENSE
├── Makefile
├── MANIFEST.in
├── maverick_mcp
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── circuit_breaker.py
│ │ ├── deep_research.py
│ │ ├── market_analysis.py
│ │ ├── optimized_research.py
│ │ ├── supervisor.py
│ │ └── technical_analysis.py
│ ├── api
│ │ ├── __init__.py
│ │ ├── api_server.py
│ │ ├── connection_manager.py
│ │ ├── dependencies
│ │ │ ├── __init__.py
│ │ │ ├── stock_analysis.py
│ │ │ └── technical_analysis.py
│ │ ├── error_handling.py
│ │ ├── inspector_compatible_sse.py
│ │ ├── inspector_sse.py
│ │ ├── middleware
│ │ │ ├── error_handling.py
│ │ │ ├── mcp_logging.py
│ │ │ ├── rate_limiting_enhanced.py
│ │ │ └── security.py
│ │ ├── openapi_config.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── agents.py
│ │ │ ├── backtesting.py
│ │ │ ├── data_enhanced.py
│ │ │ ├── data.py
│ │ │ ├── health_enhanced.py
│ │ │ ├── health_tools.py
│ │ │ ├── health.py
│ │ │ ├── intelligent_backtesting.py
│ │ │ ├── introspection.py
│ │ │ ├── mcp_prompts.py
│ │ │ ├── monitoring.py
│ │ │ ├── news_sentiment_enhanced.py
│ │ │ ├── performance.py
│ │ │ ├── portfolio.py
│ │ │ ├── research.py
│ │ │ ├── screening_ddd.py
│ │ │ ├── screening_parallel.py
│ │ │ ├── screening.py
│ │ │ ├── technical_ddd.py
│ │ │ ├── technical_enhanced.py
│ │ │ ├── technical.py
│ │ │ └── tool_registry.py
│ │ ├── server.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── base_service.py
│ │ │ ├── market_service.py
│ │ │ ├── portfolio_service.py
│ │ │ ├── prompt_service.py
│ │ │ └── resource_service.py
│ │ ├── simple_sse.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── insomnia_export.py
│ │ └── postman_export.py
│ ├── application
│ │ ├── __init__.py
│ │ ├── commands
│ │ │ └── __init__.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_dto.py
│ │ ├── queries
│ │ │ ├── __init__.py
│ │ │ └── get_technical_analysis.py
│ │ └── screening
│ │ ├── __init__.py
│ │ ├── dtos.py
│ │ └── queries.py
│ ├── backtesting
│ │ ├── __init__.py
│ │ ├── ab_testing.py
│ │ ├── analysis.py
│ │ ├── batch_processing_stub.py
│ │ ├── batch_processing.py
│ │ ├── model_manager.py
│ │ ├── optimization.py
│ │ ├── persistence.py
│ │ ├── retraining_pipeline.py
│ │ ├── strategies
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── ml
│ │ │ │ ├── __init__.py
│ │ │ │ ├── adaptive.py
│ │ │ │ ├── ensemble.py
│ │ │ │ ├── feature_engineering.py
│ │ │ │ └── regime_aware.py
│ │ │ ├── ml_strategies.py
│ │ │ ├── parser.py
│ │ │ └── templates.py
│ │ ├── strategy_executor.py
│ │ ├── vectorbt_engine.py
│ │ └── visualization.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── database_self_contained.py
│ │ ├── database.py
│ │ ├── llm_optimization_config.py
│ │ ├── logging_settings.py
│ │ ├── plotly_config.py
│ │ ├── security_utils.py
│ │ ├── security.py
│ │ ├── settings.py
│ │ ├── technical_constants.py
│ │ ├── tool_estimation.py
│ │ └── validation.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── technical_analysis.py
│ │ └── visualization.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── cache_manager.py
│ │ ├── cache.py
│ │ ├── django_adapter.py
│ │ ├── health.py
│ │ ├── models.py
│ │ ├── performance.py
│ │ ├── session_management.py
│ │ └── validation.py
│ ├── database
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── optimization.py
│ ├── dependencies.py
│ ├── domain
│ │ ├── __init__.py
│ │ ├── entities
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis.py
│ │ ├── events
│ │ │ └── __init__.py
│ │ ├── portfolio.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ ├── entities.py
│ │ │ ├── services.py
│ │ │ └── value_objects.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_service.py
│ │ ├── stock_analysis
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis_service.py
│ │ └── value_objects
│ │ ├── __init__.py
│ │ └── technical_indicators.py
│ ├── exceptions.py
│ ├── infrastructure
│ │ ├── __init__.py
│ │ ├── cache
│ │ │ └── __init__.py
│ │ ├── caching
│ │ │ ├── __init__.py
│ │ │ └── cache_management_service.py
│ │ ├── connection_manager.py
│ │ ├── data_fetching
│ │ │ ├── __init__.py
│ │ │ └── stock_data_service.py
│ │ ├── health
│ │ │ ├── __init__.py
│ │ │ └── health_checker.py
│ │ ├── persistence
│ │ │ ├── __init__.py
│ │ │ └── stock_repository.py
│ │ ├── providers
│ │ │ └── __init__.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ └── repositories.py
│ │ └── sse_optimizer.py
│ ├── langchain_tools
│ │ ├── __init__.py
│ │ ├── adapters.py
│ │ └── registry.py
│ ├── logging_config.py
│ ├── memory
│ │ ├── __init__.py
│ │ └── stores.py
│ ├── monitoring
│ │ ├── __init__.py
│ │ ├── health_check.py
│ │ ├── health_monitor.py
│ │ ├── integration_example.py
│ │ ├── metrics.py
│ │ ├── middleware.py
│ │ └── status_dashboard.py
│ ├── providers
│ │ ├── __init__.py
│ │ ├── dependencies.py
│ │ ├── factories
│ │ │ ├── __init__.py
│ │ │ ├── config_factory.py
│ │ │ └── provider_factory.py
│ │ ├── implementations
│ │ │ ├── __init__.py
│ │ │ ├── cache_adapter.py
│ │ │ ├── macro_data_adapter.py
│ │ │ ├── market_data_adapter.py
│ │ │ ├── persistence_adapter.py
│ │ │ └── stock_data_adapter.py
│ │ ├── interfaces
│ │ │ ├── __init__.py
│ │ │ ├── cache.py
│ │ │ ├── config.py
│ │ │ ├── macro_data.py
│ │ │ ├── market_data.py
│ │ │ ├── persistence.py
│ │ │ └── stock_data.py
│ │ ├── llm_factory.py
│ │ ├── macro_data.py
│ │ ├── market_data.py
│ │ ├── mocks
│ │ │ ├── __init__.py
│ │ │ ├── mock_cache.py
│ │ │ ├── mock_config.py
│ │ │ ├── mock_macro_data.py
│ │ │ ├── mock_market_data.py
│ │ │ ├── mock_persistence.py
│ │ │ └── mock_stock_data.py
│ │ ├── openrouter_provider.py
│ │ ├── optimized_screening.py
│ │ ├── optimized_stock_data.py
│ │ └── stock_data.py
│ ├── README.md
│ ├── tests
│ │ ├── __init__.py
│ │ ├── README_INMEMORY_TESTS.md
│ │ ├── test_cache_debug.py
│ │ ├── test_fixes_validation.py
│ │ ├── test_in_memory_routers.py
│ │ ├── test_in_memory_server.py
│ │ ├── test_macro_data_provider.py
│ │ ├── test_mailgun_email.py
│ │ ├── test_market_calendar_caching.py
│ │ ├── test_mcp_tool_fixes_pytest.py
│ │ ├── test_mcp_tool_fixes.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_models_functional.py
│ │ ├── test_server.py
│ │ ├── test_stock_data_enhanced.py
│ │ ├── test_stock_data_provider.py
│ │ └── test_technical_analysis.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── performance_monitoring.py
│ │ ├── portfolio_manager.py
│ │ ├── risk_management.py
│ │ └── sentiment_analysis.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── agent_errors.py
│ │ ├── batch_processing.py
│ │ ├── cache_warmer.py
│ │ ├── circuit_breaker_decorators.py
│ │ ├── circuit_breaker_services.py
│ │ ├── circuit_breaker.py
│ │ ├── data_chunking.py
│ │ ├── database_monitoring.py
│ │ ├── debug_utils.py
│ │ ├── fallback_strategies.py
│ │ ├── llm_optimization.py
│ │ ├── logging_example.py
│ │ ├── logging_init.py
│ │ ├── logging.py
│ │ ├── mcp_logging.py
│ │ ├── memory_profiler.py
│ │ ├── monitoring_middleware.py
│ │ ├── monitoring.py
│ │ ├── orchestration_logging.py
│ │ ├── parallel_research.py
│ │ ├── parallel_screening.py
│ │ ├── quick_cache.py
│ │ ├── resource_manager.py
│ │ ├── shutdown.py
│ │ ├── stock_helpers.py
│ │ ├── structured_logger.py
│ │ ├── tool_monitoring.py
│ │ ├── tracing.py
│ │ └── yfinance_pool.py
│ ├── validation
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── data.py
│ │ ├── middleware.py
│ │ ├── portfolio.py
│ │ ├── responses.py
│ │ ├── screening.py
│ │ └── technical.py
│ └── workflows
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── market_analyzer.py
│ │ ├── optimizer_agent.py
│ │ ├── strategy_selector.py
│ │ └── validator_agent.py
│ ├── backtesting_workflow.py
│ └── state.py
├── PLANS.md
├── pyproject.toml
├── pyrightconfig.json
├── README.md
├── scripts
│ ├── dev.sh
│ ├── INSTALLATION_GUIDE.md
│ ├── load_example.py
│ ├── load_market_data.py
│ ├── load_tiingo_data.py
│ ├── migrate_db.py
│ ├── README_TIINGO_LOADER.md
│ ├── requirements_tiingo.txt
│ ├── run_stock_screening.py
│ ├── run-migrations.sh
│ ├── seed_db.py
│ ├── seed_sp500.py
│ ├── setup_database.sh
│ ├── setup_self_contained.py
│ ├── setup_sp500_database.sh
│ ├── test_seeded_data.py
│ ├── test_tiingo_loader.py
│ ├── tiingo_config.py
│ └── validate_setup.py
├── SECURITY.md
├── server.json
├── setup.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── core
│ │ └── test_technical_analysis.py
│ ├── data
│ │ └── test_portfolio_models.py
│ ├── domain
│ │ ├── conftest.py
│ │ ├── test_portfolio_entities.py
│ │ └── test_technical_analysis_service.py
│ ├── fixtures
│ │ └── orchestration_fixtures.py
│ ├── integration
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── README.md
│ │ ├── run_integration_tests.sh
│ │ ├── test_api_technical.py
│ │ ├── test_chaos_engineering.py
│ │ ├── test_config_management.py
│ │ ├── test_full_backtest_workflow_advanced.py
│ │ ├── test_full_backtest_workflow.py
│ │ ├── test_high_volume.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_orchestration_complete.py
│ │ ├── test_portfolio_persistence.py
│ │ ├── test_redis_cache.py
│ │ ├── test_security_integration.py.disabled
│ │ └── vcr_setup.py
│ ├── performance
│ │ ├── __init__.py
│ │ ├── test_benchmarks.py
│ │ ├── test_load.py
│ │ ├── test_profiling.py
│ │ └── test_stress.py
│ ├── providers
│ │ └── test_stock_data_simple.py
│ ├── README.md
│ ├── test_agents_router_mcp.py
│ ├── test_backtest_persistence.py
│ ├── test_cache_management_service.py
│ ├── test_cache_serialization.py
│ ├── test_circuit_breaker.py
│ ├── test_database_pool_config_simple.py
│ ├── test_database_pool_config.py
│ ├── test_deep_research_functional.py
│ ├── test_deep_research_integration.py
│ ├── test_deep_research_parallel_execution.py
│ ├── test_error_handling.py
│ ├── test_event_loop_integrity.py
│ ├── test_exa_research_integration.py
│ ├── test_exception_hierarchy.py
│ ├── test_financial_search.py
│ ├── test_graceful_shutdown.py
│ ├── test_integration_simple.py
│ ├── test_langgraph_workflow.py
│ ├── test_market_data_async.py
│ ├── test_market_data_simple.py
│ ├── test_mcp_orchestration_functional.py
│ ├── test_ml_strategies.py
│ ├── test_optimized_research_agent.py
│ ├── test_orchestration_integration.py
│ ├── test_orchestration_logging.py
│ ├── test_orchestration_tools_simple.py
│ ├── test_parallel_research_integration.py
│ ├── test_parallel_research_orchestrator.py
│ ├── test_parallel_research_performance.py
│ ├── test_performance_optimizations.py
│ ├── test_production_validation.py
│ ├── test_provider_architecture.py
│ ├── test_rate_limiting_enhanced.py
│ ├── test_runner_validation.py
│ ├── test_security_comprehensive.py.disabled
│ ├── test_security_cors.py
│ ├── test_security_enhancements.py.disabled
│ ├── test_security_headers.py
│ ├── test_security_penetration.py
│ ├── test_session_management.py
│ ├── test_speed_optimization_validation.py
│ ├── test_stock_analysis_dependencies.py
│ ├── test_stock_analysis_service.py
│ ├── test_stock_data_fetching_service.py
│ ├── test_supervisor_agent.py
│ ├── test_supervisor_functional.py
│ ├── test_tool_estimation_config.py
│ ├── test_visualization.py
│ └── utils
│ ├── test_agent_errors.py
│ ├── test_logging.py
│ ├── test_parallel_screening.py
│ └── test_quick_cache.py
├── tools
│ ├── check_orchestration_config.py
│ ├── experiments
│ │ ├── validation_examples.py
│ │ └── validation_fixed.py
│ ├── fast_dev.sh
│ ├── hot_reload.py
│ ├── quick_test.py
│ └── templates
│ ├── new_router_template.py
│ ├── new_tool_template.py
│ ├── screening_strategy_template.py
│ └── test_template.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/maverick_mcp/domain/screening/services.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Screening domain services.
3 |
4 | This module contains pure business logic services that operate on
5 | screening entities and value objects without any external dependencies.
6 | """
7 |
8 | from datetime import datetime
9 | from decimal import Decimal
10 | from typing import Any, Protocol
11 |
12 | from .entities import ScreeningResult, ScreeningResultCollection
13 | from .value_objects import (
14 | ScreeningCriteria,
15 | ScreeningLimits,
16 | ScreeningStrategy,
17 | SortingOptions,
18 | )
19 |
20 |
21 | class IStockRepository(Protocol):
22 | """Protocol defining the interface for stock data access."""
23 |
24 | def get_maverick_stocks(
25 | self, limit: int = 20, min_score: int | None = None
26 | ) -> list[dict[str, Any]]:
27 | """Get Maverick bullish stocks."""
28 | ...
29 |
30 | def get_maverick_bear_stocks(
31 | self, limit: int = 20, min_score: int | None = None
32 | ) -> list[dict[str, Any]]:
33 | """Get Maverick bearish stocks."""
34 | ...
35 |
36 | def get_trending_stocks(
37 | self,
38 | limit: int = 20,
39 | min_momentum_score: Decimal | None = None,
40 | filter_moving_averages: bool = False,
41 | ) -> list[dict[str, Any]]:
42 | """Get trending stocks."""
43 | ...
44 |
45 |
46 | class ScreeningService:
47 | """
48 | Pure domain service for stock screening business logic.
49 |
50 | This service contains no external dependencies and focuses solely
51 | on the business rules and logic for screening operations.
52 | """
53 |
54 | def __init__(self):
55 | """Initialize the screening service."""
56 | self._default_limits = ScreeningLimits()
57 |
58 | def create_screening_result_from_raw_data(
59 | self, raw_data: dict[str, Any], screening_date: datetime | None = None
60 | ) -> ScreeningResult:
61 | """
62 | Create a ScreeningResult entity from raw database data.
63 |
64 | This method handles the transformation of raw data into
65 | a properly validated domain entity.
66 | """
67 | if screening_date is None:
68 | screening_date = datetime.utcnow()
69 |
70 | return ScreeningResult(
71 | stock_symbol=raw_data.get("stock", ""),
72 | screening_date=screening_date,
73 | open_price=Decimal(str(raw_data.get("open", 0))),
74 | high_price=Decimal(str(raw_data.get("high", 0))),
75 | low_price=Decimal(str(raw_data.get("low", 0))),
76 | close_price=Decimal(str(raw_data.get("close", 0))),
77 | volume=int(raw_data.get("volume", 0)),
78 | ema_21=Decimal(str(raw_data.get("ema_21", 0))),
79 | sma_50=Decimal(str(raw_data.get("sma_50", 0))),
80 | sma_150=Decimal(str(raw_data.get("sma_150", 0))),
81 | sma_200=Decimal(str(raw_data.get("sma_200", 0))),
82 | momentum_score=Decimal(str(raw_data.get("momentum_score", 0))),
83 | avg_volume_30d=Decimal(
84 | str(raw_data.get("avg_vol_30d", raw_data.get("avg_volume_30d", 0)))
85 | ),
86 | adr_percentage=Decimal(str(raw_data.get("adr_pct", 0))),
87 | atr=Decimal(str(raw_data.get("atr", 0))),
88 | pattern=raw_data.get("pat"),
89 | squeeze=raw_data.get("sqz"),
90 | vcp=raw_data.get("vcp"),
91 | entry_signal=raw_data.get("entry"),
92 | combined_score=int(raw_data.get("combined_score", 0)),
93 | bear_score=int(raw_data.get("score", 0)), # Bear score uses 'score' field
94 | compression_score=int(raw_data.get("compression_score", 0)),
95 | pattern_detected=int(raw_data.get("pattern_detected", 0)),
96 | # Bearish-specific fields
97 | rsi_14=Decimal(str(raw_data["rsi_14"]))
98 | if raw_data.get("rsi_14") is not None
99 | else None,
100 | macd=Decimal(str(raw_data["macd"]))
101 | if raw_data.get("macd") is not None
102 | else None,
103 | macd_signal=Decimal(str(raw_data["macd_s"]))
104 | if raw_data.get("macd_s") is not None
105 | else None,
106 | macd_histogram=Decimal(str(raw_data["macd_h"]))
107 | if raw_data.get("macd_h") is not None
108 | else None,
109 | distribution_days_20=raw_data.get("dist_days_20"),
110 | atr_contraction=raw_data.get("atr_contraction"),
111 | big_down_volume=raw_data.get("big_down_vol"),
112 | )
113 |
114 | def apply_screening_criteria(
115 | self, results: list[ScreeningResult], criteria: ScreeningCriteria
116 | ) -> list[ScreeningResult]:
117 | """
118 | Apply screening criteria to filter results.
119 |
120 | This method implements all the business rules for filtering
121 | screening results based on the provided criteria.
122 | """
123 | if not criteria.has_any_filters():
124 | return results
125 |
126 | filtered_results = results
127 |
128 | # Momentum Score filters
129 | if criteria.min_momentum_score is not None:
130 | filtered_results = [
131 | r
132 | for r in filtered_results
133 | if r.momentum_score >= criteria.min_momentum_score
134 | ]
135 |
136 | if criteria.max_momentum_score is not None:
137 | filtered_results = [
138 | r
139 | for r in filtered_results
140 | if r.momentum_score <= criteria.max_momentum_score
141 | ]
142 |
143 | # Volume filters
144 | if criteria.min_volume is not None:
145 | filtered_results = [
146 | r for r in filtered_results if r.avg_volume_30d >= criteria.min_volume
147 | ]
148 |
149 | if criteria.max_volume is not None:
150 | filtered_results = [
151 | r for r in filtered_results if r.avg_volume_30d <= criteria.max_volume
152 | ]
153 |
154 | # Price filters
155 | if criteria.min_price is not None:
156 | filtered_results = [
157 | r for r in filtered_results if r.close_price >= criteria.min_price
158 | ]
159 |
160 | if criteria.max_price is not None:
161 | filtered_results = [
162 | r for r in filtered_results if r.close_price <= criteria.max_price
163 | ]
164 |
165 | # Score filters
166 | if criteria.min_combined_score is not None:
167 | filtered_results = [
168 | r
169 | for r in filtered_results
170 | if r.combined_score >= criteria.min_combined_score
171 | ]
172 |
173 | if criteria.min_bear_score is not None:
174 | filtered_results = [
175 | r for r in filtered_results if r.bear_score >= criteria.min_bear_score
176 | ]
177 |
178 | # ADR filters
179 | if criteria.min_adr_percentage is not None:
180 | filtered_results = [
181 | r
182 | for r in filtered_results
183 | if r.adr_percentage >= criteria.min_adr_percentage
184 | ]
185 |
186 | if criteria.max_adr_percentage is not None:
187 | filtered_results = [
188 | r
189 | for r in filtered_results
190 | if r.adr_percentage <= criteria.max_adr_percentage
191 | ]
192 |
193 | # Pattern filters
194 | if criteria.require_pattern_detected:
195 | filtered_results = [r for r in filtered_results if r.pattern_detected > 0]
196 |
197 | if criteria.require_squeeze:
198 | filtered_results = [
199 | r
200 | for r in filtered_results
201 | if r.squeeze is not None and r.squeeze.strip()
202 | ]
203 |
204 | if criteria.require_vcp:
205 | filtered_results = [
206 | r for r in filtered_results if r.vcp is not None and r.vcp.strip()
207 | ]
208 |
209 | if criteria.require_entry_signal:
210 | filtered_results = [
211 | r
212 | for r in filtered_results
213 | if r.entry_signal is not None and r.entry_signal.strip()
214 | ]
215 |
216 | # Moving average filters
217 | if criteria.require_above_sma50:
218 | filtered_results = [r for r in filtered_results if r.close_price > r.sma_50]
219 |
220 | if criteria.require_above_sma150:
221 | filtered_results = [
222 | r for r in filtered_results if r.close_price > r.sma_150
223 | ]
224 |
225 | if criteria.require_above_sma200:
226 | filtered_results = [
227 | r for r in filtered_results if r.close_price > r.sma_200
228 | ]
229 |
230 | if criteria.require_ma_alignment:
231 | filtered_results = [
232 | r
233 | for r in filtered_results
234 | if (r.sma_50 > r.sma_150 and r.sma_150 > r.sma_200)
235 | ]
236 |
237 | return filtered_results
238 |
239 | def sort_screening_results(
240 | self, results: list[ScreeningResult], sorting: SortingOptions
241 | ) -> list[ScreeningResult]:
242 | """
243 | Sort screening results according to the specified options.
244 |
245 | This method implements the business rules for ranking and
246 | ordering screening results.
247 | """
248 |
249 | def get_sort_value(result: ScreeningResult, field: str) -> Any:
250 | """Get the value for sorting from a result."""
251 | if field == "combined_score":
252 | return result.combined_score
253 | elif field == "bear_score":
254 | return result.bear_score
255 | elif field == "momentum_score":
256 | return result.momentum_score
257 | elif field == "close_price":
258 | return result.close_price
259 | elif field == "volume":
260 | return result.volume
261 | elif field == "avg_volume_30d":
262 | return result.avg_volume_30d
263 | elif field == "adr_percentage":
264 | return result.adr_percentage
265 | elif field == "quality_score":
266 | return result.get_quality_score()
267 | else:
268 | return 0
269 |
270 | # Sort by primary field
271 | sorted_results = sorted(
272 | results,
273 | key=lambda r: get_sort_value(r, sorting.field),
274 | reverse=sorting.descending,
275 | )
276 |
277 | # Apply secondary sort if specified
278 | if sorting.secondary_field:
279 | sorted_results = sorted(
280 | sorted_results,
281 | key=lambda r: (
282 | get_sort_value(r, sorting.field),
283 | get_sort_value(r, sorting.secondary_field),
284 | ),
285 | reverse=sorting.descending,
286 | )
287 |
288 | return sorted_results
289 |
290 | def create_screening_collection(
291 | self,
292 | results: list[ScreeningResult],
293 | strategy: ScreeningStrategy,
294 | total_candidates: int,
295 | ) -> ScreeningResultCollection:
296 | """
297 | Create a ScreeningResultCollection from individual results.
298 |
299 | This method assembles the aggregate root with proper validation.
300 | """
301 | return ScreeningResultCollection(
302 | results=results,
303 | strategy_used=strategy.value,
304 | screening_timestamp=datetime.utcnow(),
305 | total_candidates_analyzed=total_candidates,
306 | )
307 |
308 | def validate_screening_limits(self, requested_limit: int) -> int:
309 | """
310 | Validate and adjust the requested result limit.
311 |
312 | Business rule: Limits must be within acceptable bounds.
313 | """
314 | return self._default_limits.validate_limit(requested_limit)
315 |
316 | def calculate_screening_statistics(
317 | self, collection: ScreeningResultCollection
318 | ) -> dict[str, Any]:
319 | """
320 | Calculate comprehensive statistics for a screening collection.
321 |
322 | This method provides business intelligence metrics for
323 | screening result analysis.
324 | """
325 | base_stats = collection.get_statistics()
326 |
327 | # Add additional business metrics
328 | results = collection.results
329 | if not results:
330 | return base_stats
331 |
332 | # Quality distribution
333 | quality_scores = [r.get_quality_score() for r in results]
334 | base_stats.update(
335 | {
336 | "quality_distribution": {
337 | "high_quality": sum(1 for q in quality_scores if q >= 80),
338 | "medium_quality": sum(1 for q in quality_scores if 50 <= q < 80),
339 | "low_quality": sum(1 for q in quality_scores if q < 50),
340 | },
341 | "avg_quality_score": sum(quality_scores) / len(quality_scores),
342 | }
343 | )
344 |
345 | # Risk/reward analysis
346 | risk_rewards = [r.calculate_risk_reward_ratio() for r in results]
347 | valid_ratios = [rr for rr in risk_rewards if rr > 0]
348 |
349 | if valid_ratios:
350 | base_stats.update(
351 | {
352 | "risk_reward_analysis": {
353 | "avg_ratio": float(sum(valid_ratios) / len(valid_ratios)),
354 | "favorable_setups": sum(1 for rr in valid_ratios if rr >= 2),
355 | "conservative_setups": sum(
356 | 1 for rr in valid_ratios if 1 <= rr < 2
357 | ),
358 | "risky_setups": sum(1 for rr in valid_ratios if rr < 1),
359 | }
360 | }
361 | )
362 |
363 | # Strategy-specific metrics
364 | if collection.strategy_used == ScreeningStrategy.MAVERICK_BULLISH.value:
365 | base_stats["momentum_analysis"] = self._calculate_momentum_metrics(results)
366 | elif collection.strategy_used == ScreeningStrategy.MAVERICK_BEARISH.value:
367 | base_stats["weakness_analysis"] = self._calculate_weakness_metrics(results)
368 | elif collection.strategy_used == ScreeningStrategy.TRENDING_STAGE2.value:
369 | base_stats["trend_analysis"] = self._calculate_trend_metrics(results)
370 |
371 | return base_stats
372 |
373 | def _calculate_momentum_metrics(
374 | self, results: list[ScreeningResult]
375 | ) -> dict[str, Any]:
376 | """Calculate momentum-specific metrics for bullish screens."""
377 | return {
378 | "high_momentum": sum(1 for r in results if r.combined_score >= 80),
379 | "pattern_breakouts": sum(1 for r in results if r.pattern_detected > 0),
380 | "strong_momentum": sum(1 for r in results if r.momentum_score >= 90),
381 | }
382 |
383 | def _calculate_weakness_metrics(
384 | self, results: list[ScreeningResult]
385 | ) -> dict[str, Any]:
386 | """Calculate weakness-specific metrics for bearish screens."""
387 | return {
388 | "severe_weakness": sum(1 for r in results if r.bear_score >= 80),
389 | "distribution_signals": sum(
390 | 1
391 | for r in results
392 | if r.distribution_days_20 is not None and r.distribution_days_20 >= 5
393 | ),
394 | "breakdown_candidates": sum(
395 | 1 for r in results if r.close_price < r.sma_200
396 | ),
397 | }
398 |
399 | def _calculate_trend_metrics(
400 | self, results: list[ScreeningResult]
401 | ) -> dict[str, Any]:
402 | """Calculate trend-specific metrics for trending screens."""
403 | return {
404 | "strong_trends": sum(1 for r in results if r.is_trending_stage2()),
405 | "perfect_alignment": sum(
406 | 1 for r in results if (r.sma_50 > r.sma_150 and r.sma_150 > r.sma_200)
407 | ),
408 | "elite_momentum": sum(1 for r in results if r.momentum_score >= 95),
409 | }
410 |
```
--------------------------------------------------------------------------------
/maverick_mcp/api/routers/data.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Data fetching router for Maverick-MCP.
3 |
4 | This module contains all data retrieval tools including
5 | stock data, news, fundamentals, and caching operations.
6 |
7 | Updated to use separated services following Single Responsibility Principle.
8 | """
9 |
10 | import json
11 | import logging
12 | from concurrent.futures import ThreadPoolExecutor
13 | from datetime import UTC, datetime
14 | from typing import Any
15 |
16 | import requests
17 | import requests.exceptions
18 | from fastmcp import FastMCP
19 |
20 | from maverick_mcp.config.settings import settings
21 | from maverick_mcp.data.models import PriceCache
22 | from maverick_mcp.data.session_management import get_db_session_read_only
23 | from maverick_mcp.domain.stock_analysis import StockAnalysisService
24 | from maverick_mcp.infrastructure.caching import CacheManagementService
25 | from maverick_mcp.infrastructure.data_fetching import StockDataFetchingService
26 | from maverick_mcp.providers.stock_data import (
27 | StockDataProvider,
28 | ) # Kept for backward compatibility
29 |
30 | logger = logging.getLogger(__name__)
31 |
32 | # Create the data router
33 | data_router: FastMCP = FastMCP("Data_Operations")
34 |
35 | # Thread pool for blocking operations
36 | executor = ThreadPoolExecutor(max_workers=10)
37 |
38 |
39 | def fetch_stock_data(
40 | ticker: str,
41 | start_date: str | None = None,
42 | end_date: str | None = None,
43 | ) -> dict[str, Any]:
44 | """
45 | Fetch historical stock data for a given ticker symbol.
46 |
47 | This is the primary tool for retrieving stock price data. It uses intelligent
48 | caching to minimize API calls and improve performance.
49 |
50 | Updated to use separated services following Single Responsibility Principle.
51 |
52 | Args:
53 | ticker: The ticker symbol of the stock (e.g., AAPL, MSFT)
54 | start_date: Start date for data in YYYY-MM-DD format (default: 1 year ago)
55 | end_date: End date for data in YYYY-MM-DD format (default: today)
56 |
57 | Returns:
58 | Dictionary containing the stock data in JSON format with:
59 | - data: OHLCV price data
60 | - columns: Column names
61 | - index: Date index
62 |
63 | Examples:
64 | >>> fetch_stock_data(ticker="AAPL")
65 | >>> fetch_stock_data(
66 | ... ticker="MSFT",
67 | ... start_date="2024-01-01",
68 | ... end_date="2024-12-31"
69 | ... )
70 | """
71 | try:
72 | # Create services with dependency injection
73 | data_fetching_service = StockDataFetchingService()
74 |
75 | with get_db_session_read_only() as session:
76 | cache_service = CacheManagementService(db_session=session)
77 | stock_analysis_service = StockAnalysisService(
78 | data_fetching_service=data_fetching_service,
79 | cache_service=cache_service,
80 | db_session=session,
81 | )
82 |
83 | data = stock_analysis_service.get_stock_data(ticker, start_date, end_date)
84 | json_data = data.to_json(orient="split", date_format="iso")
85 | result: dict[str, Any] = json.loads(json_data) if json_data else {}
86 | result["ticker"] = ticker
87 | result["record_count"] = len(data)
88 | return result
89 | except Exception as e:
90 | logger.error(f"Error fetching stock data for {ticker}: {e}")
91 | return {"error": str(e), "ticker": ticker}
92 |
93 |
94 | def fetch_stock_data_batch(
95 | tickers: list[str],
96 | start_date: str | None = None,
97 | end_date: str | None = None,
98 | ) -> dict[str, Any]:
99 | """
100 | Fetch historical data for multiple tickers efficiently.
101 |
102 | This tool fetches data for multiple stocks in a single call,
103 | which is more efficient than calling fetch_stock_data multiple times.
104 |
105 | Updated to use separated services following Single Responsibility Principle.
106 |
107 | Args:
108 | tickers: List of ticker symbols (e.g., ["AAPL", "MSFT", "GOOGL"])
109 | start_date: Start date in YYYY-MM-DD format
110 | end_date: End date in YYYY-MM-DD format
111 |
112 | Returns:
113 | Dictionary with ticker symbols as keys and data/errors as values
114 |
115 | Examples:
116 | >>> fetch_stock_data_batch(
117 | ... tickers=["AAPL", "MSFT", "GOOGL"],
118 | ... start_date="2024-01-01"
119 | ... )
120 | """
121 | results = {}
122 |
123 | # Create services with dependency injection
124 | data_fetching_service = StockDataFetchingService()
125 |
126 | with get_db_session_read_only() as session:
127 | cache_service = CacheManagementService(db_session=session)
128 | stock_analysis_service = StockAnalysisService(
129 | data_fetching_service=data_fetching_service,
130 | cache_service=cache_service,
131 | db_session=session,
132 | )
133 |
134 | for ticker in tickers:
135 | try:
136 | data = stock_analysis_service.get_stock_data(
137 | ticker, start_date, end_date
138 | )
139 | results[ticker] = {
140 | "status": "success",
141 | "data": json.loads(
142 | data.to_json(orient="split", date_format="iso") or "{}"
143 | ),
144 | "record_count": len(data),
145 | }
146 | except Exception as e:
147 | logger.error(f"Error fetching data for {ticker}: {e}")
148 | results[ticker] = {"status": "error", "error": str(e)}
149 |
150 | return {
151 | "results": results,
152 | "success_count": sum(1 for r in results.values() if r["status"] == "success"),
153 | "error_count": sum(1 for r in results.values() if r["status"] == "error"),
154 | "tickers": tickers,
155 | }
156 |
157 |
158 | def get_stock_info(ticker: str) -> dict[str, Any]:
159 | """
160 | Get detailed fundamental information about a stock.
161 |
162 | This tool retrieves comprehensive stock information including:
163 | - Company description and sector
164 | - Market cap and valuation metrics
165 | - Financial ratios
166 | - Trading information
167 |
168 | Args:
169 | ticker: Stock ticker symbol
170 |
171 | Returns:
172 | Dictionary containing detailed stock information
173 | """
174 | try:
175 | # Use read-only context manager for automatic session management
176 | with get_db_session_read_only() as session:
177 | provider = StockDataProvider(db_session=session)
178 | info = provider.get_stock_info(ticker)
179 |
180 | # Extract key information
181 | return {
182 | "ticker": ticker,
183 | "company": {
184 | "name": info.get("longName", info.get("shortName")),
185 | "sector": info.get("sector"),
186 | "industry": info.get("industry"),
187 | "website": info.get("website"),
188 | "description": info.get("longBusinessSummary"),
189 | },
190 | "market_data": {
191 | "current_price": info.get(
192 | "currentPrice", info.get("regularMarketPrice")
193 | ),
194 | "market_cap": info.get("marketCap"),
195 | "enterprise_value": info.get("enterpriseValue"),
196 | "shares_outstanding": info.get("sharesOutstanding"),
197 | "float_shares": info.get("floatShares"),
198 | },
199 | "valuation": {
200 | "pe_ratio": info.get("trailingPE"),
201 | "forward_pe": info.get("forwardPE"),
202 | "peg_ratio": info.get("pegRatio"),
203 | "price_to_book": info.get("priceToBook"),
204 | "price_to_sales": info.get("priceToSalesTrailing12Months"),
205 | },
206 | "financials": {
207 | "revenue": info.get("totalRevenue"),
208 | "profit_margin": info.get("profitMargins"),
209 | "operating_margin": info.get("operatingMargins"),
210 | "roe": info.get("returnOnEquity"),
211 | "roa": info.get("returnOnAssets"),
212 | },
213 | "trading": {
214 | "avg_volume": info.get("averageVolume"),
215 | "avg_volume_10d": info.get("averageVolume10days"),
216 | "beta": info.get("beta"),
217 | "52_week_high": info.get("fiftyTwoWeekHigh"),
218 | "52_week_low": info.get("fiftyTwoWeekLow"),
219 | },
220 | }
221 | except Exception as e:
222 | logger.error(f"Error fetching stock info for {ticker}: {e}")
223 | return {"error": str(e), "ticker": ticker}
224 |
225 |
226 | def get_news_sentiment(
227 | ticker: str,
228 | timeframe: str = "7d",
229 | limit: int = 10,
230 | ) -> dict[str, Any]:
231 | """
232 | Retrieve news sentiment analysis for a stock.
233 |
234 | This tool fetches sentiment data from External API,
235 | providing insights into market sentiment based on recent news.
236 |
237 | Args:
238 | ticker: The ticker symbol of the stock to analyze
239 | timeframe: Time frame for news (1d, 7d, 30d, etc.)
240 | limit: Maximum number of news articles to analyze
241 |
242 | Returns:
243 | Dictionary containing news sentiment analysis
244 | """
245 | try:
246 | api_key = settings.external_data.api_key
247 | base_url = settings.external_data.base_url
248 | if not api_key:
249 | logger.info(
250 | "External sentiment API not configured, providing basic response"
251 | )
252 | return {
253 | "ticker": ticker,
254 | "sentiment": "neutral",
255 | "message": "External sentiment API not configured - configure EXTERNAL_DATA_API_KEY for enhanced sentiment analysis",
256 | "status": "fallback_mode",
257 | "confidence": 0.5,
258 | "source": "fallback",
259 | }
260 |
261 | url = f"{base_url}/sentiment/{ticker}"
262 | headers = {"X-API-KEY": api_key}
263 | logger.info(f"Fetching sentiment for {ticker} from {url}")
264 | resp = requests.get(url, headers=headers, timeout=10)
265 |
266 | if resp.status_code == 404:
267 | return {
268 | "ticker": ticker,
269 | "sentiment": "unavailable",
270 | "message": f"No sentiment data available for {ticker}",
271 | "status": "not_found",
272 | }
273 | elif resp.status_code == 401:
274 | return {
275 | "error": "Invalid API key",
276 | "ticker": ticker,
277 | "sentiment": "unavailable",
278 | "status": "unauthorized",
279 | }
280 | elif resp.status_code == 429:
281 | return {
282 | "error": "Rate limit exceeded",
283 | "ticker": ticker,
284 | "sentiment": "unavailable",
285 | "status": "rate_limited",
286 | }
287 |
288 | resp.raise_for_status()
289 | return resp.json()
290 |
291 | except requests.exceptions.Timeout:
292 | return {
293 | "error": "Request timed out",
294 | "ticker": ticker,
295 | "sentiment": "unavailable",
296 | "status": "timeout",
297 | }
298 | except requests.exceptions.ConnectionError:
299 | return {
300 | "error": "Connection error",
301 | "ticker": ticker,
302 | "sentiment": "unavailable",
303 | "status": "connection_error",
304 | }
305 | except Exception as e:
306 | logger.error(f"Error fetching sentiment from External API for {ticker}: {e}")
307 | return {
308 | "error": str(e),
309 | "ticker": ticker,
310 | "sentiment": "unavailable",
311 | "status": "error",
312 | }
313 |
314 |
315 | def get_cached_price_data(
316 | ticker: str,
317 | start_date: str,
318 | end_date: str | None = None,
319 | ) -> dict[str, Any]:
320 | """
321 | Get cached price data directly from the database.
322 |
323 | This tool retrieves data from the local cache without making external API calls.
324 | Useful for checking what data is available locally.
325 |
326 | Args:
327 | ticker: Stock ticker symbol
328 | start_date: Start date in YYYY-MM-DD format
329 | end_date: End date in YYYY-MM-DD format (optional, defaults to today)
330 |
331 | Returns:
332 | Dictionary containing cached price data
333 | """
334 | try:
335 | with get_db_session_read_only() as session:
336 | df = PriceCache.get_price_data(session, ticker, start_date, end_date)
337 |
338 | if df.empty:
339 | return {
340 | "status": "success",
341 | "ticker": ticker,
342 | "message": "No cached data found for the specified date range",
343 | "data": [],
344 | }
345 |
346 | # Convert DataFrame to dict format
347 | data = df.reset_index().to_dict(orient="records")
348 |
349 | return {
350 | "status": "success",
351 | "ticker": ticker,
352 | "start_date": start_date,
353 | "end_date": end_date or datetime.now(UTC).strftime("%Y-%m-%d"),
354 | "count": len(data),
355 | "data": data,
356 | }
357 | except Exception as e:
358 | logger.error(f"Error fetching cached price data for {ticker}: {str(e)}")
359 | return {"error": str(e), "status": "error"}
360 |
361 |
362 | def get_chart_links(ticker: str) -> dict[str, Any]:
363 | """
364 | Provide links to various financial charting websites.
365 |
366 | This tool generates URLs to popular financial websites where detailed
367 | stock charts can be viewed, including:
368 | - TradingView (advanced charting)
369 | - Finviz (visual screener)
370 | - Yahoo Finance (comprehensive data)
371 | - StockCharts (technical analysis)
372 |
373 | Args:
374 | ticker: The ticker symbol of the stock
375 |
376 | Returns:
377 | Dictionary containing links to various chart providers
378 | """
379 | try:
380 | links = {
381 | "trading_view": f"https://www.tradingview.com/symbols/{ticker}",
382 | "finviz": f"https://finviz.com/quote.ashx?t={ticker}",
383 | "yahoo_finance": f"https://finance.yahoo.com/quote/{ticker}/chart",
384 | "stock_charts": f"https://stockcharts.com/h-sc/ui?s={ticker}",
385 | "seeking_alpha": f"https://seekingalpha.com/symbol/{ticker}/charts",
386 | "marketwatch": f"https://www.marketwatch.com/investing/stock/{ticker}/charts",
387 | }
388 |
389 | return {
390 | "ticker": ticker,
391 | "charts": links,
392 | "description": "External chart resources for detailed analysis",
393 | }
394 | except Exception as e:
395 | logger.error(f"Error generating chart links for {ticker}: {e}")
396 | return {"error": str(e)}
397 |
398 |
399 | def clear_cache(ticker: str | None = None) -> dict[str, Any]:
400 | """
401 | Clear cached data for a specific ticker or all tickers.
402 |
403 | This tool helps manage the local cache by removing stored data,
404 | forcing fresh data retrieval on the next request.
405 |
406 | Args:
407 | ticker: Specific ticker to clear (None to clear all)
408 |
409 | Returns:
410 | Dictionary with cache clearing status
411 | """
412 | try:
413 | from maverick_mcp.data.cache import clear_cache as cache_clear
414 |
415 | if ticker:
416 | pattern = f"stock:{ticker}:*"
417 | count = cache_clear(pattern)
418 | message = f"Cleared cache for {ticker}"
419 | else:
420 | count = cache_clear()
421 | message = "Cleared all cache entries"
422 |
423 | return {"status": "success", "message": message, "entries_cleared": count}
424 | except Exception as e:
425 | logger.error(f"Error clearing cache: {e}")
426 | return {"error": str(e), "status": "error"}
427 |
```
--------------------------------------------------------------------------------
/maverick_mcp/config/security.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Comprehensive Security Configuration for Maverick MCP.
3 |
4 | This module provides centralized security configuration including CORS settings,
5 | security headers, rate limiting, and environment-specific security policies.
6 | All security settings are validated to prevent common misconfigurations.
7 | """
8 |
9 | import os
10 |
11 | from pydantic import BaseModel, Field, model_validator
12 |
13 |
14 | class CORSConfig(BaseModel):
15 | """CORS (Cross-Origin Resource Sharing) configuration with validation."""
16 |
17 | # Origins configuration
18 | allowed_origins: list[str] = Field(
19 | default_factory=lambda: _get_cors_origins(),
20 | description="List of allowed origins for CORS requests",
21 | )
22 |
23 | # Credentials and methods
24 | allow_credentials: bool = Field(
25 | default=True, description="Whether to allow credentials in CORS requests"
26 | )
27 |
28 | allowed_methods: list[str] = Field(
29 | default=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"],
30 | description="Allowed HTTP methods for CORS requests",
31 | )
32 |
33 | # Headers configuration
34 | allowed_headers: list[str] = Field(
35 | default=[
36 | "Authorization",
37 | "Content-Type",
38 | "X-API-Key",
39 | "X-Request-ID",
40 | "X-Requested-With",
41 | "Accept",
42 | "Origin",
43 | "User-Agent",
44 | "Cache-Control",
45 | ],
46 | description="Allowed headers for CORS requests",
47 | )
48 |
49 | exposed_headers: list[str] = Field(
50 | default=[
51 | "X-Process-Time",
52 | "X-RateLimit-Limit",
53 | "X-RateLimit-Remaining",
54 | "X-RateLimit-Reset",
55 | "X-Request-ID",
56 | ],
57 | description="Headers exposed to the client",
58 | )
59 |
60 | # Cache and validation
61 | max_age: int = Field(
62 | default=86400, # 24 hours
63 | description="Maximum age for CORS preflight cache in seconds",
64 | )
65 |
66 | @model_validator(mode="after")
67 | def validate_cors_security(self):
68 | """Validate CORS configuration for security."""
69 | # Critical: Wildcard origins with credentials is dangerous
70 | if self.allow_credentials and "*" in self.allowed_origins:
71 | raise ValueError(
72 | "CORS Security Error: Cannot use wildcard origin ('*') with "
73 | "allow_credentials=True. This is a serious security vulnerability. "
74 | "Specify explicit origins instead."
75 | )
76 |
77 | # Warning for wildcard origins without credentials
78 | if "*" in self.allowed_origins and not self.allow_credentials:
79 | # This is allowed but should be logged
80 | import logging
81 |
82 | logger = logging.getLogger(__name__)
83 | logger.warning(
84 | "CORS Warning: Using wildcard origin ('*') without credentials. "
85 | "Consider using specific origins for better security."
86 | )
87 |
88 | return self
89 |
90 |
91 | class SecurityHeadersConfig(BaseModel):
92 | """Security headers configuration."""
93 |
94 | # Content security
95 | x_content_type_options: str = Field(
96 | default="nosniff", description="X-Content-Type-Options header value"
97 | )
98 |
99 | x_frame_options: str = Field(
100 | default="DENY",
101 | description="X-Frame-Options header value (DENY, SAMEORIGIN, or ALLOW-FROM)",
102 | )
103 |
104 | x_xss_protection: str = Field(
105 | default="1; mode=block", description="X-XSS-Protection header value"
106 | )
107 |
108 | referrer_policy: str = Field(
109 | default="strict-origin-when-cross-origin",
110 | description="Referrer-Policy header value",
111 | )
112 |
113 | permissions_policy: str = Field(
114 | default="geolocation=(), microphone=(), camera=(), usb=(), magnetometer=()",
115 | description="Permissions-Policy header value",
116 | )
117 |
118 | # HSTS (HTTP Strict Transport Security)
119 | hsts_max_age: int = Field(
120 | default=31536000, # 1 year
121 | description="HSTS max-age in seconds",
122 | )
123 |
124 | hsts_include_subdomains: bool = Field(
125 | default=True, description="Include subdomains in HSTS policy"
126 | )
127 |
128 | hsts_preload: bool = Field(
129 | default=False,
130 | description="Enable HSTS preload (requires manual submission to browser vendors)",
131 | )
132 |
133 | # Content Security Policy
134 | csp_default_src: list[str] = Field(
135 | default=["'self'"], description="CSP default-src directive"
136 | )
137 |
138 | csp_script_src: list[str] = Field(
139 | default=["'self'", "'unsafe-inline'"],
140 | description="CSP script-src directive",
141 | )
142 |
143 | csp_style_src: list[str] = Field(
144 | default=["'self'", "'unsafe-inline'"], description="CSP style-src directive"
145 | )
146 |
147 | csp_img_src: list[str] = Field(
148 | default=["'self'", "data:", "https:"], description="CSP img-src directive"
149 | )
150 |
151 | csp_connect_src: list[str] = Field(
152 | default=["'self'"],
153 | description="CSP connect-src directive",
154 | )
155 |
156 | csp_frame_src: list[str] = Field(
157 | default=["'none'"], description="CSP frame-src directive"
158 | )
159 |
160 | csp_object_src: list[str] = Field(
161 | default=["'none'"], description="CSP object-src directive"
162 | )
163 |
164 | @property
165 | def hsts_header_value(self) -> str:
166 | """Generate HSTS header value."""
167 | value = f"max-age={self.hsts_max_age}"
168 | if self.hsts_include_subdomains:
169 | value += "; includeSubDomains"
170 | if self.hsts_preload:
171 | value += "; preload"
172 | return value
173 |
174 | @property
175 | def csp_header_value(self) -> str:
176 | """Generate Content-Security-Policy header value."""
177 | directives = [
178 | f"default-src {' '.join(self.csp_default_src)}",
179 | f"script-src {' '.join(self.csp_script_src)}",
180 | f"style-src {' '.join(self.csp_style_src)}",
181 | f"img-src {' '.join(self.csp_img_src)}",
182 | f"connect-src {' '.join(self.csp_connect_src)}",
183 | f"frame-src {' '.join(self.csp_frame_src)}",
184 | f"object-src {' '.join(self.csp_object_src)}",
185 | "base-uri 'self'",
186 | "form-action 'self'",
187 | ]
188 | return "; ".join(directives)
189 |
190 |
191 | class RateLimitConfig(BaseModel):
192 | """Rate limiting configuration."""
193 |
194 | # Basic rate limits
195 | default_rate_limit: str = Field(
196 | default="1000 per hour", description="Default rate limit for all endpoints"
197 | )
198 |
199 | # User-specific limits
200 | authenticated_limit_per_minute: int = Field(
201 | default=60, description="Rate limit for authenticated users per minute"
202 | )
203 |
204 | anonymous_limit_per_minute: int = Field(
205 | default=10, description="Rate limit for anonymous users per minute"
206 | )
207 |
208 | # Endpoint-specific limits
209 | auth_endpoints_limit: str = Field(
210 | default="10 per hour",
211 | description="Rate limit for authentication endpoints (login, signup)",
212 | )
213 |
214 | api_endpoints_limit: str = Field(
215 | default="60 per minute", description="Rate limit for API endpoints"
216 | )
217 |
218 | sensitive_endpoints_limit: str = Field(
219 | default="5 per minute", description="Rate limit for sensitive operations"
220 | )
221 |
222 | webhook_endpoints_limit: str = Field(
223 | default="100 per minute", description="Rate limit for webhook endpoints"
224 | )
225 |
226 | # Redis configuration for rate limiting
227 | redis_url: str | None = Field(
228 | default_factory=lambda: os.getenv("AUTH_REDIS_URL", "redis://localhost:6379/1"),
229 | description="Redis URL for rate limiting storage",
230 | )
231 |
232 | enabled: bool = Field(
233 | default_factory=lambda: os.getenv("RATE_LIMITING_ENABLED", "true").lower()
234 | == "true",
235 | description="Enable rate limiting",
236 | )
237 |
238 |
239 | class TrustedHostsConfig(BaseModel):
240 | """Trusted hosts configuration."""
241 |
242 | allowed_hosts: list[str] = Field(
243 | default_factory=lambda: _get_trusted_hosts(),
244 | description="List of trusted host patterns",
245 | )
246 |
247 | enforce_in_development: bool = Field(
248 | default=False, description="Whether to enforce trusted hosts in development"
249 | )
250 |
251 |
252 | class SecurityConfig(BaseModel):
253 | """Comprehensive security configuration for Maverick MCP."""
254 |
255 | # Environment detection
256 | environment: str = Field(
257 | default_factory=lambda: os.getenv("ENVIRONMENT", "development").lower(),
258 | description="Environment (development, staging, production)",
259 | )
260 |
261 | # Sub-configurations
262 | cors: CORSConfig = Field(
263 | default_factory=CORSConfig, description="CORS configuration"
264 | )
265 |
266 | headers: SecurityHeadersConfig = Field(
267 | default_factory=SecurityHeadersConfig,
268 | description="Security headers configuration",
269 | )
270 |
271 | rate_limiting: RateLimitConfig = Field(
272 | default_factory=RateLimitConfig, description="Rate limiting configuration"
273 | )
274 |
275 | trusted_hosts: TrustedHostsConfig = Field(
276 | default_factory=TrustedHostsConfig, description="Trusted hosts configuration"
277 | )
278 |
279 | # General security settings
280 | force_https: bool = Field(
281 | default_factory=lambda: os.getenv("FORCE_HTTPS", "false").lower() == "true",
282 | description="Force HTTPS in production",
283 | )
284 |
285 | strict_security: bool = Field(
286 | default_factory=lambda: os.getenv("STRICT_SECURITY", "false").lower() == "true",
287 | description="Enable strict security mode",
288 | )
289 |
290 | @model_validator(mode="after")
291 | def validate_environment_security(self):
292 | """Validate security configuration based on environment."""
293 | if self.environment == "production":
294 | # Production security requirements
295 | if not self.force_https:
296 | import logging
297 |
298 | logger = logging.getLogger(__name__)
299 | logger.warning(
300 | "Production Warning: FORCE_HTTPS is disabled in production. "
301 | "Set FORCE_HTTPS=true for better security."
302 | )
303 |
304 | # Validate CORS for production
305 | if "*" in self.cors.allowed_origins:
306 | import logging
307 |
308 | logger = logging.getLogger(__name__)
309 | logger.error(
310 | "Production Error: Wildcard CORS origins detected in production. "
311 | "This is a security risk and should be fixed."
312 | )
313 |
314 | return self
315 |
316 | def get_cors_middleware_config(self) -> dict:
317 | """Get CORS middleware configuration dictionary."""
318 | return {
319 | "allow_origins": self.cors.allowed_origins,
320 | "allow_credentials": self.cors.allow_credentials,
321 | "allow_methods": self.cors.allowed_methods,
322 | "allow_headers": self.cors.allowed_headers,
323 | "expose_headers": self.cors.exposed_headers,
324 | "max_age": self.cors.max_age,
325 | }
326 |
327 | def get_security_headers(self) -> dict[str, str]:
328 | """Get security headers dictionary."""
329 | headers = {
330 | "X-Content-Type-Options": self.headers.x_content_type_options,
331 | "X-Frame-Options": self.headers.x_frame_options,
332 | "X-XSS-Protection": self.headers.x_xss_protection,
333 | "Referrer-Policy": self.headers.referrer_policy,
334 | "Permissions-Policy": self.headers.permissions_policy,
335 | "Content-Security-Policy": self.headers.csp_header_value,
336 | }
337 |
338 | # Add HSTS only in production or when HTTPS is forced
339 | if self.environment == "production" or self.force_https:
340 | headers["Strict-Transport-Security"] = self.headers.hsts_header_value
341 |
342 | return headers
343 |
344 | def is_production(self) -> bool:
345 | """Check if running in production environment."""
346 | return self.environment == "production"
347 |
348 | def is_development(self) -> bool:
349 | """Check if running in development environment."""
350 | return self.environment in ["development", "dev", "local"]
351 |
352 |
353 | def _get_cors_origins() -> list[str]:
354 | """Get CORS origins based on environment."""
355 | environment = os.getenv("ENVIRONMENT", "development").lower()
356 | cors_origins_env = os.getenv("CORS_ORIGINS")
357 |
358 | if cors_origins_env:
359 | # Parse comma-separated origins from environment
360 | return [origin.strip() for origin in cors_origins_env.split(",")]
361 |
362 | if environment == "production":
363 | return [
364 | "https://app.maverick-mcp.com",
365 | "https://maverick-mcp.com",
366 | "https://www.maverick-mcp.com",
367 | ]
368 | elif environment in ["staging", "test"]:
369 | return [
370 | "https://staging.maverick-mcp.com",
371 | "https://test.maverick-mcp.com",
372 | "http://localhost:3000",
373 | "http://localhost:3001",
374 | ]
375 | else:
376 | # Development
377 | return [
378 | "http://localhost:3000",
379 | "http://localhost:3001",
380 | "http://127.0.0.1:3000",
381 | "http://127.0.0.1:3001",
382 | "http://localhost:8080",
383 | "http://localhost:5173", # Vite default
384 | ]
385 |
386 |
387 | def _get_trusted_hosts() -> list[str]:
388 | """Get trusted hosts based on environment."""
389 | environment = os.getenv("ENVIRONMENT", "development").lower()
390 | trusted_hosts_env = os.getenv("TRUSTED_HOSTS")
391 |
392 | if trusted_hosts_env:
393 | # Parse comma-separated hosts from environment
394 | return [host.strip() for host in trusted_hosts_env.split(",")]
395 |
396 | if environment == "production":
397 | return ["api.maverick-mcp.com", "*.maverick-mcp.com", "maverick-mcp.com"]
398 | elif environment in ["staging", "test"]:
399 | return [
400 | "staging.maverick-mcp.com",
401 | "test.maverick-mcp.com",
402 | "*.maverick-mcp.com",
403 | "localhost",
404 | "127.0.0.1",
405 | ]
406 | else:
407 | # Development - allow any host
408 | return ["*"]
409 |
410 |
411 | # Create singleton instance
412 | security_config = SecurityConfig()
413 |
414 |
415 | def get_security_config() -> SecurityConfig:
416 | """Get the security configuration instance."""
417 | return security_config
418 |
419 |
420 | def validate_security_config() -> dict[str, any]:
421 | """Validate the current security configuration."""
422 | config = get_security_config()
423 | issues = []
424 | warnings = []
425 |
426 | # Check for dangerous CORS configuration
427 | if config.cors.allow_credentials and "*" in config.cors.allowed_origins:
428 | issues.append("CRITICAL: Wildcard CORS origins with credentials enabled")
429 |
430 | # Check production-specific requirements
431 | if config.is_production():
432 | if "*" in config.cors.allowed_origins:
433 | issues.append("CRITICAL: Wildcard CORS origins in production")
434 |
435 | if not config.force_https:
436 | warnings.append("HTTPS not enforced in production")
437 |
438 | if "localhost" in str(config.cors.allowed_origins).lower():
439 | warnings.append("Localhost origins found in production CORS config")
440 |
441 | # Check for insecure headers
442 | if config.headers.x_frame_options not in ["DENY", "SAMEORIGIN"]:
443 | warnings.append("X-Frame-Options not set to DENY or SAMEORIGIN")
444 |
445 | return {
446 | "valid": len(issues) == 0,
447 | "issues": issues,
448 | "warnings": warnings,
449 | "environment": config.environment,
450 | "cors_origins": config.cors.allowed_origins,
451 | "force_https": config.force_https,
452 | }
453 |
```
--------------------------------------------------------------------------------
/maverick_mcp/api/middleware/mcp_logging.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Comprehensive MCP Logging Middleware for debugging tool calls and protocol communication.
3 |
4 | This middleware provides:
5 | - Tool call lifecycle logging
6 | - MCP protocol message logging
7 | - Request/response payload logging
8 | - Error tracking with full context
9 | - Performance metrics collection
10 | - Timeout detection and logging
11 | """
12 |
13 | import asyncio
14 | import functools
15 | import json
16 | import logging
17 | import time
18 | import traceback
19 | import uuid
20 | from typing import Any
21 |
22 | from fastmcp import FastMCP
23 |
24 | try:
25 | from fastmcp.server.middleware import Middleware, MiddlewareContext
26 |
27 | MIDDLEWARE_AVAILABLE = True
28 | except ImportError:
29 | # Fallback for older FastMCP versions
30 | MIDDLEWARE_AVAILABLE = False
31 |
32 | class Middleware: # type: ignore
33 | """Fallback Middleware class for older FastMCP versions."""
34 |
35 | pass
36 |
37 | MiddlewareContext = Any
38 |
39 | from maverick_mcp.utils.logging import (
40 | get_logger,
41 | request_id_var,
42 | request_start_var,
43 | tool_name_var,
44 | )
45 |
46 | logger = get_logger("maverick_mcp.middleware.mcp_logging")
47 |
48 |
49 | class MCPLoggingMiddleware(Middleware if MIDDLEWARE_AVAILABLE else object):
50 | """
51 | Comprehensive MCP protocol and tool call logging middleware for FastMCP 2.0+.
52 |
53 | Logs:
54 | - Tool call lifecycle with execution details
55 | - Resource access and prompt retrievals
56 | - Error conditions with full context
57 | - Performance metrics (execution time, memory usage)
58 | - Timeout detection and warnings
59 | """
60 |
61 | def __init__(
62 | self,
63 | include_payloads: bool = True,
64 | max_payload_length: int = 2000,
65 | log_level: int = logging.INFO,
66 | ):
67 | if MIDDLEWARE_AVAILABLE:
68 | super().__init__()
69 | self.include_payloads = include_payloads
70 | self.max_payload_length = max_payload_length
71 | self.log_level = log_level
72 | self.logger = get_logger("maverick_mcp.mcp_protocol")
73 |
74 | async def on_call_tool(self, context: MiddlewareContext, call_next) -> Any:
75 | """Log tool call lifecycle with comprehensive details."""
76 | if not MIDDLEWARE_AVAILABLE:
77 | return await call_next(context)
78 |
79 | request_id = str(uuid.uuid4())
80 | request_start_var.set(time.time())
81 | request_id_var.set(request_id)
82 |
83 | start_time = time.time()
84 | tool_name = getattr(context.message, "name", "unknown_tool")
85 | tool_name_var.set(tool_name)
86 |
87 | # Extract arguments if available
88 | arguments = getattr(context.message, "arguments", {})
89 |
90 | # Log tool call start
91 | self._log_tool_call_start(request_id, tool_name, arguments)
92 |
93 | try:
94 | # Execute with timeout detection
95 | result = await asyncio.wait_for(call_next(context), timeout=25.0)
96 |
97 | # Log successful completion
98 | execution_time = time.time() - start_time
99 | self._log_tool_call_success(request_id, tool_name, result, execution_time)
100 |
101 | return result
102 |
103 | except TimeoutError:
104 | execution_time = time.time() - start_time
105 | self._log_tool_call_timeout(request_id, tool_name, execution_time)
106 | raise
107 |
108 | except Exception as e:
109 | # Log error with full context
110 | execution_time = time.time() - start_time
111 | self._log_tool_call_error(
112 | request_id, tool_name, e, execution_time, arguments
113 | )
114 | raise
115 |
116 | async def on_read_resource(self, context: MiddlewareContext, call_next) -> Any:
117 | """Log resource access."""
118 | if not MIDDLEWARE_AVAILABLE:
119 | return await call_next(context)
120 |
121 | resource_uri = getattr(context.message, "uri", "unknown_resource")
122 | start_time = time.time()
123 |
124 | print(f"🔗 RESOURCE ACCESS: {resource_uri}")
125 |
126 | try:
127 | result = await call_next(context)
128 | execution_time = time.time() - start_time
129 | print(f"✅ RESOURCE SUCCESS: {resource_uri} ({execution_time:.2f}s)")
130 | return result
131 | except Exception as e:
132 | execution_time = time.time() - start_time
133 | print(
134 | f"❌ RESOURCE ERROR: {resource_uri} ({execution_time:.2f}s) - {type(e).__name__}: {str(e)}"
135 | )
136 | raise
137 |
138 | def _log_tool_call_start(self, request_id: str, tool_name: str, arguments: dict):
139 | """Log tool call initiation."""
140 | log_data = {
141 | "request_id": request_id,
142 | "direction": "incoming",
143 | "tool_name": tool_name,
144 | "timestamp": time.time(),
145 | }
146 |
147 | # Add arguments if requested (debug mode)
148 | if self.include_payloads and arguments:
149 | try:
150 | args_str = json.dumps(arguments)[: self.max_payload_length]
151 | log_data["arguments"] = args_str
152 | except Exception as e:
153 | log_data["args_error"] = str(e)
154 |
155 | self.logger.info("TOOL_CALL_START", extra=log_data)
156 |
157 | # Console output for immediate visibility
158 | args_preview = ""
159 | if arguments:
160 | args_str = str(arguments)
161 | args_preview = f" with {args_str[:50]}{'...' if len(args_str) > 50 else ''}"
162 | print(f"🔧 TOOL CALL: {tool_name}{args_preview} [{request_id[:8]}]")
163 |
164 | def _log_tool_call_success(
165 | self, request_id: str, tool_name: str, result: Any, execution_time: float
166 | ):
167 | """Log successful tool completion."""
168 | log_data = {
169 | "request_id": request_id,
170 | "direction": "outgoing",
171 | "tool_name": tool_name,
172 | "execution_time": execution_time,
173 | "status": "success",
174 | "timestamp": time.time(),
175 | }
176 |
177 | # Add result preview if requested (debug mode)
178 | if self.include_payloads and result is not None:
179 | try:
180 | result_str = (
181 | json.dumps(result)[: self.max_payload_length]
182 | if not isinstance(result, str)
183 | else result[: self.max_payload_length]
184 | )
185 | log_data["result_preview"] = result_str
186 | log_data["result_type"] = type(result).__name__
187 | except Exception as e:
188 | log_data["result_error"] = str(e)
189 |
190 | self.logger.info("TOOL_CALL_SUCCESS", extra=log_data)
191 |
192 | # Console output with color coding based on execution time
193 | status_icon = (
194 | "🟢" if execution_time < 5.0 else "🟡" if execution_time < 15.0 else "🟠"
195 | )
196 | print(
197 | f"{status_icon} TOOL SUCCESS: {tool_name} [{request_id[:8]}] {execution_time:.2f}s"
198 | )
199 |
200 | def _log_tool_call_timeout(
201 | self, request_id: str, tool_name: str, execution_time: float
202 | ):
203 | """Log tool timeout."""
204 | log_data = {
205 | "request_id": request_id,
206 | "direction": "outgoing",
207 | "tool_name": tool_name,
208 | "execution_time": execution_time,
209 | "status": "timeout",
210 | "timeout_seconds": 25.0,
211 | "error_type": "timeout",
212 | "timestamp": time.time(),
213 | }
214 |
215 | self.logger.error("TOOL_CALL_TIMEOUT", extra=log_data)
216 | print(
217 | f"⏰ TOOL TIMEOUT: {tool_name} [{request_id[:8]}] {execution_time:.2f}s (exceeded 25s limit)"
218 | )
219 |
220 | def _log_tool_call_error(
221 | self,
222 | request_id: str,
223 | tool_name: str,
224 | error: Exception,
225 | execution_time: float,
226 | arguments: dict,
227 | ):
228 | """Log tool error with full context."""
229 | log_data = {
230 | "request_id": request_id,
231 | "direction": "outgoing",
232 | "tool_name": tool_name,
233 | "execution_time": execution_time,
234 | "status": "error",
235 | "error_type": type(error).__name__,
236 | "error_message": str(error),
237 | "traceback": traceback.format_exc(),
238 | "timestamp": time.time(),
239 | }
240 |
241 | # Add arguments for debugging
242 | if self.include_payloads and arguments:
243 | try:
244 | log_data["arguments"] = json.dumps(arguments)[: self.max_payload_length]
245 | except Exception as e:
246 | log_data["args_error"] = str(e)
247 |
248 | self.logger.error("TOOL_CALL_ERROR", extra=log_data)
249 |
250 | # Console output with error details
251 | print(
252 | f"❌ TOOL ERROR: {tool_name} [{request_id[:8]}] {execution_time:.2f}s - {type(error).__name__}: {str(error)}"
253 | )
254 |
255 |
256 | class ToolExecutionLogger:
257 | """
258 | Specific logger for individual tool execution steps.
259 |
260 | Use this within tools to log execution progress and debug issues.
261 | """
262 |
263 | def __init__(self, tool_name: str, request_id: str | None = None):
264 | self.tool_name = tool_name
265 | self.request_id = request_id or request_id_var.get() or str(uuid.uuid4())
266 | self.logger = get_logger(f"maverick_mcp.tools.{tool_name}")
267 | self.start_time = time.time()
268 | self.step_times = {}
269 |
270 | def step(self, step_name: str, message: str | None = None):
271 | """Log a step in tool execution."""
272 | current_time = time.time()
273 | step_duration = current_time - self.start_time
274 | self.step_times[step_name] = step_duration
275 |
276 | log_message = message or f"Executing step: {step_name}"
277 |
278 | self.logger.info(
279 | log_message,
280 | extra={
281 | "request_id": self.request_id,
282 | "tool_name": self.tool_name,
283 | "step": step_name,
284 | "step_duration": step_duration,
285 | "total_duration": current_time - self.start_time,
286 | },
287 | )
288 |
289 | # Console progress indicator
290 | print(f" 📊 {self.tool_name} -> {step_name} ({step_duration:.2f}s)")
291 |
292 | def error(self, step_name: str, error: Exception, message: str | None = None):
293 | """Log an error in tool execution."""
294 | current_time = time.time()
295 | step_duration = current_time - self.start_time
296 |
297 | log_message = message or f"Error in step: {step_name}"
298 |
299 | self.logger.error(
300 | log_message,
301 | extra={
302 | "request_id": self.request_id,
303 | "tool_name": self.tool_name,
304 | "step": step_name,
305 | "step_duration": step_duration,
306 | "total_duration": current_time - self.start_time,
307 | "error_type": type(error).__name__,
308 | "error_message": str(error),
309 | "traceback": traceback.format_exc(),
310 | },
311 | )
312 |
313 | # Console error indicator
314 | print(
315 | f" ❌ {self.tool_name} -> {step_name} ERROR: {type(error).__name__}: {str(error)}"
316 | )
317 |
318 | def complete(self, result_summary: str | None = None):
319 | """Log completion of tool execution."""
320 | total_duration = time.time() - self.start_time
321 |
322 | log_message = result_summary or "Tool execution completed"
323 |
324 | self.logger.info(
325 | log_message,
326 | extra={
327 | "request_id": self.request_id,
328 | "tool_name": self.tool_name,
329 | "total_duration": total_duration,
330 | "step_times": self.step_times,
331 | "status": "completed",
332 | },
333 | )
334 |
335 | # Console completion
336 | print(f" ✅ {self.tool_name} completed ({total_duration:.2f}s)")
337 |
338 |
339 | def add_mcp_logging_middleware(
340 | server: FastMCP,
341 | include_payloads: bool = True,
342 | max_payload_length: int = 2000,
343 | log_level: int = logging.INFO,
344 | ):
345 | """
346 | Add comprehensive MCP logging middleware to a FastMCP server.
347 |
348 | Args:
349 | server: FastMCP server instance
350 | include_payloads: Whether to log request/response payloads (debug mode)
351 | max_payload_length: Maximum length of logged payloads
352 | log_level: Minimum logging level
353 | """
354 | if not MIDDLEWARE_AVAILABLE:
355 | logger.warning("FastMCP middleware not available - requires FastMCP 2.9+")
356 | print("⚠️ FastMCP middleware not available - tool logging will be limited")
357 | return
358 |
359 | middleware = MCPLoggingMiddleware(
360 | include_payloads=include_payloads,
361 | max_payload_length=max_payload_length,
362 | log_level=log_level,
363 | )
364 |
365 | # Use the correct FastMCP 2.0 middleware registration method
366 | try:
367 | if hasattr(server, "add_middleware"):
368 | server.add_middleware(middleware)
369 | logger.info("✅ FastMCP 2.0 middleware registered successfully")
370 | elif hasattr(server, "middleware"):
371 | # Fallback for different API structure
372 | if isinstance(server.middleware, list):
373 | server.middleware.append(middleware)
374 | else:
375 | server.middleware = [middleware]
376 | logger.info("✅ FastMCP middleware registered via fallback method")
377 | else:
378 | # Manual middleware application as decorator
379 | logger.warning("Using decorator-style middleware registration")
380 | _apply_middleware_as_decorators(server, middleware)
381 |
382 | except Exception as e:
383 | logger.error(f"Failed to register FastMCP middleware: {e}")
384 | print(f"⚠️ Middleware registration failed: {e}")
385 |
386 | logger.info(
387 | "MCP logging middleware setup completed",
388 | extra={
389 | "include_payloads": include_payloads,
390 | "max_payload_length": max_payload_length,
391 | "log_level": logging.getLevelName(log_level),
392 | },
393 | )
394 |
395 |
396 | def _apply_middleware_as_decorators(server: FastMCP, middleware: MCPLoggingMiddleware):
397 | """Apply middleware functionality via decorators if direct middleware isn't available."""
398 | # This is a fallback approach - wrap tool execution with logging
399 | original_tool_method = server.tool
400 |
401 | def logging_tool_decorator(*args, **kwargs):
402 | def decorator(func):
403 | # Wrap the original tool function with logging
404 | @functools.wraps(func)
405 | async def wrapper(*func_args, **func_kwargs):
406 | # Simple console logging as fallback
407 | func_name = getattr(func, "__name__", "unknown_tool")
408 | print(f"🔧 TOOL CALL: {func_name}")
409 | start_time = time.time()
410 | try:
411 | result = await func(*func_args, **func_kwargs)
412 | execution_time = time.time() - start_time
413 | print(f"🟢 TOOL SUCCESS: {func_name} ({execution_time:.2f}s)")
414 | return result
415 | except Exception as e:
416 | execution_time = time.time() - start_time
417 | print(
418 | f"❌ TOOL ERROR: {func_name} ({execution_time:.2f}s) - {type(e).__name__}: {str(e)}"
419 | )
420 | raise
421 |
422 | # Register the wrapped function
423 | return original_tool_method(*args, **kwargs)(wrapper)
424 |
425 | return decorator
426 |
427 | # Replace the server's tool decorator
428 | server.tool = logging_tool_decorator
429 | logger.info("Applied middleware as tool decorators (fallback mode)")
430 |
431 |
432 | # Convenience function for tool developers
433 | def get_tool_logger(tool_name: str) -> ToolExecutionLogger:
434 | """Get a tool execution logger for the current request."""
435 | return ToolExecutionLogger(tool_name)
436 |
```
--------------------------------------------------------------------------------
/maverick_mcp/domain/portfolio.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Portfolio domain entities for MaverickMCP.
3 |
4 | This module implements pure business logic for portfolio management following
5 | Domain-Driven Design (DDD) principles. These entities are framework-independent
6 | and contain the core portfolio logic including cost basis averaging and P&L calculations.
7 |
8 | Cost Basis Method: Average Cost
9 | - Simplest for educational purposes
10 | - Total cost / total shares
11 | - Does not change on partial sales
12 | """
13 |
14 | from dataclasses import dataclass, field
15 | from datetime import UTC, datetime
16 | from decimal import ROUND_HALF_UP, Decimal
17 | from typing import Optional
18 |
19 |
20 | @dataclass
21 | class Position:
22 | """
23 | Value object representing a single portfolio position.
24 |
25 | A position tracks shares held in a specific ticker with cost basis information.
26 | Uses immutable operations - modifications return new Position instances.
27 |
28 | Attributes:
29 | ticker: Stock ticker symbol (e.g., "AAPL")
30 | shares: Number of shares owned (supports fractional shares)
31 | average_cost_basis: Average cost per share
32 | total_cost: Total capital invested (shares × average_cost_basis)
33 | purchase_date: Earliest purchase date for this position
34 | notes: Optional user notes about the position
35 | """
36 |
37 | ticker: str
38 | shares: Decimal
39 | average_cost_basis: Decimal
40 | total_cost: Decimal
41 | purchase_date: datetime
42 | notes: str | None = None
43 |
44 | def __post_init__(self) -> None:
45 | """Validate position invariants after initialization."""
46 | if self.shares <= 0:
47 | raise ValueError(f"Shares must be positive, got {self.shares}")
48 | if self.average_cost_basis <= 0:
49 | raise ValueError(
50 | f"Average cost basis must be positive, got {self.average_cost_basis}"
51 | )
52 | if self.total_cost <= 0:
53 | raise ValueError(f"Total cost must be positive, got {self.total_cost}")
54 |
55 | # Normalize ticker to uppercase
56 | object.__setattr__(self, "ticker", self.ticker.upper())
57 |
58 | def add_shares(self, shares: Decimal, price: Decimal, date: datetime) -> "Position":
59 | """
60 | Add shares to position with automatic cost basis averaging.
61 |
62 | This creates a new Position instance with updated shares and averaged cost basis.
63 | The average cost method is used: (total_cost + new_cost) / total_shares
64 |
65 | Args:
66 | shares: Number of shares to add (must be > 0)
67 | price: Purchase price per share (must be > 0)
68 | date: Purchase date
69 |
70 | Returns:
71 | New Position instance with averaged cost basis
72 |
73 | Raises:
74 | ValueError: If shares or price is not positive
75 |
76 | Example:
77 | >>> pos = Position("AAPL", Decimal("10"), Decimal("150"), Decimal("1500"), datetime.now())
78 | >>> pos = pos.add_shares(Decimal("10"), Decimal("170"), datetime.now())
79 | >>> pos.shares
80 | Decimal('20')
81 | >>> pos.average_cost_basis
82 | Decimal('160.00')
83 | """
84 | if shares <= 0:
85 | raise ValueError(f"Shares to add must be positive, got {shares}")
86 | if price <= 0:
87 | raise ValueError(f"Price must be positive, got {price}")
88 |
89 | new_total_shares = self.shares + shares
90 | new_total_cost = self.total_cost + (shares * price)
91 | new_avg_cost = (new_total_cost / new_total_shares).quantize(
92 | Decimal("0.0001"), rounding=ROUND_HALF_UP
93 | )
94 |
95 | return Position(
96 | ticker=self.ticker,
97 | shares=new_total_shares,
98 | average_cost_basis=new_avg_cost,
99 | total_cost=new_total_cost,
100 | purchase_date=min(self.purchase_date, date),
101 | notes=self.notes,
102 | )
103 |
104 | def remove_shares(self, shares: Decimal) -> Optional["Position"]:
105 | """
106 | Remove shares from position.
107 |
108 | Returns None if the removal would close the position entirely (sold_shares >= held_shares).
109 | For partial sales, average cost basis remains unchanged.
110 |
111 | Args:
112 | shares: Number of shares to remove (must be > 0)
113 |
114 | Returns:
115 | New Position instance with reduced shares, or None if position closed
116 |
117 | Raises:
118 | ValueError: If shares is not positive
119 |
120 | Example:
121 | >>> pos = Position("AAPL", Decimal("20"), Decimal("160"), Decimal("3200"), datetime.now())
122 | >>> pos = pos.remove_shares(Decimal("10"))
123 | >>> pos.shares
124 | Decimal('10')
125 | >>> pos.average_cost_basis # Unchanged
126 | Decimal('160.00')
127 | """
128 | if shares <= 0:
129 | raise ValueError(f"Shares to remove must be positive, got {shares}")
130 |
131 | if shares >= self.shares:
132 | # Full position close
133 | return None
134 |
135 | new_shares = self.shares - shares
136 | new_total_cost = new_shares * self.average_cost_basis
137 |
138 | return Position(
139 | ticker=self.ticker,
140 | shares=new_shares,
141 | average_cost_basis=self.average_cost_basis,
142 | total_cost=new_total_cost,
143 | purchase_date=self.purchase_date,
144 | notes=self.notes,
145 | )
146 |
147 | def calculate_current_value(self, current_price: Decimal) -> dict[str, Decimal]:
148 | """
149 | Calculate live position value and P&L metrics.
150 |
151 | Args:
152 | current_price: Current market price per share
153 |
154 | Returns:
155 | Dictionary containing:
156 | - current_value: Current market value (shares × price)
157 | - unrealized_pnl: Unrealized profit/loss (current_value - total_cost)
158 | - pnl_percentage: P&L as percentage of total cost
159 |
160 | Example:
161 | >>> pos = Position("AAPL", Decimal("20"), Decimal("160"), Decimal("3200"), datetime.now())
162 | >>> metrics = pos.calculate_current_value(Decimal("175.50"))
163 | >>> metrics["current_value"]
164 | Decimal('3510.00')
165 | >>> metrics["unrealized_pnl"]
166 | Decimal('310.00')
167 | >>> metrics["pnl_percentage"]
168 | Decimal('9.6875')
169 | """
170 | current_value = (self.shares * current_price).quantize(
171 | Decimal("0.01"), rounding=ROUND_HALF_UP
172 | )
173 | unrealized_pnl = (current_value - self.total_cost).quantize(
174 | Decimal("0.01"), rounding=ROUND_HALF_UP
175 | )
176 |
177 | if self.total_cost > 0:
178 | pnl_percentage = (unrealized_pnl / self.total_cost * 100).quantize(
179 | Decimal("0.01"), rounding=ROUND_HALF_UP
180 | )
181 | else:
182 | pnl_percentage = Decimal("0.00")
183 |
184 | return {
185 | "current_value": current_value,
186 | "unrealized_pnl": unrealized_pnl,
187 | "pnl_percentage": pnl_percentage,
188 | }
189 |
190 | def to_dict(self) -> dict:
191 | """
192 | Convert position to dictionary for serialization.
193 |
194 | Returns:
195 | Dictionary representation with float values for JSON compatibility
196 | """
197 | return {
198 | "ticker": self.ticker,
199 | "shares": float(self.shares),
200 | "average_cost_basis": float(self.average_cost_basis),
201 | "total_cost": float(self.total_cost),
202 | "purchase_date": self.purchase_date.isoformat(),
203 | "notes": self.notes,
204 | }
205 |
206 |
207 | @dataclass
208 | class Portfolio:
209 | """
210 | Aggregate root for user portfolio.
211 |
212 | Manages a collection of positions with operations for adding, removing, and analyzing
213 | holdings. Enforces business rules and maintains consistency.
214 |
215 | Attributes:
216 | portfolio_id: Unique identifier (UUID as string)
217 | user_id: User identifier (default: "default" for single-user system)
218 | name: Portfolio display name
219 | positions: List of Position value objects
220 | created_at: Portfolio creation timestamp
221 | updated_at: Last modification timestamp
222 | """
223 |
224 | portfolio_id: str
225 | user_id: str
226 | name: str
227 | positions: list[Position] = field(default_factory=list)
228 | created_at: datetime = field(default_factory=lambda: datetime.now(UTC))
229 | updated_at: datetime = field(default_factory=lambda: datetime.now(UTC))
230 |
231 | def add_position(
232 | self,
233 | ticker: str,
234 | shares: Decimal,
235 | price: Decimal,
236 | date: datetime,
237 | notes: str | None = None,
238 | ) -> None:
239 | """
240 | Add or update position with automatic cost basis averaging.
241 |
242 | If the ticker already exists, shares are added and cost basis is averaged.
243 | Otherwise, a new position is created.
244 |
245 | Args:
246 | ticker: Stock ticker symbol
247 | shares: Number of shares to add
248 | price: Purchase price per share
249 | date: Purchase date
250 | notes: Optional notes (only used for new positions)
251 |
252 | Example:
253 | >>> portfolio = Portfolio("id", "default", "My Portfolio")
254 | >>> portfolio.add_position("AAPL", Decimal("10"), Decimal("150"), datetime.now())
255 | >>> portfolio.add_position("AAPL", Decimal("10"), Decimal("170"), datetime.now())
256 | >>> portfolio.get_position("AAPL").shares
257 | Decimal('20')
258 | """
259 | ticker = ticker.upper()
260 |
261 | # Find existing position
262 | for i, pos in enumerate(self.positions):
263 | if pos.ticker == ticker:
264 | self.positions[i] = pos.add_shares(shares, price, date)
265 | self.updated_at = datetime.now(UTC)
266 | return
267 |
268 | # Create new position
269 | new_position = Position(
270 | ticker=ticker,
271 | shares=shares,
272 | average_cost_basis=price,
273 | total_cost=shares * price,
274 | purchase_date=date,
275 | notes=notes,
276 | )
277 | self.positions.append(new_position)
278 | self.updated_at = datetime.now(UTC)
279 |
280 | def remove_position(self, ticker: str, shares: Decimal | None = None) -> bool:
281 | """
282 | Remove position or reduce shares.
283 |
284 | Args:
285 | ticker: Stock ticker symbol
286 | shares: Number of shares to remove (None = remove entire position)
287 |
288 | Returns:
289 | True if position was found and removed/reduced, False otherwise
290 |
291 | Example:
292 | >>> portfolio.remove_position("AAPL", Decimal("10")) # Partial
293 | True
294 | >>> portfolio.remove_position("AAPL") # Full removal
295 | True
296 | """
297 | ticker = ticker.upper()
298 |
299 | for i, pos in enumerate(self.positions):
300 | if pos.ticker == ticker:
301 | if shares is None or shares >= pos.shares:
302 | # Full position removal
303 | self.positions.pop(i)
304 | else:
305 | # Partial removal
306 | updated_pos = pos.remove_shares(shares)
307 | if updated_pos:
308 | self.positions[i] = updated_pos
309 | else:
310 | self.positions.pop(i)
311 |
312 | self.updated_at = datetime.now(UTC)
313 | return True
314 |
315 | return False
316 |
317 | def get_position(self, ticker: str) -> Position | None:
318 | """
319 | Get position by ticker symbol.
320 |
321 | Args:
322 | ticker: Stock ticker symbol (case-insensitive)
323 |
324 | Returns:
325 | Position if found, None otherwise
326 | """
327 | ticker = ticker.upper()
328 | return next((pos for pos in self.positions if pos.ticker == ticker), None)
329 |
330 | def get_total_invested(self) -> Decimal:
331 | """
332 | Calculate total capital invested across all positions.
333 |
334 | Returns:
335 | Sum of all position total costs
336 | """
337 | return sum((pos.total_cost for pos in self.positions), Decimal("0"))
338 |
339 | def calculate_portfolio_metrics(self, current_prices: dict[str, Decimal]) -> dict:
340 | """
341 | Calculate comprehensive portfolio metrics with live prices.
342 |
343 | Args:
344 | current_prices: Dictionary mapping ticker symbols to current prices
345 |
346 | Returns:
347 | Dictionary containing:
348 | - total_value: Current market value of all positions
349 | - total_invested: Total capital invested
350 | - total_pnl: Total unrealized profit/loss
351 | - total_pnl_percentage: Total P&L as percentage
352 | - position_count: Number of positions
353 | - positions: List of position details with current metrics
354 |
355 | Example:
356 | >>> prices = {"AAPL": Decimal("175.50"), "MSFT": Decimal("380.00")}
357 | >>> metrics = portfolio.calculate_portfolio_metrics(prices)
358 | >>> metrics["total_value"]
359 | 15250.50
360 | """
361 | total_value = Decimal("0")
362 | total_cost = Decimal("0")
363 | position_details = []
364 |
365 | for pos in self.positions:
366 | # Use current price if available, otherwise fall back to cost basis
367 | current_price = current_prices.get(pos.ticker, pos.average_cost_basis)
368 | metrics = pos.calculate_current_value(current_price)
369 |
370 | total_value += metrics["current_value"]
371 | total_cost += pos.total_cost
372 |
373 | position_details.append(
374 | {
375 | "ticker": pos.ticker,
376 | "shares": float(pos.shares),
377 | "cost_basis": float(pos.average_cost_basis),
378 | "current_price": float(current_price),
379 | "current_value": float(metrics["current_value"]),
380 | "unrealized_pnl": float(metrics["unrealized_pnl"]),
381 | "pnl_percentage": float(metrics["pnl_percentage"]),
382 | "purchase_date": pos.purchase_date.isoformat(),
383 | "notes": pos.notes,
384 | }
385 | )
386 |
387 | total_pnl = total_value - total_cost
388 | total_pnl_pct = (
389 | (total_pnl / total_cost * 100).quantize(
390 | Decimal("0.01"), rounding=ROUND_HALF_UP
391 | )
392 | if total_cost > 0
393 | else Decimal("0.00")
394 | )
395 |
396 | return {
397 | "total_value": float(total_value),
398 | "total_invested": float(total_cost),
399 | "total_pnl": float(total_pnl),
400 | "total_pnl_percentage": float(total_pnl_pct),
401 | "position_count": len(self.positions),
402 | "positions": position_details,
403 | }
404 |
405 | def clear_all_positions(self) -> None:
406 | """
407 | Remove all positions from the portfolio.
408 |
409 | ⚠️ WARNING: This operation cannot be undone.
410 | """
411 | self.positions.clear()
412 | self.updated_at = datetime.now(UTC)
413 |
414 | def to_dict(self) -> dict:
415 | """
416 | Convert portfolio to dictionary for serialization.
417 |
418 | Returns:
419 | Dictionary representation suitable for JSON serialization
420 | """
421 | return {
422 | "portfolio_id": self.portfolio_id,
423 | "user_id": self.user_id,
424 | "name": self.name,
425 | "positions": [pos.to_dict() for pos in self.positions],
426 | "position_count": len(self.positions),
427 | "total_invested": float(self.get_total_invested()),
428 | "created_at": self.created_at.isoformat(),
429 | "updated_at": self.updated_at.isoformat(),
430 | }
431 |
```
--------------------------------------------------------------------------------
/maverick_mcp/api/error_handling.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Enhanced error handling framework for MaverickMCP API.
3 |
4 | This module provides centralized error handling with structured responses,
5 | proper logging, monitoring integration, and client-friendly error messages.
6 | """
7 |
8 | import asyncio
9 | import uuid
10 | from collections.abc import Callable
11 | from typing import Any
12 |
13 | from fastapi import HTTPException, Request, status
14 | from fastapi.exceptions import RequestValidationError
15 | from fastapi.responses import JSONResponse
16 | from sqlalchemy.exc import IntegrityError, OperationalError
17 |
18 | from maverick_mcp.exceptions import (
19 | APIRateLimitError,
20 | AuthenticationError,
21 | AuthorizationError,
22 | CacheConnectionError,
23 | CircuitBreakerError,
24 | ConflictError,
25 | DatabaseConnectionError,
26 | DataIntegrityError,
27 | DataNotFoundError,
28 | ExternalServiceError,
29 | MaverickException,
30 | NotFoundError,
31 | RateLimitError,
32 | ValidationError,
33 | WebhookError,
34 | )
35 | from maverick_mcp.utils.logging import get_logger
36 | from maverick_mcp.utils.monitoring import get_monitoring_service
37 | from maverick_mcp.validation.responses import error_response, validation_error_response
38 |
39 | logger = get_logger(__name__)
40 | monitoring = get_monitoring_service()
41 |
42 |
43 | class ErrorHandler:
44 | """Centralized error handler with monitoring integration."""
45 |
46 | def __init__(self):
47 | self.error_mappings = self._build_error_mappings()
48 |
49 | def _build_error_mappings(self) -> dict[type[Exception], dict[str, Any]]:
50 | """Build mapping of exception types to response details."""
51 | return {
52 | # MaverickMCP exceptions
53 | ValidationError: {
54 | "status_code": status.HTTP_422_UNPROCESSABLE_ENTITY,
55 | "code": "VALIDATION_ERROR",
56 | "log_level": "warning",
57 | },
58 | AuthenticationError: {
59 | "status_code": status.HTTP_401_UNAUTHORIZED,
60 | "code": "AUTHENTICATION_ERROR",
61 | "log_level": "warning",
62 | },
63 | AuthorizationError: {
64 | "status_code": status.HTTP_403_FORBIDDEN,
65 | "code": "AUTHORIZATION_ERROR",
66 | "log_level": "warning",
67 | },
68 | DataNotFoundError: {
69 | "status_code": status.HTTP_404_NOT_FOUND,
70 | "code": "DATA_NOT_FOUND",
71 | "log_level": "info",
72 | },
73 | APIRateLimitError: {
74 | "status_code": status.HTTP_429_TOO_MANY_REQUESTS,
75 | "code": "RATE_LIMIT_EXCEEDED",
76 | "log_level": "warning",
77 | },
78 | CircuitBreakerError: {
79 | "status_code": status.HTTP_503_SERVICE_UNAVAILABLE,
80 | "code": "SERVICE_UNAVAILABLE",
81 | "log_level": "error",
82 | },
83 | DatabaseConnectionError: {
84 | "status_code": status.HTTP_503_SERVICE_UNAVAILABLE,
85 | "code": "DATABASE_CONNECTION_ERROR",
86 | "log_level": "error",
87 | },
88 | CacheConnectionError: {
89 | "status_code": status.HTTP_503_SERVICE_UNAVAILABLE,
90 | "code": "CACHE_CONNECTION_ERROR",
91 | "log_level": "error",
92 | },
93 | DataIntegrityError: {
94 | "status_code": status.HTTP_409_CONFLICT,
95 | "code": "DATA_INTEGRITY_ERROR",
96 | "log_level": "error",
97 | },
98 | # API errors from validation module
99 | NotFoundError: {
100 | "status_code": status.HTTP_404_NOT_FOUND,
101 | "code": "NOT_FOUND",
102 | "log_level": "info",
103 | },
104 | ConflictError: {
105 | "status_code": status.HTTP_409_CONFLICT,
106 | "code": "CONFLICT",
107 | "log_level": "warning",
108 | },
109 | RateLimitError: {
110 | "status_code": status.HTTP_429_TOO_MANY_REQUESTS,
111 | "code": "RATE_LIMIT_EXCEEDED",
112 | "log_level": "warning",
113 | },
114 | ExternalServiceError: {
115 | "status_code": status.HTTP_503_SERVICE_UNAVAILABLE,
116 | "code": "EXTERNAL_SERVICE_ERROR",
117 | "log_level": "error",
118 | },
119 | WebhookError: {
120 | "status_code": status.HTTP_400_BAD_REQUEST,
121 | "code": "WEBHOOK_ERROR",
122 | "log_level": "warning",
123 | },
124 | # SQLAlchemy exceptions
125 | IntegrityError: {
126 | "status_code": status.HTTP_409_CONFLICT,
127 | "code": "DATABASE_INTEGRITY_ERROR",
128 | "log_level": "error",
129 | },
130 | OperationalError: {
131 | "status_code": status.HTTP_503_SERVICE_UNAVAILABLE,
132 | "code": "DATABASE_OPERATIONAL_ERROR",
133 | "log_level": "error",
134 | },
135 | # Third-party API exceptions
136 | ValueError: {
137 | "status_code": status.HTTP_400_BAD_REQUEST,
138 | "code": "INVALID_REQUEST",
139 | "log_level": "warning",
140 | },
141 | KeyError: {
142 | "status_code": status.HTTP_400_BAD_REQUEST,
143 | "code": "MISSING_REQUIRED_FIELD",
144 | "log_level": "warning",
145 | },
146 | TypeError: {
147 | "status_code": status.HTTP_400_BAD_REQUEST,
148 | "code": "TYPE_ERROR",
149 | "log_level": "warning",
150 | },
151 | }
152 |
153 | def handle_exception(
154 | self,
155 | request: Request,
156 | exception: Exception,
157 | context: dict[str, Any] | None = None,
158 | ) -> JSONResponse:
159 | """
160 | Handle exception and return structured error response.
161 |
162 | Args:
163 | request: FastAPI request object
164 | exception: The exception to handle
165 | context: Additional context for logging
166 |
167 | Returns:
168 | JSONResponse with structured error
169 | """
170 | # Generate trace ID for this error
171 | trace_id = str(uuid.uuid4())
172 |
173 | # Get error details from mapping
174 | error_info = self._get_error_info(exception)
175 |
176 | # Log the error with full context
177 | self._log_error(
178 | exception=exception,
179 | trace_id=trace_id,
180 | request=request,
181 | error_info=error_info,
182 | context=context,
183 | )
184 |
185 | # Send to monitoring service
186 | self._send_to_monitoring(
187 | exception=exception,
188 | trace_id=trace_id,
189 | request=request,
190 | context=context,
191 | )
192 |
193 | # Build client-friendly response
194 | response_data = self._build_error_response(
195 | exception=exception,
196 | error_info=error_info,
197 | trace_id=trace_id,
198 | )
199 |
200 | return JSONResponse(
201 | status_code=error_info["status_code"],
202 | content=response_data,
203 | )
204 |
205 | def _get_error_info(self, exception: Exception) -> dict[str, Any]:
206 | """Get error information for the exception type."""
207 | # Check for exact type match first
208 | exc_type = type(exception)
209 | if exc_type in self.error_mappings:
210 | return self.error_mappings[exc_type]
211 |
212 | # Check for inheritance
213 | for error_type, info in self.error_mappings.items():
214 | if isinstance(exception, error_type):
215 | return info
216 |
217 | # Default for unknown exceptions
218 | return {
219 | "status_code": status.HTTP_500_INTERNAL_SERVER_ERROR,
220 | "code": "INTERNAL_ERROR",
221 | "log_level": "error",
222 | }
223 |
224 | def _log_error(
225 | self,
226 | exception: Exception,
227 | trace_id: str,
228 | request: Request,
229 | error_info: dict[str, Any],
230 | context: dict[str, Any] | None = None,
231 | ) -> None:
232 | """Log error with full context."""
233 | log_data = {
234 | "trace_id": trace_id,
235 | "error_type": type(exception).__name__,
236 | "error_code": error_info["code"],
237 | "status_code": error_info["status_code"],
238 | "method": request.method,
239 | "path": request.url.path,
240 | "client_host": request.client.host if request.client else None,
241 | "user_agent": request.headers.get("user-agent"),
242 | }
243 |
244 | # Add exception details if available
245 | if isinstance(exception, MaverickException):
246 | log_data["error_details"] = exception.to_dict()
247 |
248 | # Add custom context
249 | if context:
250 | log_data["context"] = context
251 |
252 | # Log at appropriate level
253 | log_level = error_info["log_level"]
254 | if log_level == "error":
255 | logger.error(
256 | f"Error handling request: {str(exception)}",
257 | exc_info=True,
258 | extra=log_data,
259 | )
260 | elif log_level == "warning":
261 | logger.warning(
262 | f"Request failed: {str(exception)}",
263 | extra=log_data,
264 | )
265 | else:
266 | logger.info(
267 | f"Request rejected: {str(exception)}",
268 | extra=log_data,
269 | )
270 |
271 | def _send_to_monitoring(
272 | self,
273 | exception: Exception,
274 | trace_id: str,
275 | request: Request,
276 | context: dict[str, Any] | None = None,
277 | ) -> None:
278 | """Send error to monitoring service (Sentry)."""
279 | monitoring_context = {
280 | "trace_id": trace_id,
281 | "request": {
282 | "method": request.method,
283 | "path": request.url.path,
284 | "query": str(request.url.query),
285 | },
286 | }
287 |
288 | if context:
289 | monitoring_context["custom_context"] = context
290 |
291 | # Only send certain errors to Sentry
292 | error_info = self._get_error_info(exception)
293 | if error_info["log_level"] in ["error", "warning"]:
294 | monitoring.capture_exception(exception, **monitoring_context)
295 |
296 | def _build_error_response(
297 | self,
298 | exception: Exception,
299 | error_info: dict[str, Any],
300 | trace_id: str,
301 | ) -> dict[str, Any]:
302 | """Build client-friendly error response."""
303 | # Extract error details
304 | if isinstance(exception, MaverickException):
305 | message = exception.message
306 | context = exception.context
307 | elif isinstance(exception, HTTPException):
308 | message = exception.detail
309 | context = None
310 | else:
311 | # Generic message for unknown errors
312 | message = self._get_safe_error_message(exception, error_info["code"])
313 | context = None
314 |
315 | return error_response(
316 | code=error_info["code"],
317 | message=message,
318 | status_code=error_info["status_code"],
319 | context=context,
320 | trace_id=trace_id,
321 | )
322 |
323 | def _get_safe_error_message(self, exception: Exception, code: str) -> str:
324 | """Get safe error message for client."""
325 | safe_messages = {
326 | "INTERNAL_ERROR": "An unexpected error occurred. Please try again later.",
327 | "DATABASE_INTEGRITY_ERROR": "Data conflict detected. Please check your input.",
328 | "DATABASE_OPERATIONAL_ERROR": "Database temporarily unavailable.",
329 | "INVALID_REQUEST": "Invalid request format.",
330 | "MISSING_REQUIRED_FIELD": "Required field missing from request.",
331 | "TYPE_ERROR": "Invalid data type in request.",
332 | }
333 |
334 | return safe_messages.get(code, str(exception))
335 |
336 |
337 | # Global error handler instance
338 | error_handler = ErrorHandler()
339 |
340 |
341 | def handle_api_error(
342 | request: Request,
343 | exception: Exception,
344 | context: dict[str, Any] | None = None,
345 | ) -> JSONResponse:
346 | """
347 | Main entry point for API error handling.
348 |
349 | Args:
350 | request: FastAPI request
351 | exception: Exception to handle
352 | context: Additional context
353 |
354 | Returns:
355 | Structured error response
356 | """
357 | return error_handler.handle_exception(request, exception, context)
358 |
359 |
360 | async def validation_exception_handler(
361 | request: Request, exc: RequestValidationError
362 | ) -> JSONResponse:
363 | """Handle FastAPI validation errors."""
364 | errors = []
365 | for error in exc.errors():
366 | errors.append(
367 | {
368 | "code": "VALIDATION_ERROR",
369 | "field": ".".join(str(loc) for loc in error["loc"]),
370 | "message": error["msg"],
371 | "context": {"input": error.get("input")},
372 | }
373 | )
374 |
375 | trace_id = str(uuid.uuid4())
376 |
377 | # Log validation errors
378 | logger.warning(
379 | "Request validation failed",
380 | extra={
381 | "trace_id": trace_id,
382 | "path": request.url.path,
383 | "errors": errors,
384 | },
385 | )
386 |
387 | return JSONResponse(
388 | status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
389 | content=validation_error_response(errors, trace_id),
390 | )
391 |
392 |
393 | def create_error_handlers() -> dict[Any, Callable]:
394 | """Create error handlers for FastAPI app."""
395 | return {
396 | RequestValidationError: validation_exception_handler,
397 | Exception: lambda request, exc: handle_api_error(request, exc),
398 | }
399 |
400 |
401 | # Decorator for wrapping functions with error handling
402 | def with_error_handling(context_fn: Callable[[Any], dict[str, Any]] | None = None):
403 | """
404 | Decorator to wrap functions with proper error handling.
405 |
406 | Args:
407 | context_fn: Optional function to extract context from arguments
408 | """
409 |
410 | def decorator(func: Callable) -> Callable:
411 | async def async_wrapper(*args, **kwargs):
412 | try:
413 | return await func(*args, **kwargs)
414 | except Exception as e:
415 | # Extract context if function provided
416 | context = context_fn(*args, **kwargs) if context_fn else {}
417 |
418 | # Get request from args/kwargs
419 | request = None
420 | for arg in args:
421 | if isinstance(arg, Request):
422 | request = arg
423 | break
424 | if not request and "request" in kwargs:
425 | request = kwargs["request"]
426 |
427 | if request:
428 | return handle_api_error(request, e, context)
429 | else:
430 | # Re-raise if no request object
431 | raise
432 |
433 | def sync_wrapper(*args, **kwargs):
434 | try:
435 | return func(*args, **kwargs)
436 | except Exception as e:
437 | # Extract context if function provided
438 | context = context_fn(*args, **kwargs) if context_fn else {}
439 |
440 | # Get request from args/kwargs
441 | request = None
442 | for arg in args:
443 | if isinstance(arg, Request):
444 | request = arg
445 | break
446 | if not request and "request" in kwargs:
447 | request = kwargs["request"]
448 |
449 | if request:
450 | return handle_api_error(request, e, context)
451 | else:
452 | # Re-raise if no request object
453 | raise
454 |
455 | # Return appropriate wrapper based on function type
456 | if asyncio.iscoroutinefunction(func):
457 | return async_wrapper
458 | else:
459 | return sync_wrapper
460 |
461 | return decorator
462 |
```
--------------------------------------------------------------------------------
/maverick_mcp/utils/logging_init.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Logging initialization module for the backtesting system.
3 |
4 | This module provides a centralized initialization point for all logging
5 | components including structured logging, performance monitoring, debug
6 | utilities, and log aggregation.
7 | """
8 |
9 | import logging
10 | import os
11 | from typing import Any
12 |
13 | from maverick_mcp.config.logging_settings import (
14 | LoggingSettings,
15 | configure_logging_for_environment,
16 | get_logging_settings,
17 | validate_logging_settings,
18 | )
19 | from maverick_mcp.utils.debug_utils import (
20 | disable_debug_mode,
21 | enable_debug_mode,
22 | )
23 | from maverick_mcp.utils.debug_utils import (
24 | print_debug_summary as debug_print_summary,
25 | )
26 | from maverick_mcp.utils.structured_logger import (
27 | StructuredLoggerManager,
28 | get_logger_manager,
29 | )
30 |
31 |
32 | class LoggingInitializer:
33 | """Comprehensive logging system initializer."""
34 |
35 | def __init__(self):
36 | self._initialized = False
37 | self._settings: LoggingSettings | None = None
38 | self._manager: StructuredLoggerManager | None = None
39 |
40 | def initialize_logging_system(
41 | self,
42 | environment: str | None = None,
43 | custom_settings: dict[str, Any] | None = None,
44 | force_reinit: bool = False,
45 | ) -> LoggingSettings:
46 | """
47 | Initialize the complete logging system.
48 |
49 | Args:
50 | environment: Environment name (development, testing, production)
51 | custom_settings: Custom settings to override defaults
52 | force_reinit: Force reinitialization even if already initialized
53 |
54 | Returns:
55 | LoggingSettings: The final logging configuration
56 | """
57 | if self._initialized and not force_reinit:
58 | return self._settings
59 |
60 | # Determine environment
61 | if not environment:
62 | environment = os.getenv("MAVERICK_ENVIRONMENT", "development")
63 |
64 | # Get base settings for environment
65 | if environment in ["development", "testing", "production"]:
66 | self._settings = configure_logging_for_environment(environment)
67 | else:
68 | self._settings = get_logging_settings()
69 |
70 | # Apply custom settings if provided
71 | if custom_settings:
72 | for key, value in custom_settings.items():
73 | if hasattr(self._settings, key):
74 | setattr(self._settings, key, value)
75 |
76 | # Validate settings
77 | warnings = validate_logging_settings(self._settings)
78 | if warnings:
79 | print("⚠️ Logging configuration warnings:")
80 | for warning in warnings:
81 | print(f" - {warning}")
82 |
83 | # Initialize structured logging system
84 | self._initialize_structured_logging()
85 |
86 | # Initialize debug mode if enabled
87 | if self._settings.debug_enabled:
88 | enable_debug_mode()
89 | self._setup_debug_logging()
90 |
91 | # Initialize performance monitoring
92 | self._initialize_performance_monitoring()
93 |
94 | # Setup log rotation and cleanup
95 | self._setup_log_management()
96 |
97 | # Print initialization summary
98 | self._print_initialization_summary(environment)
99 |
100 | self._initialized = True
101 | return self._settings
102 |
103 | def _initialize_structured_logging(self):
104 | """Initialize structured logging infrastructure."""
105 | self._manager = get_logger_manager()
106 |
107 | # Setup structured logging with current settings
108 | self._manager.setup_structured_logging(
109 | log_level=self._settings.log_level,
110 | log_format=self._settings.log_format,
111 | log_file=self._settings.log_file_path
112 | if self._settings.enable_file_logging
113 | else None,
114 | enable_async=self._settings.enable_async_logging,
115 | enable_rotation=self._settings.enable_log_rotation,
116 | max_log_size=self._settings.max_log_size_mb * 1024 * 1024,
117 | backup_count=self._settings.backup_count,
118 | console_output=self._settings.console_output,
119 | )
120 |
121 | # Configure debug filters if debug mode is enabled
122 | if self._settings.debug_enabled:
123 | for module in self._settings.get_debug_modules():
124 | self._manager.debug_manager.enable_verbose_logging(module)
125 |
126 | if self._settings.log_request_response:
127 | self._manager.debug_manager.add_debug_filter(
128 | "backtesting_requests",
129 | {
130 | "log_request_response": True,
131 | "operations": [
132 | "run_backtest",
133 | "optimize_parameters",
134 | "get_historical_data",
135 | "calculate_technical_indicators",
136 | ],
137 | },
138 | )
139 |
140 | def _setup_debug_logging(self):
141 | """Setup debug-specific logging configuration."""
142 | # Create debug loggers
143 | debug_logger = logging.getLogger("maverick_mcp.debug")
144 | debug_logger.setLevel(logging.DEBUG)
145 |
146 | request_logger = logging.getLogger("maverick_mcp.requests")
147 | request_logger.setLevel(logging.DEBUG)
148 |
149 | error_logger = logging.getLogger("maverick_mcp.errors")
150 | error_logger.setLevel(logging.DEBUG)
151 |
152 | # Add debug file handler if file logging is enabled
153 | if self._settings.enable_file_logging:
154 | debug_log_path = self._settings.log_file_path.replace(".log", "_debug.log")
155 | debug_handler = logging.FileHandler(debug_log_path)
156 | debug_handler.setLevel(logging.DEBUG)
157 |
158 | # Use structured formatter for debug logs
159 | from maverick_mcp.utils.structured_logger import EnhancedStructuredFormatter
160 |
161 | debug_formatter = EnhancedStructuredFormatter(
162 | include_performance=True, include_resources=True
163 | )
164 | debug_handler.setFormatter(debug_formatter)
165 |
166 | debug_logger.addHandler(debug_handler)
167 | request_logger.addHandler(debug_handler)
168 | error_logger.addHandler(debug_handler)
169 |
170 | def _initialize_performance_monitoring(self):
171 | """Initialize performance monitoring system."""
172 | if not self._settings.enable_performance_logging:
173 | return
174 |
175 | # Create performance loggers for key components
176 | components = [
177 | "vectorbt_engine",
178 | "data_provider",
179 | "cache_manager",
180 | "technical_analysis",
181 | "portfolio_optimization",
182 | "strategy_execution",
183 | ]
184 |
185 | for component in components:
186 | perf_logger = self._manager.get_performance_logger(
187 | f"performance.{component}"
188 | )
189 | perf_logger.logger.info(
190 | f"Performance monitoring initialized for {component}"
191 | )
192 |
193 | def _setup_log_management(self):
194 | """Setup log rotation and cleanup mechanisms."""
195 | if (
196 | not self._settings.enable_file_logging
197 | or not self._settings.enable_log_rotation
198 | ):
199 | return
200 |
201 | # Log rotation is handled by RotatingFileHandler
202 | # Additional cleanup could be implemented here for old log files
203 |
204 | # Create logs directory if it doesn't exist
205 | self._settings.ensure_log_directory()
206 |
207 | def _print_initialization_summary(self, environment: str):
208 | """Print logging initialization summary."""
209 | print("\n" + "=" * 80)
210 | print("MAVERICK MCP LOGGING SYSTEM INITIALIZED")
211 | print("=" * 80)
212 | print(f"Environment: {environment}")
213 | print(f"Log Level: {self._settings.log_level}")
214 | print(f"Log Format: {self._settings.log_format}")
215 | print(
216 | f"Debug Mode: {'✅ Enabled' if self._settings.debug_enabled else '❌ Disabled'}"
217 | )
218 | print(
219 | f"Performance Monitoring: {'✅ Enabled' if self._settings.enable_performance_logging else '❌ Disabled'}"
220 | )
221 | print(
222 | f"File Logging: {'✅ Enabled' if self._settings.enable_file_logging else '❌ Disabled'}"
223 | )
224 |
225 | if self._settings.enable_file_logging:
226 | print(f"Log File: {self._settings.log_file_path}")
227 | print(
228 | f"Log Rotation: {'✅ Enabled' if self._settings.enable_log_rotation else '❌ Disabled'}"
229 | )
230 |
231 | print(
232 | f"Async Logging: {'✅ Enabled' if self._settings.enable_async_logging else '❌ Disabled'}"
233 | )
234 | print(
235 | f"Resource Tracking: {'✅ Enabled' if self._settings.enable_resource_tracking else '❌ Disabled'}"
236 | )
237 |
238 | if self._settings.debug_enabled:
239 | print("\n🐛 DEBUG MODE FEATURES:")
240 | print(
241 | f" - Request/Response Logging: {'✅' if self._settings.log_request_response else '❌'}"
242 | )
243 | print(f" - Verbose Modules: {len(self._settings.get_debug_modules())}")
244 | print(f" - Max Payload Size: {self._settings.max_payload_length} chars")
245 |
246 | if self._settings.enable_performance_logging:
247 | print("\n📊 PERFORMANCE MONITORING:")
248 | print(f" - Threshold: {self._settings.performance_log_threshold_ms}ms")
249 | print(
250 | f" - Business Metrics: {'✅' if self._settings.enable_business_metrics else '❌'}"
251 | )
252 |
253 | print("\n" + "=" * 80 + "\n")
254 |
255 | def get_settings(self) -> LoggingSettings | None:
256 | """Get current logging settings."""
257 | return self._settings
258 |
259 | def get_manager(self) -> StructuredLoggerManager | None:
260 | """Get logging manager instance."""
261 | return self._manager
262 |
263 | def enable_debug_mode_runtime(self):
264 | """Enable debug mode at runtime."""
265 | if self._settings:
266 | self._settings.debug_enabled = True
267 | enable_debug_mode()
268 | self._setup_debug_logging()
269 | print("🐛 Debug mode enabled at runtime")
270 |
271 | def disable_debug_mode_runtime(self):
272 | """Disable debug mode at runtime."""
273 | if self._settings:
274 | self._settings.debug_enabled = False
275 | disable_debug_mode()
276 | print("🐛 Debug mode disabled at runtime")
277 |
278 | def print_debug_summary_if_enabled(self):
279 | """Print debug summary if debug mode is enabled."""
280 | if self._settings and self._settings.debug_enabled:
281 | debug_print_summary()
282 |
283 | def reconfigure_log_level(self, new_level: str):
284 | """Reconfigure log level at runtime."""
285 | if not self._settings:
286 | raise RuntimeError("Logging system not initialized")
287 |
288 | valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
289 | if new_level.upper() not in valid_levels:
290 | raise ValueError(f"Invalid log level: {new_level}")
291 |
292 | self._settings.log_level = new_level.upper()
293 |
294 | # Update all loggers
295 | logging.getLogger().setLevel(getattr(logging, new_level.upper()))
296 |
297 | print(f"📊 Log level changed to: {new_level.upper()}")
298 |
299 | def get_performance_summary(self) -> dict[str, Any]:
300 | """Get comprehensive performance summary."""
301 | if not self._manager:
302 | return {"error": "Logging system not initialized"}
303 |
304 | return self._manager.create_dashboard_metrics()
305 |
306 | def cleanup_logging_system(self):
307 | """Cleanup logging system resources."""
308 | if self._manager:
309 | # Close any open handlers
310 | for handler in logging.getLogger().handlers:
311 | if hasattr(handler, "close"):
312 | handler.close()
313 |
314 | self._initialized = False
315 | print("🧹 Logging system cleaned up")
316 |
317 |
318 | # Global initializer instance
319 | _logging_initializer: LoggingInitializer | None = None
320 |
321 |
322 | def get_logging_initializer() -> LoggingInitializer:
323 | """Get global logging initializer instance."""
324 | global _logging_initializer
325 | if _logging_initializer is None:
326 | _logging_initializer = LoggingInitializer()
327 | return _logging_initializer
328 |
329 |
330 | def initialize_for_environment(environment: str, **custom_settings) -> LoggingSettings:
331 | """Initialize logging for specific environment."""
332 | initializer = get_logging_initializer()
333 | return initializer.initialize_logging_system(environment, custom_settings)
334 |
335 |
336 | def initialize_for_development(**custom_settings) -> LoggingSettings:
337 | """Initialize logging for development environment."""
338 | return initialize_for_environment("development", **custom_settings)
339 |
340 |
341 | def initialize_for_testing(**custom_settings) -> LoggingSettings:
342 | """Initialize logging for testing environment."""
343 | return initialize_for_environment("testing", **custom_settings)
344 |
345 |
346 | def initialize_for_production(**custom_settings) -> LoggingSettings:
347 | """Initialize logging for production environment."""
348 | return initialize_for_environment("production", **custom_settings)
349 |
350 |
351 | def initialize_backtesting_logging(
352 | environment: str | None = None, debug_mode: bool = False, **custom_settings
353 | ) -> LoggingSettings:
354 | """
355 | Convenient function to initialize logging specifically for backtesting.
356 |
357 | Args:
358 | environment: Target environment (auto-detected if None)
359 | debug_mode: Enable debug mode
360 | **custom_settings: Additional custom settings
361 |
362 | Returns:
363 | LoggingSettings: Final logging configuration
364 | """
365 | if debug_mode:
366 | custom_settings["debug_enabled"] = True
367 | custom_settings["log_request_response"] = True
368 | custom_settings["performance_log_threshold_ms"] = 100.0
369 |
370 | return initialize_for_environment(environment, **custom_settings)
371 |
372 |
373 | # Convenience functions for runtime control
374 | def enable_debug_mode_runtime():
375 | """Enable debug mode at runtime."""
376 | get_logging_initializer().enable_debug_mode_runtime()
377 |
378 |
379 | def disable_debug_mode_runtime():
380 | """Disable debug mode at runtime."""
381 | get_logging_initializer().disable_debug_mode_runtime()
382 |
383 |
384 | def change_log_level(new_level: str):
385 | """Change log level at runtime."""
386 | get_logging_initializer().reconfigure_log_level(new_level)
387 |
388 |
389 | def get_performance_summary() -> dict[str, Any]:
390 | """Get comprehensive performance summary."""
391 | return get_logging_initializer().get_performance_summary()
392 |
393 |
394 | def print_debug_summary():
395 | """Print debug summary if enabled."""
396 | get_logging_initializer().print_debug_summary_if_enabled()
397 |
398 |
399 | def cleanup_logging():
400 | """Cleanup logging system."""
401 | get_logging_initializer().cleanup_logging_system()
402 |
403 |
404 | # Environment detection and auto-initialization
405 | def auto_initialize_logging() -> LoggingSettings:
406 | """
407 | Automatically initialize logging based on environment variables.
408 |
409 | This function is called automatically when the module is imported
410 | in most cases, but can be called manually for custom initialization.
411 | """
412 | environment = os.getenv("MAVERICK_ENVIRONMENT", "development")
413 | debug_mode = os.getenv("MAVERICK_DEBUG", "false").lower() == "true"
414 |
415 | return initialize_backtesting_logging(
416 | environment=environment, debug_mode=debug_mode
417 | )
418 |
419 |
420 | # Auto-initialize if running as main module or in certain conditions
421 | if __name__ == "__main__":
422 | settings = auto_initialize_logging()
423 | print("Logging system initialized from command line")
424 | print_debug_summary()
425 | elif os.getenv("MAVERICK_AUTO_INIT_LOGGING", "false").lower() == "true":
426 | auto_initialize_logging()
427 |
```
--------------------------------------------------------------------------------
/tests/test_database_pool_config_simple.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Simplified tests for DatabasePoolConfig focusing on core functionality.
3 |
4 | This module tests the essential features of the enhanced database pool configuration:
5 | - Basic configuration and validation
6 | - Pool validation logic
7 | - Factory methods
8 | - Monitoring thresholds
9 | - Environment variable integration
10 | """
11 |
12 | import os
13 | import warnings
14 | from unittest.mock import patch
15 |
16 | import pytest
17 | from sqlalchemy.pool import QueuePool
18 |
19 | from maverick_mcp.config.database import (
20 | DatabasePoolConfig,
21 | get_default_pool_config,
22 | get_development_pool_config,
23 | get_high_concurrency_pool_config,
24 | validate_production_config,
25 | )
26 | from maverick_mcp.providers.interfaces.persistence import DatabaseConfig
27 |
28 |
29 | class TestDatabasePoolConfigBasics:
30 | """Test basic DatabasePoolConfig functionality."""
31 |
32 | def test_default_configuration(self):
33 | """Test default configuration values."""
34 | config = DatabasePoolConfig()
35 |
36 | # Should have reasonable defaults
37 | assert config.pool_size >= 5
38 | assert config.max_overflow >= 0
39 | assert config.pool_timeout > 0
40 | assert config.pool_recycle > 0
41 | assert config.max_database_connections > 0
42 |
43 | def test_valid_configuration(self):
44 | """Test a valid configuration passes validation."""
45 | config = DatabasePoolConfig(
46 | pool_size=10,
47 | max_overflow=5,
48 | max_database_connections=50,
49 | reserved_superuser_connections=3,
50 | expected_concurrent_users=10,
51 | connections_per_user=1.2,
52 | )
53 |
54 | assert config.pool_size == 10
55 | assert config.max_overflow == 5
56 |
57 | # Should calculate totals correctly
58 | total_app_connections = config.pool_size + config.max_overflow
59 | available_connections = (
60 | config.max_database_connections - config.reserved_superuser_connections
61 | )
62 | assert total_app_connections <= available_connections
63 |
64 | def test_validation_exceeds_database_capacity(self):
65 | """Test validation failure when pool exceeds database capacity."""
66 | with pytest.raises(
67 | ValueError, match="Pool configuration exceeds database capacity"
68 | ):
69 | DatabasePoolConfig(
70 | pool_size=50,
71 | max_overflow=30, # Total = 80
72 | max_database_connections=70, # Available = 67 (70-3)
73 | reserved_superuser_connections=3,
74 | expected_concurrent_users=60, # Adjust to avoid other validation errors
75 | connections_per_user=1.0,
76 | )
77 |
78 | def test_validation_insufficient_for_expected_load(self):
79 | """Test validation failure when pool is insufficient for expected load."""
80 | with pytest.raises(
81 | ValueError, match="Total connection capacity .* is insufficient"
82 | ):
83 | DatabasePoolConfig(
84 | pool_size=5,
85 | max_overflow=0, # Total capacity = 5
86 | expected_concurrent_users=10,
87 | connections_per_user=1.0, # Expected demand = 10
88 | max_database_connections=50,
89 | )
90 |
91 | def test_validation_warning_for_small_pool(self):
92 | """Test warning when pool size may be insufficient."""
93 | with warnings.catch_warnings(record=True) as w:
94 | warnings.simplefilter("always")
95 |
96 | DatabasePoolConfig(
97 | pool_size=5, # Small pool
98 | max_overflow=15, # But enough overflow to meet demand
99 | expected_concurrent_users=10,
100 | connections_per_user=1.5, # Expected demand = 15
101 | max_database_connections=50,
102 | )
103 |
104 | # Should generate a warning
105 | assert len(w) > 0
106 | assert "Pool size (5) may be insufficient" in str(w[0].message)
107 |
108 | def test_get_pool_kwargs(self):
109 | """Test SQLAlchemy pool configuration generation."""
110 | config = DatabasePoolConfig(
111 | pool_size=15,
112 | max_overflow=8,
113 | pool_timeout=45,
114 | pool_recycle=1800,
115 | pool_pre_ping=True,
116 | echo_pool=True,
117 | expected_concurrent_users=18,
118 | connections_per_user=1.0,
119 | )
120 |
121 | kwargs = config.get_pool_kwargs()
122 |
123 | expected = {
124 | "poolclass": QueuePool,
125 | "pool_size": 15,
126 | "max_overflow": 8,
127 | "pool_timeout": 45,
128 | "pool_recycle": 1800,
129 | "pool_pre_ping": True,
130 | "echo_pool": True,
131 | }
132 |
133 | assert kwargs == expected
134 |
135 | def test_get_monitoring_thresholds(self):
136 | """Test monitoring threshold calculation."""
137 | config = DatabasePoolConfig(
138 | pool_size=20,
139 | max_overflow=10,
140 | expected_concurrent_users=25,
141 | connections_per_user=1.0,
142 | )
143 | thresholds = config.get_monitoring_thresholds()
144 |
145 | expected = {
146 | "warning_threshold": int(20 * 0.8), # 16
147 | "critical_threshold": int(20 * 0.95), # 19
148 | "pool_size": 20,
149 | "max_overflow": 10,
150 | "total_capacity": 30,
151 | }
152 |
153 | assert thresholds == expected
154 |
155 | def test_to_legacy_config(self):
156 | """Test conversion to legacy DatabaseConfig."""
157 | config = DatabasePoolConfig(
158 | pool_size=15,
159 | max_overflow=8,
160 | pool_timeout=45,
161 | pool_recycle=1800,
162 | echo_pool=True,
163 | expected_concurrent_users=20,
164 | connections_per_user=1.0,
165 | )
166 |
167 | database_url = "postgresql://user:pass@localhost/test"
168 | legacy_config = config.to_legacy_config(database_url)
169 |
170 | assert isinstance(legacy_config, DatabaseConfig)
171 | assert legacy_config.database_url == database_url
172 | assert legacy_config.pool_size == 15
173 | assert legacy_config.max_overflow == 8
174 | assert legacy_config.pool_timeout == 45
175 | assert legacy_config.pool_recycle == 1800
176 | assert legacy_config.echo is True
177 |
178 | def test_from_legacy_config(self):
179 | """Test creation from legacy DatabaseConfig."""
180 | legacy_config = DatabaseConfig(
181 | database_url="postgresql://user:pass@localhost/test",
182 | pool_size=12,
183 | max_overflow=6,
184 | pool_timeout=60,
185 | pool_recycle=2400,
186 | echo=False,
187 | )
188 |
189 | enhanced_config = DatabasePoolConfig.from_legacy_config(
190 | legacy_config,
191 | expected_concurrent_users=15,
192 | max_database_connections=80,
193 | )
194 |
195 | assert enhanced_config.pool_size == 12
196 | assert enhanced_config.max_overflow == 6
197 | assert enhanced_config.pool_timeout == 60
198 | assert enhanced_config.pool_recycle == 2400
199 | assert enhanced_config.echo_pool is False
200 | assert enhanced_config.expected_concurrent_users == 15
201 | assert enhanced_config.max_database_connections == 80
202 |
203 |
204 | class TestFactoryMethods:
205 | """Test factory methods for different configuration types."""
206 |
207 | def test_get_default_pool_config(self):
208 | """Test default pool configuration factory."""
209 | config = get_default_pool_config()
210 |
211 | assert isinstance(config, DatabasePoolConfig)
212 | assert config.pool_size > 0
213 |
214 | def test_get_development_pool_config(self):
215 | """Test development pool configuration factory."""
216 | config = get_development_pool_config()
217 |
218 | assert isinstance(config, DatabasePoolConfig)
219 | assert config.pool_size == 5
220 | assert config.max_overflow == 2
221 | assert config.echo_pool is True # Debug enabled in development
222 |
223 | def test_get_high_concurrency_pool_config(self):
224 | """Test high concurrency pool configuration factory."""
225 | config = get_high_concurrency_pool_config()
226 |
227 | assert isinstance(config, DatabasePoolConfig)
228 | assert config.pool_size == 50
229 | assert config.max_overflow == 30
230 | assert config.expected_concurrent_users == 60
231 |
232 | def test_validate_production_config_valid(self):
233 | """Test production validation for valid configuration."""
234 | config = DatabasePoolConfig(
235 | pool_size=25,
236 | max_overflow=15,
237 | pool_timeout=30,
238 | pool_recycle=3600,
239 | expected_concurrent_users=35,
240 | connections_per_user=1.0,
241 | )
242 |
243 | with patch("maverick_mcp.config.database.logger") as mock_logger:
244 | result = validate_production_config(config)
245 |
246 | assert result is True
247 | mock_logger.info.assert_called()
248 |
249 | def test_validate_production_config_warnings(self):
250 | """Test production validation with warnings."""
251 | config = DatabasePoolConfig(
252 | pool_size=5, # Too small for production
253 | max_overflow=10, # Enough to meet demand but will warn
254 | pool_timeout=30,
255 | pool_recycle=3600,
256 | expected_concurrent_users=10,
257 | connections_per_user=1.0,
258 | )
259 |
260 | with patch("maverick_mcp.config.database.logger") as mock_logger:
261 | result = validate_production_config(config)
262 |
263 | assert result is True # Warnings don't fail validation
264 | # Should log warnings
265 | assert mock_logger.warning.called
266 |
267 | def test_validate_production_config_errors(self):
268 | """Test production validation with errors."""
269 | # Create a valid config first
270 | config = DatabasePoolConfig(
271 | pool_size=15,
272 | max_overflow=5,
273 | pool_timeout=5, # This is actually at the minimum, so will work
274 | pool_recycle=3600,
275 | expected_concurrent_users=18,
276 | connections_per_user=1.0,
277 | )
278 |
279 | # Now test the production validation function which has stricter requirements
280 | with pytest.raises(
281 | ValueError, match="Production configuration validation failed"
282 | ):
283 | validate_production_config(config)
284 |
285 |
286 | class TestEnvironmentVariables:
287 | """Test environment variable integration."""
288 |
289 | @patch.dict(
290 | os.environ,
291 | {
292 | "DB_POOL_SIZE": "25",
293 | "DB_MAX_OVERFLOW": "10",
294 | "DB_EXPECTED_CONCURRENT_USERS": "25",
295 | "DB_CONNECTIONS_PER_USER": "1.2",
296 | },
297 | )
298 | def test_environment_variable_overrides(self):
299 | """Test that environment variables override defaults."""
300 | config = DatabasePoolConfig()
301 |
302 | # Should use environment values
303 | assert config.pool_size == 25
304 | assert config.max_overflow == 10
305 | assert config.expected_concurrent_users == 25
306 | assert config.connections_per_user == 1.2
307 |
308 | @patch.dict(
309 | os.environ,
310 | {
311 | "DB_ECHO_POOL": "true",
312 | "DB_POOL_PRE_PING": "false",
313 | },
314 | )
315 | def test_boolean_environment_variables(self):
316 | """Test boolean environment variable parsing."""
317 | config = DatabasePoolConfig()
318 |
319 | assert config.echo_pool is True
320 | assert config.pool_pre_ping is False
321 |
322 |
323 | class TestValidationScenarios:
324 | """Test various validation scenarios."""
325 |
326 | def test_database_limits_validation(self):
327 | """Test validation against database connection limits."""
328 | config = DatabasePoolConfig(
329 | pool_size=10,
330 | max_overflow=5,
331 | max_database_connections=100,
332 | expected_concurrent_users=12,
333 | connections_per_user=1.0,
334 | )
335 |
336 | # Should pass validation when limits match
337 | config.validate_against_database_limits(100)
338 | assert config.max_database_connections == 100
339 |
340 | def test_database_limits_higher_actual(self):
341 | """Test when actual database limits are higher."""
342 | config = DatabasePoolConfig(
343 | pool_size=10,
344 | max_overflow=5,
345 | max_database_connections=50,
346 | expected_concurrent_users=12,
347 | connections_per_user=1.0,
348 | )
349 |
350 | with patch("maverick_mcp.config.database.logger") as mock_logger:
351 | config.validate_against_database_limits(100)
352 |
353 | # Should update configuration
354 | assert config.max_database_connections == 100
355 | mock_logger.info.assert_called()
356 |
357 | def test_database_limits_too_low(self):
358 | """Test when actual database limits are dangerously low."""
359 | config = DatabasePoolConfig(
360 | pool_size=30,
361 | max_overflow=20, # Total = 50
362 | max_database_connections=100,
363 | expected_concurrent_users=40,
364 | connections_per_user=1.0,
365 | )
366 |
367 | with pytest.raises(
368 | ValueError, match="Configuration invalid for actual database limits"
369 | ):
370 | # Actual limit is 40, available is 37, pool needs 50 - should fail
371 | config.validate_against_database_limits(40)
372 |
373 |
374 | class TestRealWorldScenarios:
375 | """Test realistic usage scenarios."""
376 |
377 | def test_microservice_configuration(self):
378 | """Test configuration suitable for microservice deployment."""
379 | config = DatabasePoolConfig(
380 | pool_size=8,
381 | max_overflow=4,
382 | expected_concurrent_users=10,
383 | connections_per_user=1.0,
384 | max_database_connections=50,
385 | )
386 |
387 | # Should be valid and suitable for microservice
388 | assert config.pool_size == 8
389 | thresholds = config.get_monitoring_thresholds()
390 | assert thresholds["total_capacity"] == 12
391 |
392 | def test_development_to_production_migration(self):
393 | """Test migrating from development to production configuration."""
394 | # Start with development config
395 | dev_config = get_development_pool_config()
396 | assert dev_config.echo_pool is True
397 | assert dev_config.pool_size == 5
398 |
399 | # Convert to legacy for compatibility
400 | legacy_config = dev_config.to_legacy_config("postgresql://localhost/test")
401 |
402 | # Upgrade to production config
403 | prod_config = DatabasePoolConfig.from_legacy_config(
404 | legacy_config,
405 | pool_size=30,
406 | max_overflow=20,
407 | expected_concurrent_users=40,
408 | echo_pool=False,
409 | )
410 |
411 | # Should be production-ready
412 | assert validate_production_config(prod_config) is True
413 | assert prod_config.echo_pool is False
414 | assert prod_config.pool_size == 30
415 |
416 | def test_connection_exhaustion_prevention(self):
417 | """Test that configuration prevents connection exhaustion."""
418 | # Configuration that would exhaust connections should fail
419 | with pytest.raises(ValueError, match="exceeds database capacity"):
420 | DatabasePoolConfig(
421 | pool_size=45,
422 | max_overflow=35, # Total = 80
423 | max_database_connections=75, # Available = 72
424 | expected_concurrent_users=60,
425 | connections_per_user=1.0,
426 | )
427 |
428 | # Safe configuration should work
429 | safe_config = DatabasePoolConfig(
430 | pool_size=30,
431 | max_overflow=20, # Total = 50
432 | max_database_connections=75, # Available = 72
433 | expected_concurrent_users=45,
434 | connections_per_user=1.0,
435 | )
436 |
437 | # Should work and leave room for other applications
438 | total_used = safe_config.pool_size + safe_config.max_overflow
439 | available = (
440 | safe_config.max_database_connections
441 | - safe_config.reserved_superuser_connections
442 | )
443 | assert total_used < available
444 |
```
--------------------------------------------------------------------------------
/alembic/versions/013_add_backtest_persistence_models.py:
--------------------------------------------------------------------------------
```python
1 | """Add backtest persistence models
2 |
3 | Revision ID: 013_add_backtest_persistence_models
4 | Revises: fix_database_integrity_issues
5 | Create Date: 2025-01-16 12:00:00.000000
6 |
7 | This migration adds comprehensive backtesting persistence models:
8 | 1. BacktestResult - Main backtest results with comprehensive metrics
9 | 2. BacktestTrade - Individual trade records from backtests
10 | 3. OptimizationResult - Parameter optimization results
11 | 4. WalkForwardTest - Walk-forward validation test results
12 | 5. BacktestPortfolio - Portfolio-level backtests with multiple symbols
13 |
14 | All tables include proper indexes for common query patterns and foreign key
15 | relationships for data integrity.
16 | """
17 |
18 | import sqlalchemy as sa
19 |
20 | from alembic import op
21 |
22 | # revision identifiers, used by Alembic.
23 | revision = "013_add_backtest_persistence_models"
24 | down_revision = "fix_database_integrity_issues"
25 | branch_labels = None
26 | depends_on = None
27 |
28 |
29 | def upgrade() -> None:
30 | # Create BacktestResult table
31 | op.create_table(
32 | "mcp_backtest_results",
33 | sa.Column("backtest_id", sa.Uuid(), nullable=False, primary_key=True),
34 | # Basic metadata
35 | sa.Column("symbol", sa.String(length=10), nullable=False),
36 | sa.Column("strategy_type", sa.String(length=50), nullable=False),
37 | sa.Column("backtest_date", sa.DateTime(timezone=True), nullable=False),
38 | # Date range and setup
39 | sa.Column("start_date", sa.Date(), nullable=False),
40 | sa.Column("end_date", sa.Date(), nullable=False),
41 | sa.Column(
42 | "initial_capital",
43 | sa.Numeric(precision=15, scale=2),
44 | server_default="10000.0",
45 | ),
46 | # Trading costs
47 | sa.Column("fees", sa.Numeric(precision=6, scale=4), server_default="0.001"),
48 | sa.Column("slippage", sa.Numeric(precision=6, scale=4), server_default="0.001"),
49 | # Strategy parameters
50 | sa.Column("parameters", sa.JSON()),
51 | # Performance metrics
52 | sa.Column("total_return", sa.Numeric(precision=10, scale=4)),
53 | sa.Column("annualized_return", sa.Numeric(precision=10, scale=4)),
54 | sa.Column("sharpe_ratio", sa.Numeric(precision=8, scale=4)),
55 | sa.Column("sortino_ratio", sa.Numeric(precision=8, scale=4)),
56 | sa.Column("calmar_ratio", sa.Numeric(precision=8, scale=4)),
57 | # Risk metrics
58 | sa.Column("max_drawdown", sa.Numeric(precision=8, scale=4)),
59 | sa.Column("max_drawdown_duration", sa.Integer()),
60 | sa.Column("volatility", sa.Numeric(precision=8, scale=4)),
61 | sa.Column("downside_volatility", sa.Numeric(precision=8, scale=4)),
62 | # Trade statistics
63 | sa.Column("total_trades", sa.Integer(), server_default="0"),
64 | sa.Column("winning_trades", sa.Integer(), server_default="0"),
65 | sa.Column("losing_trades", sa.Integer(), server_default="0"),
66 | sa.Column("win_rate", sa.Numeric(precision=5, scale=4)),
67 | # P&L statistics
68 | sa.Column("profit_factor", sa.Numeric(precision=8, scale=4)),
69 | sa.Column("average_win", sa.Numeric(precision=12, scale=4)),
70 | sa.Column("average_loss", sa.Numeric(precision=12, scale=4)),
71 | sa.Column("largest_win", sa.Numeric(precision=12, scale=4)),
72 | sa.Column("largest_loss", sa.Numeric(precision=12, scale=4)),
73 | # Portfolio values
74 | sa.Column("final_portfolio_value", sa.Numeric(precision=15, scale=2)),
75 | sa.Column("peak_portfolio_value", sa.Numeric(precision=15, scale=2)),
76 | # Market analysis
77 | sa.Column("beta", sa.Numeric(precision=8, scale=4)),
78 | sa.Column("alpha", sa.Numeric(precision=8, scale=4)),
79 | # Time series data
80 | sa.Column("equity_curve", sa.JSON()),
81 | sa.Column("drawdown_series", sa.JSON()),
82 | # Execution metadata
83 | sa.Column("execution_time_seconds", sa.Numeric(precision=8, scale=3)),
84 | sa.Column("data_points", sa.Integer()),
85 | # Status and notes
86 | sa.Column("status", sa.String(length=20), server_default="completed"),
87 | sa.Column("error_message", sa.Text()),
88 | sa.Column("notes", sa.Text()),
89 | # Timestamps
90 | sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
91 | sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
92 | )
93 |
94 | # Create indexes for BacktestResult
95 | op.create_index(
96 | "mcp_backtest_results_symbol_idx", "mcp_backtest_results", ["symbol"]
97 | )
98 | op.create_index(
99 | "mcp_backtest_results_strategy_idx", "mcp_backtest_results", ["strategy_type"]
100 | )
101 | op.create_index(
102 | "mcp_backtest_results_date_idx", "mcp_backtest_results", ["backtest_date"]
103 | )
104 | op.create_index(
105 | "mcp_backtest_results_sharpe_idx", "mcp_backtest_results", ["sharpe_ratio"]
106 | )
107 | op.create_index(
108 | "mcp_backtest_results_total_return_idx",
109 | "mcp_backtest_results",
110 | ["total_return"],
111 | )
112 | op.create_index(
113 | "mcp_backtest_results_symbol_strategy_idx",
114 | "mcp_backtest_results",
115 | ["symbol", "strategy_type"],
116 | )
117 |
118 | # Create BacktestTrade table
119 | op.create_table(
120 | "mcp_backtest_trades",
121 | sa.Column("trade_id", sa.Uuid(), nullable=False, primary_key=True),
122 | sa.Column("backtest_id", sa.Uuid(), nullable=False),
123 | # Trade identification
124 | sa.Column("trade_number", sa.Integer(), nullable=False),
125 | # Entry details
126 | sa.Column("entry_date", sa.Date(), nullable=False),
127 | sa.Column("entry_price", sa.Numeric(precision=12, scale=4), nullable=False),
128 | sa.Column("entry_time", sa.DateTime(timezone=True)),
129 | # Exit details
130 | sa.Column("exit_date", sa.Date()),
131 | sa.Column("exit_price", sa.Numeric(precision=12, scale=4)),
132 | sa.Column("exit_time", sa.DateTime(timezone=True)),
133 | # Position details
134 | sa.Column("position_size", sa.Numeric(precision=15, scale=2)),
135 | sa.Column("direction", sa.String(length=5), nullable=False),
136 | # P&L
137 | sa.Column("pnl", sa.Numeric(precision=12, scale=4)),
138 | sa.Column("pnl_percent", sa.Numeric(precision=8, scale=4)),
139 | # Risk metrics
140 | sa.Column("mae", sa.Numeric(precision=8, scale=4)), # Maximum Adverse Excursion
141 | sa.Column(
142 | "mfe", sa.Numeric(precision=8, scale=4)
143 | ), # Maximum Favorable Excursion
144 | # Duration
145 | sa.Column("duration_days", sa.Integer()),
146 | sa.Column("duration_hours", sa.Numeric(precision=8, scale=2)),
147 | # Exit details
148 | sa.Column("exit_reason", sa.String(length=50)),
149 | sa.Column("fees_paid", sa.Numeric(precision=10, scale=4), server_default="0"),
150 | sa.Column(
151 | "slippage_cost", sa.Numeric(precision=10, scale=4), server_default="0"
152 | ),
153 | # Timestamps
154 | sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
155 | sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
156 | # Foreign key constraint
157 | sa.ForeignKeyConstraint(
158 | ["backtest_id"], ["mcp_backtest_results.backtest_id"], ondelete="CASCADE"
159 | ),
160 | )
161 |
162 | # Create indexes for BacktestTrade
163 | op.create_index(
164 | "mcp_backtest_trades_backtest_idx", "mcp_backtest_trades", ["backtest_id"]
165 | )
166 | op.create_index(
167 | "mcp_backtest_trades_entry_date_idx", "mcp_backtest_trades", ["entry_date"]
168 | )
169 | op.create_index(
170 | "mcp_backtest_trades_exit_date_idx", "mcp_backtest_trades", ["exit_date"]
171 | )
172 | op.create_index("mcp_backtest_trades_pnl_idx", "mcp_backtest_trades", ["pnl"])
173 | op.create_index(
174 | "mcp_backtest_trades_backtest_entry_idx",
175 | "mcp_backtest_trades",
176 | ["backtest_id", "entry_date"],
177 | )
178 |
179 | # Create OptimizationResult table
180 | op.create_table(
181 | "mcp_optimization_results",
182 | sa.Column("optimization_id", sa.Uuid(), nullable=False, primary_key=True),
183 | sa.Column("backtest_id", sa.Uuid(), nullable=False),
184 | # Optimization metadata
185 | sa.Column("optimization_date", sa.DateTime(timezone=True), nullable=False),
186 | sa.Column("parameter_set", sa.Integer(), nullable=False),
187 | # Parameters and results
188 | sa.Column("parameters", sa.JSON(), nullable=False),
189 | sa.Column("objective_function", sa.String(length=50)),
190 | sa.Column("objective_value", sa.Numeric(precision=12, scale=6)),
191 | # Key metrics
192 | sa.Column("total_return", sa.Numeric(precision=10, scale=4)),
193 | sa.Column("sharpe_ratio", sa.Numeric(precision=8, scale=4)),
194 | sa.Column("max_drawdown", sa.Numeric(precision=8, scale=4)),
195 | sa.Column("win_rate", sa.Numeric(precision=5, scale=4)),
196 | sa.Column("profit_factor", sa.Numeric(precision=8, scale=4)),
197 | sa.Column("total_trades", sa.Integer()),
198 | # Ranking
199 | sa.Column("rank", sa.Integer()),
200 | # Statistical significance
201 | sa.Column("is_statistically_significant", sa.Boolean(), server_default="false"),
202 | sa.Column("p_value", sa.Numeric(precision=8, scale=6)),
203 | # Timestamps
204 | sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
205 | sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
206 | # Foreign key constraint
207 | sa.ForeignKeyConstraint(
208 | ["backtest_id"], ["mcp_backtest_results.backtest_id"], ondelete="CASCADE"
209 | ),
210 | )
211 |
212 | # Create indexes for OptimizationResult
213 | op.create_index(
214 | "mcp_optimization_results_backtest_idx",
215 | "mcp_optimization_results",
216 | ["backtest_id"],
217 | )
218 | op.create_index(
219 | "mcp_optimization_results_param_set_idx",
220 | "mcp_optimization_results",
221 | ["parameter_set"],
222 | )
223 | op.create_index(
224 | "mcp_optimization_results_objective_idx",
225 | "mcp_optimization_results",
226 | ["objective_value"],
227 | )
228 |
229 | # Create WalkForwardTest table
230 | op.create_table(
231 | "mcp_walk_forward_tests",
232 | sa.Column("walk_forward_id", sa.Uuid(), nullable=False, primary_key=True),
233 | sa.Column("parent_backtest_id", sa.Uuid(), nullable=False),
234 | # Test configuration
235 | sa.Column("test_date", sa.DateTime(timezone=True), nullable=False),
236 | sa.Column("window_size_months", sa.Integer(), nullable=False),
237 | sa.Column("step_size_months", sa.Integer(), nullable=False),
238 | # Time periods
239 | sa.Column("training_start", sa.Date(), nullable=False),
240 | sa.Column("training_end", sa.Date(), nullable=False),
241 | sa.Column("test_period_start", sa.Date(), nullable=False),
242 | sa.Column("test_period_end", sa.Date(), nullable=False),
243 | # Training results
244 | sa.Column("optimal_parameters", sa.JSON()),
245 | sa.Column("training_performance", sa.Numeric(precision=10, scale=4)),
246 | # Out-of-sample results
247 | sa.Column("out_of_sample_return", sa.Numeric(precision=10, scale=4)),
248 | sa.Column("out_of_sample_sharpe", sa.Numeric(precision=8, scale=4)),
249 | sa.Column("out_of_sample_drawdown", sa.Numeric(precision=8, scale=4)),
250 | sa.Column("out_of_sample_trades", sa.Integer()),
251 | # Performance analysis
252 | sa.Column("performance_ratio", sa.Numeric(precision=8, scale=4)),
253 | sa.Column("degradation_factor", sa.Numeric(precision=8, scale=4)),
254 | # Validation
255 | sa.Column("is_profitable", sa.Boolean()),
256 | sa.Column("is_statistically_significant", sa.Boolean(), server_default="false"),
257 | # Timestamps
258 | sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
259 | sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
260 | # Foreign key constraint
261 | sa.ForeignKeyConstraint(
262 | ["parent_backtest_id"],
263 | ["mcp_backtest_results.backtest_id"],
264 | ondelete="CASCADE",
265 | ),
266 | )
267 |
268 | # Create indexes for WalkForwardTest
269 | op.create_index(
270 | "mcp_walk_forward_tests_parent_idx",
271 | "mcp_walk_forward_tests",
272 | ["parent_backtest_id"],
273 | )
274 | op.create_index(
275 | "mcp_walk_forward_tests_period_idx",
276 | "mcp_walk_forward_tests",
277 | ["test_period_start"],
278 | )
279 | op.create_index(
280 | "mcp_walk_forward_tests_performance_idx",
281 | "mcp_walk_forward_tests",
282 | ["out_of_sample_return"],
283 | )
284 |
285 | # Create BacktestPortfolio table
286 | op.create_table(
287 | "mcp_backtest_portfolios",
288 | sa.Column("portfolio_backtest_id", sa.Uuid(), nullable=False, primary_key=True),
289 | # Portfolio identification
290 | sa.Column("portfolio_name", sa.String(length=100), nullable=False),
291 | sa.Column("description", sa.Text()),
292 | # Test metadata
293 | sa.Column("backtest_date", sa.DateTime(timezone=True), nullable=False),
294 | sa.Column("start_date", sa.Date(), nullable=False),
295 | sa.Column("end_date", sa.Date(), nullable=False),
296 | # Portfolio composition
297 | sa.Column("symbols", sa.JSON(), nullable=False),
298 | sa.Column("weights", sa.JSON()),
299 | sa.Column("rebalance_frequency", sa.String(length=20)),
300 | # Portfolio parameters
301 | sa.Column(
302 | "initial_capital",
303 | sa.Numeric(precision=15, scale=2),
304 | server_default="100000.0",
305 | ),
306 | sa.Column("max_positions", sa.Integer()),
307 | sa.Column("position_sizing_method", sa.String(length=50)),
308 | # Risk management
309 | sa.Column("portfolio_stop_loss", sa.Numeric(precision=6, scale=4)),
310 | sa.Column("max_sector_allocation", sa.Numeric(precision=5, scale=4)),
311 | sa.Column("correlation_threshold", sa.Numeric(precision=5, scale=4)),
312 | # Performance metrics
313 | sa.Column("total_return", sa.Numeric(precision=10, scale=4)),
314 | sa.Column("annualized_return", sa.Numeric(precision=10, scale=4)),
315 | sa.Column("sharpe_ratio", sa.Numeric(precision=8, scale=4)),
316 | sa.Column("sortino_ratio", sa.Numeric(precision=8, scale=4)),
317 | sa.Column("max_drawdown", sa.Numeric(precision=8, scale=4)),
318 | sa.Column("volatility", sa.Numeric(precision=8, scale=4)),
319 | # Portfolio-specific metrics
320 | sa.Column("diversification_ratio", sa.Numeric(precision=8, scale=4)),
321 | sa.Column("concentration_index", sa.Numeric(precision=8, scale=4)),
322 | sa.Column("turnover_rate", sa.Numeric(precision=8, scale=4)),
323 | # References and time series
324 | sa.Column("component_backtest_ids", sa.JSON()),
325 | sa.Column("portfolio_equity_curve", sa.JSON()),
326 | sa.Column("portfolio_weights_history", sa.JSON()),
327 | # Status
328 | sa.Column("status", sa.String(length=20), server_default="completed"),
329 | sa.Column("notes", sa.Text()),
330 | # Timestamps
331 | sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
332 | sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
333 | )
334 |
335 | # Create indexes for BacktestPortfolio
336 | op.create_index(
337 | "mcp_backtest_portfolios_name_idx",
338 | "mcp_backtest_portfolios",
339 | ["portfolio_name"],
340 | )
341 | op.create_index(
342 | "mcp_backtest_portfolios_date_idx", "mcp_backtest_portfolios", ["backtest_date"]
343 | )
344 | op.create_index(
345 | "mcp_backtest_portfolios_return_idx",
346 | "mcp_backtest_portfolios",
347 | ["total_return"],
348 | )
349 |
350 |
351 | def downgrade() -> None:
352 | # Drop tables in reverse order (due to foreign key constraints)
353 | op.drop_table("mcp_backtest_portfolios")
354 | op.drop_table("mcp_walk_forward_tests")
355 | op.drop_table("mcp_optimization_results")
356 | op.drop_table("mcp_backtest_trades")
357 | op.drop_table("mcp_backtest_results")
358 |
```
--------------------------------------------------------------------------------
/maverick_mcp/tools/risk_management.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Risk management tools for position sizing, stop loss calculation, and portfolio risk analysis.
3 | """
4 |
5 | import logging
6 | from datetime import datetime, timedelta
7 | from typing import Any
8 |
9 | import numpy as np
10 | import pandas as pd
11 | from pydantic import BaseModel, Field
12 |
13 | from maverick_mcp.agents.base import PersonaAwareTool
14 | from maverick_mcp.core.technical_analysis import calculate_atr
15 | from maverick_mcp.providers.stock_data import StockDataProvider
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | class PositionSizeInput(BaseModel):
21 | """Input for position sizing calculations."""
22 |
23 | account_size: float = Field(description="Total account size in dollars")
24 | entry_price: float = Field(description="Planned entry price")
25 | stop_loss_price: float = Field(description="Stop loss price")
26 | risk_percentage: float = Field(
27 | default=2.0, description="Percentage of account to risk (default 2%)"
28 | )
29 |
30 |
31 | class TechnicalStopsInput(BaseModel):
32 | """Input for technical stop calculations."""
33 |
34 | symbol: str = Field(description="Stock symbol")
35 | lookback_days: int = Field(default=20, description="Days to look back for analysis")
36 | atr_multiplier: float = Field(
37 | default=2.0, description="ATR multiplier for stop distance"
38 | )
39 |
40 |
41 | class RiskMetricsInput(BaseModel):
42 | """Input for portfolio risk metrics."""
43 |
44 | symbols: list[str] = Field(description="List of symbols in portfolio")
45 | weights: list[float] | None = Field(
46 | default=None, description="Portfolio weights (equal weight if not provided)"
47 | )
48 | lookback_days: int = Field(
49 | default=252, description="Days for correlation calculation"
50 | )
51 |
52 |
53 | class PositionSizeTool(PersonaAwareTool):
54 | """Calculate position size based on risk management rules."""
55 |
56 | name: str = "calculate_position_size"
57 | description: str = (
58 | "Calculate position size based on account risk, with Kelly Criterion "
59 | "and persona adjustments"
60 | )
61 | args_schema: type[BaseModel] = PositionSizeInput
62 |
63 | def _run(
64 | self,
65 | account_size: float,
66 | entry_price: float,
67 | stop_loss_price: float,
68 | risk_percentage: float = 2.0,
69 | ) -> str:
70 | """Calculate position size synchronously."""
71 | try:
72 | # Basic risk calculation
73 | risk_amount = account_size * (risk_percentage / 100)
74 | price_risk = abs(entry_price - stop_loss_price)
75 |
76 | if price_risk == 0:
77 | return "Error: Entry and stop loss prices cannot be the same"
78 |
79 | # Calculate base position size
80 | base_shares = risk_amount / price_risk
81 | base_position_value = base_shares * entry_price
82 |
83 | # Apply persona adjustments
84 | adjusted_shares = self.adjust_for_risk(base_shares, "position_size")
85 | adjusted_value = adjusted_shares * entry_price
86 |
87 | # Calculate Kelly fraction if persona is set
88 | kelly_fraction = 0.25 # Default conservative Kelly
89 | if self.persona:
90 | risk_factor = sum(self.persona.risk_tolerance) / 100
91 | kelly_fraction = self._calculate_kelly_fraction(risk_factor)
92 |
93 | kelly_shares = base_shares * kelly_fraction
94 | kelly_value = kelly_shares * entry_price
95 |
96 | # Ensure position doesn't exceed max allocation
97 | max_position_pct = self.persona.position_size_max if self.persona else 0.10
98 | max_position_value = account_size * max_position_pct
99 |
100 | final_shares = min(adjusted_shares, kelly_shares)
101 | final_value = final_shares * entry_price
102 |
103 | if final_value > max_position_value:
104 | final_shares = max_position_value / entry_price
105 | final_value = max_position_value
106 |
107 | result = {
108 | "status": "success",
109 | "position_sizing": {
110 | "recommended_shares": int(final_shares),
111 | "position_value": round(final_value, 2),
112 | "position_percentage": round((final_value / account_size) * 100, 2),
113 | "risk_amount": round(risk_amount, 2),
114 | "price_risk_per_share": round(price_risk, 2),
115 | "r_multiple_target": round(
116 | 2.0 * price_risk / entry_price * 100, 2
117 | ), # 2R target
118 | },
119 | "calculations": {
120 | "base_shares": int(base_shares),
121 | "base_position_value": round(base_position_value, 2),
122 | "kelly_shares": int(kelly_shares),
123 | "kelly_value": round(kelly_value, 2),
124 | "persona_adjusted_shares": int(adjusted_shares),
125 | "persona_adjusted_value": round(adjusted_value, 2),
126 | "kelly_fraction": round(kelly_fraction, 3),
127 | "max_allowed_value": round(max_position_value, 2),
128 | },
129 | }
130 |
131 | # Add persona insights if available
132 | if self.persona:
133 | result["persona_insights"] = {
134 | "investor_type": self.persona.name,
135 | "risk_tolerance": self.persona.risk_tolerance,
136 | "max_position_size": f"{self.persona.position_size_max * 100:.1f}%",
137 | "suitable_for_profile": final_value <= max_position_value,
138 | }
139 |
140 | # Format for return
141 | formatted = self.format_for_persona(result)
142 | return str(formatted)
143 |
144 | except Exception as e:
145 | logger.error(f"Error calculating position size: {e}")
146 | return f"Error calculating position size: {str(e)}"
147 |
148 |
149 | class TechnicalStopsTool(PersonaAwareTool):
150 | """Calculate stop loss levels based on technical analysis."""
151 |
152 | name: str = "calculate_technical_stops"
153 | description: str = (
154 | "Calculate stop loss levels using ATR, support levels, and moving averages"
155 | )
156 | args_schema: type[BaseModel] = TechnicalStopsInput
157 |
158 | def _run(
159 | self, symbol: str, lookback_days: int = 20, atr_multiplier: float = 2.0
160 | ) -> str:
161 | """Calculate technical stops synchronously."""
162 | try:
163 | provider = StockDataProvider()
164 |
165 | # Get price data
166 | end_date = datetime.now()
167 | start_date = end_date - timedelta(days=max(lookback_days * 2, 100))
168 |
169 | df = provider.get_stock_data(
170 | symbol,
171 | start_date.strftime("%Y-%m-%d"),
172 | end_date.strftime("%Y-%m-%d"),
173 | use_cache=True,
174 | )
175 |
176 | if df.empty:
177 | return f"Error: No price data available for {symbol}"
178 |
179 | # Calculate technical levels
180 | current_price = df["Close"].iloc[-1]
181 |
182 | # ATR-based stop
183 | atr = calculate_atr(df, period=14)
184 | atr_value = atr.iloc[-1]
185 | atr_stop = current_price - (atr_value * atr_multiplier)
186 |
187 | # Support-based stops
188 | recent_lows = df["Low"].rolling(window=lookback_days).min()
189 | support_level = recent_lows.iloc[-1]
190 |
191 | # Moving average stops
192 | ma_20 = float(df["Close"].rolling(window=20).mean().iloc[-1])
193 | ma_50 = (
194 | float(df["Close"].rolling(window=50).mean().iloc[-1])
195 | if len(df) >= 50
196 | else None
197 | )
198 |
199 | # Swing low stop (lowest low in recent period)
200 | swing_low = df["Low"].iloc[-lookback_days:].min()
201 |
202 | # Apply persona adjustments
203 | if self.persona:
204 | atr_multiplier = self.adjust_for_risk(atr_multiplier, "stop_loss")
205 | atr_stop = current_price - (atr_value * atr_multiplier)
206 |
207 | stops = {
208 | "current_price": round(current_price, 2),
209 | "atr_stop": round(atr_stop, 2),
210 | "support_stop": round(support_level, 2),
211 | "swing_low_stop": round(swing_low, 2),
212 | "ma_20_stop": round(ma_20, 2),
213 | "ma_50_stop": round(ma_50, 2) if ma_50 else None,
214 | "atr_value": round(atr_value, 2),
215 | "stop_distances": {
216 | "atr_stop_pct": round(
217 | ((current_price - atr_stop) / current_price) * 100, 2
218 | ),
219 | "support_stop_pct": round(
220 | ((current_price - support_level) / current_price) * 100, 2
221 | ),
222 | "swing_low_pct": round(
223 | ((current_price - swing_low) / current_price) * 100, 2
224 | ),
225 | },
226 | }
227 |
228 | # Recommend stop based on persona
229 | if self.persona:
230 | if self.persona.name == "Conservative":
231 | recommended = max(atr_stop, ma_20) # Tighter stop
232 | elif self.persona.name == "Day Trader":
233 | recommended = atr_stop # ATR-based for volatility
234 | else:
235 | recommended = min(support_level, atr_stop) # Balance
236 | else:
237 | recommended = atr_stop
238 |
239 | stops["recommended_stop"] = round(recommended, 2)
240 | stops["recommended_stop_pct"] = round(
241 | ((current_price - recommended) / current_price) * 100, 2
242 | )
243 |
244 | result = {
245 | "status": "success",
246 | "symbol": symbol,
247 | "technical_stops": stops,
248 | "analysis_period": lookback_days,
249 | "atr_multiplier": atr_multiplier,
250 | }
251 |
252 | # Format for persona
253 | formatted = self.format_for_persona(result)
254 | return str(formatted)
255 |
256 | except Exception as e:
257 | logger.error(f"Error calculating technical stops for {symbol}: {e}")
258 | return f"Error calculating technical stops: {str(e)}"
259 |
260 |
261 | class RiskMetricsTool(PersonaAwareTool):
262 | """Calculate portfolio risk metrics including correlations and VaR."""
263 |
264 | name: str = "calculate_risk_metrics"
265 | description: str = (
266 | "Calculate portfolio risk metrics including correlation, beta, and VaR"
267 | )
268 | args_schema: type[BaseModel] = RiskMetricsInput # type: ignore[assignment]
269 |
270 | def _run(
271 | self,
272 | symbols: list[str],
273 | weights: list[float] | None = None,
274 | lookback_days: int = 252,
275 | ) -> str:
276 | """Calculate risk metrics synchronously."""
277 | try:
278 | if not symbols:
279 | return "Error: No symbols provided"
280 |
281 | provider = StockDataProvider()
282 |
283 | # If no weights provided, use equal weight
284 | if weights is None:
285 | weights = [1.0 / len(symbols)] * len(symbols)
286 | elif len(weights) != len(symbols):
287 | return "Error: Number of weights must match number of symbols"
288 |
289 | # Normalize weights
290 | weights_array = np.array(weights)
291 | weights = list(weights_array / weights_array.sum())
292 |
293 | # Get price data for all symbols
294 | end_date = datetime.now()
295 | start_date = end_date - timedelta(days=lookback_days + 30)
296 |
297 | price_data = {}
298 | returns_data = {}
299 |
300 | for symbol in symbols:
301 | df = provider.get_stock_data(
302 | symbol,
303 | start_date.strftime("%Y-%m-%d"),
304 | end_date.strftime("%Y-%m-%d"),
305 | use_cache=True,
306 | )
307 | if not df.empty:
308 | price_data[symbol] = df["Close"]
309 | returns_data[symbol] = df["Close"].pct_change().dropna()
310 |
311 | if not returns_data:
312 | return "Error: No price data available for any symbols"
313 |
314 | # Create returns DataFrame
315 | returns_df = pd.DataFrame(returns_data).dropna()
316 |
317 | # Calculate correlation matrix
318 | correlation_matrix = returns_df.corr()
319 |
320 | # Calculate portfolio metrics
321 | portfolio_returns = (returns_df * weights[: len(returns_df.columns)]).sum(
322 | axis=1
323 | )
324 | portfolio_std = portfolio_returns.std() * np.sqrt(252) # Annualized
325 |
326 | # Calculate VaR (95% confidence)
327 | var_95 = np.percentile(portfolio_returns, 5) * np.sqrt(252)
328 |
329 | # Calculate portfolio beta (vs SPY)
330 | spy_df = provider.get_stock_data(
331 | "SPY",
332 | start_date.strftime("%Y-%m-%d"),
333 | end_date.strftime("%Y-%m-%d"),
334 | use_cache=True,
335 | )
336 | if not spy_df.empty:
337 | spy_returns = spy_df["Close"].pct_change().dropna()
338 | # Align dates
339 | common_dates = portfolio_returns.index.intersection(spy_returns.index)
340 | if len(common_dates) > 0:
341 | portfolio_beta = (
342 | portfolio_returns[common_dates].cov(spy_returns[common_dates])
343 | / spy_returns[common_dates].var()
344 | )
345 | else:
346 | portfolio_beta = None
347 | else:
348 | portfolio_beta = None
349 |
350 | # Build result
351 | result = {
352 | "status": "success",
353 | "portfolio_metrics": {
354 | "annualized_volatility": round(portfolio_std * 100, 2),
355 | "value_at_risk_95": round(var_95 * 100, 2),
356 | "portfolio_beta": round(portfolio_beta, 2)
357 | if portfolio_beta
358 | else None,
359 | "avg_correlation": round(
360 | correlation_matrix.values[
361 | np.triu_indices_from(correlation_matrix.values, k=1)
362 | ].mean(),
363 | 3,
364 | ),
365 | },
366 | "correlations": correlation_matrix.to_dict(),
367 | "weights": {
368 | symbol: round(weight, 3)
369 | for symbol, weight in zip(
370 | symbols[: len(weights)], weights, strict=False
371 | )
372 | },
373 | "risk_assessment": self._assess_portfolio_risk(
374 | portfolio_std, var_95, correlation_matrix
375 | ),
376 | }
377 |
378 | # Format for persona
379 | formatted = self.format_for_persona(result)
380 | return str(formatted)
381 |
382 | except Exception as e:
383 | logger.error(f"Error calculating risk metrics: {e}")
384 | return f"Error calculating risk metrics: {str(e)}"
385 |
386 | def _assess_portfolio_risk(
387 | self, volatility: float, var: float, correlation_matrix: pd.DataFrame
388 | ) -> dict[str, Any]:
389 | """Assess portfolio risk level."""
390 | risk_level = "Low"
391 | warnings = []
392 |
393 | # Check volatility
394 | if volatility > 0.25: # 25% annual vol
395 | risk_level = "High"
396 | warnings.append("High portfolio volatility")
397 | elif volatility > 0.15:
398 | risk_level = "Moderate"
399 |
400 | # Check VaR
401 | if abs(var) > 0.10: # 10% VaR
402 | warnings.append("High Value at Risk")
403 |
404 | # Check correlation
405 | avg_corr = correlation_matrix.values[
406 | np.triu_indices_from(correlation_matrix.values, k=1)
407 | ].mean()
408 | if avg_corr > 0.7:
409 | warnings.append("High correlation between holdings")
410 |
411 | return {
412 | "risk_level": risk_level,
413 | "warnings": warnings,
414 | "diversification_score": round(1 - avg_corr, 2),
415 | }
416 |
```