This is page 12 of 39. Use http://codebase.md/wshobson/maverick-mcp?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.md
│ │ ├── config.yml
│ │ ├── feature_request.md
│ │ ├── question.md
│ │ └── security_report.md
│ ├── pull_request_template.md
│ └── workflows
│ ├── claude-code-review.yml
│ └── claude.yml
├── .gitignore
├── .python-version
├── .vscode
│ ├── launch.json
│ └── settings.json
├── alembic
│ ├── env.py
│ ├── script.py.mako
│ └── versions
│ ├── 001_initial_schema.py
│ ├── 003_add_performance_indexes.py
│ ├── 006_rename_metadata_columns.py
│ ├── 008_performance_optimization_indexes.py
│ ├── 009_rename_to_supply_demand.py
│ ├── 010_self_contained_schema.py
│ ├── 011_remove_proprietary_terms.py
│ ├── 013_add_backtest_persistence_models.py
│ ├── 014_add_portfolio_models.py
│ ├── 08e3945a0c93_merge_heads.py
│ ├── 9374a5c9b679_merge_heads_for_testing.py
│ ├── abf9b9afb134_merge_multiple_heads.py
│ ├── adda6d3fd84b_merge_proprietary_terms_removal_with_.py
│ ├── e0c75b0bdadb_fix_financial_data_precision_only.py
│ ├── f0696e2cac15_add_essential_performance_indexes.py
│ └── fix_database_integrity_issues.py
├── alembic.ini
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── DATABASE_SETUP.md
├── docker-compose.override.yml.example
├── docker-compose.yml
├── Dockerfile
├── docs
│ ├── api
│ │ └── backtesting.md
│ ├── BACKTESTING.md
│ ├── COST_BASIS_SPECIFICATION.md
│ ├── deep_research_agent.md
│ ├── exa_research_testing_strategy.md
│ ├── PORTFOLIO_PERSONALIZATION_PLAN.md
│ ├── PORTFOLIO.md
│ ├── SETUP_SELF_CONTAINED.md
│ └── speed_testing_framework.md
├── examples
│ ├── complete_speed_validation.py
│ ├── deep_research_integration.py
│ ├── llm_optimization_example.py
│ ├── llm_speed_demo.py
│ ├── monitoring_example.py
│ ├── parallel_research_example.py
│ ├── speed_optimization_demo.py
│ └── timeout_fix_demonstration.py
├── LICENSE
├── Makefile
├── MANIFEST.in
├── maverick_mcp
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── circuit_breaker.py
│ │ ├── deep_research.py
│ │ ├── market_analysis.py
│ │ ├── optimized_research.py
│ │ ├── supervisor.py
│ │ └── technical_analysis.py
│ ├── api
│ │ ├── __init__.py
│ │ ├── api_server.py
│ │ ├── connection_manager.py
│ │ ├── dependencies
│ │ │ ├── __init__.py
│ │ │ ├── stock_analysis.py
│ │ │ └── technical_analysis.py
│ │ ├── error_handling.py
│ │ ├── inspector_compatible_sse.py
│ │ ├── inspector_sse.py
│ │ ├── middleware
│ │ │ ├── error_handling.py
│ │ │ ├── mcp_logging.py
│ │ │ ├── rate_limiting_enhanced.py
│ │ │ └── security.py
│ │ ├── openapi_config.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── agents.py
│ │ │ ├── backtesting.py
│ │ │ ├── data_enhanced.py
│ │ │ ├── data.py
│ │ │ ├── health_enhanced.py
│ │ │ ├── health_tools.py
│ │ │ ├── health.py
│ │ │ ├── intelligent_backtesting.py
│ │ │ ├── introspection.py
│ │ │ ├── mcp_prompts.py
│ │ │ ├── monitoring.py
│ │ │ ├── news_sentiment_enhanced.py
│ │ │ ├── performance.py
│ │ │ ├── portfolio.py
│ │ │ ├── research.py
│ │ │ ├── screening_ddd.py
│ │ │ ├── screening_parallel.py
│ │ │ ├── screening.py
│ │ │ ├── technical_ddd.py
│ │ │ ├── technical_enhanced.py
│ │ │ ├── technical.py
│ │ │ └── tool_registry.py
│ │ ├── server.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── base_service.py
│ │ │ ├── market_service.py
│ │ │ ├── portfolio_service.py
│ │ │ ├── prompt_service.py
│ │ │ └── resource_service.py
│ │ ├── simple_sse.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── insomnia_export.py
│ │ └── postman_export.py
│ ├── application
│ │ ├── __init__.py
│ │ ├── commands
│ │ │ └── __init__.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_dto.py
│ │ ├── queries
│ │ │ ├── __init__.py
│ │ │ └── get_technical_analysis.py
│ │ └── screening
│ │ ├── __init__.py
│ │ ├── dtos.py
│ │ └── queries.py
│ ├── backtesting
│ │ ├── __init__.py
│ │ ├── ab_testing.py
│ │ ├── analysis.py
│ │ ├── batch_processing_stub.py
│ │ ├── batch_processing.py
│ │ ├── model_manager.py
│ │ ├── optimization.py
│ │ ├── persistence.py
│ │ ├── retraining_pipeline.py
│ │ ├── strategies
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── ml
│ │ │ │ ├── __init__.py
│ │ │ │ ├── adaptive.py
│ │ │ │ ├── ensemble.py
│ │ │ │ ├── feature_engineering.py
│ │ │ │ └── regime_aware.py
│ │ │ ├── ml_strategies.py
│ │ │ ├── parser.py
│ │ │ └── templates.py
│ │ ├── strategy_executor.py
│ │ ├── vectorbt_engine.py
│ │ └── visualization.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── database_self_contained.py
│ │ ├── database.py
│ │ ├── llm_optimization_config.py
│ │ ├── logging_settings.py
│ │ ├── plotly_config.py
│ │ ├── security_utils.py
│ │ ├── security.py
│ │ ├── settings.py
│ │ ├── technical_constants.py
│ │ ├── tool_estimation.py
│ │ └── validation.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── technical_analysis.py
│ │ └── visualization.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── cache_manager.py
│ │ ├── cache.py
│ │ ├── django_adapter.py
│ │ ├── health.py
│ │ ├── models.py
│ │ ├── performance.py
│ │ ├── session_management.py
│ │ └── validation.py
│ ├── database
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── optimization.py
│ ├── dependencies.py
│ ├── domain
│ │ ├── __init__.py
│ │ ├── entities
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis.py
│ │ ├── events
│ │ │ └── __init__.py
│ │ ├── portfolio.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ ├── entities.py
│ │ │ ├── services.py
│ │ │ └── value_objects.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_service.py
│ │ ├── stock_analysis
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis_service.py
│ │ └── value_objects
│ │ ├── __init__.py
│ │ └── technical_indicators.py
│ ├── exceptions.py
│ ├── infrastructure
│ │ ├── __init__.py
│ │ ├── cache
│ │ │ └── __init__.py
│ │ ├── caching
│ │ │ ├── __init__.py
│ │ │ └── cache_management_service.py
│ │ ├── connection_manager.py
│ │ ├── data_fetching
│ │ │ ├── __init__.py
│ │ │ └── stock_data_service.py
│ │ ├── health
│ │ │ ├── __init__.py
│ │ │ └── health_checker.py
│ │ ├── persistence
│ │ │ ├── __init__.py
│ │ │ └── stock_repository.py
│ │ ├── providers
│ │ │ └── __init__.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ └── repositories.py
│ │ └── sse_optimizer.py
│ ├── langchain_tools
│ │ ├── __init__.py
│ │ ├── adapters.py
│ │ └── registry.py
│ ├── logging_config.py
│ ├── memory
│ │ ├── __init__.py
│ │ └── stores.py
│ ├── monitoring
│ │ ├── __init__.py
│ │ ├── health_check.py
│ │ ├── health_monitor.py
│ │ ├── integration_example.py
│ │ ├── metrics.py
│ │ ├── middleware.py
│ │ └── status_dashboard.py
│ ├── providers
│ │ ├── __init__.py
│ │ ├── dependencies.py
│ │ ├── factories
│ │ │ ├── __init__.py
│ │ │ ├── config_factory.py
│ │ │ └── provider_factory.py
│ │ ├── implementations
│ │ │ ├── __init__.py
│ │ │ ├── cache_adapter.py
│ │ │ ├── macro_data_adapter.py
│ │ │ ├── market_data_adapter.py
│ │ │ ├── persistence_adapter.py
│ │ │ └── stock_data_adapter.py
│ │ ├── interfaces
│ │ │ ├── __init__.py
│ │ │ ├── cache.py
│ │ │ ├── config.py
│ │ │ ├── macro_data.py
│ │ │ ├── market_data.py
│ │ │ ├── persistence.py
│ │ │ └── stock_data.py
│ │ ├── llm_factory.py
│ │ ├── macro_data.py
│ │ ├── market_data.py
│ │ ├── mocks
│ │ │ ├── __init__.py
│ │ │ ├── mock_cache.py
│ │ │ ├── mock_config.py
│ │ │ ├── mock_macro_data.py
│ │ │ ├── mock_market_data.py
│ │ │ ├── mock_persistence.py
│ │ │ └── mock_stock_data.py
│ │ ├── openrouter_provider.py
│ │ ├── optimized_screening.py
│ │ ├── optimized_stock_data.py
│ │ └── stock_data.py
│ ├── README.md
│ ├── tests
│ │ ├── __init__.py
│ │ ├── README_INMEMORY_TESTS.md
│ │ ├── test_cache_debug.py
│ │ ├── test_fixes_validation.py
│ │ ├── test_in_memory_routers.py
│ │ ├── test_in_memory_server.py
│ │ ├── test_macro_data_provider.py
│ │ ├── test_mailgun_email.py
│ │ ├── test_market_calendar_caching.py
│ │ ├── test_mcp_tool_fixes_pytest.py
│ │ ├── test_mcp_tool_fixes.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_models_functional.py
│ │ ├── test_server.py
│ │ ├── test_stock_data_enhanced.py
│ │ ├── test_stock_data_provider.py
│ │ └── test_technical_analysis.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── performance_monitoring.py
│ │ ├── portfolio_manager.py
│ │ ├── risk_management.py
│ │ └── sentiment_analysis.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── agent_errors.py
│ │ ├── batch_processing.py
│ │ ├── cache_warmer.py
│ │ ├── circuit_breaker_decorators.py
│ │ ├── circuit_breaker_services.py
│ │ ├── circuit_breaker.py
│ │ ├── data_chunking.py
│ │ ├── database_monitoring.py
│ │ ├── debug_utils.py
│ │ ├── fallback_strategies.py
│ │ ├── llm_optimization.py
│ │ ├── logging_example.py
│ │ ├── logging_init.py
│ │ ├── logging.py
│ │ ├── mcp_logging.py
│ │ ├── memory_profiler.py
│ │ ├── monitoring_middleware.py
│ │ ├── monitoring.py
│ │ ├── orchestration_logging.py
│ │ ├── parallel_research.py
│ │ ├── parallel_screening.py
│ │ ├── quick_cache.py
│ │ ├── resource_manager.py
│ │ ├── shutdown.py
│ │ ├── stock_helpers.py
│ │ ├── structured_logger.py
│ │ ├── tool_monitoring.py
│ │ ├── tracing.py
│ │ └── yfinance_pool.py
│ ├── validation
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── data.py
│ │ ├── middleware.py
│ │ ├── portfolio.py
│ │ ├── responses.py
│ │ ├── screening.py
│ │ └── technical.py
│ └── workflows
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── market_analyzer.py
│ │ ├── optimizer_agent.py
│ │ ├── strategy_selector.py
│ │ └── validator_agent.py
│ ├── backtesting_workflow.py
│ └── state.py
├── PLANS.md
├── pyproject.toml
├── pyrightconfig.json
├── README.md
├── scripts
│ ├── dev.sh
│ ├── INSTALLATION_GUIDE.md
│ ├── load_example.py
│ ├── load_market_data.py
│ ├── load_tiingo_data.py
│ ├── migrate_db.py
│ ├── README_TIINGO_LOADER.md
│ ├── requirements_tiingo.txt
│ ├── run_stock_screening.py
│ ├── run-migrations.sh
│ ├── seed_db.py
│ ├── seed_sp500.py
│ ├── setup_database.sh
│ ├── setup_self_contained.py
│ ├── setup_sp500_database.sh
│ ├── test_seeded_data.py
│ ├── test_tiingo_loader.py
│ ├── tiingo_config.py
│ └── validate_setup.py
├── SECURITY.md
├── server.json
├── setup.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── core
│ │ └── test_technical_analysis.py
│ ├── data
│ │ └── test_portfolio_models.py
│ ├── domain
│ │ ├── conftest.py
│ │ ├── test_portfolio_entities.py
│ │ └── test_technical_analysis_service.py
│ ├── fixtures
│ │ └── orchestration_fixtures.py
│ ├── integration
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── README.md
│ │ ├── run_integration_tests.sh
│ │ ├── test_api_technical.py
│ │ ├── test_chaos_engineering.py
│ │ ├── test_config_management.py
│ │ ├── test_full_backtest_workflow_advanced.py
│ │ ├── test_full_backtest_workflow.py
│ │ ├── test_high_volume.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_orchestration_complete.py
│ │ ├── test_portfolio_persistence.py
│ │ ├── test_redis_cache.py
│ │ ├── test_security_integration.py.disabled
│ │ └── vcr_setup.py
│ ├── performance
│ │ ├── __init__.py
│ │ ├── test_benchmarks.py
│ │ ├── test_load.py
│ │ ├── test_profiling.py
│ │ └── test_stress.py
│ ├── providers
│ │ └── test_stock_data_simple.py
│ ├── README.md
│ ├── test_agents_router_mcp.py
│ ├── test_backtest_persistence.py
│ ├── test_cache_management_service.py
│ ├── test_cache_serialization.py
│ ├── test_circuit_breaker.py
│ ├── test_database_pool_config_simple.py
│ ├── test_database_pool_config.py
│ ├── test_deep_research_functional.py
│ ├── test_deep_research_integration.py
│ ├── test_deep_research_parallel_execution.py
│ ├── test_error_handling.py
│ ├── test_event_loop_integrity.py
│ ├── test_exa_research_integration.py
│ ├── test_exception_hierarchy.py
│ ├── test_financial_search.py
│ ├── test_graceful_shutdown.py
│ ├── test_integration_simple.py
│ ├── test_langgraph_workflow.py
│ ├── test_market_data_async.py
│ ├── test_market_data_simple.py
│ ├── test_mcp_orchestration_functional.py
│ ├── test_ml_strategies.py
│ ├── test_optimized_research_agent.py
│ ├── test_orchestration_integration.py
│ ├── test_orchestration_logging.py
│ ├── test_orchestration_tools_simple.py
│ ├── test_parallel_research_integration.py
│ ├── test_parallel_research_orchestrator.py
│ ├── test_parallel_research_performance.py
│ ├── test_performance_optimizations.py
│ ├── test_production_validation.py
│ ├── test_provider_architecture.py
│ ├── test_rate_limiting_enhanced.py
│ ├── test_runner_validation.py
│ ├── test_security_comprehensive.py.disabled
│ ├── test_security_cors.py
│ ├── test_security_enhancements.py.disabled
│ ├── test_security_headers.py
│ ├── test_security_penetration.py
│ ├── test_session_management.py
│ ├── test_speed_optimization_validation.py
│ ├── test_stock_analysis_dependencies.py
│ ├── test_stock_analysis_service.py
│ ├── test_stock_data_fetching_service.py
│ ├── test_supervisor_agent.py
│ ├── test_supervisor_functional.py
│ ├── test_tool_estimation_config.py
│ ├── test_visualization.py
│ └── utils
│ ├── test_agent_errors.py
│ ├── test_logging.py
│ ├── test_parallel_screening.py
│ └── test_quick_cache.py
├── tools
│ ├── check_orchestration_config.py
│ ├── experiments
│ │ ├── validation_examples.py
│ │ └── validation_fixed.py
│ ├── fast_dev.sh
│ ├── hot_reload.py
│ ├── quick_test.py
│ └── templates
│ ├── new_router_template.py
│ ├── new_tool_template.py
│ ├── screening_strategy_template.py
│ └── test_template.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/maverick_mcp/providers/optimized_screening.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Optimized screening operations with eager loading and batch processing.
3 |
4 | This module demonstrates proper eager loading patterns and optimizations
5 | for database queries to prevent N+1 query issues.
6 | """
7 |
8 | from datetime import datetime, timedelta
9 | from typing import Any
10 |
11 | from sqlalchemy import and_
12 | from sqlalchemy.orm import Session, selectinload
13 |
14 | from maverick_mcp.data.models import (
15 | MaverickBearStocks,
16 | MaverickStocks,
17 | PriceCache,
18 | Stock,
19 | SupplyDemandBreakoutStocks,
20 | )
21 | from maverick_mcp.data.session_management import get_db_session
22 | from maverick_mcp.utils.logging import get_logger
23 |
24 | logger = get_logger(__name__)
25 |
26 |
27 | class OptimizedScreeningProvider:
28 | """
29 | Optimized screening provider that demonstrates proper eager loading
30 | and batch operations to prevent N+1 queries.
31 | """
32 |
33 | def __init__(self, session: Session | None = None):
34 | """Initialize with optional database session."""
35 | self._session = session
36 |
37 | def _get_session(self) -> tuple[Session, bool]:
38 | """Get database session and whether it should be closed."""
39 | if self._session:
40 | return self._session, False
41 | else:
42 | return next(get_db_session()), True
43 |
44 | def get_enhanced_maverick_recommendations(
45 | self,
46 | limit: int = 20,
47 | min_score: int | None = None,
48 | include_stock_details: bool = True,
49 | ) -> list[dict[str, Any]]:
50 | """
51 | Get Maverick recommendations with optional stock details using eager loading.
52 |
53 | This demonstrates proper eager loading to prevent N+1 queries when
54 | accessing related Stock model data.
55 |
56 | Args:
57 | limit: Maximum number of recommendations
58 | min_score: Minimum combined score filter
59 | include_stock_details: Whether to include full stock details (requires joins)
60 |
61 | Returns:
62 | List of stock recommendations with enhanced details
63 | """
64 | session, should_close = self._get_session()
65 | try:
66 | if include_stock_details:
67 | # Example of proper eager loading if there were relationships
68 | # This would prevent N+1 queries when accessing stock details
69 | query = (
70 | session.query(MaverickStocks)
71 | # If MaverickStocks had a foreign key to Stock, we would use:
72 | # .options(joinedload(MaverickStocks.stock_details))
73 | # Since it doesn't, we'll show how to join manually
74 | .join(Stock, Stock.ticker_symbol == MaverickStocks.stock)
75 | .options(
76 | # Eager load any related data to prevent N+1 queries
77 | selectinload(
78 | Stock.price_caches.and_(
79 | PriceCache.date >= datetime.now() - timedelta(days=30)
80 | )
81 | )
82 | )
83 | )
84 | else:
85 | # Simple query without joins for basic screening
86 | query = session.query(MaverickStocks)
87 |
88 | # Apply filters
89 | if min_score:
90 | query = query.filter(MaverickStocks.combined_score >= min_score)
91 |
92 | # Execute query with limit
93 | if include_stock_details:
94 | results = (
95 | query.order_by(MaverickStocks.combined_score.desc())
96 | .limit(limit)
97 | .all()
98 | )
99 | stocks = [(maverick_stock, stock) for maverick_stock, stock in results]
100 | else:
101 | stocks = (
102 | query.order_by(MaverickStocks.combined_score.desc())
103 | .limit(limit)
104 | .all()
105 | )
106 |
107 | # Process results efficiently
108 | recommendations = []
109 | for item in stocks:
110 | if include_stock_details:
111 | maverick_stock, stock_details = item
112 | rec = {
113 | **maverick_stock.to_dict(),
114 | "recommendation_type": "maverick_bullish",
115 | "reason": self._generate_reason(maverick_stock),
116 | # Enhanced details from Stock model
117 | "company_name": stock_details.company_name,
118 | "sector": stock_details.sector,
119 | "industry": stock_details.industry,
120 | "exchange": stock_details.exchange,
121 | # Recent price data (already eager loaded)
122 | "recent_prices": [
123 | {
124 | "date": pc.date.isoformat(),
125 | "close": pc.close_price,
126 | "volume": pc.volume,
127 | }
128 | for pc in stock_details.price_caches[-5:] # Last 5 days
129 | ]
130 | if stock_details.price_caches
131 | else [],
132 | }
133 | else:
134 | rec = {
135 | **item.to_dict(),
136 | "recommendation_type": "maverick_bullish",
137 | "reason": self._generate_reason(item),
138 | }
139 | recommendations.append(rec)
140 |
141 | return recommendations
142 |
143 | except Exception as e:
144 | logger.error(f"Error getting enhanced maverick recommendations: {e}")
145 | return []
146 | finally:
147 | if should_close:
148 | session.close()
149 |
150 | def get_batch_stock_details(self, symbols: list[str]) -> dict[str, dict[str, Any]]:
151 | """
152 | Get stock details for multiple symbols efficiently with batch query.
153 |
154 | This demonstrates how to avoid N+1 queries when fetching details
155 | for multiple stocks by using a single batch query.
156 |
157 | Args:
158 | symbols: List of stock symbols
159 |
160 | Returns:
161 | Dictionary mapping symbols to their details
162 | """
163 | session, should_close = self._get_session()
164 | try:
165 | # Single query to get all stock details with eager loading
166 | stocks = (
167 | session.query(Stock)
168 | .options(
169 | # Eager load price caches to prevent N+1 queries
170 | selectinload(
171 | Stock.price_caches.and_(
172 | PriceCache.date >= datetime.now() - timedelta(days=30)
173 | )
174 | )
175 | )
176 | .filter(Stock.ticker_symbol.in_(symbols))
177 | .all()
178 | )
179 |
180 | # Build result dictionary
181 | result = {}
182 | for stock in stocks:
183 | result[stock.ticker_symbol] = {
184 | "company_name": stock.company_name,
185 | "sector": stock.sector,
186 | "industry": stock.industry,
187 | "exchange": stock.exchange,
188 | "country": stock.country,
189 | "currency": stock.currency,
190 | "recent_prices": [
191 | {
192 | "date": pc.date.isoformat(),
193 | "close": pc.close_price,
194 | "volume": pc.volume,
195 | "high": pc.high_price,
196 | "low": pc.low_price,
197 | }
198 | for pc in sorted(stock.price_caches, key=lambda x: x.date)[-10:]
199 | ]
200 | if stock.price_caches
201 | else [],
202 | }
203 |
204 | return result
205 |
206 | except Exception as e:
207 | logger.error(f"Error getting batch stock details: {e}")
208 | return {}
209 | finally:
210 | if should_close:
211 | session.close()
212 |
213 | def get_comprehensive_screening_results(
214 | self, include_details: bool = False
215 | ) -> dict[str, list[dict[str, Any]]]:
216 | """
217 | Get all screening results efficiently with optional eager loading.
218 |
219 | This demonstrates how to minimize database queries when fetching
220 | multiple types of screening results.
221 |
222 | Args:
223 | include_details: Whether to include enhanced stock details
224 |
225 | Returns:
226 | Dictionary with all screening types and their results
227 | """
228 | session, should_close = self._get_session()
229 | try:
230 | results = {}
231 |
232 | if include_details:
233 | # Get all unique stock symbols first
234 | maverick_symbols = (
235 | session.query(MaverickStocks.stock).distinct().subquery()
236 | )
237 | bear_symbols = (
238 | session.query(MaverickBearStocks.stock).distinct().subquery()
239 | )
240 | supply_demand_symbols = (
241 | session.query(SupplyDemandBreakoutStocks.stock)
242 | .distinct()
243 | .subquery()
244 | )
245 |
246 | # Single query to get all stock details for all screening types
247 | all_symbols = (
248 | session.query(maverick_symbols.c.stock)
249 | .union(session.query(bear_symbols.c.stock))
250 | .union(session.query(supply_demand_symbols.c.stock))
251 | .all()
252 | )
253 |
254 | symbol_list = [s[0] for s in all_symbols]
255 | stock_details = self.get_batch_stock_details(symbol_list)
256 |
257 | # Get screening results
258 | maverick_stocks = (
259 | session.query(MaverickStocks)
260 | .order_by(MaverickStocks.combined_score.desc())
261 | .limit(20)
262 | .all()
263 | )
264 |
265 | bear_stocks = (
266 | session.query(MaverickBearStocks)
267 | .order_by(MaverickBearStocks.score.desc())
268 | .limit(20)
269 | .all()
270 | )
271 |
272 | supply_demand_stocks = (
273 | session.query(SupplyDemandBreakoutStocks)
274 | .filter(
275 | and_(
276 | SupplyDemandBreakoutStocks.close_price
277 | > SupplyDemandBreakoutStocks.sma_50,
278 | SupplyDemandBreakoutStocks.close_price
279 | > SupplyDemandBreakoutStocks.sma_150,
280 | SupplyDemandBreakoutStocks.close_price
281 | > SupplyDemandBreakoutStocks.sma_200,
282 | )
283 | )
284 | .order_by(SupplyDemandBreakoutStocks.momentum_score.desc())
285 | .limit(20)
286 | .all()
287 | )
288 |
289 | # Process results with optional details
290 | results["maverick_bullish"] = [
291 | {
292 | **stock.to_dict(),
293 | "recommendation_type": "maverick_bullish",
294 | "reason": self._generate_reason(stock),
295 | **(stock_details.get(stock.stock, {}) if include_details else {}),
296 | }
297 | for stock in maverick_stocks
298 | ]
299 |
300 | results["maverick_bearish"] = [
301 | {
302 | **stock.to_dict(),
303 | "recommendation_type": "maverick_bearish",
304 | "reason": self._generate_bear_reason(stock),
305 | **(stock_details.get(stock.stock, {}) if include_details else {}),
306 | }
307 | for stock in bear_stocks
308 | ]
309 |
310 | results["supply_demand_breakouts"] = [
311 | {
312 | **stock.to_dict(),
313 | "recommendation_type": "supply_demand_breakout",
314 | "reason": self._generate_supply_demand_reason(stock),
315 | **(stock_details.get(stock.stock, {}) if include_details else {}),
316 | }
317 | for stock in supply_demand_stocks
318 | ]
319 |
320 | return results
321 |
322 | except Exception as e:
323 | logger.error(f"Error getting comprehensive screening results: {e}")
324 | return {}
325 | finally:
326 | if should_close:
327 | session.close()
328 |
329 | def _generate_reason(self, stock: MaverickStocks) -> str:
330 | """Generate recommendation reason for Maverick stock."""
331 | reasons = []
332 |
333 | if hasattr(stock, "combined_score") and stock.combined_score >= 90:
334 | reasons.append("Exceptional combined score")
335 | elif hasattr(stock, "combined_score") and stock.combined_score >= 80:
336 | reasons.append("Strong combined score")
337 |
338 | if hasattr(stock, "momentum_score") and stock.momentum_score >= 90:
339 | reasons.append("outstanding relative strength")
340 | elif hasattr(stock, "momentum_score") and stock.momentum_score >= 80:
341 | reasons.append("strong relative strength")
342 |
343 | if hasattr(stock, "pat") and stock.pat:
344 | reasons.append(f"{stock.pat} pattern detected")
345 |
346 | return (
347 | "Bullish setup with " + ", ".join(reasons)
348 | if reasons
349 | else "Strong technical setup"
350 | )
351 |
352 | def _generate_bear_reason(self, stock: MaverickBearStocks) -> str:
353 | """Generate recommendation reason for bear stock."""
354 | reasons = []
355 |
356 | if hasattr(stock, "score") and stock.score >= 80:
357 | reasons.append("Strong bear signals")
358 |
359 | if hasattr(stock, "momentum_score") and stock.momentum_score <= 30:
360 | reasons.append("weak relative strength")
361 |
362 | return (
363 | "Bearish setup with " + ", ".join(reasons)
364 | if reasons
365 | else "Weak technical setup"
366 | )
367 |
368 | def _generate_supply_demand_reason(self, stock: SupplyDemandBreakoutStocks) -> str:
369 | """Generate recommendation reason for supply/demand breakout stock."""
370 | reasons = []
371 |
372 | if hasattr(stock, "momentum_score") and stock.momentum_score >= 90:
373 | reasons.append("exceptional relative strength")
374 |
375 | if hasattr(stock, "close") and hasattr(stock, "sma_200"):
376 | if stock.close > stock.sma_200 * 1.1: # 10% above 200 SMA
377 | reasons.append("strong uptrend")
378 |
379 | return (
380 | "Supply/demand breakout with " + ", ".join(reasons)
381 | if reasons
382 | else "Supply absorption and demand expansion"
383 | )
384 |
```
--------------------------------------------------------------------------------
/maverick_mcp/application/queries/get_technical_analysis.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Application query for getting technical analysis.
3 |
4 | This query orchestrates the domain services and repositories
5 | to provide technical analysis functionality.
6 | """
7 |
8 | from datetime import UTC, datetime, timedelta
9 | from typing import Protocol
10 |
11 | import pandas as pd
12 |
13 | from maverick_mcp.application.dto.technical_analysis_dto import (
14 | BollingerBandsDTO,
15 | CompleteTechnicalAnalysisDTO,
16 | MACDAnalysisDTO,
17 | PriceLevelDTO,
18 | RSIAnalysisDTO,
19 | StochasticDTO,
20 | TrendAnalysisDTO,
21 | VolumeAnalysisDTO,
22 | )
23 | from maverick_mcp.domain.entities.stock_analysis import StockAnalysis
24 | from maverick_mcp.domain.services.technical_analysis_service import (
25 | TechnicalAnalysisService,
26 | )
27 | from maverick_mcp.domain.value_objects.technical_indicators import (
28 | Signal,
29 | )
30 |
31 |
32 | class StockDataRepository(Protocol):
33 | """Protocol for stock data repository."""
34 |
35 | def get_price_data(
36 | self, symbol: str, start_date: str, end_date: str
37 | ) -> pd.DataFrame:
38 | """Get historical price data."""
39 | ...
40 |
41 |
42 | class GetTechnicalAnalysisQuery:
43 | """
44 | Application query for retrieving technical analysis.
45 |
46 | This query coordinates between the domain layer and infrastructure
47 | to provide technical analysis without exposing domain complexity.
48 | """
49 |
50 | def __init__(
51 | self,
52 | stock_repository: StockDataRepository,
53 | technical_service: TechnicalAnalysisService,
54 | ):
55 | """
56 | Initialize the query handler.
57 |
58 | Args:
59 | stock_repository: Repository for fetching stock data
60 | technical_service: Domain service for technical calculations
61 | """
62 | self.stock_repository = stock_repository
63 | self.technical_service = technical_service
64 |
65 | async def execute(
66 | self,
67 | symbol: str,
68 | days: int = 365,
69 | indicators: list[str] | None = None,
70 | rsi_period: int = 14,
71 | ) -> CompleteTechnicalAnalysisDTO:
72 | """
73 | Execute the technical analysis query.
74 |
75 | Args:
76 | symbol: Stock ticker symbol
77 | days: Number of days of historical data
78 | indicators: Specific indicators to calculate (None = all)
79 | rsi_period: Period for RSI calculation (default: 14)
80 |
81 | Returns:
82 | Complete technical analysis DTO
83 | """
84 | # Calculate date range
85 | end_date = datetime.now(UTC)
86 | start_date = end_date - timedelta(days=days)
87 |
88 | # Fetch stock data from repository
89 | df = self.stock_repository.get_price_data(
90 | symbol,
91 | start_date.strftime("%Y-%m-%d"),
92 | end_date.strftime("%Y-%m-%d"),
93 | )
94 |
95 | # Create domain entity
96 | analysis = StockAnalysis(
97 | symbol=symbol,
98 | analysis_date=datetime.now(UTC),
99 | current_price=float(df["close"].iloc[-1]),
100 | trend_direction=self.technical_service.identify_trend(
101 | pd.Series(df["close"])
102 | ),
103 | trend_strength=self._calculate_trend_strength(df),
104 | analysis_period_days=days,
105 | indicators_used=[], # Initialize indicators_used
106 | )
107 |
108 | # Calculate requested indicators
109 | # Since we initialized indicators_used as [], it's safe to use
110 | assert analysis.indicators_used is not None
111 |
112 | if not indicators or "rsi" in indicators:
113 | analysis.rsi = self.technical_service.calculate_rsi(
114 | pd.Series(df["close"]), period=rsi_period
115 | )
116 | analysis.indicators_used.append("RSI")
117 |
118 | if not indicators or "macd" in indicators:
119 | analysis.macd = self.technical_service.calculate_macd(
120 | pd.Series(df["close"])
121 | )
122 | analysis.indicators_used.append("MACD")
123 |
124 | if not indicators or "bollinger" in indicators:
125 | analysis.bollinger_bands = self.technical_service.calculate_bollinger_bands(
126 | pd.Series(df["close"])
127 | )
128 | analysis.indicators_used.append("Bollinger Bands")
129 |
130 | if not indicators or "stochastic" in indicators:
131 | analysis.stochastic = self.technical_service.calculate_stochastic(
132 | pd.Series(df["high"]), pd.Series(df["low"]), pd.Series(df["close"])
133 | )
134 | analysis.indicators_used.append("Stochastic")
135 |
136 | # Analyze volume
137 | if "volume" in df.columns:
138 | analysis.volume_profile = self.technical_service.analyze_volume(
139 | pd.Series(df["volume"])
140 | )
141 |
142 | # Calculate support and resistance levels
143 | analysis.support_levels = self.technical_service.find_support_levels(df)
144 | analysis.resistance_levels = self.technical_service.find_resistance_levels(df)
145 |
146 | # Calculate composite signal
147 | analysis.composite_signal = self.technical_service.calculate_composite_signal(
148 | analysis.rsi,
149 | analysis.macd,
150 | analysis.bollinger_bands,
151 | analysis.stochastic,
152 | )
153 |
154 | # Calculate confidence score
155 | analysis.confidence_score = self._calculate_confidence_score(analysis)
156 |
157 | # Convert to DTO
158 | return self._map_to_dto(analysis)
159 |
160 | def _calculate_trend_strength(self, df: pd.DataFrame) -> float:
161 | """Calculate trend strength as a percentage."""
162 | # Simple implementation using price change
163 | if len(df) < 20:
164 | return 0.0
165 |
166 | price_change = (df["close"].iloc[-1] - df["close"].iloc[-20]) / df[
167 | "close"
168 | ].iloc[-20]
169 | return float(min(abs(price_change) * 100, 100.0))
170 |
171 | def _calculate_confidence_score(self, analysis: StockAnalysis) -> float:
172 | """Calculate confidence score based on indicator agreement."""
173 | signals = []
174 |
175 | if analysis.rsi:
176 | signals.append(analysis.rsi.signal)
177 | if analysis.macd:
178 | signals.append(analysis.macd.signal)
179 | if analysis.bollinger_bands:
180 | signals.append(analysis.bollinger_bands.signal)
181 | if analysis.stochastic:
182 | signals.append(analysis.stochastic.signal)
183 |
184 | if not signals:
185 | return 0.0
186 |
187 | # Count agreeing signals
188 | signal_counts: dict[Signal, int] = {}
189 | for signal in signals:
190 | signal_counts[signal] = signal_counts.get(signal, 0) + 1
191 |
192 | max_agreement = max(signal_counts.values())
193 | confidence = (max_agreement / len(signals)) * 100
194 |
195 | # Boost confidence if volume confirms
196 | if analysis.volume_profile and analysis.volume_profile.unusual_activity:
197 | confidence = min(100, confidence + 10)
198 |
199 | return float(confidence)
200 |
201 | def _map_to_dto(self, analysis: StockAnalysis) -> CompleteTechnicalAnalysisDTO:
202 | """Map domain entity to DTO."""
203 | dto = CompleteTechnicalAnalysisDTO(
204 | symbol=analysis.symbol,
205 | analysis_date=analysis.analysis_date,
206 | current_price=analysis.current_price,
207 | trend=TrendAnalysisDTO(
208 | direction=analysis.trend_direction.value,
209 | strength=analysis.trend_strength,
210 | interpretation=self._interpret_trend(analysis),
211 | ),
212 | composite_signal=analysis.composite_signal.value,
213 | confidence_score=analysis.confidence_score,
214 | risk_reward_ratio=analysis.risk_reward_ratio,
215 | summary=self._generate_summary(analysis),
216 | key_levels=analysis.get_key_levels(),
217 | rsi=None,
218 | macd=None,
219 | bollinger_bands=None,
220 | stochastic=None,
221 | volume_analysis=None,
222 | )
223 |
224 | # Map indicators if present
225 | if analysis.rsi:
226 | dto.rsi = RSIAnalysisDTO(
227 | current_value=analysis.rsi.value,
228 | period=analysis.rsi.period,
229 | signal=analysis.rsi.signal.value,
230 | is_overbought=analysis.rsi.is_overbought,
231 | is_oversold=analysis.rsi.is_oversold,
232 | interpretation=self._interpret_rsi(analysis.rsi),
233 | )
234 |
235 | if analysis.macd:
236 | dto.macd = MACDAnalysisDTO(
237 | macd_line=analysis.macd.macd_line,
238 | signal_line=analysis.macd.signal_line,
239 | histogram=analysis.macd.histogram,
240 | signal=analysis.macd.signal.value,
241 | is_bullish_crossover=analysis.macd.is_bullish_crossover,
242 | is_bearish_crossover=analysis.macd.is_bearish_crossover,
243 | interpretation=self._interpret_macd(analysis.macd),
244 | )
245 |
246 | if analysis.bollinger_bands:
247 | dto.bollinger_bands = BollingerBandsDTO(
248 | upper_band=analysis.bollinger_bands.upper_band,
249 | middle_band=analysis.bollinger_bands.middle_band,
250 | lower_band=analysis.bollinger_bands.lower_band,
251 | current_price=analysis.bollinger_bands.current_price,
252 | bandwidth=analysis.bollinger_bands.bandwidth,
253 | percent_b=analysis.bollinger_bands.percent_b,
254 | signal=analysis.bollinger_bands.signal.value,
255 | interpretation=self._interpret_bollinger(analysis.bollinger_bands),
256 | )
257 |
258 | if analysis.stochastic:
259 | dto.stochastic = StochasticDTO(
260 | k_value=analysis.stochastic.k_value,
261 | d_value=analysis.stochastic.d_value,
262 | signal=analysis.stochastic.signal.value,
263 | is_overbought=analysis.stochastic.is_overbought,
264 | is_oversold=analysis.stochastic.is_oversold,
265 | interpretation=self._interpret_stochastic(analysis.stochastic),
266 | )
267 |
268 | # Map levels
269 | dto.support_levels = [
270 | PriceLevelDTO(
271 | price=level.price,
272 | strength=level.strength,
273 | touches=level.touches,
274 | distance_from_current=(
275 | (analysis.current_price - level.price)
276 | / analysis.current_price
277 | * 100
278 | ),
279 | )
280 | for level in (analysis.support_levels or [])
281 | ]
282 |
283 | dto.resistance_levels = [
284 | PriceLevelDTO(
285 | price=level.price,
286 | strength=level.strength,
287 | touches=level.touches,
288 | distance_from_current=(
289 | (level.price - analysis.current_price)
290 | / analysis.current_price
291 | * 100
292 | ),
293 | )
294 | for level in (analysis.resistance_levels or [])
295 | ]
296 |
297 | # Map volume if present
298 | if analysis.volume_profile:
299 | dto.volume_analysis = VolumeAnalysisDTO(
300 | current_volume=analysis.volume_profile.current_volume,
301 | average_volume=analysis.volume_profile.average_volume,
302 | relative_volume=analysis.volume_profile.relative_volume,
303 | volume_trend=analysis.volume_profile.volume_trend.value,
304 | unusual_activity=analysis.volume_profile.unusual_activity,
305 | interpretation=self._interpret_volume(analysis.volume_profile),
306 | )
307 |
308 | return dto
309 |
310 | def _generate_summary(self, analysis: StockAnalysis) -> str:
311 | """Generate executive summary of the analysis."""
312 | signal_text = {
313 | Signal.STRONG_BUY: "strong buy signal",
314 | Signal.BUY: "buy signal",
315 | Signal.NEUTRAL: "neutral stance",
316 | Signal.SELL: "sell signal",
317 | Signal.STRONG_SELL: "strong sell signal",
318 | }
319 |
320 | summary_parts = [
321 | f"{analysis.symbol} shows a {signal_text[analysis.composite_signal]}",
322 | f"with {analysis.confidence_score:.0f}% confidence.",
323 | f"The stock is in a {analysis.trend_direction.value.replace('_', ' ')}.",
324 | ]
325 |
326 | if analysis.risk_reward_ratio:
327 | summary_parts.append(
328 | f"Risk/reward ratio is {analysis.risk_reward_ratio:.2f}."
329 | )
330 |
331 | return " ".join(summary_parts)
332 |
333 | def _interpret_trend(self, analysis: StockAnalysis) -> str:
334 | """Generate trend interpretation."""
335 | return (
336 | f"The stock is showing a {analysis.trend_direction.value.replace('_', ' ')} "
337 | f"with {analysis.trend_strength:.0f}% strength."
338 | )
339 |
340 | def _interpret_rsi(self, rsi) -> str:
341 | """Generate RSI interpretation."""
342 | if rsi.is_overbought:
343 | return f"RSI at {rsi.value:.1f} indicates overbought conditions."
344 | elif rsi.is_oversold:
345 | return f"RSI at {rsi.value:.1f} indicates oversold conditions."
346 | else:
347 | return f"RSI at {rsi.value:.1f} is in neutral territory."
348 |
349 | def _interpret_macd(self, macd) -> str:
350 | """Generate MACD interpretation."""
351 | if macd.is_bullish_crossover:
352 | return "MACD shows bullish crossover - potential buy signal."
353 | elif macd.is_bearish_crossover:
354 | return "MACD shows bearish crossover - potential sell signal."
355 | else:
356 | return "MACD is neutral, no clear signal."
357 |
358 | def _interpret_bollinger(self, bb) -> str:
359 | """Generate Bollinger Bands interpretation."""
360 | if bb.is_squeeze:
361 | return "Bollinger Bands are squeezing - expect volatility breakout."
362 | elif bb.percent_b > 1:
363 | return "Price above upper band - potential overbought."
364 | elif bb.percent_b < 0:
365 | return "Price below lower band - potential oversold."
366 | else:
367 | return f"Price at {bb.percent_b:.1%} of bands range."
368 |
369 | def _interpret_stochastic(self, stoch) -> str:
370 | """Generate Stochastic interpretation."""
371 | if stoch.is_overbought:
372 | return f"Stochastic at {stoch.k_value:.1f} indicates overbought."
373 | elif stoch.is_oversold:
374 | return f"Stochastic at {stoch.k_value:.1f} indicates oversold."
375 | else:
376 | return f"Stochastic at {stoch.k_value:.1f} is neutral."
377 |
378 | def _interpret_volume(self, volume) -> str:
379 | """Generate volume interpretation."""
380 | if volume.unusual_activity:
381 | return f"Unusual volume at {volume.relative_volume:.1f}x average!"
382 | elif volume.is_high_volume:
383 | return f"High volume at {volume.relative_volume:.1f}x average."
384 | elif volume.is_low_volume:
385 | return f"Low volume at {volume.relative_volume:.1f}x average."
386 | else:
387 | return "Normal trading volume."
388 |
```
--------------------------------------------------------------------------------
/tests/providers/test_stock_data_simple.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Simplified unit tests for maverick_mcp.providers.stock_data module.
3 |
4 | This module contains focused tests for the Enhanced Stock Data Provider
5 | with proper mocking to avoid external dependencies.
6 | """
7 |
8 | from datetime import datetime, timedelta
9 | from unittest.mock import Mock, patch
10 |
11 | import pandas as pd
12 | import pytest
13 | from sqlalchemy.orm import Session
14 |
15 | from maverick_mcp.providers.stock_data import EnhancedStockDataProvider
16 |
17 |
18 | class TestEnhancedStockDataProviderCore:
19 | """Test core functionality of the Enhanced Stock Data Provider."""
20 |
21 | @pytest.fixture
22 | def mock_db_session(self):
23 | """Create a mock database session."""
24 | session = Mock(spec=Session)
25 | session.execute.return_value.fetchone.return_value = [1]
26 | return session
27 |
28 | @pytest.fixture
29 | def provider(self, mock_db_session):
30 | """Create a stock data provider with mocked dependencies."""
31 | with patch("maverick_mcp.providers.stock_data.get_db_session_read_only"):
32 | with patch("maverick_mcp.providers.stock_data.mcal.get_calendar"):
33 | provider = EnhancedStockDataProvider(db_session=mock_db_session)
34 | return provider
35 |
36 | def test_provider_initialization(self, mock_db_session):
37 | """Test provider initialization."""
38 | with patch("maverick_mcp.providers.stock_data.get_db_session_read_only"):
39 | with patch("maverick_mcp.providers.stock_data.mcal.get_calendar"):
40 | provider = EnhancedStockDataProvider(db_session=mock_db_session)
41 |
42 | assert provider.timeout == 30
43 | assert provider.max_retries == 3
44 | assert provider.cache_days == 1
45 | assert provider._db_session == mock_db_session
46 |
47 | def test_provider_initialization_without_session(self):
48 | """Test provider initialization without database session."""
49 | with patch("maverick_mcp.providers.stock_data.get_db_session_read_only"):
50 | with patch("maverick_mcp.providers.stock_data.mcal.get_calendar"):
51 | provider = EnhancedStockDataProvider()
52 |
53 | assert provider._db_session is None
54 |
55 | def test_get_stock_data_returns_dataframe(self, provider):
56 | """Test that get_stock_data returns a DataFrame."""
57 | # Test with use_cache=False to avoid database dependency
58 | result = provider.get_stock_data(
59 | "AAPL", "2024-01-01", "2024-01-31", use_cache=False
60 | )
61 |
62 | assert isinstance(result, pd.DataFrame)
63 | # Note: May be empty due to mocking, but should be DataFrame
64 |
65 | def test_get_maverick_recommendations_no_session(self):
66 | """Test getting Maverick recommendations without database session."""
67 | with patch("maverick_mcp.providers.stock_data.get_db_session_read_only"):
68 | with patch("maverick_mcp.providers.stock_data.mcal.get_calendar"):
69 | provider = EnhancedStockDataProvider(db_session=None)
70 |
71 | result = provider.get_maverick_recommendations()
72 |
73 | assert isinstance(result, list)
74 | assert len(result) == 0
75 |
76 | def test_get_maverick_bear_recommendations_no_session(self):
77 | """Test getting Maverick bear recommendations without database session."""
78 | with patch("maverick_mcp.providers.stock_data.get_db_session_read_only"):
79 | with patch("maverick_mcp.providers.stock_data.mcal.get_calendar"):
80 | provider = EnhancedStockDataProvider(db_session=None)
81 |
82 | result = provider.get_maverick_bear_recommendations()
83 |
84 | assert isinstance(result, list)
85 | assert len(result) == 0
86 |
87 | def test_get_trending_recommendations_no_session(self):
88 | """Test getting trending recommendations without database session."""
89 | with patch("maverick_mcp.providers.stock_data.get_db_session_read_only"):
90 | with patch("maverick_mcp.providers.stock_data.mcal.get_calendar"):
91 | provider = EnhancedStockDataProvider(db_session=None)
92 |
93 | result = provider.get_trending_recommendations()
94 |
95 | assert isinstance(result, list)
96 | # The provider now falls back to using default database connection
97 | # when no session is provided, so we expect actual results
98 | assert len(result) >= 0 # May return cached/fallback data
99 |
100 | @patch("maverick_mcp.providers.stock_data.get_latest_maverick_screening")
101 | def test_get_all_screening_recommendations(self, mock_screening, provider):
102 | """Test getting all screening recommendations."""
103 | mock_screening.return_value = {
104 | "maverick_stocks": [],
105 | "maverick_bear_stocks": [],
106 | "trending_stocks": [],
107 | }
108 |
109 | result = provider.get_all_screening_recommendations()
110 |
111 | assert isinstance(result, dict)
112 | assert "maverick_stocks" in result
113 | assert "maverick_bear_stocks" in result
114 | assert "trending_stocks" in result
115 |
116 | @patch("maverick_mcp.providers.stock_data.yf.Ticker")
117 | def test_get_stock_info_success(self, mock_ticker, provider):
118 | """Test getting stock information successfully."""
119 | mock_info = {
120 | "symbol": "AAPL",
121 | "longName": "Apple Inc.",
122 | "sector": "Technology",
123 | "industry": "Consumer Electronics",
124 | }
125 |
126 | mock_ticker.return_value.info = mock_info
127 |
128 | result = provider.get_stock_info("AAPL")
129 |
130 | assert isinstance(result, dict)
131 | assert result.get("symbol") == "AAPL"
132 |
133 | @pytest.mark.skip(reason="Flaky test with external dependencies")
134 | @patch("maverick_mcp.providers.stock_data.yf.Ticker")
135 | def test_get_stock_info_exception(self, mock_ticker, provider):
136 | """Test getting stock information with exception."""
137 | mock_ticker.side_effect = Exception("API Error")
138 |
139 | result = provider.get_stock_info("INVALID")
140 |
141 | assert isinstance(result, dict)
142 | assert result == {}
143 |
144 | @patch("maverick_mcp.providers.stock_data.yf.Ticker")
145 | def test_get_realtime_data_success(self, mock_ticker, provider):
146 | """Test getting real-time data successfully."""
147 | # Create mock data that matches the expected format
148 | mock_data = pd.DataFrame(
149 | {
150 | "Open": [150.0],
151 | "High": [155.0],
152 | "Low": [149.0],
153 | "Close": [153.0],
154 | "Volume": [1000000],
155 | },
156 | index=pd.DatetimeIndex([datetime.now()]),
157 | )
158 |
159 | mock_ticker.return_value.history.return_value = mock_data
160 | mock_ticker.return_value.info = {"previousClose": 151.0}
161 |
162 | result = provider.get_realtime_data("AAPL")
163 |
164 | assert isinstance(result, dict)
165 | assert "symbol" in result
166 | assert "price" in result
167 |
168 | @patch("maverick_mcp.providers.stock_data.yf.Ticker")
169 | def test_get_realtime_data_empty(self, mock_ticker, provider):
170 | """Test getting real-time data with empty result."""
171 | mock_ticker.return_value.history.return_value = pd.DataFrame()
172 |
173 | result = provider.get_realtime_data("INVALID")
174 |
175 | assert result is None
176 |
177 | @patch("maverick_mcp.providers.stock_data.yf.Ticker")
178 | def test_get_realtime_data_exception(self, mock_ticker, provider):
179 | """Test getting real-time data with exception."""
180 | mock_ticker.side_effect = Exception("API Error")
181 |
182 | result = provider.get_realtime_data("INVALID")
183 |
184 | assert result is None
185 |
186 | def test_get_all_realtime_data(self, provider):
187 | """Test getting real-time data for multiple symbols."""
188 | with patch.object(provider, "get_realtime_data") as mock_single:
189 | mock_single.side_effect = [
190 | {"symbol": "AAPL", "price": 153.0},
191 | {"symbol": "MSFT", "price": 420.0},
192 | ]
193 |
194 | result = provider.get_all_realtime_data(["AAPL", "MSFT"])
195 |
196 | assert isinstance(result, dict)
197 | assert "AAPL" in result
198 | assert "MSFT" in result
199 |
200 | def test_is_market_open(self, provider):
201 | """Test market open check."""
202 | with patch.object(provider.market_calendar, "open_at_time") as mock_open:
203 | mock_open.return_value = True
204 |
205 | result = provider.is_market_open()
206 |
207 | assert isinstance(result, bool)
208 |
209 | @patch("maverick_mcp.providers.stock_data.yf.Ticker")
210 | def test_get_news_success(self, mock_ticker, provider):
211 | """Test getting news successfully."""
212 | mock_news = [
213 | {
214 | "title": "Apple Reports Strong Q4 Earnings",
215 | "link": "https://example.com/news1",
216 | "providerPublishTime": datetime.now().timestamp(),
217 | "type": "STORY",
218 | },
219 | ]
220 |
221 | mock_ticker.return_value.news = mock_news
222 |
223 | result = provider.get_news("AAPL", limit=5)
224 |
225 | assert isinstance(result, pd.DataFrame)
226 |
227 | @patch("maverick_mcp.providers.stock_data.yf.Ticker")
228 | def test_get_news_exception(self, mock_ticker, provider):
229 | """Test getting news with exception."""
230 | mock_ticker.side_effect = Exception("API Error")
231 |
232 | result = provider.get_news("INVALID")
233 |
234 | assert isinstance(result, pd.DataFrame)
235 | assert result.empty
236 |
237 | @patch("maverick_mcp.providers.stock_data.yf.Ticker")
238 | def test_get_earnings_success(self, mock_ticker, provider):
239 | """Test getting earnings data successfully."""
240 | mock_ticker.return_value.calendar = pd.DataFrame()
241 | mock_ticker.return_value.earnings_dates = {}
242 | mock_ticker.return_value.earnings_trend = {}
243 |
244 | result = provider.get_earnings("AAPL")
245 |
246 | assert isinstance(result, dict)
247 | assert "earnings" in result or "earnings_dates" in result
248 |
249 | @patch("maverick_mcp.providers.stock_data.yf.Ticker")
250 | def test_get_earnings_exception(self, mock_ticker, provider):
251 | """Test getting earnings with exception."""
252 | mock_ticker.side_effect = Exception("API Error")
253 |
254 | result = provider.get_earnings("INVALID")
255 |
256 | assert isinstance(result, dict)
257 |
258 | @patch("maverick_mcp.providers.stock_data.yf.Ticker")
259 | def test_get_recommendations_success(self, mock_ticker, provider):
260 | """Test getting analyst recommendations successfully."""
261 | mock_recommendations = pd.DataFrame(
262 | {
263 | "period": ["0m", "-1m"],
264 | "strongBuy": [5, 4],
265 | "buy": [10, 12],
266 | "hold": [3, 3],
267 | "sell": [1, 1],
268 | "strongSell": [0, 0],
269 | }
270 | )
271 |
272 | mock_ticker.return_value.recommendations = mock_recommendations
273 |
274 | result = provider.get_recommendations("AAPL")
275 |
276 | assert isinstance(result, pd.DataFrame)
277 |
278 | @patch("maverick_mcp.providers.stock_data.yf.Ticker")
279 | def test_get_recommendations_exception(self, mock_ticker, provider):
280 | """Test getting recommendations with exception."""
281 | mock_ticker.side_effect = Exception("API Error")
282 |
283 | result = provider.get_recommendations("INVALID")
284 |
285 | assert isinstance(result, pd.DataFrame)
286 | assert result.empty
287 |
288 | @patch("maverick_mcp.providers.stock_data.yf.Ticker")
289 | def test_is_etf_true(self, mock_ticker, provider):
290 | """Test ETF detection for actual ETF."""
291 | mock_ticker.return_value.info = {"quoteType": "ETF"}
292 |
293 | result = provider.is_etf("SPY")
294 |
295 | assert result is True
296 |
297 | @patch("maverick_mcp.providers.stock_data.yf.Ticker")
298 | def test_is_etf_false(self, mock_ticker, provider):
299 | """Test ETF detection for stock."""
300 | mock_ticker.return_value.info = {"quoteType": "EQUITY"}
301 |
302 | result = provider.is_etf("AAPL")
303 |
304 | assert result is False
305 |
306 | @patch("maverick_mcp.providers.stock_data.yf.Ticker")
307 | def test_is_etf_exception(self, mock_ticker, provider):
308 | """Test ETF detection with exception."""
309 | mock_ticker.side_effect = Exception("API Error")
310 |
311 | result = provider.is_etf("INVALID")
312 |
313 | assert result is False
314 |
315 |
316 | class TestStockDataProviderErrorHandling:
317 | """Test error handling and edge cases."""
318 |
319 | def test_invalid_date_range(self):
320 | """Test with invalid date range."""
321 | with patch("maverick_mcp.providers.stock_data.get_db_session_read_only"):
322 | with patch("maverick_mcp.providers.stock_data.mcal.get_calendar"):
323 | provider = EnhancedStockDataProvider()
324 |
325 | # Test with end date before start date
326 | result = provider.get_stock_data(
327 | "AAPL", "2024-12-31", "2024-01-01", use_cache=False
328 | )
329 |
330 | assert isinstance(result, pd.DataFrame)
331 |
332 | @pytest.mark.skip(reason="Flaky test with external dependencies")
333 | def test_empty_symbol(self):
334 | """Test with empty symbol."""
335 | with patch("maverick_mcp.providers.stock_data.get_db_session_read_only"):
336 | with patch("maverick_mcp.providers.stock_data.mcal.get_calendar"):
337 | provider = EnhancedStockDataProvider()
338 |
339 | result = provider.get_stock_data(
340 | "", "2024-01-01", "2024-01-31", use_cache=False
341 | )
342 |
343 | assert isinstance(result, pd.DataFrame)
344 |
345 | def test_future_date_range(self):
346 | """Test with future dates."""
347 | with patch("maverick_mcp.providers.stock_data.get_db_session_read_only"):
348 | with patch("maverick_mcp.providers.stock_data.mcal.get_calendar"):
349 | provider = EnhancedStockDataProvider()
350 |
351 | future_date = (datetime.now() + timedelta(days=365)).strftime(
352 | "%Y-%m-%d"
353 | )
354 | result = provider.get_stock_data(
355 | "AAPL", future_date, future_date, use_cache=False
356 | )
357 |
358 | assert isinstance(result, pd.DataFrame)
359 |
360 | def test_database_connection_failure(self):
361 | """Test graceful handling of database connection failure."""
362 | mock_session = Mock(spec=Session)
363 | mock_session.execute.side_effect = Exception("Connection failed")
364 |
365 | with patch("maverick_mcp.providers.stock_data.get_db_session_read_only"):
366 | with patch("maverick_mcp.providers.stock_data.mcal.get_calendar"):
367 | # Should not raise exception, just log warning
368 | provider = EnhancedStockDataProvider(db_session=mock_session)
369 | assert provider is not None
370 |
371 |
372 | if __name__ == "__main__":
373 | pytest.main([__file__])
374 |
```
--------------------------------------------------------------------------------
/maverick_mcp/utils/fallback_strategies.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Fallback strategies for circuit breakers to provide graceful degradation.
3 | """
4 |
5 | import logging
6 | from abc import ABC, abstractmethod
7 | from datetime import UTC, datetime, timedelta
8 | from typing import TypeVar
9 |
10 | import pandas as pd
11 |
12 | from maverick_mcp.data.models import PriceCache, Stock
13 | from maverick_mcp.data.session_management import get_db_session_read_only as get_session
14 | from maverick_mcp.exceptions import DataNotFoundError
15 |
16 | logger = logging.getLogger(__name__)
17 |
18 | T = TypeVar("T")
19 |
20 |
21 | class FallbackStrategy[T](ABC):
22 | """Base class for fallback strategies."""
23 |
24 | @abstractmethod
25 | async def execute_async(self, *args, **kwargs) -> T:
26 | """Execute the fallback strategy asynchronously."""
27 | pass
28 |
29 | @abstractmethod
30 | def execute_sync(self, *args, **kwargs) -> T:
31 | """Execute the fallback strategy synchronously."""
32 | pass
33 |
34 |
35 | class FallbackChain[T]:
36 | """
37 | Chain of fallback strategies to execute in order.
38 | Stops at the first successful strategy.
39 | """
40 |
41 | def __init__(self, strategies: list[FallbackStrategy[T]]):
42 | """Initialize fallback chain with ordered strategies."""
43 | self.strategies = strategies
44 |
45 | async def execute_async(self, *args, **kwargs) -> T:
46 | """Execute strategies asynchronously until one succeeds."""
47 | last_error = None
48 |
49 | for i, strategy in enumerate(self.strategies):
50 | try:
51 | logger.info(
52 | f"Executing fallback strategy {i + 1}/{len(self.strategies)}: {strategy.__class__.__name__}"
53 | )
54 | result = await strategy.execute_async(*args, **kwargs)
55 | if result is not None: # Success
56 | return result
57 | except Exception as e:
58 | logger.warning(
59 | f"Fallback strategy {strategy.__class__.__name__} failed: {e}"
60 | )
61 | last_error = e
62 | continue
63 |
64 | # All strategies failed
65 | if last_error:
66 | raise last_error
67 | raise DataNotFoundError("All fallback strategies failed")
68 |
69 | def execute_sync(self, *args, **kwargs) -> T:
70 | """Execute strategies synchronously until one succeeds."""
71 | last_error = None
72 |
73 | for i, strategy in enumerate(self.strategies):
74 | try:
75 | logger.info(
76 | f"Executing fallback strategy {i + 1}/{len(self.strategies)}: {strategy.__class__.__name__}"
77 | )
78 | result = strategy.execute_sync(*args, **kwargs)
79 | if result is not None: # Success
80 | return result
81 | except Exception as e:
82 | logger.warning(
83 | f"Fallback strategy {strategy.__class__.__name__} failed: {e}"
84 | )
85 | last_error = e
86 | continue
87 |
88 | # All strategies failed
89 | if last_error:
90 | raise last_error
91 | raise DataNotFoundError("All fallback strategies failed")
92 |
93 |
94 | class CachedStockDataFallback(FallbackStrategy[pd.DataFrame]):
95 | """Fallback to cached stock data from database."""
96 |
97 | def __init__(self, max_age_days: int = 7):
98 | """
99 | Initialize cached data fallback.
100 |
101 | Args:
102 | max_age_days: Maximum age of cached data to use
103 | """
104 | self.max_age_days = max_age_days
105 |
106 | async def execute_async(
107 | self, symbol: str, start_date: str, end_date: str, **kwargs
108 | ) -> pd.DataFrame:
109 | """Get cached stock data asynchronously."""
110 | # For now, delegate to sync version
111 | return self.execute_sync(symbol, start_date, end_date, **kwargs)
112 |
113 | def execute_sync(
114 | self, symbol: str, start_date: str, end_date: str, **kwargs
115 | ) -> pd.DataFrame:
116 | """Get cached stock data synchronously."""
117 | try:
118 | with get_session() as session:
119 | # Check if stock exists
120 | stock = session.query(Stock).filter_by(symbol=symbol).first()
121 | if not stock:
122 | raise DataNotFoundError(f"Stock {symbol} not found in database")
123 |
124 | # Get cached prices
125 | cutoff_date = datetime.now(UTC) - timedelta(days=self.max_age_days)
126 |
127 | query = session.query(PriceCache).filter(
128 | PriceCache.stock_id == stock.id,
129 | PriceCache.date >= start_date,
130 | PriceCache.date <= end_date,
131 | PriceCache.updated_at >= cutoff_date, # Only use recent cache
132 | )
133 |
134 | results = query.all()
135 |
136 | if not results:
137 | raise DataNotFoundError(f"No cached data found for {symbol}")
138 |
139 | # Convert to DataFrame
140 | data = []
141 | for row in results:
142 | data.append(
143 | {
144 | "Date": pd.to_datetime(row.date),
145 | "Open": float(row.open),
146 | "High": float(row.high),
147 | "Low": float(row.low),
148 | "Close": float(row.close),
149 | "Volume": int(row.volume),
150 | }
151 | )
152 |
153 | df = pd.DataFrame(data)
154 | df.set_index("Date", inplace=True)
155 | df.sort_index(inplace=True)
156 |
157 | logger.info(
158 | f"Returned {len(df)} rows of cached data for {symbol} "
159 | f"(may be stale up to {self.max_age_days} days)"
160 | )
161 |
162 | return df
163 |
164 | except Exception as e:
165 | logger.error(f"Failed to get cached data for {symbol}: {e}")
166 | raise
167 |
168 |
169 | class StaleDataFallback(FallbackStrategy[pd.DataFrame]):
170 | """Return any available cached data regardless of age."""
171 |
172 | async def execute_async(
173 | self, symbol: str, start_date: str, end_date: str, **kwargs
174 | ) -> pd.DataFrame:
175 | """Get stale stock data asynchronously."""
176 | return self.execute_sync(symbol, start_date, end_date, **kwargs)
177 |
178 | def execute_sync(
179 | self, symbol: str, start_date: str, end_date: str, **kwargs
180 | ) -> pd.DataFrame:
181 | """Get stale stock data synchronously."""
182 | try:
183 | with get_session() as session:
184 | # Check if stock exists
185 | stock = session.query(Stock).filter_by(symbol=symbol).first()
186 | if not stock:
187 | raise DataNotFoundError(f"Stock {symbol} not found in database")
188 |
189 | # Get any cached prices
190 | query = session.query(PriceCache).filter(
191 | PriceCache.stock_id == stock.id,
192 | PriceCache.date >= start_date,
193 | PriceCache.date <= end_date,
194 | )
195 |
196 | results = query.all()
197 |
198 | if not results:
199 | raise DataNotFoundError(f"No cached data found for {symbol}")
200 |
201 | # Convert to DataFrame
202 | data = []
203 | for row in results:
204 | data.append(
205 | {
206 | "Date": pd.to_datetime(row.date),
207 | "Open": float(row.open),
208 | "High": float(row.high),
209 | "Low": float(row.low),
210 | "Close": float(row.close),
211 | "Volume": int(row.volume),
212 | }
213 | )
214 |
215 | df = pd.DataFrame(data)
216 | df.set_index("Date", inplace=True)
217 | df.sort_index(inplace=True)
218 |
219 | # Add warning about stale data
220 | oldest_update = min(row.updated_at for row in results)
221 | age_days = (datetime.now(UTC) - oldest_update).days
222 |
223 | logger.warning(
224 | f"Returning {len(df)} rows of STALE cached data for {symbol} "
225 | f"(data is up to {age_days} days old)"
226 | )
227 |
228 | # Add metadata to indicate stale data
229 | df.attrs["is_stale"] = True
230 | df.attrs["max_age_days"] = age_days
231 | df.attrs["warning"] = f"Data may be up to {age_days} days old"
232 |
233 | return df
234 |
235 | except Exception as e:
236 | logger.error(f"Failed to get stale cached data for {symbol}: {e}")
237 | raise
238 |
239 |
240 | class DefaultMarketDataFallback(FallbackStrategy[dict]):
241 | """Return default/neutral market data when APIs are down."""
242 |
243 | async def execute_async(self, mover_type: str = "gainers", **kwargs) -> dict:
244 | """Get default market data asynchronously."""
245 | return self.execute_sync(mover_type, **kwargs)
246 |
247 | def execute_sync(self, mover_type: str = "gainers", **kwargs) -> dict:
248 | """Get default market data synchronously."""
249 | logger.warning(f"Returning default {mover_type} data due to API failure")
250 |
251 | # Return empty but valid structure
252 | return {
253 | "movers": [],
254 | "metadata": {
255 | "source": "fallback",
256 | "timestamp": datetime.now(UTC).isoformat(),
257 | "is_fallback": True,
258 | "message": f"Market {mover_type} data temporarily unavailable",
259 | },
260 | }
261 |
262 |
263 | class CachedEconomicDataFallback(FallbackStrategy[pd.Series]):
264 | """Fallback to cached economic indicator data."""
265 |
266 | def __init__(self, default_values: dict[str, float] | None = None):
267 | """
268 | Initialize economic data fallback.
269 |
270 | Args:
271 | default_values: Default values for common indicators
272 | """
273 | self.default_values = default_values or {
274 | "GDP": 2.5, # Default GDP growth %
275 | "UNRATE": 4.0, # Default unemployment rate %
276 | "CPIAUCSL": 2.0, # Default inflation rate %
277 | "DFF": 5.0, # Default federal funds rate %
278 | "DGS10": 4.0, # Default 10-year treasury yield %
279 | "VIXCLS": 20.0, # Default VIX
280 | }
281 |
282 | async def execute_async(
283 | self, series_id: str, start_date: str, end_date: str, **kwargs
284 | ) -> pd.Series:
285 | """Get cached economic data asynchronously."""
286 | return self.execute_sync(series_id, start_date, end_date, **kwargs)
287 |
288 | def execute_sync(
289 | self, series_id: str, start_date: str, end_date: str, **kwargs
290 | ) -> pd.Series:
291 | """Get cached economic data synchronously."""
292 | # For now, return default values as a series
293 | logger.warning(f"Returning default value for {series_id} due to API failure")
294 |
295 | default_value = self.default_values.get(series_id, 0.0)
296 |
297 | # Create a simple series with the default value
298 | dates = pd.date_range(start=start_date, end=end_date, freq="D")
299 | series = pd.Series(default_value, index=dates, name=series_id)
300 |
301 | # Add metadata
302 | series.attrs["is_fallback"] = True
303 | series.attrs["source"] = "default"
304 | series.attrs["warning"] = f"Using default value of {default_value}"
305 |
306 | return series
307 |
308 |
309 | class EmptyNewsFallback(FallbackStrategy[dict]):
310 | """Return empty news data when news APIs are down."""
311 |
312 | async def execute_async(self, symbol: str, **kwargs) -> dict:
313 | """Get empty news data asynchronously."""
314 | return self.execute_sync(symbol, **kwargs)
315 |
316 | def execute_sync(self, symbol: str, **kwargs) -> dict:
317 | """Get empty news data synchronously."""
318 | logger.warning(f"News API unavailable for {symbol}, returning empty news")
319 |
320 | return {
321 | "articles": [],
322 | "metadata": {
323 | "symbol": symbol,
324 | "source": "fallback",
325 | "timestamp": datetime.now(UTC).isoformat(),
326 | "is_fallback": True,
327 | "message": "News sentiment analysis temporarily unavailable",
328 | },
329 | }
330 |
331 |
332 | class LastKnownQuoteFallback(FallbackStrategy[dict]):
333 | """Return last known quote from cache."""
334 |
335 | async def execute_async(self, symbol: str, **kwargs) -> dict:
336 | """Get last known quote asynchronously."""
337 | return self.execute_sync(symbol, **kwargs)
338 |
339 | def execute_sync(self, symbol: str, **kwargs) -> dict:
340 | """Get last known quote synchronously."""
341 | try:
342 | with get_session() as session:
343 | # Get stock
344 | stock = session.query(Stock).filter_by(symbol=symbol).first()
345 | if not stock:
346 | raise DataNotFoundError(f"Stock {symbol} not found")
347 |
348 | # Get most recent price
349 | latest_price = (
350 | session.query(PriceCache)
351 | .filter_by(stock_id=stock.id)
352 | .order_by(PriceCache.date.desc())
353 | .first()
354 | )
355 |
356 | if not latest_price:
357 | raise DataNotFoundError(f"No cached prices for {symbol}")
358 |
359 | age_days = (datetime.now(UTC).date() - latest_price.date).days
360 |
361 | logger.warning(
362 | f"Returning cached quote for {symbol} from {latest_price.date} "
363 | f"({age_days} days old)"
364 | )
365 |
366 | return {
367 | "symbol": symbol,
368 | "price": float(latest_price.close),
369 | "open": float(latest_price.open),
370 | "high": float(latest_price.high),
371 | "low": float(latest_price.low),
372 | "close": float(latest_price.close),
373 | "volume": int(latest_price.volume),
374 | "date": latest_price.date.isoformat(),
375 | "is_fallback": True,
376 | "age_days": age_days,
377 | "warning": f"Quote is {age_days} days old",
378 | }
379 |
380 | except Exception as e:
381 | logger.error(f"Failed to get cached quote for {symbol}: {e}")
382 | # Return a minimal quote structure
383 | return {
384 | "symbol": symbol,
385 | "price": 0.0,
386 | "is_fallback": True,
387 | "error": str(e),
388 | "warning": "No quote data available",
389 | }
390 |
391 |
392 | # Pre-configured fallback chains for common use cases
393 | STOCK_DATA_FALLBACK_CHAIN = FallbackChain[pd.DataFrame](
394 | [
395 | CachedStockDataFallback(max_age_days=1), # Try recent cache first
396 | CachedStockDataFallback(max_age_days=7), # Then older cache
397 | StaleDataFallback(), # Finally any cache
398 | ]
399 | )
400 |
401 | MARKET_DATA_FALLBACK = DefaultMarketDataFallback()
402 |
403 | ECONOMIC_DATA_FALLBACK = CachedEconomicDataFallback()
404 |
405 | NEWS_FALLBACK = EmptyNewsFallback()
406 |
407 | QUOTE_FALLBACK = LastKnownQuoteFallback()
408 |
```
--------------------------------------------------------------------------------
/maverick_mcp/config/logging_settings.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Structured logging configuration settings for the backtesting system.
3 |
4 | This module provides centralized configuration for all logging-related settings
5 | including debug mode, log levels, output formats, and performance monitoring.
6 | """
7 |
8 | import os
9 | from dataclasses import dataclass
10 | from pathlib import Path
11 | from typing import Any
12 |
13 |
14 | @dataclass
15 | class LoggingSettings:
16 | """Comprehensive logging configuration settings."""
17 |
18 | # Basic logging configuration
19 | log_level: str = "INFO"
20 | log_format: str = "json" # json or text
21 | enable_async_logging: bool = True
22 | console_output: str = "stderr" # stdout or stderr
23 |
24 | # File logging configuration
25 | enable_file_logging: bool = True
26 | log_file_path: str = "logs/backtesting.log"
27 | enable_log_rotation: bool = True
28 | max_log_size_mb: int = 10
29 | backup_count: int = 5
30 |
31 | # Debug mode configuration
32 | debug_enabled: bool = False
33 | verbose_modules: list[str] = None
34 | log_request_response: bool = False
35 | max_payload_length: int = 1000
36 |
37 | # Performance monitoring
38 | enable_performance_logging: bool = True
39 | performance_log_threshold_ms: float = 1000.0
40 | enable_resource_tracking: bool = True
41 | enable_business_metrics: bool = True
42 |
43 | # Async logging configuration
44 | async_log_queue_size: int = 10000
45 | async_log_flush_interval: float = 1.0
46 |
47 | # Sensitive data handling
48 | mask_sensitive_data: bool = True
49 | sensitive_field_patterns: list[str] = None
50 |
51 | # Remote logging (for future log aggregation)
52 | enable_remote_logging: bool = False
53 | remote_endpoint: str | None = None
54 | remote_api_key: str | None = None
55 |
56 | # Correlation and tracing
57 | enable_correlation_tracking: bool = True
58 | correlation_id_header: str = "X-Correlation-ID"
59 | enable_request_tracing: bool = True
60 |
61 | def __post_init__(self):
62 | """Initialize default values for mutable fields."""
63 | if self.verbose_modules is None:
64 | self.verbose_modules = []
65 |
66 | if self.sensitive_field_patterns is None:
67 | self.sensitive_field_patterns = [
68 | "password",
69 | "token",
70 | "key",
71 | "secret",
72 | "auth",
73 | "credential",
74 | "bearer",
75 | "session",
76 | "cookie",
77 | "api_key",
78 | "access_token",
79 | "refresh_token",
80 | "private",
81 | "confidential",
82 | ]
83 |
84 | @classmethod
85 | def from_env(cls) -> "LoggingSettings":
86 | """Create logging settings from environment variables."""
87 | return cls(
88 | log_level=os.getenv("MAVERICK_LOG_LEVEL", "INFO").upper(),
89 | log_format=os.getenv("MAVERICK_LOG_FORMAT", "json").lower(),
90 | enable_async_logging=os.getenv("MAVERICK_ASYNC_LOGGING", "true").lower()
91 | == "true",
92 | console_output=os.getenv("MAVERICK_CONSOLE_OUTPUT", "stderr").lower(),
93 | # File logging
94 | enable_file_logging=os.getenv("MAVERICK_FILE_LOGGING", "true").lower()
95 | == "true",
96 | log_file_path=os.getenv("MAVERICK_LOG_FILE", "logs/backtesting.log"),
97 | enable_log_rotation=os.getenv("MAVERICK_LOG_ROTATION", "true").lower()
98 | == "true",
99 | max_log_size_mb=int(os.getenv("MAVERICK_LOG_SIZE_MB", "10")),
100 | backup_count=int(os.getenv("MAVERICK_LOG_BACKUPS", "5")),
101 | # Debug configuration
102 | debug_enabled=os.getenv("MAVERICK_DEBUG", "false").lower() == "true",
103 | log_request_response=os.getenv("MAVERICK_LOG_REQUESTS", "false").lower()
104 | == "true",
105 | max_payload_length=int(os.getenv("MAVERICK_MAX_PAYLOAD", "1000")),
106 | # Performance monitoring
107 | enable_performance_logging=os.getenv(
108 | "MAVERICK_PERF_LOGGING", "true"
109 | ).lower()
110 | == "true",
111 | performance_log_threshold_ms=float(
112 | os.getenv("MAVERICK_PERF_THRESHOLD", "1000.0")
113 | ),
114 | enable_resource_tracking=os.getenv(
115 | "MAVERICK_RESOURCE_TRACKING", "true"
116 | ).lower()
117 | == "true",
118 | enable_business_metrics=os.getenv(
119 | "MAVERICK_BUSINESS_METRICS", "true"
120 | ).lower()
121 | == "true",
122 | # Async logging
123 | async_log_queue_size=int(os.getenv("MAVERICK_LOG_QUEUE_SIZE", "10000")),
124 | async_log_flush_interval=float(
125 | os.getenv("MAVERICK_LOG_FLUSH_INTERVAL", "1.0")
126 | ),
127 | # Sensitive data
128 | mask_sensitive_data=os.getenv("MAVERICK_MASK_SENSITIVE", "true").lower()
129 | == "true",
130 | # Remote logging
131 | enable_remote_logging=os.getenv("MAVERICK_REMOTE_LOGGING", "false").lower()
132 | == "true",
133 | remote_endpoint=os.getenv("MAVERICK_REMOTE_LOG_ENDPOINT"),
134 | remote_api_key=os.getenv("MAVERICK_REMOTE_LOG_API_KEY"),
135 | # Correlation and tracing
136 | enable_correlation_tracking=os.getenv(
137 | "MAVERICK_CORRELATION_TRACKING", "true"
138 | ).lower()
139 | == "true",
140 | correlation_id_header=os.getenv(
141 | "MAVERICK_CORRELATION_HEADER", "X-Correlation-ID"
142 | ),
143 | enable_request_tracing=os.getenv("MAVERICK_REQUEST_TRACING", "true").lower()
144 | == "true",
145 | )
146 |
147 | def to_dict(self) -> dict[str, Any]:
148 | """Convert settings to dictionary for serialization."""
149 | return {
150 | "log_level": self.log_level,
151 | "log_format": self.log_format,
152 | "enable_async_logging": self.enable_async_logging,
153 | "console_output": self.console_output,
154 | "enable_file_logging": self.enable_file_logging,
155 | "log_file_path": self.log_file_path,
156 | "enable_log_rotation": self.enable_log_rotation,
157 | "max_log_size_mb": self.max_log_size_mb,
158 | "backup_count": self.backup_count,
159 | "debug_enabled": self.debug_enabled,
160 | "verbose_modules": self.verbose_modules,
161 | "log_request_response": self.log_request_response,
162 | "max_payload_length": self.max_payload_length,
163 | "enable_performance_logging": self.enable_performance_logging,
164 | "performance_log_threshold_ms": self.performance_log_threshold_ms,
165 | "enable_resource_tracking": self.enable_resource_tracking,
166 | "enable_business_metrics": self.enable_business_metrics,
167 | "async_log_queue_size": self.async_log_queue_size,
168 | "async_log_flush_interval": self.async_log_flush_interval,
169 | "mask_sensitive_data": self.mask_sensitive_data,
170 | "sensitive_field_patterns": self.sensitive_field_patterns,
171 | "enable_remote_logging": self.enable_remote_logging,
172 | "remote_endpoint": self.remote_endpoint,
173 | "enable_correlation_tracking": self.enable_correlation_tracking,
174 | "correlation_id_header": self.correlation_id_header,
175 | "enable_request_tracing": self.enable_request_tracing,
176 | }
177 |
178 | def ensure_log_directory(self):
179 | """Ensure the log directory exists."""
180 | if self.enable_file_logging and self.log_file_path:
181 | log_path = Path(self.log_file_path)
182 | log_path.parent.mkdir(parents=True, exist_ok=True)
183 |
184 | def get_debug_modules(self) -> list[str]:
185 | """Get list of modules for debug logging."""
186 | if not self.debug_enabled:
187 | return []
188 |
189 | if not self.verbose_modules:
190 | # Default debug modules for backtesting
191 | return [
192 | "maverick_mcp.backtesting",
193 | "maverick_mcp.api.tools.backtesting",
194 | "maverick_mcp.providers",
195 | "maverick_mcp.data.cache",
196 | ]
197 |
198 | return self.verbose_modules
199 |
200 | def should_log_performance(self, duration_ms: float) -> bool:
201 | """Check if operation should be logged based on performance threshold."""
202 | if not self.enable_performance_logging:
203 | return False
204 | return duration_ms >= self.performance_log_threshold_ms
205 |
206 | def get_log_file_config(self) -> dict[str, Any]:
207 | """Get file logging configuration."""
208 | if not self.enable_file_logging:
209 | return {}
210 |
211 | config = {
212 | "filename": self.log_file_path,
213 | "mode": "a",
214 | "encoding": "utf-8",
215 | }
216 |
217 | if self.enable_log_rotation:
218 | config.update(
219 | {
220 | "maxBytes": self.max_log_size_mb * 1024 * 1024,
221 | "backupCount": self.backup_count,
222 | }
223 | )
224 |
225 | return config
226 |
227 | def get_performance_config(self) -> dict[str, Any]:
228 | """Get performance monitoring configuration."""
229 | return {
230 | "enabled": self.enable_performance_logging,
231 | "threshold_ms": self.performance_log_threshold_ms,
232 | "resource_tracking": self.enable_resource_tracking,
233 | "business_metrics": self.enable_business_metrics,
234 | }
235 |
236 | def get_debug_config(self) -> dict[str, Any]:
237 | """Get debug configuration."""
238 | return {
239 | "enabled": self.debug_enabled,
240 | "verbose_modules": self.get_debug_modules(),
241 | "log_request_response": self.log_request_response,
242 | "max_payload_length": self.max_payload_length,
243 | }
244 |
245 |
246 | # Environment-specific configurations
247 | class EnvironmentLogSettings:
248 | """Environment-specific logging configurations."""
249 |
250 | @staticmethod
251 | def development() -> LoggingSettings:
252 | """Development environment logging configuration."""
253 | return LoggingSettings(
254 | log_level="DEBUG",
255 | log_format="text",
256 | debug_enabled=True,
257 | log_request_response=True,
258 | enable_performance_logging=True,
259 | performance_log_threshold_ms=100.0, # Lower threshold for development
260 | console_output="stdout",
261 | enable_file_logging=True,
262 | log_file_path="logs/dev_backtesting.log",
263 | )
264 |
265 | @staticmethod
266 | def testing() -> LoggingSettings:
267 | """Testing environment logging configuration."""
268 | return LoggingSettings(
269 | log_level="WARNING",
270 | log_format="text",
271 | debug_enabled=False,
272 | enable_performance_logging=False,
273 | enable_file_logging=False,
274 | console_output="stdout",
275 | enable_async_logging=False, # Synchronous for tests
276 | )
277 |
278 | @staticmethod
279 | def production() -> LoggingSettings:
280 | """Production environment logging configuration."""
281 | return LoggingSettings(
282 | log_level="INFO",
283 | log_format="json",
284 | debug_enabled=False,
285 | log_request_response=False,
286 | enable_performance_logging=True,
287 | performance_log_threshold_ms=2000.0, # Higher threshold for production
288 | console_output="stderr",
289 | enable_file_logging=True,
290 | log_file_path="/var/log/maverick/backtesting.log",
291 | enable_log_rotation=True,
292 | max_log_size_mb=50, # Larger files in production
293 | backup_count=10,
294 | enable_remote_logging=True, # Enable for log aggregation
295 | )
296 |
297 |
298 | # Global logging settings instance
299 | _logging_settings: LoggingSettings | None = None
300 |
301 |
302 | def get_logging_settings() -> LoggingSettings:
303 | """Get global logging settings instance."""
304 | global _logging_settings
305 |
306 | if _logging_settings is None:
307 | environment = os.getenv("MAVERICK_ENVIRONMENT", "development").lower()
308 |
309 | if environment == "development":
310 | _logging_settings = EnvironmentLogSettings.development()
311 | elif environment == "testing":
312 | _logging_settings = EnvironmentLogSettings.testing()
313 | elif environment == "production":
314 | _logging_settings = EnvironmentLogSettings.production()
315 | else:
316 | # Default to environment variables
317 | _logging_settings = LoggingSettings.from_env()
318 |
319 | # Override with any environment variables
320 | env_overrides = LoggingSettings.from_env()
321 | for key, value in env_overrides.to_dict().items():
322 | if value is not None and value != getattr(LoggingSettings(), key):
323 | setattr(_logging_settings, key, value)
324 |
325 | # Ensure log directory exists
326 | _logging_settings.ensure_log_directory()
327 |
328 | return _logging_settings
329 |
330 |
331 | def configure_logging_for_environment(environment: str) -> LoggingSettings:
332 | """Configure logging for specific environment."""
333 | global _logging_settings
334 |
335 | if environment.lower() == "development":
336 | _logging_settings = EnvironmentLogSettings.development()
337 | elif environment.lower() == "testing":
338 | _logging_settings = EnvironmentLogSettings.testing()
339 | elif environment.lower() == "production":
340 | _logging_settings = EnvironmentLogSettings.production()
341 | else:
342 | raise ValueError(f"Unknown environment: {environment}")
343 |
344 | _logging_settings.ensure_log_directory()
345 | return _logging_settings
346 |
347 |
348 | # Logging configuration validation
349 | def validate_logging_settings(settings: LoggingSettings) -> list[str]:
350 | """Validate logging settings and return list of warnings/errors."""
351 | warnings = []
352 |
353 | # Validate log level
354 | valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
355 | if settings.log_level not in valid_levels:
356 | warnings.append(f"Invalid log level '{settings.log_level}', using INFO")
357 |
358 | # Validate log format
359 | valid_formats = ["json", "text"]
360 | if settings.log_format not in valid_formats:
361 | warnings.append(f"Invalid log format '{settings.log_format}', using json")
362 |
363 | # Validate console output
364 | valid_outputs = ["stdout", "stderr"]
365 | if settings.console_output not in valid_outputs:
366 | warnings.append(
367 | f"Invalid console output '{settings.console_output}', using stderr"
368 | )
369 |
370 | # Validate file logging
371 | if settings.enable_file_logging:
372 | try:
373 | log_path = Path(settings.log_file_path)
374 | log_path.parent.mkdir(parents=True, exist_ok=True)
375 | except Exception as e:
376 | warnings.append(f"Cannot create log directory: {e}")
377 |
378 | # Validate performance settings
379 | if settings.performance_log_threshold_ms < 0:
380 | warnings.append("Performance threshold cannot be negative, using 1000ms")
381 |
382 | # Validate async settings
383 | if settings.async_log_queue_size < 100:
384 | warnings.append("Async log queue size too small, using 1000")
385 |
386 | return warnings
387 |
```
--------------------------------------------------------------------------------
/maverick_mcp/exceptions.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Custom exception classes for MaverickMCP with comprehensive error handling.
3 |
4 | This module provides a unified exception hierarchy with proper error codes,
5 | HTTP status codes, and standardized error responses.
6 | """
7 |
8 | from typing import Any
9 |
10 |
11 | class MaverickException(Exception):
12 | """Base exception for all Maverick errors."""
13 |
14 | # Default values can be overridden by subclasses
15 | error_code: str = "INTERNAL_ERROR"
16 | status_code: int = 500
17 |
18 | def __init__(
19 | self,
20 | message: str,
21 | error_code: str | None = None,
22 | status_code: int | None = None,
23 | field: str | None = None,
24 | context: dict[str, Any] | None = None,
25 | recoverable: bool = True,
26 | ):
27 | super().__init__(message)
28 | self.message = message
29 | self.error_code = error_code or self.__class__.error_code
30 | self.status_code = status_code or self.__class__.status_code
31 | self.field = field
32 | self.context = context or {}
33 | self.recoverable = recoverable
34 |
35 | def to_dict(self) -> dict[str, Any]:
36 | """Convert exception to dictionary for API responses."""
37 | result: dict[str, Any] = {
38 | "code": self.error_code,
39 | "message": self.message,
40 | }
41 | if self.field:
42 | result["field"] = self.field
43 | if self.context:
44 | result["context"] = self.context
45 | return result
46 |
47 | def __repr__(self) -> str:
48 | """String representation of the exception."""
49 | return f"{self.__class__.__name__}('{self.message}', code='{self.error_code}')"
50 |
51 |
52 | # Validation exceptions
53 | class ValidationError(MaverickException):
54 | """Raised when input validation fails."""
55 |
56 |
57 | # Research and agent exceptions
58 | class ResearchError(MaverickException):
59 | """Raised when research operations fail."""
60 |
61 | error_code = "RESEARCH_ERROR"
62 | status_code = 500
63 |
64 | def __init__(
65 | self,
66 | message: str,
67 | research_type: str | None = None,
68 | provider: str | None = None,
69 | error_code: str | None = None,
70 | status_code: int | None = None,
71 | field: str | None = None,
72 | context: dict[str, Any] | None = None,
73 | recoverable: bool = True,
74 | ):
75 | super().__init__(
76 | message=message,
77 | error_code=error_code,
78 | status_code=status_code,
79 | field=field,
80 | context=context,
81 | recoverable=recoverable,
82 | )
83 | self.research_type = research_type
84 | self.provider = provider
85 |
86 | def to_dict(self) -> dict[str, Any]:
87 | """Convert exception to dictionary for API responses."""
88 | result = super().to_dict()
89 | if self.research_type:
90 | result["research_type"] = self.research_type
91 | if self.provider:
92 | result["provider"] = self.provider
93 | return result
94 |
95 |
96 | class WebSearchError(ResearchError):
97 | """Raised when web search operations fail."""
98 |
99 | error_code = "WEB_SEARCH_ERROR"
100 |
101 |
102 | class ContentAnalysisError(ResearchError):
103 | """Raised when content analysis fails."""
104 |
105 | error_code = "CONTENT_ANALYSIS_ERROR"
106 |
107 |
108 | class AgentExecutionError(MaverickException):
109 | """Raised when agent execution fails."""
110 |
111 | error_code = "AGENT_EXECUTION_ERROR"
112 | status_code = 500
113 |
114 |
115 | # Authentication/Authorization exceptions
116 | class AuthenticationError(MaverickException):
117 | """Raised when authentication fails."""
118 |
119 | error_code = "AUTHENTICATION_ERROR"
120 | status_code = 401
121 |
122 | def __init__(self, message: str = "Authentication failed", **kwargs):
123 | super().__init__(message, **kwargs)
124 |
125 |
126 | class AuthorizationError(MaverickException):
127 | """Raised when authorization fails."""
128 |
129 | error_code = "AUTHORIZATION_ERROR"
130 | status_code = 403
131 |
132 | def __init__(
133 | self,
134 | message: str = "Insufficient permissions",
135 | resource: str | None = None,
136 | action: str | None = None,
137 | **kwargs,
138 | ):
139 | if resource and action:
140 | message = f"Unauthorized access to {resource} for action '{action}'"
141 | super().__init__(message, **kwargs)
142 | if resource:
143 | self.context["resource"] = resource
144 | if action:
145 | self.context["action"] = action
146 |
147 |
148 | # Resource exceptions
149 | class NotFoundError(MaverickException):
150 | """Raised when a requested resource is not found."""
151 |
152 | error_code = "NOT_FOUND"
153 | status_code = 404
154 |
155 | def __init__(self, resource: str, identifier: str | None = None, **kwargs):
156 | message = f"{resource} not found"
157 | if identifier:
158 | message += f": {identifier}"
159 | super().__init__(message, **kwargs)
160 | self.context["resource"] = resource
161 | if identifier:
162 | self.context["identifier"] = identifier
163 |
164 |
165 | class ConflictError(MaverickException):
166 | """Raised when there's a conflict with existing data."""
167 |
168 | error_code = "CONFLICT"
169 | status_code = 409
170 |
171 | def __init__(self, message: str, field: str | None = None, **kwargs):
172 | super().__init__(message, field=field, **kwargs)
173 |
174 |
175 | # Rate limiting exceptions
176 | class RateLimitError(MaverickException):
177 | """Raised when rate limit is exceeded."""
178 |
179 | error_code = "RATE_LIMIT_EXCEEDED"
180 | status_code = 429
181 |
182 | def __init__(
183 | self,
184 | message: str = "Rate limit exceeded",
185 | retry_after: int | None = None,
186 | **kwargs,
187 | ):
188 | super().__init__(message, **kwargs)
189 | if retry_after:
190 | self.context["retry_after"] = retry_after
191 |
192 |
193 | # External service exceptions
194 | class ExternalServiceError(MaverickException):
195 | """Raised when an external service fails."""
196 |
197 | error_code = "EXTERNAL_SERVICE_ERROR"
198 | status_code = 503
199 |
200 | def __init__(
201 | self, service: str, message: str, original_error: str | None = None, **kwargs
202 | ):
203 | super().__init__(message, **kwargs)
204 | self.context["service"] = service
205 | if original_error:
206 | self.context["original_error"] = original_error
207 |
208 |
209 | # Data provider exceptions
210 | class DataProviderError(MaverickException):
211 | """Base exception for data provider errors."""
212 |
213 | error_code = "DATA_PROVIDER_ERROR"
214 | status_code = 503
215 |
216 | def __init__(self, provider: str, message: str, **kwargs):
217 | super().__init__(message, **kwargs)
218 | self.context["provider"] = provider
219 |
220 |
221 | class DataNotFoundError(DataProviderError):
222 | """Raised when requested data is not found."""
223 |
224 | error_code = "DATA_NOT_FOUND"
225 | status_code = 404
226 |
227 | def __init__(self, symbol: str, date_range: tuple | None = None, **kwargs):
228 | message = f"Data not found for symbol '{symbol}'"
229 | if date_range:
230 | message += f" in range {date_range[0]} to {date_range[1]}"
231 | super().__init__("cache", message, **kwargs)
232 | self.context["symbol"] = symbol
233 | if date_range:
234 | self.context["date_range"] = date_range
235 |
236 |
237 | class APIRateLimitError(DataProviderError):
238 | """Raised when API rate limit is exceeded."""
239 |
240 | error_code = "RATE_LIMIT_EXCEEDED"
241 | status_code = 429
242 |
243 | def __init__(self, provider: str, retry_after: int | None = None, **kwargs):
244 | message = f"Rate limit exceeded for {provider}"
245 | if retry_after:
246 | message += f". Retry after {retry_after} seconds"
247 | super().__init__(provider, message, recoverable=True, **kwargs)
248 | if retry_after:
249 | self.context["retry_after"] = retry_after
250 |
251 |
252 | class APIConnectionError(DataProviderError):
253 | """Raised when API connection fails."""
254 |
255 | error_code = "API_CONNECTION_ERROR"
256 | status_code = 503
257 |
258 | def __init__(self, provider: str, endpoint: str, reason: str, **kwargs):
259 | message = f"Failed to connect to {provider} at {endpoint}: {reason}"
260 | super().__init__(provider, message, recoverable=True, **kwargs)
261 | self.context["endpoint"] = endpoint
262 | self.context["connection_reason"] = reason
263 |
264 |
265 | # Database exceptions
266 | class DatabaseError(MaverickException):
267 | """Base exception for database errors."""
268 |
269 | error_code = "DATABASE_ERROR"
270 | status_code = 500
271 |
272 | def __init__(self, operation: str, message: str, **kwargs):
273 | super().__init__(message, **kwargs)
274 | self.context["operation"] = operation
275 |
276 |
277 | class DatabaseConnectionError(DatabaseError):
278 | """Raised when database connection fails."""
279 |
280 | error_code = "DATABASE_CONNECTION_ERROR"
281 | status_code = 503
282 |
283 | def __init__(self, reason: str, **kwargs):
284 | message = f"Database connection failed: {reason}"
285 | super().__init__("connect", message, recoverable=True, **kwargs)
286 |
287 |
288 | class DataIntegrityError(DatabaseError):
289 | """Raised when data integrity check fails."""
290 |
291 | error_code = "DATA_INTEGRITY_ERROR"
292 | status_code = 422
293 |
294 | def __init__(
295 | self,
296 | message: str,
297 | table: str | None = None,
298 | constraint: str | None = None,
299 | **kwargs,
300 | ):
301 | super().__init__("integrity_check", message, recoverable=False, **kwargs)
302 | if table:
303 | self.context["table"] = table
304 | if constraint:
305 | self.context["constraint"] = constraint
306 |
307 |
308 | # Cache exceptions
309 | class CacheError(MaverickException):
310 | """Base exception for cache errors."""
311 |
312 | error_code = "CACHE_ERROR"
313 | status_code = 503
314 |
315 | def __init__(self, operation: str, message: str, **kwargs):
316 | super().__init__(message, **kwargs)
317 | self.context["operation"] = operation
318 |
319 |
320 | class CacheConnectionError(CacheError):
321 | """Raised when cache connection fails."""
322 |
323 | error_code = "CACHE_CONNECTION_ERROR"
324 | status_code = 503
325 |
326 | def __init__(self, cache_type: str, reason: str, **kwargs):
327 | message = f"{cache_type} cache connection failed: {reason}"
328 | super().__init__("connect", message, recoverable=True, **kwargs)
329 | self.context["cache_type"] = cache_type
330 |
331 |
332 | # Configuration exceptions
333 | class ConfigurationError(MaverickException):
334 | """Raised when there's a configuration problem."""
335 |
336 | error_code = "CONFIGURATION_ERROR"
337 | status_code = 500
338 |
339 | def __init__(self, message: str, config_key: str | None = None, **kwargs):
340 | super().__init__(message, **kwargs)
341 | if config_key:
342 | self.context["config_key"] = config_key
343 |
344 |
345 | # Webhook exceptions
346 | class WebhookError(MaverickException):
347 | """Raised when webhook processing fails."""
348 |
349 | error_code = "WEBHOOK_ERROR"
350 | status_code = 400
351 |
352 | def __init__(
353 | self,
354 | message: str,
355 | event_type: str | None = None,
356 | event_id: str | None = None,
357 | **kwargs,
358 | ):
359 | super().__init__(message, **kwargs)
360 | if event_type:
361 | self.context["event_type"] = event_type
362 | if event_id:
363 | self.context["event_id"] = event_id
364 |
365 |
366 | # Agent-specific exceptions
367 | class AgentInitializationError(MaverickException):
368 | """Raised when agent initialization fails."""
369 |
370 | error_code = "AGENT_INIT_ERROR"
371 | status_code = 500
372 |
373 | def __init__(self, agent_type: str, reason: str, **kwargs):
374 | message = f"Failed to initialize {agent_type}: {reason}"
375 | super().__init__(message, **kwargs)
376 | self.context["agent_type"] = agent_type
377 | self.context["reason"] = reason
378 |
379 |
380 | class PersonaConfigurationError(MaverickException):
381 | """Raised when persona configuration is invalid."""
382 |
383 | error_code = "PERSONA_CONFIG_ERROR"
384 | status_code = 400
385 |
386 | def __init__(self, persona: str, valid_personas: list, **kwargs):
387 | message = (
388 | f"Invalid persona '{persona}'. Valid options: {', '.join(valid_personas)}"
389 | )
390 | super().__init__(message, **kwargs)
391 | self.context["invalid_persona"] = persona
392 | self.context["valid_personas"] = valid_personas
393 |
394 |
395 | class ToolRegistrationError(MaverickException):
396 | """Raised when tool registration fails."""
397 |
398 | error_code = "TOOL_REGISTRATION_ERROR"
399 | status_code = 500
400 |
401 | def __init__(self, tool_name: str, reason: str, **kwargs):
402 | message = f"Failed to register tool '{tool_name}': {reason}"
403 | super().__init__(message, **kwargs)
404 | self.context["tool_name"] = tool_name
405 | self.context["reason"] = reason
406 |
407 |
408 | # Circuit breaker exceptions
409 | class CircuitBreakerError(MaverickException):
410 | """Raised when circuit breaker is open."""
411 |
412 | error_code = "CIRCUIT_BREAKER_OPEN"
413 | status_code = 503
414 |
415 | def __init__(self, service: str, failure_count: int, threshold: int, **kwargs):
416 | message = (
417 | f"Circuit breaker open for {service}: {failure_count}/{threshold} failures"
418 | )
419 | super().__init__(message, recoverable=True, **kwargs)
420 | self.context["service"] = service
421 | self.context["failure_count"] = failure_count
422 | self.context["threshold"] = threshold
423 |
424 |
425 | # Parameter validation exceptions
426 | class ParameterValidationError(ValidationError):
427 | """Raised when function parameters are invalid."""
428 |
429 | error_code = "PARAMETER_VALIDATION_ERROR"
430 | status_code = 400
431 |
432 | def __init__(self, param_name: str, expected_type: str, actual_type: str, **kwargs):
433 | reason = f"Expected {expected_type}, got {actual_type}"
434 | message = f"Validation failed for '{param_name}': {reason}"
435 | super().__init__(message, field=param_name, **kwargs)
436 | self.context["expected_type"] = expected_type
437 | self.context["actual_type"] = actual_type
438 |
439 |
440 | # Error code constants
441 | ERROR_CODES = {
442 | "VALIDATION_ERROR": "Request validation failed",
443 | "AUTHENTICATION_ERROR": "Authentication failed",
444 | "AUTHORIZATION_ERROR": "Insufficient permissions",
445 | "NOT_FOUND": "Resource not found",
446 | "CONFLICT": "Resource conflict",
447 | "RATE_LIMIT_EXCEEDED": "Too many requests",
448 | "EXTERNAL_SERVICE_ERROR": "External service unavailable",
449 | "DATA_PROVIDER_ERROR": "Data provider error",
450 | "DATA_NOT_FOUND": "Data not found",
451 | "API_CONNECTION_ERROR": "API connection failed",
452 | "DATABASE_ERROR": "Database error",
453 | "DATABASE_CONNECTION_ERROR": "Database connection failed",
454 | "DATA_INTEGRITY_ERROR": "Data integrity violation",
455 | "CACHE_ERROR": "Cache error",
456 | "CACHE_CONNECTION_ERROR": "Cache connection failed",
457 | "CONFIGURATION_ERROR": "Configuration error",
458 | "WEBHOOK_ERROR": "Webhook processing failed",
459 | "AGENT_INIT_ERROR": "Agent initialization failed",
460 | "PERSONA_CONFIG_ERROR": "Invalid persona configuration",
461 | "TOOL_REGISTRATION_ERROR": "Tool registration failed",
462 | "CIRCUIT_BREAKER_OPEN": "Service unavailable - circuit breaker open",
463 | "PARAMETER_VALIDATION_ERROR": "Invalid parameter",
464 | "INTERNAL_ERROR": "Internal server error",
465 | }
466 |
467 |
468 | def get_error_message(code: str) -> str:
469 | """Get human-readable message for error code."""
470 | return ERROR_CODES.get(code, "Unknown error")
471 |
472 |
473 | # Backward compatibility alias
474 | MaverickMCPError = MaverickException
475 |
```
--------------------------------------------------------------------------------
/maverick_mcp/utils/logging.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Structured logging with request context for MaverickMCP.
3 |
4 | This module provides structured logging capabilities that:
5 | - Capture request context (request ID, user, tool name)
6 | - Track performance metrics (duration, memory usage)
7 | - Support JSON output for log aggregation
8 | - Integrate with FastMCP's context system
9 | """
10 |
11 | import functools
12 | import json
13 | import logging
14 | import sys
15 | import time
16 | import traceback
17 | import uuid
18 | from collections.abc import Callable
19 | from contextvars import ContextVar
20 | from datetime import UTC, datetime
21 | from typing import Any
22 |
23 | import psutil
24 | from fastmcp import Context as MCPContext
25 |
26 | # Context variables for request tracking
27 | request_id_var: ContextVar[str | None] = ContextVar("request_id", default=None) # type: ignore[assignment]
28 | user_id_var: ContextVar[str | None] = ContextVar("user_id", default=None) # type: ignore[assignment]
29 | tool_name_var: ContextVar[str | None] = ContextVar("tool_name", default=None) # type: ignore[assignment]
30 | request_start_var: ContextVar[float | None] = ContextVar("request_start", default=None) # type: ignore[assignment]
31 |
32 |
33 | class StructuredFormatter(logging.Formatter):
34 | """Custom formatter that outputs structured JSON logs."""
35 |
36 | def format(self, record: logging.LogRecord) -> str:
37 | """Format log record as structured JSON."""
38 | # Base log data
39 | log_data = {
40 | "timestamp": datetime.now(UTC).isoformat(),
41 | "level": record.levelname,
42 | "logger": record.name,
43 | "message": record.getMessage(),
44 | "module": record.module,
45 | "function": record.funcName,
46 | "line": record.lineno,
47 | }
48 |
49 | # Add request context if available
50 | request_id = request_id_var.get()
51 | if request_id:
52 | log_data["request_id"] = request_id
53 |
54 | user_id = user_id_var.get()
55 | if user_id:
56 | log_data["user_id"] = user_id
57 |
58 | tool_name = tool_name_var.get()
59 | if tool_name:
60 | log_data["tool_name"] = tool_name
61 |
62 | # Add request duration if available
63 | request_start = request_start_var.get()
64 | if request_start:
65 | log_data["duration_ms"] = int((time.time() - request_start) * 1000)
66 |
67 | # Add exception info if present
68 | if record.exc_info:
69 | log_data["exception"] = {
70 | "type": record.exc_info[0].__name__
71 | if record.exc_info[0]
72 | else "Unknown",
73 | "message": str(record.exc_info[1]),
74 | "traceback": traceback.format_exception(*record.exc_info),
75 | }
76 |
77 | # Add any extra fields
78 | for key, value in record.__dict__.items():
79 | if key not in [
80 | "name",
81 | "msg",
82 | "args",
83 | "created",
84 | "filename",
85 | "funcName",
86 | "levelname",
87 | "levelno",
88 | "lineno",
89 | "module",
90 | "msecs",
91 | "pathname",
92 | "process",
93 | "processName",
94 | "relativeCreated",
95 | "thread",
96 | "threadName",
97 | "exc_info",
98 | "exc_text",
99 | "stack_info",
100 | ]:
101 | log_data[key] = value
102 |
103 | return json.dumps(log_data)
104 |
105 |
106 | class RequestContextLogger:
107 | """Logger that automatically includes request context."""
108 |
109 | def __init__(self, logger: logging.Logger):
110 | self.logger = logger
111 |
112 | def _log_with_context(self, level: int, msg: str, *args, **kwargs):
113 | """Log with additional context fields."""
114 | extra = kwargs.get("extra", {})
115 |
116 | # Add performance metrics
117 | process = psutil.Process()
118 | extra["memory_mb"] = process.memory_info().rss / 1024 / 1024
119 | extra["cpu_percent"] = process.cpu_percent(interval=0.1)
120 |
121 | kwargs["extra"] = extra
122 | self.logger.log(level, msg, *args, **kwargs)
123 |
124 | def debug(self, msg: str, *args, **kwargs):
125 | self._log_with_context(logging.DEBUG, msg, *args, **kwargs)
126 |
127 | def info(self, msg: str, *args, **kwargs):
128 | self._log_with_context(logging.INFO, msg, *args, **kwargs)
129 |
130 | def warning(self, msg: str, *args, **kwargs):
131 | self._log_with_context(logging.WARNING, msg, *args, **kwargs)
132 |
133 | def error(self, msg: str, *args, **kwargs):
134 | self._log_with_context(logging.ERROR, msg, *args, **kwargs)
135 |
136 | def critical(self, msg: str, *args, **kwargs):
137 | self._log_with_context(logging.CRITICAL, msg, *args, **kwargs)
138 |
139 |
140 | def setup_structured_logging(
141 | log_level: str = "INFO",
142 | log_format: str = "json",
143 | log_file: str | None = None,
144 | use_stderr: bool = False,
145 | ) -> None:
146 | """
147 | Set up structured logging for the application.
148 |
149 | Args:
150 | log_level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
151 | log_format: Output format ("json" or "text")
152 | log_file: Optional log file path
153 | use_stderr: If True, send console logs to stderr instead of stdout
154 | """
155 | # Configure warnings filter to suppress known deprecation warnings
156 | import warnings
157 |
158 | # Suppress pandas_ta pkg_resources deprecation warning
159 | warnings.filterwarnings(
160 | "ignore",
161 | message="pkg_resources is deprecated as an API.*",
162 | category=UserWarning,
163 | module="pandas_ta.*",
164 | )
165 |
166 | # Suppress passlib crypt deprecation warning
167 | warnings.filterwarnings(
168 | "ignore",
169 | message="'crypt' is deprecated and slated for removal.*",
170 | category=DeprecationWarning,
171 | module="passlib.*",
172 | )
173 |
174 | # Suppress LangChain Pydantic v1 deprecation warnings
175 | warnings.filterwarnings(
176 | "ignore",
177 | message=".*pydantic.* is deprecated.*",
178 | category=DeprecationWarning,
179 | module="langchain.*",
180 | )
181 |
182 | # Suppress Starlette cookie deprecation warnings
183 | warnings.filterwarnings(
184 | "ignore",
185 | message=".*cookie.*deprecated.*",
186 | category=DeprecationWarning,
187 | module="starlette.*",
188 | )
189 |
190 | root_logger = logging.getLogger()
191 | root_logger.setLevel(getattr(logging, log_level.upper()))
192 |
193 | # Remove existing handlers
194 | for handler in root_logger.handlers[:]:
195 | root_logger.removeHandler(handler)
196 |
197 | # Console handler - use stderr for stdio transport to avoid interfering with JSON-RPC
198 | console_handler = logging.StreamHandler(sys.stderr if use_stderr else sys.stdout)
199 |
200 | if log_format == "json":
201 | console_handler.setFormatter(StructuredFormatter())
202 | else:
203 | console_handler.setFormatter(
204 | logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
205 | )
206 |
207 | root_logger.addHandler(console_handler)
208 |
209 | # File handler if specified
210 | if log_file:
211 | file_handler = logging.FileHandler(log_file)
212 | file_handler.setFormatter(StructuredFormatter())
213 | root_logger.addHandler(file_handler)
214 |
215 |
216 | def get_logger(name: str) -> RequestContextLogger:
217 | """Get a logger with request context support."""
218 | return RequestContextLogger(logging.getLogger(name))
219 |
220 |
221 | def log_tool_execution(func: Callable) -> Callable:
222 | """
223 | Decorator to log tool execution with context.
224 |
225 | Automatically captures:
226 | - Tool name
227 | - Request ID
228 | - Execution time
229 | - Success/failure status
230 | - Input parameters (sanitized)
231 | """
232 |
233 | @functools.wraps(func)
234 | async def wrapper(*args, **kwargs):
235 | # Generate request ID
236 | request_id = str(uuid.uuid4())
237 | request_id_var.set(request_id)
238 |
239 | # Set tool name
240 | tool_name = getattr(func, "__name__", "unknown_function")
241 | tool_name_var.set(tool_name)
242 |
243 | # Set start time
244 | start_time = time.time()
245 | request_start_var.set(start_time)
246 |
247 | # Get logger
248 | logger = get_logger(f"maverick_mcp.tools.{tool_name}")
249 |
250 | # Check if context is available (but not used in this decorator)
251 | for arg in args:
252 | if isinstance(arg, MCPContext):
253 | break
254 |
255 | # Sanitize parameters for logging (hide sensitive data)
256 | safe_kwargs = _sanitize_params(kwargs)
257 |
258 | logger.info(
259 | "Tool execution started",
260 | extra={
261 | "tool_name": tool_name,
262 | "request_id": request_id,
263 | "parameters": safe_kwargs,
264 | },
265 | )
266 |
267 | try:
268 | # Execute the tool
269 | result = await func(*args, **kwargs)
270 |
271 | # Log success
272 | duration_ms = int((time.time() - start_time) * 1000)
273 | logger.info(
274 | "Tool execution completed successfully",
275 | extra={
276 | "tool_name": tool_name,
277 | "request_id": request_id,
278 | "duration_ms": duration_ms,
279 | "status": "success",
280 | },
281 | )
282 |
283 | return result
284 |
285 | except Exception as e:
286 | # Log error
287 | duration_ms = int((time.time() - start_time) * 1000)
288 | logger.error(
289 | f"Tool execution failed: {str(e)}",
290 | exc_info=True,
291 | extra={
292 | "tool_name": tool_name,
293 | "request_id": request_id,
294 | "duration_ms": duration_ms,
295 | "status": "error",
296 | "error_type": type(e).__name__,
297 | },
298 | )
299 | raise
300 |
301 | finally:
302 | # Clear context vars
303 | request_id_var.set(None)
304 | tool_name_var.set(None)
305 | request_start_var.set(None)
306 |
307 | return wrapper
308 |
309 |
310 | def _sanitize_params(params: dict[str, Any]) -> dict[str, Any]:
311 | """
312 | Sanitize parameters for logging by hiding sensitive data.
313 |
314 | Args:
315 | params: Original parameters
316 |
317 | Returns:
318 | Sanitized parameters safe for logging
319 | """
320 | sensitive_keys = {"password", "api_key", "secret", "token", "auth"}
321 | sanitized = {}
322 |
323 | for key, value in params.items():
324 | if any(sensitive in key.lower() for sensitive in sensitive_keys):
325 | sanitized[key] = "***REDACTED***"
326 | elif isinstance(value, dict):
327 | sanitized[key] = _sanitize_params(value)
328 | elif isinstance(value, list) and len(value) > 10:
329 | # Truncate long lists
330 | sanitized[key] = f"[{len(value)} items]"
331 | elif isinstance(value, str) and len(value) > 1000:
332 | # Truncate long strings
333 | sanitized[key] = value[:100] + f"... ({len(value)} chars total)"
334 | else:
335 | sanitized[key] = value
336 |
337 | return sanitized
338 |
339 |
340 | def log_database_query(
341 | query: str, params: dict | None = None, duration_ms: int | None = None
342 | ):
343 | """Log database query execution."""
344 | logger = get_logger("maverick_mcp.database")
345 |
346 | extra = {"query_type": _get_query_type(query), "query_length": len(query)}
347 |
348 | if duration_ms is not None:
349 | extra["duration_ms"] = duration_ms
350 | extra["slow_query"] = duration_ms > 1000 # Mark queries over 1 second as slow
351 |
352 | if params:
353 | extra["param_count"] = len(params)
354 |
355 | logger.info("Database query executed", extra=extra)
356 |
357 | # Log the actual query at debug level
358 | logger.debug(
359 | f"Query details: {query[:200]}..."
360 | if len(query) > 200
361 | else f"Query details: {query}",
362 | extra={"params": _sanitize_params(params) if params else None},
363 | )
364 |
365 |
366 | def _get_query_type(query: str) -> str:
367 | """Extract query type from SQL query."""
368 | query_upper = query.strip().upper()
369 | if query_upper.startswith("SELECT"):
370 | return "SELECT"
371 | elif query_upper.startswith("INSERT"):
372 | return "INSERT"
373 | elif query_upper.startswith("UPDATE"):
374 | return "UPDATE"
375 | elif query_upper.startswith("DELETE"):
376 | return "DELETE"
377 | elif query_upper.startswith("CREATE"):
378 | return "CREATE"
379 | elif query_upper.startswith("DROP"):
380 | return "DROP"
381 | else:
382 | return "OTHER"
383 |
384 |
385 | def log_cache_operation(
386 | operation: str, key: str, hit: bool = False, duration_ms: int | None = None
387 | ):
388 | """Log cache operation."""
389 | logger = get_logger("maverick_mcp.cache")
390 |
391 | extra = {"operation": operation, "cache_key": key, "cache_hit": hit}
392 |
393 | if duration_ms is not None:
394 | extra["duration_ms"] = duration_ms
395 |
396 | logger.info(f"Cache {operation}: {'hit' if hit else 'miss'} for {key}", extra=extra)
397 |
398 |
399 | def log_external_api_call(
400 | service: str,
401 | endpoint: str,
402 | method: str = "GET",
403 | status_code: int | None = None,
404 | duration_ms: int | None = None,
405 | error: str | None = None,
406 | ):
407 | """Log external API call."""
408 | logger = get_logger("maverick_mcp.external_api")
409 |
410 | extra: dict[str, Any] = {"service": service, "endpoint": endpoint, "method": method}
411 |
412 | if status_code is not None:
413 | extra["status_code"] = status_code
414 | extra["success"] = 200 <= status_code < 300
415 |
416 | if duration_ms is not None:
417 | extra["duration_ms"] = duration_ms
418 |
419 | if error:
420 | extra["error"] = error
421 | logger.error(
422 | f"External API call failed: {service} {method} {endpoint}", extra=extra
423 | )
424 | else:
425 | logger.info(f"External API call: {service} {method} {endpoint}", extra=extra)
426 |
427 |
428 | # Performance monitoring context manager
429 | class PerformanceMonitor:
430 | """Context manager for monitoring performance of code blocks."""
431 |
432 | def __init__(self, operation_name: str, logger: RequestContextLogger | None = None):
433 | self.operation_name = operation_name
434 | self.logger = logger or get_logger("maverick_mcp.performance")
435 | self.start_time: float | None = None
436 | self.start_memory: float | None = None
437 |
438 | def __enter__(self):
439 | self.start_time = time.time()
440 | process = psutil.Process()
441 | self.start_memory = process.memory_info().rss / 1024 / 1024
442 | return self
443 |
444 | def __exit__(self, exc_type, exc_val, exc_tb):
445 | duration_ms = int((time.time() - (self.start_time or 0)) * 1000)
446 | process = psutil.Process()
447 | end_memory = process.memory_info().rss / 1024 / 1024
448 | memory_delta = end_memory - (self.start_memory or 0)
449 |
450 | extra = {
451 | "operation": self.operation_name,
452 | "duration_ms": duration_ms,
453 | "memory_delta_mb": round(memory_delta, 2),
454 | "success": exc_type is None,
455 | }
456 |
457 | if exc_type:
458 | extra["error_type"] = exc_type.__name__
459 | self.logger.error(
460 | f"Operation '{self.operation_name}' failed after {duration_ms}ms",
461 | extra=extra,
462 | )
463 | else:
464 | self.logger.info(
465 | f"Operation '{self.operation_name}' completed in {duration_ms}ms",
466 | extra=extra,
467 | )
468 |
```
--------------------------------------------------------------------------------
/tests/test_circuit_breaker.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Comprehensive tests for the circuit breaker system.
3 | """
4 |
5 | import asyncio
6 | import time
7 | from unittest.mock import patch
8 |
9 | import pytest
10 |
11 | from maverick_mcp.exceptions import CircuitBreakerError, ExternalServiceError
12 | from maverick_mcp.utils.circuit_breaker import (
13 | CircuitBreakerConfig,
14 | CircuitBreakerMetrics,
15 | CircuitState,
16 | EnhancedCircuitBreaker,
17 | FailureDetectionStrategy,
18 | circuit_breaker,
19 | get_all_circuit_breakers,
20 | get_circuit_breaker,
21 | get_circuit_breaker_status,
22 | reset_all_circuit_breakers,
23 | )
24 |
25 |
26 | class TestCircuitBreakerMetrics:
27 | """Test circuit breaker metrics collection."""
28 |
29 | def test_metrics_initialization(self):
30 | """Test metrics are initialized correctly."""
31 | metrics = CircuitBreakerMetrics(window_size=10)
32 | stats = metrics.get_stats()
33 |
34 | assert stats["total_calls"] == 0
35 | assert stats["success_rate"] == 1.0
36 | assert stats["failure_rate"] == 0.0
37 | assert stats["avg_duration"] == 0.0
38 | assert stats["timeout_rate"] == 0.0
39 |
40 | def test_record_successful_call(self):
41 | """Test recording successful calls."""
42 | metrics = CircuitBreakerMetrics()
43 |
44 | metrics.record_call(True, 0.5)
45 | metrics.record_call(True, 1.0)
46 |
47 | stats = metrics.get_stats()
48 | assert stats["total_calls"] == 2
49 | assert stats["success_rate"] == 1.0
50 | assert stats["failure_rate"] == 0.0
51 | assert stats["avg_duration"] == 0.75
52 |
53 | def test_record_failed_call(self):
54 | """Test recording failed calls."""
55 | metrics = CircuitBreakerMetrics()
56 |
57 | metrics.record_call(False, 2.0)
58 | metrics.record_call(True, 1.0)
59 |
60 | stats = metrics.get_stats()
61 | assert stats["total_calls"] == 2
62 | assert stats["success_rate"] == 0.5
63 | assert stats["failure_rate"] == 0.5
64 | assert stats["avg_duration"] == 1.5
65 |
66 | def test_window_cleanup(self):
67 | """Test old data is cleaned up."""
68 | metrics = CircuitBreakerMetrics(window_size=1) # 1 second window
69 |
70 | metrics.record_call(True, 0.5)
71 | time.sleep(1.1) # Wait for window to expire
72 | metrics.record_call(True, 1.0)
73 |
74 | stats = metrics.get_stats()
75 | assert stats["total_calls"] == 1 # Old call should be removed
76 |
77 |
78 | class TestEnhancedCircuitBreaker:
79 | """Test enhanced circuit breaker functionality."""
80 |
81 | def test_circuit_breaker_initialization(self):
82 | """Test circuit breaker is initialized correctly."""
83 | config = CircuitBreakerConfig(
84 | name="test",
85 | failure_threshold=3,
86 | recovery_timeout=5,
87 | )
88 | breaker = EnhancedCircuitBreaker(config)
89 |
90 | assert breaker.state == CircuitState.CLOSED
91 | assert breaker.is_closed
92 | assert not breaker.is_open
93 |
94 | def test_consecutive_failures_opens_circuit(self):
95 | """Test circuit opens after consecutive failures."""
96 | config = CircuitBreakerConfig(
97 | name="test",
98 | failure_threshold=3,
99 | detection_strategy=FailureDetectionStrategy.CONSECUTIVE_FAILURES,
100 | )
101 | breaker = EnhancedCircuitBreaker(config)
102 |
103 | # Fail 3 times
104 | for _ in range(3):
105 | try:
106 | breaker.call_sync(lambda: 1 / 0)
107 | except ZeroDivisionError:
108 | pass
109 |
110 | assert breaker.state == CircuitState.OPEN
111 | assert breaker.is_open
112 |
113 | def test_failure_rate_opens_circuit(self):
114 | """Test circuit opens based on failure rate."""
115 | config = CircuitBreakerConfig(
116 | name="test",
117 | failure_rate_threshold=0.5,
118 | detection_strategy=FailureDetectionStrategy.FAILURE_RATE,
119 | )
120 | breaker = EnhancedCircuitBreaker(config)
121 |
122 | # Need minimum calls for rate calculation
123 | for i in range(10):
124 | try:
125 | if i % 2 == 0: # 50% failure rate
126 | breaker.call_sync(lambda: 1 / 0)
127 | else:
128 | breaker.call_sync(lambda: "success")
129 | except (ZeroDivisionError, CircuitBreakerError):
130 | pass
131 |
132 | stats = breaker._metrics.get_stats()
133 | assert stats["failure_rate"] >= 0.5
134 | assert breaker.state == CircuitState.OPEN
135 |
136 | def test_circuit_breaker_blocks_calls_when_open(self):
137 | """Test circuit breaker blocks calls when open."""
138 | config = CircuitBreakerConfig(
139 | name="test",
140 | failure_threshold=1,
141 | recovery_timeout=60,
142 | )
143 | breaker = EnhancedCircuitBreaker(config)
144 |
145 | # Open the circuit
146 | try:
147 | breaker.call_sync(lambda: 1 / 0)
148 | except ZeroDivisionError:
149 | pass
150 |
151 | # Next call should be blocked
152 | with pytest.raises(CircuitBreakerError) as exc_info:
153 | breaker.call_sync(lambda: "success")
154 |
155 | assert "Circuit breaker open for test:" in str(exc_info.value)
156 | assert exc_info.value.context["state"] == "open"
157 |
158 | def test_circuit_breaker_recovery(self):
159 | """Test circuit breaker recovery to half-open then closed."""
160 | config = CircuitBreakerConfig(
161 | name="test",
162 | failure_threshold=1,
163 | recovery_timeout=1, # 1 second
164 | success_threshold=2,
165 | )
166 | breaker = EnhancedCircuitBreaker(config)
167 |
168 | # Open the circuit
169 | try:
170 | breaker.call_sync(lambda: 1 / 0)
171 | except ZeroDivisionError:
172 | pass
173 |
174 | assert breaker.state == CircuitState.OPEN
175 |
176 | # Wait for recovery timeout
177 | time.sleep(1.1)
178 |
179 | # First successful call should move to half-open
180 | result = breaker.call_sync(lambda: "success1")
181 | assert result == "success1"
182 | assert breaker.state == CircuitState.HALF_OPEN
183 |
184 | # Second successful call should close the circuit
185 | result = breaker.call_sync(lambda: "success2")
186 | assert result == "success2"
187 | assert breaker.state == CircuitState.CLOSED
188 |
189 | def test_half_open_failure_reopens(self):
190 | """Test failure in half-open state reopens circuit."""
191 | config = CircuitBreakerConfig(
192 | name="test",
193 | failure_threshold=1,
194 | recovery_timeout=1,
195 | )
196 | breaker = EnhancedCircuitBreaker(config)
197 |
198 | # Open the circuit
199 | try:
200 | breaker.call_sync(lambda: 1 / 0)
201 | except ZeroDivisionError:
202 | pass
203 |
204 | # Wait for recovery
205 | time.sleep(1.1)
206 |
207 | # Fail in half-open state
208 | try:
209 | breaker.call_sync(lambda: 1 / 0)
210 | except ZeroDivisionError:
211 | pass
212 |
213 | assert breaker.state == CircuitState.OPEN
214 |
215 | def test_manual_reset(self):
216 | """Test manual circuit breaker reset."""
217 | config = CircuitBreakerConfig(
218 | name="test",
219 | failure_threshold=1,
220 | )
221 | breaker = EnhancedCircuitBreaker(config)
222 |
223 | # Open the circuit
224 | try:
225 | breaker.call_sync(lambda: 1 / 0)
226 | except ZeroDivisionError:
227 | pass
228 |
229 | assert breaker.state == CircuitState.OPEN
230 |
231 | # Manual reset
232 | breaker.reset()
233 | assert breaker.state == CircuitState.CLOSED
234 | assert breaker._consecutive_failures == 0
235 |
236 | @pytest.mark.asyncio
237 | async def test_async_circuit_breaker(self):
238 | """Test circuit breaker with async functions."""
239 | config = CircuitBreakerConfig(
240 | name="test_async",
241 | failure_threshold=2,
242 | )
243 | breaker = EnhancedCircuitBreaker(config)
244 |
245 | async def failing_func():
246 | raise ValueError("Async failure")
247 |
248 | async def success_func():
249 | return "async success"
250 |
251 | # Test failures
252 | for _ in range(2):
253 | with pytest.raises(ValueError):
254 | await breaker.call_async(failing_func)
255 |
256 | assert breaker.state == CircuitState.OPEN
257 |
258 | # Test blocking
259 | with pytest.raises(CircuitBreakerError):
260 | await breaker.call_async(success_func)
261 |
262 | @pytest.mark.asyncio
263 | async def test_async_timeout(self):
264 | """Test async timeout handling."""
265 | config = CircuitBreakerConfig(
266 | name="test_timeout",
267 | timeout_threshold=0.1, # 100ms
268 | failure_threshold=1,
269 | )
270 | breaker = EnhancedCircuitBreaker(config)
271 |
272 | async def slow_func():
273 | await asyncio.sleep(0.5) # 500ms
274 | return "done"
275 |
276 | with pytest.raises(ExternalServiceError) as exc_info:
277 | await breaker.call_async(slow_func)
278 |
279 | assert "timed out" in str(exc_info.value)
280 | assert breaker.state == CircuitState.OPEN
281 |
282 |
283 | class TestCircuitBreakerDecorator:
284 | """Test circuit breaker decorator functionality."""
285 |
286 | def test_sync_decorator(self):
287 | """Test decorator with sync function."""
288 | call_count = 0
289 |
290 | @circuit_breaker(name="test_decorator", failure_threshold=2)
291 | def test_func(should_fail=False):
292 | nonlocal call_count
293 | call_count += 1
294 | if should_fail:
295 | raise ValueError("Test failure")
296 | return "success"
297 |
298 | # Successful calls
299 | assert test_func() == "success"
300 | assert test_func() == "success"
301 |
302 | # Failures
303 | for _ in range(2):
304 | with pytest.raises(ValueError):
305 | test_func(should_fail=True)
306 |
307 | # Circuit should be open
308 | with pytest.raises(CircuitBreakerError):
309 | test_func()
310 |
311 | @pytest.mark.asyncio
312 | async def test_async_decorator(self):
313 | """Test decorator with async function."""
314 |
315 | @circuit_breaker(name="test_async_decorator", failure_threshold=1)
316 | async def async_test_func(should_fail=False):
317 | if should_fail:
318 | raise ValueError("Async test failure")
319 | return "async success"
320 |
321 | # Success
322 | result = await async_test_func()
323 | assert result == "async success"
324 |
325 | # Failure
326 | with pytest.raises(ValueError):
327 | await async_test_func(should_fail=True)
328 |
329 | # Circuit open
330 | with pytest.raises(CircuitBreakerError):
331 | await async_test_func()
332 |
333 |
334 | class TestCircuitBreakerRegistry:
335 | """Test global circuit breaker registry."""
336 |
337 | def test_get_circuit_breaker(self):
338 | """Test getting circuit breaker by name."""
339 |
340 | # Create a breaker via decorator
341 | @circuit_breaker(name="registry_test")
342 | def test_func():
343 | return "test"
344 |
345 | # Call to initialize
346 | test_func()
347 |
348 | # Get from registry
349 | breaker = get_circuit_breaker("registry_test")
350 | assert breaker is not None
351 | assert breaker.config.name == "registry_test"
352 |
353 | def test_get_all_circuit_breakers(self):
354 | """Test getting all circuit breakers."""
355 | # Clear existing (from other tests)
356 | from maverick_mcp.utils.circuit_breaker import _breakers
357 |
358 | _breakers.clear()
359 |
360 | # Create multiple breakers
361 | @circuit_breaker(name="breaker1")
362 | def func1():
363 | pass
364 |
365 | @circuit_breaker(name="breaker2")
366 | def func2():
367 | pass
368 |
369 | # Initialize
370 | func1()
371 | func2()
372 |
373 | all_breakers = get_all_circuit_breakers()
374 | assert len(all_breakers) == 2
375 | assert "breaker1" in all_breakers
376 | assert "breaker2" in all_breakers
377 |
378 | def test_reset_all_circuit_breakers(self):
379 | """Test resetting all circuit breakers."""
380 |
381 | # Create and open a breaker
382 | @circuit_breaker(name="reset_test", failure_threshold=1)
383 | def failing_func():
384 | raise ValueError("Fail")
385 |
386 | with pytest.raises(ValueError):
387 | failing_func()
388 |
389 | breaker = get_circuit_breaker("reset_test")
390 | assert breaker.state == CircuitState.OPEN
391 |
392 | # Reset all
393 | reset_all_circuit_breakers()
394 | assert breaker.state == CircuitState.CLOSED
395 |
396 | def test_circuit_breaker_status(self):
397 | """Test getting status of all circuit breakers."""
398 |
399 | # Create a breaker
400 | @circuit_breaker(name="status_test")
401 | def test_func():
402 | return "test"
403 |
404 | test_func()
405 |
406 | status = get_circuit_breaker_status()
407 | assert "status_test" in status
408 | assert status["status_test"]["state"] == "closed"
409 | assert status["status_test"]["name"] == "status_test"
410 |
411 |
412 | class TestServiceSpecificCircuitBreakers:
413 | """Test service-specific circuit breaker implementations."""
414 |
415 | def test_stock_data_circuit_breaker(self):
416 | """Test stock data circuit breaker with fallback."""
417 | from maverick_mcp.utils.circuit_breaker_services import StockDataCircuitBreaker
418 |
419 | breaker = StockDataCircuitBreaker()
420 |
421 | # Mock a failing function
422 | def failing_fetch(symbol, start, end):
423 | raise Exception("API Error")
424 |
425 | # Mock fallback data
426 | with patch.object(breaker.fallback_chain, "execute_sync") as mock_fallback:
427 | import pandas as pd
428 |
429 | mock_fallback.return_value = pd.DataFrame({"Close": [100, 101, 102]})
430 |
431 | # Should use fallback
432 | result = breaker.fetch_with_fallback(
433 | failing_fetch, "AAPL", "2024-01-01", "2024-01-31"
434 | )
435 |
436 | assert not result.empty
437 | assert len(result) == 3
438 | mock_fallback.assert_called_once()
439 |
440 | def test_market_data_circuit_breaker(self):
441 | """Test market data circuit breaker with fallback."""
442 | from maverick_mcp.utils.circuit_breaker_services import MarketDataCircuitBreaker
443 |
444 | breaker = MarketDataCircuitBreaker("finviz")
445 |
446 | # Mock failing function
447 | def failing_fetch(mover_type):
448 | raise Exception("Finviz Error")
449 |
450 | # Should return fallback
451 | result = breaker.fetch_with_fallback(failing_fetch, "gainers")
452 |
453 | assert isinstance(result, dict)
454 | assert "movers" in result
455 | assert result["movers"] == []
456 | assert result["metadata"]["is_fallback"] is True
457 |
458 | def test_economic_data_circuit_breaker(self):
459 | """Test economic data circuit breaker with fallback."""
460 | from maverick_mcp.utils.circuit_breaker_services import (
461 | EconomicDataCircuitBreaker,
462 | )
463 |
464 | breaker = EconomicDataCircuitBreaker()
465 |
466 | # Mock failing function
467 | def failing_fetch(series_id, start, end):
468 | raise Exception("FRED API Error")
469 |
470 | # Should return default values
471 | result = breaker.fetch_with_fallback(
472 | failing_fetch, "GDP", "2024-01-01", "2024-01-31"
473 | )
474 |
475 | import pandas as pd
476 |
477 | assert isinstance(result, pd.Series)
478 | assert result.attrs["is_fallback"] is True
479 | assert all(result == 2.5) # Default GDP value
480 |
```
--------------------------------------------------------------------------------
/maverick_mcp/api/routers/news_sentiment_enhanced.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Enhanced news sentiment analysis using Tiingo News API or LLM-based analysis.
3 |
4 | This module provides reliable news sentiment analysis by:
5 | 1. Using Tiingo's get_news method (if available)
6 | 2. Falling back to LLM-based sentiment analysis using existing research tools
7 | 3. Never relying on undefined EXTERNAL_DATA_API_KEY
8 | """
9 |
10 | import asyncio
11 | import logging
12 | import os
13 | import uuid
14 | from datetime import datetime, timedelta
15 | from typing import Any
16 |
17 | from tiingo import TiingoClient
18 |
19 | from maverick_mcp.api.middleware.mcp_logging import get_tool_logger
20 | from maverick_mcp.config.settings import get_settings
21 |
22 | logger = logging.getLogger(__name__)
23 | settings = get_settings()
24 |
25 |
26 | def get_tiingo_client() -> TiingoClient | None:
27 | """Get or create Tiingo client if API key is available."""
28 | api_key = os.getenv("TIINGO_API_KEY")
29 | if api_key:
30 | try:
31 | config = {"session": True, "api_key": api_key}
32 | return TiingoClient(config)
33 | except Exception as e:
34 | logger.warning(f"Failed to initialize Tiingo client: {e}")
35 | return None
36 |
37 |
38 | def get_llm():
39 | """Get LLM for sentiment analysis (optimized for speed)."""
40 | from maverick_mcp.providers.llm_factory import get_llm as get_llm_factory
41 | from maverick_mcp.providers.openrouter_provider import TaskType
42 |
43 | # Use sentiment analysis task type with fast preference
44 | return get_llm_factory(
45 | task_type=TaskType.SENTIMENT_ANALYSIS, prefer_fast=True, prefer_cheap=True
46 | )
47 |
48 |
49 | async def get_news_sentiment_enhanced(
50 | ticker: str, timeframe: str = "7d", limit: int = 10
51 | ) -> dict[str, Any]:
52 | """
53 | Enhanced news sentiment analysis using Tiingo News API or LLM analysis.
54 |
55 | This tool provides reliable sentiment analysis by:
56 | 1. First attempting to use Tiingo's news API (if available)
57 | 2. Analyzing news sentiment using LLM if news is retrieved
58 | 3. Falling back to research-based sentiment if Tiingo unavailable
59 | 4. Providing guaranteed responses with appropriate fallbacks
60 |
61 | Args:
62 | ticker: Stock ticker symbol
63 | timeframe: Time frame for news (1d, 7d, 30d, etc.)
64 | limit: Maximum number of news articles to analyze
65 |
66 | Returns:
67 | Dictionary containing news sentiment analysis with confidence scores
68 | """
69 | tool_logger = get_tool_logger("data_get_news_sentiment_enhanced")
70 | request_id = str(uuid.uuid4())
71 |
72 | try:
73 | # Step 1: Try Tiingo News API
74 | tool_logger.step("tiingo_check", f"Checking Tiingo News API for {ticker}")
75 |
76 | tiingo_client = get_tiingo_client()
77 | if tiingo_client:
78 | try:
79 | # Calculate date range from timeframe
80 | end_date = datetime.now()
81 | days = int(timeframe.rstrip("d")) if timeframe.endswith("d") else 7
82 | start_date = end_date - timedelta(days=days)
83 |
84 | tool_logger.step(
85 | "tiingo_fetch", f"Fetching news from Tiingo for {ticker}"
86 | )
87 |
88 | # Fetch news using Tiingo's get_news method
89 | news_articles = await asyncio.wait_for(
90 | asyncio.to_thread(
91 | tiingo_client.get_news,
92 | tickers=[ticker],
93 | startDate=start_date.strftime("%Y-%m-%d"),
94 | endDate=end_date.strftime("%Y-%m-%d"),
95 | limit=limit,
96 | sortBy="publishedDate",
97 | onlyWithTickers=True,
98 | ),
99 | timeout=10.0,
100 | )
101 |
102 | if news_articles:
103 | tool_logger.step(
104 | "llm_analysis",
105 | f"Analyzing {len(news_articles)} articles with LLM",
106 | )
107 |
108 | # Analyze sentiment using LLM
109 | sentiment_result = await _analyze_news_sentiment_with_llm(
110 | news_articles, ticker, tool_logger
111 | )
112 |
113 | tool_logger.complete(
114 | f"Tiingo news sentiment analysis completed for {ticker}"
115 | )
116 |
117 | return {
118 | "ticker": ticker,
119 | "sentiment": sentiment_result["overall_sentiment"],
120 | "confidence": sentiment_result["confidence"],
121 | "source": "tiingo_news_with_llm_analysis",
122 | "status": "success",
123 | "analysis": {
124 | "articles_analyzed": len(news_articles),
125 | "sentiment_breakdown": sentiment_result["breakdown"],
126 | "key_themes": sentiment_result["themes"],
127 | "recent_headlines": sentiment_result["headlines"][:3],
128 | },
129 | "timeframe": timeframe,
130 | "request_id": request_id,
131 | "timestamp": datetime.now().isoformat(),
132 | }
133 |
134 | except TimeoutError:
135 | tool_logger.step(
136 | "tiingo_timeout", "Tiingo API timed out, using fallback"
137 | )
138 | except Exception as e:
139 | # Check if it's a permissions issue (free tier doesn't include news)
140 | if (
141 | "403" in str(e)
142 | or "permission" in str(e).lower()
143 | or "unauthorized" in str(e).lower()
144 | ):
145 | tool_logger.step(
146 | "tiingo_no_permission",
147 | "Tiingo news not available (requires paid plan)",
148 | )
149 | else:
150 | tool_logger.step("tiingo_error", f"Tiingo error: {str(e)}")
151 |
152 | # Step 2: Fallback to research-based sentiment
153 | tool_logger.step("research_fallback", "Using research-based sentiment analysis")
154 |
155 | from maverick_mcp.api.routers.research import analyze_market_sentiment
156 |
157 | # Use research tools to gather sentiment
158 | result = await asyncio.wait_for(
159 | analyze_market_sentiment(
160 | topic=f"{ticker} stock news sentiment recent {timeframe}",
161 | timeframe="1w" if days <= 7 else "1m",
162 | persona="moderate",
163 | ),
164 | timeout=15.0,
165 | )
166 |
167 | if result.get("success", False):
168 | sentiment_data = result.get("sentiment_analysis", {})
169 | return {
170 | "ticker": ticker,
171 | "sentiment": _extract_sentiment_from_research(sentiment_data),
172 | "confidence": sentiment_data.get("sentiment_confidence", 0.5),
173 | "source": "research_based_sentiment",
174 | "status": "fallback_success",
175 | "analysis": {
176 | "overall_sentiment": sentiment_data.get("overall_sentiment", {}),
177 | "key_themes": sentiment_data.get("sentiment_themes", [])[:3],
178 | "market_insights": sentiment_data.get("market_insights", [])[:2],
179 | },
180 | "timeframe": timeframe,
181 | "request_id": request_id,
182 | "timestamp": datetime.now().isoformat(),
183 | "message": "Using research-based sentiment (Tiingo news unavailable on free tier)",
184 | }
185 |
186 | # Step 3: Basic fallback
187 | return _provide_basic_sentiment_fallback(ticker, request_id)
188 |
189 | except Exception as e:
190 | tool_logger.error("sentiment_error", e, f"Sentiment analysis failed: {str(e)}")
191 | return _provide_basic_sentiment_fallback(ticker, request_id, str(e))
192 |
193 |
194 | async def _analyze_news_sentiment_with_llm(
195 | news_articles: list, ticker: str, tool_logger
196 | ) -> dict[str, Any]:
197 | """Analyze news articles sentiment using LLM."""
198 |
199 | llm = get_llm()
200 | if not llm:
201 | # No LLM available, do basic analysis
202 | return _basic_news_analysis(news_articles)
203 |
204 | try:
205 | # Prepare news summary for LLM
206 | news_summary = []
207 | for article in news_articles[:10]: # Limit to 10 most recent
208 | news_summary.append(
209 | {
210 | "title": article.get("title", ""),
211 | "description": article.get("description", "")[:200]
212 | if article.get("description")
213 | else "",
214 | "publishedDate": article.get("publishedDate", ""),
215 | "source": article.get("source", ""),
216 | }
217 | )
218 |
219 | # Create sentiment analysis prompt
220 | prompt = f"""Analyze the sentiment of these recent news articles about {ticker} stock.
221 |
222 | News Articles:
223 | {chr(10).join([f"- {a['title']} ({a['source']}, {a['publishedDate'][:10] if a['publishedDate'] else 'Unknown date'})" for a in news_summary[:5]])}
224 |
225 | Provide a JSON response with:
226 | 1. overall_sentiment: "bullish", "bearish", or "neutral"
227 | 2. confidence: 0.0 to 1.0
228 | 3. breakdown: dict with counts of positive, negative, neutral articles
229 | 4. themes: list of 3 key themes from the news
230 | 5. headlines: list of 3 most important headlines
231 |
232 | Response format:
233 | {{"overall_sentiment": "...", "confidence": 0.X, "breakdown": {{"positive": X, "negative": Y, "neutral": Z}}, "themes": ["...", "...", "..."], "headlines": ["...", "...", "..."]}}"""
234 |
235 | # Get LLM analysis
236 | response = await asyncio.to_thread(lambda: llm.invoke(prompt).content)
237 |
238 | # Parse JSON response
239 | import json
240 |
241 | try:
242 | # Extract JSON from response (handle markdown code blocks)
243 | if "```json" in response:
244 | json_str = response.split("```json")[1].split("```")[0].strip()
245 | elif "```" in response:
246 | json_str = response.split("```")[1].split("```")[0].strip()
247 | elif "{" in response:
248 | # Find JSON object in response
249 | start = response.index("{")
250 | end = response.rindex("}") + 1
251 | json_str = response[start:end]
252 | else:
253 | json_str = response
254 |
255 | result = json.loads(json_str)
256 |
257 | # Ensure all required fields
258 | return {
259 | "overall_sentiment": result.get("overall_sentiment", "neutral"),
260 | "confidence": float(result.get("confidence", 0.5)),
261 | "breakdown": result.get(
262 | "breakdown",
263 | {"positive": 0, "negative": 0, "neutral": len(news_articles)},
264 | ),
265 | "themes": result.get(
266 | "themes",
267 | ["Market movement", "Company performance", "Industry trends"],
268 | ),
269 | "headlines": [a.get("title", "") for a in news_summary[:3]],
270 | }
271 |
272 | except (json.JSONDecodeError, ValueError) as e:
273 | tool_logger.step("llm_parse_error", f"Failed to parse LLM response: {e}")
274 | return _basic_news_analysis(news_articles)
275 |
276 | except Exception as e:
277 | tool_logger.step("llm_error", f"LLM analysis failed: {e}")
278 | return _basic_news_analysis(news_articles)
279 |
280 |
281 | def _basic_news_analysis(news_articles: list) -> dict[str, Any]:
282 | """Basic sentiment analysis without LLM."""
283 |
284 | # Simple keyword-based sentiment
285 | positive_keywords = [
286 | "gain",
287 | "rise",
288 | "up",
289 | "beat",
290 | "exceed",
291 | "strong",
292 | "bull",
293 | "buy",
294 | "upgrade",
295 | "positive",
296 | ]
297 | negative_keywords = [
298 | "loss",
299 | "fall",
300 | "down",
301 | "miss",
302 | "below",
303 | "weak",
304 | "bear",
305 | "sell",
306 | "downgrade",
307 | "negative",
308 | ]
309 |
310 | positive_count = 0
311 | negative_count = 0
312 | neutral_count = 0
313 |
314 | for article in news_articles:
315 | title = (
316 | article.get("title", "") + " " + article.get("description", "")
317 | ).lower()
318 |
319 | pos_score = sum(1 for keyword in positive_keywords if keyword in title)
320 | neg_score = sum(1 for keyword in negative_keywords if keyword in title)
321 |
322 | if pos_score > neg_score:
323 | positive_count += 1
324 | elif neg_score > pos_score:
325 | negative_count += 1
326 | else:
327 | neutral_count += 1
328 |
329 | total = len(news_articles)
330 | if total == 0:
331 | return {
332 | "overall_sentiment": "neutral",
333 | "confidence": 0.0,
334 | "breakdown": {"positive": 0, "negative": 0, "neutral": 0},
335 | "themes": [],
336 | "headlines": [],
337 | }
338 |
339 | # Determine overall sentiment
340 | if positive_count > negative_count * 1.5:
341 | overall = "bullish"
342 | elif negative_count > positive_count * 1.5:
343 | overall = "bearish"
344 | else:
345 | overall = "neutral"
346 |
347 | # Calculate confidence based on consensus
348 | max_count = max(positive_count, negative_count, neutral_count)
349 | confidence = max_count / total if total > 0 else 0.0
350 |
351 | return {
352 | "overall_sentiment": overall,
353 | "confidence": confidence,
354 | "breakdown": {
355 | "positive": positive_count,
356 | "negative": negative_count,
357 | "neutral": neutral_count,
358 | },
359 | "themes": ["Recent news", "Market activity", "Company updates"],
360 | "headlines": [a.get("title", "") for a in news_articles[:3]],
361 | }
362 |
363 |
364 | def _extract_sentiment_from_research(sentiment_data: dict) -> str:
365 | """Extract simple sentiment direction from research data."""
366 |
367 | overall = sentiment_data.get("overall_sentiment", {})
368 |
369 | # Check for sentiment keywords
370 | if isinstance(overall, dict):
371 | sentiment_str = str(overall).lower()
372 | else:
373 | sentiment_str = str(overall).lower()
374 |
375 | if "bullish" in sentiment_str or "positive" in sentiment_str:
376 | return "bullish"
377 | elif "bearish" in sentiment_str or "negative" in sentiment_str:
378 | return "bearish"
379 |
380 | # Check confidence for direction
381 | confidence = sentiment_data.get("sentiment_confidence", 0.5)
382 | if confidence > 0.6:
383 | return "bullish"
384 | elif confidence < 0.4:
385 | return "bearish"
386 |
387 | return "neutral"
388 |
389 |
390 | def _provide_basic_sentiment_fallback(
391 | ticker: str, request_id: str, error_detail: str = None
392 | ) -> dict[str, Any]:
393 | """Provide basic fallback when all methods fail."""
394 |
395 | response = {
396 | "ticker": ticker,
397 | "sentiment": "neutral",
398 | "confidence": 0.0,
399 | "source": "fallback",
400 | "status": "all_methods_failed",
401 | "message": "Unable to fetch news sentiment - returning neutral baseline",
402 | "analysis": {
403 | "note": "Consider using a paid Tiingo plan for news access or check API keys"
404 | },
405 | "request_id": request_id,
406 | "timestamp": datetime.now().isoformat(),
407 | }
408 |
409 | if error_detail:
410 | response["error_detail"] = error_detail[:200] # Limit error message length
411 |
412 | return response
413 |
```
--------------------------------------------------------------------------------
/maverick_mcp/tests/test_macro_data_provider.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Tests for the MacroDataProvider class.
3 | """
4 |
5 | import unittest
6 | from datetime import datetime
7 | from unittest.mock import MagicMock, patch
8 |
9 | import pandas as pd
10 |
11 | from maverick_mcp.providers.macro_data import MacroDataProvider
12 |
13 |
14 | class TestMacroDataProvider(unittest.TestCase):
15 | """Test suite for MacroDataProvider."""
16 |
17 | @patch("fredapi.Fred")
18 | def setUp(self, mock_fred_class):
19 | """Set up test fixtures."""
20 | mock_fred = MagicMock()
21 | mock_fred_class.return_value = mock_fred
22 | # Create provider with mocked FRED
23 | self.provider = MacroDataProvider()
24 | self.provider.fred = mock_fred
25 |
26 | @patch("fredapi.Fred")
27 | def test_init_with_fred_api(self, mock_fred_class):
28 | """Test initialization with FRED API."""
29 | mock_fred = MagicMock()
30 | mock_fred_class.return_value = mock_fred
31 |
32 | provider = MacroDataProvider(window_days=180)
33 |
34 | self.assertEqual(provider.window_days, 180)
35 | self.assertIsNotNone(provider.scaler)
36 | self.assertIsNotNone(provider.weights)
37 | mock_fred_class.assert_called_once()
38 |
39 | def test_calculate_weighted_rolling_performance(self):
40 | """Test weighted rolling performance calculation."""
41 | # Mock FRED data
42 | mock_data = pd.Series(
43 | [100, 102, 104, 106, 108],
44 | index=pd.date_range(end=datetime.now(), periods=5, freq="D"),
45 | )
46 |
47 | with patch.object(self.provider.fred, "get_series") as mock_get_series:
48 | mock_get_series.return_value = mock_data
49 |
50 | result = self.provider._calculate_weighted_rolling_performance( # type: ignore[attr-defined]
51 | "SP500", [30, 90, 180], [0.5, 0.3, 0.2]
52 | )
53 |
54 | self.assertIsInstance(result, float)
55 | self.assertEqual(mock_get_series.call_count, 3)
56 |
57 | def test_calculate_weighted_rolling_performance_empty_data(self):
58 | """Test weighted rolling performance with empty data."""
59 | with patch.object(self.provider.fred, "get_series") as mock_get_series:
60 | mock_get_series.return_value = pd.Series([])
61 |
62 | result = self.provider._calculate_weighted_rolling_performance( # type: ignore[attr-defined]
63 | "SP500", [30], [1.0]
64 | )
65 |
66 | self.assertEqual(result, 0.0)
67 |
68 | def test_get_sp500_performance(self):
69 | """Test S&P 500 performance calculation."""
70 | with patch.object(
71 | self.provider, "_calculate_weighted_rolling_performance"
72 | ) as mock_calc:
73 | mock_calc.return_value = 5.5
74 |
75 | result = self.provider.get_sp500_performance()
76 |
77 | self.assertEqual(result, 5.5)
78 | mock_calc.assert_called_once_with("SP500", [30, 90, 180], [0.5, 0.3, 0.2])
79 |
80 | def test_get_nasdaq_performance(self):
81 | """Test NASDAQ performance calculation."""
82 | with patch.object(
83 | self.provider, "_calculate_weighted_rolling_performance"
84 | ) as mock_calc:
85 | mock_calc.return_value = 7.2
86 |
87 | result = self.provider.get_nasdaq_performance()
88 |
89 | self.assertEqual(result, 7.2)
90 | mock_calc.assert_called_once_with(
91 | "NASDAQ100", [30, 90, 180], [0.5, 0.3, 0.2]
92 | )
93 |
94 | def test_get_gdp_growth_rate(self):
95 | """Test GDP growth rate fetching."""
96 | mock_data = pd.Series(
97 | [2.5, 2.8], index=pd.date_range(end=datetime.now(), periods=2, freq="Q")
98 | )
99 |
100 | with patch.object(self.provider.fred, "get_series") as mock_get_series:
101 | mock_get_series.return_value = mock_data
102 |
103 | result = self.provider.get_gdp_growth_rate()
104 |
105 | self.assertIsInstance(result, dict)
106 | self.assertEqual(result["current"], 2.8)
107 | self.assertEqual(result["previous"], 2.5)
108 |
109 | def test_get_gdp_growth_rate_empty_data(self):
110 | """Test GDP growth rate with no data."""
111 | with patch.object(self.provider.fred, "get_series") as mock_get_series:
112 | mock_get_series.return_value = pd.Series([])
113 |
114 | result = self.provider.get_gdp_growth_rate()
115 |
116 | self.assertEqual(result["current"], 0.0)
117 | self.assertEqual(result["previous"], 0.0)
118 |
119 | def test_get_unemployment_rate(self):
120 | """Test unemployment rate fetching."""
121 | mock_data = pd.Series(
122 | [3.5, 3.6, 3.7],
123 | index=pd.date_range(end=datetime.now(), periods=3, freq="M"),
124 | )
125 |
126 | with patch.object(self.provider.fred, "get_series") as mock_get_series:
127 | mock_get_series.return_value = mock_data
128 |
129 | result = self.provider.get_unemployment_rate()
130 |
131 | self.assertIsInstance(result, dict)
132 | self.assertEqual(result["current"], 3.7)
133 | self.assertEqual(result["previous"], 3.6)
134 |
135 | def test_get_inflation_rate(self):
136 | """Test inflation rate calculation."""
137 | # Create CPI data for 24 months
138 | dates = pd.date_range(end=datetime.now(), periods=24, freq="MS")
139 | cpi_values = [100 + i * 0.2 for i in range(24)] # Gradual increase
140 | mock_data = pd.Series(cpi_values, index=dates)
141 |
142 | with patch.object(self.provider.fred, "get_series") as mock_get_series:
143 | mock_get_series.return_value = mock_data
144 |
145 | result = self.provider.get_inflation_rate()
146 |
147 | self.assertIsInstance(result, dict)
148 | self.assertIn("current", result)
149 | self.assertIn("previous", result)
150 | self.assertIn("bounds", result)
151 | self.assertIsInstance(result["bounds"], tuple)
152 |
153 | def test_get_inflation_rate_insufficient_data(self):
154 | """Test inflation rate with insufficient data."""
155 | # Only 6 months of data (need 13+ for YoY)
156 | dates = pd.date_range(end=datetime.now(), periods=6, freq="MS")
157 | mock_data = pd.Series([100, 101, 102, 103, 104, 105], index=dates)
158 |
159 | with patch.object(self.provider.fred, "get_series") as mock_get_series:
160 | mock_get_series.return_value = mock_data
161 |
162 | result = self.provider.get_inflation_rate()
163 |
164 | self.assertEqual(result["current"], 0.0)
165 | self.assertEqual(result["previous"], 0.0)
166 |
167 | def test_get_vix(self):
168 | """Test VIX fetching."""
169 | # Test with yfinance first
170 | with patch("yfinance.Ticker") as mock_ticker_class:
171 | mock_ticker = MagicMock()
172 | mock_ticker_class.return_value = mock_ticker
173 | mock_ticker.history.return_value = pd.DataFrame(
174 | {"Close": [18.5]}, index=[datetime.now()]
175 | )
176 |
177 | result = self.provider.get_vix()
178 |
179 | self.assertEqual(result, 18.5)
180 |
181 | def test_get_vix_fallback_to_fred(self):
182 | """Test VIX fetching with FRED fallback."""
183 | with patch("yfinance.Ticker") as mock_ticker_class:
184 | mock_ticker = MagicMock()
185 | mock_ticker_class.return_value = mock_ticker
186 | mock_ticker.history.return_value = pd.DataFrame() # Empty yfinance data
187 |
188 | mock_fred_data = pd.Series([20.5], index=[datetime.now()])
189 | with patch.object(self.provider.fred, "get_series") as mock_get_series:
190 | mock_get_series.return_value = mock_fred_data
191 |
192 | result = self.provider.get_vix()
193 |
194 | self.assertEqual(result, 20.5)
195 |
196 | def test_get_sp500_momentum(self):
197 | """Test S&P 500 momentum calculation."""
198 | # Create mock data with upward trend
199 | dates = pd.date_range(end=datetime.now(), periods=15, freq="D")
200 | values = [3000 + i * 10 for i in range(15)]
201 | mock_data = pd.Series(values, index=dates)
202 |
203 | with patch.object(self.provider.fred, "get_series") as mock_get_series:
204 | mock_get_series.return_value = mock_data
205 |
206 | result = self.provider.get_sp500_momentum()
207 |
208 | self.assertIsInstance(result, float)
209 | self.assertGreater(result, 0) # Should be positive for upward trend
210 |
211 | def test_get_nasdaq_momentum(self):
212 | """Test NASDAQ momentum calculation."""
213 | dates = pd.date_range(end=datetime.now(), periods=15, freq="D")
214 | values = [15000 + i * 50 for i in range(15)]
215 | mock_data = pd.Series(values, index=dates)
216 |
217 | with patch.object(self.provider.fred, "get_series") as mock_get_series:
218 | mock_get_series.return_value = mock_data
219 |
220 | result = self.provider.get_nasdaq_momentum()
221 |
222 | self.assertIsInstance(result, float)
223 | self.assertGreater(result, 0)
224 |
225 | def test_get_usd_momentum(self):
226 | """Test USD momentum calculation."""
227 | dates = pd.date_range(end=datetime.now(), periods=15, freq="D")
228 | values = [100 + i * 0.1 for i in range(15)]
229 | mock_data = pd.Series(values, index=dates)
230 |
231 | with patch.object(self.provider.fred, "get_series") as mock_get_series:
232 | mock_get_series.return_value = mock_data
233 |
234 | result = self.provider.get_usd_momentum()
235 |
236 | self.assertIsInstance(result, float)
237 |
238 | def test_update_historical_bounds(self):
239 | """Test updating historical bounds."""
240 | # Mock data for different indicators
241 | gdp_data = pd.Series([1.5, 2.0, 2.5, 3.0])
242 | unemployment_data = pd.Series([3.5, 4.0, 4.5, 5.0])
243 |
244 | with patch.object(self.provider.fred, "get_series") as mock_get_series:
245 |
246 | def side_effect(series_id, *args, **kwargs):
247 | if series_id == "A191RL1Q225SBEA":
248 | return gdp_data
249 | elif series_id == "UNRATE":
250 | return unemployment_data
251 | else:
252 | return pd.Series([])
253 |
254 | mock_get_series.side_effect = side_effect
255 |
256 | self.provider.update_historical_bounds()
257 |
258 | self.assertIn("gdp_growth_rate", self.provider.historical_data_bounds)
259 | self.assertIn("unemployment_rate", self.provider.historical_data_bounds)
260 |
261 | def test_default_bounds(self):
262 | """Test default bounds for indicators."""
263 | bounds = self.provider.default_bounds("vix")
264 | self.assertEqual(bounds["min"], 10.0)
265 | self.assertEqual(bounds["max"], 50.0)
266 |
267 | bounds = self.provider.default_bounds("unknown_indicator")
268 | self.assertEqual(bounds["min"], 0.0)
269 | self.assertEqual(bounds["max"], 1.0)
270 |
271 | def test_normalize_indicators(self):
272 | """Test indicator normalization."""
273 | indicators = {
274 | "vix": 30.0, # Middle of 10-50 range
275 | "sp500_momentum": 0.0, # Middle of -15 to 15 range
276 | "unemployment_rate": 6.0, # Middle of 2-10 range
277 | "gdp_growth_rate": 2.0, # In -2 to 6 range
278 | }
279 |
280 | normalized = self.provider.normalize_indicators(indicators)
281 |
282 | # VIX should be inverted (lower is better)
283 | self.assertAlmostEqual(normalized["vix"], 0.5, places=1)
284 | # SP500 momentum at 0 should normalize to 0.5
285 | self.assertAlmostEqual(normalized["sp500_momentum"], 0.5, places=1)
286 | # Unemployment should be inverted
287 | self.assertAlmostEqual(normalized["unemployment_rate"], 0.5, places=1)
288 |
289 | def test_normalize_indicators_with_none_values(self):
290 | """Test normalization with None values."""
291 | indicators = {
292 | "vix": None,
293 | "sp500_momentum": 5.0,
294 | }
295 |
296 | normalized = self.provider.normalize_indicators(indicators)
297 |
298 | self.assertEqual(normalized["vix"], 0.5) # Default for None
299 | self.assertGreater(normalized["sp500_momentum"], 0.5)
300 |
301 | def test_get_historical_data(self):
302 | """Test fetching historical data."""
303 | # Mock different data series
304 | sp500_data = pd.Series(
305 | [3000, 3050, 3100],
306 | index=pd.date_range(end=datetime.now(), periods=3, freq="D"),
307 | )
308 | vix_data = pd.Series(
309 | [15, 16, 17], index=pd.date_range(end=datetime.now(), periods=3, freq="D")
310 | )
311 |
312 | with patch.object(self.provider.fred, "get_series") as mock_get_series:
313 |
314 | def side_effect(series_id, *args, **kwargs):
315 | if series_id == "SP500":
316 | return sp500_data
317 | elif series_id == "VIXCLS":
318 | return vix_data
319 | else:
320 | return pd.Series([])
321 |
322 | mock_get_series.side_effect = side_effect
323 |
324 | result = self.provider.get_historical_data()
325 |
326 | self.assertIsInstance(result, dict)
327 | self.assertIn("sp500_performance", result)
328 | self.assertIn("vix", result)
329 | self.assertIsInstance(result["sp500_performance"], list)
330 | self.assertIsInstance(result["vix"], list)
331 |
332 | def test_get_macro_statistics(self):
333 | """Test comprehensive macro statistics."""
334 | # Mock all the individual methods
335 | with patch.object(self.provider, "get_gdp_growth_rate") as mock_gdp:
336 | mock_gdp.return_value = {"current": 2.5, "previous": 2.3}
337 |
338 | with patch.object(
339 | self.provider, "get_unemployment_rate"
340 | ) as mock_unemployment:
341 | mock_unemployment.return_value = {"current": 3.7, "previous": 3.8}
342 |
343 | with patch.object(
344 | self.provider, "get_inflation_rate"
345 | ) as mock_inflation:
346 | mock_inflation.return_value = {
347 | "current": 2.1,
348 | "previous": 2.0,
349 | "bounds": (1.5, 3.0),
350 | }
351 |
352 | with patch.object(self.provider, "get_vix") as mock_vix:
353 | mock_vix.return_value = 18.5
354 |
355 | result = self.provider.get_macro_statistics()
356 |
357 | self.assertIsInstance(result, dict)
358 | self.assertEqual(result["gdp_growth_rate"], 2.5)
359 | self.assertEqual(result["unemployment_rate"], 3.7)
360 | self.assertEqual(result["inflation_rate"], 2.1)
361 | self.assertEqual(result["vix"], 18.5)
362 | self.assertIn("sentiment_score", result)
363 | self.assertIsInstance(result["sentiment_score"], float)
364 | self.assertTrue(1 <= result["sentiment_score"] <= 100)
365 |
366 | def test_get_macro_statistics_error_handling(self):
367 | """Test macro statistics with errors."""
368 | with patch.object(self.provider, "update_historical_bounds") as mock_update:
369 | mock_update.side_effect = Exception("Update error")
370 |
371 | result = self.provider.get_macro_statistics()
372 |
373 | # Should return safe defaults
374 | self.assertEqual(result["gdp_growth_rate"], 0.0)
375 | self.assertEqual(result["unemployment_rate"], 0.0)
376 | self.assertEqual(result["sentiment_score"], 50.0)
377 |
378 |
379 | if __name__ == "__main__":
380 | unittest.main()
381 |
```
--------------------------------------------------------------------------------
/maverick_mcp/api/routers/technical_enhanced.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Enhanced technical analysis router with comprehensive logging and timeout handling.
3 |
4 | This module fixes the "No result received from client-side tool execution" issues by:
5 | - Adding comprehensive logging for each step of tool execution
6 | - Implementing proper timeout handling (under 25 seconds)
7 | - Breaking down complex operations into logged steps
8 | - Providing detailed error context and debugging information
9 | - Ensuring JSON-RPC responses are always sent
10 | """
11 |
12 | import asyncio
13 | from concurrent.futures import ThreadPoolExecutor
14 | from datetime import UTC, datetime
15 | from typing import Any
16 |
17 | from fastmcp import FastMCP
18 | from fastmcp.server.dependencies import get_access_token
19 |
20 | from maverick_mcp.api.middleware.mcp_logging import get_tool_logger
21 | from maverick_mcp.core.technical_analysis import (
22 | analyze_bollinger_bands,
23 | analyze_macd,
24 | analyze_rsi,
25 | analyze_stochastic,
26 | analyze_trend,
27 | analyze_volume,
28 | generate_outlook,
29 | identify_chart_patterns,
30 | identify_resistance_levels,
31 | identify_support_levels,
32 | )
33 | from maverick_mcp.utils.logging import get_logger
34 | from maverick_mcp.utils.stock_helpers import get_stock_dataframe_async
35 | from maverick_mcp.validation.technical import TechnicalAnalysisRequest
36 |
37 | logger = get_logger("maverick_mcp.routers.technical_enhanced")
38 |
39 | # Create the enhanced technical analysis router
40 | technical_enhanced_router: FastMCP = FastMCP("Technical_Analysis_Enhanced")
41 |
42 | # Thread pool for blocking operations
43 | executor = ThreadPoolExecutor(max_workers=4)
44 |
45 |
46 | class TechnicalAnalysisTimeoutError(Exception):
47 | """Raised when technical analysis times out."""
48 |
49 | pass
50 |
51 |
52 | class TechnicalAnalysisError(Exception):
53 | """Base exception for technical analysis errors."""
54 |
55 | pass
56 |
57 |
58 | async def get_full_technical_analysis_enhanced(
59 | request: TechnicalAnalysisRequest,
60 | ) -> dict[str, Any]:
61 | """
62 | Enhanced technical analysis with comprehensive logging and timeout handling.
63 |
64 | This version:
65 | - Logs every step of execution for debugging
66 | - Uses proper timeout handling (25 seconds max)
67 | - Breaks complex operations into chunks
68 | - Always returns a JSON-RPC compatible response
69 | - Provides detailed error context
70 |
71 | Args:
72 | request: Validated technical analysis request
73 |
74 | Returns:
75 | Dictionary containing complete technical analysis
76 |
77 | Raises:
78 | TechnicalAnalysisTimeoutError: If analysis takes too long
79 | TechnicalAnalysisError: For other analysis errors
80 | """
81 | tool_logger = get_tool_logger("get_full_technical_analysis_enhanced")
82 | ticker = request.ticker
83 | days = request.days
84 |
85 | try:
86 | # Set overall timeout (25s to stay under Claude Desktop's 30s limit)
87 | return await asyncio.wait_for(
88 | _execute_technical_analysis_with_logging(tool_logger, ticker, days),
89 | timeout=25.0,
90 | )
91 |
92 | except TimeoutError:
93 | error_msg = f"Technical analysis for {ticker} timed out after 25 seconds"
94 | tool_logger.error("timeout", TimeoutError(error_msg))
95 | logger.error(error_msg, extra={"ticker": ticker, "days": days})
96 |
97 | return {
98 | "error": error_msg,
99 | "error_type": "timeout",
100 | "ticker": ticker,
101 | "status": "failed",
102 | "execution_time": 25.0,
103 | "timestamp": datetime.now(UTC).isoformat(),
104 | }
105 |
106 | except Exception as e:
107 | error_msg = f"Technical analysis for {ticker} failed: {str(e)}"
108 | tool_logger.error("general_error", e)
109 | logger.error(
110 | error_msg,
111 | extra={"ticker": ticker, "days": days, "error_type": type(e).__name__},
112 | )
113 |
114 | return {
115 | "error": error_msg,
116 | "error_type": type(e).__name__,
117 | "ticker": ticker,
118 | "status": "failed",
119 | "timestamp": datetime.now(UTC).isoformat(),
120 | }
121 |
122 |
123 | async def _execute_technical_analysis_with_logging(
124 | tool_logger, ticker: str, days: int
125 | ) -> dict[str, Any]:
126 | """Execute technical analysis with comprehensive step-by-step logging."""
127 |
128 | # Step 1: Check authentication (optional)
129 | tool_logger.step("auth_check", "Checking authentication context")
130 | has_premium = False
131 | try:
132 | access_token = get_access_token()
133 | if access_token and "premium:access" in access_token.scopes:
134 | has_premium = True
135 | logger.info(
136 | f"Premium user accessing technical analysis: {access_token.client_id}"
137 | )
138 | except Exception:
139 | logger.debug("Unauthenticated user accessing technical analysis")
140 |
141 | # Step 2: Fetch stock data
142 | tool_logger.step("data_fetch", f"Fetching {days} days of data for {ticker}")
143 | try:
144 | df = await asyncio.wait_for(
145 | get_stock_dataframe_async(ticker, days),
146 | timeout=8.0, # Data fetch should be fast
147 | )
148 |
149 | if df.empty:
150 | raise TechnicalAnalysisError(f"No data available for {ticker}")
151 |
152 | logger.info(f"Retrieved {len(df)} data points for {ticker}")
153 | tool_logger.step("data_validation", f"Retrieved {len(df)} data points")
154 |
155 | except TimeoutError:
156 | raise TechnicalAnalysisError(f"Data fetch for {ticker} timed out")
157 | except Exception as e:
158 | raise TechnicalAnalysisError(f"Failed to fetch data for {ticker}: {str(e)}")
159 |
160 | # Step 3: Calculate basic indicators (parallel execution)
161 | tool_logger.step("basic_indicators", "Calculating RSI, MACD, Stochastic")
162 | try:
163 | # Run basic indicators in parallel with timeouts
164 | basic_tasks = [
165 | asyncio.wait_for(_run_in_executor(analyze_rsi, df), timeout=3.0),
166 | asyncio.wait_for(_run_in_executor(analyze_macd, df), timeout=3.0),
167 | asyncio.wait_for(_run_in_executor(analyze_stochastic, df), timeout=3.0),
168 | asyncio.wait_for(_run_in_executor(analyze_trend, df), timeout=2.0),
169 | ]
170 |
171 | rsi_analysis, macd_analysis, stoch_analysis, trend = await asyncio.gather(
172 | *basic_tasks
173 | )
174 | tool_logger.step(
175 | "basic_indicators_complete", "Basic indicators calculated successfully"
176 | )
177 |
178 | except TimeoutError:
179 | raise TechnicalAnalysisError("Basic indicator calculation timed out")
180 | except Exception as e:
181 | raise TechnicalAnalysisError(f"Basic indicator calculation failed: {str(e)}")
182 |
183 | # Step 4: Calculate advanced indicators
184 | tool_logger.step(
185 | "advanced_indicators", "Calculating Bollinger Bands, Volume analysis"
186 | )
187 | try:
188 | advanced_tasks = [
189 | asyncio.wait_for(
190 | _run_in_executor(analyze_bollinger_bands, df), timeout=3.0
191 | ),
192 | asyncio.wait_for(_run_in_executor(analyze_volume, df), timeout=3.0),
193 | ]
194 |
195 | bb_analysis, volume_analysis = await asyncio.gather(*advanced_tasks)
196 | tool_logger.step(
197 | "advanced_indicators_complete", "Advanced indicators calculated"
198 | )
199 |
200 | except TimeoutError:
201 | raise TechnicalAnalysisError("Advanced indicator calculation timed out")
202 | except Exception as e:
203 | raise TechnicalAnalysisError(f"Advanced indicator calculation failed: {str(e)}")
204 |
205 | # Step 5: Pattern recognition and levels
206 | tool_logger.step(
207 | "pattern_analysis", "Identifying patterns and support/resistance levels"
208 | )
209 | try:
210 | pattern_tasks = [
211 | asyncio.wait_for(
212 | _run_in_executor(identify_chart_patterns, df), timeout=4.0
213 | ),
214 | asyncio.wait_for(
215 | _run_in_executor(identify_support_levels, df), timeout=3.0
216 | ),
217 | asyncio.wait_for(
218 | _run_in_executor(identify_resistance_levels, df), timeout=3.0
219 | ),
220 | ]
221 |
222 | patterns, support, resistance = await asyncio.gather(*pattern_tasks)
223 | tool_logger.step("pattern_analysis_complete", f"Found {len(patterns)} patterns")
224 |
225 | except TimeoutError:
226 | raise TechnicalAnalysisError("Pattern analysis timed out")
227 | except Exception as e:
228 | raise TechnicalAnalysisError(f"Pattern analysis failed: {str(e)}")
229 |
230 | # Step 6: Generate outlook
231 | tool_logger.step("outlook_generation", "Generating market outlook")
232 | try:
233 | outlook = await asyncio.wait_for(
234 | _run_in_executor(
235 | generate_outlook,
236 | df,
237 | str(trend),
238 | rsi_analysis,
239 | macd_analysis,
240 | stoch_analysis,
241 | ),
242 | timeout=3.0,
243 | )
244 | tool_logger.step("outlook_complete", "Market outlook generated")
245 |
246 | except TimeoutError:
247 | raise TechnicalAnalysisError("Outlook generation timed out")
248 | except Exception as e:
249 | raise TechnicalAnalysisError(f"Outlook generation failed: {str(e)}")
250 |
251 | # Step 7: Compile final results
252 | tool_logger.step("result_compilation", "Compiling final analysis results")
253 | try:
254 | current_price = float(df["close"].iloc[-1])
255 |
256 | result = {
257 | "ticker": ticker,
258 | "current_price": current_price,
259 | "trend": trend,
260 | "outlook": outlook,
261 | "indicators": {
262 | "rsi": rsi_analysis,
263 | "macd": macd_analysis,
264 | "stochastic": stoch_analysis,
265 | "bollinger_bands": bb_analysis,
266 | "volume": volume_analysis,
267 | },
268 | "levels": {
269 | "support": sorted(support) if support else [],
270 | "resistance": sorted(resistance) if resistance else [],
271 | },
272 | "patterns": patterns,
273 | "analysis_metadata": {
274 | "data_points": len(df),
275 | "period_days": days,
276 | "has_premium": has_premium,
277 | "timestamp": datetime.now(UTC).isoformat(),
278 | },
279 | "status": "completed",
280 | }
281 |
282 | tool_logger.complete(
283 | f"Analysis completed for {ticker} with {len(df)} data points"
284 | )
285 | return result
286 |
287 | except Exception as e:
288 | raise TechnicalAnalysisError(f"Result compilation failed: {str(e)}")
289 |
290 |
291 | async def _run_in_executor(func, *args) -> Any:
292 | """Run a synchronous function in the thread pool executor."""
293 | loop = asyncio.get_event_loop()
294 | return await loop.run_in_executor(executor, func, *args)
295 |
296 |
297 | async def get_stock_chart_analysis_enhanced(ticker: str) -> dict[str, Any]:
298 | """
299 | Enhanced stock chart analysis with logging and timeout handling.
300 |
301 | This version generates charts with proper timeout handling and error logging.
302 |
303 | Args:
304 | ticker: Stock ticker symbol
305 |
306 | Returns:
307 | Dictionary containing chart data or error information
308 | """
309 | tool_logger = get_tool_logger("get_stock_chart_analysis_enhanced")
310 |
311 | try:
312 | # Set timeout for chart generation
313 | return await asyncio.wait_for(
314 | _generate_chart_with_logging(tool_logger, ticker),
315 | timeout=15.0, # Charts should be faster than full analysis
316 | )
317 |
318 | except TimeoutError:
319 | error_msg = f"Chart generation for {ticker} timed out after 15 seconds"
320 | tool_logger.error("timeout", TimeoutError(error_msg))
321 |
322 | return {
323 | "error": error_msg,
324 | "error_type": "timeout",
325 | "ticker": ticker,
326 | "status": "failed",
327 | }
328 |
329 | except Exception as e:
330 | error_msg = f"Chart generation for {ticker} failed: {str(e)}"
331 | tool_logger.error("general_error", e)
332 |
333 | return {
334 | "error": error_msg,
335 | "error_type": type(e).__name__,
336 | "ticker": ticker,
337 | "status": "failed",
338 | }
339 |
340 |
341 | async def _generate_chart_with_logging(tool_logger, ticker: str) -> dict[str, Any]:
342 | """Generate chart with step-by-step logging."""
343 | from maverick_mcp.core.technical_analysis import add_technical_indicators
344 | from maverick_mcp.core.visualization import (
345 | create_plotly_technical_chart,
346 | plotly_fig_to_base64,
347 | )
348 |
349 | # Step 1: Fetch data
350 | tool_logger.step("chart_data_fetch", f"Fetching chart data for {ticker}")
351 | df = await get_stock_dataframe_async(ticker, 365)
352 |
353 | if df.empty:
354 | raise TechnicalAnalysisError(
355 | f"No data available for chart generation: {ticker}"
356 | )
357 |
358 | # Step 2: Add technical indicators
359 | tool_logger.step("chart_indicators", "Adding technical indicators to chart")
360 | df_with_indicators = await _run_in_executor(add_technical_indicators, df)
361 |
362 | # Step 3: Generate chart configurations (progressive sizing)
363 | chart_configs = [
364 | {"height": 400, "width": 600, "format": "png", "quality": 85},
365 | {"height": 300, "width": 500, "format": "jpeg", "quality": 75},
366 | {"height": 250, "width": 400, "format": "jpeg", "quality": 65},
367 | ]
368 |
369 | for i, config in enumerate(chart_configs):
370 | try:
371 | tool_logger.step(
372 | f"chart_generation_{i + 1}", f"Generating chart (attempt {i + 1})"
373 | )
374 |
375 | # Generate chart
376 | chart = await _run_in_executor(
377 | create_plotly_technical_chart,
378 | df_with_indicators,
379 | ticker,
380 | config["height"],
381 | config["width"],
382 | )
383 |
384 | # Convert to base64
385 | data_uri = await _run_in_executor(
386 | plotly_fig_to_base64, chart, config["format"]
387 | )
388 |
389 | # Validate size (Claude Desktop has limits)
390 | if len(data_uri) < 200000: # ~200KB limit for safety
391 | tool_logger.complete(
392 | f"Chart generated successfully (size: {len(data_uri)} chars)"
393 | )
394 |
395 | return {
396 | "ticker": ticker,
397 | "chart_data": data_uri,
398 | "chart_format": config["format"],
399 | "chart_size": {
400 | "height": config["height"],
401 | "width": config["width"],
402 | },
403 | "data_points": len(df),
404 | "status": "completed",
405 | "timestamp": datetime.now(UTC).isoformat(),
406 | }
407 | else:
408 | logger.warning(
409 | f"Chart too large ({len(data_uri)} chars), trying smaller config"
410 | )
411 |
412 | except Exception as e:
413 | logger.warning(f"Chart generation attempt {i + 1} failed: {e}")
414 | if i == len(chart_configs) - 1: # Last attempt
415 | raise TechnicalAnalysisError(
416 | f"All chart generation attempts failed: {e}"
417 | )
418 |
419 | raise TechnicalAnalysisError(
420 | "Chart generation failed - all size configurations exceeded limits"
421 | )
422 |
423 |
424 | # Export functions for registration with FastMCP
425 | __all__ = [
426 | "technical_enhanced_router",
427 | "get_full_technical_analysis_enhanced",
428 | "get_stock_chart_analysis_enhanced",
429 | ]
430 |
```