This is page 5 of 39. Use http://codebase.md/wshobson/maverick-mcp?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.md
│ │ ├── config.yml
│ │ ├── feature_request.md
│ │ ├── question.md
│ │ └── security_report.md
│ ├── pull_request_template.md
│ └── workflows
│ ├── claude-code-review.yml
│ └── claude.yml
├── .gitignore
├── .python-version
├── .vscode
│ ├── launch.json
│ └── settings.json
├── alembic
│ ├── env.py
│ ├── script.py.mako
│ └── versions
│ ├── 001_initial_schema.py
│ ├── 003_add_performance_indexes.py
│ ├── 006_rename_metadata_columns.py
│ ├── 008_performance_optimization_indexes.py
│ ├── 009_rename_to_supply_demand.py
│ ├── 010_self_contained_schema.py
│ ├── 011_remove_proprietary_terms.py
│ ├── 013_add_backtest_persistence_models.py
│ ├── 014_add_portfolio_models.py
│ ├── 08e3945a0c93_merge_heads.py
│ ├── 9374a5c9b679_merge_heads_for_testing.py
│ ├── abf9b9afb134_merge_multiple_heads.py
│ ├── adda6d3fd84b_merge_proprietary_terms_removal_with_.py
│ ├── e0c75b0bdadb_fix_financial_data_precision_only.py
│ ├── f0696e2cac15_add_essential_performance_indexes.py
│ └── fix_database_integrity_issues.py
├── alembic.ini
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── DATABASE_SETUP.md
├── docker-compose.override.yml.example
├── docker-compose.yml
├── Dockerfile
├── docs
│ ├── api
│ │ └── backtesting.md
│ ├── BACKTESTING.md
│ ├── COST_BASIS_SPECIFICATION.md
│ ├── deep_research_agent.md
│ ├── exa_research_testing_strategy.md
│ ├── PORTFOLIO_PERSONALIZATION_PLAN.md
│ ├── PORTFOLIO.md
│ ├── SETUP_SELF_CONTAINED.md
│ └── speed_testing_framework.md
├── examples
│ ├── complete_speed_validation.py
│ ├── deep_research_integration.py
│ ├── llm_optimization_example.py
│ ├── llm_speed_demo.py
│ ├── monitoring_example.py
│ ├── parallel_research_example.py
│ ├── speed_optimization_demo.py
│ └── timeout_fix_demonstration.py
├── LICENSE
├── Makefile
├── MANIFEST.in
├── maverick_mcp
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── circuit_breaker.py
│ │ ├── deep_research.py
│ │ ├── market_analysis.py
│ │ ├── optimized_research.py
│ │ ├── supervisor.py
│ │ └── technical_analysis.py
│ ├── api
│ │ ├── __init__.py
│ │ ├── api_server.py
│ │ ├── connection_manager.py
│ │ ├── dependencies
│ │ │ ├── __init__.py
│ │ │ ├── stock_analysis.py
│ │ │ └── technical_analysis.py
│ │ ├── error_handling.py
│ │ ├── inspector_compatible_sse.py
│ │ ├── inspector_sse.py
│ │ ├── middleware
│ │ │ ├── error_handling.py
│ │ │ ├── mcp_logging.py
│ │ │ ├── rate_limiting_enhanced.py
│ │ │ └── security.py
│ │ ├── openapi_config.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── agents.py
│ │ │ ├── backtesting.py
│ │ │ ├── data_enhanced.py
│ │ │ ├── data.py
│ │ │ ├── health_enhanced.py
│ │ │ ├── health_tools.py
│ │ │ ├── health.py
│ │ │ ├── intelligent_backtesting.py
│ │ │ ├── introspection.py
│ │ │ ├── mcp_prompts.py
│ │ │ ├── monitoring.py
│ │ │ ├── news_sentiment_enhanced.py
│ │ │ ├── performance.py
│ │ │ ├── portfolio.py
│ │ │ ├── research.py
│ │ │ ├── screening_ddd.py
│ │ │ ├── screening_parallel.py
│ │ │ ├── screening.py
│ │ │ ├── technical_ddd.py
│ │ │ ├── technical_enhanced.py
│ │ │ ├── technical.py
│ │ │ └── tool_registry.py
│ │ ├── server.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── base_service.py
│ │ │ ├── market_service.py
│ │ │ ├── portfolio_service.py
│ │ │ ├── prompt_service.py
│ │ │ └── resource_service.py
│ │ ├── simple_sse.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── insomnia_export.py
│ │ └── postman_export.py
│ ├── application
│ │ ├── __init__.py
│ │ ├── commands
│ │ │ └── __init__.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_dto.py
│ │ ├── queries
│ │ │ ├── __init__.py
│ │ │ └── get_technical_analysis.py
│ │ └── screening
│ │ ├── __init__.py
│ │ ├── dtos.py
│ │ └── queries.py
│ ├── backtesting
│ │ ├── __init__.py
│ │ ├── ab_testing.py
│ │ ├── analysis.py
│ │ ├── batch_processing_stub.py
│ │ ├── batch_processing.py
│ │ ├── model_manager.py
│ │ ├── optimization.py
│ │ ├── persistence.py
│ │ ├── retraining_pipeline.py
│ │ ├── strategies
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── ml
│ │ │ │ ├── __init__.py
│ │ │ │ ├── adaptive.py
│ │ │ │ ├── ensemble.py
│ │ │ │ ├── feature_engineering.py
│ │ │ │ └── regime_aware.py
│ │ │ ├── ml_strategies.py
│ │ │ ├── parser.py
│ │ │ └── templates.py
│ │ ├── strategy_executor.py
│ │ ├── vectorbt_engine.py
│ │ └── visualization.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── database_self_contained.py
│ │ ├── database.py
│ │ ├── llm_optimization_config.py
│ │ ├── logging_settings.py
│ │ ├── plotly_config.py
│ │ ├── security_utils.py
│ │ ├── security.py
│ │ ├── settings.py
│ │ ├── technical_constants.py
│ │ ├── tool_estimation.py
│ │ └── validation.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── technical_analysis.py
│ │ └── visualization.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── cache_manager.py
│ │ ├── cache.py
│ │ ├── django_adapter.py
│ │ ├── health.py
│ │ ├── models.py
│ │ ├── performance.py
│ │ ├── session_management.py
│ │ └── validation.py
│ ├── database
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── optimization.py
│ ├── dependencies.py
│ ├── domain
│ │ ├── __init__.py
│ │ ├── entities
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis.py
│ │ ├── events
│ │ │ └── __init__.py
│ │ ├── portfolio.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ ├── entities.py
│ │ │ ├── services.py
│ │ │ └── value_objects.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_service.py
│ │ ├── stock_analysis
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis_service.py
│ │ └── value_objects
│ │ ├── __init__.py
│ │ └── technical_indicators.py
│ ├── exceptions.py
│ ├── infrastructure
│ │ ├── __init__.py
│ │ ├── cache
│ │ │ └── __init__.py
│ │ ├── caching
│ │ │ ├── __init__.py
│ │ │ └── cache_management_service.py
│ │ ├── connection_manager.py
│ │ ├── data_fetching
│ │ │ ├── __init__.py
│ │ │ └── stock_data_service.py
│ │ ├── health
│ │ │ ├── __init__.py
│ │ │ └── health_checker.py
│ │ ├── persistence
│ │ │ ├── __init__.py
│ │ │ └── stock_repository.py
│ │ ├── providers
│ │ │ └── __init__.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ └── repositories.py
│ │ └── sse_optimizer.py
│ ├── langchain_tools
│ │ ├── __init__.py
│ │ ├── adapters.py
│ │ └── registry.py
│ ├── logging_config.py
│ ├── memory
│ │ ├── __init__.py
│ │ └── stores.py
│ ├── monitoring
│ │ ├── __init__.py
│ │ ├── health_check.py
│ │ ├── health_monitor.py
│ │ ├── integration_example.py
│ │ ├── metrics.py
│ │ ├── middleware.py
│ │ └── status_dashboard.py
│ ├── providers
│ │ ├── __init__.py
│ │ ├── dependencies.py
│ │ ├── factories
│ │ │ ├── __init__.py
│ │ │ ├── config_factory.py
│ │ │ └── provider_factory.py
│ │ ├── implementations
│ │ │ ├── __init__.py
│ │ │ ├── cache_adapter.py
│ │ │ ├── macro_data_adapter.py
│ │ │ ├── market_data_adapter.py
│ │ │ ├── persistence_adapter.py
│ │ │ └── stock_data_adapter.py
│ │ ├── interfaces
│ │ │ ├── __init__.py
│ │ │ ├── cache.py
│ │ │ ├── config.py
│ │ │ ├── macro_data.py
│ │ │ ├── market_data.py
│ │ │ ├── persistence.py
│ │ │ └── stock_data.py
│ │ ├── llm_factory.py
│ │ ├── macro_data.py
│ │ ├── market_data.py
│ │ ├── mocks
│ │ │ ├── __init__.py
│ │ │ ├── mock_cache.py
│ │ │ ├── mock_config.py
│ │ │ ├── mock_macro_data.py
│ │ │ ├── mock_market_data.py
│ │ │ ├── mock_persistence.py
│ │ │ └── mock_stock_data.py
│ │ ├── openrouter_provider.py
│ │ ├── optimized_screening.py
│ │ ├── optimized_stock_data.py
│ │ └── stock_data.py
│ ├── README.md
│ ├── tests
│ │ ├── __init__.py
│ │ ├── README_INMEMORY_TESTS.md
│ │ ├── test_cache_debug.py
│ │ ├── test_fixes_validation.py
│ │ ├── test_in_memory_routers.py
│ │ ├── test_in_memory_server.py
│ │ ├── test_macro_data_provider.py
│ │ ├── test_mailgun_email.py
│ │ ├── test_market_calendar_caching.py
│ │ ├── test_mcp_tool_fixes_pytest.py
│ │ ├── test_mcp_tool_fixes.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_models_functional.py
│ │ ├── test_server.py
│ │ ├── test_stock_data_enhanced.py
│ │ ├── test_stock_data_provider.py
│ │ └── test_technical_analysis.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── performance_monitoring.py
│ │ ├── portfolio_manager.py
│ │ ├── risk_management.py
│ │ └── sentiment_analysis.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── agent_errors.py
│ │ ├── batch_processing.py
│ │ ├── cache_warmer.py
│ │ ├── circuit_breaker_decorators.py
│ │ ├── circuit_breaker_services.py
│ │ ├── circuit_breaker.py
│ │ ├── data_chunking.py
│ │ ├── database_monitoring.py
│ │ ├── debug_utils.py
│ │ ├── fallback_strategies.py
│ │ ├── llm_optimization.py
│ │ ├── logging_example.py
│ │ ├── logging_init.py
│ │ ├── logging.py
│ │ ├── mcp_logging.py
│ │ ├── memory_profiler.py
│ │ ├── monitoring_middleware.py
│ │ ├── monitoring.py
│ │ ├── orchestration_logging.py
│ │ ├── parallel_research.py
│ │ ├── parallel_screening.py
│ │ ├── quick_cache.py
│ │ ├── resource_manager.py
│ │ ├── shutdown.py
│ │ ├── stock_helpers.py
│ │ ├── structured_logger.py
│ │ ├── tool_monitoring.py
│ │ ├── tracing.py
│ │ └── yfinance_pool.py
│ ├── validation
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── data.py
│ │ ├── middleware.py
│ │ ├── portfolio.py
│ │ ├── responses.py
│ │ ├── screening.py
│ │ └── technical.py
│ └── workflows
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── market_analyzer.py
│ │ ├── optimizer_agent.py
│ │ ├── strategy_selector.py
│ │ └── validator_agent.py
│ ├── backtesting_workflow.py
│ └── state.py
├── PLANS.md
├── pyproject.toml
├── pyrightconfig.json
├── README.md
├── scripts
│ ├── dev.sh
│ ├── INSTALLATION_GUIDE.md
│ ├── load_example.py
│ ├── load_market_data.py
│ ├── load_tiingo_data.py
│ ├── migrate_db.py
│ ├── README_TIINGO_LOADER.md
│ ├── requirements_tiingo.txt
│ ├── run_stock_screening.py
│ ├── run-migrations.sh
│ ├── seed_db.py
│ ├── seed_sp500.py
│ ├── setup_database.sh
│ ├── setup_self_contained.py
│ ├── setup_sp500_database.sh
│ ├── test_seeded_data.py
│ ├── test_tiingo_loader.py
│ ├── tiingo_config.py
│ └── validate_setup.py
├── SECURITY.md
├── server.json
├── setup.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── core
│ │ └── test_technical_analysis.py
│ ├── data
│ │ └── test_portfolio_models.py
│ ├── domain
│ │ ├── conftest.py
│ │ ├── test_portfolio_entities.py
│ │ └── test_technical_analysis_service.py
│ ├── fixtures
│ │ └── orchestration_fixtures.py
│ ├── integration
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── README.md
│ │ ├── run_integration_tests.sh
│ │ ├── test_api_technical.py
│ │ ├── test_chaos_engineering.py
│ │ ├── test_config_management.py
│ │ ├── test_full_backtest_workflow_advanced.py
│ │ ├── test_full_backtest_workflow.py
│ │ ├── test_high_volume.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_orchestration_complete.py
│ │ ├── test_portfolio_persistence.py
│ │ ├── test_redis_cache.py
│ │ ├── test_security_integration.py.disabled
│ │ └── vcr_setup.py
│ ├── performance
│ │ ├── __init__.py
│ │ ├── test_benchmarks.py
│ │ ├── test_load.py
│ │ ├── test_profiling.py
│ │ └── test_stress.py
│ ├── providers
│ │ └── test_stock_data_simple.py
│ ├── README.md
│ ├── test_agents_router_mcp.py
│ ├── test_backtest_persistence.py
│ ├── test_cache_management_service.py
│ ├── test_cache_serialization.py
│ ├── test_circuit_breaker.py
│ ├── test_database_pool_config_simple.py
│ ├── test_database_pool_config.py
│ ├── test_deep_research_functional.py
│ ├── test_deep_research_integration.py
│ ├── test_deep_research_parallel_execution.py
│ ├── test_error_handling.py
│ ├── test_event_loop_integrity.py
│ ├── test_exa_research_integration.py
│ ├── test_exception_hierarchy.py
│ ├── test_financial_search.py
│ ├── test_graceful_shutdown.py
│ ├── test_integration_simple.py
│ ├── test_langgraph_workflow.py
│ ├── test_market_data_async.py
│ ├── test_market_data_simple.py
│ ├── test_mcp_orchestration_functional.py
│ ├── test_ml_strategies.py
│ ├── test_optimized_research_agent.py
│ ├── test_orchestration_integration.py
│ ├── test_orchestration_logging.py
│ ├── test_orchestration_tools_simple.py
│ ├── test_parallel_research_integration.py
│ ├── test_parallel_research_orchestrator.py
│ ├── test_parallel_research_performance.py
│ ├── test_performance_optimizations.py
│ ├── test_production_validation.py
│ ├── test_provider_architecture.py
│ ├── test_rate_limiting_enhanced.py
│ ├── test_runner_validation.py
│ ├── test_security_comprehensive.py.disabled
│ ├── test_security_cors.py
│ ├── test_security_enhancements.py.disabled
│ ├── test_security_headers.py
│ ├── test_security_penetration.py
│ ├── test_session_management.py
│ ├── test_speed_optimization_validation.py
│ ├── test_stock_analysis_dependencies.py
│ ├── test_stock_analysis_service.py
│ ├── test_stock_data_fetching_service.py
│ ├── test_supervisor_agent.py
│ ├── test_supervisor_functional.py
│ ├── test_tool_estimation_config.py
│ ├── test_visualization.py
│ └── utils
│ ├── test_agent_errors.py
│ ├── test_logging.py
│ ├── test_parallel_screening.py
│ └── test_quick_cache.py
├── tools
│ ├── check_orchestration_config.py
│ ├── experiments
│ │ ├── validation_examples.py
│ │ └── validation_fixed.py
│ ├── fast_dev.sh
│ ├── hot_reload.py
│ ├── quick_test.py
│ └── templates
│ ├── new_router_template.py
│ ├── new_tool_template.py
│ ├── screening_strategy_template.py
│ └── test_template.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/maverick_mcp/tests/test_mcp_tool_fixes_pytest.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Pytest-compatible test suite for MCP tool fixes.
3 |
4 | This test validates that the fixes for:
5 | 1. Research returning empty results (API keys not passed to DeepResearchAgent)
6 | 2. Portfolio risk analysis cryptic "'high'" error (DataFrame validation and column case)
7 | 3. External API key hard dependency (graceful degradation)
8 |
9 | All continue to work correctly after code changes.
10 | """
11 |
12 | import os
13 | from datetime import UTC, datetime, timedelta
14 |
15 | import pytest
16 |
17 | from maverick_mcp.api.routers.data import get_stock_info
18 | from maverick_mcp.api.routers.portfolio import risk_adjusted_analysis, stock_provider
19 | from maverick_mcp.validation.data import GetStockInfoRequest
20 |
21 |
22 | @pytest.mark.integration
23 | @pytest.mark.external
24 | def test_portfolio_risk_analysis_fix():
25 | """
26 | Test Issue #2: Portfolio risk analysis DataFrame validation and column case fix.
27 |
28 | Validates:
29 | - DataFrame is properly retrieved with correct columns
30 | - Column name case sensitivity is handled correctly
31 | - Date range calculation avoids weekend/holiday issues
32 | - Risk calculations complete successfully
33 | """
34 | # Test data provider directly first
35 | end_date = (datetime.now(UTC) - timedelta(days=7)).strftime("%Y-%m-%d")
36 | start_date = (datetime.now(UTC) - timedelta(days=365)).strftime("%Y-%m-%d")
37 | df = stock_provider.get_stock_data("MSFT", start_date=start_date, end_date=end_date)
38 |
39 | # Verify DataFrame has expected structure
40 | assert not df.empty, "DataFrame should not be empty"
41 | assert df.shape[0] > 200, "Should have substantial historical data"
42 | expected_cols = ["Open", "High", "Low", "Close", "Volume"]
43 | for col in expected_cols:
44 | assert col in df.columns, f"Missing expected column: {col}"
45 |
46 | # Test the actual portfolio risk analysis function
47 | result = risk_adjusted_analysis("MSFT", 75.0)
48 |
49 | # Verify successful result structure
50 | assert "error" not in result, f"Should not have error: {result}"
51 | assert "current_price" in result, "Should include current price"
52 | assert "risk_level" in result, "Should include risk level"
53 | assert "position_sizing" in result, "Should include position sizing"
54 | assert "analysis" in result, "Should include analysis"
55 |
56 | # Verify data types and ranges
57 | assert isinstance(result["current_price"], int | float), (
58 | "Current price should be numeric"
59 | )
60 | assert result["current_price"] > 0, "Current price should be positive"
61 | assert result["risk_level"] == 75.0, "Risk level should match input"
62 |
63 | position_size = result["position_sizing"]["suggested_position_size"]
64 | assert isinstance(position_size, int | float), "Position size should be numeric"
65 | assert position_size > 0, "Position size should be positive"
66 |
67 |
68 | @pytest.mark.integration
69 | @pytest.mark.database
70 | def test_stock_info_external_api_graceful_fallback():
71 | """
72 | Test Issue #3: External API graceful fallback handling.
73 |
74 | Validates:
75 | - External API dependency is optional
76 | - Graceful fallback when EXTERNAL_DATA_API_KEY not configured
77 | - Core stock info functionality still works
78 | """
79 | request = GetStockInfoRequest(ticker="MSFT")
80 | result = get_stock_info(request)
81 |
82 | # Should not have hard errors about missing API keys
83 | if "error" in result:
84 | assert "Invalid API key" not in str(result.get("error")), (
85 | f"Should not have hard API key error: {result}"
86 | )
87 |
88 | # Should have basic company information
89 | assert "company" in result, "Should include company information"
90 | assert "market_data" in result, "Should include market data"
91 |
92 | company = result.get("company", {})
93 | assert company.get("name"), "Should have company name"
94 |
95 | market_data = result.get("market_data", {})
96 | current_price = market_data.get("current_price")
97 | if current_price:
98 | assert isinstance(current_price, int | float), "Price should be numeric"
99 | assert current_price > 0, "Price should be positive"
100 |
101 |
102 | @pytest.mark.integration
103 | @pytest.mark.external
104 | @pytest.mark.asyncio
105 | async def test_research_agent_api_key_configuration():
106 | """
107 | Test Issue #1: Research agent API key configuration fix.
108 |
109 | Validates:
110 | - DeepResearchAgent is created with API keys from settings
111 | - Search providers are properly initialized
112 | - API keys are correctly passed through the configuration chain
113 | """
114 | from maverick_mcp.api.routers.research import get_research_agent
115 |
116 | # Check environment has required API keys
117 | exa_key = os.getenv("EXA_API_KEY")
118 | tavily_key = os.getenv("TAVILY_API_KEY")
119 |
120 | if not (exa_key and tavily_key):
121 | pytest.skip("EXA_API_KEY and TAVILY_API_KEY required for research test")
122 |
123 | # Create research agent
124 | agent = get_research_agent()
125 |
126 | # Verify agent has search providers
127 | assert hasattr(agent, "search_providers"), "Agent should have search_providers"
128 | assert len(agent.search_providers) > 0, "Should have at least one search provider"
129 |
130 | # Verify providers have API keys configured
131 | providers_configured = 0
132 | for provider in agent.search_providers:
133 | if hasattr(provider, "api_key") and provider.api_key:
134 | providers_configured += 1
135 |
136 | assert providers_configured > 0, (
137 | "At least one search provider should have API key configured"
138 | )
139 | assert providers_configured >= 2, (
140 | "Should have both EXA and Tavily providers configured"
141 | )
142 |
143 |
144 | @pytest.mark.unit
145 | def test_llm_configuration_compatibility():
146 | """
147 | Test LLM configuration fixes.
148 |
149 | Validates:
150 | - LLM can be created successfully
151 | - Temperature and streaming settings are compatible with gpt-5-mini
152 | - LLM can handle basic queries without errors
153 | """
154 | from maverick_mcp.providers.llm_factory import get_llm
155 |
156 | # Test LLM creation
157 | llm = get_llm()
158 | assert llm is not None, "LLM should be created successfully"
159 |
160 | # Test basic query to ensure configuration is working
161 | openai_key = os.getenv("OPENAI_API_KEY")
162 | if openai_key:
163 | response = llm.invoke("What is 2+2?")
164 | assert response is not None, "LLM should return a response"
165 | assert hasattr(response, "content"), "Response should have content attribute"
166 | assert "4" in response.content, "LLM should correctly answer 2+2=4"
167 | else:
168 | pytest.skip("OPENAI_API_KEY required for LLM test")
169 |
170 |
171 | @pytest.mark.integration
172 | @pytest.mark.external
173 | @pytest.mark.database
174 | def test_all_mcp_fixes_integration():
175 | """
176 | Integration test to verify all three MCP tool fixes work together.
177 |
178 | This is a comprehensive test that ensures all fixes are compatible
179 | and work correctly in combination.
180 | """
181 | # Test 1: Portfolio analysis
182 | portfolio_result = risk_adjusted_analysis("AAPL", 50.0)
183 | assert "error" not in portfolio_result, "Portfolio analysis should work"
184 |
185 | # Test 2: Stock info
186 | request = GetStockInfoRequest(ticker="AAPL")
187 | stock_info_result = get_stock_info(request)
188 | assert "company" in stock_info_result, "Stock info should work"
189 |
190 | # Test 3: Research agent (if API keys available)
191 | exa_key = os.getenv("EXA_API_KEY")
192 | tavily_key = os.getenv("TAVILY_API_KEY")
193 |
194 | if exa_key and tavily_key:
195 | from maverick_mcp.api.routers.research import get_research_agent
196 |
197 | agent = get_research_agent()
198 | assert len(agent.search_providers) >= 2, "Research agent should have providers"
199 |
200 | # Test 4: LLM configuration
201 | from maverick_mcp.providers.llm_factory import get_llm
202 |
203 | llm = get_llm()
204 | assert llm is not None, "LLM should be configured correctly"
205 |
```
--------------------------------------------------------------------------------
/alembic/versions/009_rename_to_supply_demand.py:
--------------------------------------------------------------------------------
```python
1 | """Rename tables to Supply/Demand terminology
2 |
3 | Revision ID: 009_rename_to_supply_demand
4 | Revises: 008_performance_optimization_indexes
5 | Create Date: 2025-01-27
6 |
7 | This migration renames all database objects to use
8 | supply/demand market structure terminology, removing trademarked references.
9 | """
10 |
11 | import sqlalchemy as sa
12 |
13 | from alembic import op
14 |
15 | # revision identifiers
16 | revision = "009_rename_to_supply_demand"
17 | down_revision = "008_performance_optimization_indexes"
18 | branch_labels = None
19 | depends_on = None
20 |
21 |
22 | def upgrade():
23 | """Rename tables and indexes to supply/demand terminology."""
24 |
25 | # Check if we're using PostgreSQL or SQLite
26 | bind = op.get_bind()
27 | dialect_name = bind.dialect.name
28 |
29 | if dialect_name == "postgresql":
30 | # PostgreSQL supports proper RENAME operations
31 |
32 | # 1. Rename the main table
33 | op.rename_table("stocks_minervinistocks", "stocks_supply_demand_breakouts")
34 |
35 | # 2. Rename indexes
36 | op.execute(
37 | "ALTER INDEX IF EXISTS idx_stocks_minervinistocks_rs_rating_desc RENAME TO idx_stocks_supply_demand_breakouts_rs_rating_desc"
38 | )
39 | op.execute(
40 | "ALTER INDEX IF EXISTS idx_stocks_minervinistocks_date_analyzed RENAME TO idx_stocks_supply_demand_breakouts_date_analyzed"
41 | )
42 | op.execute(
43 | "ALTER INDEX IF EXISTS idx_stocks_minervinistocks_rs_date RENAME TO idx_stocks_supply_demand_breakouts_rs_date"
44 | )
45 | op.execute(
46 | "ALTER INDEX IF EXISTS idx_minervini_stocks_rs_rating RENAME TO idx_supply_demand_breakouts_rs_rating"
47 | )
48 |
49 | # 3. Update any foreign key constraints if they exist
50 | # Note: Adjust these based on your actual foreign key relationships
51 | op.execute("""
52 | DO $$
53 | BEGIN
54 | -- Check if constraint exists before trying to rename
55 | IF EXISTS (
56 | SELECT 1 FROM information_schema.table_constraints
57 | WHERE constraint_name = 'fk_minervinistocks_symbol'
58 | ) THEN
59 | ALTER TABLE stocks_supply_demand_breakouts
60 | RENAME CONSTRAINT fk_minervinistocks_symbol TO fk_supply_demand_breakouts_symbol;
61 | END IF;
62 | END $$;
63 | """)
64 |
65 | elif dialect_name == "sqlite":
66 | # SQLite doesn't support RENAME operations well, need to recreate
67 |
68 | # 1. Create new table with same structure
69 | op.create_table(
70 | "stocks_supply_demand_breakouts",
71 | sa.Column("id", sa.Integer(), nullable=False),
72 | sa.Column("symbol", sa.String(10), nullable=False),
73 | sa.Column("date_analyzed", sa.Date(), nullable=False),
74 | sa.Column("rs_rating", sa.Integer(), nullable=True),
75 | sa.Column("price", sa.Float(), nullable=True),
76 | sa.Column("volume", sa.BigInteger(), nullable=True),
77 | sa.Column("meets_criteria", sa.Boolean(), nullable=True),
78 | sa.Column("created_at", sa.DateTime(), nullable=True),
79 | sa.Column("updated_at", sa.DateTime(), nullable=True),
80 | sa.PrimaryKeyConstraint("id"),
81 | sa.UniqueConstraint(
82 | "symbol", "date_analyzed", name="uq_supply_demand_breakouts_symbol_date"
83 | ),
84 | )
85 |
86 | # 2. Copy data from old table to new
87 | op.execute("""
88 | INSERT INTO stocks_supply_demand_breakouts
89 | SELECT * FROM stocks_minervinistocks
90 | """)
91 |
92 | # 3. Drop old table
93 | op.drop_table("stocks_minervinistocks")
94 |
95 | # 4. Create indexes on new table
96 | op.create_index(
97 | "idx_stocks_supply_demand_breakouts_rs_rating_desc",
98 | "stocks_supply_demand_breakouts",
99 | ["rs_rating"],
100 | postgresql_using="btree",
101 | postgresql_ops={"rs_rating": "DESC"},
102 | )
103 | op.create_index(
104 | "idx_stocks_supply_demand_breakouts_date_analyzed",
105 | "stocks_supply_demand_breakouts",
106 | ["date_analyzed"],
107 | )
108 | op.create_index(
109 | "idx_stocks_supply_demand_breakouts_rs_date",
110 | "stocks_supply_demand_breakouts",
111 | ["symbol", "date_analyzed"],
112 | )
113 |
114 | # Log successful migration
115 | print("✅ Successfully renamed tables to Supply/Demand Breakout terminology")
116 | print(" - stocks_minervinistocks → stocks_supply_demand_breakouts")
117 | print(" - All related indexes have been renamed")
118 |
119 |
120 | def downgrade():
121 | """Revert table names back to original terminology."""
122 |
123 | bind = op.get_bind()
124 | dialect_name = bind.dialect.name
125 |
126 | if dialect_name == "postgresql":
127 | # Rename table back
128 | op.rename_table("stocks_supply_demand_breakouts", "stocks_minervinistocks")
129 |
130 | # Rename indexes back
131 | op.execute(
132 | "ALTER INDEX IF EXISTS idx_stocks_supply_demand_breakouts_rs_rating_desc RENAME TO idx_stocks_minervinistocks_rs_rating_desc"
133 | )
134 | op.execute(
135 | "ALTER INDEX IF EXISTS idx_stocks_supply_demand_breakouts_date_analyzed RENAME TO idx_stocks_minervinistocks_date_analyzed"
136 | )
137 | op.execute(
138 | "ALTER INDEX IF EXISTS idx_stocks_supply_demand_breakouts_rs_date RENAME TO idx_stocks_minervinistocks_rs_date"
139 | )
140 | op.execute(
141 | "ALTER INDEX IF EXISTS idx_supply_demand_breakouts_rs_rating RENAME TO idx_minervini_stocks_rs_rating"
142 | )
143 |
144 | # Rename constraints back
145 | op.execute("""
146 | DO $$
147 | BEGIN
148 | IF EXISTS (
149 | SELECT 1 FROM information_schema.table_constraints
150 | WHERE constraint_name = 'fk_supply_demand_breakouts_symbol'
151 | ) THEN
152 | ALTER TABLE stocks_minervinistocks
153 | RENAME CONSTRAINT fk_supply_demand_breakouts_symbol TO fk_minervinistocks_symbol;
154 | END IF;
155 | END $$;
156 | """)
157 |
158 | elif dialect_name == "sqlite":
159 | # Create old table structure
160 | op.create_table(
161 | "stocks_minervinistocks",
162 | sa.Column("id", sa.Integer(), nullable=False),
163 | sa.Column("symbol", sa.String(10), nullable=False),
164 | sa.Column("date_analyzed", sa.Date(), nullable=False),
165 | sa.Column("rs_rating", sa.Integer(), nullable=True),
166 | sa.Column("price", sa.Float(), nullable=True),
167 | sa.Column("volume", sa.BigInteger(), nullable=True),
168 | sa.Column("meets_criteria", sa.Boolean(), nullable=True),
169 | sa.Column("created_at", sa.DateTime(), nullable=True),
170 | sa.Column("updated_at", sa.DateTime(), nullable=True),
171 | sa.PrimaryKeyConstraint("id"),
172 | sa.UniqueConstraint(
173 | "symbol", "date_analyzed", name="uq_minervinistocks_symbol_date"
174 | ),
175 | )
176 |
177 | # Copy data back
178 | op.execute("""
179 | INSERT INTO stocks_minervinistocks
180 | SELECT * FROM stocks_supply_demand_breakouts
181 | """)
182 |
183 | # Drop new table
184 | op.drop_table("stocks_supply_demand_breakouts")
185 |
186 | # Recreate old indexes
187 | op.create_index(
188 | "idx_stocks_minervinistocks_rs_rating_desc",
189 | "stocks_minervinistocks",
190 | ["rs_rating"],
191 | )
192 | op.create_index(
193 | "idx_stocks_minervinistocks_date_analyzed",
194 | "stocks_minervinistocks",
195 | ["date_analyzed"],
196 | )
197 | op.create_index(
198 | "idx_stocks_minervinistocks_rs_date",
199 | "stocks_minervinistocks",
200 | ["symbol", "date_analyzed"],
201 | )
202 |
```
--------------------------------------------------------------------------------
/maverick_mcp/providers/interfaces/cache.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Cache manager interface.
3 |
4 | This module defines the abstract interface for caching operations,
5 | enabling different caching implementations (Redis, in-memory, etc.)
6 | to be used interchangeably throughout the application.
7 | """
8 |
9 | from typing import Any, Protocol, runtime_checkable
10 |
11 |
12 | @runtime_checkable
13 | class ICacheManager(Protocol):
14 | """
15 | Interface for cache management operations.
16 |
17 | This interface abstracts caching operations to enable different
18 | implementations (Redis, in-memory, etc.) to be used interchangeably.
19 | All methods should be async-compatible to support non-blocking operations.
20 | """
21 |
22 | async def get(self, key: str) -> Any:
23 | """
24 | Get data from cache.
25 |
26 | Args:
27 | key: Cache key to retrieve
28 |
29 | Returns:
30 | Cached data or None if not found or expired
31 | """
32 | ...
33 |
34 | async def set(self, key: str, value: Any, ttl: int | None = None) -> bool:
35 | """
36 | Store data in cache.
37 |
38 | Args:
39 | key: Cache key
40 | value: Data to cache (must be JSON serializable)
41 | ttl: Time-to-live in seconds (None for default TTL)
42 |
43 | Returns:
44 | True if successfully cached, False otherwise
45 | """
46 | ...
47 |
48 | async def delete(self, key: str) -> bool:
49 | """
50 | Delete a key from cache.
51 |
52 | Args:
53 | key: Cache key to delete
54 |
55 | Returns:
56 | True if key was deleted, False if key didn't exist
57 | """
58 | ...
59 |
60 | async def exists(self, key: str) -> bool:
61 | """
62 | Check if a key exists in cache.
63 |
64 | Args:
65 | key: Cache key to check
66 |
67 | Returns:
68 | True if key exists and hasn't expired, False otherwise
69 | """
70 | ...
71 |
72 | async def clear(self, pattern: str | None = None) -> int:
73 | """
74 | Clear cache entries.
75 |
76 | Args:
77 | pattern: Pattern to match keys (e.g., "stock:*")
78 | If None, clears all cache entries
79 |
80 | Returns:
81 | Number of entries cleared
82 | """
83 | ...
84 |
85 | async def get_many(self, keys: list[str]) -> dict[str, Any]:
86 | """
87 | Get multiple values at once for better performance.
88 |
89 | Args:
90 | keys: List of cache keys to retrieve
91 |
92 | Returns:
93 | Dictionary mapping keys to their cached values
94 | (missing keys will not be in the result)
95 | """
96 | ...
97 |
98 | async def set_many(self, items: list[tuple[str, Any, int | None]]) -> int:
99 | """
100 | Set multiple values at once for better performance.
101 |
102 | Args:
103 | items: List of tuples (key, value, ttl)
104 |
105 | Returns:
106 | Number of items successfully cached
107 | """
108 | ...
109 |
110 | async def delete_many(self, keys: list[str]) -> int:
111 | """
112 | Delete multiple keys for better performance.
113 |
114 | Args:
115 | keys: List of keys to delete
116 |
117 | Returns:
118 | Number of keys successfully deleted
119 | """
120 | ...
121 |
122 | async def exists_many(self, keys: list[str]) -> dict[str, bool]:
123 | """
124 | Check existence of multiple keys for better performance.
125 |
126 | Args:
127 | keys: List of keys to check
128 |
129 | Returns:
130 | Dictionary mapping keys to their existence status
131 | """
132 | ...
133 |
134 | async def count_keys(self, pattern: str) -> int:
135 | """
136 | Count keys matching a pattern.
137 |
138 | Args:
139 | pattern: Pattern to match (e.g., "stock:*")
140 |
141 | Returns:
142 | Number of matching keys
143 | """
144 | ...
145 |
146 | async def get_or_set(
147 | self, key: str, default_value: Any, ttl: int | None = None
148 | ) -> Any:
149 | """
150 | Get value from cache, setting it if it doesn't exist.
151 |
152 | Args:
153 | key: Cache key
154 | default_value: Value to set if key doesn't exist
155 | ttl: Time-to-live for the default value
156 |
157 | Returns:
158 | Either the existing cached value or the default value
159 | """
160 | ...
161 |
162 | async def increment(self, key: str, amount: int = 1) -> int:
163 | """
164 | Increment a numeric value in cache.
165 |
166 | Args:
167 | key: Cache key
168 | amount: Amount to increment by
169 |
170 | Returns:
171 | New value after increment
172 |
173 | Raises:
174 | ValueError: If the key exists but doesn't contain a numeric value
175 | """
176 | ...
177 |
178 | async def set_if_not_exists(
179 | self, key: str, value: Any, ttl: int | None = None
180 | ) -> bool:
181 | """
182 | Set a value only if the key doesn't already exist.
183 |
184 | Args:
185 | key: Cache key
186 | value: Value to set
187 | ttl: Time-to-live in seconds
188 |
189 | Returns:
190 | True if the value was set, False if key already existed
191 | """
192 | ...
193 |
194 | async def get_ttl(self, key: str) -> int | None:
195 | """
196 | Get the remaining time-to-live for a key.
197 |
198 | Args:
199 | key: Cache key
200 |
201 | Returns:
202 | Remaining TTL in seconds, None if key doesn't exist or has no TTL
203 | """
204 | ...
205 |
206 | async def expire(self, key: str, ttl: int) -> bool:
207 | """
208 | Set expiration time for an existing key.
209 |
210 | Args:
211 | key: Cache key
212 | ttl: Time-to-live in seconds
213 |
214 | Returns:
215 | True if expiration was set, False if key doesn't exist
216 | """
217 | ...
218 |
219 |
220 | class CacheConfig:
221 | """
222 | Configuration class for cache implementations.
223 |
224 | This class encapsulates cache-related configuration parameters
225 | to reduce coupling between cache implementations and configuration sources.
226 | """
227 |
228 | def __init__(
229 | self,
230 | enabled: bool = True,
231 | default_ttl: int = 3600,
232 | max_memory_size: int = 1000,
233 | redis_host: str = "localhost",
234 | redis_port: int = 6379,
235 | redis_db: int = 0,
236 | redis_password: str | None = None,
237 | redis_ssl: bool = False,
238 | connection_pool_size: int = 20,
239 | socket_timeout: int = 5,
240 | socket_connect_timeout: int = 5,
241 | ):
242 | """
243 | Initialize cache configuration.
244 |
245 | Args:
246 | enabled: Whether caching is enabled
247 | default_ttl: Default time-to-live in seconds
248 | max_memory_size: Maximum in-memory cache size
249 | redis_host: Redis server host
250 | redis_port: Redis server port
251 | redis_db: Redis database number
252 | redis_password: Redis password (if required)
253 | redis_ssl: Whether to use SSL for Redis connection
254 | connection_pool_size: Redis connection pool size
255 | socket_timeout: Socket timeout in seconds
256 | socket_connect_timeout: Socket connection timeout in seconds
257 | """
258 | self.enabled = enabled
259 | self.default_ttl = default_ttl
260 | self.max_memory_size = max_memory_size
261 | self.redis_host = redis_host
262 | self.redis_port = redis_port
263 | self.redis_db = redis_db
264 | self.redis_password = redis_password
265 | self.redis_ssl = redis_ssl
266 | self.connection_pool_size = connection_pool_size
267 | self.socket_timeout = socket_timeout
268 | self.socket_connect_timeout = socket_connect_timeout
269 |
270 | def get_redis_url(self) -> str:
271 | """
272 | Get Redis connection URL.
273 |
274 | Returns:
275 | Redis connection URL string
276 | """
277 | scheme = "rediss" if self.redis_ssl else "redis"
278 | auth = f":{self.redis_password}@" if self.redis_password else ""
279 | return f"{scheme}://{auth}{self.redis_host}:{self.redis_port}/{self.redis_db}"
280 |
```
--------------------------------------------------------------------------------
/maverick_mcp/providers/mocks/mock_cache.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Mock cache manager implementation for testing.
3 | """
4 |
5 | import time
6 | from typing import Any
7 |
8 |
9 | class MockCacheManager:
10 | """
11 | Mock implementation of ICacheManager for testing.
12 |
13 | This implementation uses in-memory storage and provides predictable
14 | behavior for testing cache-dependent functionality.
15 | """
16 |
17 | def __init__(self):
18 | """Initialize the mock cache manager."""
19 | self._data: dict[str, dict[str, Any]] = {}
20 | self._call_log: list[dict[str, Any]] = []
21 |
22 | async def get(self, key: str) -> Any:
23 | """Get data from mock cache."""
24 | self._log_call("get", {"key": key})
25 |
26 | if key not in self._data:
27 | return None
28 |
29 | entry = self._data[key]
30 |
31 | # Check if expired
32 | if "expires_at" in entry and entry["expires_at"] < time.time():
33 | del self._data[key]
34 | return None
35 |
36 | return entry["value"]
37 |
38 | async def set(self, key: str, value: Any, ttl: int | None = None) -> bool:
39 | """Store data in mock cache."""
40 | self._log_call("set", {"key": key, "value": value, "ttl": ttl})
41 |
42 | entry = {"value": value}
43 |
44 | if ttl is not None:
45 | entry["expires_at"] = time.time() + ttl
46 |
47 | self._data[key] = entry
48 | return True
49 |
50 | async def delete(self, key: str) -> bool:
51 | """Delete a key from mock cache."""
52 | self._log_call("delete", {"key": key})
53 |
54 | if key in self._data:
55 | del self._data[key]
56 | return True
57 |
58 | return False
59 |
60 | async def exists(self, key: str) -> bool:
61 | """Check if a key exists in mock cache."""
62 | self._log_call("exists", {"key": key})
63 |
64 | if key not in self._data:
65 | return False
66 |
67 | entry = self._data[key]
68 |
69 | # Check if expired
70 | if "expires_at" in entry and entry["expires_at"] < time.time():
71 | del self._data[key]
72 | return False
73 |
74 | return True
75 |
76 | async def clear(self, pattern: str | None = None) -> int:
77 | """Clear cache entries."""
78 | self._log_call("clear", {"pattern": pattern})
79 |
80 | if pattern is None:
81 | count = len(self._data)
82 | self._data.clear()
83 | return count
84 |
85 | # Simple pattern matching (only supports prefix*)
86 | if pattern.endswith("*"):
87 | prefix = pattern[:-1]
88 | keys_to_delete = [k for k in self._data.keys() if k.startswith(prefix)]
89 | else:
90 | keys_to_delete = [k for k in self._data.keys() if k == pattern]
91 |
92 | for key in keys_to_delete:
93 | del self._data[key]
94 |
95 | return len(keys_to_delete)
96 |
97 | async def get_many(self, keys: list[str]) -> dict[str, Any]:
98 | """Get multiple values at once."""
99 | self._log_call("get_many", {"keys": keys})
100 |
101 | results = {}
102 | for key in keys:
103 | value = await self.get(key)
104 | if value is not None:
105 | results[key] = value
106 |
107 | return results
108 |
109 | async def set_many(self, items: list[tuple[str, Any, int | None]]) -> int:
110 | """Set multiple values at once."""
111 | self._log_call("set_many", {"items_count": len(items)})
112 |
113 | success_count = 0
114 | for key, value, ttl in items:
115 | if await self.set(key, value, ttl):
116 | success_count += 1
117 |
118 | return success_count
119 |
120 | async def delete_many(self, keys: list[str]) -> int:
121 | """Delete multiple keys."""
122 | self._log_call("delete_many", {"keys": keys})
123 |
124 | deleted_count = 0
125 | for key in keys:
126 | if await self.delete(key):
127 | deleted_count += 1
128 |
129 | return deleted_count
130 |
131 | async def exists_many(self, keys: list[str]) -> dict[str, bool]:
132 | """Check existence of multiple keys."""
133 | self._log_call("exists_many", {"keys": keys})
134 |
135 | results = {}
136 | for key in keys:
137 | results[key] = await self.exists(key)
138 |
139 | return results
140 |
141 | async def count_keys(self, pattern: str) -> int:
142 | """Count keys matching a pattern."""
143 | self._log_call("count_keys", {"pattern": pattern})
144 |
145 | if pattern.endswith("*"):
146 | prefix = pattern[:-1]
147 | return len([k for k in self._data.keys() if k.startswith(prefix)])
148 | else:
149 | return 1 if pattern in self._data else 0
150 |
151 | async def get_or_set(
152 | self, key: str, default_value: Any, ttl: int | None = None
153 | ) -> Any:
154 | """Get value from cache, setting it if it doesn't exist."""
155 | self._log_call(
156 | "get_or_set", {"key": key, "default_value": default_value, "ttl": ttl}
157 | )
158 |
159 | value = await self.get(key)
160 | if value is not None:
161 | return value
162 |
163 | await self.set(key, default_value, ttl)
164 | return default_value
165 |
166 | async def increment(self, key: str, amount: int = 1) -> int:
167 | """Increment a numeric value in cache."""
168 | self._log_call("increment", {"key": key, "amount": amount})
169 |
170 | current = await self.get(key)
171 |
172 | if current is None:
173 | new_value = amount
174 | else:
175 | try:
176 | current_int = int(current)
177 | new_value = current_int + amount
178 | except (ValueError, TypeError):
179 | raise ValueError(f"Key {key} contains non-numeric value: {current}")
180 |
181 | await self.set(key, new_value)
182 | return new_value
183 |
184 | async def set_if_not_exists(
185 | self, key: str, value: Any, ttl: int | None = None
186 | ) -> bool:
187 | """Set a value only if the key doesn't already exist."""
188 | self._log_call("set_if_not_exists", {"key": key, "value": value, "ttl": ttl})
189 |
190 | if await self.exists(key):
191 | return False
192 |
193 | return await self.set(key, value, ttl)
194 |
195 | async def get_ttl(self, key: str) -> int | None:
196 | """Get the remaining time-to-live for a key."""
197 | self._log_call("get_ttl", {"key": key})
198 |
199 | if key not in self._data:
200 | return None
201 |
202 | entry = self._data[key]
203 |
204 | if "expires_at" not in entry:
205 | return None
206 |
207 | remaining = int(entry["expires_at"] - time.time())
208 | return max(0, remaining)
209 |
210 | async def expire(self, key: str, ttl: int) -> bool:
211 | """Set expiration time for an existing key."""
212 | self._log_call("expire", {"key": key, "ttl": ttl})
213 |
214 | if key not in self._data:
215 | return False
216 |
217 | self._data[key]["expires_at"] = time.time() + ttl
218 | return True
219 |
220 | # Testing utilities
221 |
222 | def _log_call(self, method: str, args: dict[str, Any]) -> None:
223 | """Log method calls for testing verification."""
224 | self._call_log.append(
225 | {
226 | "method": method,
227 | "args": args,
228 | "timestamp": time.time(),
229 | }
230 | )
231 |
232 | def get_call_log(self) -> list[dict[str, Any]]:
233 | """Get the log of method calls for testing verification."""
234 | return self._call_log.copy()
235 |
236 | def clear_call_log(self) -> None:
237 | """Clear the method call log."""
238 | self._call_log.clear()
239 |
240 | def get_cache_contents(self) -> dict[str, Any]:
241 | """Get all cache contents for testing verification."""
242 | return {k: v["value"] for k, v in self._data.items()}
243 |
244 | def set_cache_contents(self, contents: dict[str, Any]) -> None:
245 | """Set cache contents directly for testing setup."""
246 | self._data.clear()
247 | for key, value in contents.items():
248 | self._data[key] = {"value": value}
249 |
250 | def simulate_cache_expiry(self, key: str) -> None:
251 | """Simulate cache expiry for testing."""
252 | if key in self._data:
253 | self._data[key]["expires_at"] = time.time() - 1
254 |
```
--------------------------------------------------------------------------------
/maverick_mcp/infrastructure/sse_optimizer.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | SSE Transport Optimizer for FastMCP server stability.
3 |
4 | Provides SSE-specific optimizations to prevent connection drops
5 | and ensure persistent tool availability in Claude Desktop.
6 | """
7 |
8 | import asyncio
9 | import logging
10 | from typing import Any
11 |
12 | from fastmcp import FastMCP
13 | from starlette.middleware.base import BaseHTTPMiddleware
14 | from starlette.requests import Request
15 | from starlette.responses import Response
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | class SSEStabilityMiddleware(BaseHTTPMiddleware):
21 | """
22 | Middleware to enhance SSE connection stability.
23 |
24 | Features:
25 | - Connection keepalive headers
26 | - Proper CORS for SSE
27 | - Connection state tracking
28 | - Automatic reconnection support
29 | """
30 |
31 | async def dispatch(self, request: Request, call_next) -> Response:
32 | # Add SSE-specific headers for stability
33 | response = await call_next(request)
34 |
35 | # SSE connection optimizations
36 | if request.url.path.endswith("/sse"):
37 | # Keepalive and caching headers
38 | response.headers["Cache-Control"] = "no-cache"
39 | response.headers["Connection"] = "keep-alive"
40 | response.headers["Content-Type"] = "text/event-stream"
41 |
42 | # CORS headers for cross-origin SSE
43 | response.headers["Access-Control-Allow-Origin"] = "*"
44 | response.headers["Access-Control-Allow-Methods"] = "GET, POST, OPTIONS"
45 | response.headers["Access-Control-Allow-Headers"] = "*"
46 | response.headers["Access-Control-Allow-Credentials"] = "false"
47 |
48 | # Prevent proxy buffering
49 | response.headers["X-Accel-Buffering"] = "no"
50 |
51 | return response
52 |
53 |
54 | class SSEHeartbeat:
55 | """
56 | Heartbeat mechanism for SSE connections.
57 |
58 | Sends periodic keepalive messages to maintain connection
59 | and detect client disconnections early.
60 | """
61 |
62 | def __init__(self, interval: float = 30.0):
63 | self.interval = interval
64 | self.active_connections: dict[str, asyncio.Task] = {}
65 |
66 | async def start_heartbeat(self, connection_id: str, send_function):
67 | """Start heartbeat for a specific connection."""
68 | try:
69 | while True:
70 | await asyncio.sleep(self.interval)
71 |
72 | # Send heartbeat event
73 | heartbeat_event = {
74 | "event": "heartbeat",
75 | "data": {
76 | "timestamp": asyncio.get_event_loop().time(),
77 | "connection_id": connection_id[:8],
78 | },
79 | }
80 |
81 | await send_function(heartbeat_event)
82 |
83 | except asyncio.CancelledError:
84 | logger.info(f"Heartbeat stopped for connection: {connection_id[:8]}")
85 | except Exception as e:
86 | logger.error(f"Heartbeat error for {connection_id[:8]}: {e}")
87 |
88 | def register_connection(self, connection_id: str, send_function) -> None:
89 | """Register a new connection for heartbeat."""
90 | if connection_id in self.active_connections:
91 | # Cancel existing heartbeat
92 | self.active_connections[connection_id].cancel()
93 |
94 | # Start new heartbeat task
95 | task = asyncio.create_task(self.start_heartbeat(connection_id, send_function))
96 | self.active_connections[connection_id] = task
97 |
98 | logger.info(f"Heartbeat registered for connection: {connection_id[:8]}")
99 |
100 | def unregister_connection(self, connection_id: str) -> None:
101 | """Unregister connection and stop heartbeat."""
102 | if connection_id in self.active_connections:
103 | self.active_connections[connection_id].cancel()
104 | del self.active_connections[connection_id]
105 | logger.info(f"Heartbeat unregistered for connection: {connection_id[:8]}")
106 |
107 | async def shutdown(self):
108 | """Shutdown all heartbeats."""
109 | for task in self.active_connections.values():
110 | task.cancel()
111 |
112 | if self.active_connections:
113 | await asyncio.gather(
114 | *self.active_connections.values(), return_exceptions=True
115 | )
116 |
117 | self.active_connections.clear()
118 | logger.info("All heartbeats shutdown")
119 |
120 |
121 | class SSEOptimizer:
122 | """
123 | SSE Transport Optimizer for enhanced stability.
124 |
125 | Provides comprehensive optimizations for SSE connections:
126 | - Stability middleware
127 | - Heartbeat mechanism
128 | - Connection monitoring
129 | - Automatic recovery
130 | """
131 |
132 | def __init__(self, mcp_server: FastMCP):
133 | self.mcp_server = mcp_server
134 | self.heartbeat = SSEHeartbeat(interval=25.0) # 25-second heartbeat
135 | self.connection_count = 0
136 |
137 | def optimize_server(self) -> None:
138 | """Apply SSE optimizations to the FastMCP server."""
139 |
140 | # Add stability middleware
141 | if hasattr(self.mcp_server, "fastapi_app") and self.mcp_server.fastapi_app:
142 | self.mcp_server.fastapi_app.add_middleware(SSEStabilityMiddleware)
143 | logger.info("SSE stability middleware added")
144 |
145 | # Register SSE event handlers
146 | self._register_sse_handlers()
147 |
148 | logger.info("SSE transport optimizations applied")
149 |
150 | def _register_sse_handlers(self) -> None:
151 | """Register SSE-specific event handlers."""
152 |
153 | @self.mcp_server.event("sse_connection_opened")
154 | async def on_sse_connection_open(connection_id: str, send_function):
155 | """Handle SSE connection open with optimization."""
156 | self.connection_count += 1
157 | logger.info(
158 | f"SSE connection opened: {connection_id[:8]} (total: {self.connection_count})"
159 | )
160 |
161 | # Register heartbeat
162 | self.heartbeat.register_connection(connection_id, send_function)
163 |
164 | # Send connection confirmation
165 | await send_function(
166 | {
167 | "event": "connection_ready",
168 | "data": {
169 | "connection_id": connection_id[:8],
170 | "server": "maverick-mcp",
171 | "transport": "sse",
172 | "optimization": "enabled",
173 | },
174 | }
175 | )
176 |
177 | @self.mcp_server.event("sse_connection_closed")
178 | async def on_sse_connection_close(connection_id: str):
179 | """Handle SSE connection close with cleanup."""
180 | self.connection_count = max(0, self.connection_count - 1)
181 | logger.info(
182 | f"SSE connection closed: {connection_id[:8]} (remaining: {self.connection_count})"
183 | )
184 |
185 | # Unregister heartbeat
186 | self.heartbeat.unregister_connection(connection_id)
187 |
188 | async def shutdown(self):
189 | """Shutdown SSE optimizer."""
190 | await self.heartbeat.shutdown()
191 | logger.info("SSE optimizer shutdown complete")
192 |
193 | def get_sse_status(self) -> dict[str, Any]:
194 | """Get SSE connection status."""
195 | return {
196 | "active_connections": self.connection_count,
197 | "heartbeat_connections": len(self.heartbeat.active_connections),
198 | "heartbeat_interval": self.heartbeat.interval,
199 | "optimization_status": "enabled",
200 | }
201 |
202 |
203 | # Global SSE optimizer instance
204 | _sse_optimizer: SSEOptimizer | None = None
205 |
206 |
207 | def get_sse_optimizer(mcp_server: FastMCP) -> SSEOptimizer:
208 | """Get or create the global SSE optimizer."""
209 | global _sse_optimizer
210 | if _sse_optimizer is None:
211 | _sse_optimizer = SSEOptimizer(mcp_server)
212 | return _sse_optimizer
213 |
214 |
215 | def apply_sse_optimizations(mcp_server: FastMCP) -> SSEOptimizer:
216 | """Apply SSE transport optimizations to FastMCP server."""
217 | optimizer = get_sse_optimizer(mcp_server)
218 | optimizer.optimize_server()
219 | logger.info("SSE transport optimizations applied for enhanced stability")
220 | return optimizer
221 |
```
--------------------------------------------------------------------------------
/maverick_mcp/providers/interfaces/stock_data.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Stock data provider interfaces.
3 |
4 | This module defines abstract interfaces for stock data fetching and screening operations.
5 | These interfaces separate concerns between basic data retrieval and advanced screening logic,
6 | following the Interface Segregation Principle.
7 | """
8 |
9 | from abc import ABC, abstractmethod
10 | from typing import Any, Protocol, runtime_checkable
11 |
12 | import pandas as pd
13 |
14 |
15 | @runtime_checkable
16 | class IStockDataFetcher(Protocol):
17 | """
18 | Interface for fetching basic stock data.
19 |
20 | This interface defines the contract for retrieving historical price data,
21 | real-time quotes, company information, and related financial data.
22 | """
23 |
24 | async def get_stock_data(
25 | self,
26 | symbol: str,
27 | start_date: str | None = None,
28 | end_date: str | None = None,
29 | period: str | None = None,
30 | interval: str = "1d",
31 | use_cache: bool = True,
32 | ) -> pd.DataFrame:
33 | """
34 | Fetch historical stock data.
35 |
36 | Args:
37 | symbol: Stock ticker symbol
38 | start_date: Start date in YYYY-MM-DD format
39 | end_date: End date in YYYY-MM-DD format
40 | period: Alternative to start/end dates (e.g., '1y', '6mo')
41 | interval: Data interval ('1d', '1wk', '1mo', etc.)
42 | use_cache: Whether to use cached data if available
43 |
44 | Returns:
45 | DataFrame with OHLCV data indexed by date
46 | """
47 | ...
48 |
49 | async def get_realtime_data(self, symbol: str) -> dict[str, Any] | None:
50 | """
51 | Get real-time stock data.
52 |
53 | Args:
54 | symbol: Stock ticker symbol
55 |
56 | Returns:
57 | Dictionary with current price, change, volume, etc. or None if unavailable
58 | """
59 | ...
60 |
61 | async def get_stock_info(self, symbol: str) -> dict[str, Any]:
62 | """
63 | Get detailed stock information and fundamentals.
64 |
65 | Args:
66 | symbol: Stock ticker symbol
67 |
68 | Returns:
69 | Dictionary with company info, financials, and market data
70 | """
71 | ...
72 |
73 | async def get_news(self, symbol: str, limit: int = 10) -> pd.DataFrame:
74 | """
75 | Get news articles for a stock.
76 |
77 | Args:
78 | symbol: Stock ticker symbol
79 | limit: Maximum number of articles to return
80 |
81 | Returns:
82 | DataFrame with news articles
83 | """
84 | ...
85 |
86 | async def get_earnings(self, symbol: str) -> dict[str, Any]:
87 | """
88 | Get earnings information for a stock.
89 |
90 | Args:
91 | symbol: Stock ticker symbol
92 |
93 | Returns:
94 | Dictionary with earnings data and dates
95 | """
96 | ...
97 |
98 | async def get_recommendations(self, symbol: str) -> pd.DataFrame:
99 | """
100 | Get analyst recommendations for a stock.
101 |
102 | Args:
103 | symbol: Stock ticker symbol
104 |
105 | Returns:
106 | DataFrame with analyst recommendations
107 | """
108 | ...
109 |
110 | async def is_market_open(self) -> bool:
111 | """
112 | Check if the stock market is currently open.
113 |
114 | Returns:
115 | True if market is open, False otherwise
116 | """
117 | ...
118 |
119 | async def is_etf(self, symbol: str) -> bool:
120 | """
121 | Check if a symbol represents an ETF.
122 |
123 | Args:
124 | symbol: Stock ticker symbol
125 |
126 | Returns:
127 | True if symbol is an ETF, False otherwise
128 | """
129 | ...
130 |
131 |
132 | @runtime_checkable
133 | class IStockScreener(Protocol):
134 | """
135 | Interface for stock screening and recommendation operations.
136 |
137 | This interface defines the contract for generating stock recommendations
138 | based on various technical and fundamental criteria.
139 | """
140 |
141 | async def get_maverick_recommendations(
142 | self, limit: int = 20, min_score: int | None = None
143 | ) -> list[dict[str, Any]]:
144 | """
145 | Get bullish Maverick stock recommendations.
146 |
147 | Args:
148 | limit: Maximum number of recommendations
149 | min_score: Minimum combined score filter
150 |
151 | Returns:
152 | List of stock recommendations with technical analysis
153 | """
154 | ...
155 |
156 | async def get_maverick_bear_recommendations(
157 | self, limit: int = 20, min_score: int | None = None
158 | ) -> list[dict[str, Any]]:
159 | """
160 | Get bearish Maverick stock recommendations.
161 |
162 | Args:
163 | limit: Maximum number of recommendations
164 | min_score: Minimum score filter
165 |
166 | Returns:
167 | List of bear stock recommendations
168 | """
169 | ...
170 |
171 | async def get_trending_recommendations(
172 | self, limit: int = 20, min_momentum_score: float | None = None
173 | ) -> list[dict[str, Any]]:
174 | """
175 | Get trending stock recommendations.
176 |
177 | Args:
178 | limit: Maximum number of recommendations
179 | min_momentum_score: Minimum momentum score filter
180 |
181 | Returns:
182 | List of trending stock recommendations
183 | """
184 | ...
185 |
186 | async def get_all_screening_recommendations(
187 | self,
188 | ) -> dict[str, list[dict[str, Any]]]:
189 | """
190 | Get all screening recommendations in one call.
191 |
192 | Returns:
193 | Dictionary with all screening types and their recommendations
194 | """
195 | ...
196 |
197 |
198 | class StockDataProviderBase(ABC):
199 | """
200 | Abstract base class for stock data providers.
201 |
202 | This class provides a foundation for implementing both IStockDataFetcher
203 | and IStockScreener interfaces, with common functionality and error handling.
204 | """
205 |
206 | @abstractmethod
207 | def _fetch_stock_data_from_source(
208 | self,
209 | symbol: str,
210 | start_date: str | None = None,
211 | end_date: str | None = None,
212 | period: str | None = None,
213 | interval: str = "1d",
214 | ) -> pd.DataFrame:
215 | """
216 | Fetch stock data from the underlying data source.
217 |
218 | This method must be implemented by concrete providers to define
219 | how data is actually retrieved (e.g., from yfinance, Alpha Vantage, etc.)
220 | """
221 | pass
222 |
223 | def _validate_symbol(self, symbol: str) -> str:
224 | """
225 | Validate and normalize a stock symbol.
226 |
227 | Args:
228 | symbol: Raw stock symbol
229 |
230 | Returns:
231 | Normalized symbol (uppercase, stripped)
232 |
233 | Raises:
234 | ValueError: If symbol is invalid
235 | """
236 | if not symbol or not isinstance(symbol, str):
237 | raise ValueError("Symbol must be a non-empty string")
238 |
239 | normalized = symbol.strip().upper()
240 | if not normalized:
241 | raise ValueError("Symbol cannot be empty after normalization")
242 |
243 | return normalized
244 |
245 | def _validate_date_range(
246 | self, start_date: str | None, end_date: str | None
247 | ) -> tuple[str | None, str | None]:
248 | """
249 | Validate date range parameters.
250 |
251 | Args:
252 | start_date: Start date string
253 | end_date: End date string
254 |
255 | Returns:
256 | Tuple of validated dates
257 |
258 | Raises:
259 | ValueError: If date format is invalid
260 | """
261 | # Basic validation - can be extended with actual date parsing
262 | if start_date is not None and not isinstance(start_date, str):
263 | raise ValueError("start_date must be a string in YYYY-MM-DD format")
264 |
265 | if end_date is not None and not isinstance(end_date, str):
266 | raise ValueError("end_date must be a string in YYYY-MM-DD format")
267 |
268 | return start_date, end_date
269 |
270 | def _handle_provider_error(self, error: Exception, context: str) -> None:
271 | """
272 | Handle provider-specific errors with consistent logging.
273 |
274 | Args:
275 | error: The exception that occurred
276 | context: Context information for debugging
277 | """
278 | # This would integrate with the logging system
279 | # For now, we'll re-raise to maintain existing behavior
280 | raise error
281 |
```
--------------------------------------------------------------------------------
/tools/templates/screening_strategy_template.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Template for creating new stock screening strategies.
3 |
4 | Copy this file and modify it to create new screening strategies quickly.
5 | """
6 |
7 | from datetime import datetime, timedelta
8 | from typing import Any
9 |
10 | import pandas as pd
11 |
12 | from maverick_mcp.core.technical_analysis import (
13 | calculate_atr,
14 | calculate_rsi,
15 | calculate_sma,
16 | )
17 | from maverick_mcp.data.models import Stock, get_db
18 | from maverick_mcp.providers.stock_data import StockDataProvider
19 | from maverick_mcp.utils.logging import get_logger
20 |
21 | logger = get_logger(__name__)
22 |
23 |
24 | class YourScreeningStrategy:
25 | """
26 | Your custom screening strategy.
27 |
28 | This strategy identifies stocks that meet specific criteria
29 | based on technical indicators and price action.
30 | """
31 |
32 | def __init__(
33 | self,
34 | min_price: float = 10.0,
35 | min_volume: int = 1_000_000,
36 | lookback_days: int = 90,
37 | ):
38 | """
39 | Initialize the screening strategy.
40 |
41 | Args:
42 | min_price: Minimum stock price to consider
43 | min_volume: Minimum average daily volume
44 | lookback_days: Number of days to analyze
45 | """
46 | self.min_price = min_price
47 | self.min_volume = min_volume
48 | self.lookback_days = lookback_days
49 | self.stock_provider = StockDataProvider()
50 |
51 | def calculate_score(self, symbol: str, data: pd.DataFrame) -> float:
52 | """
53 | Calculate a composite score for the stock.
54 |
55 | Args:
56 | symbol: Stock symbol
57 | data: Historical price data
58 |
59 | Returns:
60 | Score between 0 and 100
61 | """
62 | score = 0.0
63 |
64 | try:
65 | # Price above moving averages
66 | sma_20 = calculate_sma(data, 20).iloc[-1]
67 | sma_50 = calculate_sma(data, 50).iloc[-1]
68 | current_price = data["Close"].iloc[-1]
69 |
70 | if current_price > sma_20:
71 | score += 20
72 | if current_price > sma_50:
73 | score += 15
74 |
75 | # RSI in optimal range (not overbought/oversold)
76 | rsi = calculate_rsi(data, 14).iloc[-1]
77 | if 40 <= rsi <= 70:
78 | score += 20
79 | elif 30 <= rsi <= 80:
80 | score += 10
81 |
82 | # MACD bullish (using pandas_ta as alternative)
83 | try:
84 | import pandas_ta as ta
85 |
86 | macd = ta.macd(data["close"])
87 | if macd["MACD_12_26_9"].iloc[-1] > macd["MACDs_12_26_9"].iloc[-1]:
88 | score += 15
89 | except ImportError:
90 | # Skip MACD if pandas_ta not available
91 | pass
92 |
93 | # Volume increasing
94 | avg_volume_recent = data["Volume"].iloc[-5:].mean()
95 | avg_volume_prior = data["Volume"].iloc[-20:-5].mean()
96 | if avg_volume_recent > avg_volume_prior * 1.2:
97 | score += 15
98 |
99 | # Price momentum
100 | price_change_1m = (current_price / data["Close"].iloc[-20] - 1) * 100
101 | if price_change_1m > 10:
102 | score += 15
103 | elif price_change_1m > 5:
104 | score += 10
105 |
106 | logger.debug(
107 | f"Score calculated for {symbol}: {score}",
108 | extra={
109 | "symbol": symbol,
110 | "price": current_price,
111 | "rsi": rsi,
112 | "score": score,
113 | },
114 | )
115 |
116 | except Exception as e:
117 | logger.error(f"Error calculating score for {symbol}: {e}")
118 | score = 0.0
119 |
120 | return min(score, 100.0)
121 |
122 | def screen_stocks(
123 | self,
124 | symbols: list[str] | None = None,
125 | min_score: float = 70.0,
126 | ) -> list[dict[str, Any]]:
127 | """
128 | Screen stocks based on the strategy criteria.
129 |
130 | Args:
131 | symbols: List of symbols to screen (None for all)
132 | min_score: Minimum score to include in results
133 |
134 | Returns:
135 | List of stocks meeting criteria with scores
136 | """
137 | results = []
138 | end_date = datetime.now().strftime("%Y-%m-%d")
139 | start_date = (datetime.now() - timedelta(days=self.lookback_days)).strftime(
140 | "%Y-%m-%d"
141 | )
142 |
143 | # Get list of symbols to screen
144 | if symbols is None:
145 | # Get all active stocks from database
146 | db = next(get_db())
147 | try:
148 | stocks = db.query(Stock).filter(Stock.is_active).all()
149 | symbols = [stock.symbol for stock in stocks]
150 | finally:
151 | db.close()
152 |
153 | logger.info(f"Screening {len(symbols)} stocks")
154 |
155 | # Screen each stock
156 | for symbol in symbols:
157 | try:
158 | # Get historical data
159 | data = self.stock_provider.get_stock_data(symbol, start_date, end_date)
160 |
161 | if len(data) < 50: # Need enough data for indicators
162 | continue
163 |
164 | # Check basic criteria
165 | current_price = data["Close"].iloc[-1]
166 | avg_volume = data["Volume"].iloc[-20:].mean()
167 |
168 | if current_price < self.min_price or avg_volume < self.min_volume:
169 | continue
170 |
171 | # Calculate score
172 | score = self.calculate_score(symbol, data)
173 |
174 | if score >= min_score:
175 | # Calculate additional metrics
176 | atr = calculate_atr(data, 14).iloc[-1]
177 | price_change_5d = (
178 | data["Close"].iloc[-1] / data["Close"].iloc[-5] - 1
179 | ) * 100
180 |
181 | result = {
182 | "symbol": symbol,
183 | "score": round(score, 2),
184 | "price": round(current_price, 2),
185 | "volume": int(avg_volume),
186 | "atr": round(atr, 2),
187 | "price_change_5d": round(price_change_5d, 2),
188 | "rsi": round(calculate_rsi(data, 14).iloc[-1], 2),
189 | "above_sma_20": current_price
190 | > calculate_sma(data, 20).iloc[-1],
191 | "above_sma_50": current_price
192 | > calculate_sma(data, 50).iloc[-1],
193 | }
194 |
195 | results.append(result)
196 | logger.info(f"Stock passed screening: {symbol} (score: {score})")
197 |
198 | except Exception as e:
199 | logger.error(f"Error screening {symbol}: {e}")
200 | continue
201 |
202 | # Sort by score descending
203 | results.sort(key=lambda x: x["score"], reverse=True)
204 |
205 | logger.info(f"Screening complete: {len(results)} stocks found")
206 | return results
207 |
208 | def get_entry_exit_levels(
209 | self, symbol: str, data: pd.DataFrame
210 | ) -> dict[str, float]:
211 | """
212 | Calculate entry, stop loss, and target levels.
213 |
214 | Args:
215 | symbol: Stock symbol
216 | data: Historical price data
217 |
218 | Returns:
219 | Dictionary with entry, stop, and target levels
220 | """
221 | current_price = data["Close"].iloc[-1]
222 | atr = calculate_atr(data, 14).iloc[-1]
223 |
224 | # Find recent support/resistance
225 | recent_low = data["Low"].iloc[-20:].min()
226 |
227 | # Calculate levels
228 | entry = current_price
229 | stop_loss = max(current_price - (2 * atr), recent_low * 0.98)
230 | target1 = current_price + (2 * atr)
231 | target2 = current_price + (3 * atr)
232 |
233 | # Ensure minimum risk/reward
234 | risk = entry - stop_loss
235 | reward = target1 - entry
236 | if reward / risk < 2:
237 | target1 = entry + (2 * risk)
238 | target2 = entry + (3 * risk)
239 |
240 | return {
241 | "entry": round(entry, 2),
242 | "stop_loss": round(stop_loss, 2),
243 | "target1": round(target1, 2),
244 | "target2": round(target2, 2),
245 | "risk_reward_ratio": round(reward / risk, 2),
246 | }
247 |
```
--------------------------------------------------------------------------------
/tests/test_session_management.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Tests for enhanced database session management.
3 |
4 | Tests the new context managers and connection pool monitoring
5 | introduced to fix Issue #55: Database Session Management.
6 | """
7 |
8 | from unittest.mock import Mock, patch
9 |
10 | import pytest
11 |
12 | from maverick_mcp.data.session_management import (
13 | check_connection_pool_health,
14 | get_connection_pool_status,
15 | get_db_session,
16 | get_db_session_read_only,
17 | )
18 |
19 |
20 | class TestSessionManagement:
21 | """Test suite for database session management context managers."""
22 |
23 | @patch("maverick_mcp.data.session_management.SessionLocal")
24 | def test_get_db_session_success(self, mock_session_local):
25 | """Test successful database session with automatic commit."""
26 | mock_session = Mock()
27 | mock_session_local.return_value = mock_session
28 |
29 | with get_db_session() as session:
30 | assert session == mock_session
31 | # Simulate some database operation
32 |
33 | # Verify session lifecycle
34 | mock_session.commit.assert_called_once()
35 | mock_session.close.assert_called_once()
36 | mock_session.rollback.assert_not_called()
37 |
38 | @patch("maverick_mcp.data.session_management.SessionLocal")
39 | def test_get_db_session_exception_rollback(self, mock_session_local):
40 | """Test database session rollback on exception."""
41 | mock_session = Mock()
42 | mock_session_local.return_value = mock_session
43 |
44 | with pytest.raises(ValueError):
45 | with get_db_session() as session:
46 | assert session == mock_session
47 | raise ValueError("Test exception")
48 |
49 | # Verify rollback was called, but not commit
50 | mock_session.rollback.assert_called_once()
51 | mock_session.commit.assert_not_called()
52 | mock_session.close.assert_called_once()
53 |
54 | @patch("maverick_mcp.data.session_management.SessionLocal")
55 | def test_get_db_session_read_only_success(self, mock_session_local):
56 | """Test read-only database session (no commit)."""
57 | mock_session = Mock()
58 | mock_session_local.return_value = mock_session
59 |
60 | with get_db_session_read_only() as session:
61 | assert session == mock_session
62 | # Simulate some read-only operation
63 |
64 | # Verify no commit for read-only operations
65 | mock_session.commit.assert_not_called()
66 | mock_session.close.assert_called_once()
67 | mock_session.rollback.assert_not_called()
68 |
69 | @patch("maverick_mcp.data.session_management.SessionLocal")
70 | def test_get_db_session_read_only_exception_rollback(self, mock_session_local):
71 | """Test read-only database session rollback on exception."""
72 | mock_session = Mock()
73 | mock_session_local.return_value = mock_session
74 |
75 | with pytest.raises(RuntimeError):
76 | with get_db_session_read_only() as session:
77 | assert session == mock_session
78 | raise RuntimeError("Read operation failed")
79 |
80 | # Verify rollback was called, but not commit
81 | mock_session.rollback.assert_called_once()
82 | mock_session.commit.assert_not_called()
83 | mock_session.close.assert_called_once()
84 |
85 |
86 | class TestConnectionPoolMonitoring:
87 | """Test suite for connection pool monitoring functionality."""
88 |
89 | @patch("maverick_mcp.data.models.engine")
90 | def test_get_connection_pool_status(self, mock_engine):
91 | """Test connection pool status reporting."""
92 | mock_pool = Mock()
93 | mock_pool.size.return_value = 10
94 | mock_pool.checkedin.return_value = 5
95 | mock_pool.checkedout.return_value = 3
96 | mock_pool.overflow.return_value = 0
97 | mock_pool.invalid.return_value = 0
98 | mock_engine.pool = mock_pool
99 |
100 | status = get_connection_pool_status()
101 |
102 | expected = {
103 | "pool_size": 10,
104 | "checked_in": 5,
105 | "checked_out": 3,
106 | "overflow": 0,
107 | "invalid": 0,
108 | "pool_status": "healthy", # 3/10 = 30% < 80%
109 | }
110 | assert status == expected
111 |
112 | @patch("maverick_mcp.data.models.engine")
113 | def test_get_connection_pool_status_warning(self, mock_engine):
114 | """Test connection pool status with high utilization warning."""
115 | mock_pool = Mock()
116 | mock_pool.size.return_value = 10
117 | mock_pool.checkedin.return_value = 1
118 | mock_pool.checkedout.return_value = 9 # 90% utilization
119 | mock_pool.overflow.return_value = 0
120 | mock_pool.invalid.return_value = 0
121 | mock_engine.pool = mock_pool
122 |
123 | status = get_connection_pool_status()
124 |
125 | assert status["pool_status"] == "warning"
126 | assert status["checked_out"] == 9
127 |
128 | @patch("maverick_mcp.data.session_management.get_connection_pool_status")
129 | def test_check_connection_pool_health_healthy(self, mock_get_status):
130 | """Test connection pool health check - healthy scenario."""
131 | mock_get_status.return_value = {
132 | "pool_size": 10,
133 | "checked_out": 5, # 50% utilization
134 | "invalid": 0,
135 | }
136 |
137 | assert check_connection_pool_health() is True
138 |
139 | @patch("maverick_mcp.data.session_management.get_connection_pool_status")
140 | def test_check_connection_pool_health_high_utilization(self, mock_get_status):
141 | """Test connection pool health check - high utilization."""
142 | mock_get_status.return_value = {
143 | "pool_size": 10,
144 | "checked_out": 9, # 90% utilization > 80% threshold
145 | "invalid": 0,
146 | }
147 |
148 | assert check_connection_pool_health() is False
149 |
150 | @patch("maverick_mcp.data.session_management.get_connection_pool_status")
151 | def test_check_connection_pool_health_invalid_connections(self, mock_get_status):
152 | """Test connection pool health check - invalid connections detected."""
153 | mock_get_status.return_value = {
154 | "pool_size": 10,
155 | "checked_out": 3, # Low utilization
156 | "invalid": 2, # But has invalid connections
157 | }
158 |
159 | assert check_connection_pool_health() is False
160 |
161 | @patch("maverick_mcp.data.session_management.get_connection_pool_status")
162 | def test_check_connection_pool_health_exception(self, mock_get_status):
163 | """Test connection pool health check with exception handling."""
164 | mock_get_status.side_effect = Exception("Pool access failed")
165 |
166 | assert check_connection_pool_health() is False
167 |
168 |
169 | class TestSessionManagementIntegration:
170 | """Integration tests for session management with real database."""
171 |
172 | @pytest.mark.integration
173 | def test_session_context_manager_real_db(self):
174 | """Test session context manager with real database connection."""
175 | try:
176 | with get_db_session_read_only() as session:
177 | # Simple test query that should work on any PostgreSQL database
178 | result = session.execute("SELECT 1 as test_value")
179 | row = result.fetchone()
180 | assert row[0] == 1
181 | except Exception as e:
182 | # If database is not available, skip this test
183 | pytest.skip(f"Database not available for integration test: {e}")
184 |
185 | @pytest.mark.integration
186 | def test_connection_pool_status_real(self):
187 | """Test connection pool status with real database."""
188 | try:
189 | status = get_connection_pool_status()
190 |
191 | # Verify the status has expected keys
192 | required_keys = [
193 | "pool_size",
194 | "checked_in",
195 | "checked_out",
196 | "overflow",
197 | "invalid",
198 | "pool_status",
199 | ]
200 | for key in required_keys:
201 | assert key in status
202 |
203 | # Verify status values are reasonable
204 | assert isinstance(status["pool_size"], int)
205 | assert status["pool_size"] > 0
206 | assert status["pool_status"] in ["healthy", "warning"]
207 |
208 | except Exception as e:
209 | # If database is not available, skip this test
210 | pytest.skip(f"Database not available for integration test: {e}")
211 |
```
--------------------------------------------------------------------------------
/maverick_mcp/providers/mocks/mock_config.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Mock configuration provider implementation for testing.
3 | """
4 |
5 | from typing import Any
6 |
7 |
8 | class MockConfigurationProvider:
9 | """
10 | Mock implementation of IConfigurationProvider for testing.
11 |
12 | This implementation provides safe test defaults and allows
13 | easy configuration overrides for specific test scenarios.
14 | """
15 |
16 | def __init__(self, overrides: dict[str, Any] | None = None):
17 | """
18 | Initialize the mock configuration provider.
19 |
20 | Args:
21 | overrides: Optional dictionary of configuration overrides
22 | """
23 | self._overrides = overrides or {}
24 | self._defaults = {
25 | "DATABASE_URL": "sqlite:///:memory:",
26 | "REDIS_HOST": "localhost",
27 | "REDIS_PORT": 6379,
28 | "REDIS_DB": 1, # Use different DB for tests
29 | "REDIS_PASSWORD": None,
30 | "REDIS_SSL": False,
31 | "CACHE_ENABLED": False, # Disable cache in tests by default
32 | "CACHE_TTL_SECONDS": 300, # 5 minutes for tests
33 | "FRED_API_KEY": "",
34 | "CAPITAL_COMPANION_API_KEY": "",
35 | "TIINGO_API_KEY": "",
36 | "AUTH_ENABLED": False,
37 | "JWT_SECRET_KEY": "test-secret-key",
38 | "LOG_LEVEL": "DEBUG",
39 | "ENVIRONMENT": "test",
40 | "REQUEST_TIMEOUT": 5,
41 | "MAX_RETRIES": 1,
42 | "DB_POOL_SIZE": 1,
43 | "DB_MAX_OVERFLOW": 0,
44 | }
45 | self._call_log: list[dict[str, Any]] = []
46 |
47 | def get_database_url(self) -> str:
48 | """Get mock database URL."""
49 | self._log_call("get_database_url", {})
50 | return self._get_value("DATABASE_URL")
51 |
52 | def get_redis_host(self) -> str:
53 | """Get mock Redis host."""
54 | self._log_call("get_redis_host", {})
55 | return self._get_value("REDIS_HOST")
56 |
57 | def get_redis_port(self) -> int:
58 | """Get mock Redis port."""
59 | self._log_call("get_redis_port", {})
60 | return int(self._get_value("REDIS_PORT"))
61 |
62 | def get_redis_db(self) -> int:
63 | """Get mock Redis database."""
64 | self._log_call("get_redis_db", {})
65 | return int(self._get_value("REDIS_DB"))
66 |
67 | def get_redis_password(self) -> str | None:
68 | """Get mock Redis password."""
69 | self._log_call("get_redis_password", {})
70 | return self._get_value("REDIS_PASSWORD")
71 |
72 | def get_redis_ssl(self) -> bool:
73 | """Get mock Redis SSL setting."""
74 | self._log_call("get_redis_ssl", {})
75 | return bool(self._get_value("REDIS_SSL"))
76 |
77 | def is_cache_enabled(self) -> bool:
78 | """Check if mock caching is enabled."""
79 | self._log_call("is_cache_enabled", {})
80 | return bool(self._get_value("CACHE_ENABLED"))
81 |
82 | def get_cache_ttl(self) -> int:
83 | """Get mock cache TTL."""
84 | self._log_call("get_cache_ttl", {})
85 | return int(self._get_value("CACHE_TTL_SECONDS"))
86 |
87 | def get_fred_api_key(self) -> str:
88 | """Get mock FRED API key."""
89 | self._log_call("get_fred_api_key", {})
90 | return str(self._get_value("FRED_API_KEY"))
91 |
92 | def get_external_api_key(self) -> str:
93 | """Get mock External API key."""
94 | self._log_call("get_external_api_key", {})
95 | return str(self._get_value("CAPITAL_COMPANION_API_KEY"))
96 |
97 | def get_tiingo_api_key(self) -> str:
98 | """Get mock Tiingo API key."""
99 | self._log_call("get_tiingo_api_key", {})
100 | return str(self._get_value("TIINGO_API_KEY"))
101 |
102 | def is_auth_enabled(self) -> bool:
103 | """Check if mock auth is enabled."""
104 | self._log_call("is_auth_enabled", {})
105 | return bool(self._get_value("AUTH_ENABLED"))
106 |
107 | def get_jwt_secret_key(self) -> str:
108 | """Get mock JWT secret key."""
109 | self._log_call("get_jwt_secret_key", {})
110 | return str(self._get_value("JWT_SECRET_KEY"))
111 |
112 | def get_log_level(self) -> str:
113 | """Get mock log level."""
114 | self._log_call("get_log_level", {})
115 | return str(self._get_value("LOG_LEVEL"))
116 |
117 | def is_development_mode(self) -> bool:
118 | """Check if in mock development mode."""
119 | self._log_call("is_development_mode", {})
120 | env = str(self._get_value("ENVIRONMENT")).lower()
121 | return env in ("development", "dev", "test")
122 |
123 | def is_production_mode(self) -> bool:
124 | """Check if in mock production mode."""
125 | self._log_call("is_production_mode", {})
126 | env = str(self._get_value("ENVIRONMENT")).lower()
127 | return env in ("production", "prod")
128 |
129 | def get_request_timeout(self) -> int:
130 | """Get mock request timeout."""
131 | self._log_call("get_request_timeout", {})
132 | return int(self._get_value("REQUEST_TIMEOUT"))
133 |
134 | def get_max_retries(self) -> int:
135 | """Get mock max retries."""
136 | self._log_call("get_max_retries", {})
137 | return int(self._get_value("MAX_RETRIES"))
138 |
139 | def get_pool_size(self) -> int:
140 | """Get mock pool size."""
141 | self._log_call("get_pool_size", {})
142 | return int(self._get_value("DB_POOL_SIZE"))
143 |
144 | def get_max_overflow(self) -> int:
145 | """Get mock max overflow."""
146 | self._log_call("get_max_overflow", {})
147 | return int(self._get_value("DB_MAX_OVERFLOW"))
148 |
149 | def get_config_value(self, key: str, default: Any = None) -> Any:
150 | """Get mock configuration value."""
151 | self._log_call("get_config_value", {"key": key, "default": default})
152 |
153 | if key in self._overrides:
154 | return self._overrides[key]
155 | elif key in self._defaults:
156 | return self._defaults[key]
157 | else:
158 | return default
159 |
160 | def set_config_value(self, key: str, value: Any) -> None:
161 | """Set mock configuration value."""
162 | self._log_call("set_config_value", {"key": key, "value": value})
163 | self._overrides[key] = value
164 |
165 | def get_all_config(self) -> dict[str, Any]:
166 | """Get all mock configuration."""
167 | self._log_call("get_all_config", {})
168 |
169 | config = self._defaults.copy()
170 | config.update(self._overrides)
171 | return config
172 |
173 | def reload_config(self) -> None:
174 | """Reload mock configuration (no-op)."""
175 | self._log_call("reload_config", {})
176 | # No-op for mock implementation
177 |
178 | def _get_value(self, key: str) -> Any:
179 | """Get a configuration value with override support."""
180 | if key in self._overrides:
181 | return self._overrides[key]
182 | return self._defaults.get(key)
183 |
184 | # Testing utilities
185 |
186 | def _log_call(self, method: str, args: dict[str, Any]) -> None:
187 | """Log method calls for testing verification."""
188 | self._call_log.append(
189 | {
190 | "method": method,
191 | "args": args,
192 | }
193 | )
194 |
195 | def get_call_log(self) -> list[dict[str, Any]]:
196 | """Get the log of method calls."""
197 | return self._call_log.copy()
198 |
199 | def clear_call_log(self) -> None:
200 | """Clear the method call log."""
201 | self._call_log.clear()
202 |
203 | def set_override(self, key: str, value: Any) -> None:
204 | """Set a configuration override for testing."""
205 | self._overrides[key] = value
206 |
207 | def clear_overrides(self) -> None:
208 | """Clear all configuration overrides."""
209 | self._overrides.clear()
210 |
211 | def enable_cache(self) -> None:
212 | """Enable caching for testing."""
213 | self.set_override("CACHE_ENABLED", True)
214 |
215 | def disable_cache(self) -> None:
216 | """Disable caching for testing."""
217 | self.set_override("CACHE_ENABLED", False)
218 |
219 | def enable_auth(self) -> None:
220 | """Enable authentication for testing."""
221 | self.set_override("AUTH_ENABLED", True)
222 |
223 | def disable_auth(self) -> None:
224 | """Disable authentication for testing."""
225 | self.set_override("AUTH_ENABLED", False)
226 |
227 | def set_production_mode(self) -> None:
228 | """Set production mode for testing."""
229 | self.set_override("ENVIRONMENT", "production")
230 |
231 | def set_development_mode(self) -> None:
232 | """Set development mode for testing."""
233 | self.set_override("ENVIRONMENT", "development")
234 |
```
--------------------------------------------------------------------------------
/tests/domain/test_technical_analysis_service.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Unit tests for the TechnicalAnalysisService domain service.
3 |
4 | These tests demonstrate that the domain service can be tested
5 | without any infrastructure dependencies (no mocks needed).
6 | """
7 |
8 | import numpy as np
9 | import pandas as pd
10 | import pytest
11 |
12 | from maverick_mcp.domain.services.technical_analysis_service import (
13 | TechnicalAnalysisService,
14 | )
15 | from maverick_mcp.domain.value_objects.technical_indicators import (
16 | Signal,
17 | TrendDirection,
18 | )
19 |
20 |
21 | class TestTechnicalAnalysisService:
22 | """Test the technical analysis domain service."""
23 |
24 | @pytest.fixture
25 | def service(self):
26 | """Create a technical analysis service instance."""
27 | return TechnicalAnalysisService()
28 |
29 | @pytest.fixture
30 | def sample_prices(self):
31 | """Create sample price data for testing."""
32 | # Generate synthetic price data
33 | dates = pd.date_range(start="2024-01-01", periods=100, freq="D")
34 | prices = 100 + np.cumsum(np.random.randn(100) * 2)
35 | return pd.Series(prices, index=dates)
36 |
37 | @pytest.fixture
38 | def sample_ohlc(self):
39 | """Create sample OHLC data for testing."""
40 | dates = pd.date_range(start="2024-01-01", periods=100, freq="D")
41 | close = 100 + np.cumsum(np.random.randn(100) * 2)
42 |
43 | # Generate high/low based on close
44 | high = close + np.abs(np.random.randn(100))
45 | low = close - np.abs(np.random.randn(100))
46 |
47 | return pd.DataFrame(
48 | {
49 | "high": high,
50 | "low": low,
51 | "close": close,
52 | },
53 | index=dates,
54 | )
55 |
56 | def test_calculate_rsi(self, service, sample_prices):
57 | """Test RSI calculation."""
58 | rsi = service.calculate_rsi(sample_prices, period=14)
59 |
60 | # RSI should be between 0 and 100
61 | assert 0 <= rsi.value <= 100
62 | assert rsi.period == 14
63 |
64 | # Check signal logic
65 | if rsi.value >= 70:
66 | assert rsi.is_overbought
67 | if rsi.value <= 30:
68 | assert rsi.is_oversold
69 |
70 | def test_calculate_rsi_insufficient_data(self, service):
71 | """Test RSI with insufficient data."""
72 | prices = pd.Series([100, 101, 102]) # Only 3 prices
73 |
74 | with pytest.raises(ValueError, match="Need at least 14 prices"):
75 | service.calculate_rsi(prices, period=14)
76 |
77 | def test_calculate_macd(self, service, sample_prices):
78 | """Test MACD calculation."""
79 | macd = service.calculate_macd(sample_prices)
80 |
81 | # Check structure
82 | assert hasattr(macd, "macd_line")
83 | assert hasattr(macd, "signal_line")
84 | assert hasattr(macd, "histogram")
85 |
86 | # Histogram should be difference between MACD and signal
87 | assert abs(macd.histogram - (macd.macd_line - macd.signal_line)) < 0.01
88 |
89 | # Check signal logic
90 | if macd.macd_line > macd.signal_line and macd.histogram > 0:
91 | assert macd.is_bullish_crossover
92 | if macd.macd_line < macd.signal_line and macd.histogram < 0:
93 | assert macd.is_bearish_crossover
94 |
95 | def test_calculate_bollinger_bands(self, service, sample_prices):
96 | """Test Bollinger Bands calculation."""
97 | bb = service.calculate_bollinger_bands(sample_prices)
98 |
99 | # Check structure
100 | assert bb.upper_band > bb.middle_band
101 | assert bb.middle_band > bb.lower_band
102 | assert bb.period == 20
103 | assert bb.std_dev == 2
104 |
105 | # Check bandwidth calculation
106 | expected_bandwidth = (bb.upper_band - bb.lower_band) / bb.middle_band
107 | assert abs(bb.bandwidth - expected_bandwidth) < 0.01
108 |
109 | # Check %B calculation
110 | expected_percent_b = (bb.current_price - bb.lower_band) / (
111 | bb.upper_band - bb.lower_band
112 | )
113 | assert abs(bb.percent_b - expected_percent_b) < 0.01
114 |
115 | def test_calculate_stochastic(self, service, sample_ohlc):
116 | """Test Stochastic Oscillator calculation."""
117 | stoch = service.calculate_stochastic(
118 | sample_ohlc["high"],
119 | sample_ohlc["low"],
120 | sample_ohlc["close"],
121 | period=14,
122 | )
123 |
124 | # Values should be between 0 and 100
125 | assert 0 <= stoch.k_value <= 100
126 | assert 0 <= stoch.d_value <= 100
127 | assert stoch.period == 14
128 |
129 | # Check overbought/oversold logic
130 | if stoch.k_value >= 80:
131 | assert stoch.is_overbought
132 | if stoch.k_value <= 20:
133 | assert stoch.is_oversold
134 |
135 | def test_identify_trend_uptrend(self, service):
136 | """Test trend identification for uptrend."""
137 | # Create clear uptrend data
138 | dates = pd.date_range(start="2024-01-01", periods=100, freq="D")
139 | prices = pd.Series(range(100, 200), index=dates) # Linear uptrend
140 |
141 | trend = service.identify_trend(prices, period=50)
142 | assert trend in [TrendDirection.UPTREND, TrendDirection.STRONG_UPTREND]
143 |
144 | def test_identify_trend_downtrend(self, service):
145 | """Test trend identification for downtrend."""
146 | # Create clear downtrend data
147 | dates = pd.date_range(start="2024-01-01", periods=100, freq="D")
148 | prices = pd.Series(range(200, 100, -1), index=dates) # Linear downtrend
149 |
150 | trend = service.identify_trend(prices, period=50)
151 | assert trend in [TrendDirection.DOWNTREND, TrendDirection.STRONG_DOWNTREND]
152 |
153 | def test_analyze_volume(self, service):
154 | """Test volume analysis."""
155 | # Create volume data with spike
156 | dates = pd.date_range(start="2024-01-01", periods=30, freq="D")
157 | volume = pd.Series([1000000] * 29 + [3000000], index=dates) # Spike at end
158 |
159 | volume_profile = service.analyze_volume(volume, period=20)
160 |
161 | assert volume_profile.current_volume == 3000000
162 | assert volume_profile.average_volume < 1500000
163 | assert volume_profile.relative_volume > 2.0
164 | assert volume_profile.unusual_activity # 3x average is unusual
165 |
166 | def test_calculate_composite_signal_bullish(self, service):
167 | """Test composite signal calculation with bullish indicators."""
168 | # Manually create bullish indicators for testing
169 | from maverick_mcp.domain.value_objects.technical_indicators import (
170 | MACDIndicator,
171 | RSIIndicator,
172 | )
173 |
174 | bullish_rsi = RSIIndicator(value=25, period=14) # Oversold
175 | bullish_macd = MACDIndicator(
176 | macd_line=1.0,
177 | signal_line=0.5,
178 | histogram=0.5,
179 | ) # Bullish crossover
180 |
181 | signal = service.calculate_composite_signal(
182 | rsi=bullish_rsi,
183 | macd=bullish_macd,
184 | )
185 |
186 | assert signal in [Signal.BUY, Signal.STRONG_BUY]
187 |
188 | def test_calculate_composite_signal_mixed(self, service):
189 | """Test composite signal with mixed indicators."""
190 | from maverick_mcp.domain.value_objects.technical_indicators import (
191 | BollingerBands,
192 | MACDIndicator,
193 | RSIIndicator,
194 | )
195 |
196 | # Create mixed signals
197 | neutral_rsi = RSIIndicator(value=50, period=14) # Neutral
198 | bearish_macd = MACDIndicator(
199 | macd_line=-0.5,
200 | signal_line=0.0,
201 | histogram=-0.5,
202 | ) # Bearish
203 | neutral_bb = BollingerBands(
204 | upper_band=110,
205 | middle_band=100,
206 | lower_band=90,
207 | current_price=100,
208 | ) # Neutral
209 |
210 | signal = service.calculate_composite_signal(
211 | rsi=neutral_rsi,
212 | macd=bearish_macd,
213 | bollinger=neutral_bb,
214 | )
215 |
216 | # With mixed signals, should be neutral or slightly bearish
217 | assert signal in [Signal.NEUTRAL, Signal.SELL]
218 |
219 | def test_domain_service_has_no_infrastructure_dependencies(self, service):
220 | """Verify the domain service has no infrastructure dependencies."""
221 | # Check that the service has no database, API, or cache attributes
222 | assert not hasattr(service, "db")
223 | assert not hasattr(service, "session")
224 | assert not hasattr(service, "cache")
225 | assert not hasattr(service, "api_client")
226 | assert not hasattr(service, "http_client")
227 |
228 | # Check that all methods are pure functions (no side effects)
229 | # This is verified by the fact that all tests above work without mocks
230 |
```
--------------------------------------------------------------------------------
/tests/test_financial_search.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python3
2 | """
3 | Test script for enhanced financial search capabilities in DeepResearchAgent.
4 |
5 | This script demonstrates the improved Exa client usage for financial records search
6 | with different strategies and optimizations.
7 | """
8 |
9 | import asyncio
10 | import os
11 | import sys
12 | from datetime import datetime
13 |
14 | # Add the project root to the Python path
15 | sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
16 |
17 | from maverick_mcp.agents.deep_research import DeepResearchAgent, ExaSearchProvider
18 |
19 |
20 | async def test_financial_search_strategies():
21 | """Test different financial search strategies."""
22 |
23 | # Initialize the search provider
24 | exa_api_key = os.getenv("EXA_API_KEY")
25 | if not exa_api_key:
26 | print("❌ EXA_API_KEY environment variable not set")
27 | return
28 |
29 | print("🔍 Testing Enhanced Financial Search Capabilities")
30 | print("=" * 60)
31 |
32 | # Test queries for different financial scenarios
33 | test_queries = [
34 | ("AAPL financial performance", "Apple stock analysis"),
35 | ("Tesla quarterly earnings 2024", "Tesla earnings report"),
36 | ("Microsoft revenue growth", "Microsoft financial growth"),
37 | ("S&P 500 market analysis", "Market index analysis"),
38 | ("Federal Reserve interest rates", "Fed policy analysis"),
39 | ]
40 |
41 | # Test different search strategies
42 | strategies = ["hybrid", "authoritative", "comprehensive", "auto"]
43 |
44 | provider = ExaSearchProvider(exa_api_key)
45 |
46 | for query, description in test_queries:
47 | print(f"\n📊 Testing Query: {description}")
48 | print(f" Query: '{query}'")
49 | print("-" * 40)
50 |
51 | for strategy in strategies:
52 | try:
53 | start_time = datetime.now()
54 |
55 | # Test the enhanced financial search
56 | results = await provider.search_financial(
57 | query=query, num_results=5, strategy=strategy
58 | )
59 |
60 | duration = (datetime.now() - start_time).total_seconds()
61 |
62 | print(f" 🎯 Strategy: {strategy.upper()}")
63 | print(f" Results: {len(results)}")
64 | print(f" Duration: {duration:.2f}s")
65 |
66 | if results:
67 | # Show top result with enhanced metadata
68 | top_result = results[0]
69 | print(" Top Result:")
70 | print(f" Title: {top_result.get('title', 'N/A')[:80]}...")
71 | print(f" Domain: {top_result.get('domain', 'N/A')}")
72 | print(
73 | f" Financial Relevance: {top_result.get('financial_relevance', 0):.2f}"
74 | )
75 | print(
76 | f" Authoritative: {top_result.get('is_authoritative', False)}"
77 | )
78 | print(f" Score: {top_result.get('score', 0):.2f}")
79 |
80 | print()
81 |
82 | except Exception as e:
83 | print(f" ❌ Strategy {strategy} failed: {str(e)}")
84 | print()
85 |
86 |
87 | async def test_query_enhancement():
88 | """Test the financial query enhancement feature."""
89 |
90 | print("\n🔧 Testing Query Enhancement")
91 | print("=" * 40)
92 |
93 | exa_api_key = os.getenv("EXA_API_KEY")
94 | if not exa_api_key:
95 | print("❌ EXA_API_KEY environment variable not set")
96 | return
97 |
98 | provider = ExaSearchProvider(exa_api_key)
99 |
100 | # Test queries that should be enhanced
101 | test_queries = [
102 | "AAPL", # Stock symbol
103 | "Tesla company", # Company name
104 | "Microsoft analysis", # Analysis request
105 | "Amazon earnings financial", # Already has financial context
106 | ]
107 |
108 | for query in test_queries:
109 | enhanced = provider._enhance_financial_query(query)
110 | print(f"Original: '{query}'")
111 | print(f"Enhanced: '{enhanced}'")
112 | print(f"Changed: {'Yes' if enhanced != query else 'No'}")
113 | print()
114 |
115 |
116 | async def test_financial_relevance_scoring():
117 | """Test the financial relevance scoring system."""
118 |
119 | print("\n📈 Testing Financial Relevance Scoring")
120 | print("=" * 45)
121 |
122 | exa_api_key = os.getenv("EXA_API_KEY")
123 | if not exa_api_key:
124 | print("❌ EXA_API_KEY environment variable not set")
125 | return
126 |
127 | provider = ExaSearchProvider(exa_api_key)
128 |
129 | # Mock result objects for testing
130 | class MockResult:
131 | def __init__(self, url, title, text, published_date=None):
132 | self.url = url
133 | self.title = title
134 | self.text = text
135 | self.published_date = published_date
136 |
137 | test_results = [
138 | MockResult(
139 | "https://sec.gov/filing/aapl-10k-2024",
140 | "Apple Inc. Annual Report (Form 10-K)",
141 | "Apple Inc. reported quarterly earnings of $1.50 per share, with revenue of $95 billion for the quarter ending March 31, 2024.",
142 | "2024-01-15T00:00:00Z",
143 | ),
144 | MockResult(
145 | "https://bloomberg.com/news/apple-stock-analysis",
146 | "Apple Stock Analysis: Strong Financial Performance",
147 | "Apple's financial performance continues to show strong growth with increased market cap and dividend distributions.",
148 | "2024-01-10T00:00:00Z",
149 | ),
150 | MockResult(
151 | "https://example.com/random-article",
152 | "Random Article About Technology",
153 | "This is just a random article about technology trends without specific financial information.",
154 | "2024-01-01T00:00:00Z",
155 | ),
156 | ]
157 |
158 | for i, result in enumerate(test_results, 1):
159 | relevance = provider._calculate_financial_relevance(result)
160 | is_auth = provider._is_authoritative_source(result.url)
161 | domain = provider._extract_domain(result.url)
162 |
163 | print(f"Result {i}:")
164 | print(f" URL: {result.url}")
165 | print(f" Domain: {domain}")
166 | print(f" Title: {result.title}")
167 | print(f" Financial Relevance: {relevance:.2f}")
168 | print(f" Authoritative: {is_auth}")
169 | print()
170 |
171 |
172 | async def test_deep_research_agent_integration():
173 | """Test the integration with DeepResearchAgent."""
174 |
175 | print("\n🤖 Testing DeepResearchAgent Integration")
176 | print("=" * 45)
177 |
178 | exa_api_key = os.getenv("EXA_API_KEY")
179 | if not exa_api_key:
180 | print("❌ EXA_API_KEY environment variable not set")
181 | return
182 |
183 | try:
184 | # Initialize the agent
185 | agent = DeepResearchAgent(
186 | llm=None, # Will be set by initialize if needed
187 | persona="financial_analyst",
188 | exa_api_key=exa_api_key,
189 | )
190 |
191 | await agent.initialize()
192 |
193 | # Test the enhanced financial search tool
194 | result = await agent._perform_financial_search(
195 | query="Apple quarterly earnings Q4 2024",
196 | num_results=3,
197 | provider="exa",
198 | strategy="authoritative",
199 | )
200 |
201 | print(f"Search Results: {result.get('total_results', 0)} found")
202 | print(f"Strategy Used: {result.get('search_strategy', 'N/A')}")
203 | print(f"Duration: {result.get('search_duration', 0):.2f}s")
204 | print(f"Enhanced Search: {result.get('enhanced_search', False)}")
205 |
206 | if result.get("results"):
207 | print("\nTop Result:")
208 | top = result["results"][0]
209 | print(f" Title: {top.get('title', 'N/A')[:80]}...")
210 | print(f" Financial Relevance: {top.get('financial_relevance', 0):.2f}")
211 | print(f" Authoritative: {top.get('is_authoritative', False)}")
212 |
213 | except Exception as e:
214 | print(f"❌ Integration test failed: {str(e)}")
215 |
216 |
217 | async def main():
218 | """Run all tests."""
219 |
220 | print("🚀 Enhanced Financial Search Testing Suite")
221 | print("=" * 60)
222 | print(f"Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
223 |
224 | try:
225 | await test_financial_search_strategies()
226 | await test_query_enhancement()
227 | await test_financial_relevance_scoring()
228 | await test_deep_research_agent_integration()
229 |
230 | print("\n✅ All tests completed successfully!")
231 |
232 | except Exception as e:
233 | print(f"\n❌ Test suite failed: {str(e)}")
234 | import traceback
235 |
236 | traceback.print_exc()
237 |
238 | print(f"\nCompleted at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
239 |
240 |
241 | if __name__ == "__main__":
242 | asyncio.run(main())
243 |
```
--------------------------------------------------------------------------------
/maverick_mcp/utils/shutdown.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Graceful shutdown handler for MaverickMCP servers.
3 |
4 | This module provides signal handling and graceful shutdown capabilities
5 | for all server components to ensure safe deployments and prevent data loss.
6 | """
7 |
8 | import asyncio
9 | import signal
10 | import sys
11 | import time
12 | from collections.abc import Callable
13 | from contextlib import contextmanager
14 | from typing import Any
15 |
16 | from maverick_mcp.utils.logging import get_logger
17 |
18 | logger = get_logger(__name__)
19 |
20 |
21 | class GracefulShutdownHandler:
22 | """Handles graceful shutdown for server components."""
23 |
24 | def __init__(
25 | self,
26 | name: str,
27 | shutdown_timeout: float = 30.0,
28 | drain_timeout: float = 10.0,
29 | ):
30 | """
31 | Initialize shutdown handler.
32 |
33 | Args:
34 | name: Name of the component for logging
35 | shutdown_timeout: Maximum time to wait for shutdown (seconds)
36 | drain_timeout: Time to wait for connection draining (seconds)
37 | """
38 | self.name = name
39 | self.shutdown_timeout = shutdown_timeout
40 | self.drain_timeout = drain_timeout
41 | self._shutdown_event = asyncio.Event()
42 | self._cleanup_callbacks: list[Callable] = []
43 | self._active_requests: set[asyncio.Task] = set()
44 | self._original_handlers: dict[int, Any] = {}
45 | self._shutdown_in_progress = False
46 | self._start_time = time.time()
47 |
48 | def register_cleanup(self, callback: Callable) -> None:
49 | """Register a cleanup callback to run during shutdown."""
50 | self._cleanup_callbacks.append(callback)
51 | logger.debug(f"Registered cleanup callback: {callback.__name__}")
52 |
53 | def track_request(self, task: asyncio.Task) -> None:
54 | """Track an active request/task."""
55 | self._active_requests.add(task)
56 | task.add_done_callback(self._active_requests.discard)
57 |
58 | @contextmanager
59 | def track_sync_request(self):
60 | """Context manager to track synchronous requests."""
61 | request_id = id(asyncio.current_task()) if asyncio.current_task() else None
62 | try:
63 | if request_id:
64 | logger.debug(f"Tracking sync request: {request_id}")
65 | yield
66 | finally:
67 | if request_id:
68 | logger.debug(f"Completed sync request: {request_id}")
69 |
70 | async def wait_for_shutdown(self) -> None:
71 | """Wait for shutdown signal."""
72 | await self._shutdown_event.wait()
73 |
74 | def is_shutting_down(self) -> bool:
75 | """Check if shutdown is in progress."""
76 | return self._shutdown_in_progress
77 |
78 | def install_signal_handlers(self) -> None:
79 | """Install signal handlers for graceful shutdown."""
80 | # Store original handlers
81 | for sig in (signal.SIGTERM, signal.SIGINT):
82 | self._original_handlers[sig] = signal.signal(sig, self._signal_handler)
83 |
84 | # Also handle SIGHUP for reload scenarios
85 | if hasattr(signal, "SIGHUP"):
86 | self._original_handlers[signal.SIGHUP] = signal.signal(
87 | signal.SIGHUP, self._signal_handler
88 | )
89 |
90 | logger.info(f"{self.name}: Signal handlers installed")
91 |
92 | def _signal_handler(self, signum: int, frame: Any) -> None:
93 | """Handle shutdown signals."""
94 | signal_name = signal.Signals(signum).name
95 | logger.info(f"{self.name}: Received {signal_name} signal")
96 |
97 | if self._shutdown_in_progress:
98 | logger.warning(
99 | f"{self.name}: Shutdown already in progress, ignoring signal"
100 | )
101 | return
102 |
103 | # Trigger async shutdown
104 | if asyncio.get_event_loop().is_running():
105 | asyncio.create_task(self._async_shutdown(signal_name))
106 | else:
107 | # Fallback for non-async context
108 | self._sync_shutdown(signal_name)
109 |
110 | async def _async_shutdown(self, signal_name: str) -> None:
111 | """Perform async graceful shutdown."""
112 | if self._shutdown_in_progress:
113 | return
114 |
115 | self._shutdown_in_progress = True
116 | shutdown_start = time.time()
117 |
118 | logger.info(
119 | f"{self.name}: Starting graceful shutdown (signal: {signal_name}, "
120 | f"uptime: {shutdown_start - self._start_time:.1f}s)"
121 | )
122 |
123 | # Set shutdown event to notify waiting coroutines
124 | self._shutdown_event.set()
125 |
126 | # Phase 1: Stop accepting new requests
127 | logger.info(f"{self.name}: Phase 1 - Stopping new requests")
128 |
129 | # Phase 2: Drain active requests
130 | if self._active_requests:
131 | logger.info(
132 | f"{self.name}: Phase 2 - Draining {len(self._active_requests)} "
133 | f"active requests (timeout: {self.drain_timeout}s)"
134 | )
135 |
136 | try:
137 | await asyncio.wait_for(
138 | self._wait_for_requests(),
139 | timeout=self.drain_timeout,
140 | )
141 | logger.info(f"{self.name}: All requests completed")
142 | except TimeoutError:
143 | remaining = len(self._active_requests)
144 | logger.warning(
145 | f"{self.name}: Drain timeout reached, {remaining} requests remaining"
146 | )
147 | # Cancel remaining requests
148 | for task in self._active_requests:
149 | task.cancel()
150 |
151 | # Phase 3: Run cleanup callbacks
152 | logger.info(f"{self.name}: Phase 3 - Running cleanup callbacks")
153 | for callback in self._cleanup_callbacks:
154 | try:
155 | logger.debug(f"Running cleanup: {callback.__name__}")
156 | if asyncio.iscoroutinefunction(callback):
157 | await asyncio.wait_for(callback(), timeout=5.0)
158 | else:
159 | callback()
160 | except Exception as e:
161 | logger.error(f"Error in cleanup callback {callback.__name__}: {e}")
162 |
163 | # Phase 4: Final shutdown
164 | shutdown_duration = time.time() - shutdown_start
165 | logger.info(
166 | f"{self.name}: Graceful shutdown completed in {shutdown_duration:.1f}s"
167 | )
168 |
169 | # Exit the process
170 | sys.exit(0)
171 |
172 | def _sync_shutdown(self, signal_name: str) -> None:
173 | """Perform synchronous shutdown (fallback)."""
174 | if self._shutdown_in_progress:
175 | return
176 |
177 | self._shutdown_in_progress = True
178 | logger.info(f"{self.name}: Starting sync shutdown (signal: {signal_name})")
179 |
180 | # Run sync cleanup callbacks
181 | for callback in self._cleanup_callbacks:
182 | if not asyncio.iscoroutinefunction(callback):
183 | try:
184 | callback()
185 | except Exception as e:
186 | logger.error(f"Error in cleanup callback: {e}")
187 |
188 | logger.info(f"{self.name}: Sync shutdown completed")
189 | sys.exit(0)
190 |
191 | async def _wait_for_requests(self) -> None:
192 | """Wait for all active requests to complete."""
193 | while self._active_requests:
194 | # Wait a bit and check again
195 | await asyncio.sleep(0.1)
196 |
197 | # Log progress periodically
198 | if int(time.time()) % 5 == 0:
199 | logger.info(
200 | f"{self.name}: Waiting for {len(self._active_requests)} requests"
201 | )
202 |
203 | def restore_signal_handlers(self) -> None:
204 | """Restore original signal handlers."""
205 | for sig, handler in self._original_handlers.items():
206 | signal.signal(sig, handler)
207 | logger.debug(f"{self.name}: Signal handlers restored")
208 |
209 |
210 | # Global shutdown handler instance
211 | _shutdown_handler: GracefulShutdownHandler | None = None
212 |
213 |
214 | def get_shutdown_handler(
215 | name: str = "Server",
216 | shutdown_timeout: float = 30.0,
217 | drain_timeout: float = 10.0,
218 | ) -> GracefulShutdownHandler:
219 | """Get or create the global shutdown handler."""
220 | global _shutdown_handler
221 | if _shutdown_handler is None:
222 | _shutdown_handler = GracefulShutdownHandler(
223 | name, shutdown_timeout, drain_timeout
224 | )
225 | return _shutdown_handler
226 |
227 |
228 | @contextmanager
229 | def graceful_shutdown(
230 | name: str = "Server",
231 | shutdown_timeout: float = 30.0,
232 | drain_timeout: float = 10.0,
233 | ):
234 | """Context manager for graceful shutdown handling."""
235 | handler = get_shutdown_handler(name, shutdown_timeout, drain_timeout)
236 | handler.install_signal_handlers()
237 |
238 | try:
239 | yield handler
240 | finally:
241 | handler.restore_signal_handlers()
242 |
```
--------------------------------------------------------------------------------
/maverick_mcp/providers/implementations/cache_adapter.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Cache manager adapter.
3 |
4 | This module provides adapters that make the existing cache system
5 | compatible with the new ICacheManager interface.
6 | """
7 |
8 | import asyncio
9 | import logging
10 | from typing import Any
11 |
12 | from maverick_mcp.data.cache import (
13 | CacheManager as ExistingCacheManager,
14 | )
15 | from maverick_mcp.data.cache import (
16 | clear_cache,
17 | get_from_cache,
18 | save_to_cache,
19 | )
20 | from maverick_mcp.providers.interfaces.cache import CacheConfig, ICacheManager
21 |
22 | logger = logging.getLogger(__name__)
23 |
24 |
25 | class RedisCacheAdapter(ICacheManager):
26 | """
27 | Adapter that makes the existing cache system compatible with ICacheManager interface.
28 |
29 | This adapter wraps the existing cache functions and CacheManager class
30 | to provide the new interface while maintaining all existing functionality.
31 | """
32 |
33 | def __init__(self, config: CacheConfig | None = None):
34 | """
35 | Initialize the cache adapter.
36 |
37 | Args:
38 | config: Cache configuration (optional, defaults to environment)
39 | """
40 | self._config = config
41 | self._cache_manager = ExistingCacheManager()
42 |
43 | logger.debug("RedisCacheAdapter initialized")
44 |
45 | async def get(self, key: str) -> Any:
46 | """
47 | Get data from cache (async wrapper).
48 |
49 | Args:
50 | key: Cache key to retrieve
51 |
52 | Returns:
53 | Cached data or None if not found or expired
54 | """
55 | loop = asyncio.get_event_loop()
56 | return await loop.run_in_executor(None, get_from_cache, key)
57 |
58 | async def set(self, key: str, value: Any, ttl: int | None = None) -> bool:
59 | """
60 | Store data in cache (async wrapper).
61 |
62 | Args:
63 | key: Cache key
64 | value: Data to cache (must be JSON serializable)
65 | ttl: Time-to-live in seconds (None for default TTL)
66 |
67 | Returns:
68 | True if successfully cached, False otherwise
69 | """
70 | loop = asyncio.get_event_loop()
71 | return await loop.run_in_executor(None, save_to_cache, key, value, ttl)
72 |
73 | async def delete(self, key: str) -> bool:
74 | """
75 | Delete a key from cache.
76 |
77 | Args:
78 | key: Cache key to delete
79 |
80 | Returns:
81 | True if key was deleted, False if key didn't exist
82 | """
83 | return await self._cache_manager.delete(key)
84 |
85 | async def exists(self, key: str) -> bool:
86 | """
87 | Check if a key exists in cache.
88 |
89 | Args:
90 | key: Cache key to check
91 |
92 | Returns:
93 | True if key exists and hasn't expired, False otherwise
94 | """
95 | return await self._cache_manager.exists(key)
96 |
97 | async def clear(self, pattern: str | None = None) -> int:
98 | """
99 | Clear cache entries.
100 |
101 | Args:
102 | pattern: Pattern to match keys (e.g., "stock:*")
103 | If None, clears all cache entries
104 |
105 | Returns:
106 | Number of entries cleared
107 | """
108 | loop = asyncio.get_event_loop()
109 | return await loop.run_in_executor(None, clear_cache, pattern)
110 |
111 | async def get_many(self, keys: list[str]) -> dict[str, Any]:
112 | """
113 | Get multiple values at once for better performance.
114 |
115 | Args:
116 | keys: List of cache keys to retrieve
117 |
118 | Returns:
119 | Dictionary mapping keys to their cached values
120 | (missing keys will not be in the result)
121 | """
122 | return await self._cache_manager.get_many(keys)
123 |
124 | async def set_many(self, items: list[tuple[str, Any, int | None]]) -> int:
125 | """
126 | Set multiple values at once for better performance.
127 |
128 | Args:
129 | items: List of tuples (key, value, ttl)
130 |
131 | Returns:
132 | Number of items successfully cached
133 | """
134 | return await self._cache_manager.batch_save(items)
135 |
136 | async def delete_many(self, keys: list[str]) -> int:
137 | """
138 | Delete multiple keys for better performance.
139 |
140 | Args:
141 | keys: List of keys to delete
142 |
143 | Returns:
144 | Number of keys successfully deleted
145 | """
146 | return await self._cache_manager.batch_delete(keys)
147 |
148 | async def exists_many(self, keys: list[str]) -> dict[str, bool]:
149 | """
150 | Check existence of multiple keys for better performance.
151 |
152 | Args:
153 | keys: List of keys to check
154 |
155 | Returns:
156 | Dictionary mapping keys to their existence status
157 | """
158 | return await self._cache_manager.batch_exists(keys)
159 |
160 | async def count_keys(self, pattern: str) -> int:
161 | """
162 | Count keys matching a pattern.
163 |
164 | Args:
165 | pattern: Pattern to match (e.g., "stock:*")
166 |
167 | Returns:
168 | Number of matching keys
169 | """
170 | return await self._cache_manager.count_keys(pattern)
171 |
172 | async def get_or_set(
173 | self, key: str, default_value: Any, ttl: int | None = None
174 | ) -> Any:
175 | """
176 | Get value from cache, setting it if it doesn't exist.
177 |
178 | Args:
179 | key: Cache key
180 | default_value: Value to set if key doesn't exist
181 | ttl: Time-to-live for the default value
182 |
183 | Returns:
184 | Either the existing cached value or the default value
185 | """
186 | # Check if key exists
187 | existing_value = await self.get(key)
188 | if existing_value is not None:
189 | return existing_value
190 |
191 | # Set default value and return it
192 | await self.set(key, default_value, ttl)
193 | return default_value
194 |
195 | async def increment(self, key: str, amount: int = 1) -> int:
196 | """
197 | Increment a numeric value in cache.
198 |
199 | Args:
200 | key: Cache key
201 | amount: Amount to increment by
202 |
203 | Returns:
204 | New value after increment
205 |
206 | Raises:
207 | ValueError: If the key exists but doesn't contain a numeric value
208 | """
209 | # Get current value
210 | current = await self.get(key)
211 |
212 | if current is None:
213 | # Key doesn't exist, start from 0
214 | new_value = amount
215 | else:
216 | # Try to convert to int and increment
217 | try:
218 | current_int = int(current)
219 | new_value = current_int + amount
220 | except (ValueError, TypeError):
221 | raise ValueError(f"Key {key} contains non-numeric value: {current}")
222 |
223 | # Set the new value
224 | await self.set(key, new_value)
225 | return new_value
226 |
227 | async def set_if_not_exists(
228 | self, key: str, value: Any, ttl: int | None = None
229 | ) -> bool:
230 | """
231 | Set a value only if the key doesn't already exist.
232 |
233 | Args:
234 | key: Cache key
235 | value: Value to set
236 | ttl: Time-to-live in seconds
237 |
238 | Returns:
239 | True if the value was set, False if key already existed
240 | """
241 | # Check if key already exists
242 | if await self.exists(key):
243 | return False
244 |
245 | # Key doesn't exist, set the value
246 | return await self.set(key, value, ttl)
247 |
248 | async def get_ttl(self, key: str) -> int | None:
249 | """
250 | Get the remaining time-to-live for a key.
251 |
252 | Args:
253 | key: Cache key
254 |
255 | Returns:
256 | Remaining TTL in seconds, None if key doesn't exist or has no TTL
257 | """
258 | # This would need to be implemented in the underlying cache manager
259 | # For now, return None as we don't have TTL introspection in the existing system
260 | logger.warning(f"TTL introspection not implemented for key: {key}")
261 | return None
262 |
263 | async def expire(self, key: str, ttl: int) -> bool:
264 | """
265 | Set expiration time for an existing key.
266 |
267 | Args:
268 | key: Cache key
269 | ttl: Time-to-live in seconds
270 |
271 | Returns:
272 | True if expiration was set, False if key doesn't exist
273 | """
274 | # Check if key exists
275 | if not await self.exists(key):
276 | return False
277 |
278 | # Get current value and re-set with new TTL
279 | current_value = await self.get(key)
280 | if current_value is not None:
281 | return await self.set(key, current_value, ttl)
282 |
283 | return False
284 |
285 | def get_sync_cache_manager(self) -> ExistingCacheManager:
286 | """
287 | Get the underlying synchronous cache manager for backward compatibility.
288 |
289 | Returns:
290 | The wrapped CacheManager instance
291 | """
292 | return self._cache_manager
293 |
```
--------------------------------------------------------------------------------
/maverick_mcp/config/security_utils.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Security utilities for applying centralized security configuration.
3 |
4 | This module provides utility functions to apply the SecurityConfig
5 | across different server implementations consistently.
6 | """
7 |
8 | from fastapi import FastAPI
9 | from fastapi.middleware.cors import CORSMiddleware
10 | from fastapi.middleware.trustedhost import TrustedHostMiddleware
11 | from starlette.applications import Starlette
12 | from starlette.middleware import Middleware
13 | from starlette.middleware.base import BaseHTTPMiddleware
14 | from starlette.middleware.cors import CORSMiddleware as StarletteCORSMiddleware
15 | from starlette.requests import Request
16 |
17 | from maverick_mcp.config.security import get_security_config, validate_security_config
18 | from maverick_mcp.utils.logging import get_logger
19 |
20 | logger = get_logger(__name__)
21 |
22 |
23 | class SecurityHeadersMiddleware(BaseHTTPMiddleware):
24 | """Middleware to add security headers based on SecurityConfig."""
25 |
26 | def __init__(self, app, security_config=None):
27 | super().__init__(app)
28 | self.security_config = security_config or get_security_config()
29 |
30 | async def dispatch(self, request: Request, call_next):
31 | response = await call_next(request)
32 |
33 | # Add security headers
34 | headers = self.security_config.get_security_headers()
35 | for name, value in headers.items():
36 | response.headers[name] = value
37 |
38 | return response
39 |
40 |
41 | def apply_cors_to_fastapi(app: FastAPI, security_config=None) -> None:
42 | """Apply CORS configuration to FastAPI app using SecurityConfig."""
43 | config = security_config or get_security_config()
44 |
45 | # Validate security before applying
46 | validation = validate_security_config()
47 | if not validation["valid"]:
48 | logger.error(f"Security validation failed: {validation['issues']}")
49 | for issue in validation["issues"]:
50 | logger.error(f"SECURITY ISSUE: {issue}")
51 | raise ValueError(f"Security configuration is invalid: {validation['issues']}")
52 |
53 | if validation["warnings"]:
54 | for warning in validation["warnings"]:
55 | logger.warning(f"SECURITY WARNING: {warning}")
56 |
57 | # Apply CORS middleware
58 | cors_config = config.get_cors_middleware_config()
59 | app.add_middleware(CORSMiddleware, **cors_config)
60 |
61 | logger.info(
62 | f"CORS configured for {config.environment} environment: "
63 | f"origins={cors_config['allow_origins']}, "
64 | f"credentials={cors_config['allow_credentials']}"
65 | )
66 |
67 |
68 | def apply_cors_to_starlette(app: Starlette, security_config=None) -> list[Middleware]:
69 | """Get CORS middleware configuration for Starlette app using SecurityConfig."""
70 | config = security_config or get_security_config()
71 |
72 | # Validate security before applying
73 | validation = validate_security_config()
74 | if not validation["valid"]:
75 | logger.error(f"Security validation failed: {validation['issues']}")
76 | for issue in validation["issues"]:
77 | logger.error(f"SECURITY ISSUE: {issue}")
78 | raise ValueError(f"Security configuration is invalid: {validation['issues']}")
79 |
80 | if validation["warnings"]:
81 | for warning in validation["warnings"]:
82 | logger.warning(f"SECURITY WARNING: {warning}")
83 |
84 | # Create middleware configuration
85 | cors_config = config.get_cors_middleware_config()
86 |
87 | middleware_list = [
88 | Middleware(StarletteCORSMiddleware, **cors_config),
89 | Middleware(SecurityHeadersMiddleware, security_config=config),
90 | ]
91 |
92 | logger.info(
93 | f"Starlette CORS configured for {config.environment} environment: "
94 | f"origins={cors_config['allow_origins']}, "
95 | f"credentials={cors_config['allow_credentials']}"
96 | )
97 |
98 | return middleware_list
99 |
100 |
101 | def apply_trusted_hosts_to_fastapi(app: FastAPI, security_config=None) -> None:
102 | """Apply trusted hosts configuration to FastAPI app."""
103 | config = security_config or get_security_config()
104 |
105 | # Only enforce in production or when strict security is enabled
106 | if config.is_production() or config.strict_security:
107 | app.add_middleware(
108 | TrustedHostMiddleware, allowed_hosts=config.trusted_hosts.allowed_hosts
109 | )
110 | logger.info(f"Trusted hosts configured: {config.trusted_hosts.allowed_hosts}")
111 | elif config.trusted_hosts.enforce_in_development:
112 | app.add_middleware(
113 | TrustedHostMiddleware, allowed_hosts=config.trusted_hosts.allowed_hosts
114 | )
115 | logger.info(
116 | f"Trusted hosts configured for development: {config.trusted_hosts.allowed_hosts}"
117 | )
118 | else:
119 | logger.info("Trusted hosts validation disabled for development")
120 |
121 |
122 | def apply_security_headers_to_fastapi(app: FastAPI, security_config=None) -> None:
123 | """Apply security headers middleware to FastAPI app."""
124 | config = security_config or get_security_config()
125 | app.add_middleware(SecurityHeadersMiddleware, security_config=config)
126 | logger.info("Security headers middleware applied")
127 |
128 |
129 | def get_safe_cors_config() -> dict:
130 | """Get a safe CORS configuration that prevents common vulnerabilities."""
131 | config = get_security_config()
132 |
133 | # Validate the configuration
134 | validation = validate_security_config()
135 | if not validation["valid"]:
136 | logger.error("Using fallback safe CORS configuration due to validation errors")
137 |
138 | # Return a safe fallback configuration
139 | if config.is_production():
140 | return {
141 | "allow_origins": ["https://maverick-mcp.com"],
142 | "allow_credentials": True,
143 | "allow_methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS"],
144 | "allow_headers": ["Authorization", "Content-Type"],
145 | "expose_headers": [],
146 | "max_age": 86400,
147 | }
148 | else:
149 | return {
150 | "allow_origins": ["http://localhost:3000"],
151 | "allow_credentials": True,
152 | "allow_methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS"],
153 | "allow_headers": ["Authorization", "Content-Type"],
154 | "expose_headers": [],
155 | "max_age": 86400,
156 | }
157 |
158 | return config.get_cors_middleware_config()
159 |
160 |
161 | def log_security_status() -> None:
162 | """Log current security configuration status."""
163 | config = get_security_config()
164 | validation = validate_security_config()
165 |
166 | logger.info("=== Security Configuration Status ===")
167 | logger.info(f"Environment: {config.environment}")
168 | logger.info(f"Force HTTPS: {config.force_https}")
169 | logger.info(f"Strict Security: {config.strict_security}")
170 | logger.info(f"CORS Origins: {config.cors.allowed_origins}")
171 | logger.info(f"CORS Credentials: {config.cors.allow_credentials}")
172 | logger.info(f"Rate Limiting: {config.rate_limiting.enabled}")
173 | logger.info(f"Trusted Hosts: {config.trusted_hosts.allowed_hosts}")
174 |
175 | if validation["valid"]:
176 | logger.info("✅ Security validation: PASSED")
177 | else:
178 | logger.error("❌ Security validation: FAILED")
179 | for issue in validation["issues"]:
180 | logger.error(f" - {issue}")
181 |
182 | if validation["warnings"]:
183 | logger.warning("⚠️ Security warnings:")
184 | for warning in validation["warnings"]:
185 | logger.warning(f" - {warning}")
186 |
187 | logger.info("=====================================")
188 |
189 |
190 | def create_secure_fastapi_app(
191 | title: str = "Maverick MCP API",
192 | description: str = "Secure API with centralized security configuration",
193 | version: str = "1.0.0",
194 | **kwargs,
195 | ) -> FastAPI:
196 | """Create a FastAPI app with security configuration applied."""
197 | app = FastAPI(title=title, description=description, version=version, **kwargs)
198 |
199 | # Apply security configuration
200 | apply_trusted_hosts_to_fastapi(app)
201 | apply_cors_to_fastapi(app)
202 | apply_security_headers_to_fastapi(app)
203 |
204 | # Log security status
205 | log_security_status()
206 |
207 | return app
208 |
209 |
210 | def create_secure_starlette_middleware() -> list[Middleware]:
211 | """Create Starlette middleware list with security configuration."""
212 | config = get_security_config()
213 |
214 | # Start with CORS and security headers
215 | middleware_list = apply_cors_to_starlette(None, config)
216 |
217 | # Log security status
218 | log_security_status()
219 |
220 | return middleware_list
221 |
222 |
223 | # Export validation function for easy access
224 | def check_security_config() -> bool:
225 | """Check if security configuration is valid."""
226 | validation = validate_security_config()
227 | return validation["valid"]
228 |
```
--------------------------------------------------------------------------------
/scripts/dev.sh:
--------------------------------------------------------------------------------
```bash
1 | #!/bin/bash
2 |
3 | # Maverick-MCP Development Script
4 | # This script starts the backend MCP server for personal stock analysis
5 |
6 | set -e
7 |
8 | # Colors for output
9 | RED='\033[0;31m'
10 | GREEN='\033[0;32m'
11 | YELLOW='\033[1;33m'
12 | NC='\033[0m' # No Color
13 |
14 | echo -e "${GREEN}Starting Maverick-MCP Development Environment${NC}"
15 |
16 | # Kill any existing processes on port 8003 to avoid conflicts
17 | echo -e "${YELLOW}Checking for existing processes on port 8003...${NC}"
18 | EXISTING_PID=$(lsof -ti:8003 2>/dev/null || true)
19 | if [ ! -z "$EXISTING_PID" ]; then
20 | echo -e "${YELLOW}Found existing process(es) on port 8003: $EXISTING_PID${NC}"
21 | echo -e "${YELLOW}Killing existing processes...${NC}"
22 | kill -9 $EXISTING_PID 2>/dev/null || true
23 | sleep 1
24 | else
25 | echo -e "${GREEN}No existing processes found on port 8003${NC}"
26 | fi
27 |
28 | # Check if Redis is running
29 | if ! pgrep -x "redis-server" > /dev/null; then
30 | echo -e "${YELLOW}Starting Redis...${NC}"
31 | if command -v brew &> /dev/null; then
32 | brew services start redis
33 | else
34 | redis-server --daemonize yes
35 | fi
36 | else
37 | echo -e "${GREEN}Redis is already running${NC}"
38 | fi
39 |
40 | # Function to cleanup on exit
41 | cleanup() {
42 | echo -e "\n${YELLOW}Shutting down services...${NC}"
43 | # Kill backend process
44 | if [ ! -z "$BACKEND_PID" ]; then
45 | kill $BACKEND_PID 2>/dev/null || true
46 | fi
47 | echo -e "${GREEN}Development environment stopped${NC}"
48 | exit 0
49 | }
50 |
51 | # Set trap to cleanup on script exit
52 | trap cleanup EXIT INT TERM
53 |
54 | # Start backend
55 | echo -e "${YELLOW}Starting backend MCP server...${NC}"
56 | cd "$(dirname "$0")/.."
57 | echo -e "${YELLOW}Current directory: $(pwd)${NC}"
58 |
59 | # Source .env if it exists
60 | if [ -f .env ]; then
61 | source .env
62 | fi
63 |
64 | # Check if uv is available (more relevant than python since we use uv run)
65 | if ! command -v uv &> /dev/null; then
66 | echo -e "${RED}uv not found! Please install uv: curl -LsSf https://astral.sh/uv/install.sh | sh${NC}"
67 | exit 1
68 | fi
69 |
70 | # Validate critical environment variables
71 | echo -e "${YELLOW}Validating environment...${NC}"
72 | if [ -z "$TIINGO_API_KEY" ]; then
73 | echo -e "${RED}Warning: TIINGO_API_KEY not set - stock data tools may not work${NC}"
74 | fi
75 |
76 | if [ -z "$EXA_API_KEY" ] && [ -z "$TAVILY_API_KEY" ]; then
77 | echo -e "${RED}Warning: Neither EXA_API_KEY nor TAVILY_API_KEY set - research tools may be limited${NC}"
78 | fi
79 |
80 | # Choose transport based on environment variable or default to SSE for reliability
81 | TRANSPORT=${MAVERICK_TRANSPORT:-sse}
82 | echo -e "${YELLOW}Starting backend with: uv run python -m maverick_mcp.api.server --transport ${TRANSPORT} --host 0.0.0.0 --port 8003${NC}"
83 | echo -e "${YELLOW}Transport: ${TRANSPORT} (recommended for Claude Desktop stability)${NC}"
84 |
85 | # Run backend with FastMCP in development mode (show real-time output)
86 | echo -e "${YELLOW}Starting server with real-time output...${NC}"
87 | # Set PYTHONWARNINGS to suppress websockets deprecation warnings from uvicorn
88 | PYTHONWARNINGS="ignore::DeprecationWarning:websockets.*,ignore::DeprecationWarning:uvicorn.*" \
89 | uv run python -m maverick_mcp.api.server --transport ${TRANSPORT} --host 0.0.0.0 --port 8003 2>&1 | tee backend.log &
90 | BACKEND_PID=$!
91 | echo -e "${YELLOW}Backend PID: $BACKEND_PID${NC}"
92 |
93 | # Wait for backend to start
94 | echo -e "${YELLOW}Waiting for backend to start...${NC}"
95 |
96 | # Wait up to 45 seconds for the backend to start and tools to register
97 | TOOLS_REGISTERED=false
98 | for i in {1..45}; do
99 | # Check if backend process is still running first
100 | if ! kill -0 $BACKEND_PID 2>/dev/null; then
101 | echo -e "${RED}Backend process died! Check output above for errors.${NC}"
102 | exit 1
103 | fi
104 |
105 | # Check if port is open
106 | if nc -z localhost 8003 2>/dev/null || curl -s http://localhost:8003/health >/dev/null 2>&1; then
107 | if [ "$TOOLS_REGISTERED" = false ]; then
108 | echo -e "${GREEN}Backend port is open, checking for tool registration...${NC}"
109 |
110 | # Check backend.log for tool registration messages
111 | if grep -q "Research tools registered successfully" backend.log 2>/dev/null ||
112 | grep -q "Tool registration process completed" backend.log 2>/dev/null ||
113 | grep -q "Tools registered successfully" backend.log 2>/dev/null; then
114 | echo -e "${GREEN}Research tools successfully registered!${NC}"
115 | TOOLS_REGISTERED=true
116 | break
117 | else
118 | echo -e "${YELLOW}Backend running but tools not yet registered... ($i/45)${NC}"
119 | fi
120 | fi
121 | else
122 | echo -e "${YELLOW}Still waiting for backend to start... ($i/45)${NC}"
123 | fi
124 |
125 | if [ $i -eq 45 ]; then
126 | echo -e "${RED}Backend failed to fully initialize after 45 seconds!${NC}"
127 | echo -e "${RED}Server may be running but tools not registered. Check output above.${NC}"
128 | # Don't exit - let it continue in case tools load later
129 | fi
130 |
131 | sleep 1
132 | done
133 |
134 | if [ "$TOOLS_REGISTERED" = true ]; then
135 | echo -e "${GREEN}Backend is ready with tools registered!${NC}"
136 | else
137 | echo -e "${YELLOW}Backend appears to be running but tool registration status unclear${NC}"
138 | fi
139 |
140 | echo -e "${GREEN}Backend started successfully on http://localhost:8003${NC}"
141 |
142 | # Show information
143 | echo -e "\n${GREEN}Development environment is running!${NC}"
144 | echo -e "${YELLOW}MCP Server:${NC} http://localhost:8003"
145 | echo -e "${YELLOW}Health Check:${NC} http://localhost:8003/health"
146 |
147 | # Show endpoint based on transport type
148 | if [ "$TRANSPORT" = "sse" ]; then
149 | echo -e "${YELLOW}MCP SSE Endpoint:${NC} http://localhost:8003/sse/"
150 | elif [ "$TRANSPORT" = "streamable-http" ]; then
151 | echo -e "${YELLOW}MCP HTTP Endpoint:${NC} http://localhost:8003/mcp"
152 | echo -e "${YELLOW}Test with curl:${NC} curl -X POST http://localhost:8003/mcp"
153 | elif [ "$TRANSPORT" = "stdio" ]; then
154 | echo -e "${YELLOW}MCP Transport:${NC} STDIO (no HTTP endpoint)"
155 | fi
156 |
157 | echo -e "${YELLOW}Logs:${NC} tail -f backend.log"
158 |
159 | if [ "$TOOLS_REGISTERED" = true ]; then
160 | echo -e "\n${GREEN}✓ Research tools are registered and ready${NC}"
161 | else
162 | echo -e "\n${YELLOW}⚠ Tool registration status unclear${NC}"
163 | echo -e "${YELLOW}Debug: Check backend.log for tool registration messages${NC}"
164 | echo -e "${YELLOW}Debug: Look for 'Successfully registered' or 'research tools' in logs${NC}"
165 | fi
166 |
167 | echo -e "\n${YELLOW}Claude Desktop Configuration:${NC}"
168 | if [ "$TRANSPORT" = "sse" ]; then
169 | echo -e "${GREEN}SSE Transport (tested and stable):${NC}"
170 | echo -e '{"mcpServers": {"maverick-mcp": {"command": "npx", "args": ["-y", "mcp-remote", "http://localhost:8003/sse/"]}}}'
171 | elif [ "$TRANSPORT" = "stdio" ]; then
172 | echo -e "${GREEN}STDIO Transport (direct connection):${NC}"
173 | echo -e '{"mcpServers": {"maverick-mcp": {"command": "uv", "args": ["run", "python", "-m", "maverick_mcp.api.server", "--transport", "stdio"], "cwd": "'$(pwd)'"}}}'
174 | elif [ "$TRANSPORT" = "streamable-http" ]; then
175 | echo -e "${GREEN}Streamable-HTTP Transport (for testing):${NC}"
176 | echo -e '{"mcpServers": {"maverick-mcp": {"command": "npx", "args": ["-y", "mcp-remote", "http://localhost:8003/mcp"]}}}'
177 | else
178 | echo -e '{"mcpServers": {"maverick-mcp": {"command": "npx", "args": ["-y", "mcp-remote", "http://localhost:8003/mcp"]}}}'
179 | fi
180 |
181 | echo -e "\n${YELLOW}Connection Stability Features:${NC}"
182 | if [ "$TRANSPORT" = "sse" ]; then
183 | echo -e " • SSE transport (tested and stable for Claude Desktop)"
184 | echo -e " • Uses mcp-remote bridge for reliable connection"
185 | echo -e " • Prevents tools from disappearing"
186 | echo -e " • Persistent connection with session management"
187 | echo -e " • Adaptive timeout system for research tools"
188 | elif [ "$TRANSPORT" = "stdio" ]; then
189 | echo -e " • Direct STDIO transport (no network layer)"
190 | echo -e " • No mcp-remote needed (direct Claude Desktop integration)"
191 | echo -e " • No session management issues"
192 | echo -e " • No timeout problems"
193 | elif [ "$TRANSPORT" = "streamable-http" ]; then
194 | echo -e " • Streamable-HTTP transport (FastMCP 2.0 standard)"
195 | echo -e " • Uses mcp-remote bridge for Claude Desktop"
196 | echo -e " • Ideal for testing with curl/Postman/REST clients"
197 | echo -e " • Good for debugging transport-specific issues"
198 | echo -e " • Alternative to SSE for compatibility testing"
199 | else
200 | echo -e " • HTTP transport with mcp-remote bridge"
201 | echo -e " • Alternative to SSE for compatibility"
202 | echo -e " • Single process management"
203 | fi
204 | echo -e "\nPress Ctrl+C to stop the server"
205 |
206 | # Wait for process
207 | wait
```
--------------------------------------------------------------------------------
/maverick_mcp/api/routers/screening.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Stock screening router for Maverick-MCP.
3 |
4 | This module contains all stock screening related tools including
5 | Maverick, supply/demand breakouts, and other screening strategies.
6 | """
7 |
8 | import logging
9 | from typing import Any
10 |
11 | from fastmcp import FastMCP
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 | # Create the screening router
16 | screening_router: FastMCP = FastMCP("Stock_Screening")
17 |
18 |
19 | def get_maverick_stocks(limit: int = 20) -> dict[str, Any]:
20 | """
21 | Get top Maverick stocks from the screening results.
22 |
23 | DISCLAIMER: Stock screening results are for educational and research purposes only.
24 | This is not investment advice. Past performance does not guarantee future results.
25 | Always conduct thorough research and consult financial professionals before investing.
26 |
27 | The Maverick screening strategy identifies stocks with:
28 | - High momentum strength
29 | - Technical patterns (Cup & Handle, consolidation, etc.)
30 | - Momentum characteristics
31 | - Strong combined scores
32 |
33 | Args:
34 | limit: Maximum number of stocks to return (default: 20)
35 |
36 | Returns:
37 | Dictionary containing Maverick stock screening results
38 | """
39 | try:
40 | from maverick_mcp.data.models import MaverickStocks, SessionLocal
41 |
42 | with SessionLocal() as session:
43 | stocks = MaverickStocks.get_top_stocks(session, limit=limit)
44 |
45 | return {
46 | "status": "success",
47 | "count": len(stocks),
48 | "stocks": [stock.to_dict() for stock in stocks],
49 | "screening_type": "maverick_bullish",
50 | "description": "High momentum stocks with bullish technical setups",
51 | }
52 | except Exception as e:
53 | logger.error(f"Error fetching Maverick stocks: {str(e)}")
54 | return {"error": str(e), "status": "error"}
55 |
56 |
57 | def get_maverick_bear_stocks(limit: int = 20) -> dict[str, Any]:
58 | """
59 | Get top Maverick Bear stocks from the screening results.
60 |
61 | DISCLAIMER: Bearish screening results are for educational purposes only.
62 | This is not advice to sell short or make bearish trades. Short selling involves
63 | unlimited risk potential. Always consult financial professionals before trading.
64 |
65 | The Maverick Bear screening identifies stocks with:
66 | - Weak momentum strength
67 | - Bearish technical patterns
68 | - Distribution characteristics
69 | - High bear scores
70 |
71 | Args:
72 | limit: Maximum number of stocks to return (default: 20)
73 |
74 | Returns:
75 | Dictionary containing Maverick Bear stock screening results
76 | """
77 | try:
78 | from maverick_mcp.data.models import MaverickBearStocks, SessionLocal
79 |
80 | with SessionLocal() as session:
81 | stocks = MaverickBearStocks.get_top_stocks(session, limit=limit)
82 |
83 | return {
84 | "status": "success",
85 | "count": len(stocks),
86 | "stocks": [stock.to_dict() for stock in stocks],
87 | "screening_type": "maverick_bearish",
88 | "description": "Weak stocks with bearish technical setups",
89 | }
90 | except Exception as e:
91 | logger.error(f"Error fetching Maverick Bear stocks: {str(e)}")
92 | return {"error": str(e), "status": "error"}
93 |
94 |
95 | def get_supply_demand_breakouts(
96 | limit: int = 20, filter_moving_averages: bool = False
97 | ) -> dict[str, Any]:
98 | """
99 | Get stocks showing supply/demand breakout patterns from accumulation.
100 |
101 | This screening identifies stocks in the demand expansion phase with:
102 | - Price above all major moving averages (demand zone)
103 | - Moving averages in proper alignment indicating accumulation (50 > 150 > 200)
104 | - Strong momentum strength showing institutional interest
105 | - Market structure indicating supply absorption and demand dominance
106 |
107 | Args:
108 | limit: Maximum number of stocks to return (default: 20)
109 | filter_moving_averages: If True, only return stocks above all moving averages
110 |
111 | Returns:
112 | Dictionary containing supply/demand breakout screening results
113 | """
114 | try:
115 | from maverick_mcp.data.models import SessionLocal, SupplyDemandBreakoutStocks
116 |
117 | with SessionLocal() as session:
118 | if filter_moving_averages:
119 | stocks = SupplyDemandBreakoutStocks.get_stocks_above_moving_averages(
120 | session
121 | )[:limit]
122 | else:
123 | stocks = SupplyDemandBreakoutStocks.get_top_stocks(session, limit=limit)
124 |
125 | return {
126 | "status": "success",
127 | "count": len(stocks),
128 | "stocks": [stock.to_dict() for stock in stocks],
129 | "screening_type": "supply_demand_breakout",
130 | "description": "Stocks breaking out from accumulation with strong demand dynamics",
131 | }
132 | except Exception as e:
133 | logger.error(f"Error fetching supply/demand breakout stocks: {str(e)}")
134 | return {"error": str(e), "status": "error"}
135 |
136 |
137 | def get_all_screening_recommendations() -> dict[str, Any]:
138 | """
139 | Get comprehensive screening results from all strategies.
140 |
141 | This tool returns the top stocks from each screening strategy:
142 | - Maverick Bullish: High momentum growth stocks
143 | - Maverick Bearish: Weak stocks for short opportunities
144 | - Supply/Demand Breakouts: Stocks breaking out from accumulation phases
145 |
146 | Returns:
147 | Dictionary containing all screening results organized by strategy
148 | """
149 | try:
150 | from maverick_mcp.providers.stock_data import StockDataProvider
151 |
152 | provider = StockDataProvider()
153 | return provider.get_all_screening_recommendations()
154 | except Exception as e:
155 | logger.error(f"Error getting all screening recommendations: {e}")
156 | return {
157 | "error": str(e),
158 | "status": "error",
159 | "maverick_stocks": [],
160 | "maverick_bear_stocks": [],
161 | "supply_demand_breakouts": [],
162 | }
163 |
164 |
165 | def get_screening_by_criteria(
166 | min_momentum_score: float | str | None = None,
167 | min_volume: int | str | None = None,
168 | max_price: float | str | None = None,
169 | sector: str | None = None,
170 | limit: int | str = 20,
171 | ) -> dict[str, Any]:
172 | """
173 | Get stocks filtered by specific screening criteria.
174 |
175 | This tool allows custom filtering across all screening results based on:
176 | - Momentum score rating
177 | - Volume requirements
178 | - Price constraints
179 | - Sector preferences
180 |
181 | Args:
182 | min_momentum_score: Minimum momentum score rating (0-100)
183 | min_volume: Minimum average daily volume
184 | max_price: Maximum stock price
185 | sector: Specific sector to filter (e.g., "Technology")
186 | limit: Maximum number of results
187 |
188 | Returns:
189 | Dictionary containing filtered screening results
190 | """
191 | try:
192 | from maverick_mcp.data.models import MaverickStocks, SessionLocal
193 |
194 | # Convert string inputs to appropriate numeric types
195 | if min_momentum_score is not None:
196 | min_momentum_score = float(min_momentum_score)
197 | if min_volume is not None:
198 | min_volume = int(min_volume)
199 | if max_price is not None:
200 | max_price = float(max_price)
201 | if isinstance(limit, str):
202 | limit = int(limit)
203 |
204 | with SessionLocal() as session:
205 | query = session.query(MaverickStocks)
206 |
207 | if min_momentum_score:
208 | query = query.filter(
209 | MaverickStocks.momentum_score >= min_momentum_score
210 | )
211 |
212 | if min_volume:
213 | query = query.filter(MaverickStocks.avg_vol_30d >= min_volume)
214 |
215 | if max_price:
216 | query = query.filter(MaverickStocks.close_price <= max_price)
217 |
218 | # Note: Sector filtering would require joining with Stock table
219 | # This is a simplified version
220 |
221 | stocks = (
222 | query.order_by(MaverickStocks.combined_score.desc()).limit(limit).all()
223 | )
224 |
225 | return {
226 | "status": "success",
227 | "count": len(stocks),
228 | "stocks": [stock.to_dict() for stock in stocks],
229 | "criteria": {
230 | "min_momentum_score": min_momentum_score,
231 | "min_volume": min_volume,
232 | "max_price": max_price,
233 | "sector": sector,
234 | },
235 | }
236 | except Exception as e:
237 | logger.error(f"Error in custom screening: {str(e)}")
238 | return {"error": str(e), "status": "error"}
239 |
```
--------------------------------------------------------------------------------
/maverick_mcp/providers/interfaces/config.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Configuration provider interface.
3 |
4 | This module defines the abstract interface for configuration management,
5 | enabling different configuration sources (environment variables, files, etc.)
6 | to be used interchangeably throughout the application.
7 | """
8 |
9 | from typing import Any, Protocol, runtime_checkable
10 |
11 |
12 | @runtime_checkable
13 | class IConfigurationProvider(Protocol):
14 | """
15 | Interface for configuration management.
16 |
17 | This interface abstracts configuration access to enable different
18 | sources (environment variables, config files, etc.) to be used interchangeably.
19 | """
20 |
21 | def get_database_url(self) -> str:
22 | """
23 | Get database connection URL.
24 |
25 | Returns:
26 | Database connection URL string
27 | """
28 | ...
29 |
30 | def get_redis_host(self) -> str:
31 | """Get Redis server host."""
32 | ...
33 |
34 | def get_redis_port(self) -> int:
35 | """Get Redis server port."""
36 | ...
37 |
38 | def get_redis_db(self) -> int:
39 | """Get Redis database number."""
40 | ...
41 |
42 | def get_redis_password(self) -> str | None:
43 | """Get Redis password."""
44 | ...
45 |
46 | def get_redis_ssl(self) -> bool:
47 | """Get Redis SSL setting."""
48 | ...
49 |
50 | def is_cache_enabled(self) -> bool:
51 | """Check if caching is enabled."""
52 | ...
53 |
54 | def get_cache_ttl(self) -> int:
55 | """Get default cache TTL in seconds."""
56 | ...
57 |
58 | def get_fred_api_key(self) -> str:
59 | """Get FRED API key for macroeconomic data."""
60 | ...
61 |
62 | def get_external_api_key(self) -> str:
63 | """Get External API key for market data."""
64 | ...
65 |
66 | def get_tiingo_api_key(self) -> str:
67 | """Get Tiingo API key for market data."""
68 | ...
69 |
70 | def get_log_level(self) -> str:
71 | """Get logging level."""
72 | ...
73 |
74 | def is_development_mode(self) -> bool:
75 | """Check if running in development mode."""
76 | ...
77 |
78 | def is_production_mode(self) -> bool:
79 | """Check if running in production mode."""
80 | ...
81 |
82 | def get_request_timeout(self) -> int:
83 | """Get default request timeout in seconds."""
84 | ...
85 |
86 | def get_max_retries(self) -> int:
87 | """Get maximum retry attempts for API calls."""
88 | ...
89 |
90 | def get_pool_size(self) -> int:
91 | """Get database connection pool size."""
92 | ...
93 |
94 | def get_max_overflow(self) -> int:
95 | """Get database connection pool overflow."""
96 | ...
97 |
98 | def get_config_value(self, key: str, default: Any = None) -> Any:
99 | """
100 | Get a configuration value by key.
101 |
102 | Args:
103 | key: Configuration key
104 | default: Default value if key not found
105 |
106 | Returns:
107 | Configuration value or default
108 | """
109 | ...
110 |
111 | def set_config_value(self, key: str, value: Any) -> None:
112 | """
113 | Set a configuration value.
114 |
115 | Args:
116 | key: Configuration key
117 | value: Value to set
118 | """
119 | ...
120 |
121 | def get_all_config(self) -> dict[str, Any]:
122 | """
123 | Get all configuration as a dictionary.
124 |
125 | Returns:
126 | Dictionary of all configuration values
127 | """
128 | ...
129 |
130 | def reload_config(self) -> None:
131 | """Reload configuration from source."""
132 | ...
133 |
134 |
135 | class ConfigurationError(Exception):
136 | """Base exception for configuration-related errors."""
137 |
138 | pass
139 |
140 |
141 | class MissingConfigurationError(ConfigurationError):
142 | """Raised when required configuration is missing."""
143 |
144 | def __init__(self, key: str, message: str | None = None):
145 | self.key = key
146 | super().__init__(message or f"Missing required configuration: {key}")
147 |
148 |
149 | class InvalidConfigurationError(ConfigurationError):
150 | """Raised when configuration value is invalid."""
151 |
152 | def __init__(self, key: str, value: Any, message: str | None = None):
153 | self.key = key
154 | self.value = value
155 | super().__init__(message or f"Invalid configuration value for {key}: {value}")
156 |
157 |
158 | class EnvironmentConfigurationProvider:
159 | """
160 | Environment-based configuration provider.
161 |
162 | This is a concrete implementation that can be used as a default
163 | or reference implementation for the IConfigurationProvider interface.
164 | """
165 |
166 | def __init__(self):
167 | """Initialize with environment variables."""
168 | import os
169 |
170 | self._env = os.environ
171 | self._cache: dict[str, Any] = {}
172 |
173 | def get_database_url(self) -> str:
174 | """Get database URL from DATABASE_URL environment variable."""
175 | return self._env.get("DATABASE_URL", "sqlite:///maverick_mcp.db")
176 |
177 | def get_redis_host(self) -> str:
178 | """Get Redis host from REDIS_HOST environment variable."""
179 | return self._env.get("REDIS_HOST", "localhost")
180 |
181 | def get_redis_port(self) -> int:
182 | """Get Redis port from REDIS_PORT environment variable."""
183 | return int(self._env.get("REDIS_PORT", "6379"))
184 |
185 | def get_redis_db(self) -> int:
186 | """Get Redis database from REDIS_DB environment variable."""
187 | return int(self._env.get("REDIS_DB", "0"))
188 |
189 | def get_redis_password(self) -> str | None:
190 | """Get Redis password from REDIS_PASSWORD environment variable."""
191 | password = self._env.get("REDIS_PASSWORD", "")
192 | return password if password else None
193 |
194 | def get_redis_ssl(self) -> bool:
195 | """Get Redis SSL setting from REDIS_SSL environment variable."""
196 | return self._env.get("REDIS_SSL", "False").lower() == "true"
197 |
198 | def is_cache_enabled(self) -> bool:
199 | """Check if caching is enabled from CACHE_ENABLED environment variable."""
200 | return self._env.get("CACHE_ENABLED", "True").lower() == "true"
201 |
202 | def get_cache_ttl(self) -> int:
203 | """Get cache TTL from CACHE_TTL_SECONDS environment variable."""
204 | return int(self._env.get("CACHE_TTL_SECONDS", "604800"))
205 |
206 | def get_fred_api_key(self) -> str:
207 | """Get FRED API key from FRED_API_KEY environment variable."""
208 | return self._env.get("FRED_API_KEY", "")
209 |
210 | def get_external_api_key(self) -> str:
211 | """Get External API key from CAPITAL_COMPANION_API_KEY environment variable."""
212 | return self._env.get("CAPITAL_COMPANION_API_KEY", "")
213 |
214 | def get_tiingo_api_key(self) -> str:
215 | """Get Tiingo API key from TIINGO_API_KEY environment variable."""
216 | return self._env.get("TIINGO_API_KEY", "")
217 |
218 | def get_log_level(self) -> str:
219 | """Get log level from LOG_LEVEL environment variable."""
220 | return self._env.get("LOG_LEVEL", "INFO")
221 |
222 | def is_development_mode(self) -> bool:
223 | """Check if in development mode from ENVIRONMENT environment variable."""
224 | env = self._env.get("ENVIRONMENT", "development").lower()
225 | return env in ("development", "dev", "test")
226 |
227 | def is_production_mode(self) -> bool:
228 | """Check if in production mode from ENVIRONMENT environment variable."""
229 | env = self._env.get("ENVIRONMENT", "development").lower()
230 | return env in ("production", "prod")
231 |
232 | def get_request_timeout(self) -> int:
233 | """Get request timeout from REQUEST_TIMEOUT environment variable."""
234 | return int(self._env.get("REQUEST_TIMEOUT", "30"))
235 |
236 | def get_max_retries(self) -> int:
237 | """Get max retries from MAX_RETRIES environment variable."""
238 | return int(self._env.get("MAX_RETRIES", "3"))
239 |
240 | def get_pool_size(self) -> int:
241 | """Get pool size from DB_POOL_SIZE environment variable."""
242 | return int(self._env.get("DB_POOL_SIZE", "5"))
243 |
244 | def get_max_overflow(self) -> int:
245 | """Get max overflow from DB_MAX_OVERFLOW environment variable."""
246 | return int(self._env.get("DB_MAX_OVERFLOW", "10"))
247 |
248 | def get_config_value(self, key: str, default: Any = None) -> Any:
249 | """Get configuration value from environment variables."""
250 | if key in self._cache:
251 | return self._cache[key]
252 |
253 | value = self._env.get(key, default)
254 | self._cache[key] = value
255 | return value
256 |
257 | def set_config_value(self, key: str, value: Any) -> None:
258 | """Set configuration value (updates cache, not environment)."""
259 | self._cache[key] = value
260 |
261 | def get_all_config(self) -> dict[str, Any]:
262 | """Get all configuration as dictionary."""
263 | config = {}
264 | config.update(self._env)
265 | config.update(self._cache)
266 | return config
267 |
268 | def reload_config(self) -> None:
269 | """Clear cache to force reload from environment."""
270 | self._cache.clear()
271 |
```
--------------------------------------------------------------------------------
/tests/integration/base.py:
--------------------------------------------------------------------------------
```python
1 | """Base classes and utilities for integration testing."""
2 |
3 | from __future__ import annotations
4 |
5 | import asyncio
6 | import fnmatch
7 | import time
8 | from collections import defaultdict
9 | from typing import Any
10 | from unittest.mock import AsyncMock, MagicMock
11 |
12 | import pytest
13 |
14 |
15 | class InMemoryPubSub:
16 | """Lightweight pub/sub implementation for the in-memory Redis stub."""
17 |
18 | def __init__(self, redis: InMemoryRedis) -> None:
19 | self._redis = redis
20 | self._queues: dict[str, asyncio.Queue[dict[str, Any]]] = {}
21 | self._active = True
22 |
23 | async def subscribe(self, channel: str) -> None:
24 | queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
25 | self._queues[channel] = queue
26 | self._redis.register_subscriber(channel, queue)
27 |
28 | async def unsubscribe(self, channel: str) -> None:
29 | queue = self._queues.pop(channel, None)
30 | if queue is not None:
31 | self._redis.unregister_subscriber(channel, queue)
32 |
33 | async def close(self) -> None:
34 | self._active = False
35 | for channel, _queue in list(self._queues.items()):
36 | await self.unsubscribe(channel)
37 |
38 | async def listen(self): # pragma: no cover - simple async generator
39 | while self._active:
40 | tasks = [
41 | asyncio.create_task(queue.get()) for queue in self._queues.values()
42 | ]
43 | if not tasks:
44 | await asyncio.sleep(0.01)
45 | continue
46 | done, pending = await asyncio.wait(
47 | tasks, return_when=asyncio.FIRST_COMPLETED
48 | )
49 | for task in pending:
50 | task.cancel()
51 | for task in done:
52 | message = task.result()
53 | yield message
54 |
55 |
56 | class InMemoryRedis:
57 | """A minimal asynchronous Redis replacement used in tests."""
58 |
59 | def __init__(self) -> None:
60 | self._data: dict[str, bytes] = {}
61 | self._hashes: dict[str, dict[str, str]] = defaultdict(dict)
62 | self._expiry: dict[str, float] = {}
63 | self._pubsub_channels: dict[str, list[asyncio.Queue[dict[str, Any]]]] = (
64 | defaultdict(list)
65 | )
66 |
67 | def _is_expired(self, key: str) -> bool:
68 | expiry = self._expiry.get(key)
69 | if expiry is None:
70 | return False
71 | if expiry < time.time():
72 | self._data.pop(key, None)
73 | self._hashes.pop(key, None)
74 | self._expiry.pop(key, None)
75 | return True
76 | return False
77 |
78 | def register_subscriber(
79 | self, channel: str, queue: asyncio.Queue[dict[str, Any]]
80 | ) -> None:
81 | self._pubsub_channels[channel].append(queue)
82 |
83 | def unregister_subscriber(
84 | self, channel: str, queue: asyncio.Queue[dict[str, Any]]
85 | ) -> None:
86 | if channel in self._pubsub_channels:
87 | try:
88 | self._pubsub_channels[channel].remove(queue)
89 | except ValueError:
90 | pass
91 | if not self._pubsub_channels[channel]:
92 | del self._pubsub_channels[channel]
93 |
94 | async def setex(self, key: str, ttl: int, value: Any) -> None:
95 | self._data[key] = self._encode(value)
96 | self._expiry[key] = time.time() + ttl
97 |
98 | async def set(
99 | self,
100 | key: str,
101 | value: Any,
102 | *,
103 | nx: bool = False,
104 | ex: int | None = None,
105 | ) -> str | None:
106 | if nx and key in self._data and not self._is_expired(key):
107 | return None
108 | self._data[key] = self._encode(value)
109 | if ex is not None:
110 | self._expiry[key] = time.time() + ex
111 | return "OK"
112 |
113 | async def get(self, key: str) -> bytes | None:
114 | if self._is_expired(key):
115 | return None
116 | return self._data.get(key)
117 |
118 | async def delete(self, *keys: str) -> int:
119 | removed = 0
120 | for key in keys:
121 | if key in self._data and not self._is_expired(key):
122 | removed += 1
123 | self._data.pop(key, None)
124 | self._hashes.pop(key, None)
125 | self._expiry.pop(key, None)
126 | return removed
127 |
128 | async def scan(
129 | self, cursor: int, match: str | None = None, count: int = 100
130 | ) -> tuple[int, list[str]]:
131 | keys = [key for key in self._data.keys() if not self._is_expired(key)]
132 | if match:
133 | keys = [key for key in keys if fnmatch.fnmatch(key, match)]
134 | return 0, keys[:count]
135 |
136 | async def mget(self, keys: list[str]) -> list[bytes | None]:
137 | return [await self.get(key) for key in keys]
138 |
139 | async def hincrby(self, key: str, field: str, amount: int) -> int:
140 | current = int(self._hashes[key].get(field, "0"))
141 | current += amount
142 | self._hashes[key][field] = str(current)
143 | return current
144 |
145 | async def hgetall(self, key: str) -> dict[bytes, bytes]:
146 | if self._is_expired(key):
147 | return {}
148 | mapping = self._hashes.get(key, {})
149 | return {
150 | field.encode("utf-8"): value.encode("utf-8")
151 | for field, value in mapping.items()
152 | }
153 |
154 | async def hset(self, key: str, mapping: dict[str, Any]) -> None:
155 | for field, value in mapping.items():
156 | self._hashes[key][field] = str(value)
157 |
158 | async def eval(self, script: str, keys: list[str], args: list[str]) -> int:
159 | if not keys:
160 | return 0
161 | key = keys[0]
162 | expected = args[0] if args else ""
163 | stored = await self.get(key)
164 | if stored is not None and stored.decode("utf-8") == expected:
165 | await self.delete(key)
166 | return 1
167 | return 0
168 |
169 | async def publish(self, channel: str, message: Any) -> None:
170 | encoded = self._encode(message)
171 | for queue in self._pubsub_channels.get(channel, []):
172 | await queue.put(
173 | {"type": "message", "channel": channel, "data": encoded.decode("utf-8")}
174 | )
175 |
176 | def pubsub(self) -> InMemoryPubSub:
177 | return InMemoryPubSub(self)
178 |
179 | def _encode(self, value: Any) -> bytes:
180 | if isinstance(value, bytes):
181 | return value
182 | if isinstance(value, str):
183 | return value.encode("utf-8")
184 | return str(value).encode("utf-8")
185 |
186 | async def close(self) -> None:
187 | self._data.clear()
188 | self._hashes.clear()
189 | self._expiry.clear()
190 | self._pubsub_channels.clear()
191 |
192 |
193 | class BaseIntegrationTest:
194 | """Base class for integration tests with common utilities."""
195 |
196 | def setup_test(self):
197 | """Set up test environment for each test."""
198 | return None
199 |
200 | def assert_response_success(self, response, expected_status: int = 200):
201 | """Assert that a response is successful."""
202 | if hasattr(response, "status_code"):
203 | assert response.status_code == expected_status, (
204 | f"Expected status {expected_status}, got {response.status_code}. "
205 | f"Response: {response.json() if hasattr(response, 'content') and response.content else 'No content'}"
206 | )
207 |
208 |
209 | class RedisIntegrationTest(BaseIntegrationTest):
210 | """Integration tests that rely on a Redis-like backend."""
211 |
212 | redis_client: InMemoryRedis
213 |
214 | @pytest.fixture(autouse=True)
215 | async def _setup_redis(self):
216 | self.redis_client = InMemoryRedis()
217 | yield
218 | await self.redis_client.close()
219 |
220 | async def assert_cache_exists(self, key: str) -> None:
221 | value = await self.redis_client.get(key)
222 | assert value is not None, f"Expected cache key {key} to exist"
223 |
224 | async def assert_cache_not_exists(self, key: str) -> None:
225 | value = await self.redis_client.get(key)
226 | assert value is None, f"Expected cache key {key} to be absent"
227 |
228 |
229 | class MockLLMBase:
230 | """Base mock LLM for consistent testing."""
231 |
232 | def __init__(self):
233 | self.ainvoke = AsyncMock()
234 | self.bind_tools = MagicMock(return_value=self)
235 | self.invoke = MagicMock()
236 |
237 | mock_response = MagicMock()
238 | mock_response.content = '{"insights": ["Test insight"], "sentiment": {"direction": "neutral", "confidence": 0.5}}'
239 | self.ainvoke.return_value = mock_response
240 |
241 |
242 | class MockCacheManager:
243 | """Mock cache manager for testing."""
244 |
245 | def __init__(self):
246 | self.get = AsyncMock(return_value=None)
247 | self.set = AsyncMock()
248 | self._cache: dict[str, Any] = {}
249 |
250 | async def get_cached(self, key: str) -> Any:
251 | """Get value from mock cache."""
252 | return self._cache.get(key)
253 |
254 | async def set_cached(self, key: str, value: Any) -> None:
255 | """Set value in mock cache."""
256 | self._cache[key] = value
257 |
```