This is page 14 of 39. Use http://codebase.md/wshobson/maverick-mcp?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.md
│ │ ├── config.yml
│ │ ├── feature_request.md
│ │ ├── question.md
│ │ └── security_report.md
│ ├── pull_request_template.md
│ └── workflows
│ ├── claude-code-review.yml
│ └── claude.yml
├── .gitignore
├── .python-version
├── .vscode
│ ├── launch.json
│ └── settings.json
├── alembic
│ ├── env.py
│ ├── script.py.mako
│ └── versions
│ ├── 001_initial_schema.py
│ ├── 003_add_performance_indexes.py
│ ├── 006_rename_metadata_columns.py
│ ├── 008_performance_optimization_indexes.py
│ ├── 009_rename_to_supply_demand.py
│ ├── 010_self_contained_schema.py
│ ├── 011_remove_proprietary_terms.py
│ ├── 013_add_backtest_persistence_models.py
│ ├── 014_add_portfolio_models.py
│ ├── 08e3945a0c93_merge_heads.py
│ ├── 9374a5c9b679_merge_heads_for_testing.py
│ ├── abf9b9afb134_merge_multiple_heads.py
│ ├── adda6d3fd84b_merge_proprietary_terms_removal_with_.py
│ ├── e0c75b0bdadb_fix_financial_data_precision_only.py
│ ├── f0696e2cac15_add_essential_performance_indexes.py
│ └── fix_database_integrity_issues.py
├── alembic.ini
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── DATABASE_SETUP.md
├── docker-compose.override.yml.example
├── docker-compose.yml
├── Dockerfile
├── docs
│ ├── api
│ │ └── backtesting.md
│ ├── BACKTESTING.md
│ ├── COST_BASIS_SPECIFICATION.md
│ ├── deep_research_agent.md
│ ├── exa_research_testing_strategy.md
│ ├── PORTFOLIO_PERSONALIZATION_PLAN.md
│ ├── PORTFOLIO.md
│ ├── SETUP_SELF_CONTAINED.md
│ └── speed_testing_framework.md
├── examples
│ ├── complete_speed_validation.py
│ ├── deep_research_integration.py
│ ├── llm_optimization_example.py
│ ├── llm_speed_demo.py
│ ├── monitoring_example.py
│ ├── parallel_research_example.py
│ ├── speed_optimization_demo.py
│ └── timeout_fix_demonstration.py
├── LICENSE
├── Makefile
├── MANIFEST.in
├── maverick_mcp
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── circuit_breaker.py
│ │ ├── deep_research.py
│ │ ├── market_analysis.py
│ │ ├── optimized_research.py
│ │ ├── supervisor.py
│ │ └── technical_analysis.py
│ ├── api
│ │ ├── __init__.py
│ │ ├── api_server.py
│ │ ├── connection_manager.py
│ │ ├── dependencies
│ │ │ ├── __init__.py
│ │ │ ├── stock_analysis.py
│ │ │ └── technical_analysis.py
│ │ ├── error_handling.py
│ │ ├── inspector_compatible_sse.py
│ │ ├── inspector_sse.py
│ │ ├── middleware
│ │ │ ├── error_handling.py
│ │ │ ├── mcp_logging.py
│ │ │ ├── rate_limiting_enhanced.py
│ │ │ └── security.py
│ │ ├── openapi_config.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── agents.py
│ │ │ ├── backtesting.py
│ │ │ ├── data_enhanced.py
│ │ │ ├── data.py
│ │ │ ├── health_enhanced.py
│ │ │ ├── health_tools.py
│ │ │ ├── health.py
│ │ │ ├── intelligent_backtesting.py
│ │ │ ├── introspection.py
│ │ │ ├── mcp_prompts.py
│ │ │ ├── monitoring.py
│ │ │ ├── news_sentiment_enhanced.py
│ │ │ ├── performance.py
│ │ │ ├── portfolio.py
│ │ │ ├── research.py
│ │ │ ├── screening_ddd.py
│ │ │ ├── screening_parallel.py
│ │ │ ├── screening.py
│ │ │ ├── technical_ddd.py
│ │ │ ├── technical_enhanced.py
│ │ │ ├── technical.py
│ │ │ └── tool_registry.py
│ │ ├── server.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── base_service.py
│ │ │ ├── market_service.py
│ │ │ ├── portfolio_service.py
│ │ │ ├── prompt_service.py
│ │ │ └── resource_service.py
│ │ ├── simple_sse.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── insomnia_export.py
│ │ └── postman_export.py
│ ├── application
│ │ ├── __init__.py
│ │ ├── commands
│ │ │ └── __init__.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_dto.py
│ │ ├── queries
│ │ │ ├── __init__.py
│ │ │ └── get_technical_analysis.py
│ │ └── screening
│ │ ├── __init__.py
│ │ ├── dtos.py
│ │ └── queries.py
│ ├── backtesting
│ │ ├── __init__.py
│ │ ├── ab_testing.py
│ │ ├── analysis.py
│ │ ├── batch_processing_stub.py
│ │ ├── batch_processing.py
│ │ ├── model_manager.py
│ │ ├── optimization.py
│ │ ├── persistence.py
│ │ ├── retraining_pipeline.py
│ │ ├── strategies
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── ml
│ │ │ │ ├── __init__.py
│ │ │ │ ├── adaptive.py
│ │ │ │ ├── ensemble.py
│ │ │ │ ├── feature_engineering.py
│ │ │ │ └── regime_aware.py
│ │ │ ├── ml_strategies.py
│ │ │ ├── parser.py
│ │ │ └── templates.py
│ │ ├── strategy_executor.py
│ │ ├── vectorbt_engine.py
│ │ └── visualization.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── database_self_contained.py
│ │ ├── database.py
│ │ ├── llm_optimization_config.py
│ │ ├── logging_settings.py
│ │ ├── plotly_config.py
│ │ ├── security_utils.py
│ │ ├── security.py
│ │ ├── settings.py
│ │ ├── technical_constants.py
│ │ ├── tool_estimation.py
│ │ └── validation.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── technical_analysis.py
│ │ └── visualization.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── cache_manager.py
│ │ ├── cache.py
│ │ ├── django_adapter.py
│ │ ├── health.py
│ │ ├── models.py
│ │ ├── performance.py
│ │ ├── session_management.py
│ │ └── validation.py
│ ├── database
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── optimization.py
│ ├── dependencies.py
│ ├── domain
│ │ ├── __init__.py
│ │ ├── entities
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis.py
│ │ ├── events
│ │ │ └── __init__.py
│ │ ├── portfolio.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ ├── entities.py
│ │ │ ├── services.py
│ │ │ └── value_objects.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_service.py
│ │ ├── stock_analysis
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis_service.py
│ │ └── value_objects
│ │ ├── __init__.py
│ │ └── technical_indicators.py
│ ├── exceptions.py
│ ├── infrastructure
│ │ ├── __init__.py
│ │ ├── cache
│ │ │ └── __init__.py
│ │ ├── caching
│ │ │ ├── __init__.py
│ │ │ └── cache_management_service.py
│ │ ├── connection_manager.py
│ │ ├── data_fetching
│ │ │ ├── __init__.py
│ │ │ └── stock_data_service.py
│ │ ├── health
│ │ │ ├── __init__.py
│ │ │ └── health_checker.py
│ │ ├── persistence
│ │ │ ├── __init__.py
│ │ │ └── stock_repository.py
│ │ ├── providers
│ │ │ └── __init__.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ └── repositories.py
│ │ └── sse_optimizer.py
│ ├── langchain_tools
│ │ ├── __init__.py
│ │ ├── adapters.py
│ │ └── registry.py
│ ├── logging_config.py
│ ├── memory
│ │ ├── __init__.py
│ │ └── stores.py
│ ├── monitoring
│ │ ├── __init__.py
│ │ ├── health_check.py
│ │ ├── health_monitor.py
│ │ ├── integration_example.py
│ │ ├── metrics.py
│ │ ├── middleware.py
│ │ └── status_dashboard.py
│ ├── providers
│ │ ├── __init__.py
│ │ ├── dependencies.py
│ │ ├── factories
│ │ │ ├── __init__.py
│ │ │ ├── config_factory.py
│ │ │ └── provider_factory.py
│ │ ├── implementations
│ │ │ ├── __init__.py
│ │ │ ├── cache_adapter.py
│ │ │ ├── macro_data_adapter.py
│ │ │ ├── market_data_adapter.py
│ │ │ ├── persistence_adapter.py
│ │ │ └── stock_data_adapter.py
│ │ ├── interfaces
│ │ │ ├── __init__.py
│ │ │ ├── cache.py
│ │ │ ├── config.py
│ │ │ ├── macro_data.py
│ │ │ ├── market_data.py
│ │ │ ├── persistence.py
│ │ │ └── stock_data.py
│ │ ├── llm_factory.py
│ │ ├── macro_data.py
│ │ ├── market_data.py
│ │ ├── mocks
│ │ │ ├── __init__.py
│ │ │ ├── mock_cache.py
│ │ │ ├── mock_config.py
│ │ │ ├── mock_macro_data.py
│ │ │ ├── mock_market_data.py
│ │ │ ├── mock_persistence.py
│ │ │ └── mock_stock_data.py
│ │ ├── openrouter_provider.py
│ │ ├── optimized_screening.py
│ │ ├── optimized_stock_data.py
│ │ └── stock_data.py
│ ├── README.md
│ ├── tests
│ │ ├── __init__.py
│ │ ├── README_INMEMORY_TESTS.md
│ │ ├── test_cache_debug.py
│ │ ├── test_fixes_validation.py
│ │ ├── test_in_memory_routers.py
│ │ ├── test_in_memory_server.py
│ │ ├── test_macro_data_provider.py
│ │ ├── test_mailgun_email.py
│ │ ├── test_market_calendar_caching.py
│ │ ├── test_mcp_tool_fixes_pytest.py
│ │ ├── test_mcp_tool_fixes.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_models_functional.py
│ │ ├── test_server.py
│ │ ├── test_stock_data_enhanced.py
│ │ ├── test_stock_data_provider.py
│ │ └── test_technical_analysis.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── performance_monitoring.py
│ │ ├── portfolio_manager.py
│ │ ├── risk_management.py
│ │ └── sentiment_analysis.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── agent_errors.py
│ │ ├── batch_processing.py
│ │ ├── cache_warmer.py
│ │ ├── circuit_breaker_decorators.py
│ │ ├── circuit_breaker_services.py
│ │ ├── circuit_breaker.py
│ │ ├── data_chunking.py
│ │ ├── database_monitoring.py
│ │ ├── debug_utils.py
│ │ ├── fallback_strategies.py
│ │ ├── llm_optimization.py
│ │ ├── logging_example.py
│ │ ├── logging_init.py
│ │ ├── logging.py
│ │ ├── mcp_logging.py
│ │ ├── memory_profiler.py
│ │ ├── monitoring_middleware.py
│ │ ├── monitoring.py
│ │ ├── orchestration_logging.py
│ │ ├── parallel_research.py
│ │ ├── parallel_screening.py
│ │ ├── quick_cache.py
│ │ ├── resource_manager.py
│ │ ├── shutdown.py
│ │ ├── stock_helpers.py
│ │ ├── structured_logger.py
│ │ ├── tool_monitoring.py
│ │ ├── tracing.py
│ │ └── yfinance_pool.py
│ ├── validation
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── data.py
│ │ ├── middleware.py
│ │ ├── portfolio.py
│ │ ├── responses.py
│ │ ├── screening.py
│ │ └── technical.py
│ └── workflows
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── market_analyzer.py
│ │ ├── optimizer_agent.py
│ │ ├── strategy_selector.py
│ │ └── validator_agent.py
│ ├── backtesting_workflow.py
│ └── state.py
├── PLANS.md
├── pyproject.toml
├── pyrightconfig.json
├── README.md
├── scripts
│ ├── dev.sh
│ ├── INSTALLATION_GUIDE.md
│ ├── load_example.py
│ ├── load_market_data.py
│ ├── load_tiingo_data.py
│ ├── migrate_db.py
│ ├── README_TIINGO_LOADER.md
│ ├── requirements_tiingo.txt
│ ├── run_stock_screening.py
│ ├── run-migrations.sh
│ ├── seed_db.py
│ ├── seed_sp500.py
│ ├── setup_database.sh
│ ├── setup_self_contained.py
│ ├── setup_sp500_database.sh
│ ├── test_seeded_data.py
│ ├── test_tiingo_loader.py
│ ├── tiingo_config.py
│ └── validate_setup.py
├── SECURITY.md
├── server.json
├── setup.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── core
│ │ └── test_technical_analysis.py
│ ├── data
│ │ └── test_portfolio_models.py
│ ├── domain
│ │ ├── conftest.py
│ │ ├── test_portfolio_entities.py
│ │ └── test_technical_analysis_service.py
│ ├── fixtures
│ │ └── orchestration_fixtures.py
│ ├── integration
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── README.md
│ │ ├── run_integration_tests.sh
│ │ ├── test_api_technical.py
│ │ ├── test_chaos_engineering.py
│ │ ├── test_config_management.py
│ │ ├── test_full_backtest_workflow_advanced.py
│ │ ├── test_full_backtest_workflow.py
│ │ ├── test_high_volume.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_orchestration_complete.py
│ │ ├── test_portfolio_persistence.py
│ │ ├── test_redis_cache.py
│ │ ├── test_security_integration.py.disabled
│ │ └── vcr_setup.py
│ ├── performance
│ │ ├── __init__.py
│ │ ├── test_benchmarks.py
│ │ ├── test_load.py
│ │ ├── test_profiling.py
│ │ └── test_stress.py
│ ├── providers
│ │ └── test_stock_data_simple.py
│ ├── README.md
│ ├── test_agents_router_mcp.py
│ ├── test_backtest_persistence.py
│ ├── test_cache_management_service.py
│ ├── test_cache_serialization.py
│ ├── test_circuit_breaker.py
│ ├── test_database_pool_config_simple.py
│ ├── test_database_pool_config.py
│ ├── test_deep_research_functional.py
│ ├── test_deep_research_integration.py
│ ├── test_deep_research_parallel_execution.py
│ ├── test_error_handling.py
│ ├── test_event_loop_integrity.py
│ ├── test_exa_research_integration.py
│ ├── test_exception_hierarchy.py
│ ├── test_financial_search.py
│ ├── test_graceful_shutdown.py
│ ├── test_integration_simple.py
│ ├── test_langgraph_workflow.py
│ ├── test_market_data_async.py
│ ├── test_market_data_simple.py
│ ├── test_mcp_orchestration_functional.py
│ ├── test_ml_strategies.py
│ ├── test_optimized_research_agent.py
│ ├── test_orchestration_integration.py
│ ├── test_orchestration_logging.py
│ ├── test_orchestration_tools_simple.py
│ ├── test_parallel_research_integration.py
│ ├── test_parallel_research_orchestrator.py
│ ├── test_parallel_research_performance.py
│ ├── test_performance_optimizations.py
│ ├── test_production_validation.py
│ ├── test_provider_architecture.py
│ ├── test_rate_limiting_enhanced.py
│ ├── test_runner_validation.py
│ ├── test_security_comprehensive.py.disabled
│ ├── test_security_cors.py
│ ├── test_security_enhancements.py.disabled
│ ├── test_security_headers.py
│ ├── test_security_penetration.py
│ ├── test_session_management.py
│ ├── test_speed_optimization_validation.py
│ ├── test_stock_analysis_dependencies.py
│ ├── test_stock_analysis_service.py
│ ├── test_stock_data_fetching_service.py
│ ├── test_supervisor_agent.py
│ ├── test_supervisor_functional.py
│ ├── test_tool_estimation_config.py
│ ├── test_visualization.py
│ └── utils
│ ├── test_agent_errors.py
│ ├── test_logging.py
│ ├── test_parallel_screening.py
│ └── test_quick_cache.py
├── tools
│ ├── check_orchestration_config.py
│ ├── experiments
│ │ ├── validation_examples.py
│ │ └── validation_fixed.py
│ ├── fast_dev.sh
│ ├── hot_reload.py
│ ├── quick_test.py
│ └── templates
│ ├── new_router_template.py
│ ├── new_tool_template.py
│ ├── screening_strategy_template.py
│ └── test_template.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/maverick_mcp/tests/test_mcp_tools.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Comprehensive tests for all MCP tool functions in Maverick-MCP.
3 |
4 | This module tests all public MCP tools exposed by the server including:
5 | - Stock data fetching
6 | - Technical analysis
7 | - Risk analysis
8 | - Chart generation
9 | - News sentiment
10 | - Multi-ticker comparison
11 | """
12 |
13 | from datetime import datetime
14 | from unittest.mock import MagicMock, patch
15 |
16 | import numpy as np
17 | import pandas as pd
18 | import pytest
19 | from fastmcp import Client
20 |
21 | from maverick_mcp.api.server import mcp
22 |
23 |
24 | class TestMCPTools:
25 | """Test suite for all MCP tool functions using the new router structure."""
26 |
27 | @pytest.fixture
28 | def mock_stock_data(self):
29 | """Create sample stock data for testing."""
30 | dates = pd.date_range(end=datetime.now(), periods=250, freq="D")
31 | return pd.DataFrame(
32 | {
33 | "Open": np.random.uniform(90, 110, 250),
34 | "High": np.random.uniform(95, 115, 250),
35 | "Low": np.random.uniform(85, 105, 250),
36 | "Close": np.random.uniform(90, 110, 250),
37 | "Volume": np.random.randint(1000000, 10000000, 250),
38 | },
39 | index=dates,
40 | )
41 |
42 | @pytest.mark.asyncio
43 | async def test_fetch_stock_data(self, mock_stock_data):
44 | """Test basic stock data fetching."""
45 | with patch(
46 | "maverick_mcp.providers.stock_data.StockDataProvider.get_stock_data"
47 | ) as mock_get:
48 | mock_get.return_value = mock_stock_data
49 |
50 | async with Client(mcp) as client:
51 | result = await client.call_tool(
52 | "/data_fetch_stock_data",
53 | {
54 | "request": {
55 | "ticker": "AAPL",
56 | "start_date": "2024-01-01",
57 | "end_date": "2024-01-31",
58 | }
59 | },
60 | )
61 |
62 | assert result[0].text is not None
63 | data = eval(result[0].text)
64 | assert "ticker" in data
65 | assert data["ticker"] == "AAPL"
66 | assert "record_count" in data
67 | assert data["record_count"] == 250
68 |
69 | @pytest.mark.asyncio
70 | async def test_rsi_analysis(self, mock_stock_data):
71 | """Test RSI technical analysis."""
72 | with patch(
73 | "maverick_mcp.providers.stock_data.StockDataProvider.get_stock_data"
74 | ) as mock_get:
75 | mock_get.return_value = mock_stock_data
76 |
77 | async with Client(mcp) as client:
78 | result = await client.call_tool(
79 | "/technical_get_rsi_analysis", {"ticker": "AAPL", "period": 14}
80 | )
81 |
82 | assert result[0].text is not None
83 | data = eval(result[0].text)
84 | assert "analysis" in data
85 | assert "ticker" in data
86 | assert data["ticker"] == "AAPL"
87 | assert "current" in data["analysis"]
88 | assert "signal" in data["analysis"]
89 | assert data["analysis"]["signal"] in [
90 | "oversold",
91 | "neutral",
92 | "overbought",
93 | "bullish",
94 | "bearish",
95 | ]
96 |
97 | @pytest.mark.asyncio
98 | async def test_macd_analysis(self, mock_stock_data):
99 | """Test MACD technical analysis."""
100 | with patch(
101 | "maverick_mcp.providers.stock_data.StockDataProvider.get_stock_data"
102 | ) as mock_get:
103 | mock_get.return_value = mock_stock_data
104 |
105 | async with Client(mcp) as client:
106 | result = await client.call_tool(
107 | "/technical_get_macd_analysis", {"ticker": "MSFT"}
108 | )
109 |
110 | assert result[0].text is not None
111 | data = eval(result[0].text)
112 | assert "analysis" in data
113 | assert "ticker" in data
114 | assert data["ticker"] == "MSFT"
115 | assert "macd" in data["analysis"]
116 | assert "signal" in data["analysis"]
117 | assert "histogram" in data["analysis"]
118 | assert "indicator" in data["analysis"]
119 |
120 | @pytest.mark.asyncio
121 | async def test_support_resistance(self, mock_stock_data):
122 | """Test support and resistance level detection."""
123 | with patch(
124 | "maverick_mcp.providers.stock_data.StockDataProvider.get_stock_data"
125 | ) as mock_get:
126 | # Create data with clear support/resistance levels
127 | mock_data = mock_stock_data.copy()
128 | mock_data["High"] = [105 if i % 20 < 10 else 110 for i in range(250)]
129 | mock_data["Low"] = [95 if i % 20 < 10 else 100 for i in range(250)]
130 | mock_data["Close"] = [100 if i % 20 < 10 else 105 for i in range(250)]
131 | mock_get.return_value = mock_data
132 |
133 | async with Client(mcp) as client:
134 | result = await client.call_tool(
135 | "/technical_get_support_resistance", {"ticker": "GOOGL"}
136 | )
137 |
138 | assert result[0].text is not None
139 | data = eval(result[0].text)
140 | assert "support_levels" in data
141 | assert "resistance_levels" in data
142 | assert len(data["support_levels"]) > 0
143 | assert len(data["resistance_levels"]) > 0
144 |
145 | @pytest.mark.asyncio
146 | async def test_batch_stock_data(self, mock_stock_data):
147 | """Test batch stock data fetching."""
148 | with patch(
149 | "maverick_mcp.providers.stock_data.StockDataProvider.get_stock_data"
150 | ) as mock_get:
151 | mock_get.return_value = mock_stock_data
152 |
153 | async with Client(mcp) as client:
154 | result = await client.call_tool(
155 | "/data_fetch_stock_data_batch",
156 | {
157 | "request": {
158 | "tickers": ["AAPL", "MSFT", "GOOGL"],
159 | "start_date": "2024-01-01",
160 | "end_date": "2024-01-31",
161 | }
162 | },
163 | )
164 |
165 | assert result[0].text is not None
166 | data = eval(result[0].text)
167 | assert "results" in data
168 | assert "success_count" in data
169 | assert "error_count" in data
170 | assert len(data["results"]) == 3
171 | assert data["success_count"] == 3
172 | assert data["error_count"] == 0
173 |
174 | @pytest.mark.asyncio
175 | async def test_portfolio_risk_analysis(self, mock_stock_data):
176 | """Test portfolio risk analysis."""
177 | with patch(
178 | "maverick_mcp.providers.stock_data.StockDataProvider.get_stock_data"
179 | ) as mock_get:
180 | # Create correlated stock data
181 | base_returns = np.random.normal(0.001, 0.02, 250)
182 | mock_data1 = mock_stock_data.copy()
183 | mock_data2 = mock_stock_data.copy()
184 | mock_data3 = mock_stock_data.copy()
185 |
186 | # Apply correlated returns and ensure lowercase column names
187 | mock_data1.columns = mock_data1.columns.str.lower()
188 | mock_data2.columns = mock_data2.columns.str.lower()
189 | mock_data3.columns = mock_data3.columns.str.lower()
190 |
191 | mock_data1["close"] = 100 * np.exp(np.cumsum(base_returns))
192 | mock_data2["close"] = 100 * np.exp(
193 | np.cumsum(base_returns * 0.8 + np.random.normal(0, 0.01, 250))
194 | )
195 | mock_data3["close"] = 100 * np.exp(
196 | np.cumsum(base_returns * 0.6 + np.random.normal(0, 0.015, 250))
197 | )
198 |
199 | mock_get.return_value = mock_data1
200 |
201 | async with Client(mcp) as client:
202 | result = await client.call_tool(
203 | "/portfolio_risk_adjusted_analysis",
204 | {"ticker": "AAPL", "risk_level": 50.0},
205 | )
206 |
207 | assert result[0].text is not None
208 | data = eval(result[0].text)
209 | assert "ticker" in data
210 | assert "risk_level" in data
211 | assert "position_sizing" in data
212 | assert "risk_management" in data
213 |
214 | @pytest.mark.asyncio
215 | async def test_maverick_screening(self):
216 | """Test Maverick stock screening."""
217 | with (
218 | patch("maverick_mcp.data.models.SessionLocal") as mock_session_cls,
219 | patch(
220 | "maverick_mcp.data.models.MaverickStocks.get_top_stocks"
221 | ) as mock_get_stocks,
222 | ):
223 | # Mock database session (not used but needed for session lifecycle)
224 | _ = mock_session_cls.return_value.__enter__.return_value
225 |
226 | # Mock return data
227 | class MockStock1:
228 | def to_dict(self):
229 | return {
230 | "stock": "AAPL",
231 | "close": 150.0,
232 | "combined_score": 92,
233 | "momentum_score": 88,
234 | "adr_pct": 2.5,
235 | }
236 |
237 | class MockStock2:
238 | def to_dict(self):
239 | return {
240 | "stock": "MSFT",
241 | "close": 300.0,
242 | "combined_score": 89,
243 | "momentum_score": 85,
244 | "adr_pct": 2.1,
245 | }
246 |
247 | mock_get_stocks.return_value = [MockStock1(), MockStock2()]
248 |
249 | async with Client(mcp) as client:
250 | result = await client.call_tool(
251 | "/screening_get_maverick_stocks", {"limit": 10}
252 | )
253 |
254 | assert result[0].text is not None
255 | data = eval(result[0].text)
256 | assert "stocks" in data
257 | assert len(data["stocks"]) == 2
258 | assert data["stocks"][0]["stock"] == "AAPL"
259 |
260 | @pytest.mark.asyncio
261 | async def test_news_sentiment(self):
262 | """Test news sentiment analysis."""
263 | with (
264 | patch("requests.get") as mock_get,
265 | patch(
266 | "maverick_mcp.config.settings.settings.external_data.api_key",
267 | "test_api_key",
268 | ),
269 | patch(
270 | "maverick_mcp.config.settings.settings.external_data.base_url",
271 | "https://test-api.com",
272 | ),
273 | ):
274 | mock_response = MagicMock()
275 | mock_response.status_code = 200
276 | mock_response.json.return_value = {
277 | "articles": [
278 | {
279 | "title": "Apple hits new highs",
280 | "url": "https://example.com/1",
281 | "summary": "Positive news about Apple",
282 | "banner_image": "https://example.com/image1.jpg",
283 | "time_published": "20240115T100000",
284 | "overall_sentiment_score": 0.8,
285 | "overall_sentiment_label": "Bullish",
286 | }
287 | ]
288 | }
289 | mock_get.return_value = mock_response
290 |
291 | async with Client(mcp) as client:
292 | result = await client.call_tool(
293 | "/data_get_news_sentiment", {"request": {"ticker": "AAPL"}}
294 | )
295 |
296 | assert result[0].text is not None
297 | data = eval(result[0].text)
298 | assert "articles" in data
299 | assert len(data["articles"]) > 0
300 | assert data["articles"][0]["overall_sentiment_label"] == "Bullish"
301 |
302 | @pytest.mark.asyncio
303 | async def test_full_technical_analysis(self, mock_stock_data):
304 | """Test comprehensive technical analysis."""
305 | with patch(
306 | "maverick_mcp.providers.stock_data.StockDataProvider.get_stock_data"
307 | ) as mock_get:
308 | # Ensure lowercase column names for technical analysis
309 | mock_data_lowercase = mock_stock_data.copy()
310 | mock_data_lowercase.columns = mock_data_lowercase.columns.str.lower()
311 | mock_get.return_value = mock_data_lowercase
312 |
313 | async with Client(mcp) as client:
314 | result = await client.call_tool(
315 | "/technical_get_full_technical_analysis", {"ticker": "AAPL"}
316 | )
317 |
318 | assert result[0].text is not None
319 | data = eval(result[0].text)
320 | assert "indicators" in data
321 | assert "rsi" in data["indicators"]
322 | assert "macd" in data["indicators"]
323 | assert "bollinger_bands" in data["indicators"]
324 | assert "levels" in data
325 | assert "current_price" in data
326 | assert "last_updated" in data
327 |
328 | @pytest.mark.asyncio
329 | async def test_error_handling(self):
330 | """Test error handling for invalid requests."""
331 | async with Client(mcp) as client:
332 | # Test invalid ticker format
333 | with pytest.raises(Exception) as exc_info:
334 | await client.call_tool(
335 | "/data_fetch_stock_data",
336 | {
337 | "request": {
338 | "ticker": "INVALIDTICKER", # Too long (max 10 chars)
339 | "start_date": "2024-01-01",
340 | "end_date": "2024-01-31",
341 | }
342 | },
343 | )
344 | assert "validation error" in str(exc_info.value).lower()
345 |
346 | # Test invalid date range
347 | with pytest.raises(Exception) as exc_info:
348 | await client.call_tool(
349 | "/data_fetch_stock_data",
350 | {
351 | "request": {
352 | "ticker": "AAPL",
353 | "start_date": "2024-12-31",
354 | "end_date": "2024-01-01", # End before start
355 | }
356 | },
357 | )
358 | assert (
359 | "end date" in str(exc_info.value).lower()
360 | and "start date" in str(exc_info.value).lower()
361 | )
362 |
363 | @pytest.mark.asyncio
364 | async def test_caching_behavior(self, mock_stock_data):
365 | """Test that caching reduces API calls."""
366 | call_count = 0
367 |
368 | def mock_get_data(*args, **kwargs):
369 | nonlocal call_count
370 | call_count += 1
371 | return mock_stock_data
372 |
373 | with patch(
374 | "maverick_mcp.providers.stock_data.StockDataProvider.get_stock_data",
375 | side_effect=mock_get_data,
376 | ):
377 | async with Client(mcp) as client:
378 | # First call
379 | await client.call_tool(
380 | "/data_fetch_stock_data",
381 | {
382 | "request": {
383 | "ticker": "AAPL",
384 | "start_date": "2024-01-01",
385 | "end_date": "2024-01-31",
386 | }
387 | },
388 | )
389 | assert call_count == 1
390 |
391 | # Second call with same parameters should hit cache
392 | await client.call_tool(
393 | "/data_fetch_stock_data",
394 | {
395 | "request": {
396 | "ticker": "AAPL",
397 | "start_date": "2024-01-01",
398 | "end_date": "2024-01-31",
399 | }
400 | },
401 | )
402 | # Note: In test environment without actual caching infrastructure,
403 | # the call count may be 2. This is expected behavior.
404 | assert call_count <= 2
405 |
406 |
407 | if __name__ == "__main__":
408 | pytest.main([__file__, "-v"])
409 |
```
--------------------------------------------------------------------------------
/maverick_mcp/database/optimization.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Database optimization module for parallel backtesting performance.
3 | Implements query optimization, bulk operations, and performance monitoring.
4 | """
5 |
6 | import logging
7 | import time
8 | from contextlib import contextmanager
9 | from typing import Any
10 |
11 | import pandas as pd
12 | from sqlalchemy import Index, event, text
13 | from sqlalchemy.engine import Engine
14 | from sqlalchemy.orm import Session
15 |
16 | from maverick_mcp.data.models import (
17 | PriceCache,
18 | SessionLocal,
19 | Stock,
20 | engine,
21 | )
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 |
26 | class QueryOptimizer:
27 | """Database query optimization for backtesting performance."""
28 |
29 | def __init__(self, session_factory=None):
30 | """Initialize query optimizer."""
31 | self.session_factory = session_factory or SessionLocal
32 | self._query_stats = {}
33 | self._connection_pool_stats = {
34 | "active_connections": 0,
35 | "checked_out": 0,
36 | "total_queries": 0,
37 | "slow_queries": 0,
38 | }
39 |
40 | def create_backtesting_indexes(self, engine: Engine):
41 | """
42 | Create optimized indexes for backtesting queries.
43 |
44 | These indexes are specifically designed for the parallel backtesting
45 | workload patterns.
46 | """
47 | logger.info("Creating backtesting optimization indexes...")
48 |
49 | # Define additional indexes for common backtesting query patterns
50 | additional_indexes = [
51 | # Composite index for date range queries with symbol lookup
52 | Index(
53 | "mcp_price_cache_symbol_date_range_idx",
54 | Stock.__table__.c.ticker_symbol,
55 | PriceCache.__table__.c.date,
56 | PriceCache.__table__.c.close_price,
57 | ),
58 | # Index for volume-based queries (common in strategy analysis)
59 | Index(
60 | "mcp_price_cache_volume_date_idx",
61 | PriceCache.__table__.c.volume,
62 | PriceCache.__table__.c.date,
63 | ),
64 | # Covering index for OHLCV queries (includes all price data)
65 | Index(
66 | "mcp_price_cache_ohlcv_covering_idx",
67 | PriceCache.__table__.c.stock_id,
68 | PriceCache.__table__.c.date,
69 | # Include all price columns as covering columns
70 | PriceCache.__table__.c.open_price,
71 | PriceCache.__table__.c.high_price,
72 | PriceCache.__table__.c.low_price,
73 | PriceCache.__table__.c.close_price,
74 | PriceCache.__table__.c.volume,
75 | ),
76 | # Index for latest price queries
77 | Index(
78 | "mcp_price_cache_latest_price_idx",
79 | PriceCache.__table__.c.stock_id,
80 | PriceCache.__table__.c.date.desc(),
81 | ),
82 | # Partial index for recent data (last 2 years) - most commonly queried
83 | # Note: This is PostgreSQL-specific, will be skipped for SQLite
84 | ]
85 |
86 | try:
87 | with engine.connect() as conn:
88 | # Check if we're using PostgreSQL for partial indexes
89 | is_postgresql = engine.dialect.name == "postgresql"
90 |
91 | for index in additional_indexes:
92 | try:
93 | # Skip PostgreSQL-specific features on SQLite
94 | if not is_postgresql and "partial" in str(index).lower():
95 | continue
96 |
97 | # Create index if it doesn't exist
98 | index.create(conn, checkfirst=True)
99 | logger.info(f"Created index: {index.name}")
100 |
101 | except Exception as e:
102 | logger.warning(f"Failed to create index {index.name}: {e}")
103 |
104 | # Add PostgreSQL-specific optimizations
105 | if is_postgresql:
106 | try:
107 | # Create partial index for recent data (last 2 years)
108 | conn.execute(
109 | text("""
110 | CREATE INDEX CONCURRENTLY IF NOT EXISTS mcp_price_cache_recent_data_idx
111 | ON mcp_price_cache (stock_id, date DESC, close_price)
112 | WHERE date >= CURRENT_DATE - INTERVAL '2 years'
113 | """)
114 | )
115 | logger.info("Created partial index for recent data")
116 |
117 | # Update table statistics for better query planning
118 | conn.execute(text("ANALYZE mcp_price_cache"))
119 | conn.execute(text("ANALYZE mcp_stocks"))
120 | logger.info("Updated table statistics")
121 |
122 | except Exception as e:
123 | logger.warning(
124 | f"Failed to create PostgreSQL optimizations: {e}"
125 | )
126 |
127 | conn.commit()
128 |
129 | except Exception as e:
130 | logger.error(f"Failed to create backtesting indexes: {e}")
131 |
132 | def optimize_connection_pool(self, engine: Engine):
133 | """Optimize connection pool settings for parallel operations."""
134 | logger.info("Optimizing connection pool for parallel backtesting...")
135 |
136 | # Add connection pool event listeners for monitoring
137 | @event.listens_for(engine, "connect")
138 | def receive_connect(dbapi_connection, connection_record):
139 | self._connection_pool_stats["active_connections"] += 1
140 |
141 | @event.listens_for(engine, "checkout")
142 | def receive_checkout(dbapi_connection, connection_record, connection_proxy):
143 | self._connection_pool_stats["checked_out"] += 1
144 |
145 | @event.listens_for(engine, "checkin")
146 | def receive_checkin(dbapi_connection, connection_record):
147 | self._connection_pool_stats["checked_out"] -= 1
148 |
149 | def create_bulk_insert_method(self):
150 | """Create optimized bulk insert method for price data."""
151 |
152 | def bulk_insert_price_data_optimized(
153 | session: Session,
154 | price_data_list: list[dict[str, Any]],
155 | batch_size: int = 1000,
156 | ):
157 | """
158 | Optimized bulk insert for price data with batching.
159 |
160 | Args:
161 | session: Database session
162 | price_data_list: List of price data dictionaries
163 | batch_size: Number of records per batch
164 | """
165 | if not price_data_list:
166 | return
167 |
168 | logger.info(f"Bulk inserting {len(price_data_list)} price records")
169 | start_time = time.time()
170 |
171 | try:
172 | # Process in batches to avoid memory issues
173 | for i in range(0, len(price_data_list), batch_size):
174 | batch = price_data_list[i : i + batch_size]
175 |
176 | # Use bulk_insert_mappings for better performance
177 | session.bulk_insert_mappings(PriceCache, batch)
178 |
179 | # Commit each batch to free up memory
180 | if i + batch_size < len(price_data_list):
181 | session.flush()
182 |
183 | session.commit()
184 |
185 | elapsed = time.time() - start_time
186 | logger.info(
187 | f"Bulk insert completed in {elapsed:.2f}s "
188 | f"({len(price_data_list) / elapsed:.0f} records/sec)"
189 | )
190 |
191 | except Exception as e:
192 | logger.error(f"Bulk insert failed: {e}")
193 | session.rollback()
194 | raise
195 |
196 | return bulk_insert_price_data_optimized
197 |
198 | @contextmanager
199 | def query_performance_monitor(self, query_name: str):
200 | """Context manager for monitoring query performance."""
201 | start_time = time.time()
202 |
203 | try:
204 | yield
205 | finally:
206 | elapsed = time.time() - start_time
207 |
208 | # Track query statistics
209 | if query_name not in self._query_stats:
210 | self._query_stats[query_name] = {
211 | "count": 0,
212 | "total_time": 0.0,
213 | "avg_time": 0.0,
214 | "max_time": 0.0,
215 | "slow_queries": 0,
216 | }
217 |
218 | stats = self._query_stats[query_name]
219 | stats["count"] += 1
220 | stats["total_time"] += elapsed
221 | stats["avg_time"] = stats["total_time"] / stats["count"]
222 | stats["max_time"] = max(stats["max_time"], elapsed)
223 |
224 | # Mark slow queries (> 1 second)
225 | if elapsed > 1.0:
226 | stats["slow_queries"] += 1
227 | self._connection_pool_stats["slow_queries"] += 1
228 | logger.warning(f"Slow query detected: {query_name} took {elapsed:.2f}s")
229 |
230 | self._connection_pool_stats["total_queries"] += 1
231 |
232 | def get_optimized_price_query(self) -> str:
233 | """Get optimized SQL query for price data retrieval."""
234 | return """
235 | SELECT
236 | pc.date,
237 | pc.open_price as "open",
238 | pc.high_price as "high",
239 | pc.low_price as "low",
240 | pc.close_price as "close",
241 | pc.volume
242 | FROM mcp_price_cache pc
243 | JOIN mcp_stocks s ON pc.stock_id = s.stock_id
244 | WHERE s.ticker_symbol = :symbol
245 | AND pc.date >= :start_date
246 | AND pc.date <= :end_date
247 | ORDER BY pc.date
248 | """
249 |
250 | def get_batch_price_query(self) -> str:
251 | """Get optimized SQL query for batch price data retrieval."""
252 | return """
253 | SELECT
254 | s.ticker_symbol,
255 | pc.date,
256 | pc.open_price as "open",
257 | pc.high_price as "high",
258 | pc.low_price as "low",
259 | pc.close_price as "close",
260 | pc.volume
261 | FROM mcp_price_cache pc
262 | JOIN mcp_stocks s ON pc.stock_id = s.stock_id
263 | WHERE s.ticker_symbol = ANY(:symbols)
264 | AND pc.date >= :start_date
265 | AND pc.date <= :end_date
266 | ORDER BY s.ticker_symbol, pc.date
267 | """
268 |
269 | def execute_optimized_query(
270 | self,
271 | session: Session,
272 | query: str,
273 | params: dict[str, Any],
274 | query_name: str = "unnamed",
275 | ) -> pd.DataFrame:
276 | """Execute optimized query with performance monitoring."""
277 | with self.query_performance_monitor(query_name):
278 | try:
279 | result = pd.read_sql(
280 | text(query),
281 | session.bind,
282 | params=params,
283 | index_col="date" if "date" in query.lower() else None,
284 | parse_dates=["date"] if "date" in query.lower() else None,
285 | )
286 |
287 | logger.debug(f"Query {query_name} returned {len(result)} rows")
288 | return result
289 |
290 | except Exception as e:
291 | logger.error(f"Query {query_name} failed: {e}")
292 | raise
293 |
294 | def get_statistics(self) -> dict[str, Any]:
295 | """Get query and connection pool statistics."""
296 | return {
297 | "query_stats": self._query_stats.copy(),
298 | "connection_pool_stats": self._connection_pool_stats.copy(),
299 | "top_slow_queries": sorted(
300 | [
301 | (name, stats["avg_time"])
302 | for name, stats in self._query_stats.items()
303 | ],
304 | key=lambda x: x[1],
305 | reverse=True,
306 | )[:5],
307 | }
308 |
309 | def reset_statistics(self):
310 | """Reset performance statistics."""
311 | self._query_stats.clear()
312 | self._connection_pool_stats = {
313 | "active_connections": 0,
314 | "checked_out": 0,
315 | "total_queries": 0,
316 | "slow_queries": 0,
317 | }
318 |
319 |
320 | class BatchQueryExecutor:
321 | """Efficient batch query execution for parallel backtesting."""
322 |
323 | def __init__(self, optimizer: QueryOptimizer = None):
324 | """Initialize batch query executor."""
325 | self.optimizer = optimizer or QueryOptimizer()
326 |
327 | async def fetch_multiple_symbols_data(
328 | self,
329 | symbols: list[str],
330 | start_date: str,
331 | end_date: str,
332 | session: Session = None,
333 | ) -> dict[str, pd.DataFrame]:
334 | """
335 | Efficiently fetch data for multiple symbols in a single query.
336 |
337 | Args:
338 | symbols: List of stock symbols
339 | start_date: Start date (YYYY-MM-DD)
340 | end_date: End date (YYYY-MM-DD)
341 | session: Optional database session
342 |
343 | Returns:
344 | Dictionary mapping symbols to DataFrames
345 | """
346 | if not symbols:
347 | return {}
348 |
349 | should_close = session is None
350 | if session is None:
351 | session = self.optimizer.session_factory()
352 |
353 | try:
354 | # Use batch query to fetch all symbols at once
355 | batch_query = self.optimizer.get_batch_price_query()
356 |
357 | result_df = self.optimizer.execute_optimized_query(
358 | session=session,
359 | query=batch_query,
360 | params={
361 | "symbols": symbols,
362 | "start_date": start_date,
363 | "end_date": end_date,
364 | },
365 | query_name="batch_symbol_fetch",
366 | )
367 |
368 | # Group by symbol and create separate DataFrames
369 | symbol_data = {}
370 | if not result_df.empty:
371 | for symbol in symbols:
372 | symbol_df = result_df[result_df["ticker_symbol"] == symbol].copy()
373 | symbol_df.drop("ticker_symbol", axis=1, inplace=True)
374 | symbol_data[symbol] = symbol_df
375 | else:
376 | # Return empty DataFrames for all symbols
377 | symbol_data = {symbol: pd.DataFrame() for symbol in symbols}
378 |
379 | logger.info(
380 | f"Batch fetched data for {len(symbols)} symbols: "
381 | f"{sum(len(df) for df in symbol_data.values())} total records"
382 | )
383 |
384 | return symbol_data
385 |
386 | finally:
387 | if should_close:
388 | session.close()
389 |
390 |
391 | # Global instances for easy access
392 | _query_optimizer = QueryOptimizer()
393 | _batch_executor = BatchQueryExecutor(_query_optimizer)
394 |
395 |
396 | def get_query_optimizer() -> QueryOptimizer:
397 | """Get the global query optimizer instance."""
398 | return _query_optimizer
399 |
400 |
401 | def get_batch_executor() -> BatchQueryExecutor:
402 | """Get the global batch executor instance."""
403 | return _batch_executor
404 |
405 |
406 | def initialize_database_optimizations():
407 | """Initialize all database optimizations for backtesting."""
408 | logger.info("Initializing database optimizations for parallel backtesting...")
409 |
410 | try:
411 | optimizer = get_query_optimizer()
412 |
413 | # Create performance indexes
414 | optimizer.create_backtesting_indexes(engine)
415 |
416 | # Optimize connection pool
417 | optimizer.optimize_connection_pool(engine)
418 |
419 | logger.info("Database optimizations initialized successfully")
420 |
421 | except Exception as e:
422 | logger.error(f"Failed to initialize database optimizations: {e}")
423 |
424 |
425 | @contextmanager
426 | def optimized_db_session():
427 | """Context manager for optimized database session."""
428 | session = SessionLocal()
429 | try:
430 | # Configure session for optimal performance
431 | session.execute(text("PRAGMA synchronous = NORMAL")) # SQLite optimization
432 | session.execute(text("PRAGMA journal_mode = WAL")) # SQLite optimization
433 | yield session
434 | session.commit()
435 | except Exception:
436 | session.rollback()
437 | raise
438 | finally:
439 | session.close()
440 |
441 |
442 | # Performance monitoring decorator
443 | def monitor_query_performance(query_name: str):
444 | """Decorator for monitoring query performance."""
445 |
446 | def decorator(func):
447 | def wrapper(*args, **kwargs):
448 | optimizer = get_query_optimizer()
449 | with optimizer.query_performance_monitor(query_name):
450 | return func(*args, **kwargs)
451 |
452 | return wrapper
453 |
454 | return decorator
455 |
```
--------------------------------------------------------------------------------
/maverick_mcp/utils/tracing.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | OpenTelemetry distributed tracing integration for MaverickMCP.
3 |
4 | This module provides comprehensive distributed tracing capabilities including:
5 | - Automatic span creation for database queries, external API calls, and tool executions
6 | - Integration with FastMCP and FastAPI
7 | - Support for multiple tracing backends (Jaeger, Zipkin, OTLP)
8 | - Correlation with structured logging
9 | """
10 |
11 | import functools
12 | import os
13 | import time
14 | from collections.abc import Callable
15 | from contextlib import contextmanager
16 | from typing import Any
17 |
18 | from maverick_mcp.config.settings import settings
19 | from maverick_mcp.utils.logging import get_logger
20 |
21 | # OpenTelemetry imports with graceful fallback
22 | try:
23 | from opentelemetry import trace # type: ignore[import-untyped]
24 | from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import (
25 | OTLPSpanExporter, # type: ignore[import-untyped]
26 | )
27 | from opentelemetry.exporter.zipkin.json import (
28 | ZipkinExporter, # type: ignore[import-untyped]
29 | )
30 | from opentelemetry.instrumentation.asyncio import (
31 | AsyncioInstrumentor, # type: ignore[import-untyped]
32 | )
33 | from opentelemetry.instrumentation.asyncpg import (
34 | AsyncPGInstrumentor, # type: ignore[import-untyped]
35 | )
36 | from opentelemetry.instrumentation.fastapi import (
37 | FastAPIInstrumentor, # type: ignore[import-untyped]
38 | )
39 | from opentelemetry.instrumentation.httpx import (
40 | HTTPXInstrumentor, # type: ignore[import-untyped]
41 | )
42 | from opentelemetry.instrumentation.redis import (
43 | RedisInstrumentor, # type: ignore[import-untyped]
44 | )
45 | from opentelemetry.instrumentation.requests import (
46 | RequestsInstrumentor, # type: ignore[import-untyped]
47 | )
48 | from opentelemetry.instrumentation.sqlalchemy import (
49 | SQLAlchemyInstrumentor, # type: ignore[import-untyped]
50 | )
51 | from opentelemetry.propagate import (
52 | set_global_textmap, # type: ignore[import-untyped]
53 | )
54 | from opentelemetry.propagators.b3 import (
55 | B3MultiFormat, # type: ignore[import-untyped]
56 | )
57 | from opentelemetry.sdk.resources import Resource # type: ignore[import-untyped]
58 | from opentelemetry.sdk.trace import TracerProvider # type: ignore[import-untyped]
59 | from opentelemetry.sdk.trace.export import ( # type: ignore[import-untyped]
60 | BatchSpanProcessor,
61 | ConsoleSpanExporter,
62 | )
63 | from opentelemetry.semconv.resource import (
64 | ResourceAttributes, # type: ignore[import-untyped]
65 | )
66 | from opentelemetry.trace import Status, StatusCode # type: ignore[import-untyped]
67 |
68 | OTEL_AVAILABLE = True
69 | except ImportError:
70 | # Create stub classes for when OpenTelemetry is not available
71 | class _TracerStub:
72 | def start_span(self, name: str, **kwargs):
73 | return _SpanStub()
74 |
75 | def start_as_current_span(self, name: str, **kwargs):
76 | return _SpanStub()
77 |
78 | class _SpanStub:
79 | def __enter__(self):
80 | return self
81 |
82 | def __exit__(self, *args):
83 | pass
84 |
85 | def set_attribute(self, key: str, value: Any):
86 | pass
87 |
88 | def set_status(self, status):
89 | pass
90 |
91 | def record_exception(self, exception: Exception):
92 | pass
93 |
94 | def add_event(self, name: str, attributes: dict[str, Any] | None = None):
95 | pass
96 |
97 | # Create stub types for type annotations
98 | class TracerProvider:
99 | pass
100 |
101 | trace = type("trace", (), {"get_tracer": lambda name: _TracerStub()})()
102 | OTEL_AVAILABLE = False
103 |
104 |
105 | logger = get_logger(__name__)
106 |
107 |
108 | class TracingService:
109 | """Service for distributed tracing configuration and management."""
110 |
111 | def __init__(self):
112 | self.tracer = None
113 | self.enabled = False
114 | self._initialize_tracing()
115 |
116 | def _initialize_tracing(self):
117 | """Initialize OpenTelemetry tracing."""
118 | if not OTEL_AVAILABLE:
119 | return
120 |
121 | # Check if tracing is enabled
122 | tracing_enabled = os.getenv("OTEL_TRACING_ENABLED", "false").lower() == "true"
123 | if not tracing_enabled and settings.environment != "development":
124 | logger.info("OpenTelemetry tracing disabled")
125 | return
126 |
127 | try:
128 | # Create resource
129 | resource = Resource.create(
130 | {
131 | ResourceAttributes.SERVICE_NAME: settings.app_name,
132 | ResourceAttributes.SERVICE_VERSION: os.getenv(
133 | "RELEASE_VERSION", "unknown"
134 | ),
135 | ResourceAttributes.SERVICE_NAMESPACE: "maverick-mcp",
136 | ResourceAttributes.DEPLOYMENT_ENVIRONMENT: settings.environment,
137 | }
138 | )
139 |
140 | # Configure tracer provider
141 | tracer_provider = TracerProvider(resource=resource)
142 | trace.set_tracer_provider(tracer_provider)
143 |
144 | # Configure exporters
145 | self._configure_exporters(tracer_provider)
146 |
147 | # Configure propagators
148 | self._configure_propagators()
149 |
150 | # Instrument libraries
151 | self._instrument_libraries()
152 |
153 | # Create tracer
154 | self.tracer = trace.get_tracer(__name__)
155 | self.enabled = True
156 |
157 | logger.info("OpenTelemetry tracing initialized successfully")
158 |
159 | except Exception as e:
160 | logger.error(f"Failed to initialize OpenTelemetry tracing: {e}")
161 |
162 | def _configure_exporters(self, tracer_provider: TracerProvider):
163 | """Configure trace exporters based on environment variables."""
164 | # Console exporter (for development)
165 | if settings.environment == "development":
166 | console_exporter = ConsoleSpanExporter()
167 | tracer_provider.add_span_processor(BatchSpanProcessor(console_exporter)) # type: ignore[attr-defined]
168 |
169 | # Jaeger exporter via OTLP (modern approach)
170 | jaeger_endpoint = os.getenv("JAEGER_ENDPOINT")
171 | if jaeger_endpoint:
172 | # Modern Jaeger deployments accept OTLP on port 4317 (gRPC) or 4318 (HTTP)
173 | # Convert legacy Jaeger collector endpoint to OTLP format if needed
174 | if "14268" in jaeger_endpoint: # Legacy Jaeger HTTP port
175 | otlp_endpoint = jaeger_endpoint.replace(":14268", ":4318").replace(
176 | "/api/traces", ""
177 | )
178 | logger.info(
179 | f"Converting legacy Jaeger endpoint {jaeger_endpoint} to OTLP: {otlp_endpoint}"
180 | )
181 | else:
182 | otlp_endpoint = jaeger_endpoint
183 |
184 | jaeger_otlp_exporter = OTLPSpanExporter(
185 | endpoint=otlp_endpoint,
186 | # Add Jaeger-specific headers if needed
187 | headers={},
188 | )
189 | tracer_provider.add_span_processor(BatchSpanProcessor(jaeger_otlp_exporter)) # type: ignore[attr-defined]
190 | logger.info(f"Jaeger OTLP exporter configured: {otlp_endpoint}")
191 |
192 | # Zipkin exporter
193 | zipkin_endpoint = os.getenv("ZIPKIN_ENDPOINT")
194 | if zipkin_endpoint:
195 | zipkin_exporter = ZipkinExporter(endpoint=zipkin_endpoint)
196 | tracer_provider.add_span_processor(BatchSpanProcessor(zipkin_exporter)) # type: ignore[attr-defined]
197 | logger.info(f"Zipkin exporter configured: {zipkin_endpoint}")
198 |
199 | # OTLP exporter (for services like Honeycomb, New Relic, etc.)
200 | otlp_endpoint = os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT")
201 | if otlp_endpoint:
202 | otlp_exporter = OTLPSpanExporter(
203 | endpoint=otlp_endpoint,
204 | headers={"x-honeycomb-team": os.getenv("HONEYCOMB_API_KEY", "")},
205 | )
206 | tracer_provider.add_span_processor(BatchSpanProcessor(otlp_exporter)) # type: ignore[attr-defined]
207 | logger.info(f"OTLP exporter configured: {otlp_endpoint}")
208 |
209 | def _configure_propagators(self):
210 | """Configure trace propagators for cross-service communication."""
211 | # Use B3 propagator for maximum compatibility
212 | set_global_textmap(B3MultiFormat())
213 | logger.info("B3 trace propagator configured")
214 |
215 | def _instrument_libraries(self):
216 | """Automatically instrument common libraries."""
217 | try:
218 | # FastAPI instrumentation
219 | FastAPIInstrumentor().instrument()
220 |
221 | # Database instrumentation
222 | SQLAlchemyInstrumentor().instrument()
223 | AsyncPGInstrumentor().instrument()
224 |
225 | # HTTP client instrumentation
226 | RequestsInstrumentor().instrument()
227 | HTTPXInstrumentor().instrument()
228 |
229 | # Redis instrumentation
230 | RedisInstrumentor().instrument()
231 |
232 | # Asyncio instrumentation
233 | AsyncioInstrumentor().instrument()
234 |
235 | logger.info("Auto-instrumentation completed successfully")
236 |
237 | except Exception as e:
238 | logger.warning(f"Some auto-instrumentation failed: {e}")
239 |
240 | @contextmanager
241 | def trace_operation(
242 | self,
243 | operation_name: str,
244 | attributes: dict[str, Any] | None = None,
245 | record_exception: bool = True,
246 | ):
247 | """
248 | Context manager for tracing operations.
249 |
250 | Args:
251 | operation_name: Name of the operation being traced
252 | attributes: Additional attributes to add to the span
253 | record_exception: Whether to record exceptions in the span
254 | """
255 | if not self.enabled:
256 | yield None
257 | return
258 |
259 | with self.tracer.start_as_current_span(operation_name) as span:
260 | # Add attributes
261 | if attributes:
262 | for key, value in attributes.items():
263 | span.set_attribute(key, str(value))
264 |
265 | try:
266 | yield span
267 | span.set_status(Status(StatusCode.OK))
268 | except Exception as e:
269 | span.set_status(Status(StatusCode.ERROR, str(e)))
270 | if record_exception:
271 | span.record_exception(e)
272 | raise
273 |
274 | def trace_tool_execution(self, func: Callable) -> Callable:
275 | """
276 | Decorator to trace tool execution.
277 |
278 | Args:
279 | func: The tool function to trace
280 |
281 | Returns:
282 | Decorated function with tracing
283 | """
284 |
285 | @functools.wraps(func)
286 | async def wrapper(*args, **kwargs):
287 | if not self.enabled:
288 | return await func(*args, **kwargs)
289 |
290 | tool_name = getattr(func, "__name__", "unknown_tool")
291 | with self.trace_operation(
292 | f"tool.{tool_name}",
293 | attributes={
294 | "tool.name": tool_name,
295 | "tool.args_count": len(args),
296 | "tool.kwargs_count": len(kwargs),
297 | },
298 | ) as span:
299 | # Add user context if available
300 | for arg in args:
301 | if hasattr(arg, "user_id"):
302 | span.set_attribute("user.id", str(arg.user_id))
303 | break
304 |
305 | start_time = time.time()
306 | result = await func(*args, **kwargs)
307 | duration = time.time() - start_time
308 |
309 | span.set_attribute("tool.duration_seconds", duration)
310 | span.set_attribute("tool.success", True)
311 |
312 | return result
313 |
314 | return wrapper
315 |
316 | def trace_database_query(
317 | self, query_type: str, table: str | None = None, query: str | None = None
318 | ):
319 | """
320 | Context manager for tracing database queries.
321 |
322 | Args:
323 | query_type: Type of query (SELECT, INSERT, UPDATE, DELETE)
324 | table: Table name being queried
325 | query: The actual SQL query (will be truncated for security)
326 | """
327 | attributes = {
328 | "db.operation": query_type,
329 | "db.system": "postgresql",
330 | }
331 |
332 | if table:
333 | attributes["db.table"] = table
334 |
335 | if query:
336 | # Truncate query for security and performance
337 | attributes["db.statement"] = (
338 | query[:200] + "..." if len(query) > 200 else query
339 | )
340 |
341 | return self.trace_operation(f"db.{query_type.lower()}", attributes)
342 |
343 | def trace_external_api_call(self, service: str, endpoint: str, method: str = "GET"):
344 | """
345 | Context manager for tracing external API calls.
346 |
347 | Args:
348 | service: Name of the external service
349 | endpoint: API endpoint being called
350 | method: HTTP method
351 | """
352 | attributes = {
353 | "http.method": method,
354 | "http.url": endpoint,
355 | "service.name": service,
356 | }
357 |
358 | return self.trace_operation(f"http.{method.lower()}", attributes)
359 |
360 | def trace_cache_operation(self, operation: str, cache_type: str = "redis"):
361 | """
362 | Context manager for tracing cache operations.
363 |
364 | Args:
365 | operation: Cache operation (get, set, delete, etc.)
366 | cache_type: Type of cache (redis, memory, etc.)
367 | """
368 | attributes = {
369 | "cache.operation": operation,
370 | "cache.type": cache_type,
371 | }
372 |
373 | return self.trace_operation(f"cache.{operation}", attributes)
374 |
375 | def add_event(self, name: str, attributes: dict[str, Any] | None = None):
376 | """Add an event to the current span."""
377 | if not self.enabled:
378 | return
379 |
380 | current_span = trace.get_current_span()
381 | if current_span:
382 | current_span.add_event(name, attributes or {})
383 |
384 | def set_user_context(self, user_id: str, email: str | None = None):
385 | """Set user context on the current span."""
386 | if not self.enabled:
387 | return
388 |
389 | current_span = trace.get_current_span()
390 | if current_span:
391 | current_span.set_attribute("user.id", user_id)
392 | if email:
393 | current_span.set_attribute("user.email", email)
394 |
395 |
396 | # Global tracing service instance
397 | _tracing_service: TracingService | None = None
398 |
399 |
400 | def get_tracing_service() -> TracingService:
401 | """Get or create the global tracing service."""
402 | global _tracing_service
403 | if _tracing_service is None:
404 | _tracing_service = TracingService()
405 | return _tracing_service
406 |
407 |
408 | def trace_tool(func: Callable) -> Callable:
409 | """Decorator for tracing tool execution."""
410 | tracing = get_tracing_service()
411 | return tracing.trace_tool_execution(func)
412 |
413 |
414 | @contextmanager
415 | def trace_operation(
416 | operation_name: str,
417 | attributes: dict[str, Any] | None = None,
418 | record_exception: bool = True,
419 | ):
420 | """Context manager for tracing operations."""
421 | tracing = get_tracing_service()
422 | with tracing.trace_operation(operation_name, attributes, record_exception) as span:
423 | yield span
424 |
425 |
426 | @contextmanager
427 | def trace_database_query(
428 | query_type: str, table: str | None = None, query: str | None = None
429 | ):
430 | """Context manager for tracing database queries."""
431 | tracing = get_tracing_service()
432 | with tracing.trace_database_query(query_type, table, query) as span:
433 | yield span
434 |
435 |
436 | @contextmanager
437 | def trace_external_api_call(service: str, endpoint: str, method: str = "GET"):
438 | """Context manager for tracing external API calls."""
439 | tracing = get_tracing_service()
440 | with tracing.trace_external_api_call(service, endpoint, method) as span:
441 | yield span
442 |
443 |
444 | @contextmanager
445 | def trace_cache_operation(operation: str, cache_type: str = "redis"):
446 | """Context manager for tracing cache operations."""
447 | tracing = get_tracing_service()
448 | with tracing.trace_cache_operation(operation, cache_type) as span:
449 | yield span
450 |
451 |
452 | def initialize_tracing():
453 | """Initialize the global tracing service."""
454 | logger.info("Initializing distributed tracing...")
455 | tracing = get_tracing_service()
456 |
457 | if tracing.enabled:
458 | logger.info("Distributed tracing initialized successfully")
459 | else:
460 | logger.info("Distributed tracing disabled or unavailable")
461 |
```
--------------------------------------------------------------------------------
/maverick_mcp/api/routers/technical.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Technical analysis router for MaverickMCP.
3 |
4 | This module contains all technical analysis related tools including
5 | indicators, chart patterns, and analysis functions.
6 |
7 | DISCLAIMER: All technical analysis tools are for educational purposes only.
8 | Technical indicators are mathematical calculations based on historical data and
9 | do not predict future price movements. Results should not be considered as
10 | investment advice. Always consult qualified financial professionals.
11 | """
12 |
13 | import asyncio
14 | from concurrent.futures import ThreadPoolExecutor
15 | from datetime import UTC, datetime
16 | from typing import Any
17 |
18 | from fastmcp import FastMCP
19 | from fastmcp.server.dependencies import get_access_token
20 |
21 | from maverick_mcp.core.technical_analysis import (
22 | analyze_bollinger_bands,
23 | analyze_macd,
24 | analyze_rsi,
25 | analyze_stochastic,
26 | analyze_trend,
27 | analyze_volume,
28 | generate_outlook,
29 | identify_chart_patterns,
30 | identify_resistance_levels,
31 | identify_support_levels,
32 | )
33 | from maverick_mcp.core.visualization import (
34 | create_plotly_technical_chart,
35 | plotly_fig_to_base64,
36 | )
37 | from maverick_mcp.providers.stock_data import StockDataProvider
38 | from maverick_mcp.utils.logging import PerformanceMonitor, get_logger
39 | from maverick_mcp.utils.mcp_logging import with_logging
40 | from maverick_mcp.utils.stock_helpers import (
41 | get_stock_dataframe_async,
42 | )
43 |
44 | logger = get_logger("maverick_mcp.routers.technical")
45 |
46 | # Create the technical analysis router
47 | technical_router: FastMCP = FastMCP("Technical_Analysis")
48 |
49 | # Initialize data provider
50 | stock_provider = StockDataProvider()
51 |
52 | # Thread pool for blocking operations
53 | executor = ThreadPoolExecutor(max_workers=10)
54 |
55 |
56 | @with_logging("rsi_analysis")
57 | async def get_rsi_analysis(
58 | ticker: str, period: int = 14, days: int = 365
59 | ) -> dict[str, Any]:
60 | """
61 | Get RSI analysis for a given ticker.
62 |
63 | Args:
64 | ticker: Stock ticker symbol
65 | period: RSI period (default: 14)
66 | days: Number of days of historical data to analyze (default: 365)
67 |
68 | Returns:
69 | Dictionary containing RSI analysis
70 | """
71 | try:
72 | # Log analysis parameters
73 | logger.info(
74 | "Starting RSI analysis",
75 | extra={"ticker": ticker, "period": period, "days": days},
76 | )
77 |
78 | # Fetch stock data with performance monitoring
79 | with PerformanceMonitor(f"fetch_data_{ticker}"):
80 | df = await get_stock_dataframe_async(ticker, days)
81 |
82 | # Perform RSI analysis with monitoring
83 | with PerformanceMonitor(f"rsi_calculation_{ticker}"):
84 | loop = asyncio.get_event_loop()
85 | analysis = await loop.run_in_executor(executor, analyze_rsi, df)
86 |
87 | # Log successful completion
88 | logger.info(
89 | "RSI analysis completed successfully",
90 | extra={
91 | "ticker": ticker,
92 | "rsi_current": analysis.get("current_rsi"),
93 | "signal": analysis.get("signal"),
94 | },
95 | )
96 |
97 | return {"ticker": ticker, "period": period, "analysis": analysis}
98 | except Exception as e:
99 | logger.error(
100 | "Error in RSI analysis",
101 | exc_info=True,
102 | extra={"ticker": ticker, "period": period, "error_type": type(e).__name__},
103 | )
104 | return {"error": str(e), "status": "error"}
105 |
106 |
107 | async def get_macd_analysis(
108 | ticker: str,
109 | fast_period: int = 12,
110 | slow_period: int = 26,
111 | signal_period: int = 9,
112 | days: int = 365,
113 | ) -> dict[str, Any]:
114 | """
115 | Get MACD analysis for a given ticker.
116 |
117 | Args:
118 | ticker: Stock ticker symbol
119 | fast_period: Fast EMA period (default: 12)
120 | slow_period: Slow EMA period (default: 26)
121 | signal_period: Signal line period (default: 9)
122 | days: Number of days of historical data to analyze (default: 365)
123 |
124 | Returns:
125 | Dictionary containing MACD analysis
126 | """
127 | try:
128 | df = await get_stock_dataframe_async(ticker, days)
129 | analysis = analyze_macd(df)
130 | return {
131 | "ticker": ticker,
132 | "parameters": {
133 | "fast_period": fast_period,
134 | "slow_period": slow_period,
135 | "signal_period": signal_period,
136 | },
137 | "analysis": analysis,
138 | }
139 | except Exception as e:
140 | logger.error(f"Error in MACD analysis for {ticker}: {str(e)}")
141 | return {"error": str(e), "status": "error"}
142 |
143 |
144 | async def get_support_resistance(ticker: str, days: int = 365) -> dict[str, Any]:
145 | """
146 | Get support and resistance levels for a given ticker.
147 |
148 | Args:
149 | ticker: Stock ticker symbol
150 | days: Number of days of historical data to analyze (default: 365)
151 |
152 | Returns:
153 | Dictionary containing support and resistance levels
154 | """
155 | try:
156 | df = await get_stock_dataframe_async(ticker, days)
157 | support = identify_support_levels(df)
158 | resistance = identify_resistance_levels(df)
159 | current_price = df["close"].iloc[-1]
160 |
161 | return {
162 | "ticker": ticker,
163 | "current_price": float(current_price),
164 | "support_levels": sorted(support),
165 | "resistance_levels": sorted(resistance),
166 | }
167 | except Exception as e:
168 | logger.error(f"Error in support/resistance analysis for {ticker}: {str(e)}")
169 | return {"error": str(e), "status": "error"}
170 |
171 |
172 | async def get_full_technical_analysis(ticker: str, days: int = 365) -> dict[str, Any]:
173 | """
174 | Get comprehensive technical analysis for a given ticker.
175 |
176 | This tool provides a complete technical analysis including:
177 | - Trend analysis
178 | - All major indicators (RSI, MACD, Stochastic, Bollinger Bands)
179 | - Support and resistance levels
180 | - Volume analysis
181 | - Chart patterns
182 | - Overall outlook
183 |
184 | Args:
185 | ticker: Stock ticker symbol
186 | days: Number of days of historical data to analyze (default: 365)
187 |
188 | Returns:
189 | Dictionary containing complete technical analysis
190 | """
191 | try:
192 | # Access authentication context if available (optional for this tool)
193 | # This demonstrates optional authentication - tool works without auth
194 | # but provides enhanced features for authenticated users
195 | has_premium = False
196 | try:
197 | access_token = get_access_token()
198 | if access_token is None:
199 | raise ValueError("No access token available")
200 |
201 | # Log authenticated user
202 | logger.info(
203 | f"Technical analysis requested by authenticated user: {access_token.client_id}",
204 | extra={"scopes": access_token.scopes},
205 | )
206 |
207 | # Check for premium features based on scopes
208 | has_premium = "premium:access" in access_token.scopes
209 | logger.info(f"Has premium: {has_premium}")
210 | except Exception:
211 | # Authentication is optional for this tool
212 | logger.debug("Technical analysis requested by unauthenticated user")
213 |
214 | df = await get_stock_dataframe_async(ticker, days)
215 |
216 | # Perform all analyses
217 | trend = analyze_trend(df)
218 | rsi_analysis = analyze_rsi(df)
219 | macd_analysis = analyze_macd(df)
220 | stoch_analysis = analyze_stochastic(df)
221 | bb_analysis = analyze_bollinger_bands(df)
222 | volume_analysis = analyze_volume(df)
223 | patterns = identify_chart_patterns(df)
224 | support = identify_support_levels(df)
225 | resistance = identify_resistance_levels(df)
226 | outlook = generate_outlook(
227 | df, str(trend), rsi_analysis, macd_analysis, stoch_analysis
228 | )
229 |
230 | # Get current price and indicators
231 | current_price = df["close"].iloc[-1]
232 |
233 | # Compile results
234 | return {
235 | "ticker": ticker,
236 | "current_price": float(current_price),
237 | "trend": trend,
238 | "outlook": outlook,
239 | "indicators": {
240 | "rsi": rsi_analysis,
241 | "macd": macd_analysis,
242 | "stochastic": stoch_analysis,
243 | "bollinger_bands": bb_analysis,
244 | "volume": volume_analysis,
245 | },
246 | "levels": {"support": sorted(support), "resistance": sorted(resistance)},
247 | "patterns": patterns,
248 | "last_updated": datetime.now(UTC).isoformat(),
249 | }
250 | except Exception as e:
251 | logger.error(f"Error in technical analysis for {ticker}: {str(e)}")
252 | return {"error": str(e), "status": "error"}
253 |
254 |
255 | async def get_stock_chart_analysis(ticker: str) -> dict[str, Any]:
256 | """
257 | Generate a comprehensive technical analysis chart.
258 |
259 | This tool creates a visual technical analysis including:
260 | - Price action with candlesticks
261 | - Moving averages
262 | - Volume analysis
263 | - Technical indicators
264 | - Support and resistance levels
265 |
266 | Args:
267 | ticker: The ticker symbol of the stock to analyze
268 |
269 | Returns:
270 | Dictionary containing the chart as properly formatted MCP image content for Claude Desktop display
271 | """
272 | try:
273 | # Use async data fetching
274 | df = await get_stock_dataframe_async(ticker, 365)
275 |
276 | # Run the chart generation in the executor for performance
277 | loop = asyncio.get_event_loop()
278 | chart_content = await loop.run_in_executor(
279 | executor, _generate_chart_mcp_format, df, ticker
280 | )
281 | return chart_content
282 | except Exception as e:
283 | logger.error(f"Error generating chart analysis for {ticker}: {e}")
284 | return {"error": str(e)}
285 |
286 |
287 | def _generate_chart_mcp_format(df, ticker: str) -> dict[str, Any]:
288 | """Generate chart in proper MCP content format for Claude Desktop with aggressive size optimization"""
289 | from maverick_mcp.core.technical_analysis import add_technical_indicators
290 |
291 | df = add_technical_indicators(df)
292 |
293 | # Claude Desktop has a ~100k character limit for responses
294 | # Base64 images need to be MUCH smaller - aim for ~50k chars max
295 | chart_configs = [
296 | {"height": 300, "width": 500, "format": "jpeg"}, # Small primary
297 | {"height": 250, "width": 400, "format": "jpeg"}, # Smaller fallback
298 | {"height": 200, "width": 350, "format": "jpeg"}, # Tiny fallback
299 | {"height": 150, "width": 300, "format": "jpeg"}, # Last resort
300 | ]
301 |
302 | for config in chart_configs:
303 | try:
304 | # Generate chart with current config
305 | analysis = create_plotly_technical_chart(
306 | df, ticker, height=config["height"], width=config["width"]
307 | )
308 |
309 | # Generate base64 data URI
310 | data_uri = plotly_fig_to_base64(analysis, format=config["format"])
311 |
312 | # Extract base64 data without the data URI prefix
313 | if data_uri.startswith(f"data:image/{config['format']};base64,"):
314 | base64_data = data_uri.split(",", 1)[1]
315 | mime_type = f"image/{config['format']}"
316 | else:
317 | # Fallback - assume it's already base64 data
318 | base64_data = data_uri
319 | mime_type = f"image/{config['format']}"
320 |
321 | # Very conservative size limit for Claude Desktop
322 | # Response gets truncated at 100k chars, so aim for 50k max for base64
323 | max_chars = 50000
324 |
325 | logger.info(
326 | f"Generated chart for {ticker}: {config['width']}x{config['height']} "
327 | f"({len(base64_data):,} chars base64)"
328 | )
329 |
330 | if len(base64_data) <= max_chars:
331 | # Try multiple formats to work around Claude Desktop bugs
332 | description = (
333 | f"Technical analysis chart for {ticker.upper()} "
334 | f"({config['width']}x{config['height']}) showing price action, "
335 | f"moving averages, volume, RSI, and MACD indicators."
336 | )
337 |
338 | return _return_image_with_claude_desktop_workaround(
339 | base64_data, mime_type, description, ticker
340 | )
341 | else:
342 | logger.warning(
343 | f"Chart for {ticker} too large at {config['width']}x{config['height']} "
344 | f"({len(base64_data):,} chars > {max_chars}), trying smaller size..."
345 | )
346 | continue
347 |
348 | except Exception as e:
349 | logger.warning(f"Failed to generate chart with config {config}: {e}")
350 | continue
351 |
352 | # If all configs failed, return error
353 | return {
354 | "content": [
355 | {
356 | "type": "text",
357 | "text": (
358 | f"Unable to generate suitable chart size for {ticker.upper()}. "
359 | f"The chart image is too large for Claude Desktop display limits. "
360 | f"Please use the text-based technical analysis tool instead: "
361 | f"technical_get_full_technical_analysis"
362 | ),
363 | }
364 | ]
365 | }
366 |
367 |
368 | def _return_image_with_claude_desktop_workaround(
369 | base64_data: str, mime_type: str, description: str, ticker: str
370 | ) -> dict[str, Any]:
371 | """
372 | Return image using multiple formats to work around Claude Desktop bugs.
373 | Tries alternative MCP format first, fallback to file saving.
374 | """
375 | import base64 as b64
376 | import tempfile
377 | from pathlib import Path
378 |
379 | # Format 1: Alternative "source" structure (some reports of this working)
380 | try:
381 | return {
382 | "content": [
383 | {"type": "text", "text": description},
384 | {
385 | "type": "image",
386 | "source": {
387 | "type": "base64",
388 | "media_type": mime_type,
389 | "data": base64_data,
390 | },
391 | },
392 | ]
393 | }
394 | except Exception as e:
395 | logger.warning(f"Alternative image format failed: {e}")
396 |
397 | # Format 2: Try original format one more time with different structure
398 | try:
399 | return {
400 | "content": [
401 | {"type": "text", "text": description},
402 | {"type": "image", "data": base64_data, "mimeType": mime_type},
403 | ]
404 | }
405 | except Exception as e:
406 | logger.warning(f"Standard image format failed: {e}")
407 |
408 | # Format 3: File-based fallback (most reliable for Claude Desktop)
409 | try:
410 | ext = mime_type.split("/")[-1] # jpeg, png, etc.
411 |
412 | # Create temp file in a standard location
413 | temp_dir = Path(tempfile.gettempdir()) / "maverick_mcp_charts"
414 | temp_dir.mkdir(exist_ok=True)
415 |
416 | chart_file = (
417 | temp_dir
418 | / f"{ticker.lower()}_chart_{datetime.now(UTC).strftime('%Y%m%d_%H%M%S')}.{ext}"
419 | )
420 |
421 | # Decode and save base64 to file
422 | image_data = b64.b64decode(base64_data)
423 | chart_file.write_bytes(image_data)
424 |
425 | logger.info(f"Saved chart to file: {chart_file}")
426 |
427 | return {
428 | "content": [
429 | {
430 | "type": "text",
431 | "text": (
432 | f"{description}\n\n"
433 | f"📁 **Chart saved to file**: `{chart_file}`\n\n"
434 | f"**To view this image:**\n"
435 | f"1. Use the filesystem MCP server if configured, or\n"
436 | f"2. Ask me to open the file location, or\n"
437 | f"3. Navigate to the file manually\n\n"
438 | f"*Note: Claude Desktop has a known issue with embedded images. "
439 | f"File-based display is the current workaround.*"
440 | ),
441 | }
442 | ]
443 | }
444 | except Exception as e:
445 | logger.error(f"File fallback also failed: {e}")
446 | return {
447 | "content": [
448 | {
449 | "type": "text",
450 | "text": (
451 | f"Unable to display chart for {ticker.upper()} due to "
452 | f"Claude Desktop image rendering limitations. "
453 | f"Please use the text-based technical analysis instead: "
454 | f"`technical_get_full_technical_analysis`"
455 | ),
456 | }
457 | ]
458 | }
459 |
```
--------------------------------------------------------------------------------
/maverick_mcp/domain/stock_analysis/stock_analysis_service.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Stock Analysis Service - Domain service that orchestrates data fetching and caching.
3 | """
4 |
5 | import logging
6 | from datetime import UTC, datetime, timedelta
7 |
8 | import pandas as pd
9 | import pandas_market_calendars as mcal
10 | import pytz
11 | from sqlalchemy.orm import Session
12 |
13 | from maverick_mcp.infrastructure.caching import CacheManagementService
14 | from maverick_mcp.infrastructure.data_fetching import StockDataFetchingService
15 |
16 | logger = logging.getLogger("maverick_mcp.stock_analysis")
17 |
18 |
19 | class StockAnalysisService:
20 | """
21 | Domain service that orchestrates stock data retrieval with intelligent caching.
22 |
23 | This service:
24 | - Contains business logic for stock data retrieval
25 | - Orchestrates data fetching and caching services
26 | - Implements smart caching strategies
27 | - Uses dependency injection for service composition
28 | """
29 |
30 | def __init__(
31 | self,
32 | data_fetching_service: StockDataFetchingService,
33 | cache_service: CacheManagementService,
34 | db_session: Session | None = None,
35 | ):
36 | """
37 | Initialize the stock analysis service.
38 |
39 | Args:
40 | data_fetching_service: Service for fetching data from external sources
41 | cache_service: Service for cache management
42 | db_session: Optional database session for dependency injection
43 | """
44 | self.data_fetching_service = data_fetching_service
45 | self.cache_service = cache_service
46 | self.db_session = db_session
47 |
48 | # Initialize NYSE calendar for US stock market
49 | self.market_calendar = mcal.get_calendar("NYSE")
50 |
51 | def get_stock_data(
52 | self,
53 | symbol: str,
54 | start_date: str | None = None,
55 | end_date: str | None = None,
56 | period: str | None = None,
57 | interval: str = "1d",
58 | use_cache: bool = True,
59 | ) -> pd.DataFrame:
60 | """
61 | Get stock data with intelligent caching strategy.
62 |
63 | This method:
64 | 1. Gets all available data from cache
65 | 2. Identifies missing date ranges
66 | 3. Fetches only missing data from external sources
67 | 4. Combines and returns the complete dataset
68 |
69 | Args:
70 | symbol: Stock ticker symbol
71 | start_date: Start date in YYYY-MM-DD format
72 | end_date: End date in YYYY-MM-DD format
73 | period: Alternative to start/end dates (e.g., '1d', '5d', '1mo', etc.)
74 | interval: Data interval ('1d', '1wk', '1mo', '1m', '5m', etc.)
75 | use_cache: Whether to use cached data if available
76 |
77 | Returns:
78 | DataFrame with stock data
79 | """
80 | symbol = symbol.upper()
81 |
82 | # For non-daily intervals or periods, always fetch fresh data
83 | if interval != "1d" or period:
84 | logger.info(
85 | f"Non-daily interval or period specified, fetching fresh data for {symbol}"
86 | )
87 | return self.data_fetching_service.fetch_stock_data(
88 | symbol, start_date, end_date, period, interval
89 | )
90 |
91 | # Set default dates if not provided
92 | if start_date is None:
93 | start_date = (datetime.now(UTC) - timedelta(days=365)).strftime("%Y-%m-%d")
94 | if end_date is None:
95 | end_date = datetime.now(UTC).strftime("%Y-%m-%d")
96 |
97 | # For daily data, adjust end date to last trading day if it's not a trading day
98 | if interval == "1d" and use_cache:
99 | end_dt = pd.to_datetime(end_date)
100 | if not self._is_trading_day(end_dt):
101 | last_trading = self._get_last_trading_day(end_dt)
102 | logger.debug(
103 | f"Adjusting end date from {end_date} to last trading day {last_trading.strftime('%Y-%m-%d')}"
104 | )
105 | end_date = last_trading.strftime("%Y-%m-%d")
106 |
107 | # If cache is disabled, fetch directly
108 | if not use_cache:
109 | logger.info(f"Cache disabled, fetching fresh data for {symbol}")
110 | return self.data_fetching_service.fetch_stock_data(
111 | symbol, start_date, end_date, period, interval
112 | )
113 |
114 | # Use smart caching strategy
115 | try:
116 | return self._get_data_with_smart_cache(
117 | symbol, start_date, end_date, interval
118 | )
119 | except Exception as e:
120 | logger.warning(
121 | f"Smart cache failed for {symbol}, falling back to fresh data: {e}"
122 | )
123 | return self.data_fetching_service.fetch_stock_data(
124 | symbol, start_date, end_date, period, interval
125 | )
126 |
127 | def _get_data_with_smart_cache(
128 | self, symbol: str, start_date: str, end_date: str, interval: str
129 | ) -> pd.DataFrame:
130 | """
131 | Implement smart caching strategy for stock data retrieval.
132 |
133 | Args:
134 | symbol: Stock ticker symbol
135 | start_date: Start date in YYYY-MM-DD format
136 | end_date: End date in YYYY-MM-DD format
137 | interval: Data interval
138 |
139 | Returns:
140 | DataFrame with complete stock data
141 | """
142 | logger.info(
143 | f"Using smart cache strategy for {symbol} from {start_date} to {end_date}"
144 | )
145 |
146 | # Step 1: Get available cached data
147 | cached_df = self.cache_service.get_cached_data(symbol, start_date, end_date)
148 |
149 | # Convert dates for comparison
150 | start_dt = pd.to_datetime(start_date)
151 | end_dt = pd.to_datetime(end_date)
152 |
153 | # Step 2: Determine what data we need
154 | if cached_df is not None and not cached_df.empty:
155 | logger.info(f"Found {len(cached_df)} cached records for {symbol}")
156 |
157 | # Check if we have all the data we need
158 | cached_start = pd.to_datetime(cached_df.index.min())
159 | cached_end = pd.to_datetime(cached_df.index.max())
160 |
161 | # Identify missing ranges
162 | missing_ranges = []
163 |
164 | # Missing data at the beginning?
165 | if start_dt < cached_start:
166 | missing_start_trading = self._get_trading_days(
167 | start_dt, cached_start - timedelta(days=1)
168 | )
169 | if len(missing_start_trading) > 0:
170 | missing_ranges.append(
171 | (
172 | missing_start_trading[0].strftime("%Y-%m-%d"),
173 | missing_start_trading[-1].strftime("%Y-%m-%d"),
174 | )
175 | )
176 |
177 | # Missing recent data?
178 | if end_dt > cached_end:
179 | if self._is_trading_day_between(cached_end, end_dt):
180 | missing_end_trading = self._get_trading_days(
181 | cached_end + timedelta(days=1), end_dt
182 | )
183 | if len(missing_end_trading) > 0:
184 | missing_ranges.append(
185 | (
186 | missing_end_trading[0].strftime("%Y-%m-%d"),
187 | missing_end_trading[-1].strftime("%Y-%m-%d"),
188 | )
189 | )
190 |
191 | # If no missing data, return cached data
192 | if not missing_ranges:
193 | logger.info(
194 | f"Cache hit! Returning {len(cached_df)} cached records for {symbol}"
195 | )
196 | # Filter to requested range
197 | mask = (cached_df.index >= start_dt) & (cached_df.index <= end_dt)
198 | return cached_df.loc[mask]
199 |
200 | # Step 3: Fetch only missing data
201 | logger.info(f"Cache partial hit. Missing ranges: {missing_ranges}")
202 | all_dfs = [cached_df]
203 |
204 | for miss_start, miss_end in missing_ranges:
205 | logger.info(
206 | f"Fetching missing data for {symbol} from {miss_start} to {miss_end}"
207 | )
208 | missing_df = self.data_fetching_service.fetch_stock_data(
209 | symbol, miss_start, miss_end, None, interval
210 | )
211 | if not missing_df.empty:
212 | all_dfs.append(missing_df)
213 | # Cache the new data
214 | self.cache_service.cache_data(symbol, missing_df)
215 |
216 | # Combine all data
217 | combined_df = pd.concat(all_dfs).sort_index()
218 | # Remove any duplicates (keep first)
219 | combined_df = combined_df[~combined_df.index.duplicated(keep="first")]
220 |
221 | # Filter to requested range
222 | mask = (combined_df.index >= start_dt) & (combined_df.index <= end_dt)
223 | return combined_df.loc[mask]
224 |
225 | else:
226 | # No cached data, fetch everything
227 | logger.info(f"No cached data found for {symbol}, fetching fresh data")
228 |
229 | # Adjust dates to trading days
230 | trading_days = self._get_trading_days(start_date, end_date)
231 | if len(trading_days) == 0:
232 | logger.warning(
233 | f"No trading days found between {start_date} and {end_date}"
234 | )
235 | return pd.DataFrame(
236 | columns=[
237 | "Open",
238 | "High",
239 | "Low",
240 | "Close",
241 | "Volume",
242 | "Dividends",
243 | "Stock Splits",
244 | ]
245 | )
246 |
247 | # Fetch data only for the trading day range
248 | fetch_start = trading_days[0].strftime("%Y-%m-%d")
249 | fetch_end = trading_days[-1].strftime("%Y-%m-%d")
250 |
251 | logger.info(f"Fetching data for trading days: {fetch_start} to {fetch_end}")
252 | df = self.data_fetching_service.fetch_stock_data(
253 | symbol, fetch_start, fetch_end, None, interval
254 | )
255 |
256 | if not df.empty:
257 | # Cache the fetched data
258 | self.cache_service.cache_data(symbol, df)
259 |
260 | return df
261 |
262 | def get_stock_info(self, symbol: str) -> dict:
263 | """
264 | Get detailed stock information.
265 |
266 | Args:
267 | symbol: Stock ticker symbol
268 |
269 | Returns:
270 | Dictionary with stock information
271 | """
272 | return self.data_fetching_service.fetch_stock_info(symbol)
273 |
274 | def get_realtime_data(self, symbol: str) -> dict | None:
275 | """
276 | Get real-time data for a symbol.
277 |
278 | Args:
279 | symbol: Stock ticker symbol
280 |
281 | Returns:
282 | Dictionary with real-time data or None
283 | """
284 | return self.data_fetching_service.fetch_realtime_data(symbol)
285 |
286 | def get_multiple_realtime_data(self, symbols: list[str]) -> dict[str, dict]:
287 | """
288 | Get real-time data for multiple symbols.
289 |
290 | Args:
291 | symbols: List of stock ticker symbols
292 |
293 | Returns:
294 | Dictionary mapping symbols to their real-time data
295 | """
296 | return self.data_fetching_service.fetch_multiple_realtime_data(symbols)
297 |
298 | def is_market_open(self) -> bool:
299 | """
300 | Check if the US stock market is currently open.
301 |
302 | Returns:
303 | True if market is open
304 | """
305 | now = datetime.now(pytz.timezone("US/Eastern"))
306 |
307 | # Check if it's a weekday
308 | if now.weekday() >= 5: # 5 and 6 are Saturday and Sunday
309 | return False
310 |
311 | # Check if it's between 9:30 AM and 4:00 PM Eastern Time
312 | market_open = now.replace(hour=9, minute=30, second=0, microsecond=0)
313 | market_close = now.replace(hour=16, minute=0, second=0, microsecond=0)
314 |
315 | return market_open <= now <= market_close
316 |
317 | def get_news(self, symbol: str, limit: int = 10) -> pd.DataFrame:
318 | """
319 | Get news for a stock.
320 |
321 | Args:
322 | symbol: Stock ticker symbol
323 | limit: Maximum number of news items
324 |
325 | Returns:
326 | DataFrame with news data
327 | """
328 | return self.data_fetching_service.fetch_news(symbol, limit)
329 |
330 | def get_earnings(self, symbol: str) -> dict:
331 | """
332 | Get earnings information for a stock.
333 |
334 | Args:
335 | symbol: Stock ticker symbol
336 |
337 | Returns:
338 | Dictionary with earnings data
339 | """
340 | return self.data_fetching_service.fetch_earnings(symbol)
341 |
342 | def get_recommendations(self, symbol: str) -> pd.DataFrame:
343 | """
344 | Get analyst recommendations for a stock.
345 |
346 | Args:
347 | symbol: Stock ticker symbol
348 |
349 | Returns:
350 | DataFrame with recommendations
351 | """
352 | return self.data_fetching_service.fetch_recommendations(symbol)
353 |
354 | def is_etf(self, symbol: str) -> bool:
355 | """
356 | Check if a given symbol is an ETF.
357 |
358 | Args:
359 | symbol: Stock ticker symbol
360 |
361 | Returns:
362 | True if symbol is an ETF
363 | """
364 | return self.data_fetching_service.check_if_etf(symbol)
365 |
366 | def _get_trading_days(self, start_date, end_date) -> pd.DatetimeIndex:
367 | """
368 | Get all trading days between start and end dates.
369 |
370 | Args:
371 | start_date: Start date (can be string or datetime)
372 | end_date: End date (can be string or datetime)
373 |
374 | Returns:
375 | DatetimeIndex of trading days
376 | """
377 | # Ensure dates are datetime objects
378 | if isinstance(start_date, str):
379 | start_date = pd.to_datetime(start_date)
380 | if isinstance(end_date, str):
381 | end_date = pd.to_datetime(end_date)
382 |
383 | # Get valid trading days from market calendar
384 | schedule = self.market_calendar.schedule(
385 | start_date=start_date, end_date=end_date
386 | )
387 | return schedule.index
388 |
389 | def _get_last_trading_day(self, date) -> pd.Timestamp:
390 | """
391 | Get the last trading day on or before the given date.
392 |
393 | Args:
394 | date: Date to check (can be string or datetime)
395 |
396 | Returns:
397 | Last trading day as pd.Timestamp
398 | """
399 | if isinstance(date, str):
400 | date = pd.to_datetime(date)
401 |
402 | # Check if the date itself is a trading day
403 | if self._is_trading_day(date):
404 | return date
405 |
406 | # Otherwise, find the previous trading day
407 | for i in range(1, 10): # Look back up to 10 days
408 | check_date = date - timedelta(days=i)
409 | if self._is_trading_day(check_date):
410 | return check_date
411 |
412 | # Fallback to the date itself if no trading day found
413 | return date
414 |
415 | def _is_trading_day(self, date) -> bool:
416 | """
417 | Check if a specific date is a trading day.
418 |
419 | Args:
420 | date: Date to check
421 |
422 | Returns:
423 | True if it's a trading day
424 | """
425 | if isinstance(date, str):
426 | date = pd.to_datetime(date)
427 |
428 | schedule = self.market_calendar.schedule(start_date=date, end_date=date)
429 | return len(schedule) > 0
430 |
431 | def _is_trading_day_between(
432 | self, start_date: pd.Timestamp, end_date: pd.Timestamp
433 | ) -> bool:
434 | """
435 | Check if there's a trading day between two dates.
436 |
437 | Args:
438 | start_date: Start date
439 | end_date: End date
440 |
441 | Returns:
442 | True if there's a trading day between the dates
443 | """
444 | # Add one day to start since we're checking "between"
445 | check_start = start_date + timedelta(days=1)
446 |
447 | if check_start > end_date:
448 | return False
449 |
450 | # Get trading days in the range
451 | trading_days = self._get_trading_days(check_start, end_date)
452 | return len(trading_days) > 0
453 |
454 | def invalidate_cache(self, symbol: str, start_date: str, end_date: str) -> bool:
455 | """
456 | Invalidate cached data for a symbol within a date range.
457 |
458 | Args:
459 | symbol: Stock ticker symbol
460 | start_date: Start date in YYYY-MM-DD format
461 | end_date: End date in YYYY-MM-DD format
462 |
463 | Returns:
464 | True if invalidation was successful
465 | """
466 | return self.cache_service.invalidate_cache(symbol, start_date, end_date)
467 |
468 | def get_cache_stats(self, symbol: str) -> dict:
469 | """
470 | Get cache statistics for a symbol.
471 |
472 | Args:
473 | symbol: Stock ticker symbol
474 |
475 | Returns:
476 | Dictionary with cache statistics
477 | """
478 | return self.cache_service.get_cache_stats(symbol)
479 |
```
--------------------------------------------------------------------------------
/maverick_mcp/api/routers/mcp_prompts.py:
--------------------------------------------------------------------------------
```python
1 | """MCP Prompts for better tool discovery and usage guidance."""
2 |
3 | from fastmcp import FastMCP
4 |
5 |
6 | def register_mcp_prompts(mcp: FastMCP):
7 | """Register MCP prompts to help clients understand how to use the tools."""
8 |
9 | # Backtesting prompts
10 | @mcp.prompt()
11 | async def backtest_strategy_guide():
12 | """Guide for running backtesting strategies."""
13 | return """
14 | # Backtesting Strategy Guide
15 |
16 | ## Available Strategies (15 total)
17 |
18 | ### Traditional Strategies (9):
19 | - `sma_cross`: Simple Moving Average Crossover
20 | - `rsi`: RSI Mean Reversion (oversold/overbought)
21 | - `macd`: MACD Signal Line Crossover
22 | - `bollinger`: Bollinger Bands (buy low, sell high)
23 | - `momentum`: Momentum-based trading
24 | - `ema_cross`: Exponential Moving Average Crossover
25 | - `mean_reversion`: Mean Reversion Strategy
26 | - `breakout`: Channel Breakout Strategy
27 | - `volume_momentum`: Volume-Weighted Momentum
28 |
29 | ### ML Strategies (6):
30 | - `online_learning`: Adaptive learning with dynamic thresholds
31 | - `regime_aware`: Market regime detection (trending vs ranging)
32 | - `ensemble`: Multiple strategy voting system
33 |
34 | ## Example Usage:
35 |
36 | ### Traditional Strategy:
37 | "Run a backtest on AAPL using the sma_cross strategy from 2024-01-01 to 2024-12-31"
38 |
39 | ### ML Strategy:
40 | "Test the online_learning strategy on TSLA for the past year with a learning rate of 0.01"
41 |
42 | ### Parameters:
43 | - Most strategies have default parameters that work well
44 | - You can customize: fast_period, slow_period, threshold, etc.
45 | """
46 |
47 | @mcp.prompt()
48 | async def ml_strategy_examples():
49 | """Examples of ML strategy usage."""
50 | return """
51 | # ML Strategy Examples
52 |
53 | ## 1. Online Learning Strategy
54 | "Run online_learning strategy on NVDA with parameters:
55 | - lookback: 20 days
56 | - learning_rate: 0.01
57 | - start_date: 2024-01-01
58 | - end_date: 2024-12-31"
59 |
60 | ## 2. Regime-Aware Strategy
61 | "Test regime_aware strategy on SPY to detect market regimes:
62 | - regime_window: 50 days
63 | - threshold: 0.02
64 | - Adapts between trending and ranging markets"
65 |
66 | ## 3. Ensemble Strategy
67 | "Use ensemble strategy on AAPL combining multiple signals:
68 | - Combines SMA, RSI, and Momentum
69 | - Uses voting to generate signals
70 | - More robust than single strategies"
71 |
72 | ## Important Notes:
73 | - ML strategies work through the standard run_backtest tool
74 | - Use strategy_type parameter: "online_learning", "regime_aware", or "ensemble"
75 | - These are simplified ML strategies that don't require training
76 | """
77 |
78 | @mcp.prompt()
79 | async def optimization_guide():
80 | """Guide for parameter optimization."""
81 | return """
82 | # Parameter Optimization Guide
83 |
84 | ## How to Optimize Strategy Parameters
85 |
86 | ### Basic Optimization:
87 | "Optimize sma_cross parameters for MSFT over the past 6 months"
88 |
89 | This will test combinations like:
90 | - fast_period: [5, 10, 15, 20]
91 | - slow_period: [20, 30, 50, 100]
92 |
93 | ### Custom Parameter Ranges:
94 | "Optimize RSI strategy for TSLA with:
95 | - period: [7, 14, 21]
96 | - oversold: [20, 25, 30]
97 | - overbought: [70, 75, 80]"
98 |
99 | ### Optimization Metrics:
100 | - sharpe_ratio (default): Risk-adjusted returns
101 | - total_return: Raw returns
102 | - win_rate: Percentage of winning trades
103 |
104 | ## Results Include:
105 | - Best parameter combination
106 | - Performance metrics for top combinations
107 | - Comparison across all tested parameters
108 | """
109 |
110 | @mcp.prompt()
111 | async def available_tools_summary():
112 | """Summary of all available MCP tools."""
113 | return """
114 | # MaverickMCP Tools Summary
115 |
116 | ## 1. Backtesting Tools
117 | - `run_backtest`: Run any strategy (traditional or ML)
118 | - `optimize_parameters`: Find best parameters
119 | - `compare_strategies`: Compare multiple strategies
120 | - `get_strategy_info`: Get strategy details
121 |
122 | ## 2. Data Tools
123 | - `get_stock_data`: Historical price data
124 | - `get_stock_info`: Company information
125 | - `get_multiple_stocks_data`: Batch data fetching
126 |
127 | ## 3. Technical Analysis
128 | - `calculate_sma`, `calculate_ema`: Moving averages
129 | - `calculate_rsi`: Relative Strength Index
130 | - `calculate_macd`: MACD indicator
131 | - `calculate_bollinger_bands`: Bollinger Bands
132 | - `get_full_technical_analysis`: All indicators
133 |
134 | ## 4. Screening Tools
135 | - `get_maverick_recommendations`: Bullish stocks
136 | - `get_maverick_bear_recommendations`: Bearish setups
137 | - `get_trending_breakout_recommendations`: Breakout candidates
138 |
139 | ## 5. Portfolio Tools
140 | - `optimize_portfolio`: Portfolio optimization
141 | - `analyze_portfolio_risk`: Risk assessment
142 | - `calculate_correlation_matrix`: Asset correlations
143 |
144 | ## Usage Tips:
145 | - Start with simple strategies before trying ML
146 | - Use default parameters initially
147 | - Optimize parameters after initial testing
148 | - Compare multiple strategies on same data
149 | """
150 |
151 | @mcp.prompt()
152 | async def troubleshooting_guide():
153 | """Troubleshooting common issues."""
154 | return """
155 | # Troubleshooting Guide
156 |
157 | ## Common Issues and Solutions
158 |
159 | ### 1. "Unknown strategy type"
160 | **Solution**: Use one of these exact strategy names:
161 | - Traditional: sma_cross, rsi, macd, bollinger, momentum, ema_cross, mean_reversion, breakout, volume_momentum
162 | - ML: online_learning, regime_aware, ensemble
163 |
164 | ### 2. "No data available"
165 | **Solution**:
166 | - Check date range (use past dates, not future)
167 | - Verify stock symbol (use standard tickers like AAPL, MSFT)
168 | - Try shorter date ranges (1 year or less)
169 |
170 | ### 3. ML Strategy Issues
171 | **Solution**: Use the standard run_backtest tool with:
172 | ```
173 | strategy_type: "online_learning" # or "regime_aware", "ensemble"
174 | ```
175 | Don't use the run_ml_backtest tool for these strategies.
176 |
177 | ### 4. Parameter Errors
178 | **Solution**: Start with no parameters (uses defaults):
179 | "Run backtest on AAPL using sma_cross strategy"
180 |
181 | Then customize if needed:
182 | "Run backtest on AAPL using sma_cross with fast_period=10 and slow_period=30"
183 |
184 | ### 5. Connection Issues
185 | **Solution**:
186 | - Restart Claude Desktop
187 | - Check server is running: The white circle should be blue
188 | - Try a simple test: "Get AAPL stock data"
189 | """
190 |
191 | @mcp.prompt()
192 | async def quick_start():
193 | """Quick start guide for new users."""
194 | return """
195 | # Quick Start Guide
196 |
197 | ## Test These Commands First:
198 |
199 | ### 1. Simple Backtest
200 | "Run a backtest on AAPL using the sma_cross strategy for 2024"
201 |
202 | ### 2. Get Stock Data
203 | "Get AAPL stock data for the last 3 months"
204 |
205 | ### 3. Technical Analysis
206 | "Show me technical analysis for MSFT"
207 |
208 | ### 4. Stock Screening
209 | "Show me bullish stock recommendations"
210 |
211 | ### 5. ML Strategy Test
212 | "Test the online_learning strategy on TSLA for the past 6 months"
213 |
214 | ## Next Steps:
215 | 1. Try different strategies on your favorite stocks
216 | 2. Optimize parameters for better performance
217 | 3. Compare multiple strategies
218 | 4. Build a portfolio with top performers
219 |
220 | ## Pro Tips:
221 | - Use 2024 dates for reliable data
222 | - Start with liquid stocks (AAPL, MSFT, GOOGL)
223 | - Default parameters usually work well
224 | - ML strategies are experimental but fun to try
225 | """
226 |
227 | # Register a resources endpoint for better discovery
228 | @mcp.prompt()
229 | async def strategy_reference():
230 | """Complete strategy reference with all parameters."""
231 | strategies = {
232 | "sma_cross": {
233 | "description": "Buy when fast SMA crosses above slow SMA",
234 | "parameters": {
235 | "fast_period": "Fast moving average period (default: 10)",
236 | "slow_period": "Slow moving average period (default: 20)",
237 | },
238 | "example": "run_backtest(symbol='AAPL', strategy_type='sma_cross', fast_period=10, slow_period=20)",
239 | },
240 | "rsi": {
241 | "description": "Buy oversold (RSI < 30), sell overbought (RSI > 70)",
242 | "parameters": {
243 | "period": "RSI calculation period (default: 14)",
244 | "oversold": "Oversold threshold (default: 30)",
245 | "overbought": "Overbought threshold (default: 70)",
246 | },
247 | "example": "run_backtest(symbol='MSFT', strategy_type='rsi', period=14, oversold=30)",
248 | },
249 | "online_learning": {
250 | "description": "ML strategy with adaptive thresholds",
251 | "parameters": {
252 | "lookback": "Historical window (default: 20)",
253 | "learning_rate": "Adaptation rate (default: 0.01)",
254 | },
255 | "example": "run_backtest(symbol='TSLA', strategy_type='online_learning', lookback=20)",
256 | },
257 | "regime_aware": {
258 | "description": "Detects and adapts to market regimes",
259 | "parameters": {
260 | "regime_window": "Regime detection window (default: 50)",
261 | "threshold": "Regime change threshold (default: 0.02)",
262 | },
263 | "example": "run_backtest(symbol='SPY', strategy_type='regime_aware', regime_window=50)",
264 | },
265 | "ensemble": {
266 | "description": "Combines multiple strategies with voting",
267 | "parameters": {
268 | "fast_period": "Fast MA period (default: 10)",
269 | "slow_period": "Slow MA period (default: 20)",
270 | "rsi_period": "RSI period (default: 14)",
271 | },
272 | "example": "run_backtest(symbol='NVDA', strategy_type='ensemble')",
273 | },
274 | }
275 |
276 | import json
277 |
278 | return f"""
279 | # Complete Strategy Reference
280 |
281 | ## All Available Strategies with Parameters
282 |
283 | ```json
284 | {json.dumps(strategies, indent=2)}
285 | ```
286 |
287 | ## Usage Pattern:
288 | All strategies use the same tool: `run_backtest`
289 |
290 | Parameters:
291 | - symbol: Stock ticker (required)
292 | - strategy_type: Strategy name (required)
293 | - start_date: YYYY-MM-DD format
294 | - end_date: YYYY-MM-DD format
295 | - initial_capital: Starting amount (default: 10000)
296 | - Additional strategy-specific parameters
297 |
298 | ## Testing Order:
299 | 1. Start with sma_cross (simplest)
300 | 2. Try rsi or macd (intermediate)
301 | 3. Test online_learning (ML strategy)
302 | 4. Compare all with compare_strategies tool
303 | """
304 |
305 | # Register resources for better discovery
306 | @mcp.resource("strategies://list")
307 | def list_strategies_resource():
308 | """List of all available backtesting strategies with parameters."""
309 | return {
310 | "traditional_strategies": {
311 | "sma_cross": {
312 | "name": "Simple Moving Average Crossover",
313 | "parameters": ["fast_period", "slow_period"],
314 | "default_values": {"fast_period": 10, "slow_period": 20},
315 | },
316 | "rsi": {
317 | "name": "RSI Mean Reversion",
318 | "parameters": ["period", "oversold", "overbought"],
319 | "default_values": {"period": 14, "oversold": 30, "overbought": 70},
320 | },
321 | "macd": {
322 | "name": "MACD Signal Line Crossover",
323 | "parameters": ["fast_period", "slow_period", "signal_period"],
324 | "default_values": {
325 | "fast_period": 12,
326 | "slow_period": 26,
327 | "signal_period": 9,
328 | },
329 | },
330 | "bollinger": {
331 | "name": "Bollinger Bands",
332 | "parameters": ["period", "std_dev"],
333 | "default_values": {"period": 20, "std_dev": 2},
334 | },
335 | "momentum": {
336 | "name": "Momentum Trading",
337 | "parameters": ["period", "threshold"],
338 | "default_values": {"period": 10, "threshold": 0.02},
339 | },
340 | "ema_cross": {
341 | "name": "EMA Crossover",
342 | "parameters": ["fast_period", "slow_period"],
343 | "default_values": {"fast_period": 12, "slow_period": 26},
344 | },
345 | "mean_reversion": {
346 | "name": "Mean Reversion",
347 | "parameters": ["lookback", "entry_z", "exit_z"],
348 | "default_values": {"lookback": 20, "entry_z": -2, "exit_z": 0},
349 | },
350 | "breakout": {
351 | "name": "Channel Breakout",
352 | "parameters": ["lookback", "breakout_factor"],
353 | "default_values": {"lookback": 20, "breakout_factor": 1.5},
354 | },
355 | "volume_momentum": {
356 | "name": "Volume-Weighted Momentum",
357 | "parameters": ["period", "volume_factor"],
358 | "default_values": {"period": 10, "volume_factor": 1.5},
359 | },
360 | },
361 | "ml_strategies": {
362 | "online_learning": {
363 | "name": "Online Learning Adaptive Strategy",
364 | "parameters": ["lookback", "learning_rate"],
365 | "default_values": {"lookback": 20, "learning_rate": 0.01},
366 | },
367 | "regime_aware": {
368 | "name": "Market Regime Detection",
369 | "parameters": ["regime_window", "threshold"],
370 | "default_values": {"regime_window": 50, "threshold": 0.02},
371 | },
372 | "ensemble": {
373 | "name": "Ensemble Voting Strategy",
374 | "parameters": ["fast_period", "slow_period", "rsi_period"],
375 | "default_values": {
376 | "fast_period": 10,
377 | "slow_period": 20,
378 | "rsi_period": 14,
379 | },
380 | },
381 | },
382 | "total_strategies": 15,
383 | }
384 |
385 | @mcp.resource("tools://categories")
386 | def tool_categories_resource():
387 | """Categorized list of all available MCP tools."""
388 | return {
389 | "backtesting": [
390 | "run_backtest",
391 | "optimize_parameters",
392 | "compare_strategies",
393 | "get_strategy_info",
394 | ],
395 | "data": ["get_stock_data", "get_stock_info", "get_multiple_stocks_data"],
396 | "technical_analysis": [
397 | "calculate_sma",
398 | "calculate_ema",
399 | "calculate_rsi",
400 | "calculate_macd",
401 | "calculate_bollinger_bands",
402 | "get_full_technical_analysis",
403 | ],
404 | "screening": [
405 | "get_maverick_recommendations",
406 | "get_maverick_bear_recommendations",
407 | "get_trending_breakout_recommendations",
408 | ],
409 | "portfolio": [
410 | "optimize_portfolio",
411 | "analyze_portfolio_risk",
412 | "calculate_correlation_matrix",
413 | ],
414 | "research": [
415 | "research_comprehensive",
416 | "research_company",
417 | "analyze_market_sentiment",
418 | "coordinate_agents",
419 | ],
420 | }
421 |
422 | @mcp.resource("examples://backtesting")
423 | def backtesting_examples_resource():
424 | """Practical examples of using backtesting tools."""
425 | return {
426 | "simple_backtest": {
427 | "description": "Basic backtest with default parameters",
428 | "example": "run_backtest(symbol='AAPL', strategy_type='sma_cross')",
429 | "expected_output": "Performance metrics including total return, sharpe ratio, win rate",
430 | },
431 | "custom_parameters": {
432 | "description": "Backtest with custom strategy parameters",
433 | "example": "run_backtest(symbol='TSLA', strategy_type='rsi', period=21, oversold=25)",
434 | "expected_output": "Performance with adjusted RSI parameters",
435 | },
436 | "ml_strategy": {
437 | "description": "Running ML-based strategy",
438 | "example": "run_backtest(symbol='NVDA', strategy_type='online_learning', lookback=30)",
439 | "expected_output": "Adaptive strategy performance with online learning",
440 | },
441 | "optimization": {
442 | "description": "Optimize strategy parameters",
443 | "example": "optimize_parameters(symbol='MSFT', strategy_type='sma_cross')",
444 | "expected_output": "Best parameter combination and performance metrics",
445 | },
446 | "comparison": {
447 | "description": "Compare multiple strategies",
448 | "example": "compare_strategies(symbol='SPY', strategies=['sma_cross', 'rsi', 'online_learning'])",
449 | "expected_output": "Side-by-side comparison of strategy performance",
450 | },
451 | }
452 |
453 | return True
454 |
```
--------------------------------------------------------------------------------
/maverick_mcp/providers/mocks/mock_stock_data.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Mock stock data provider implementations for testing.
3 | """
4 |
5 | from datetime import datetime, timedelta
6 | from typing import Any
7 |
8 | import numpy as np
9 | import pandas as pd
10 |
11 |
12 | class MockStockDataFetcher:
13 | """
14 | Mock implementation of IStockDataFetcher for testing.
15 |
16 | This implementation provides predictable test data without requiring
17 | external API calls or database access.
18 | """
19 |
20 | def __init__(self, test_data: dict[str, pd.DataFrame] | None = None):
21 | """
22 | Initialize the mock stock data fetcher.
23 |
24 | Args:
25 | test_data: Optional dictionary mapping symbols to DataFrames
26 | """
27 | self._test_data = test_data or {}
28 | self._call_log: list[dict[str, Any]] = []
29 |
30 | async def get_stock_data(
31 | self,
32 | symbol: str,
33 | start_date: str | None = None,
34 | end_date: str | None = None,
35 | period: str | None = None,
36 | interval: str = "1d",
37 | use_cache: bool = True,
38 | ) -> pd.DataFrame:
39 | """Get mock stock data."""
40 | self._log_call(
41 | "get_stock_data",
42 | {
43 | "symbol": symbol,
44 | "start_date": start_date,
45 | "end_date": end_date,
46 | "period": period,
47 | "interval": interval,
48 | "use_cache": use_cache,
49 | },
50 | )
51 |
52 | symbol = symbol.upper()
53 |
54 | # Return test data if available
55 | if symbol in self._test_data:
56 | df = self._test_data[symbol].copy()
57 |
58 | # Filter by date range if specified
59 | if start_date or end_date:
60 | if start_date:
61 | df = df[df.index >= start_date]
62 | if end_date:
63 | df = df[df.index <= end_date]
64 |
65 | return df
66 |
67 | # Generate synthetic data
68 | return self._generate_synthetic_data(symbol, start_date, end_date, period)
69 |
70 | async def get_realtime_data(self, symbol: str) -> dict[str, Any] | None:
71 | """Get mock real-time stock data."""
72 | self._log_call("get_realtime_data", {"symbol": symbol})
73 |
74 | # Return predictable mock data
75 | return {
76 | "symbol": symbol.upper(),
77 | "price": 150.25,
78 | "change": 2.15,
79 | "change_percent": 1.45,
80 | "volume": 1234567,
81 | "timestamp": datetime.now(),
82 | "timestamp_display": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
83 | "is_real_time": False,
84 | }
85 |
86 | async def get_stock_info(self, symbol: str) -> dict[str, Any]:
87 | """Get mock stock information."""
88 | self._log_call("get_stock_info", {"symbol": symbol})
89 |
90 | return {
91 | "symbol": symbol.upper(),
92 | "longName": f"{symbol.upper()} Inc.",
93 | "sector": "Technology",
94 | "industry": "Software",
95 | "marketCap": 1000000000,
96 | "previousClose": 148.10,
97 | "beta": 1.2,
98 | "dividendYield": 0.02,
99 | "peRatio": 25.5,
100 | }
101 |
102 | async def get_news(self, symbol: str, limit: int = 10) -> pd.DataFrame:
103 | """Get mock news data."""
104 | self._log_call("get_news", {"symbol": symbol, "limit": limit})
105 |
106 | # Generate mock news data
107 | news_data = []
108 | for i in range(min(limit, 5)): # Generate up to 5 mock articles
109 | news_data.append(
110 | {
111 | "title": f"Mock news article {i + 1} for {symbol}",
112 | "publisher": f"Mock Publisher {i + 1}",
113 | "link": f"https://example.com/news/{symbol.lower()}/{i + 1}",
114 | "providerPublishTime": datetime.now() - timedelta(hours=i),
115 | "type": "STORY",
116 | }
117 | )
118 |
119 | return pd.DataFrame(news_data)
120 |
121 | async def get_earnings(self, symbol: str) -> dict[str, Any]:
122 | """Get mock earnings data."""
123 | self._log_call("get_earnings", {"symbol": symbol})
124 |
125 | return {
126 | "earnings": {
127 | "2023": 5.25,
128 | "2022": 4.80,
129 | "2021": 4.35,
130 | },
131 | "earnings_dates": {
132 | "next_date": "2024-01-25",
133 | "eps_estimate": 1.35,
134 | },
135 | "earnings_trend": {
136 | "current_quarter": 1.30,
137 | "next_quarter": 1.35,
138 | "current_year": 5.40,
139 | "next_year": 5.85,
140 | },
141 | }
142 |
143 | async def get_recommendations(self, symbol: str) -> pd.DataFrame:
144 | """Get mock analyst recommendations."""
145 | self._log_call("get_recommendations", {"symbol": symbol})
146 |
147 | recommendations_data = [
148 | {
149 | "firm": "Mock Investment Bank",
150 | "toGrade": "Buy",
151 | "fromGrade": "Hold",
152 | "action": "up",
153 | },
154 | {
155 | "firm": "Another Mock Firm",
156 | "toGrade": "Hold",
157 | "fromGrade": "Hold",
158 | "action": "main",
159 | },
160 | ]
161 |
162 | return pd.DataFrame(recommendations_data)
163 |
164 | async def is_market_open(self) -> bool:
165 | """Check if market is open (mock)."""
166 | self._log_call("is_market_open", {})
167 |
168 | # Return True for testing by default
169 | return True
170 |
171 | async def is_etf(self, symbol: str) -> bool:
172 | """Check if symbol is an ETF (mock)."""
173 | self._log_call("is_etf", {"symbol": symbol})
174 |
175 | # Mock ETF detection based on common ETF symbols
176 | etf_symbols = {"SPY", "QQQ", "IWM", "VTI", "VEA", "VWO", "XLK", "XLF"}
177 | return symbol.upper() in etf_symbols
178 |
179 | def _generate_synthetic_data(
180 | self,
181 | symbol: str,
182 | start_date: str | None = None,
183 | end_date: str | None = None,
184 | period: str | None = None,
185 | ) -> pd.DataFrame:
186 | """Generate synthetic stock data for testing."""
187 |
188 | # Determine date range
189 | if period:
190 | days = {"1d": 1, "5d": 5, "1mo": 30, "3mo": 90, "1y": 365}.get(period, 30)
191 | end_dt = datetime.now()
192 | start_dt = end_dt - timedelta(days=days)
193 | else:
194 | end_dt = pd.to_datetime(end_date) if end_date else datetime.now()
195 | start_dt = (
196 | pd.to_datetime(start_date)
197 | if start_date
198 | else end_dt - timedelta(days=30)
199 | )
200 |
201 | # Generate date range (business days only)
202 | dates = pd.bdate_range(start=start_dt, end=end_dt)
203 |
204 | if len(dates) == 0:
205 | # Return empty DataFrame with proper columns
206 | return pd.DataFrame(
207 | columns=[
208 | "Open",
209 | "High",
210 | "Low",
211 | "Close",
212 | "Volume",
213 | "Dividends",
214 | "Stock Splits",
215 | ]
216 | )
217 |
218 | # Generate synthetic price data
219 | np.random.seed(hash(symbol) % 2**32) # Consistent data per symbol
220 |
221 | base_price = 100.0
222 | returns = np.random.normal(
223 | 0.001, 0.02, len(dates)
224 | ) # 0.1% mean return, 2% volatility
225 |
226 | prices = [base_price]
227 | for ret in returns[1:]:
228 | prices.append(prices[-1] * (1 + ret))
229 |
230 | # Generate OHLCV data
231 | data = []
232 | for _i, (_date, close_price) in enumerate(zip(dates, prices, strict=False)):
233 | # Generate Open, High, Low based on Close
234 | volatility = close_price * 0.02 # 2% intraday volatility
235 |
236 | open_price = close_price + np.random.normal(0, volatility * 0.5)
237 | high_price = max(open_price, close_price) + abs(
238 | np.random.normal(0, volatility * 0.3)
239 | )
240 | low_price = min(open_price, close_price) - abs(
241 | np.random.normal(0, volatility * 0.3)
242 | )
243 |
244 | # Ensure High >= Low and prices are positive
245 | high_price = max(high_price, low_price + 0.01, 0.01)
246 | low_price = max(low_price, 0.01)
247 |
248 | volume = int(
249 | np.random.lognormal(15, 0.5)
250 | ) # Log-normal distribution for volume
251 |
252 | data.append(
253 | {
254 | "Open": round(open_price, 2),
255 | "High": round(high_price, 2),
256 | "Low": round(low_price, 2),
257 | "Close": round(close_price, 2),
258 | "Volume": volume,
259 | "Dividends": 0.0,
260 | "Stock Splits": 0.0,
261 | }
262 | )
263 |
264 | df = pd.DataFrame(data, index=dates)
265 | df.index.name = "Date"
266 |
267 | return df
268 |
269 | # Testing utilities
270 |
271 | def _log_call(self, method: str, args: dict[str, Any]) -> None:
272 | """Log method calls for testing verification."""
273 | self._call_log.append(
274 | {
275 | "method": method,
276 | "args": args,
277 | "timestamp": datetime.now(),
278 | }
279 | )
280 |
281 | def get_call_log(self) -> list[dict[str, Any]]:
282 | """Get the log of method calls."""
283 | return self._call_log.copy()
284 |
285 | def clear_call_log(self) -> None:
286 | """Clear the method call log."""
287 | self._call_log.clear()
288 |
289 | def set_test_data(self, symbol: str, data: pd.DataFrame) -> None:
290 | """Set test data for a specific symbol."""
291 | self._test_data[symbol.upper()] = data
292 |
293 | def clear_test_data(self) -> None:
294 | """Clear all test data."""
295 | self._test_data.clear()
296 |
297 |
298 | class MockStockScreener:
299 | """
300 | Mock implementation of IStockScreener for testing.
301 | """
302 |
303 | def __init__(
304 | self, test_recommendations: dict[str, list[dict[str, Any]]] | None = None
305 | ):
306 | """
307 | Initialize the mock stock screener.
308 |
309 | Args:
310 | test_recommendations: Optional dictionary of test screening results
311 | """
312 | self._test_recommendations = test_recommendations or {}
313 | self._call_log: list[dict[str, Any]] = []
314 |
315 | async def get_maverick_recommendations(
316 | self, limit: int = 20, min_score: int | None = None
317 | ) -> list[dict[str, Any]]:
318 | """Get mock maverick recommendations."""
319 | self._log_call(
320 | "get_maverick_recommendations", {"limit": limit, "min_score": min_score}
321 | )
322 |
323 | if "maverick" in self._test_recommendations:
324 | results = self._test_recommendations["maverick"]
325 | else:
326 | results = self._generate_mock_maverick_recommendations()
327 |
328 | # Apply filters
329 | if min_score:
330 | results = [r for r in results if r.get("combined_score", 0) >= min_score]
331 |
332 | return results[:limit]
333 |
334 | async def get_maverick_bear_recommendations(
335 | self, limit: int = 20, min_score: int | None = None
336 | ) -> list[dict[str, Any]]:
337 | """Get mock maverick bear recommendations."""
338 | self._log_call(
339 | "get_maverick_bear_recommendations",
340 | {"limit": limit, "min_score": min_score},
341 | )
342 |
343 | if "bear" in self._test_recommendations:
344 | results = self._test_recommendations["bear"]
345 | else:
346 | results = self._generate_mock_bear_recommendations()
347 |
348 | # Apply filters
349 | if min_score:
350 | results = [r for r in results if r.get("score", 0) >= min_score]
351 |
352 | return results[:limit]
353 |
354 | async def get_trending_recommendations(
355 | self, limit: int = 20, min_momentum_score: float | None = None
356 | ) -> list[dict[str, Any]]:
357 | """Get mock trending recommendations."""
358 | self._log_call(
359 | "get_trending_recommendations",
360 | {"limit": limit, "min_momentum_score": min_momentum_score},
361 | )
362 |
363 | if "trending" in self._test_recommendations:
364 | results = self._test_recommendations["trending"]
365 | else:
366 | results = self._generate_mock_trending_recommendations()
367 |
368 | # Apply filters
369 | if min_momentum_score:
370 | results = [
371 | r for r in results if r.get("momentum_score", 0) >= min_momentum_score
372 | ]
373 |
374 | return results[:limit]
375 |
376 | async def get_all_screening_recommendations(
377 | self,
378 | ) -> dict[str, list[dict[str, Any]]]:
379 | """Get all mock screening recommendations."""
380 | self._log_call("get_all_screening_recommendations", {})
381 |
382 | return {
383 | "maverick_stocks": await self.get_maverick_recommendations(),
384 | "maverick_bear_stocks": await self.get_maverick_bear_recommendations(),
385 | "supply_demand_breakouts": await self.get_trending_recommendations(),
386 | }
387 |
388 | def _generate_mock_maverick_recommendations(self) -> list[dict[str, Any]]:
389 | """Generate mock maverick recommendations."""
390 | return [
391 | {
392 | "symbol": "AAPL",
393 | "combined_score": 95,
394 | "momentum_score": 92,
395 | "pattern": "Cup with Handle",
396 | "consolidation": "yes",
397 | "squeeze": "firing",
398 | "recommendation_type": "maverick_bullish",
399 | "reason": "Exceptional combined score with outstanding relative strength",
400 | },
401 | {
402 | "symbol": "MSFT",
403 | "combined_score": 88,
404 | "momentum_score": 85,
405 | "pattern": "Flat Base",
406 | "consolidation": "no",
407 | "squeeze": "setup",
408 | "recommendation_type": "maverick_bullish",
409 | "reason": "Strong combined score with strong relative strength",
410 | },
411 | ]
412 |
413 | def _generate_mock_bear_recommendations(self) -> list[dict[str, Any]]:
414 | """Generate mock bear recommendations."""
415 | return [
416 | {
417 | "symbol": "BEAR1",
418 | "score": 92,
419 | "momentum_score": 25,
420 | "rsi_14": 28,
421 | "atr_contraction": True,
422 | "big_down_vol": True,
423 | "recommendation_type": "maverick_bearish",
424 | "reason": "Exceptional bear score with weak relative strength, oversold RSI",
425 | },
426 | {
427 | "symbol": "BEAR2",
428 | "score": 85,
429 | "momentum_score": 30,
430 | "rsi_14": 35,
431 | "atr_contraction": False,
432 | "big_down_vol": True,
433 | "recommendation_type": "maverick_bearish",
434 | "reason": "Strong bear score with weak relative strength",
435 | },
436 | ]
437 |
438 | def _generate_mock_trending_recommendations(self) -> list[dict[str, Any]]:
439 | """Generate mock trending recommendations."""
440 | return [
441 | {
442 | "symbol": "TREND1",
443 | "momentum_score": 95,
444 | "close": 150.25,
445 | "sma_50": 145.50,
446 | "sma_150": 140.25,
447 | "sma_200": 135.75,
448 | "pattern": "Breakout",
449 | "recommendation_type": "trending_stage2",
450 | "reason": "Uptrend with exceptional momentum strength",
451 | },
452 | {
453 | "symbol": "TREND2",
454 | "momentum_score": 88,
455 | "close": 85.30,
456 | "sma_50": 82.15,
457 | "sma_150": 79.80,
458 | "sma_200": 76.45,
459 | "pattern": "Higher Lows",
460 | "recommendation_type": "trending_stage2",
461 | "reason": "Uptrend with strong momentum strength",
462 | },
463 | ]
464 |
465 | # Testing utilities
466 |
467 | def _log_call(self, method: str, args: dict[str, Any]) -> None:
468 | """Log method calls for testing verification."""
469 | self._call_log.append(
470 | {
471 | "method": method,
472 | "args": args,
473 | "timestamp": datetime.now(),
474 | }
475 | )
476 |
477 | def get_call_log(self) -> list[dict[str, Any]]:
478 | """Get the log of method calls."""
479 | return self._call_log.copy()
480 |
481 | def clear_call_log(self) -> None:
482 | """Clear the method call log."""
483 | self._call_log.clear()
484 |
485 | def set_test_recommendations(
486 | self, screening_type: str, recommendations: list[dict[str, Any]]
487 | ) -> None:
488 | """Set test recommendations for a specific screening type."""
489 | self._test_recommendations[screening_type] = recommendations
490 |
491 | def clear_test_recommendations(self) -> None:
492 | """Clear all test recommendations."""
493 | self._test_recommendations.clear()
494 |
```
--------------------------------------------------------------------------------
/maverick_mcp/utils/database_monitoring.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Database and Redis monitoring utilities for MaverickMCP.
3 |
4 | This module provides comprehensive monitoring for:
5 | - SQLAlchemy database connection pools
6 | - Database query performance
7 | - Redis connection pools and operations
8 | - Cache hit rates and performance metrics
9 | """
10 |
11 | import asyncio
12 | import time
13 | from contextlib import asynccontextmanager, contextmanager
14 | from typing import Any
15 |
16 | from sqlalchemy.event import listen
17 | from sqlalchemy.pool import Pool
18 |
19 | from maverick_mcp.utils.logging import get_logger
20 | from maverick_mcp.utils.monitoring import (
21 | redis_connections,
22 | redis_memory_usage,
23 | track_cache_operation,
24 | track_database_connection_event,
25 | track_database_query,
26 | track_redis_operation,
27 | update_database_metrics,
28 | update_redis_metrics,
29 | )
30 | from maverick_mcp.utils.tracing import trace_cache_operation, trace_database_query
31 |
32 | logger = get_logger(__name__)
33 |
34 |
35 | class DatabaseMonitor:
36 | """Monitor for SQLAlchemy database operations and connection pools."""
37 |
38 | def __init__(self, engine=None):
39 | self.engine = engine
40 | self.query_stats = {}
41 | self._setup_event_listeners()
42 |
43 | def _setup_event_listeners(self):
44 | """Set up SQLAlchemy event listeners for monitoring."""
45 | if not self.engine:
46 | return
47 |
48 | # Connection pool events
49 | listen(Pool, "connect", self._on_connection_created)
50 | listen(Pool, "checkout", self._on_connection_checkout)
51 | listen(Pool, "checkin", self._on_connection_checkin)
52 | listen(Pool, "close", self._on_connection_closed)
53 |
54 | # Query execution events
55 | listen(self.engine, "before_cursor_execute", self._on_before_query)
56 | listen(self.engine, "after_cursor_execute", self._on_after_query)
57 |
58 | def _on_connection_created(self, dbapi_connection, connection_record):
59 | """Handle new database connection creation."""
60 | track_database_connection_event("created")
61 | logger.debug("Database connection created")
62 |
63 | def _on_connection_checkout(
64 | self, dbapi_connection, connection_record, connection_proxy
65 | ):
66 | """Handle connection checkout from pool."""
67 | # Update connection metrics
68 | pool = self.engine.pool
69 | self._update_pool_metrics(pool)
70 |
71 | def _on_connection_checkin(self, dbapi_connection, connection_record):
72 | """Handle connection checkin to pool."""
73 | # Update connection metrics
74 | pool = self.engine.pool
75 | self._update_pool_metrics(pool)
76 |
77 | def _on_connection_closed(self, dbapi_connection, connection_record):
78 | """Handle connection closure."""
79 | track_database_connection_event("closed", "normal")
80 | logger.debug("Database connection closed")
81 |
82 | def _on_before_query(
83 | self, conn, cursor, statement, parameters, context, executemany
84 | ):
85 | """Handle query execution start."""
86 | context._query_start_time = time.time()
87 | context._query_statement = statement
88 |
89 | def _on_after_query(
90 | self, conn, cursor, statement, parameters, context, executemany
91 | ):
92 | """Handle query execution completion."""
93 | if hasattr(context, "_query_start_time"):
94 | duration = time.time() - context._query_start_time
95 | query_type = self._extract_query_type(statement)
96 | table = self._extract_table_name(statement)
97 |
98 | # Track metrics
99 | track_database_query(query_type, table, duration, "success")
100 |
101 | # Log slow queries
102 | if duration > 1.0: # Queries over 1 second
103 | logger.warning(
104 | "Slow database query detected",
105 | extra={
106 | "query_type": query_type,
107 | "table": table,
108 | "duration_seconds": duration,
109 | "statement": statement[:200] + "..."
110 | if len(statement) > 200
111 | else statement,
112 | },
113 | )
114 |
115 | def _update_pool_metrics(self, pool):
116 | """Update connection pool metrics."""
117 | try:
118 | pool_size = pool.size()
119 | checked_out = pool.checkedout()
120 | checked_in = pool.checkedin()
121 |
122 | update_database_metrics(
123 | pool_size=pool_size,
124 | active_connections=checked_out,
125 | idle_connections=checked_in,
126 | )
127 | except Exception as e:
128 | logger.warning(f"Failed to update pool metrics: {e}")
129 |
130 | def _extract_query_type(self, statement: str) -> str:
131 | """Extract query type from SQL statement."""
132 | statement_upper = statement.strip().upper()
133 | if statement_upper.startswith("SELECT"):
134 | return "SELECT"
135 | elif statement_upper.startswith("INSERT"):
136 | return "INSERT"
137 | elif statement_upper.startswith("UPDATE"):
138 | return "UPDATE"
139 | elif statement_upper.startswith("DELETE"):
140 | return "DELETE"
141 | elif statement_upper.startswith("CREATE"):
142 | return "CREATE"
143 | elif statement_upper.startswith("DROP"):
144 | return "DROP"
145 | elif statement_upper.startswith("ALTER"):
146 | return "ALTER"
147 | else:
148 | return "OTHER"
149 |
150 | def _extract_table_name(self, statement: str) -> str | None:
151 | """Extract table name from SQL statement."""
152 | import re
153 |
154 | # Simple regex to extract table names
155 | patterns = [
156 | r"FROM\s+([a-zA-Z_][a-zA-Z0-9_]*)", # SELECT FROM table
157 | r"INTO\s+([a-zA-Z_][a-zA-Z0-9_]*)", # INSERT INTO table
158 | r"UPDATE\s+([a-zA-Z_][a-zA-Z0-9_]*)", # UPDATE table
159 | r"DELETE\s+FROM\s+([a-zA-Z_][a-zA-Z0-9_]*)", # DELETE FROM table
160 | ]
161 |
162 | for pattern in patterns:
163 | match = re.search(pattern, statement.upper())
164 | if match:
165 | return match.group(1).lower()
166 |
167 | return "unknown"
168 |
169 | @contextmanager
170 | def trace_query(self, query_type: str, table: str | None = None):
171 | """Context manager for tracing database queries."""
172 | with trace_database_query(query_type, table) as span:
173 | start_time = time.time()
174 | try:
175 | yield span
176 | duration = time.time() - start_time
177 | track_database_query(
178 | query_type, table or "unknown", duration, "success"
179 | )
180 | except Exception:
181 | duration = time.time() - start_time
182 | track_database_query(query_type, table or "unknown", duration, "error")
183 | raise
184 |
185 | def get_pool_status(self) -> dict[str, Any]:
186 | """Get current database pool status."""
187 | if not self.engine:
188 | return {}
189 |
190 | try:
191 | pool = self.engine.pool
192 | return {
193 | "pool_size": pool.size(),
194 | "checked_out": pool.checkedout(),
195 | "checked_in": pool.checkedin(),
196 | "overflow": pool.overflow(),
197 | "invalid": pool.invalid(),
198 | }
199 | except Exception as e:
200 | logger.error(f"Failed to get pool status: {e}")
201 | return {}
202 |
203 |
204 | class RedisMonitor:
205 | """Monitor for Redis operations and connection pools."""
206 |
207 | def __init__(self, redis_client=None):
208 | self.redis_client = redis_client
209 | self.operation_stats = {}
210 |
211 | @asynccontextmanager
212 | async def trace_operation(self, operation: str, key: str | None = None):
213 | """Context manager for tracing Redis operations."""
214 | with trace_cache_operation(operation, "redis") as span:
215 | start_time = time.time()
216 |
217 | if span and key:
218 | span.set_attribute("redis.key", key)
219 |
220 | try:
221 | yield span
222 | duration = time.time() - start_time
223 | track_redis_operation(operation, duration, "success")
224 | except Exception as e:
225 | duration = time.time() - start_time
226 | track_redis_operation(operation, duration, "error")
227 |
228 | if span:
229 | span.record_exception(e)
230 |
231 | logger.error(
232 | f"Redis operation failed: {operation}",
233 | extra={
234 | "operation": operation,
235 | "key": key,
236 | "duration_seconds": duration,
237 | "error": str(e),
238 | },
239 | )
240 | raise
241 |
242 | async def monitor_get(self, key: str):
243 | """Monitor Redis GET operation."""
244 | async with self.trace_operation("get", key):
245 | try:
246 | result = await self.redis_client.get(key)
247 | hit = result is not None
248 | track_cache_operation("redis", "get", hit, self._get_key_prefix(key))
249 | return result
250 | except Exception:
251 | track_cache_operation("redis", "get", False, self._get_key_prefix(key))
252 | raise
253 |
254 | async def monitor_set(self, key: str, value: Any, **kwargs):
255 | """Monitor Redis SET operation."""
256 | async with self.trace_operation("set", key):
257 | return await self.redis_client.set(key, value, **kwargs)
258 |
259 | async def monitor_delete(self, key: str):
260 | """Monitor Redis DELETE operation."""
261 | async with self.trace_operation("delete", key):
262 | return await self.redis_client.delete(key)
263 |
264 | async def monitor_exists(self, key: str):
265 | """Monitor Redis EXISTS operation."""
266 | async with self.trace_operation("exists", key):
267 | return await self.redis_client.exists(key)
268 |
269 | async def update_redis_metrics(self):
270 | """Update Redis metrics from server info."""
271 | if not self.redis_client:
272 | return
273 |
274 | try:
275 | info = await self.redis_client.info()
276 |
277 | # Connection metrics
278 | connected_clients = info.get("connected_clients", 0)
279 | redis_connections.set(connected_clients)
280 |
281 | # Memory metrics
282 | used_memory = info.get("used_memory", 0)
283 | redis_memory_usage.set(used_memory)
284 |
285 | # Keyspace metrics
286 | keyspace_hits = info.get("keyspace_hits", 0)
287 | keyspace_misses = info.get("keyspace_misses", 0)
288 |
289 | # Update counters (these are cumulative, so we track the difference)
290 | update_redis_metrics(
291 | connections=connected_clients,
292 | memory_bytes=used_memory,
293 | hits=0, # Will be updated by individual operations
294 | misses=0, # Will be updated by individual operations
295 | )
296 |
297 | logger.debug(
298 | "Redis metrics updated",
299 | extra={
300 | "connected_clients": connected_clients,
301 | "used_memory_mb": used_memory / 1024 / 1024,
302 | "keyspace_hits": keyspace_hits,
303 | "keyspace_misses": keyspace_misses,
304 | },
305 | )
306 |
307 | except Exception as e:
308 | logger.error(f"Failed to update Redis metrics: {e}")
309 |
310 | def _get_key_prefix(self, key: str) -> str:
311 | """Extract key prefix for metrics grouping."""
312 | if ":" in key:
313 | return key.split(":")[0]
314 | return "other"
315 |
316 | async def get_redis_info(self) -> dict[str, Any]:
317 | """Get Redis server information."""
318 | if not self.redis_client:
319 | return {}
320 |
321 | try:
322 | info = await self.redis_client.info()
323 | return {
324 | "connected_clients": info.get("connected_clients", 0),
325 | "used_memory": info.get("used_memory", 0),
326 | "used_memory_human": info.get("used_memory_human", "0B"),
327 | "keyspace_hits": info.get("keyspace_hits", 0),
328 | "keyspace_misses": info.get("keyspace_misses", 0),
329 | "total_commands_processed": info.get("total_commands_processed", 0),
330 | "uptime_in_seconds": info.get("uptime_in_seconds", 0),
331 | }
332 | except Exception as e:
333 | logger.error(f"Failed to get Redis info: {e}")
334 | return {}
335 |
336 |
337 | class CacheMonitor:
338 | """High-level cache monitoring that supports multiple backends."""
339 |
340 | def __init__(self, redis_monitor: RedisMonitor | None = None):
341 | self.redis_monitor = redis_monitor
342 |
343 | @contextmanager
344 | def monitor_operation(self, cache_type: str, operation: str, key: str):
345 | """Monitor cache operation across different backends."""
346 | start_time = time.time()
347 | hit = False
348 |
349 | try:
350 | yield
351 | hit = True # If no exception, assume it was a hit for GET operations
352 | except Exception as e:
353 | logger.error(
354 | f"Cache operation failed: {cache_type} {operation}",
355 | extra={
356 | "cache_type": cache_type,
357 | "operation": operation,
358 | "key": key,
359 | "error": str(e),
360 | },
361 | )
362 | raise
363 | finally:
364 | duration = time.time() - start_time
365 |
366 | # Track metrics based on operation
367 | if operation in ["get", "exists"]:
368 | track_cache_operation(
369 | cache_type, operation, hit, self._get_key_prefix(key)
370 | )
371 |
372 | # Log slow cache operations
373 | if duration > 0.1: # Operations over 100ms
374 | logger.warning(
375 | f"Slow cache operation: {cache_type} {operation}",
376 | extra={
377 | "cache_type": cache_type,
378 | "operation": operation,
379 | "key": key,
380 | "duration_seconds": duration,
381 | },
382 | )
383 |
384 | def _get_key_prefix(self, key: str) -> str:
385 | """Extract key prefix for metrics grouping."""
386 | if ":" in key:
387 | return key.split(":")[0]
388 | return "other"
389 |
390 | async def update_all_metrics(self):
391 | """Update metrics for all monitored cache backends."""
392 | tasks = []
393 |
394 | if self.redis_monitor:
395 | tasks.append(self.redis_monitor.update_redis_metrics())
396 |
397 | if tasks:
398 | try:
399 | await asyncio.gather(*tasks, return_exceptions=True)
400 | except Exception as e:
401 | logger.error(f"Failed to update cache metrics: {e}")
402 |
403 |
404 | # Global monitor instances
405 | _database_monitor: DatabaseMonitor | None = None
406 | _redis_monitor: RedisMonitor | None = None
407 | _cache_monitor: CacheMonitor | None = None
408 |
409 |
410 | def get_database_monitor(engine=None) -> DatabaseMonitor:
411 | """Get or create the global database monitor."""
412 | global _database_monitor
413 | if _database_monitor is None:
414 | _database_monitor = DatabaseMonitor(engine)
415 | return _database_monitor
416 |
417 |
418 | def get_redis_monitor(redis_client=None) -> RedisMonitor:
419 | """Get or create the global Redis monitor."""
420 | global _redis_monitor
421 | if _redis_monitor is None:
422 | _redis_monitor = RedisMonitor(redis_client)
423 | return _redis_monitor
424 |
425 |
426 | def get_cache_monitor() -> CacheMonitor:
427 | """Get or create the global cache monitor."""
428 | global _cache_monitor
429 | if _cache_monitor is None:
430 | redis_monitor = get_redis_monitor()
431 | _cache_monitor = CacheMonitor(redis_monitor)
432 | return _cache_monitor
433 |
434 |
435 | def initialize_database_monitoring(engine):
436 | """Initialize database monitoring with the given engine."""
437 | logger.info("Initializing database monitoring...")
438 | monitor = get_database_monitor(engine)
439 | logger.info("Database monitoring initialized")
440 | return monitor
441 |
442 |
443 | def initialize_redis_monitoring(redis_client):
444 | """Initialize Redis monitoring with the given client."""
445 | logger.info("Initializing Redis monitoring...")
446 | monitor = get_redis_monitor(redis_client)
447 | logger.info("Redis monitoring initialized")
448 | return monitor
449 |
450 |
451 | async def start_periodic_metrics_collection(interval: int = 30):
452 | """Start periodic collection of database and cache metrics."""
453 | logger.info(f"Starting periodic metrics collection (interval: {interval}s)")
454 |
455 | cache_monitor = get_cache_monitor()
456 |
457 | while True:
458 | try:
459 | await cache_monitor.update_all_metrics()
460 | except Exception as e:
461 | logger.error(f"Error in periodic metrics collection: {e}")
462 |
463 | await asyncio.sleep(interval)
464 |
```
--------------------------------------------------------------------------------
/maverick_mcp/monitoring/middleware.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Monitoring middleware for automatic metrics collection.
3 |
4 | This module provides middleware components that automatically track:
5 | - API calls and response times
6 | - Strategy execution performance
7 | - Resource usage during operations
8 | - Anomaly detection triggers
9 | """
10 |
11 | import asyncio
12 | import time
13 | from collections.abc import Callable
14 | from contextlib import asynccontextmanager
15 | from functools import wraps
16 | from typing import Any
17 |
18 | from maverick_mcp.monitoring.metrics import get_backtesting_metrics
19 | from maverick_mcp.utils.logging import get_logger
20 |
21 | logger = get_logger(__name__)
22 |
23 |
24 | class MetricsMiddleware:
25 | """
26 | Middleware for automatic metrics collection during backtesting operations.
27 |
28 | Provides decorators and context managers for seamless metrics integration.
29 | """
30 |
31 | def __init__(self):
32 | self.collector = get_backtesting_metrics()
33 | self.logger = get_logger(f"{__name__}.MetricsMiddleware")
34 |
35 | def track_api_call(self, provider: str, endpoint: str, method: str = "GET"):
36 | """
37 | Decorator to automatically track API call metrics.
38 |
39 | Usage:
40 | @middleware.track_api_call("tiingo", "/daily/{symbol}")
41 | async def get_stock_data(symbol: str):
42 | # API call logic here
43 | pass
44 | """
45 |
46 | def decorator(func):
47 | @wraps(func)
48 | async def async_wrapper(*args, **kwargs):
49 | start_time = time.time()
50 | status_code = 200
51 | error_type = None
52 |
53 | try:
54 | result = await func(*args, **kwargs)
55 | return result
56 | except Exception as e:
57 | status_code = getattr(e, "status_code", 500)
58 | error_type = type(e).__name__
59 | raise
60 | finally:
61 | duration = time.time() - start_time
62 | self.collector.track_api_call(
63 | provider=provider,
64 | endpoint=endpoint,
65 | method=method,
66 | status_code=status_code,
67 | duration=duration,
68 | error_type=error_type,
69 | )
70 |
71 | @wraps(func)
72 | def sync_wrapper(*args, **kwargs):
73 | start_time = time.time()
74 | status_code = 200
75 | error_type = None
76 |
77 | try:
78 | result = func(*args, **kwargs)
79 | return result
80 | except Exception as e:
81 | status_code = getattr(e, "status_code", 500)
82 | error_type = type(e).__name__
83 | raise
84 | finally:
85 | duration = time.time() - start_time
86 | self.collector.track_api_call(
87 | provider=provider,
88 | endpoint=endpoint,
89 | method=method,
90 | status_code=status_code,
91 | duration=duration,
92 | error_type=error_type,
93 | )
94 |
95 | # Return appropriate wrapper based on function type
96 | if asyncio.iscoroutinefunction(func):
97 | return async_wrapper
98 | else:
99 | return sync_wrapper
100 |
101 | return decorator
102 |
103 | def track_strategy_execution(
104 | self, strategy_name: str, symbol: str, timeframe: str = "1D"
105 | ):
106 | """
107 | Decorator to automatically track strategy execution metrics.
108 |
109 | Usage:
110 | @middleware.track_strategy_execution("RSI_Strategy", "AAPL")
111 | def run_backtest(data):
112 | # Strategy execution logic here
113 | return results
114 | """
115 |
116 | def decorator(func):
117 | @wraps(func)
118 | async def async_wrapper(*args, **kwargs):
119 | with self.collector.track_backtest_execution(
120 | strategy_name=strategy_name,
121 | symbol=symbol,
122 | timeframe=timeframe,
123 | data_points=kwargs.get("data_points", 0),
124 | ):
125 | result = await func(*args, **kwargs)
126 |
127 | # Extract performance metrics from result if available
128 | if isinstance(result, dict):
129 | self._extract_and_track_performance(
130 | result, strategy_name, symbol, timeframe
131 | )
132 |
133 | return result
134 |
135 | @wraps(func)
136 | def sync_wrapper(*args, **kwargs):
137 | with self.collector.track_backtest_execution(
138 | strategy_name=strategy_name,
139 | symbol=symbol,
140 | timeframe=timeframe,
141 | data_points=kwargs.get("data_points", 0),
142 | ):
143 | result = func(*args, **kwargs)
144 |
145 | # Extract performance metrics from result if available
146 | if isinstance(result, dict):
147 | self._extract_and_track_performance(
148 | result, strategy_name, symbol, timeframe
149 | )
150 |
151 | return result
152 |
153 | # Return appropriate wrapper based on function type
154 | if asyncio.iscoroutinefunction(func):
155 | return async_wrapper
156 | else:
157 | return sync_wrapper
158 |
159 | return decorator
160 |
161 | def track_resource_usage(self, operation_type: str):
162 | """
163 | Decorator to automatically track resource usage for operations.
164 |
165 | Usage:
166 | @middleware.track_resource_usage("vectorbt_backtest")
167 | def run_vectorbt_analysis(data):
168 | # VectorBT analysis logic here
169 | pass
170 | """
171 |
172 | def decorator(func):
173 | @wraps(func)
174 | async def async_wrapper(*args, **kwargs):
175 | import psutil
176 |
177 | process = psutil.Process()
178 | start_memory = process.memory_info().rss / 1024 / 1024
179 | start_time = time.time()
180 |
181 | try:
182 | result = await func(*args, **kwargs)
183 | return result
184 | finally:
185 | end_memory = process.memory_info().rss / 1024 / 1024
186 | duration = time.time() - start_time
187 | memory_used = max(0, end_memory - start_memory)
188 |
189 | # Determine data size category
190 | data_size = "unknown"
191 | if "data" in kwargs:
192 | data_length = (
193 | len(kwargs["data"])
194 | if hasattr(kwargs["data"], "__len__")
195 | else 0
196 | )
197 | data_size = self.collector._categorize_data_size(data_length)
198 |
199 | self.collector.track_resource_usage(
200 | operation_type=operation_type,
201 | memory_mb=memory_used,
202 | computation_time=duration,
203 | data_size=data_size,
204 | )
205 |
206 | @wraps(func)
207 | def sync_wrapper(*args, **kwargs):
208 | import psutil
209 |
210 | process = psutil.Process()
211 | start_memory = process.memory_info().rss / 1024 / 1024
212 | start_time = time.time()
213 |
214 | try:
215 | result = func(*args, **kwargs)
216 | return result
217 | finally:
218 | end_memory = process.memory_info().rss / 1024 / 1024
219 | duration = time.time() - start_time
220 | memory_used = max(0, end_memory - start_memory)
221 |
222 | # Determine data size category
223 | data_size = "unknown"
224 | if "data" in kwargs:
225 | data_length = (
226 | len(kwargs["data"])
227 | if hasattr(kwargs["data"], "__len__")
228 | else 0
229 | )
230 | data_size = self.collector._categorize_data_size(data_length)
231 |
232 | self.collector.track_resource_usage(
233 | operation_type=operation_type,
234 | memory_mb=memory_used,
235 | computation_time=duration,
236 | data_size=data_size,
237 | )
238 |
239 | # Return appropriate wrapper based on function type
240 | if asyncio.iscoroutinefunction(func):
241 | return async_wrapper
242 | else:
243 | return sync_wrapper
244 |
245 | return decorator
246 |
247 | @asynccontextmanager
248 | async def track_database_operation(
249 | self, query_type: str, table_name: str, operation: str
250 | ):
251 | """
252 | Context manager to track database operation performance.
253 |
254 | Usage:
255 | async with middleware.track_database_operation("SELECT", "stocks", "fetch"):
256 | result = await db.execute(query)
257 | """
258 | start_time = time.time()
259 | try:
260 | yield
261 | finally:
262 | duration = time.time() - start_time
263 | self.collector.track_database_operation(
264 | query_type=query_type,
265 | table_name=table_name,
266 | operation=operation,
267 | duration=duration,
268 | )
269 |
270 | def _extract_and_track_performance(
271 | self, result: dict[str, Any], strategy_name: str, symbol: str, timeframe: str
272 | ):
273 | """Extract and track strategy performance metrics from results."""
274 | try:
275 | # Extract common performance metrics from result dictionary
276 | returns = result.get("total_return", result.get("returns", 0.0))
277 | sharpe_ratio = result.get("sharpe_ratio", 0.0)
278 | max_drawdown = result.get("max_drawdown", result.get("max_dd", 0.0))
279 | win_rate = result.get("win_rate", result.get("win_ratio", 0.0))
280 | total_trades = result.get("total_trades", result.get("num_trades", 0))
281 | winning_trades = result.get("winning_trades", 0)
282 |
283 | # Convert win rate to percentage if it's in decimal form
284 | if win_rate <= 1.0:
285 | win_rate *= 100
286 |
287 | # Convert max drawdown to positive percentage if negative
288 | if max_drawdown < 0:
289 | max_drawdown = abs(max_drawdown) * 100
290 |
291 | # Extract winning trades from win rate if not provided directly
292 | if winning_trades == 0 and total_trades > 0:
293 | winning_trades = int(total_trades * (win_rate / 100))
294 |
295 | # Determine period from timeframe or use default
296 | period_mapping = {"1D": "1Y", "1H": "3M", "5m": "1M", "1m": "1W"}
297 | period = period_mapping.get(timeframe, "1Y")
298 |
299 | # Track the performance metrics
300 | self.collector.track_strategy_performance(
301 | strategy_name=strategy_name,
302 | symbol=symbol,
303 | period=period,
304 | returns=returns,
305 | sharpe_ratio=sharpe_ratio,
306 | max_drawdown=max_drawdown,
307 | win_rate=win_rate,
308 | total_trades=total_trades,
309 | winning_trades=winning_trades,
310 | )
311 |
312 | self.logger.debug(
313 | f"Tracked strategy performance for {strategy_name}",
314 | extra={
315 | "strategy": strategy_name,
316 | "symbol": symbol,
317 | "returns": returns,
318 | "sharpe_ratio": sharpe_ratio,
319 | "max_drawdown": max_drawdown,
320 | "win_rate": win_rate,
321 | "total_trades": total_trades,
322 | },
323 | )
324 |
325 | except Exception as e:
326 | self.logger.warning(
327 | f"Failed to extract performance metrics from result: {e}",
328 | extra={
329 | "result_keys": list(result.keys())
330 | if isinstance(result, dict)
331 | else "not_dict"
332 | },
333 | )
334 |
335 |
336 | # Global middleware instance
337 | _middleware_instance: MetricsMiddleware | None = None
338 |
339 |
340 | def get_metrics_middleware() -> MetricsMiddleware:
341 | """Get or create the global metrics middleware instance."""
342 | global _middleware_instance
343 | if _middleware_instance is None:
344 | _middleware_instance = MetricsMiddleware()
345 | return _middleware_instance
346 |
347 |
348 | # Convenience decorators using global middleware instance
349 | def track_api_call(provider: str, endpoint: str, method: str = "GET"):
350 | """Convenience decorator for API call tracking."""
351 | return get_metrics_middleware().track_api_call(provider, endpoint, method)
352 |
353 |
354 | def track_strategy_execution(strategy_name: str, symbol: str, timeframe: str = "1D"):
355 | """Convenience decorator for strategy execution tracking."""
356 | return get_metrics_middleware().track_strategy_execution(
357 | strategy_name, symbol, timeframe
358 | )
359 |
360 |
361 | def track_resource_usage(operation_type: str):
362 | """Convenience decorator for resource usage tracking."""
363 | return get_metrics_middleware().track_resource_usage(operation_type)
364 |
365 |
366 | def track_database_operation(query_type: str, table_name: str, operation: str):
367 | """Convenience context manager for database operation tracking."""
368 | return get_metrics_middleware().track_database_operation(
369 | query_type, table_name, operation
370 | )
371 |
372 |
373 | # Example circuit breaker with metrics
374 | class MetricsCircuitBreaker:
375 | """
376 | Circuit breaker with integrated metrics tracking.
377 |
378 | Automatically tracks circuit breaker state changes and failures.
379 | """
380 |
381 | def __init__(
382 | self,
383 | provider: str,
384 | endpoint: str,
385 | failure_threshold: int = 5,
386 | recovery_timeout: int = 60,
387 | expected_exception: type = Exception,
388 | ):
389 | self.provider = provider
390 | self.endpoint = endpoint
391 | self.failure_threshold = failure_threshold
392 | self.recovery_timeout = recovery_timeout
393 | self.expected_exception = expected_exception
394 |
395 | self.failure_count = 0
396 | self.last_failure_time = 0
397 | self.state = "closed" # closed, open, half-open
398 |
399 | self.collector = get_backtesting_metrics()
400 | self.logger = get_logger(f"{__name__}.MetricsCircuitBreaker")
401 |
402 | async def call(self, func: Callable, *args, **kwargs):
403 | """Execute function with circuit breaker protection and metrics tracking."""
404 | if self.state == "open":
405 | if time.time() - self.last_failure_time > self.recovery_timeout:
406 | self.state = "half-open"
407 | self.collector.track_circuit_breaker(
408 | self.provider, self.endpoint, self.state, 0
409 | )
410 | else:
411 | raise Exception(
412 | f"Circuit breaker is open for {self.provider}/{self.endpoint}"
413 | )
414 |
415 | try:
416 | if asyncio.iscoroutinefunction(func):
417 | result = await func(*args, **kwargs)
418 | else:
419 | result = func(*args, **kwargs)
420 |
421 | # Success - reset failure count and close circuit if half-open
422 | if self.state == "half-open":
423 | self.state = "closed"
424 | self.failure_count = 0
425 | self.collector.track_circuit_breaker(
426 | self.provider, self.endpoint, self.state, 0
427 | )
428 | self.logger.info(
429 | f"Circuit breaker closed for {self.provider}/{self.endpoint}"
430 | )
431 |
432 | return result
433 |
434 | except self.expected_exception as e:
435 | self.failure_count += 1
436 | self.last_failure_time = time.time()
437 |
438 | # Track failure
439 | self.collector.track_circuit_breaker(
440 | self.provider, self.endpoint, self.state, 1
441 | )
442 |
443 | # Open circuit if threshold reached
444 | if self.failure_count >= self.failure_threshold:
445 | self.state = "open"
446 | self.collector.track_circuit_breaker(
447 | self.provider, self.endpoint, self.state, 0
448 | )
449 | self.logger.warning(
450 | f"Circuit breaker opened for {self.provider}/{self.endpoint} "
451 | f"after {self.failure_count} failures"
452 | )
453 |
454 | raise e
455 |
```
--------------------------------------------------------------------------------
/tests/test_stock_analysis_service.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Tests for StockAnalysisService.
3 | """
4 |
5 | from unittest.mock import Mock, patch
6 |
7 | import pandas as pd
8 |
9 | from maverick_mcp.domain.stock_analysis import StockAnalysisService
10 | from maverick_mcp.infrastructure.caching import CacheManagementService
11 | from maverick_mcp.infrastructure.data_fetching import StockDataFetchingService
12 |
13 |
14 | class TestStockAnalysisService:
15 | """Test cases for StockAnalysisService."""
16 |
17 | def setup_method(self):
18 | """Set up test fixtures."""
19 | self.mock_data_fetching_service = Mock(spec=StockDataFetchingService)
20 | self.mock_cache_service = Mock(spec=CacheManagementService)
21 | self.mock_db_session = Mock()
22 |
23 | self.service = StockAnalysisService(
24 | data_fetching_service=self.mock_data_fetching_service,
25 | cache_service=self.mock_cache_service,
26 | db_session=self.mock_db_session,
27 | )
28 |
29 | def test_init(self):
30 | """Test service initialization."""
31 | assert self.service.data_fetching_service == self.mock_data_fetching_service
32 | assert self.service.cache_service == self.mock_cache_service
33 | assert self.service.db_session == self.mock_db_session
34 |
35 | def test_get_stock_data_non_daily_interval(self):
36 | """Test get_stock_data with non-daily interval bypasses cache."""
37 | mock_data = pd.DataFrame(
38 | {"Open": [150.0], "Close": [151.0]},
39 | index=pd.date_range("2024-01-01", periods=1),
40 | )
41 |
42 | self.mock_data_fetching_service.fetch_stock_data.return_value = mock_data
43 |
44 | # Test with 1-hour interval
45 | result = self.service.get_stock_data("AAPL", interval="1h")
46 |
47 | # Assertions
48 | assert not result.empty
49 | self.mock_data_fetching_service.fetch_stock_data.assert_called_once()
50 | self.mock_cache_service.get_cached_data.assert_not_called()
51 |
52 | def test_get_stock_data_with_period(self):
53 | """Test get_stock_data with period parameter bypasses cache."""
54 | mock_data = pd.DataFrame(
55 | {"Open": [150.0], "Close": [151.0]},
56 | index=pd.date_range("2024-01-01", periods=1),
57 | )
58 |
59 | self.mock_data_fetching_service.fetch_stock_data.return_value = mock_data
60 |
61 | # Test with period
62 | result = self.service.get_stock_data("AAPL", period="1mo")
63 |
64 | # Assertions
65 | assert not result.empty
66 | self.mock_data_fetching_service.fetch_stock_data.assert_called_once()
67 | self.mock_cache_service.get_cached_data.assert_not_called()
68 |
69 | def test_get_stock_data_cache_disabled(self):
70 | """Test get_stock_data with cache disabled."""
71 | mock_data = pd.DataFrame(
72 | {"Open": [150.0], "Close": [151.0]},
73 | index=pd.date_range("2024-01-01", periods=1),
74 | )
75 |
76 | self.mock_data_fetching_service.fetch_stock_data.return_value = mock_data
77 |
78 | # Test with cache disabled
79 | result = self.service.get_stock_data("AAPL", use_cache=False)
80 |
81 | # Assertions
82 | assert not result.empty
83 | self.mock_data_fetching_service.fetch_stock_data.assert_called_once()
84 | self.mock_cache_service.get_cached_data.assert_not_called()
85 |
86 | def test_get_stock_data_cache_hit(self):
87 | """Test get_stock_data with complete cache hit."""
88 | # Mock cached data
89 | mock_cached_data = pd.DataFrame(
90 | {
91 | "Open": [150.0, 151.0, 152.0],
92 | "High": [151.0, 152.0, 153.0],
93 | "Low": [149.0, 150.0, 151.0],
94 | "Close": [150.5, 151.5, 152.5],
95 | "Volume": [1000000, 1100000, 1200000],
96 | },
97 | index=pd.date_range("2024-01-01", periods=3),
98 | )
99 |
100 | self.mock_cache_service.get_cached_data.return_value = mock_cached_data
101 |
102 | # Test
103 | result = self.service.get_stock_data(
104 | "AAPL", start_date="2024-01-01", end_date="2024-01-03"
105 | )
106 |
107 | # Assertions
108 | assert not result.empty
109 | assert len(result) == 3
110 | self.mock_cache_service.get_cached_data.assert_called_once()
111 | self.mock_data_fetching_service.fetch_stock_data.assert_not_called()
112 |
113 | def test_get_stock_data_cache_miss(self):
114 | """Test get_stock_data with complete cache miss."""
115 | # Mock no cached data
116 | self.mock_cache_service.get_cached_data.return_value = None
117 |
118 | # Mock market calendar
119 | with patch.object(self.service, "_get_trading_days") as mock_trading_days:
120 | mock_trading_days.return_value = pd.DatetimeIndex(
121 | ["2024-01-01", "2024-01-02"]
122 | )
123 |
124 | # Mock fetched data
125 | mock_fetched_data = pd.DataFrame(
126 | {
127 | "Open": [150.0, 151.0],
128 | "Close": [150.5, 151.5],
129 | "Volume": [1000000, 1100000],
130 | },
131 | index=pd.date_range("2024-01-01", periods=2),
132 | )
133 |
134 | self.mock_data_fetching_service.fetch_stock_data.return_value = (
135 | mock_fetched_data
136 | )
137 |
138 | # Test
139 | result = self.service.get_stock_data(
140 | "AAPL", start_date="2024-01-01", end_date="2024-01-02"
141 | )
142 |
143 | # Assertions
144 | assert not result.empty
145 | self.mock_cache_service.get_cached_data.assert_called_once()
146 | self.mock_data_fetching_service.fetch_stock_data.assert_called_once()
147 | self.mock_cache_service.cache_data.assert_called_once()
148 |
149 | def test_get_stock_data_partial_cache_hit(self):
150 | """Test get_stock_data with partial cache hit requiring additional data."""
151 | # Mock partial cached data (missing recent data)
152 | mock_cached_data = pd.DataFrame(
153 | {"Open": [150.0], "Close": [150.5], "Volume": [1000000]},
154 | index=pd.date_range("2024-01-01", periods=1),
155 | )
156 |
157 | self.mock_cache_service.get_cached_data.return_value = mock_cached_data
158 |
159 | # Mock missing data fetch
160 | mock_missing_data = pd.DataFrame(
161 | {"Open": [151.0], "Close": [151.5], "Volume": [1100000]},
162 | index=pd.date_range("2024-01-02", periods=1),
163 | )
164 |
165 | self.mock_data_fetching_service.fetch_stock_data.return_value = (
166 | mock_missing_data
167 | )
168 |
169 | # Mock helper methods
170 | with (
171 | patch.object(self.service, "_get_trading_days") as mock_trading_days,
172 | patch.object(
173 | self.service, "_is_trading_day_between"
174 | ) as mock_trading_between,
175 | ):
176 | mock_trading_days.return_value = pd.DatetimeIndex(["2024-01-02"])
177 | mock_trading_between.return_value = True
178 |
179 | # Test
180 | result = self.service.get_stock_data(
181 | "AAPL", start_date="2024-01-01", end_date="2024-01-02"
182 | )
183 |
184 | # Assertions
185 | assert not result.empty
186 | assert len(result) == 2 # Combined cached + fetched data
187 | self.mock_cache_service.get_cached_data.assert_called_once()
188 | self.mock_data_fetching_service.fetch_stock_data.assert_called_once()
189 | self.mock_cache_service.cache_data.assert_called_once()
190 |
191 | def test_get_stock_data_smart_cache_fallback(self):
192 | """Test get_stock_data fallback when smart cache fails."""
193 | # Mock cache service to raise exception
194 | self.mock_cache_service.get_cached_data.side_effect = Exception("Cache error")
195 |
196 | # Mock fallback data
197 | mock_fallback_data = pd.DataFrame(
198 | {"Open": [150.0], "Close": [150.5]},
199 | index=pd.date_range("2024-01-01", periods=1),
200 | )
201 |
202 | self.mock_data_fetching_service.fetch_stock_data.return_value = (
203 | mock_fallback_data
204 | )
205 |
206 | # Test
207 | result = self.service.get_stock_data("AAPL")
208 |
209 | # Assertions
210 | assert not result.empty
211 | self.mock_data_fetching_service.fetch_stock_data.assert_called()
212 |
213 | def test_get_stock_info(self):
214 | """Test get_stock_info delegation."""
215 | mock_info = {"longName": "Apple Inc."}
216 | self.mock_data_fetching_service.fetch_stock_info.return_value = mock_info
217 |
218 | # Test
219 | result = self.service.get_stock_info("AAPL")
220 |
221 | # Assertions
222 | assert result == mock_info
223 | self.mock_data_fetching_service.fetch_stock_info.assert_called_once_with("AAPL")
224 |
225 | def test_get_realtime_data(self):
226 | """Test get_realtime_data delegation."""
227 | mock_data = {"symbol": "AAPL", "price": 150.0}
228 | self.mock_data_fetching_service.fetch_realtime_data.return_value = mock_data
229 |
230 | # Test
231 | result = self.service.get_realtime_data("AAPL")
232 |
233 | # Assertions
234 | assert result == mock_data
235 | self.mock_data_fetching_service.fetch_realtime_data.assert_called_once_with(
236 | "AAPL"
237 | )
238 |
239 | def test_get_multiple_realtime_data(self):
240 | """Test get_multiple_realtime_data delegation."""
241 | mock_data = {"AAPL": {"price": 150.0}, "MSFT": {"price": 300.0}}
242 | self.mock_data_fetching_service.fetch_multiple_realtime_data.return_value = (
243 | mock_data
244 | )
245 |
246 | # Test
247 | result = self.service.get_multiple_realtime_data(["AAPL", "MSFT"])
248 |
249 | # Assertions
250 | assert result == mock_data
251 | self.mock_data_fetching_service.fetch_multiple_realtime_data.assert_called_once_with(
252 | ["AAPL", "MSFT"]
253 | )
254 |
255 | @patch("maverick_mcp.domain.stock_analysis.stock_analysis_service.datetime")
256 | @patch("maverick_mcp.domain.stock_analysis.stock_analysis_service.pytz")
257 | def test_is_market_open_weekday_during_hours(self, mock_pytz, mock_datetime):
258 | """Test market open check during trading hours on weekday."""
259 | # Mock current time: Wednesday 10:00 AM ET
260 | mock_now = Mock()
261 | mock_now.weekday.return_value = 2 # Wednesday
262 | mock_now.replace.return_value = mock_now
263 | mock_now.__le__ = lambda self, other: True
264 | mock_now.__ge__ = lambda self, other: True
265 |
266 | mock_datetime.now.return_value = mock_now
267 | mock_pytz.timezone.return_value.localize = lambda x: x
268 |
269 | # Test
270 | result = self.service.is_market_open()
271 |
272 | # Assertions
273 | assert result is True
274 |
275 | @patch("maverick_mcp.domain.stock_analysis.stock_analysis_service.datetime")
276 | def test_is_market_open_weekend(self, mock_datetime):
277 | """Test market open check on weekend."""
278 | # Mock current time: Saturday
279 | mock_now = Mock()
280 | mock_now.weekday.return_value = 5 # Saturday
281 |
282 | mock_datetime.now.return_value = mock_now
283 |
284 | # Test
285 | result = self.service.is_market_open()
286 |
287 | # Assertions
288 | assert result is False
289 |
290 | def test_get_news(self):
291 | """Test get_news delegation."""
292 | mock_news = pd.DataFrame({"title": ["Apple News"]})
293 | self.mock_data_fetching_service.fetch_news.return_value = mock_news
294 |
295 | # Test
296 | result = self.service.get_news("AAPL", limit=5)
297 |
298 | # Assertions
299 | assert not result.empty
300 | self.mock_data_fetching_service.fetch_news.assert_called_once_with("AAPL", 5)
301 |
302 | def test_get_earnings(self):
303 | """Test get_earnings delegation."""
304 | mock_earnings = {"earnings": {}}
305 | self.mock_data_fetching_service.fetch_earnings.return_value = mock_earnings
306 |
307 | # Test
308 | result = self.service.get_earnings("AAPL")
309 |
310 | # Assertions
311 | assert result == mock_earnings
312 | self.mock_data_fetching_service.fetch_earnings.assert_called_once_with("AAPL")
313 |
314 | def test_get_recommendations(self):
315 | """Test get_recommendations delegation."""
316 | mock_recs = pd.DataFrame({"firm": ["Goldman Sachs"]})
317 | self.mock_data_fetching_service.fetch_recommendations.return_value = mock_recs
318 |
319 | # Test
320 | result = self.service.get_recommendations("AAPL")
321 |
322 | # Assertions
323 | assert not result.empty
324 | self.mock_data_fetching_service.fetch_recommendations.assert_called_once_with(
325 | "AAPL"
326 | )
327 |
328 | def test_is_etf(self):
329 | """Test is_etf delegation."""
330 | self.mock_data_fetching_service.check_if_etf.return_value = True
331 |
332 | # Test
333 | result = self.service.is_etf("SPY")
334 |
335 | # Assertions
336 | assert result is True
337 | self.mock_data_fetching_service.check_if_etf.assert_called_once_with("SPY")
338 |
339 | def test_invalidate_cache(self):
340 | """Test invalidate_cache delegation."""
341 | self.mock_cache_service.invalidate_cache.return_value = True
342 |
343 | # Test
344 | result = self.service.invalidate_cache("AAPL", "2024-01-01", "2024-01-02")
345 |
346 | # Assertions
347 | assert result is True
348 | self.mock_cache_service.invalidate_cache.assert_called_once_with(
349 | "AAPL", "2024-01-01", "2024-01-02"
350 | )
351 |
352 | def test_get_cache_stats(self):
353 | """Test get_cache_stats delegation."""
354 | mock_stats = {"symbol": "AAPL", "total_records": 100}
355 | self.mock_cache_service.get_cache_stats.return_value = mock_stats
356 |
357 | # Test
358 | result = self.service.get_cache_stats("AAPL")
359 |
360 | # Assertions
361 | assert result == mock_stats
362 | self.mock_cache_service.get_cache_stats.assert_called_once_with("AAPL")
363 |
364 | def test_get_trading_days(self):
365 | """Test get_trading_days helper method."""
366 | with patch.object(self.service.market_calendar, "schedule") as mock_schedule:
367 | # Mock schedule response
368 | mock_df = Mock()
369 | mock_df.index = pd.DatetimeIndex(["2024-01-01", "2024-01-02"])
370 | mock_schedule.return_value = mock_df
371 |
372 | # Test
373 | result = self.service._get_trading_days("2024-01-01", "2024-01-02")
374 |
375 | # Assertions
376 | assert len(result) == 2
377 | assert result[0] == pd.Timestamp("2024-01-01")
378 |
379 | def test_is_trading_day(self):
380 | """Test is_trading_day helper method."""
381 | with patch.object(self.service.market_calendar, "schedule") as mock_schedule:
382 | # Mock schedule response with trading session
383 | mock_df = Mock()
384 | mock_df.__len__ = Mock(return_value=1) # Has trading session
385 | mock_schedule.return_value = mock_df
386 |
387 | # Test
388 | result = self.service._is_trading_day("2024-01-01")
389 |
390 | # Assertions
391 | assert result is True
392 |
393 | def test_get_last_trading_day_is_trading_day(self):
394 | """Test get_last_trading_day when date is already a trading day."""
395 | with patch.object(self.service, "_is_trading_day") as mock_is_trading:
396 | mock_is_trading.return_value = True
397 |
398 | # Test
399 | result = self.service._get_last_trading_day("2024-01-01")
400 |
401 | # Assertions
402 | assert result == pd.Timestamp("2024-01-01")
403 |
404 | def test_get_last_trading_day_find_previous(self):
405 | """Test get_last_trading_day finding previous trading day."""
406 | with patch.object(self.service, "_is_trading_day") as mock_is_trading:
407 | # First call (date itself) returns False, second call (previous day) returns True
408 | mock_is_trading.side_effect = [False, True]
409 |
410 | # Test
411 | result = self.service._get_last_trading_day("2024-01-01")
412 |
413 | # Assertions
414 | assert result == pd.Timestamp("2023-12-31")
415 |
416 | def test_is_trading_day_between_true(self):
417 | """Test is_trading_day_between when there are trading days between dates."""
418 | with patch.object(self.service, "_get_trading_days") as mock_trading_days:
419 | mock_trading_days.return_value = pd.DatetimeIndex(["2024-01-02"])
420 |
421 | # Test
422 | start_date = pd.Timestamp("2024-01-01")
423 | end_date = pd.Timestamp("2024-01-03")
424 | result = self.service._is_trading_day_between(start_date, end_date)
425 |
426 | # Assertions
427 | assert result is True
428 |
429 | def test_is_trading_day_between_false(self):
430 | """Test is_trading_day_between when there are no trading days between dates."""
431 | with patch.object(self.service, "_get_trading_days") as mock_trading_days:
432 | mock_trading_days.return_value = pd.DatetimeIndex([])
433 |
434 | # Test
435 | start_date = pd.Timestamp("2024-01-01")
436 | end_date = pd.Timestamp("2024-01-02")
437 | result = self.service._is_trading_day_between(start_date, end_date)
438 |
439 | # Assertions
440 | assert result is False
441 |
```