This is page 19 of 39. Use http://codebase.md/wshobson/maverick-mcp?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.md
│ │ ├── config.yml
│ │ ├── feature_request.md
│ │ ├── question.md
│ │ └── security_report.md
│ ├── pull_request_template.md
│ └── workflows
│ ├── claude-code-review.yml
│ └── claude.yml
├── .gitignore
├── .python-version
├── .vscode
│ ├── launch.json
│ └── settings.json
├── alembic
│ ├── env.py
│ ├── script.py.mako
│ └── versions
│ ├── 001_initial_schema.py
│ ├── 003_add_performance_indexes.py
│ ├── 006_rename_metadata_columns.py
│ ├── 008_performance_optimization_indexes.py
│ ├── 009_rename_to_supply_demand.py
│ ├── 010_self_contained_schema.py
│ ├── 011_remove_proprietary_terms.py
│ ├── 013_add_backtest_persistence_models.py
│ ├── 014_add_portfolio_models.py
│ ├── 08e3945a0c93_merge_heads.py
│ ├── 9374a5c9b679_merge_heads_for_testing.py
│ ├── abf9b9afb134_merge_multiple_heads.py
│ ├── adda6d3fd84b_merge_proprietary_terms_removal_with_.py
│ ├── e0c75b0bdadb_fix_financial_data_precision_only.py
│ ├── f0696e2cac15_add_essential_performance_indexes.py
│ └── fix_database_integrity_issues.py
├── alembic.ini
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── DATABASE_SETUP.md
├── docker-compose.override.yml.example
├── docker-compose.yml
├── Dockerfile
├── docs
│ ├── api
│ │ └── backtesting.md
│ ├── BACKTESTING.md
│ ├── COST_BASIS_SPECIFICATION.md
│ ├── deep_research_agent.md
│ ├── exa_research_testing_strategy.md
│ ├── PORTFOLIO_PERSONALIZATION_PLAN.md
│ ├── PORTFOLIO.md
│ ├── SETUP_SELF_CONTAINED.md
│ └── speed_testing_framework.md
├── examples
│ ├── complete_speed_validation.py
│ ├── deep_research_integration.py
│ ├── llm_optimization_example.py
│ ├── llm_speed_demo.py
│ ├── monitoring_example.py
│ ├── parallel_research_example.py
│ ├── speed_optimization_demo.py
│ └── timeout_fix_demonstration.py
├── LICENSE
├── Makefile
├── MANIFEST.in
├── maverick_mcp
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── circuit_breaker.py
│ │ ├── deep_research.py
│ │ ├── market_analysis.py
│ │ ├── optimized_research.py
│ │ ├── supervisor.py
│ │ └── technical_analysis.py
│ ├── api
│ │ ├── __init__.py
│ │ ├── api_server.py
│ │ ├── connection_manager.py
│ │ ├── dependencies
│ │ │ ├── __init__.py
│ │ │ ├── stock_analysis.py
│ │ │ └── technical_analysis.py
│ │ ├── error_handling.py
│ │ ├── inspector_compatible_sse.py
│ │ ├── inspector_sse.py
│ │ ├── middleware
│ │ │ ├── error_handling.py
│ │ │ ├── mcp_logging.py
│ │ │ ├── rate_limiting_enhanced.py
│ │ │ └── security.py
│ │ ├── openapi_config.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── agents.py
│ │ │ ├── backtesting.py
│ │ │ ├── data_enhanced.py
│ │ │ ├── data.py
│ │ │ ├── health_enhanced.py
│ │ │ ├── health_tools.py
│ │ │ ├── health.py
│ │ │ ├── intelligent_backtesting.py
│ │ │ ├── introspection.py
│ │ │ ├── mcp_prompts.py
│ │ │ ├── monitoring.py
│ │ │ ├── news_sentiment_enhanced.py
│ │ │ ├── performance.py
│ │ │ ├── portfolio.py
│ │ │ ├── research.py
│ │ │ ├── screening_ddd.py
│ │ │ ├── screening_parallel.py
│ │ │ ├── screening.py
│ │ │ ├── technical_ddd.py
│ │ │ ├── technical_enhanced.py
│ │ │ ├── technical.py
│ │ │ └── tool_registry.py
│ │ ├── server.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── base_service.py
│ │ │ ├── market_service.py
│ │ │ ├── portfolio_service.py
│ │ │ ├── prompt_service.py
│ │ │ └── resource_service.py
│ │ ├── simple_sse.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── insomnia_export.py
│ │ └── postman_export.py
│ ├── application
│ │ ├── __init__.py
│ │ ├── commands
│ │ │ └── __init__.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_dto.py
│ │ ├── queries
│ │ │ ├── __init__.py
│ │ │ └── get_technical_analysis.py
│ │ └── screening
│ │ ├── __init__.py
│ │ ├── dtos.py
│ │ └── queries.py
│ ├── backtesting
│ │ ├── __init__.py
│ │ ├── ab_testing.py
│ │ ├── analysis.py
│ │ ├── batch_processing_stub.py
│ │ ├── batch_processing.py
│ │ ├── model_manager.py
│ │ ├── optimization.py
│ │ ├── persistence.py
│ │ ├── retraining_pipeline.py
│ │ ├── strategies
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── ml
│ │ │ │ ├── __init__.py
│ │ │ │ ├── adaptive.py
│ │ │ │ ├── ensemble.py
│ │ │ │ ├── feature_engineering.py
│ │ │ │ └── regime_aware.py
│ │ │ ├── ml_strategies.py
│ │ │ ├── parser.py
│ │ │ └── templates.py
│ │ ├── strategy_executor.py
│ │ ├── vectorbt_engine.py
│ │ └── visualization.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── database_self_contained.py
│ │ ├── database.py
│ │ ├── llm_optimization_config.py
│ │ ├── logging_settings.py
│ │ ├── plotly_config.py
│ │ ├── security_utils.py
│ │ ├── security.py
│ │ ├── settings.py
│ │ ├── technical_constants.py
│ │ ├── tool_estimation.py
│ │ └── validation.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── technical_analysis.py
│ │ └── visualization.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── cache_manager.py
│ │ ├── cache.py
│ │ ├── django_adapter.py
│ │ ├── health.py
│ │ ├── models.py
│ │ ├── performance.py
│ │ ├── session_management.py
│ │ └── validation.py
│ ├── database
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── optimization.py
│ ├── dependencies.py
│ ├── domain
│ │ ├── __init__.py
│ │ ├── entities
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis.py
│ │ ├── events
│ │ │ └── __init__.py
│ │ ├── portfolio.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ ├── entities.py
│ │ │ ├── services.py
│ │ │ └── value_objects.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_service.py
│ │ ├── stock_analysis
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis_service.py
│ │ └── value_objects
│ │ ├── __init__.py
│ │ └── technical_indicators.py
│ ├── exceptions.py
│ ├── infrastructure
│ │ ├── __init__.py
│ │ ├── cache
│ │ │ └── __init__.py
│ │ ├── caching
│ │ │ ├── __init__.py
│ │ │ └── cache_management_service.py
│ │ ├── connection_manager.py
│ │ ├── data_fetching
│ │ │ ├── __init__.py
│ │ │ └── stock_data_service.py
│ │ ├── health
│ │ │ ├── __init__.py
│ │ │ └── health_checker.py
│ │ ├── persistence
│ │ │ ├── __init__.py
│ │ │ └── stock_repository.py
│ │ ├── providers
│ │ │ └── __init__.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ └── repositories.py
│ │ └── sse_optimizer.py
│ ├── langchain_tools
│ │ ├── __init__.py
│ │ ├── adapters.py
│ │ └── registry.py
│ ├── logging_config.py
│ ├── memory
│ │ ├── __init__.py
│ │ └── stores.py
│ ├── monitoring
│ │ ├── __init__.py
│ │ ├── health_check.py
│ │ ├── health_monitor.py
│ │ ├── integration_example.py
│ │ ├── metrics.py
│ │ ├── middleware.py
│ │ └── status_dashboard.py
│ ├── providers
│ │ ├── __init__.py
│ │ ├── dependencies.py
│ │ ├── factories
│ │ │ ├── __init__.py
│ │ │ ├── config_factory.py
│ │ │ └── provider_factory.py
│ │ ├── implementations
│ │ │ ├── __init__.py
│ │ │ ├── cache_adapter.py
│ │ │ ├── macro_data_adapter.py
│ │ │ ├── market_data_adapter.py
│ │ │ ├── persistence_adapter.py
│ │ │ └── stock_data_adapter.py
│ │ ├── interfaces
│ │ │ ├── __init__.py
│ │ │ ├── cache.py
│ │ │ ├── config.py
│ │ │ ├── macro_data.py
│ │ │ ├── market_data.py
│ │ │ ├── persistence.py
│ │ │ └── stock_data.py
│ │ ├── llm_factory.py
│ │ ├── macro_data.py
│ │ ├── market_data.py
│ │ ├── mocks
│ │ │ ├── __init__.py
│ │ │ ├── mock_cache.py
│ │ │ ├── mock_config.py
│ │ │ ├── mock_macro_data.py
│ │ │ ├── mock_market_data.py
│ │ │ ├── mock_persistence.py
│ │ │ └── mock_stock_data.py
│ │ ├── openrouter_provider.py
│ │ ├── optimized_screening.py
│ │ ├── optimized_stock_data.py
│ │ └── stock_data.py
│ ├── README.md
│ ├── tests
│ │ ├── __init__.py
│ │ ├── README_INMEMORY_TESTS.md
│ │ ├── test_cache_debug.py
│ │ ├── test_fixes_validation.py
│ │ ├── test_in_memory_routers.py
│ │ ├── test_in_memory_server.py
│ │ ├── test_macro_data_provider.py
│ │ ├── test_mailgun_email.py
│ │ ├── test_market_calendar_caching.py
│ │ ├── test_mcp_tool_fixes_pytest.py
│ │ ├── test_mcp_tool_fixes.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_models_functional.py
│ │ ├── test_server.py
│ │ ├── test_stock_data_enhanced.py
│ │ ├── test_stock_data_provider.py
│ │ └── test_technical_analysis.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── performance_monitoring.py
│ │ ├── portfolio_manager.py
│ │ ├── risk_management.py
│ │ └── sentiment_analysis.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── agent_errors.py
│ │ ├── batch_processing.py
│ │ ├── cache_warmer.py
│ │ ├── circuit_breaker_decorators.py
│ │ ├── circuit_breaker_services.py
│ │ ├── circuit_breaker.py
│ │ ├── data_chunking.py
│ │ ├── database_monitoring.py
│ │ ├── debug_utils.py
│ │ ├── fallback_strategies.py
│ │ ├── llm_optimization.py
│ │ ├── logging_example.py
│ │ ├── logging_init.py
│ │ ├── logging.py
│ │ ├── mcp_logging.py
│ │ ├── memory_profiler.py
│ │ ├── monitoring_middleware.py
│ │ ├── monitoring.py
│ │ ├── orchestration_logging.py
│ │ ├── parallel_research.py
│ │ ├── parallel_screening.py
│ │ ├── quick_cache.py
│ │ ├── resource_manager.py
│ │ ├── shutdown.py
│ │ ├── stock_helpers.py
│ │ ├── structured_logger.py
│ │ ├── tool_monitoring.py
│ │ ├── tracing.py
│ │ └── yfinance_pool.py
│ ├── validation
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── data.py
│ │ ├── middleware.py
│ │ ├── portfolio.py
│ │ ├── responses.py
│ │ ├── screening.py
│ │ └── technical.py
│ └── workflows
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── market_analyzer.py
│ │ ├── optimizer_agent.py
│ │ ├── strategy_selector.py
│ │ └── validator_agent.py
│ ├── backtesting_workflow.py
│ └── state.py
├── PLANS.md
├── pyproject.toml
├── pyrightconfig.json
├── README.md
├── scripts
│ ├── dev.sh
│ ├── INSTALLATION_GUIDE.md
│ ├── load_example.py
│ ├── load_market_data.py
│ ├── load_tiingo_data.py
│ ├── migrate_db.py
│ ├── README_TIINGO_LOADER.md
│ ├── requirements_tiingo.txt
│ ├── run_stock_screening.py
│ ├── run-migrations.sh
│ ├── seed_db.py
│ ├── seed_sp500.py
│ ├── setup_database.sh
│ ├── setup_self_contained.py
│ ├── setup_sp500_database.sh
│ ├── test_seeded_data.py
│ ├── test_tiingo_loader.py
│ ├── tiingo_config.py
│ └── validate_setup.py
├── SECURITY.md
├── server.json
├── setup.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── core
│ │ └── test_technical_analysis.py
│ ├── data
│ │ └── test_portfolio_models.py
│ ├── domain
│ │ ├── conftest.py
│ │ ├── test_portfolio_entities.py
│ │ └── test_technical_analysis_service.py
│ ├── fixtures
│ │ └── orchestration_fixtures.py
│ ├── integration
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── README.md
│ │ ├── run_integration_tests.sh
│ │ ├── test_api_technical.py
│ │ ├── test_chaos_engineering.py
│ │ ├── test_config_management.py
│ │ ├── test_full_backtest_workflow_advanced.py
│ │ ├── test_full_backtest_workflow.py
│ │ ├── test_high_volume.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_orchestration_complete.py
│ │ ├── test_portfolio_persistence.py
│ │ ├── test_redis_cache.py
│ │ ├── test_security_integration.py.disabled
│ │ └── vcr_setup.py
│ ├── performance
│ │ ├── __init__.py
│ │ ├── test_benchmarks.py
│ │ ├── test_load.py
│ │ ├── test_profiling.py
│ │ └── test_stress.py
│ ├── providers
│ │ └── test_stock_data_simple.py
│ ├── README.md
│ ├── test_agents_router_mcp.py
│ ├── test_backtest_persistence.py
│ ├── test_cache_management_service.py
│ ├── test_cache_serialization.py
│ ├── test_circuit_breaker.py
│ ├── test_database_pool_config_simple.py
│ ├── test_database_pool_config.py
│ ├── test_deep_research_functional.py
│ ├── test_deep_research_integration.py
│ ├── test_deep_research_parallel_execution.py
│ ├── test_error_handling.py
│ ├── test_event_loop_integrity.py
│ ├── test_exa_research_integration.py
│ ├── test_exception_hierarchy.py
│ ├── test_financial_search.py
│ ├── test_graceful_shutdown.py
│ ├── test_integration_simple.py
│ ├── test_langgraph_workflow.py
│ ├── test_market_data_async.py
│ ├── test_market_data_simple.py
│ ├── test_mcp_orchestration_functional.py
│ ├── test_ml_strategies.py
│ ├── test_optimized_research_agent.py
│ ├── test_orchestration_integration.py
│ ├── test_orchestration_logging.py
│ ├── test_orchestration_tools_simple.py
│ ├── test_parallel_research_integration.py
│ ├── test_parallel_research_orchestrator.py
│ ├── test_parallel_research_performance.py
│ ├── test_performance_optimizations.py
│ ├── test_production_validation.py
│ ├── test_provider_architecture.py
│ ├── test_rate_limiting_enhanced.py
│ ├── test_runner_validation.py
│ ├── test_security_comprehensive.py.disabled
│ ├── test_security_cors.py
│ ├── test_security_enhancements.py.disabled
│ ├── test_security_headers.py
│ ├── test_security_penetration.py
│ ├── test_session_management.py
│ ├── test_speed_optimization_validation.py
│ ├── test_stock_analysis_dependencies.py
│ ├── test_stock_analysis_service.py
│ ├── test_stock_data_fetching_service.py
│ ├── test_supervisor_agent.py
│ ├── test_supervisor_functional.py
│ ├── test_tool_estimation_config.py
│ ├── test_visualization.py
│ └── utils
│ ├── test_agent_errors.py
│ ├── test_logging.py
│ ├── test_parallel_screening.py
│ └── test_quick_cache.py
├── tools
│ ├── check_orchestration_config.py
│ ├── experiments
│ │ ├── validation_examples.py
│ │ └── validation_fixed.py
│ ├── fast_dev.sh
│ ├── hot_reload.py
│ ├── quick_test.py
│ └── templates
│ ├── new_router_template.py
│ ├── new_tool_template.py
│ ├── screening_strategy_template.py
│ └── test_template.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/tests/integration/test_config_management.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Integration tests for configuration management features.
3 |
4 | This module tests the integration of ToolEstimationConfig and DatabasePoolConfig
5 | with the actual server implementation and other components. Tests verify:
6 | - server.py correctly uses ToolEstimationConfig
7 | - Database connections work with DatabasePoolConfig
8 | - Configuration changes are properly applied
9 | - Monitoring and logging functionality works end-to-end
10 | - Real-world usage patterns work correctly
11 | """
12 |
13 | import os
14 | from unittest.mock import Mock, patch
15 |
16 | import pytest
17 | from sqlalchemy import create_engine, text
18 |
19 | from maverick_mcp.config.database import (
20 | DatabasePoolConfig,
21 | validate_production_config,
22 | )
23 | from maverick_mcp.config.tool_estimation import (
24 | get_tool_estimate,
25 | get_tool_estimation_config,
26 | should_alert_for_usage,
27 | )
28 |
29 |
30 | @pytest.mark.integration
31 | class TestServerToolEstimationIntegration:
32 | """Test integration of ToolEstimationConfig with server.py."""
33 |
34 | def test_server_imports_configuration_correctly(self):
35 | """Test that server.py can import and use tool estimation configuration."""
36 | # This tests the import path used in server.py
37 | from maverick_mcp.config.tool_estimation import (
38 | get_tool_estimate,
39 | get_tool_estimation_config,
40 | should_alert_for_usage,
41 | )
42 |
43 | # Should work without errors
44 | config = get_tool_estimation_config()
45 | estimate = get_tool_estimate("get_stock_price")
46 | should_alert, message = should_alert_for_usage("test_tool", 5, 1000)
47 |
48 | assert config is not None
49 | assert estimate is not None
50 | assert isinstance(should_alert, bool)
51 |
52 | @patch("maverick_mcp.config.tool_estimation.logger")
53 | def test_server_logging_pattern_with_low_confidence(self, mock_logger):
54 | """Test the logging pattern used in server.py for low confidence estimates."""
55 | config = get_tool_estimation_config()
56 |
57 | # Find a tool with low confidence (< 0.8)
58 | low_confidence_tool = None
59 | for tool_name, estimate in config.tool_estimates.items():
60 | if estimate.confidence < 0.8:
61 | low_confidence_tool = tool_name
62 | break
63 |
64 | if low_confidence_tool:
65 | # Simulate the server.py logging pattern
66 | tool_estimate = get_tool_estimate(low_confidence_tool)
67 |
68 | # This mimics the server.py code path
69 | if tool_estimate.confidence < 0.8:
70 | # Log the warning as server.py would
71 | logger_extra = {
72 | "tool_name": low_confidence_tool,
73 | "confidence": tool_estimate.confidence,
74 | "basis": tool_estimate.based_on.value,
75 | "complexity": tool_estimate.complexity.value,
76 | "estimated_llm_calls": tool_estimate.llm_calls,
77 | "estimated_tokens": tool_estimate.total_tokens,
78 | }
79 |
80 | # Verify the data structure matches server.py expectations
81 | assert "tool_name" in logger_extra
82 | assert "confidence" in logger_extra
83 | assert "basis" in logger_extra
84 | assert "complexity" in logger_extra
85 | assert "estimated_llm_calls" in logger_extra
86 | assert "estimated_tokens" in logger_extra
87 |
88 | # Values should be in expected formats
89 | assert isinstance(logger_extra["confidence"], float)
90 | assert isinstance(logger_extra["basis"], str)
91 | assert isinstance(logger_extra["complexity"], str)
92 | assert isinstance(logger_extra["estimated_llm_calls"], int)
93 | assert isinstance(logger_extra["estimated_tokens"], int)
94 |
95 | def test_server_error_handling_fallback_pattern(self):
96 | """Test the error handling pattern used in server.py."""
97 | config = get_tool_estimation_config()
98 |
99 | # Simulate the server.py error handling pattern
100 | actual_tool_name = "nonexistent_tool"
101 | tool_estimate = None
102 |
103 | try:
104 | tool_estimate = get_tool_estimate(actual_tool_name)
105 | llm_calls = tool_estimate.llm_calls
106 | total_tokens = tool_estimate.total_tokens
107 | except Exception:
108 | # Fallback to conservative defaults (server.py pattern)
109 | fallback_estimate = config.unknown_tool_estimate
110 | llm_calls = fallback_estimate.llm_calls
111 | total_tokens = fallback_estimate.total_tokens
112 |
113 | # Should have fallback values
114 | assert llm_calls > 0
115 | assert total_tokens > 0
116 | assert tool_estimate == config.unknown_tool_estimate
117 |
118 | def test_server_usage_estimates_integration(self):
119 | """Test integration with usage estimation as done in server.py."""
120 | # Test known tools that should have specific estimates
121 | test_tools = [
122 | ("get_stock_price", "simple"),
123 | ("get_rsi_analysis", "standard"),
124 | ("get_full_technical_analysis", "complex"),
125 | ("analyze_market_with_agent", "premium"),
126 | ]
127 |
128 | for tool_name, expected_complexity in test_tools:
129 | estimate = get_tool_estimate(tool_name)
130 |
131 | # Verify estimate has all fields needed for server.py
132 | assert hasattr(estimate, "llm_calls")
133 | assert hasattr(estimate, "total_tokens")
134 | assert hasattr(estimate, "confidence")
135 | assert hasattr(estimate, "based_on")
136 | assert hasattr(estimate, "complexity")
137 |
138 | # Verify complexity matches expectations
139 | assert expected_complexity in estimate.complexity.value.lower()
140 |
141 | # Verify estimates are reasonable for usage tracking
142 | if expected_complexity == "simple":
143 | assert estimate.llm_calls <= 1
144 | elif expected_complexity == "premium":
145 | assert estimate.llm_calls >= 8
146 |
147 |
148 | @pytest.mark.integration
149 | class TestDatabasePoolConfigIntegration:
150 | """Test integration of DatabasePoolConfig with database operations."""
151 |
152 | def test_database_config_with_real_sqlite(self):
153 | """Test database configuration with real SQLite database."""
154 | # Use SQLite for testing (no external dependencies)
155 | database_url = "sqlite:///test_integration.db"
156 |
157 | config = DatabasePoolConfig(
158 | pool_size=5,
159 | max_overflow=2,
160 | pool_timeout=30,
161 | pool_recycle=3600,
162 | max_database_connections=20,
163 | expected_concurrent_users=3,
164 | connections_per_user=1.0,
165 | )
166 |
167 | # Create engine with configuration
168 | engine_kwargs = {
169 | "url": database_url,
170 | **config.get_pool_kwargs(),
171 | }
172 |
173 | # Remove poolclass for SQLite (not applicable)
174 | if "sqlite" in database_url:
175 | engine_kwargs.pop("poolclass", None)
176 |
177 | engine = create_engine(**engine_kwargs)
178 |
179 | try:
180 | # Test connection
181 | with engine.connect() as conn:
182 | result = conn.execute(text("SELECT 1"))
183 | assert result.scalar() == 1
184 |
185 | # Test monitoring setup (should not error)
186 | config.setup_pool_monitoring(engine)
187 |
188 | finally:
189 | engine.dispose()
190 | # Clean up test database
191 | if os.path.exists("test_integration.db"):
192 | os.remove("test_integration.db")
193 |
194 | @patch.dict(
195 | os.environ,
196 | {
197 | "DB_POOL_SIZE": "8",
198 | "DB_MAX_OVERFLOW": "4",
199 | "DB_POOL_TIMEOUT": "45",
200 | },
201 | )
202 | def test_config_respects_environment_variables(self):
203 | """Test that configuration respects environment variables in integration."""
204 | config = DatabasePoolConfig()
205 |
206 | # Should use environment variable values
207 | assert config.pool_size == 8
208 | assert config.max_overflow == 4
209 | assert config.pool_timeout == 45
210 |
211 | def test_legacy_compatibility_integration(self):
212 | """Test legacy DatabaseConfig compatibility in real usage."""
213 | from maverick_mcp.providers.interfaces.persistence import DatabaseConfig
214 |
215 | # Create enhanced config
216 | enhanced_config = DatabasePoolConfig(
217 | pool_size=10,
218 | max_overflow=5,
219 | pool_timeout=30,
220 | pool_recycle=1800,
221 | )
222 |
223 | # Convert to legacy format
224 | database_url = "sqlite:///test_legacy.db"
225 | legacy_config = enhanced_config.to_legacy_config(database_url)
226 |
227 | # Should be usable with existing code patterns
228 | assert isinstance(legacy_config, DatabaseConfig)
229 | assert legacy_config.database_url == database_url
230 | assert legacy_config.pool_size == 10
231 | assert legacy_config.max_overflow == 5
232 |
233 | def test_production_validation_integration(self):
234 | """Test production validation with realistic configurations."""
235 | # Test development config - should warn but not fail
236 | dev_config = DatabasePoolConfig(
237 | pool_size=5,
238 | max_overflow=2,
239 | pool_timeout=30,
240 | pool_recycle=3600,
241 | )
242 |
243 | with patch("maverick_mcp.config.database.logger") as mock_logger:
244 | result = validate_production_config(dev_config)
245 | assert result is True # Should pass with warnings
246 | # Should have logged warnings about small pool size
247 | assert mock_logger.warning.called
248 |
249 | # Test production config - should pass without warnings
250 | prod_config = DatabasePoolConfig(
251 | pool_size=25,
252 | max_overflow=15,
253 | pool_timeout=30,
254 | pool_recycle=3600,
255 | )
256 |
257 | with patch("maverick_mcp.config.database.logger") as mock_logger:
258 | result = validate_production_config(prod_config)
259 | assert result is True
260 | # Should have passed without warnings
261 | info_call = mock_logger.info.call_args[0][0]
262 | assert "validation passed" in info_call.lower()
263 |
264 |
265 | @pytest.mark.integration
266 | class TestConfigurationMonitoring:
267 | """Test monitoring and alerting integration."""
268 |
269 | def test_tool_estimation_alerting_integration(self):
270 | """Test tool estimation alerting with realistic usage patterns."""
271 | get_tool_estimation_config()
272 |
273 | # Test scenarios that should trigger alerts
274 | alert_scenarios = [
275 | # High LLM usage
276 | ("get_stock_price", 10, 1000, "should alert on unexpected LLM usage"),
277 | # High token usage
278 | ("calculate_sma", 1, 50000, "should alert on excessive tokens"),
279 | # Both high
280 | ("get_market_hours", 20, 40000, "should alert on both metrics"),
281 | ]
282 |
283 | for tool_name, llm_calls, tokens, description in alert_scenarios:
284 | should_alert, message = should_alert_for_usage(tool_name, llm_calls, tokens)
285 | assert should_alert, f"Failed: {description}"
286 | assert len(message) > 0, f"Alert message should not be empty: {description}"
287 | assert "Critical" in message or "Warning" in message
288 |
289 | def test_database_pool_monitoring_integration(self):
290 | """Test database pool monitoring integration."""
291 | config = DatabasePoolConfig(pool_size=10, echo_pool=True)
292 |
293 | # Create mock engine to test monitoring
294 | mock_engine = Mock()
295 | mock_pool = Mock()
296 | mock_engine.pool = mock_pool
297 |
298 | # Test different pool usage scenarios
299 | scenarios = [
300 | (5, "normal usage", False, False), # 50% usage
301 | (8, "warning usage", True, False), # 80% usage
302 | (10, "critical usage", True, True), # 100% usage
303 | ]
304 |
305 | with patch("maverick_mcp.config.database.event") as mock_event:
306 | config.setup_pool_monitoring(mock_engine)
307 |
308 | # Get the connect listener function
309 | connect_listener = None
310 | for call in mock_event.listens_for.call_args_list:
311 | if call[0][1] == "connect":
312 | connect_listener = call[0][2]
313 | break
314 |
315 | assert connect_listener is not None
316 |
317 | # Test each scenario
318 | for checked_out, _description, should_warn, should_error in scenarios:
319 | mock_pool.checkedout.return_value = checked_out
320 | mock_pool.checkedin.return_value = 10 - checked_out
321 |
322 | with patch("maverick_mcp.config.database.logger") as mock_logger:
323 | connect_listener(None, None)
324 |
325 | if should_warn:
326 | mock_logger.warning.assert_called()
327 | if should_error:
328 | mock_logger.error.assert_called()
329 |
330 | def test_configuration_logging_integration(self):
331 | """Test that configuration logging works correctly."""
332 | with patch("maverick_mcp.config.database.logger") as mock_logger:
333 | DatabasePoolConfig(
334 | pool_size=15,
335 | max_overflow=8,
336 | expected_concurrent_users=20,
337 | connections_per_user=1.2,
338 | max_database_connections=100,
339 | )
340 |
341 | # Should have logged configuration summary
342 | assert mock_logger.info.called
343 | log_message = mock_logger.info.call_args[0][0]
344 | assert "Database pool configured" in log_message
345 | assert "pool_size=15" in log_message
346 |
347 |
348 | @pytest.mark.integration
349 | class TestRealWorldIntegrationScenarios:
350 | """Test realistic integration scenarios."""
351 |
352 | def test_microservice_deployment_scenario(self):
353 | """Test configuration for microservice deployment."""
354 | # Simulate microservice environment
355 | with patch.dict(
356 | os.environ,
357 | {
358 | "DB_POOL_SIZE": "8",
359 | "DB_MAX_OVERFLOW": "4",
360 | "DB_MAX_CONNECTIONS": "50",
361 | "DB_EXPECTED_CONCURRENT_USERS": "10",
362 | "ENVIRONMENT": "production",
363 | },
364 | ):
365 | # Get configuration from environment
366 | db_config = DatabasePoolConfig()
367 |
368 | # Should be suitable for microservice
369 | assert db_config.pool_size == 8
370 | assert db_config.max_overflow == 4
371 | assert db_config.expected_concurrent_users == 10
372 |
373 | # Should pass production validation
374 | assert validate_production_config(db_config) is True
375 |
376 | # Test tool estimation in this context
377 | get_tool_estimation_config()
378 |
379 | # Should handle typical microservice tools
380 | api_tools = [
381 | "get_stock_price",
382 | "get_company_info",
383 | "get_rsi_analysis",
384 | "fetch_stock_data",
385 | ]
386 |
387 | for tool in api_tools:
388 | estimate = get_tool_estimate(tool)
389 | assert estimate is not None
390 | assert estimate.confidence > 0.0
391 |
392 | def test_development_environment_scenario(self):
393 | """Test configuration for development environment."""
394 | # Simulate development environment
395 | with patch.dict(
396 | os.environ,
397 | {
398 | "DB_POOL_SIZE": "3",
399 | "DB_MAX_OVERFLOW": "1",
400 | "DB_ECHO_POOL": "true",
401 | "ENVIRONMENT": "development",
402 | },
403 | ):
404 | db_config = DatabasePoolConfig()
405 |
406 | # Should use development-friendly settings
407 | assert db_config.pool_size == 3
408 | assert db_config.max_overflow == 1
409 | assert db_config.echo_pool is True
410 |
411 | # Should handle development testing
412 | get_tool_estimation_config()
413 |
414 | # Should provide estimates for development tools
415 | dev_tools = ["generate_dev_token", "clear_cache", "get_cached_price_data"]
416 |
417 | for tool in dev_tools:
418 | estimate = get_tool_estimate(tool)
419 | assert estimate.complexity.value in ["simple", "standard"]
420 |
421 | def test_high_traffic_scenario(self):
422 | """Test configuration for high traffic scenario."""
423 | # High traffic configuration
424 | db_config = DatabasePoolConfig(
425 | pool_size=50,
426 | max_overflow=30,
427 | expected_concurrent_users=100,
428 | connections_per_user=1.2,
429 | max_database_connections=200,
430 | )
431 |
432 | # Should handle the expected load
433 | total_capacity = db_config.pool_size + db_config.max_overflow
434 | expected_demand = (
435 | db_config.expected_concurrent_users * db_config.connections_per_user
436 | )
437 | assert total_capacity >= expected_demand
438 |
439 | # Should pass production validation
440 | assert validate_production_config(db_config) is True
441 |
442 | # Test tool estimation for high-usage tools
443 | high_usage_tools = [
444 | "get_full_technical_analysis",
445 | "analyze_market_with_agent",
446 | "get_all_screening_recommendations",
447 | ]
448 |
449 | for tool in high_usage_tools:
450 | estimate = get_tool_estimate(tool)
451 | # Should have monitoring in place for expensive tools
452 | should_alert, _ = should_alert_for_usage(
453 | tool,
454 | estimate.llm_calls * 2, # Double the expected usage
455 | estimate.total_tokens * 2,
456 | )
457 | assert should_alert # Should trigger alerts for high usage
458 |
459 | def test_configuration_change_propagation(self):
460 | """Test that configuration changes propagate correctly."""
461 | # Start with one configuration
462 | original_config = get_tool_estimation_config()
463 | original_estimate = get_tool_estimate("get_stock_price")
464 |
465 | # Configuration should be singleton
466 | new_config = get_tool_estimation_config()
467 | assert new_config is original_config
468 |
469 | # Estimates should be consistent
470 | new_estimate = get_tool_estimate("get_stock_price")
471 | assert new_estimate == original_estimate
472 |
473 | def test_error_recovery_integration(self):
474 | """Test error recovery in integrated scenarios."""
475 | # Test database connection failure recovery
476 | config = DatabasePoolConfig(
477 | pool_size=5,
478 | max_overflow=2,
479 | pool_timeout=1, # Short timeout for testing
480 | )
481 |
482 | # Should handle connection errors gracefully
483 | try:
484 | # This would fail in a real scenario with invalid URL
485 | engine_kwargs = config.get_pool_kwargs()
486 | assert "pool_size" in engine_kwargs
487 | except Exception:
488 | # Should not prevent configuration from working
489 | assert config.pool_size == 5
490 |
491 | def test_monitoring_data_collection(self):
492 | """Test that monitoring data can be collected for analysis."""
493 | tool_config = get_tool_estimation_config()
494 |
495 | # Collect monitoring data
496 | stats = tool_config.get_summary_stats()
497 |
498 | # Should provide useful metrics
499 | assert "total_tools" in stats
500 | assert "by_complexity" in stats
501 | assert "avg_confidence" in stats
502 |
503 | # Should be suitable for monitoring dashboards
504 | assert stats["total_tools"] > 0
505 | assert 0 <= stats["avg_confidence"] <= 1
506 |
507 | # Complexity distribution should make sense
508 | complexity_counts = stats["by_complexity"]
509 | total_by_complexity = sum(complexity_counts.values())
510 | assert total_by_complexity == stats["total_tools"]
511 |
512 | def test_configuration_validation_end_to_end(self):
513 | """Test end-to-end configuration validation."""
514 | # Test complete validation pipeline
515 |
516 | # 1. Tool estimation configuration
517 | tool_config = get_tool_estimation_config()
518 | assert (
519 | len(tool_config.tool_estimates) > 20
520 | ) # Should have substantial tool coverage
521 |
522 | # 2. Database configuration
523 | db_config = DatabasePoolConfig(
524 | pool_size=20,
525 | max_overflow=10,
526 | expected_concurrent_users=25,
527 | connections_per_user=1.2,
528 | max_database_connections=100,
529 | )
530 |
531 | # 3. Production readiness
532 | assert validate_production_config(db_config) is True
533 |
534 | # 4. Integration compatibility
535 | legacy_config = db_config.to_legacy_config("postgresql://test")
536 | enhanced_again = DatabasePoolConfig.from_legacy_config(legacy_config)
537 | assert enhanced_again.pool_size == db_config.pool_size
538 |
539 | # 5. Monitoring setup
540 | thresholds = db_config.get_monitoring_thresholds()
541 | assert thresholds["warning_threshold"] > 0
542 | assert thresholds["critical_threshold"] > thresholds["warning_threshold"]
543 |
```
--------------------------------------------------------------------------------
/maverick_mcp/providers/optimized_stock_data.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Optimized stock data provider with performance enhancements.
3 |
4 | This module provides enhanced stock data access with:
5 | - Request-level caching for expensive operations
6 | - Optimized database queries with proper indexing
7 | - Connection pooling and query monitoring
8 | - Smart cache invalidation strategies
9 | """
10 |
11 | import logging
12 | from collections.abc import Awaitable, Callable
13 | from datetime import datetime
14 | from typing import Any
15 |
16 | import pandas as pd
17 | from sqlalchemy import select, text
18 | from sqlalchemy.ext.asyncio import AsyncSession
19 | from sqlalchemy.orm import joinedload
20 |
21 | from maverick_mcp.data.models import (
22 | MaverickStocks,
23 | PriceCache,
24 | Stock,
25 | )
26 | from maverick_mcp.data.performance import (
27 | cached,
28 | monitored_db_session,
29 | query_optimizer,
30 | request_cache,
31 | )
32 |
33 | logger = logging.getLogger(__name__)
34 |
35 |
36 | class OptimizedStockDataProvider:
37 | """
38 | Performance-optimized stock data provider.
39 |
40 | This provider implements:
41 | - Smart caching strategies for different data types
42 | - Optimized database queries with minimal N+1 issues
43 | - Connection pooling and query monitoring
44 | - Efficient bulk operations for large datasets
45 | """
46 |
47 | def __init__(self):
48 | self.cache_ttl_stock_data = 3600 # 1 hour for stock data
49 | self.cache_ttl_screening = 7200 # 2 hours for screening results
50 | self.cache_ttl_market_data = 300 # 5 minutes for real-time data
51 |
52 | @cached(data_type="stock_data", ttl=3600)
53 | @query_optimizer.monitor_query("get_stock_basic_info")
54 | async def get_stock_basic_info(self, symbol: str) -> dict[str, Any] | None:
55 | """
56 | Get basic stock information with caching.
57 |
58 | Args:
59 | symbol: Stock ticker symbol
60 |
61 | Returns:
62 | Stock information dictionary or None if not found
63 | """
64 | async with monitored_db_session("get_stock_basic_info") as session:
65 | async_session: AsyncSession = session
66 | stmt = select(Stock).where(Stock.ticker_symbol == symbol.upper())
67 | result = await async_session.execute(stmt)
68 | stock = result.scalars().first()
69 |
70 | if stock:
71 | return {
72 | "symbol": stock.ticker_symbol,
73 | "name": stock.company_name,
74 | "sector": stock.sector,
75 | "industry": stock.industry,
76 | "exchange": stock.exchange,
77 | "country": stock.country,
78 | "currency": stock.currency,
79 | }
80 |
81 | return None
82 |
83 | @cached(data_type="stock_data", ttl=1800)
84 | @query_optimizer.monitor_query("get_stock_price_data")
85 | async def get_stock_price_data(
86 | self,
87 | symbol: str,
88 | start_date: str,
89 | end_date: str | None = None,
90 | use_optimized_query: bool = True,
91 | ) -> pd.DataFrame:
92 | """
93 | Get stock price data with optimized queries and caching.
94 |
95 | Args:
96 | symbol: Stock ticker symbol
97 | start_date: Start date in YYYY-MM-DD format
98 | end_date: End date in YYYY-MM-DD format
99 | use_optimized_query: Use optimized query with proper indexing
100 |
101 | Returns:
102 | DataFrame with OHLCV data
103 | """
104 | if not end_date:
105 | end_date = datetime.now().strftime("%Y-%m-%d")
106 |
107 | async with monitored_db_session("get_stock_price_data") as session:
108 | async_session: AsyncSession = session
109 | if use_optimized_query:
110 | # Optimized query using the composite index (stock_id, date)
111 | query = text(
112 | """
113 | SELECT
114 | pc.date,
115 | pc.open_price as "open",
116 | pc.high_price as "high",
117 | pc.low_price as "low",
118 | pc.close_price as "close",
119 | pc.volume
120 | FROM stocks_pricecache pc
121 | INNER JOIN stocks_stock s ON pc.stock_id = s.stock_id
122 | WHERE s.ticker_symbol = :symbol
123 | AND pc.date >= :start_date::date
124 | AND pc.date <= :end_date::date
125 | ORDER BY pc.date
126 | """
127 | )
128 |
129 | result = await async_session.execute(
130 | query,
131 | {
132 | "symbol": symbol.upper(),
133 | "start_date": start_date,
134 | "end_date": end_date,
135 | },
136 | )
137 |
138 | rows = result.fetchall()
139 | column_index = pd.Index([str(key) for key in result.keys()])
140 | df = pd.DataFrame(rows, columns=column_index)
141 | else:
142 | # Traditional SQLAlchemy query (for comparison)
143 | stmt = (
144 | select(
145 | PriceCache.date,
146 | PriceCache.open_price.label("open"),
147 | PriceCache.high_price.label("high"),
148 | PriceCache.low_price.label("low"),
149 | PriceCache.close_price.label("close"),
150 | PriceCache.volume,
151 | )
152 | .join(Stock)
153 | .where(
154 | Stock.ticker_symbol == symbol.upper(),
155 | PriceCache.date >= pd.to_datetime(start_date).date(),
156 | PriceCache.date <= pd.to_datetime(end_date).date(),
157 | )
158 | .order_by(PriceCache.date)
159 | )
160 |
161 | result = await async_session.execute(stmt)
162 | rows = result.fetchall()
163 | column_index = pd.Index([str(key) for key in result.keys()])
164 | df = pd.DataFrame(rows, columns=column_index)
165 |
166 | if not df.empty:
167 | df["date"] = pd.to_datetime(df["date"])
168 | df.set_index("date", inplace=True)
169 |
170 | # Convert decimal types to float for performance
171 | for col in ["open", "high", "low", "close"]:
172 | df[col] = pd.to_numeric(df[col], errors="coerce")
173 |
174 | df["volume"] = pd.to_numeric(df["volume"], errors="coerce")
175 |
176 | return df
177 |
178 | @cached(data_type="screening", ttl=7200)
179 | @query_optimizer.monitor_query("get_maverick_recommendations")
180 | async def get_maverick_recommendations(
181 | self,
182 | limit: int = 50,
183 | min_score: float | None = None,
184 | use_optimized_query: bool = True,
185 | ) -> list[dict[str, Any]]:
186 | """
187 | Get Maverick bullish recommendations with performance optimizations.
188 |
189 | Args:
190 | limit: Maximum number of results
191 | min_score: Minimum score threshold
192 | use_optimized_query: Use optimized query with proper indexing
193 |
194 | Returns:
195 | List of recommendation dictionaries
196 | """
197 | async with monitored_db_session("get_maverick_recommendations") as session:
198 | async_session: AsyncSession = session
199 | if use_optimized_query:
200 | # Use raw SQL with optimized indexes
201 | where_clause = ""
202 | params: dict[str, Any] = {"limit": limit}
203 |
204 | if min_score is not None:
205 | where_clause = "WHERE ms.combined_score >= :min_score"
206 | params["min_score"] = min_score
207 |
208 | query = text(
209 | f"""
210 | SELECT
211 | s.ticker_symbol,
212 | s.company_name,
213 | s.sector,
214 | s.industry,
215 | ms.combined_score AS score,
216 | ms.pattern_detected AS rank,
217 | ms.date_analyzed,
218 | ms.analysis_data
219 | FROM stocks_maverickstocks ms
220 | INNER JOIN stocks_stock s ON ms.stock_id = s.stock_id
221 | {where_clause}
222 | ORDER BY ms.combined_score DESC, ms.pattern_detected ASC
223 | LIMIT :limit
224 | """
225 | )
226 |
227 | result = await async_session.execute(query, params)
228 | rows = result.fetchall()
229 |
230 | return [
231 | {
232 | "symbol": row.ticker_symbol,
233 | "name": row.company_name,
234 | "sector": row.sector,
235 | "industry": row.industry,
236 | "score": float(getattr(row, "score", 0) or 0),
237 | "rank": getattr(row, "rank", None),
238 | "date_analyzed": (
239 | row.date_analyzed.isoformat() if row.date_analyzed else None
240 | ),
241 | "analysis_data": getattr(row, "analysis_data", None),
242 | }
243 | for row in rows
244 | ]
245 | else:
246 | # Traditional SQLAlchemy query with eager loading
247 | stmt = (
248 | select(MaverickStocks)
249 | .options(joinedload(MaverickStocks.stock))
250 | .order_by(
251 | MaverickStocks.combined_score.desc(),
252 | MaverickStocks.pattern_detected.asc(),
253 | )
254 | .limit(limit)
255 | )
256 |
257 | if min_score is not None:
258 | stmt = stmt.where(MaverickStocks.combined_score >= min_score)
259 |
260 | result = await async_session.execute(stmt)
261 | recommendations = result.scalars().all()
262 |
263 | formatted: list[dict[str, Any]] = []
264 | for rec in recommendations:
265 | stock = getattr(rec, "stock", None)
266 | analysis_date = getattr(rec, "date_analyzed", None)
267 | isoformatted = (
268 | analysis_date.isoformat()
269 | if analysis_date is not None
270 | and hasattr(analysis_date, "isoformat")
271 | else None
272 | )
273 |
274 | formatted.append(
275 | {
276 | "symbol": getattr(stock, "ticker_symbol", None),
277 | "name": getattr(stock, "company_name", None),
278 | "sector": getattr(stock, "sector", None),
279 | "industry": getattr(stock, "industry", None),
280 | "score": float(getattr(rec, "combined_score", 0) or 0),
281 | "rank": getattr(rec, "pattern_detected", None),
282 | "date_analyzed": isoformatted,
283 | "analysis_data": getattr(rec, "analysis_data", None),
284 | }
285 | )
286 |
287 | return formatted
288 |
289 | @cached(data_type="screening", ttl=7200)
290 | @query_optimizer.monitor_query("get_trending_recommendations")
291 | async def get_trending_recommendations(
292 | self,
293 | limit: int = 50,
294 | min_momentum_score: float | None = None,
295 | ) -> list[dict[str, Any]]:
296 | """
297 | Get trending supply/demand breakout recommendations with optimized queries.
298 |
299 | Args:
300 | limit: Maximum number of results
301 | min_momentum_score: Minimum momentum score threshold
302 |
303 | Returns:
304 | List of recommendation dictionaries
305 | """
306 | async with monitored_db_session("get_trending_recommendations") as session:
307 | async_session: AsyncSession = session
308 | # Use optimized raw SQL query
309 | where_clause = ""
310 | params: dict[str, Any] = {"limit": limit}
311 |
312 | if min_momentum_score is not None:
313 | where_clause = "WHERE ms.momentum_score >= :min_momentum_score"
314 | params["min_momentum_score"] = min_momentum_score
315 |
316 | query = text(
317 | f"""
318 | SELECT
319 | s.ticker_symbol,
320 | s.company_name,
321 | s.sector,
322 | s.industry,
323 | ms.momentum_score,
324 | ms.stage,
325 | ms.date_analyzed,
326 | ms.analysis_data
327 | FROM stocks_supply_demand_breakouts ms
328 | INNER JOIN stocks_stock s ON ms.stock_id = s.stock_id
329 | {where_clause}
330 | ORDER BY ms.momentum_score DESC
331 | LIMIT :limit
332 | """
333 | )
334 |
335 | result = await async_session.execute(query, params)
336 | rows = result.fetchall()
337 |
338 | return [
339 | {
340 | "symbol": row.ticker_symbol,
341 | "name": row.company_name,
342 | "sector": row.sector,
343 | "industry": row.industry,
344 | "momentum_score": (
345 | float(row.momentum_score) if row.momentum_score else 0
346 | ),
347 | "stage": row.stage,
348 | "date_analyzed": (
349 | row.date_analyzed.isoformat() if row.date_analyzed else None
350 | ),
351 | "analysis_data": row.analysis_data,
352 | }
353 | for row in rows
354 | ]
355 |
356 | @cached(data_type="market_data", ttl=300)
357 | @query_optimizer.monitor_query("get_high_volume_stocks")
358 | async def get_high_volume_stocks(
359 | self,
360 | date: str | None = None,
361 | limit: int = 100,
362 | min_volume: int = 1000000,
363 | ) -> list[dict[str, Any]]:
364 | """
365 | Get high volume stocks for a specific date with optimized query.
366 |
367 | Args:
368 | date: Date to filter (default: latest available)
369 | limit: Maximum number of results
370 | min_volume: Minimum volume threshold
371 |
372 | Returns:
373 | List of high volume stock data
374 | """
375 | if not date:
376 | date = datetime.now().strftime("%Y-%m-%d")
377 |
378 | async with monitored_db_session("get_high_volume_stocks") as session:
379 | async_session: AsyncSession = session
380 | # Use optimized query with volume index
381 | query = text(
382 | """
383 | SELECT
384 | s.ticker_symbol,
385 | s.company_name,
386 | s.sector,
387 | pc.volume,
388 | pc.close_price,
389 | pc.date
390 | FROM stocks_pricecache pc
391 | INNER JOIN stocks_stock s ON pc.stock_id = s.stock_id
392 | WHERE pc.date = :date::date
393 | AND pc.volume >= :min_volume
394 | ORDER BY pc.volume DESC
395 | LIMIT :limit
396 | """
397 | )
398 |
399 | result = await async_session.execute(
400 | query,
401 | {
402 | "date": date,
403 | "min_volume": min_volume,
404 | "limit": limit,
405 | },
406 | )
407 |
408 | rows = result.fetchall()
409 |
410 | return [
411 | {
412 | "symbol": row.ticker_symbol,
413 | "name": row.company_name,
414 | "sector": row.sector,
415 | "volume": int(row.volume) if row.volume else 0,
416 | "close_price": float(row.close_price) if row.close_price else 0,
417 | "date": row.date.isoformat() if row.date else None,
418 | }
419 | for row in rows
420 | ]
421 |
422 | @query_optimizer.monitor_query("bulk_get_stock_data")
423 | async def bulk_get_stock_data(
424 | self,
425 | symbols: list[str],
426 | start_date: str,
427 | end_date: str | None = None,
428 | ) -> dict[str, pd.DataFrame]:
429 | """
430 | Efficiently fetch stock data for multiple symbols using bulk operations.
431 |
432 | Args:
433 | symbols: List of stock symbols
434 | start_date: Start date in YYYY-MM-DD format
435 | end_date: End date in YYYY-MM-DD format
436 |
437 | Returns:
438 | Dictionary mapping symbols to DataFrames
439 | """
440 | if not end_date:
441 | end_date = datetime.now().strftime("%Y-%m-%d")
442 |
443 | # Convert symbols to uppercase for consistency
444 | symbols = [s.upper() for s in symbols]
445 |
446 | async with monitored_db_session("bulk_get_stock_data") as session:
447 | async_session: AsyncSession = session
448 | # Use bulk query with IN clause for efficiency
449 | query = text(
450 | """
451 | SELECT
452 | s.ticker_symbol,
453 | pc.date,
454 | pc.open_price as "open",
455 | pc.high_price as "high",
456 | pc.low_price as "low",
457 | pc.close_price as "close",
458 | pc.volume
459 | FROM stocks_pricecache pc
460 | INNER JOIN stocks_stock s ON pc.stock_id = s.stock_id
461 | WHERE s.ticker_symbol = ANY(:symbols)
462 | AND pc.date >= :start_date::date
463 | AND pc.date <= :end_date::date
464 | ORDER BY s.ticker_symbol, pc.date
465 | """
466 | )
467 |
468 | result = await async_session.execute(
469 | query,
470 | {
471 | "symbols": symbols,
472 | "start_date": start_date,
473 | "end_date": end_date,
474 | },
475 | )
476 |
477 | # Group results by symbol
478 | symbol_data = {}
479 | for row in result.fetchall():
480 | symbol = row.ticker_symbol
481 | if symbol not in symbol_data:
482 | symbol_data[symbol] = []
483 |
484 | symbol_data[symbol].append(
485 | {
486 | "date": row.date,
487 | "open": row.open,
488 | "high": row.high,
489 | "low": row.low,
490 | "close": row.close,
491 | "volume": row.volume,
492 | }
493 | )
494 |
495 | # Convert to DataFrames
496 | result_dfs = {}
497 | for symbol in symbols:
498 | if symbol in symbol_data:
499 | df = pd.DataFrame(symbol_data[symbol])
500 | df["date"] = pd.to_datetime(df["date"])
501 | df.set_index("date", inplace=True)
502 |
503 | # Convert decimal types to float
504 | for col in ["open", "high", "low", "close"]:
505 | df[col] = pd.to_numeric(df[col], errors="coerce")
506 |
507 | df["volume"] = pd.to_numeric(df["volume"], errors="coerce")
508 | result_dfs[symbol] = df
509 | else:
510 | # Return empty DataFrame for missing symbols
511 | result_dfs[symbol] = pd.DataFrame(
512 | columns=pd.Index(["open", "high", "low", "close", "volume"])
513 | )
514 |
515 | return result_dfs
516 |
517 | async def invalidate_cache_for_symbol(self, symbol: str) -> None:
518 | """
519 | Invalidate all cached data for a specific symbol.
520 |
521 | Args:
522 | symbol: Stock symbol to invalidate
523 | """
524 | invalidate_basic_info: Callable[[str], Awaitable[None]] | None = getattr(
525 | self.get_stock_basic_info, "invalidate_cache", None
526 | )
527 | if invalidate_basic_info is not None:
528 | await invalidate_basic_info(symbol)
529 |
530 | # Invalidate stock price data (pattern-based)
531 | await request_cache.delete_pattern(
532 | f"cache:*get_stock_price_data*{symbol.upper()}*"
533 | )
534 |
535 | logger.info(f"Cache invalidated for symbol: {symbol}")
536 |
537 | async def invalidate_screening_cache(self) -> None:
538 | """Invalidate all screening-related cache."""
539 | patterns = [
540 | "cache:*get_maverick_recommendations*",
541 | "cache:*get_trending_recommendations*",
542 | "cache:*get_high_volume_stocks*",
543 | ]
544 |
545 | for pattern in patterns:
546 | await request_cache.delete_pattern(pattern)
547 |
548 | logger.info("Screening cache invalidated")
549 |
550 | async def get_performance_metrics(self) -> dict[str, Any]:
551 | """Get performance metrics for the optimized provider."""
552 | return {
553 | "cache_metrics": request_cache.get_metrics(),
554 | "query_stats": query_optimizer.get_query_stats(),
555 | "cache_ttl_config": {
556 | "stock_data": self.cache_ttl_stock_data,
557 | "screening": self.cache_ttl_screening,
558 | "market_data": self.cache_ttl_market_data,
559 | },
560 | }
561 |
562 |
563 | # Global instance
564 | optimized_stock_provider = OptimizedStockDataProvider()
565 |
```
--------------------------------------------------------------------------------
/tests/test_optimized_research_agent.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Comprehensive test suite for OptimizedDeepResearchAgent.
3 |
4 | Tests the core functionality of the optimized research agent including:
5 | - Model selection logic
6 | - Token budgeting
7 | - Confidence tracking
8 | - Content filtering
9 | - Parallel processing
10 | - Error handling
11 | """
12 |
13 | import time
14 | from unittest.mock import AsyncMock, Mock, patch
15 |
16 | import pytest
17 |
18 | from maverick_mcp.agents.optimized_research import (
19 | OptimizedContentAnalyzer,
20 | OptimizedDeepResearchAgent,
21 | create_optimized_research_agent,
22 | )
23 | from maverick_mcp.providers.openrouter_provider import OpenRouterProvider, TaskType
24 | from maverick_mcp.utils.llm_optimization import (
25 | AdaptiveModelSelector,
26 | ConfidenceTracker,
27 | ModelConfiguration,
28 | ProgressiveTokenBudgeter,
29 | )
30 |
31 |
32 | class TestOptimizedContentAnalyzer:
33 | """Test the OptimizedContentAnalyzer component."""
34 |
35 | @pytest.fixture
36 | def mock_openrouter(self):
37 | """Create a mock OpenRouter provider."""
38 | provider = Mock(spec=OpenRouterProvider)
39 | provider.get_llm = Mock()
40 | return provider
41 |
42 | @pytest.fixture
43 | def analyzer(self, mock_openrouter):
44 | """Create an OptimizedContentAnalyzer instance."""
45 | return OptimizedContentAnalyzer(mock_openrouter)
46 |
47 | @pytest.mark.asyncio
48 | async def test_analyze_content_optimized_success(self, analyzer, mock_openrouter):
49 | """Test successful optimized content analysis."""
50 | # Setup mock LLM response
51 | mock_llm = AsyncMock()
52 | mock_response = Mock()
53 | mock_response.content = '{"insights": ["Test insight"], "sentiment": {"direction": "bullish", "confidence": 0.8}}'
54 | mock_llm.ainvoke.return_value = mock_response
55 | mock_openrouter.get_llm.return_value = mock_llm
56 |
57 | # Test analysis
58 | result = await analyzer.analyze_content_optimized(
59 | content="Test financial content about stocks",
60 | persona="moderate",
61 | analysis_focus="market_analysis",
62 | time_budget_seconds=30.0,
63 | current_confidence=0.5,
64 | )
65 |
66 | # Verify results
67 | assert result["insights"] == ["Test insight"]
68 | assert result["sentiment"]["direction"] == "bullish"
69 | assert result["sentiment"]["confidence"] == 0.8
70 | assert result["optimization_applied"] is True
71 | assert "model_used" in result
72 | assert "execution_time" in result
73 |
74 | @pytest.mark.asyncio
75 | async def test_analyze_empty_content(self, analyzer):
76 | """Test handling of empty content."""
77 | result = await analyzer.analyze_content_optimized(
78 | content="",
79 | persona="moderate",
80 | analysis_focus="general",
81 | time_budget_seconds=30.0,
82 | )
83 |
84 | assert result["empty_content"] is True
85 | assert result["insights"] == []
86 | assert result["sentiment"]["direction"] == "neutral"
87 | assert result["sentiment"]["confidence"] == 0.0
88 |
89 | @pytest.mark.asyncio
90 | async def test_analyze_with_timeout(self, analyzer, mock_openrouter):
91 | """Test timeout handling during analysis."""
92 | # Setup mock to timeout
93 | mock_llm = AsyncMock()
94 | mock_llm.ainvoke.side_effect = TimeoutError("Analysis timeout")
95 | mock_openrouter.get_llm.return_value = mock_llm
96 |
97 | result = await analyzer.analyze_content_optimized(
98 | content="Test content",
99 | persona="aggressive",
100 | analysis_focus="technical",
101 | time_budget_seconds=5.0,
102 | )
103 |
104 | # Should return fallback analysis
105 | assert "insights" in result
106 | assert "sentiment" in result
107 | assert result["sentiment"]["direction"] in ["bullish", "bearish", "neutral"]
108 |
109 | @pytest.mark.asyncio
110 | async def test_batch_analyze_content(self, analyzer, mock_openrouter):
111 | """Test batch content analysis with parallel processing."""
112 | # Setup mock parallel processor
113 | with patch.object(
114 | analyzer.parallel_processor,
115 | "parallel_content_analysis",
116 | new_callable=AsyncMock,
117 | ) as mock_parallel:
118 | mock_results = [
119 | {
120 | "analysis": {
121 | "insights": ["Insight 1"],
122 | "sentiment": {"direction": "bullish", "confidence": 0.7},
123 | }
124 | },
125 | {
126 | "analysis": {
127 | "insights": ["Insight 2"],
128 | "sentiment": {"direction": "neutral", "confidence": 0.6},
129 | }
130 | },
131 | ]
132 | mock_parallel.return_value = mock_results
133 |
134 | sources = [
135 | {"content": "Source 1 content", "url": "http://example1.com"},
136 | {"content": "Source 2 content", "url": "http://example2.com"},
137 | ]
138 |
139 | results = await analyzer.batch_analyze_content(
140 | sources=sources,
141 | persona="moderate",
142 | analysis_type="fundamental",
143 | time_budget_seconds=60.0,
144 | current_confidence=0.5,
145 | )
146 |
147 | assert len(results) == 2
148 | assert results[0]["analysis"]["insights"] == ["Insight 1"]
149 | assert results[1]["analysis"]["sentiment"]["direction"] == "neutral"
150 |
151 |
152 | class TestOptimizedDeepResearchAgent:
153 | """Test the main OptimizedDeepResearchAgent."""
154 |
155 | @pytest.fixture
156 | def mock_openrouter(self):
157 | """Create a mock OpenRouter provider."""
158 | provider = Mock(spec=OpenRouterProvider)
159 | provider.get_llm = Mock(return_value=AsyncMock())
160 | return provider
161 |
162 | @pytest.fixture
163 | def mock_search_provider(self):
164 | """Create a mock search provider."""
165 | provider = AsyncMock()
166 | provider.search = AsyncMock(
167 | return_value=[
168 | {
169 | "title": "Test Result 1",
170 | "url": "http://example1.com",
171 | "content": "Financial analysis content",
172 | },
173 | {
174 | "title": "Test Result 2",
175 | "url": "http://example2.com",
176 | "content": "Market research content",
177 | },
178 | ]
179 | )
180 | return provider
181 |
182 | @pytest.fixture
183 | def agent(self, mock_openrouter, mock_search_provider):
184 | """Create an OptimizedDeepResearchAgent instance."""
185 | agent = OptimizedDeepResearchAgent(
186 | openrouter_provider=mock_openrouter,
187 | persona="moderate",
188 | optimization_enabled=True,
189 | )
190 | # Add mock search provider
191 | agent.search_providers = [mock_search_provider]
192 | # Initialize confidence tracker for tests that need it
193 | agent.confidence_tracker = ConfidenceTracker()
194 | return agent
195 |
196 | @pytest.mark.asyncio
197 | async def test_research_comprehensive_success(
198 | self, agent, mock_search_provider, mock_openrouter
199 | ):
200 | """Test successful comprehensive research."""
201 | # Setup mock LLM for synthesis
202 | mock_llm = AsyncMock()
203 | mock_response = Mock()
204 | mock_response.content = "Comprehensive synthesis of research findings."
205 | mock_llm.ainvoke.return_value = mock_response
206 | mock_openrouter.get_llm.return_value = mock_llm
207 |
208 | # Mock analysis phase to return analyzed sources
209 | async def mock_analysis_phase(*args, **kwargs):
210 | return {
211 | "analyzed_sources": [
212 | {
213 | "title": "AAPL Analysis Report",
214 | "url": "http://example.com",
215 | "analysis": {
216 | "insights": ["Key insight"],
217 | "sentiment": {"direction": "bullish", "confidence": 0.8},
218 | "credibility_score": 0.9,
219 | "relevance_score": 0.85,
220 | "optimization_applied": True,
221 | },
222 | },
223 | {
224 | "title": "Technical Analysis AAPL",
225 | "url": "http://example2.com",
226 | "analysis": {
227 | "insights": ["Technical insight"],
228 | "sentiment": {"direction": "bullish", "confidence": 0.7},
229 | "credibility_score": 0.8,
230 | "relevance_score": 0.8,
231 | "optimization_applied": True,
232 | },
233 | },
234 | ],
235 | "final_confidence": 0.8,
236 | "early_terminated": False,
237 | "processing_mode": "optimized",
238 | }
239 |
240 | with patch.object(
241 | agent, "_optimized_analysis_phase", new_callable=AsyncMock
242 | ) as mock_analysis:
243 | mock_analysis.side_effect = mock_analysis_phase
244 |
245 | result = await agent.research_comprehensive(
246 | topic="AAPL stock analysis",
247 | session_id="test_session",
248 | depth="standard",
249 | focus_areas=["fundamental", "technical"],
250 | timeframe="30d",
251 | time_budget_seconds=120.0,
252 | target_confidence=0.75,
253 | )
254 |
255 | # Verify successful research
256 | assert result["status"] == "success"
257 | assert result["agent_type"] == "optimized_deep_research"
258 | assert result["optimization_enabled"] is True
259 | assert result["research_topic"] == "AAPL stock analysis"
260 | assert result["sources_analyzed"] > 0
261 | assert "findings" in result
262 | assert "citations" in result
263 | assert "optimization_metrics" in result
264 |
265 | @pytest.mark.asyncio
266 | async def test_research_with_no_providers(self, mock_openrouter):
267 | """Test research when no search providers are configured."""
268 | agent = OptimizedDeepResearchAgent(
269 | openrouter_provider=mock_openrouter,
270 | optimization_enabled=True,
271 | )
272 | agent.search_providers = [] # No providers
273 |
274 | result = await agent.research_comprehensive(
275 | topic="Test topic",
276 | session_id="test_session",
277 | time_budget_seconds=60.0,
278 | )
279 |
280 | assert "error" in result
281 | assert "no search providers configured" in result["error"].lower()
282 |
283 | @pytest.mark.asyncio
284 | async def test_research_with_early_termination(
285 | self, agent, mock_search_provider, mock_openrouter
286 | ):
287 | """Test early termination based on confidence threshold."""
288 |
289 | # Mock the entire analysis phase to return early termination
290 | async def mock_analysis_phase(*args, **kwargs):
291 | return {
292 | "analyzed_sources": [
293 | {
294 | "title": "Mock Source",
295 | "url": "http://example.com",
296 | "analysis": {
297 | "insights": ["High confidence insight"],
298 | "sentiment": {"direction": "bullish", "confidence": 0.95},
299 | "credibility_score": 0.95,
300 | "relevance_score": 0.9,
301 | },
302 | }
303 | ],
304 | "final_confidence": 0.92,
305 | "early_terminated": True,
306 | "termination_reason": "confidence_threshold_reached",
307 | "processing_mode": "optimized",
308 | }
309 |
310 | with patch.object(
311 | agent, "_optimized_analysis_phase", new_callable=AsyncMock
312 | ) as mock_analysis:
313 | mock_analysis.side_effect = mock_analysis_phase
314 |
315 | result = await agent.research_comprehensive(
316 | topic="Quick research test",
317 | session_id="test_session",
318 | time_budget_seconds=120.0,
319 | target_confidence=0.9,
320 | )
321 |
322 | assert result["findings"]["early_terminated"] is True
323 | assert (
324 | result["findings"]["termination_reason"]
325 | == "confidence_threshold_reached"
326 | )
327 |
328 | @pytest.mark.asyncio
329 | async def test_research_emergency_response(self, agent, mock_search_provider):
330 | """Test emergency response when time is critically low."""
331 | # Test with very short time budget
332 | result = agent._create_emergency_response(
333 | topic="Emergency test",
334 | search_results={"filtered_sources": [{"title": "Source 1"}]},
335 | start_time=time.time() - 1, # 1 second ago
336 | )
337 |
338 | assert result["status"] == "partial_success"
339 | assert result["emergency_mode"] is True
340 | assert "Emergency mode" in result["findings"]["synthesis"]
341 | assert result["findings"]["confidence_score"] == 0.3
342 |
343 |
344 | class TestModelSelectionLogic:
345 | """Test the adaptive model selection logic."""
346 |
347 | @pytest.fixture
348 | def model_selector(self):
349 | """Create a model selector with mock provider."""
350 | provider = Mock(spec=OpenRouterProvider)
351 | return AdaptiveModelSelector(provider)
352 |
353 | def test_calculate_task_complexity(self, model_selector):
354 | """Test task complexity calculation."""
355 | # Create content with financial complexity indicators
356 | content = (
357 | """
358 | This comprehensive financial analysis examines EBITDA, DCF valuation, and ROIC metrics.
359 | The company shows strong quarterly YoY growth with bullish sentiment from analysts.
360 | Technical analysis indicates RSI oversold conditions with MACD crossover signals.
361 | Support levels at $150 with resistance at $200. Volatility and beta measures suggest
362 | the stock outperforms relative to market. The Sharpe ratio indicates favorable
363 | risk-adjusted returns versus comparable companies in Q4 results.
364 | """
365 | * 20
366 | ) # Repeat to increase complexity
367 |
368 | complexity = model_selector.calculate_task_complexity(
369 | content, TaskType.DEEP_RESEARCH, ["fundamental", "technical"]
370 | )
371 |
372 | assert 0 <= complexity <= 1
373 | assert complexity > 0.1 # Should show some complexity with financial terms
374 |
375 | def test_select_model_for_time_budget(self, model_selector):
376 | """Test model selection based on time constraints."""
377 | # Test with short time budget - should select fast model
378 | config = model_selector.select_model_for_time_budget(
379 | task_type=TaskType.QUICK_ANSWER,
380 | time_remaining_seconds=10.0,
381 | complexity_score=0.3,
382 | content_size_tokens=100,
383 | current_confidence=0.5,
384 | )
385 |
386 | assert isinstance(config, ModelConfiguration)
387 | assert (
388 | config.timeout_seconds <= 15.0
389 | ) # Allow some flexibility for emergency models
390 | assert config.model_id is not None
391 |
392 | # Test with long time budget - can select quality model
393 | config_long = model_selector.select_model_for_time_budget(
394 | task_type=TaskType.DEEP_RESEARCH,
395 | time_remaining_seconds=300.0,
396 | complexity_score=0.8,
397 | content_size_tokens=5000,
398 | current_confidence=0.3,
399 | )
400 |
401 | assert config_long.timeout_seconds > config.timeout_seconds
402 | assert config_long.max_tokens >= config.max_tokens
403 |
404 |
405 | class TestTokenBudgetingAndConfidence:
406 | """Test token budgeting and confidence tracking."""
407 |
408 | def test_progressive_token_budgeter(self):
409 | """Test progressive token budget allocation."""
410 | budgeter = ProgressiveTokenBudgeter(
411 | total_time_budget_seconds=120.0, confidence_target=0.8
412 | )
413 |
414 | # Test initial allocation
415 | allocation = budgeter.get_next_allocation(
416 | sources_remaining=10,
417 | current_confidence=0.3,
418 | time_elapsed_seconds=10.0,
419 | )
420 |
421 | assert allocation["time_budget"] > 0
422 | assert allocation["max_tokens"] > 0
423 | assert allocation["priority"] in ["low", "medium", "high"]
424 |
425 | # Test with higher confidence
426 | allocation_high = budgeter.get_next_allocation(
427 | sources_remaining=5,
428 | current_confidence=0.7,
429 | time_elapsed_seconds=60.0,
430 | )
431 |
432 | # With fewer sources and higher confidence, priority should be lower or equal
433 | assert allocation_high["priority"] in ["low", "medium"]
434 | # The high confidence scenario should have lower or equal priority
435 | priority_order = {"low": 0, "medium": 1, "high": 2}
436 | assert (
437 | priority_order[allocation_high["priority"]]
438 | <= priority_order[allocation["priority"]]
439 | )
440 |
441 | def test_confidence_tracker(self):
442 | """Test confidence tracking and early termination."""
443 | tracker = ConfidenceTracker(
444 | target_confidence=0.8, min_sources=3, max_sources=20
445 | )
446 |
447 | # Test confidence updates
448 | analysis = {
449 | "sentiment": {"confidence": 0.7},
450 | "insights": ["insight1", "insight2"],
451 | }
452 |
453 | update = tracker.update_confidence(analysis, credibility_score=0.8)
454 |
455 | assert "current_confidence" in update
456 | assert "should_continue" in update
457 | assert update["sources_analyzed"] == 1
458 |
459 | # Test minimum sources requirement
460 | for _i in range(2):
461 | update = tracker.update_confidence(analysis, credibility_score=0.9)
462 |
463 | # Should continue even with high confidence if min sources not met
464 | if tracker.sources_analyzed < tracker.min_sources:
465 | assert update["should_continue"] is True
466 |
467 |
468 | class TestErrorHandlingAndRecovery:
469 | """Test error handling and recovery mechanisms."""
470 |
471 | @pytest.mark.asyncio
472 | async def test_search_timeout_handling(self):
473 | """Test handling of search provider timeouts."""
474 | agent = OptimizedDeepResearchAgent(
475 | openrouter_provider=Mock(spec=OpenRouterProvider),
476 | optimization_enabled=True,
477 | )
478 |
479 | # Mock search provider that times out
480 | mock_provider = AsyncMock()
481 | mock_provider.search.side_effect = TimeoutError("Search timeout")
482 |
483 | results = await agent._search_with_timeout(
484 | mock_provider, "test query", timeout=1.0
485 | )
486 |
487 | assert results == [] # Should return empty list on timeout
488 |
489 | @pytest.mark.asyncio
490 | async def test_synthesis_fallback(self):
491 | """Test fallback synthesis when LLM fails."""
492 | agent = OptimizedDeepResearchAgent(
493 | openrouter_provider=Mock(spec=OpenRouterProvider),
494 | optimization_enabled=True,
495 | )
496 |
497 | # Mock LLM failure
498 | with patch.object(
499 | agent.openrouter_provider,
500 | "get_llm",
501 | side_effect=Exception("LLM unavailable"),
502 | ):
503 | result = await agent._optimized_synthesis_phase(
504 | analyzed_sources=[{"analysis": {"insights": ["test"]}}],
505 | topic="Test topic",
506 | time_budget_seconds=10.0,
507 | )
508 |
509 | assert "fallback_used" in result
510 | assert result["fallback_used"] is True
511 | assert "basic processing" in result["synthesis"]
512 |
513 |
514 | class TestIntegrationWithParallelProcessing:
515 | """Test integration with parallel processing capabilities."""
516 |
517 | @pytest.mark.asyncio
518 | async def test_parallel_batch_processing(self):
519 | """Test parallel batch processing of sources."""
520 | analyzer = OptimizedContentAnalyzer(Mock(spec=OpenRouterProvider))
521 |
522 | # Mock parallel processor
523 | with patch.object(
524 | analyzer.parallel_processor,
525 | "parallel_content_analysis",
526 | new_callable=AsyncMock,
527 | ) as mock_parallel:
528 | mock_parallel.return_value = [
529 | {"analysis": {"insights": [f"Insight {i}"]}} for i in range(5)
530 | ]
531 |
532 | sources = [{"content": f"Source {i}"} for i in range(5)]
533 |
534 | results = await analyzer.batch_analyze_content(
535 | sources=sources,
536 | persona="moderate",
537 | analysis_type="general",
538 | time_budget_seconds=30.0,
539 | )
540 |
541 | assert len(results) == 5
542 | mock_parallel.assert_called_once()
543 |
544 |
545 | class TestFactoryFunction:
546 | """Test the factory function for creating optimized agents."""
547 |
548 | def test_create_optimized_research_agent(self):
549 | """Test agent creation through factory function."""
550 | with patch.dict("os.environ", {"OPENROUTER_API_KEY": "test_key"}):
551 | agent = create_optimized_research_agent(
552 | openrouter_api_key="test_key",
553 | persona="aggressive",
554 | time_budget_seconds=180.0,
555 | target_confidence=0.85,
556 | )
557 |
558 | assert isinstance(agent, OptimizedDeepResearchAgent)
559 | assert agent.optimization_enabled is True
560 | assert agent.persona.name == "Aggressive"
561 |
562 |
563 | if __name__ == "__main__":
564 | pytest.main([__file__, "-v"])
565 |
```
--------------------------------------------------------------------------------
/tests/domain/test_portfolio_entities.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Unit tests for portfolio domain entities.
3 |
4 | Tests the pure business logic of Position and Portfolio entities without
5 | any database or infrastructure dependencies.
6 | """
7 |
8 | from datetime import UTC, datetime, timedelta
9 | from decimal import Decimal
10 |
11 | import pytest
12 |
13 | from maverick_mcp.domain.portfolio import Portfolio, Position
14 |
15 |
16 | class TestPosition:
17 | """Test suite for Position value object."""
18 |
19 | def test_position_creation(self):
20 | """Test creating a valid position."""
21 | pos = Position(
22 | ticker="AAPL",
23 | shares=Decimal("10"),
24 | average_cost_basis=Decimal("150.00"),
25 | total_cost=Decimal("1500.00"),
26 | purchase_date=datetime.now(UTC),
27 | )
28 |
29 | assert pos.ticker == "AAPL"
30 | assert pos.shares == Decimal("10")
31 | assert pos.average_cost_basis == Decimal("150.00")
32 | assert pos.total_cost == Decimal("1500.00")
33 |
34 | def test_position_normalizes_ticker(self):
35 | """Test that ticker is normalized to uppercase."""
36 | pos = Position(
37 | ticker="aapl",
38 | shares=Decimal("10"),
39 | average_cost_basis=Decimal("150.00"),
40 | total_cost=Decimal("1500.00"),
41 | purchase_date=datetime.now(UTC),
42 | )
43 |
44 | assert pos.ticker == "AAPL"
45 |
46 | def test_position_rejects_zero_shares(self):
47 | """Test that positions cannot have zero shares."""
48 | with pytest.raises(ValueError, match="Shares must be positive"):
49 | Position(
50 | ticker="AAPL",
51 | shares=Decimal("0"),
52 | average_cost_basis=Decimal("150.00"),
53 | total_cost=Decimal("1500.00"),
54 | purchase_date=datetime.now(UTC),
55 | )
56 |
57 | def test_position_rejects_negative_shares(self):
58 | """Test that positions cannot have negative shares."""
59 | with pytest.raises(ValueError, match="Shares must be positive"):
60 | Position(
61 | ticker="AAPL",
62 | shares=Decimal("-10"),
63 | average_cost_basis=Decimal("150.00"),
64 | total_cost=Decimal("1500.00"),
65 | purchase_date=datetime.now(UTC),
66 | )
67 |
68 | def test_position_rejects_zero_cost_basis(self):
69 | """Test that positions cannot have zero cost basis."""
70 | with pytest.raises(ValueError, match="Average cost basis must be positive"):
71 | Position(
72 | ticker="AAPL",
73 | shares=Decimal("10"),
74 | average_cost_basis=Decimal("0"),
75 | total_cost=Decimal("1500.00"),
76 | purchase_date=datetime.now(UTC),
77 | )
78 |
79 | def test_position_rejects_negative_total_cost(self):
80 | """Test that positions cannot have negative total cost."""
81 | with pytest.raises(ValueError, match="Total cost must be positive"):
82 | Position(
83 | ticker="AAPL",
84 | shares=Decimal("10"),
85 | average_cost_basis=Decimal("150.00"),
86 | total_cost=Decimal("-1500.00"),
87 | purchase_date=datetime.now(UTC),
88 | )
89 |
90 | def test_add_shares_averages_cost_basis(self):
91 | """Test that adding shares correctly averages the cost basis."""
92 | # Start with 10 shares @ $150
93 | pos = Position(
94 | ticker="AAPL",
95 | shares=Decimal("10"),
96 | average_cost_basis=Decimal("150.00"),
97 | total_cost=Decimal("1500.00"),
98 | purchase_date=datetime.now(UTC),
99 | )
100 |
101 | # Add 10 shares @ $170
102 | pos = pos.add_shares(Decimal("10"), Decimal("170.00"), datetime.now(UTC))
103 |
104 | # Should have 20 shares @ $160 average
105 | assert pos.shares == Decimal("20")
106 | assert pos.average_cost_basis == Decimal("160.0000")
107 | assert pos.total_cost == Decimal("3200.00")
108 |
109 | def test_add_shares_updates_purchase_date(self):
110 | """Test that adding shares updates purchase date to earliest."""
111 | later_date = datetime.now(UTC)
112 | earlier_date = later_date - timedelta(days=30)
113 |
114 | pos = Position(
115 | ticker="AAPL",
116 | shares=Decimal("10"),
117 | average_cost_basis=Decimal("150.00"),
118 | total_cost=Decimal("1500.00"),
119 | purchase_date=later_date,
120 | )
121 |
122 | pos = pos.add_shares(Decimal("10"), Decimal("170.00"), earlier_date)
123 |
124 | assert pos.purchase_date == earlier_date
125 |
126 | def test_add_shares_rejects_zero_shares(self):
127 | """Test that adding zero shares raises error."""
128 | pos = Position(
129 | ticker="AAPL",
130 | shares=Decimal("10"),
131 | average_cost_basis=Decimal("150.00"),
132 | total_cost=Decimal("1500.00"),
133 | purchase_date=datetime.now(UTC),
134 | )
135 |
136 | with pytest.raises(ValueError, match="Shares to add must be positive"):
137 | pos.add_shares(Decimal("0"), Decimal("170.00"), datetime.now(UTC))
138 |
139 | def test_add_shares_rejects_zero_price(self):
140 | """Test that adding shares at zero price raises error."""
141 | pos = Position(
142 | ticker="AAPL",
143 | shares=Decimal("10"),
144 | average_cost_basis=Decimal("150.00"),
145 | total_cost=Decimal("1500.00"),
146 | purchase_date=datetime.now(UTC),
147 | )
148 |
149 | with pytest.raises(ValueError, match="Price must be positive"):
150 | pos.add_shares(Decimal("10"), Decimal("0"), datetime.now(UTC))
151 |
152 | def test_remove_shares_partial(self):
153 | """Test removing part of a position."""
154 | pos = Position(
155 | ticker="AAPL",
156 | shares=Decimal("20"),
157 | average_cost_basis=Decimal("160.00"),
158 | total_cost=Decimal("3200.00"),
159 | purchase_date=datetime.now(UTC),
160 | )
161 |
162 | pos = pos.remove_shares(Decimal("10"))
163 |
164 | assert pos is not None
165 | assert pos.shares == Decimal("10")
166 | assert pos.average_cost_basis == Decimal("160.00") # Unchanged
167 | assert pos.total_cost == Decimal("1600.00")
168 |
169 | def test_remove_shares_full(self):
170 | """Test removing entire position returns None."""
171 | pos = Position(
172 | ticker="AAPL",
173 | shares=Decimal("20"),
174 | average_cost_basis=Decimal("160.00"),
175 | total_cost=Decimal("3200.00"),
176 | purchase_date=datetime.now(UTC),
177 | )
178 |
179 | result = pos.remove_shares(Decimal("20"))
180 |
181 | assert result is None
182 |
183 | def test_remove_shares_more_than_held(self):
184 | """Test removing more shares than held closes position."""
185 | pos = Position(
186 | ticker="AAPL",
187 | shares=Decimal("20"),
188 | average_cost_basis=Decimal("160.00"),
189 | total_cost=Decimal("3200.00"),
190 | purchase_date=datetime.now(UTC),
191 | )
192 |
193 | result = pos.remove_shares(Decimal("25"))
194 |
195 | assert result is None
196 |
197 | def test_remove_shares_rejects_zero(self):
198 | """Test that removing zero shares raises error."""
199 | pos = Position(
200 | ticker="AAPL",
201 | shares=Decimal("20"),
202 | average_cost_basis=Decimal("160.00"),
203 | total_cost=Decimal("3200.00"),
204 | purchase_date=datetime.now(UTC),
205 | )
206 |
207 | with pytest.raises(ValueError, match="Shares to remove must be positive"):
208 | pos.remove_shares(Decimal("0"))
209 |
210 | def test_calculate_current_value_with_gain(self):
211 | """Test calculating current value with unrealized gain."""
212 | pos = Position(
213 | ticker="AAPL",
214 | shares=Decimal("20"),
215 | average_cost_basis=Decimal("160.00"),
216 | total_cost=Decimal("3200.00"),
217 | purchase_date=datetime.now(UTC),
218 | )
219 |
220 | metrics = pos.calculate_current_value(Decimal("175.50"))
221 |
222 | assert metrics["current_value"] == Decimal("3510.00")
223 | assert metrics["unrealized_pnl"] == Decimal("310.00")
224 | assert metrics["pnl_percentage"] == Decimal("9.69")
225 |
226 | def test_calculate_current_value_with_loss(self):
227 | """Test calculating current value with unrealized loss."""
228 | pos = Position(
229 | ticker="AAPL",
230 | shares=Decimal("20"),
231 | average_cost_basis=Decimal("160.00"),
232 | total_cost=Decimal("3200.00"),
233 | purchase_date=datetime.now(UTC),
234 | )
235 |
236 | metrics = pos.calculate_current_value(Decimal("145.00"))
237 |
238 | assert metrics["current_value"] == Decimal("2900.00")
239 | assert metrics["unrealized_pnl"] == Decimal("-300.00")
240 | assert metrics["pnl_percentage"] == Decimal("-9.38")
241 |
242 | def test_calculate_current_value_unchanged(self):
243 | """Test calculating current value when price unchanged."""
244 | pos = Position(
245 | ticker="AAPL",
246 | shares=Decimal("20"),
247 | average_cost_basis=Decimal("160.00"),
248 | total_cost=Decimal("3200.00"),
249 | purchase_date=datetime.now(UTC),
250 | )
251 |
252 | metrics = pos.calculate_current_value(Decimal("160.00"))
253 |
254 | assert metrics["current_value"] == Decimal("3200.00")
255 | assert metrics["unrealized_pnl"] == Decimal("0.00")
256 | assert metrics["pnl_percentage"] == Decimal("0.00")
257 |
258 | def test_fractional_shares(self):
259 | """Test that fractional shares are supported."""
260 | pos = Position(
261 | ticker="AAPL",
262 | shares=Decimal("10.5"),
263 | average_cost_basis=Decimal("150.25"),
264 | total_cost=Decimal("1577.625"),
265 | purchase_date=datetime.now(UTC),
266 | )
267 |
268 | assert pos.shares == Decimal("10.5")
269 | metrics = pos.calculate_current_value(Decimal("175.50"))
270 | assert metrics["current_value"] == Decimal("1842.75")
271 |
272 | def test_to_dict(self):
273 | """Test converting position to dictionary."""
274 | date = datetime.now(UTC)
275 | pos = Position(
276 | ticker="AAPL",
277 | shares=Decimal("10"),
278 | average_cost_basis=Decimal("150.00"),
279 | total_cost=Decimal("1500.00"),
280 | purchase_date=date,
281 | notes="Long-term hold",
282 | )
283 |
284 | result = pos.to_dict()
285 |
286 | assert result["ticker"] == "AAPL"
287 | assert result["shares"] == 10.0
288 | assert result["average_cost_basis"] == 150.0
289 | assert result["total_cost"] == 1500.0
290 | assert result["purchase_date"] == date.isoformat()
291 | assert result["notes"] == "Long-term hold"
292 |
293 |
294 | class TestPortfolio:
295 | """Test suite for Portfolio aggregate root."""
296 |
297 | def test_portfolio_creation(self):
298 | """Test creating an empty portfolio."""
299 | portfolio = Portfolio(
300 | portfolio_id="test-id",
301 | user_id="default",
302 | name="My Portfolio",
303 | )
304 |
305 | assert portfolio.portfolio_id == "test-id"
306 | assert portfolio.user_id == "default"
307 | assert portfolio.name == "My Portfolio"
308 | assert len(portfolio.positions) == 0
309 |
310 | def test_add_position_new(self):
311 | """Test adding a new position."""
312 | portfolio = Portfolio(
313 | portfolio_id="test-id",
314 | user_id="default",
315 | name="My Portfolio",
316 | )
317 |
318 | portfolio.add_position(
319 | ticker="AAPL",
320 | shares=Decimal("10"),
321 | price=Decimal("150.00"),
322 | date=datetime.now(UTC),
323 | )
324 |
325 | assert len(portfolio.positions) == 1
326 | assert portfolio.positions[0].ticker == "AAPL"
327 | assert portfolio.positions[0].shares == Decimal("10")
328 |
329 | def test_add_position_existing_averages(self):
330 | """Test that adding to existing position averages cost basis."""
331 | portfolio = Portfolio(
332 | portfolio_id="test-id",
333 | user_id="default",
334 | name="My Portfolio",
335 | )
336 |
337 | # First purchase
338 | portfolio.add_position(
339 | ticker="AAPL",
340 | shares=Decimal("10"),
341 | price=Decimal("150.00"),
342 | date=datetime.now(UTC),
343 | )
344 |
345 | # Second purchase
346 | portfolio.add_position(
347 | ticker="AAPL",
348 | shares=Decimal("10"),
349 | price=Decimal("170.00"),
350 | date=datetime.now(UTC),
351 | )
352 |
353 | assert len(portfolio.positions) == 1 # Still one position
354 | assert portfolio.positions[0].shares == Decimal("20")
355 | assert portfolio.positions[0].average_cost_basis == Decimal("160.0000")
356 |
357 | def test_add_position_case_insensitive(self):
358 | """Test that ticker matching is case-insensitive."""
359 | portfolio = Portfolio(
360 | portfolio_id="test-id",
361 | user_id="default",
362 | name="My Portfolio",
363 | )
364 |
365 | portfolio.add_position(
366 | ticker="aapl",
367 | shares=Decimal("10"),
368 | price=Decimal("150.00"),
369 | date=datetime.now(UTC),
370 | )
371 |
372 | portfolio.add_position(
373 | ticker="AAPL",
374 | shares=Decimal("10"),
375 | price=Decimal("170.00"),
376 | date=datetime.now(UTC),
377 | )
378 |
379 | assert len(portfolio.positions) == 1
380 | assert portfolio.positions[0].ticker == "AAPL"
381 |
382 | def test_remove_position_partial(self):
383 | """Test partially removing a position."""
384 | portfolio = Portfolio(
385 | portfolio_id="test-id",
386 | user_id="default",
387 | name="My Portfolio",
388 | )
389 |
390 | portfolio.add_position(
391 | ticker="AAPL",
392 | shares=Decimal("20"),
393 | price=Decimal("150.00"),
394 | date=datetime.now(UTC),
395 | )
396 |
397 | result = portfolio.remove_position("AAPL", Decimal("10"))
398 |
399 | assert result is True
400 | assert len(portfolio.positions) == 1
401 | assert portfolio.positions[0].shares == Decimal("10")
402 |
403 | def test_remove_position_full(self):
404 | """Test fully removing a position."""
405 | portfolio = Portfolio(
406 | portfolio_id="test-id",
407 | user_id="default",
408 | name="My Portfolio",
409 | )
410 |
411 | portfolio.add_position(
412 | ticker="AAPL",
413 | shares=Decimal("20"),
414 | price=Decimal("150.00"),
415 | date=datetime.now(UTC),
416 | )
417 |
418 | result = portfolio.remove_position("AAPL")
419 |
420 | assert result is True
421 | assert len(portfolio.positions) == 0
422 |
423 | def test_remove_position_nonexistent(self):
424 | """Test removing non-existent position returns False."""
425 | portfolio = Portfolio(
426 | portfolio_id="test-id",
427 | user_id="default",
428 | name="My Portfolio",
429 | )
430 |
431 | result = portfolio.remove_position("AAPL")
432 |
433 | assert result is False
434 |
435 | def test_get_position(self):
436 | """Test getting a position by ticker."""
437 | portfolio = Portfolio(
438 | portfolio_id="test-id",
439 | user_id="default",
440 | name="My Portfolio",
441 | )
442 |
443 | portfolio.add_position(
444 | ticker="AAPL",
445 | shares=Decimal("10"),
446 | price=Decimal("150.00"),
447 | date=datetime.now(UTC),
448 | )
449 |
450 | pos = portfolio.get_position("AAPL")
451 |
452 | assert pos is not None
453 | assert pos.ticker == "AAPL"
454 |
455 | def test_get_position_case_insensitive(self):
456 | """Test that get_position is case-insensitive."""
457 | portfolio = Portfolio(
458 | portfolio_id="test-id",
459 | user_id="default",
460 | name="My Portfolio",
461 | )
462 |
463 | portfolio.add_position(
464 | ticker="AAPL",
465 | shares=Decimal("10"),
466 | price=Decimal("150.00"),
467 | date=datetime.now(UTC),
468 | )
469 |
470 | pos = portfolio.get_position("aapl")
471 |
472 | assert pos is not None
473 | assert pos.ticker == "AAPL"
474 |
475 | def test_get_position_nonexistent(self):
476 | """Test getting non-existent position returns None."""
477 | portfolio = Portfolio(
478 | portfolio_id="test-id",
479 | user_id="default",
480 | name="My Portfolio",
481 | )
482 |
483 | pos = portfolio.get_position("AAPL")
484 |
485 | assert pos is None
486 |
487 | def test_get_total_invested(self):
488 | """Test calculating total capital invested."""
489 | portfolio = Portfolio(
490 | portfolio_id="test-id",
491 | user_id="default",
492 | name="My Portfolio",
493 | )
494 |
495 | portfolio.add_position(
496 | ticker="AAPL",
497 | shares=Decimal("10"),
498 | price=Decimal("150.00"),
499 | date=datetime.now(UTC),
500 | )
501 |
502 | portfolio.add_position(
503 | ticker="MSFT",
504 | shares=Decimal("5"),
505 | price=Decimal("300.00"),
506 | date=datetime.now(UTC),
507 | )
508 |
509 | total = portfolio.get_total_invested()
510 |
511 | assert total == Decimal("3000.00")
512 |
513 | def test_calculate_portfolio_metrics(self):
514 | """Test calculating comprehensive portfolio metrics."""
515 | portfolio = Portfolio(
516 | portfolio_id="test-id",
517 | user_id="default",
518 | name="My Portfolio",
519 | )
520 |
521 | portfolio.add_position(
522 | ticker="AAPL",
523 | shares=Decimal("10"),
524 | price=Decimal("150.00"),
525 | date=datetime.now(UTC),
526 | )
527 |
528 | portfolio.add_position(
529 | ticker="MSFT",
530 | shares=Decimal("5"),
531 | price=Decimal("300.00"),
532 | date=datetime.now(UTC),
533 | )
534 |
535 | current_prices = {
536 | "AAPL": Decimal("175.50"),
537 | "MSFT": Decimal("320.00"),
538 | }
539 |
540 | metrics = portfolio.calculate_portfolio_metrics(current_prices)
541 |
542 | assert metrics["total_value"] == 3355.0 # (10 * 175.50) + (5 * 320)
543 | assert metrics["total_invested"] == 3000.0
544 | assert metrics["total_pnl"] == 355.0
545 | assert metrics["total_pnl_percentage"] == 11.83
546 | assert metrics["position_count"] == 2
547 | assert len(metrics["positions"]) == 2
548 |
549 | def test_calculate_portfolio_metrics_uses_fallback_price(self):
550 | """Test that missing prices fall back to cost basis."""
551 | portfolio = Portfolio(
552 | portfolio_id="test-id",
553 | user_id="default",
554 | name="My Portfolio",
555 | )
556 |
557 | portfolio.add_position(
558 | ticker="AAPL",
559 | shares=Decimal("10"),
560 | price=Decimal("150.00"),
561 | date=datetime.now(UTC),
562 | )
563 |
564 | # No current price provided
565 | metrics = portfolio.calculate_portfolio_metrics({})
566 |
567 | # Should use cost basis as current price
568 | assert metrics["total_value"] == 1500.0
569 | assert metrics["total_pnl"] == 0.0
570 |
571 | def test_clear_all_positions(self):
572 | """Test clearing all positions."""
573 | portfolio = Portfolio(
574 | portfolio_id="test-id",
575 | user_id="default",
576 | name="My Portfolio",
577 | )
578 |
579 | portfolio.add_position(
580 | ticker="AAPL",
581 | shares=Decimal("10"),
582 | price=Decimal("150.00"),
583 | date=datetime.now(UTC),
584 | )
585 |
586 | portfolio.add_position(
587 | ticker="MSFT",
588 | shares=Decimal("5"),
589 | price=Decimal("300.00"),
590 | date=datetime.now(UTC),
591 | )
592 |
593 | portfolio.clear_all_positions()
594 |
595 | assert len(portfolio.positions) == 0
596 |
597 | def test_to_dict(self):
598 | """Test converting portfolio to dictionary."""
599 | portfolio = Portfolio(
600 | portfolio_id="test-id",
601 | user_id="default",
602 | name="My Portfolio",
603 | )
604 |
605 | portfolio.add_position(
606 | ticker="AAPL",
607 | shares=Decimal("10"),
608 | price=Decimal("150.00"),
609 | date=datetime.now(UTC),
610 | )
611 |
612 | result = portfolio.to_dict()
613 |
614 | assert result["portfolio_id"] == "test-id"
615 | assert result["user_id"] == "default"
616 | assert result["name"] == "My Portfolio"
617 | assert result["position_count"] == 1
618 | assert result["total_invested"] == 1500.0
619 | assert len(result["positions"]) == 1
620 |
621 | def test_multiple_positions_with_different_performance(self):
622 | """Test portfolio with positions having different performance."""
623 | portfolio = Portfolio(
624 | portfolio_id="test-id",
625 | user_id="default",
626 | name="My Portfolio",
627 | )
628 |
629 | # Winner
630 | portfolio.add_position(
631 | ticker="NVDA",
632 | shares=Decimal("5"),
633 | price=Decimal("450.00"),
634 | date=datetime.now(UTC),
635 | )
636 |
637 | # Loser
638 | portfolio.add_position(
639 | ticker="MARA",
640 | shares=Decimal("50"),
641 | price=Decimal("18.50"),
642 | date=datetime.now(UTC),
643 | )
644 |
645 | current_prices = {
646 | "NVDA": Decimal("520.00"), # +15.6%
647 | "MARA": Decimal("13.50"), # -27.0%
648 | }
649 |
650 | metrics = portfolio.calculate_portfolio_metrics(current_prices)
651 |
652 | # Check individual positions
653 | nvda_pos = next(p for p in metrics["positions"] if p["ticker"] == "NVDA")
654 | mara_pos = next(p for p in metrics["positions"] if p["ticker"] == "MARA")
655 |
656 | assert nvda_pos["unrealized_pnl"] == 350.0 # (520 - 450) * 5
657 | assert mara_pos["unrealized_pnl"] == -250.0 # (13.50 - 18.50) * 50
658 |
659 | # Overall portfolio
660 | assert metrics["total_pnl"] == 100.0 # 350 - 250
661 |
```
--------------------------------------------------------------------------------
/maverick_mcp/api/routers/introspection.py:
--------------------------------------------------------------------------------
```python
1 | """MCP Introspection Tools for Better Discovery and Understanding."""
2 |
3 | from typing import Any
4 |
5 | from fastmcp import FastMCP
6 |
7 |
8 | def register_introspection_tools(mcp: FastMCP) -> None:
9 | """Register introspection tools for better discovery."""
10 |
11 | @mcp.tool(name="discover_capabilities")
12 | async def discover_capabilities() -> dict[str, Any]:
13 | """
14 | Discover all available capabilities of the MaverickMCP server.
15 |
16 | This tool provides comprehensive information about:
17 | - Available strategies (traditional and ML)
18 | - Tool categories and their functions
19 | - Parameter requirements for each strategy
20 | - Example usage patterns
21 |
22 | Use this as your first tool to understand what's available.
23 | """
24 | return {
25 | "server_info": {
26 | "name": "MaverickMCP",
27 | "version": "1.0.0",
28 | "description": "Advanced stock analysis and backtesting MCP server",
29 | },
30 | "capabilities": {
31 | "backtesting": {
32 | "description": "Run and optimize trading strategies",
33 | "strategies_available": 15,
34 | "ml_strategies": ["online_learning", "regime_aware", "ensemble"],
35 | "traditional_strategies": [
36 | "sma_cross",
37 | "rsi",
38 | "macd",
39 | "bollinger",
40 | "momentum",
41 | "ema_cross",
42 | "mean_reversion",
43 | "breakout",
44 | "volume_momentum",
45 | ],
46 | "features": [
47 | "parameter_optimization",
48 | "strategy_comparison",
49 | "walk_forward_analysis",
50 | ],
51 | },
52 | "technical_analysis": {
53 | "description": "Calculate technical indicators and patterns",
54 | "indicators": [
55 | "SMA",
56 | "EMA",
57 | "RSI",
58 | "MACD",
59 | "Bollinger Bands",
60 | "Support/Resistance",
61 | ],
62 | "chart_analysis": True,
63 | "pattern_recognition": True,
64 | },
65 | "screening": {
66 | "description": "Pre-calculated S&P 500 screening results",
67 | "strategies": [
68 | "maverick_bullish",
69 | "maverick_bearish",
70 | "supply_demand_breakouts",
71 | ],
72 | "database": "520 S&P 500 stocks pre-seeded",
73 | },
74 | "research": {
75 | "description": "AI-powered research with parallel execution",
76 | "features": [
77 | "comprehensive_research",
78 | "company_analysis",
79 | "sentiment_analysis",
80 | ],
81 | "performance": "7-256x speedup with parallel agents",
82 | "ai_models": "400+ models via OpenRouter",
83 | },
84 | },
85 | "quick_start": {
86 | "first_command": "Run: discover_capabilities() to see this info",
87 | "simple_backtest": "run_backtest(symbol='AAPL', strategy_type='sma_cross')",
88 | "ml_strategy": "run_backtest(symbol='TSLA', strategy_type='online_learning')",
89 | "get_help": "Use prompts like 'backtest_strategy_guide' for detailed guides",
90 | },
91 | }
92 |
93 | @mcp.tool(name="list_all_strategies")
94 | async def list_all_strategies() -> list[dict[str, Any]]:
95 | """
96 | List all available backtesting strategies with their parameters.
97 |
98 | Returns detailed information about each strategy including:
99 | - Strategy name and description
100 | - Required and optional parameters
101 | - Default parameter values
102 | - Example usage
103 | """
104 | strategies = [] # Return as array
105 |
106 | # Traditional strategies
107 | strategies.extend(
108 | [
109 | {
110 | "name": "sma_cross",
111 | "description": "Simple Moving Average Crossover",
112 | "parameters": {
113 | "fast_period": {
114 | "type": "int",
115 | "default": 10,
116 | "description": "Fast MA period",
117 | },
118 | "slow_period": {
119 | "type": "int",
120 | "default": 20,
121 | "description": "Slow MA period",
122 | },
123 | },
124 | "example": "run_backtest(symbol='AAPL', strategy_type='sma_cross', fast_period=10, slow_period=20)",
125 | },
126 | {
127 | "name": "rsi",
128 | "description": "RSI Mean Reversion",
129 | "parameters": {
130 | "period": {
131 | "type": "int",
132 | "default": 14,
133 | "description": "RSI calculation period",
134 | },
135 | "oversold": {
136 | "type": "int",
137 | "default": 30,
138 | "description": "Oversold threshold",
139 | },
140 | "overbought": {
141 | "type": "int",
142 | "default": 70,
143 | "description": "Overbought threshold",
144 | },
145 | },
146 | "example": "run_backtest(symbol='MSFT', strategy_type='rsi', period=14)",
147 | },
148 | {
149 | "name": "macd",
150 | "description": "MACD Signal Line Crossover",
151 | "parameters": {
152 | "fast_period": {
153 | "type": "int",
154 | "default": 12,
155 | "description": "Fast EMA period",
156 | },
157 | "slow_period": {
158 | "type": "int",
159 | "default": 26,
160 | "description": "Slow EMA period",
161 | },
162 | "signal_period": {
163 | "type": "int",
164 | "default": 9,
165 | "description": "Signal line period",
166 | },
167 | },
168 | "example": "run_backtest(symbol='GOOGL', strategy_type='macd')",
169 | },
170 | {
171 | "name": "bollinger",
172 | "description": "Bollinger Bands Mean Reversion",
173 | "parameters": {
174 | "period": {
175 | "type": "int",
176 | "default": 20,
177 | "description": "BB calculation period",
178 | },
179 | "std_dev": {
180 | "type": "float",
181 | "default": 2,
182 | "description": "Standard deviations",
183 | },
184 | },
185 | "example": "run_backtest(symbol='AMZN', strategy_type='bollinger')",
186 | },
187 | {
188 | "name": "momentum",
189 | "description": "Momentum Trading Strategy",
190 | "parameters": {
191 | "period": {
192 | "type": "int",
193 | "default": 10,
194 | "description": "Momentum period",
195 | },
196 | "threshold": {
197 | "type": "float",
198 | "default": 0.02,
199 | "description": "Entry threshold",
200 | },
201 | },
202 | "example": "run_backtest(symbol='NVDA', strategy_type='momentum')",
203 | },
204 | {
205 | "name": "ema_cross",
206 | "description": "Exponential Moving Average Crossover",
207 | "parameters": {
208 | "fast_period": {
209 | "type": "int",
210 | "default": 12,
211 | "description": "Fast EMA period",
212 | },
213 | "slow_period": {
214 | "type": "int",
215 | "default": 26,
216 | "description": "Slow EMA period",
217 | },
218 | },
219 | "example": "run_backtest(symbol='META', strategy_type='ema_cross')",
220 | },
221 | {
222 | "name": "mean_reversion",
223 | "description": "Statistical Mean Reversion",
224 | "parameters": {
225 | "lookback": {
226 | "type": "int",
227 | "default": 20,
228 | "description": "Lookback period",
229 | },
230 | "entry_z": {
231 | "type": "float",
232 | "default": -2,
233 | "description": "Entry z-score",
234 | },
235 | "exit_z": {
236 | "type": "float",
237 | "default": 0,
238 | "description": "Exit z-score",
239 | },
240 | },
241 | "example": "run_backtest(symbol='SPY', strategy_type='mean_reversion')",
242 | },
243 | {
244 | "name": "breakout",
245 | "description": "Channel Breakout Strategy",
246 | "parameters": {
247 | "lookback": {
248 | "type": "int",
249 | "default": 20,
250 | "description": "Channel period",
251 | },
252 | "breakout_factor": {
253 | "type": "float",
254 | "default": 1.5,
255 | "description": "Breakout multiplier",
256 | },
257 | },
258 | "example": "run_backtest(symbol='QQQ', strategy_type='breakout')",
259 | },
260 | {
261 | "name": "volume_momentum",
262 | "description": "Volume-Weighted Momentum",
263 | "parameters": {
264 | "period": {
265 | "type": "int",
266 | "default": 10,
267 | "description": "Momentum period",
268 | },
269 | "volume_factor": {
270 | "type": "float",
271 | "default": 1.5,
272 | "description": "Volume multiplier",
273 | },
274 | },
275 | "example": "run_backtest(symbol='TSLA', strategy_type='volume_momentum')",
276 | },
277 | ]
278 | )
279 |
280 | # ML strategies
281 | strategies.extend(
282 | [
283 | {
284 | "name": "ml_predictor",
285 | "description": "Machine Learning predictor using Random Forest",
286 | "parameters": {
287 | "model_type": {
288 | "type": "str",
289 | "default": "random_forest",
290 | "description": "ML model type",
291 | },
292 | "n_estimators": {
293 | "type": "int",
294 | "default": 100,
295 | "description": "Number of trees",
296 | },
297 | "max_depth": {
298 | "type": "int",
299 | "default": None,
300 | "description": "Max tree depth",
301 | },
302 | },
303 | "example": "run_ml_strategy_backtest(symbol='AAPL', strategy_type='ml_predictor', model_type='random_forest')",
304 | },
305 | {
306 | "name": "online_learning",
307 | "description": "Online learning adaptive strategy (alias for adaptive)",
308 | "parameters": {
309 | "learning_rate": {
310 | "type": "float",
311 | "default": 0.01,
312 | "description": "Adaptation rate",
313 | },
314 | "adaptation_method": {
315 | "type": "str",
316 | "default": "gradient",
317 | "description": "Method for adaptation",
318 | },
319 | },
320 | "example": "run_ml_strategy_backtest(symbol='AAPL', strategy_type='online_learning')",
321 | },
322 | {
323 | "name": "regime_aware",
324 | "description": "Market regime detection and adaptation",
325 | "parameters": {
326 | "regime_window": {
327 | "type": "int",
328 | "default": 50,
329 | "description": "Regime detection window",
330 | },
331 | "threshold": {
332 | "type": "float",
333 | "default": 0.02,
334 | "description": "Regime change threshold",
335 | },
336 | },
337 | "example": "run_backtest(symbol='SPY', strategy_type='regime_aware')",
338 | },
339 | {
340 | "name": "ensemble",
341 | "description": "Ensemble voting with multiple strategies",
342 | "parameters": {
343 | "fast_period": {
344 | "type": "int",
345 | "default": 10,
346 | "description": "Fast MA for ensemble",
347 | },
348 | "slow_period": {
349 | "type": "int",
350 | "default": 20,
351 | "description": "Slow MA for ensemble",
352 | },
353 | "rsi_period": {
354 | "type": "int",
355 | "default": 14,
356 | "description": "RSI period for ensemble",
357 | },
358 | },
359 | "example": "run_ml_strategy_backtest(symbol='MSFT', strategy_type='ensemble')",
360 | },
361 | {
362 | "name": "adaptive",
363 | "description": "Adaptive strategy that adjusts based on performance",
364 | "parameters": {
365 | "learning_rate": {
366 | "type": "float",
367 | "default": 0.01,
368 | "description": "How quickly to adapt",
369 | },
370 | "adaptation_method": {
371 | "type": "str",
372 | "default": "gradient",
373 | "description": "Method for adaptation",
374 | },
375 | },
376 | "example": "run_ml_strategy_backtest(symbol='GOOGL', strategy_type='adaptive')",
377 | },
378 | ]
379 | )
380 |
381 | return strategies # Return array
382 |
383 | @mcp.tool(name="get_strategy_help")
384 | async def get_strategy_help(strategy_type: str) -> dict[str, Any]:
385 | """
386 | Get detailed help for a specific strategy.
387 |
388 | Args:
389 | strategy_type: Name of the strategy (e.g., 'sma_cross', 'online_learning')
390 |
391 | Returns:
392 | Detailed information about the strategy including theory, parameters, and best practices.
393 | """
394 | strategy_help = {
395 | "sma_cross": {
396 | "name": "Simple Moving Average Crossover",
397 | "theory": "Generates buy signals when fast SMA crosses above slow SMA, sell when opposite occurs",
398 | "best_for": "Trending markets with clear directional moves",
399 | "parameters": {
400 | "fast_period": "Typically 10-20 days for short-term trends",
401 | "slow_period": "Typically 20-50 days for medium-term trends",
402 | },
403 | "tips": [
404 | "Works best in trending markets",
405 | "Consider adding volume confirmation",
406 | "Use wider periods for less noise",
407 | ],
408 | },
409 | "ml_predictor": {
410 | "name": "Machine Learning Predictor",
411 | "theory": "Uses Random Forest or other ML models to predict price movements",
412 | "best_for": "Complex markets with multiple factors",
413 | "parameters": {
414 | "model_type": "Type of ML model (random_forest)",
415 | "n_estimators": "Number of trees in forest (50-200)",
416 | "max_depth": "Maximum tree depth (None or 5-20)",
417 | },
418 | "tips": [
419 | "More estimators for better accuracy but slower",
420 | "Limit depth to prevent overfitting",
421 | "Requires sufficient historical data",
422 | ],
423 | },
424 | "online_learning": {
425 | "name": "Online Learning Strategy",
426 | "theory": "Continuously adapts strategy parameters based on recent performance",
427 | "best_for": "Dynamic markets with changing patterns",
428 | "parameters": {
429 | "learning_rate": "How quickly to adapt (0.001-0.1)",
430 | "adaptation_method": "Method for adaptation (gradient, bayesian)",
431 | },
432 | "tips": [
433 | "Lower learning rates for stable adaptation",
434 | "Works well in volatile markets",
435 | "This is an alias for the adaptive strategy",
436 | ],
437 | },
438 | "adaptive": {
439 | "name": "Adaptive Strategy",
440 | "theory": "Dynamically adjusts strategy parameters based on performance",
441 | "best_for": "Markets with changing characteristics",
442 | "parameters": {
443 | "learning_rate": "How quickly to adapt (0.001-0.1)",
444 | "adaptation_method": "Method for adaptation (gradient, bayesian)",
445 | },
446 | "tips": [
447 | "Start with lower learning rates",
448 | "Monitor for overfitting",
449 | "Works best with stable base strategy",
450 | ],
451 | },
452 | "ensemble": {
453 | "name": "Strategy Ensemble",
454 | "theory": "Combines multiple strategies with weighted voting",
455 | "best_for": "Risk reduction through diversification",
456 | "parameters": {
457 | "base_strategies": "List of strategies to combine",
458 | "weighting_method": "How to weight strategies (equal, performance, volatility)",
459 | },
460 | "tips": [
461 | "Combine uncorrelated strategies",
462 | "Performance weighting adapts to market",
463 | "More strategies reduce single-point failure",
464 | ],
465 | },
466 | "regime_aware": {
467 | "name": "Market Regime Detection Strategy",
468 | "theory": "Identifies market regimes (trending vs ranging) and adapts strategy accordingly",
469 | "best_for": "Markets that alternate between trending and sideways movement",
470 | "parameters": {
471 | "regime_window": "Period for regime detection (30-100 days)",
472 | "threshold": "Sensitivity to regime changes (0.01-0.05)",
473 | },
474 | "tips": [
475 | "Longer windows for major regime shifts",
476 | "Lower thresholds for more sensitive detection",
477 | "Combines well with other indicators",
478 | ],
479 | },
480 | }
481 |
482 | if strategy_type in strategy_help:
483 | return strategy_help[strategy_type]
484 | else:
485 | return {
486 | "error": f"Strategy '{strategy_type}' not found",
487 | "available_strategies": list(strategy_help.keys()),
488 | "tip": "Use list_all_strategies() to see all available strategies",
489 | }
490 |
```
--------------------------------------------------------------------------------
/maverick_mcp/tests/test_in_memory_routers.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | In-memory tests for domain-specific routers using FastMCP patterns.
3 |
4 | Tests individual router functionality in isolation using FastMCP's
5 | router mounting and in-memory testing capabilities.
6 | """
7 |
8 | import asyncio
9 | from unittest.mock import Mock, patch
10 |
11 | import pytest
12 | from fastmcp import Client, FastMCP
13 | from sqlalchemy import create_engine
14 | from sqlalchemy.orm import Session
15 |
16 | from maverick_mcp.api.routers.data import data_router
17 | from maverick_mcp.api.routers.portfolio import portfolio_router
18 | from maverick_mcp.api.routers.screening import screening_router
19 | from maverick_mcp.api.routers.technical import technical_router
20 | from maverick_mcp.data.models import (
21 | Base,
22 | MaverickStocks,
23 | Stock,
24 | SupplyDemandBreakoutStocks,
25 | )
26 |
27 |
28 | @pytest.fixture
29 | def test_server():
30 | """Create a test server with only specific routers mounted."""
31 | test_mcp: FastMCP = FastMCP("TestMaverick-MCP")
32 | return test_mcp
33 |
34 |
35 | @pytest.fixture
36 | def screening_db():
37 | """Create test database with screening data."""
38 | engine = create_engine("sqlite:///:memory:")
39 | Base.metadata.create_all(engine)
40 |
41 | with Session(engine) as session:
42 | # Add test stocks
43 | stocks = [
44 | Stock(
45 | ticker_symbol="AAPL",
46 | company_name="Apple Inc.",
47 | sector="Technology",
48 | industry="Consumer Electronics",
49 | ),
50 | Stock(
51 | ticker_symbol="MSFT",
52 | company_name="Microsoft Corp.",
53 | sector="Technology",
54 | industry="Software",
55 | ),
56 | Stock(
57 | ticker_symbol="GOOGL",
58 | company_name="Alphabet Inc.",
59 | sector="Technology",
60 | industry="Internet",
61 | ),
62 | Stock(
63 | ticker_symbol="AMZN",
64 | company_name="Amazon.com Inc.",
65 | sector="Consumer Cyclical",
66 | industry="Internet Retail",
67 | ),
68 | Stock(
69 | ticker_symbol="TSLA",
70 | company_name="Tesla Inc.",
71 | sector="Consumer Cyclical",
72 | industry="Auto Manufacturers",
73 | ),
74 | ]
75 | session.add_all(stocks)
76 | session.commit()
77 |
78 | # Add Maverick screening results
79 | maverick_stocks = [
80 | MaverickStocks(
81 | id=1,
82 | stock="AAPL",
83 | close=150.0,
84 | open=148.0,
85 | high=152.0,
86 | low=147.0,
87 | volume=10000000,
88 | combined_score=92,
89 | momentum_score=88,
90 | adr_pct=2.5,
91 | atr=3.2,
92 | pat="Cup and Handle",
93 | sqz="Yes",
94 | consolidation="trending",
95 | entry="151.50",
96 | compression_score=85,
97 | pattern_detected=1,
98 | ema_21=149.0,
99 | sma_50=148.0,
100 | sma_150=145.0,
101 | sma_200=140.0,
102 | avg_vol_30d=9500000,
103 | ),
104 | MaverickStocks(
105 | id=2,
106 | stock="MSFT",
107 | close=300.0,
108 | open=298.0,
109 | high=302.0,
110 | low=297.0,
111 | volume=8000000,
112 | combined_score=89,
113 | momentum_score=82,
114 | adr_pct=2.1,
115 | atr=4.5,
116 | pat="Ascending Triangle",
117 | sqz="No",
118 | consolidation="trending",
119 | entry="301.00",
120 | compression_score=80,
121 | pattern_detected=1,
122 | ema_21=299.0,
123 | sma_50=298.0,
124 | sma_150=295.0,
125 | sma_200=290.0,
126 | avg_vol_30d=7500000,
127 | ),
128 | ]
129 | session.add_all(maverick_stocks)
130 |
131 | # Add trending screening results
132 | trending_stocks = [
133 | SupplyDemandBreakoutStocks(
134 | id=1,
135 | stock="GOOGL",
136 | close=140.0,
137 | open=138.0,
138 | high=142.0,
139 | low=137.0,
140 | volume=5000000,
141 | momentum_score=91,
142 | adr_pct=2.8,
143 | atr=3.5,
144 | pat="Base Breakout",
145 | sqz="Yes",
146 | consolidation="trending",
147 | entry="141.00",
148 | ema_21=139.0,
149 | sma_50=138.0,
150 | sma_150=135.0,
151 | sma_200=130.0,
152 | avg_volume_30d=4800000,
153 | ),
154 | ]
155 | session.add_all(trending_stocks)
156 | session.commit()
157 |
158 | with patch("maverick_mcp.data.models.engine", engine):
159 | with patch("maverick_mcp.data.models.SessionLocal", lambda: Session(engine)):
160 | yield engine
161 |
162 |
163 | class TestTechnicalRouter:
164 | """Test technical analysis router functionality."""
165 |
166 | @pytest.mark.asyncio
167 | async def test_rsi_calculation(self, test_server, screening_db):
168 | """Test RSI calculation through the router."""
169 | test_server.mount("/technical", technical_router)
170 |
171 | # Mock price data for RSI calculation
172 | with patch(
173 | "maverick_mcp.providers.stock_data.StockDataProvider.get_stock_data"
174 | ) as mock_data:
175 | # Create 30 days of price data
176 | import pandas as pd
177 |
178 | dates = pd.date_range(end="2024-01-31", periods=30)
179 | prices = pd.DataFrame(
180 | {
181 | "Close": [
182 | 100 + (i % 5) - 2 for i in range(30)
183 | ], # Oscillating prices
184 | "High": [101 + (i % 5) - 2 for i in range(30)],
185 | "Low": [99 + (i % 5) - 2 for i in range(30)],
186 | "Open": [100 + (i % 5) - 2 for i in range(30)],
187 | "Volume": [1000000] * 30,
188 | },
189 | index=dates,
190 | )
191 | mock_data.return_value = prices
192 |
193 | async with Client(test_server) as client:
194 | result = await client.call_tool(
195 | "/technical_get_rsi_analysis", {"ticker": "AAPL", "period": 14}
196 | )
197 |
198 | assert len(result) > 0
199 | assert result[0].text is not None
200 | # RSI should be calculated
201 | assert "rsi" in result[0].text.lower()
202 |
203 | @pytest.mark.asyncio
204 | async def test_macd_analysis(self, test_server, screening_db):
205 | """Test MACD analysis with custom parameters."""
206 | test_server.mount("/technical", technical_router)
207 |
208 | with patch(
209 | "maverick_mcp.providers.stock_data.StockDataProvider.get_stock_data"
210 | ) as mock_data:
211 | # Create trending price data
212 | import pandas as pd
213 |
214 | dates = pd.date_range(end="2024-01-31", periods=50)
215 | prices = pd.DataFrame(
216 | {
217 | "Close": [100 + (i * 0.5) for i in range(50)], # Upward trend
218 | "High": [101 + (i * 0.5) for i in range(50)],
219 | "Low": [99 + (i * 0.5) for i in range(50)],
220 | "Open": [100 + (i * 0.5) for i in range(50)],
221 | "Volume": [1000000] * 50,
222 | },
223 | index=dates,
224 | )
225 | mock_data.return_value = prices
226 |
227 | async with Client(test_server) as client:
228 | result = await client.call_tool(
229 | "/technical_get_macd_analysis",
230 | {
231 | "ticker": "MSFT",
232 | "fast_period": 12,
233 | "slow_period": 26,
234 | "signal_period": 9,
235 | },
236 | )
237 |
238 | assert len(result) > 0
239 | assert result[0].text is not None
240 | data = eval(result[0].text)
241 | assert "analysis" in data
242 | assert "histogram" in data["analysis"]
243 | assert "indicator" in data["analysis"]
244 |
245 | @pytest.mark.asyncio
246 | async def test_support_resistance(self, test_server, screening_db):
247 | """Test support and resistance level detection."""
248 | test_server.mount("/technical", technical_router)
249 |
250 | with patch(
251 | "maverick_mcp.providers.stock_data.StockDataProvider.get_stock_data"
252 | ) as mock_data:
253 | # Create price data with clear levels
254 | import pandas as pd
255 |
256 | dates = pd.date_range(end="2024-01-31", periods=100)
257 | prices = []
258 | for i in range(100):
259 | if i % 20 < 10:
260 | price = 100 # Support level
261 | else:
262 | price = 110 # Resistance level
263 | prices.append(
264 | {
265 | "High": price + 1,
266 | "Low": price - 1,
267 | "Close": price,
268 | "Open": price,
269 | "Volume": 1000000,
270 | }
271 | )
272 | prices_df = pd.DataFrame(prices, index=dates)
273 | mock_data.return_value = prices_df
274 |
275 | async with Client(test_server) as client:
276 | result = await client.call_tool(
277 | "/technical_get_support_resistance",
278 | {"ticker": "GOOGL", "days": 90},
279 | )
280 |
281 | assert len(result) > 0
282 | assert result[0].text is not None
283 | data = eval(result[0].text)
284 | assert "support_levels" in data
285 | assert "resistance_levels" in data
286 | assert len(data["support_levels"]) > 0
287 | assert len(data["resistance_levels"]) > 0
288 |
289 |
290 | class TestScreeningRouter:
291 | """Test stock screening router functionality."""
292 |
293 | @pytest.mark.asyncio
294 | async def test_maverick_screening(self, test_server, screening_db):
295 | """Test Maverick bullish screening."""
296 | test_server.mount("/screening", screening_router)
297 |
298 | async with Client(test_server) as client:
299 | result = await client.call_tool(
300 | "/screening_get_maverick_stocks", {"limit": 10}
301 | )
302 |
303 | assert len(result) > 0
304 | assert result[0].text is not None
305 | data = eval(result[0].text)
306 |
307 | assert "stocks" in data
308 | assert len(data["stocks"]) == 2 # AAPL and MSFT
309 | assert (
310 | data["stocks"][0]["combined_score"]
311 | > data["stocks"][1]["combined_score"]
312 | ) # Sorted by combined score
313 | assert all(
314 | stock["combined_score"] > 0 for stock in data["stocks"]
315 | ) # Score should be positive
316 |
317 | @pytest.mark.asyncio
318 | async def test_trending_screening(self, test_server, screening_db):
319 | """Test trending screening."""
320 | test_server.mount("/screening", screening_router)
321 |
322 | async with Client(test_server) as client:
323 | result = await client.call_tool(
324 | "/screening_get_trending_stocks", {"limit": 5}
325 | )
326 |
327 | assert len(result) > 0
328 | assert result[0].text is not None
329 | data = eval(result[0].text)
330 |
331 | assert "stocks" in data
332 | assert len(data["stocks"]) == 1 # Only GOOGL
333 | assert data["stocks"][0]["stock"] == "GOOGL"
334 | assert (
335 | data["stocks"][0]["momentum_score"] > 0
336 | ) # Momentum score should be positive
337 |
338 | @pytest.mark.asyncio
339 | async def test_all_screenings(self, test_server, screening_db):
340 | """Test combined screening results."""
341 | test_server.mount("/screening", screening_router)
342 |
343 | async with Client(test_server) as client:
344 | result = await client.call_tool(
345 | "/screening_get_all_screening_recommendations", {}
346 | )
347 |
348 | assert len(result) > 0
349 | assert result[0].text is not None
350 | data = eval(result[0].text)
351 |
352 | assert "maverick_stocks" in data
353 | assert "maverick_bear_stocks" in data
354 | assert "trending_stocks" in data
355 | assert len(data["maverick_stocks"]) == 2
356 | assert len(data["trending_stocks"]) == 1
357 |
358 |
359 | class TestPortfolioRouter:
360 | """Test portfolio analysis router functionality."""
361 |
362 | @pytest.mark.asyncio
363 | async def test_risk_analysis(self, test_server, screening_db):
364 | """Test portfolio risk analysis."""
365 | test_server.mount("/portfolio", portfolio_router)
366 |
367 | # Mock stock data for risk calculations
368 | with patch(
369 | "maverick_mcp.providers.stock_data.StockDataProvider.get_stock_data"
370 | ) as mock_data:
371 | # Create price data with volatility
372 | import numpy as np
373 | import pandas as pd
374 |
375 | prices = []
376 | base_price = 100.0
377 | for _ in range(252): # One year of trading days
378 | # Add some random walk
379 | change = np.random.normal(0, 2)
380 | base_price = float(base_price * (1 + change / 100))
381 | prices.append(
382 | {
383 | "close": base_price,
384 | "high": base_price + 1,
385 | "low": base_price - 1,
386 | "open": base_price,
387 | "volume": 1000000,
388 | }
389 | )
390 | dates = pd.date_range(end="2024-01-31", periods=252)
391 | prices_df = pd.DataFrame(prices, index=dates)
392 | mock_data.return_value = prices_df
393 |
394 | async with Client(test_server) as client:
395 | result = await client.call_tool(
396 | "/portfolio_risk_adjusted_analysis",
397 | {"ticker": "AAPL", "risk_level": 50.0},
398 | )
399 |
400 | assert len(result) > 0
401 | assert result[0].text is not None
402 | data = eval(result[0].text)
403 |
404 | assert "risk_level" in data or "analysis" in data
405 | assert "ticker" in data
406 |
407 | @pytest.mark.asyncio
408 | async def test_correlation_analysis(self, test_server, screening_db):
409 | """Test correlation analysis between stocks."""
410 | test_server.mount("/portfolio", portfolio_router)
411 |
412 | # Mock correlated stock data
413 | with patch(
414 | "maverick_mcp.providers.stock_data.StockDataProvider.get_stock_data"
415 | ) as mock_data:
416 | import numpy as np
417 |
418 | def create_correlated_data(base_return, correlation):
419 | import pandas as pd
420 |
421 | prices = []
422 | base_price = 100
423 | for _ in range(100):
424 | # Create correlated returns
425 | return_pct = base_return + (correlation * np.random.normal(0, 1))
426 | base_price = base_price * (1 + return_pct / 100)
427 | prices.append(
428 | {
429 | "close": base_price,
430 | "high": base_price + 1,
431 | "low": base_price - 1,
432 | "open": base_price,
433 | "volume": 1000000,
434 | }
435 | )
436 | dates = pd.date_range(end="2024-01-31", periods=100)
437 | return pd.DataFrame(prices, index=dates)
438 |
439 | # Return different data for different tickers
440 | mock_data.side_effect = [
441 | create_correlated_data(0.1, 0), # AAPL
442 | create_correlated_data(0.1, 0.8), # MSFT (high correlation)
443 | create_correlated_data(0.1, -0.3), # GOOGL (negative correlation)
444 | ]
445 |
446 | async with Client(test_server) as client:
447 | result = await client.call_tool(
448 | "/portfolio_portfolio_correlation_analysis",
449 | {"tickers": ["AAPL", "MSFT", "GOOGL"]},
450 | )
451 |
452 | assert len(result) > 0
453 | assert result[0].text is not None
454 |
455 | # Handle NaN values in response
456 | result_text = result[0].text.replace("NaN", "null")
457 | import json
458 |
459 | data = json.loads(result_text.replace("'", '"'))
460 |
461 | assert "correlation_matrix" in data
462 | assert len(data["correlation_matrix"]) == 3
463 | assert "recommendation" in data
464 |
465 |
466 | class TestDataRouter:
467 | """Test data fetching router functionality."""
468 |
469 | @pytest.mark.asyncio
470 | async def test_batch_fetch_with_validation(self, test_server, screening_db):
471 | """Test batch data fetching with validation."""
472 | test_server.mount("/data", data_router)
473 |
474 | async with Client(test_server) as client:
475 | # Test with valid tickers
476 | result = await client.call_tool(
477 | "/data_fetch_stock_data_batch",
478 | {
479 | "request": {
480 | "tickers": ["AAPL", "MSFT"],
481 | "start_date": "2024-01-01",
482 | "end_date": "2024-01-31",
483 | }
484 | },
485 | )
486 |
487 | assert len(result) > 0
488 | assert result[0].text is not None
489 | data = eval(result[0].text)
490 | assert "results" in data
491 | assert len(data["results"]) == 2
492 |
493 | # Test with invalid ticker format
494 | with pytest.raises(Exception) as exc_info:
495 | await client.call_tool(
496 | "/data_fetch_stock_data_batch",
497 | {
498 | "request": {
499 | "tickers": [
500 | "AAPL",
501 | "invalid_ticker",
502 | ], # lowercase not allowed
503 | "start_date": "2024-01-01",
504 | "end_date": "2024-01-31",
505 | }
506 | },
507 | )
508 |
509 | assert "validation error" in str(exc_info.value).lower()
510 |
511 | @pytest.mark.asyncio
512 | async def test_cache_operations(self, test_server, screening_db):
513 | """Test cache management operations."""
514 | test_server.mount("/data", data_router)
515 |
516 | # Patch the _get_redis_client to test cache operations
517 | with patch("maverick_mcp.data.cache._get_redis_client") as mock_redis_client:
518 | cache_instance = Mock()
519 | cache_instance.get.return_value = '{"cached": true, "data": "test"}'
520 | cache_instance.set.return_value = True
521 | cache_instance.delete.return_value = 1
522 | cache_instance.keys.return_value = [b"stock:AAPL:1", b"stock:AAPL:2"]
523 | mock_redis_client.return_value = cache_instance
524 |
525 | async with Client(test_server) as client:
526 | # Test cache clear
527 | result = await client.call_tool(
528 | "/data_clear_cache", {"request": {"ticker": "AAPL"}}
529 | )
530 |
531 | assert len(result) > 0
532 | assert result[0].text is not None
533 | assert (
534 | "clear" in result[0].text.lower()
535 | or "success" in result[0].text.lower()
536 | )
537 | # Verify cache operations
538 | assert cache_instance.keys.called or cache_instance.delete.called
539 |
540 |
541 | class TestConcurrentOperations:
542 | """Test concurrent operations and performance."""
543 |
544 | @pytest.mark.asyncio
545 | async def test_concurrent_router_calls(self, test_server, screening_db):
546 | """Test multiple routers being called concurrently."""
547 | # Mount all routers
548 | test_server.mount("/technical", technical_router)
549 | test_server.mount("/screening", screening_router)
550 | test_server.mount("/portfolio", portfolio_router)
551 | test_server.mount("/data", data_router)
552 |
553 | with patch(
554 | "maverick_mcp.providers.stock_data.StockDataProvider.get_stock_data"
555 | ) as mock_data:
556 | import pandas as pd
557 |
558 | dates = pd.date_range(end="2024-01-31", periods=30)
559 | mock_data.return_value = pd.DataFrame(
560 | {
561 | "Close": [100 + i for i in range(30)],
562 | "High": [101 + i for i in range(30)],
563 | "Low": [99 + i for i in range(30)],
564 | "Open": [100 + i for i in range(30)],
565 | "Volume": [1000000] * 30,
566 | },
567 | index=dates,
568 | )
569 |
570 | async with Client(test_server) as client:
571 | # Create concurrent tasks across different routers
572 | tasks = [
573 | client.call_tool(
574 | "/technical_get_rsi_analysis", {"ticker": "AAPL", "period": 14}
575 | ),
576 | client.call_tool("/screening_get_maverick_stocks", {"limit": 5}),
577 | client.call_tool(
578 | "/data_fetch_stock_data_batch",
579 | {
580 | "request": {
581 | "tickers": ["AAPL", "MSFT"],
582 | "start_date": "2024-01-01",
583 | "end_date": "2024-01-31",
584 | }
585 | },
586 | ),
587 | ]
588 |
589 | results = await asyncio.gather(*tasks)
590 |
591 | # All should complete successfully
592 | assert len(results) == 3
593 | for result in results:
594 | assert len(result) > 0
595 | assert result[0].text is not None
596 |
597 |
598 | if __name__ == "__main__":
599 | pytest.main([__file__, "-v"])
600 |
```
--------------------------------------------------------------------------------
/maverick_mcp/monitoring/health_check.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Health check module for MaverickMCP.
3 |
4 | This module provides comprehensive health checking capabilities for all system components
5 | including database, cache, APIs, and external services.
6 | """
7 |
8 | import asyncio
9 | import logging
10 | import time
11 | from dataclasses import dataclass
12 | from datetime import UTC, datetime
13 | from enum import Enum
14 | from typing import Any
15 |
16 | logger = logging.getLogger(__name__)
17 |
18 |
19 | class HealthStatus(str, Enum):
20 | """Health status enumeration."""
21 |
22 | HEALTHY = "healthy"
23 | DEGRADED = "degraded"
24 | UNHEALTHY = "unhealthy"
25 | UNKNOWN = "unknown"
26 |
27 |
28 | @dataclass
29 | class ComponentHealth:
30 | """Health information for a component."""
31 |
32 | name: str
33 | status: HealthStatus
34 | message: str
35 | response_time_ms: float | None = None
36 | details: dict[str, Any] | None = None
37 | last_check: datetime | None = None
38 |
39 |
40 | @dataclass
41 | class SystemHealth:
42 | """Overall system health information."""
43 |
44 | status: HealthStatus
45 | components: dict[str, ComponentHealth]
46 | overall_response_time_ms: float
47 | timestamp: datetime
48 | uptime_seconds: float | None = None
49 | version: str | None = None
50 |
51 |
52 | class HealthChecker:
53 | """
54 | Comprehensive health checker for MaverickMCP system components.
55 |
56 | This class provides health checking capabilities for:
57 | - Database connections
58 | - Redis cache
59 | - External APIs (Tiingo, OpenRouter, etc.)
60 | - System resources
61 | - Application services
62 | """
63 |
64 | def __init__(self):
65 | """Initialize the health checker."""
66 | self.start_time = time.time()
67 | self._component_checkers = {}
68 | self._setup_component_checkers()
69 |
70 | def _setup_component_checkers(self):
71 | """Setup component-specific health checkers."""
72 | self._component_checkers = {
73 | "database": self._check_database_health,
74 | "cache": self._check_cache_health,
75 | "tiingo_api": self._check_tiingo_api_health,
76 | "openrouter_api": self._check_openrouter_api_health,
77 | "exa_api": self._check_exa_api_health,
78 | "system_resources": self._check_system_resources_health,
79 | }
80 |
81 | async def check_health(self, components: list[str] | None = None) -> SystemHealth:
82 | """
83 | Check health of specified components or all components.
84 |
85 | Args:
86 | components: List of component names to check. If None, checks all components.
87 |
88 | Returns:
89 | SystemHealth object with overall and component-specific health information.
90 | """
91 | start_time = time.time()
92 |
93 | # Determine which components to check
94 | components_to_check = components or list(self._component_checkers.keys())
95 |
96 | # Run health checks concurrently
97 | component_results = {}
98 | tasks = []
99 |
100 | for component_name in components_to_check:
101 | if component_name in self._component_checkers:
102 | task = asyncio.create_task(
103 | self._check_component_with_timeout(component_name),
104 | name=f"health_check_{component_name}",
105 | )
106 | tasks.append((component_name, task))
107 |
108 | # Wait for all checks to complete
109 | for component_name, task in tasks:
110 | try:
111 | component_results[component_name] = await task
112 | except Exception as e:
113 | logger.error(f"Health check failed for {component_name}: {e}")
114 | component_results[component_name] = ComponentHealth(
115 | name=component_name,
116 | status=HealthStatus.UNHEALTHY,
117 | message=f"Health check failed: {str(e)}",
118 | last_check=datetime.now(UTC),
119 | )
120 |
121 | # Calculate overall response time
122 | overall_response_time = (time.time() - start_time) * 1000
123 |
124 | # Determine overall health status
125 | overall_status = self._calculate_overall_status(component_results)
126 |
127 | return SystemHealth(
128 | status=overall_status,
129 | components=component_results,
130 | overall_response_time_ms=overall_response_time,
131 | timestamp=datetime.now(UTC),
132 | uptime_seconds=time.time() - self.start_time,
133 | version=self._get_application_version(),
134 | )
135 |
136 | async def _check_component_with_timeout(
137 | self, component_name: str, timeout: float = 10.0
138 | ) -> ComponentHealth:
139 | """
140 | Check component health with timeout protection.
141 |
142 | Args:
143 | component_name: Name of the component to check
144 | timeout: Timeout in seconds
145 |
146 | Returns:
147 | ComponentHealth for the component
148 | """
149 | try:
150 | return await asyncio.wait_for(
151 | self._component_checkers[component_name](), timeout=timeout
152 | )
153 | except TimeoutError:
154 | return ComponentHealth(
155 | name=component_name,
156 | status=HealthStatus.UNHEALTHY,
157 | message=f"Health check timed out after {timeout}s",
158 | last_check=datetime.now(UTC),
159 | )
160 |
161 | async def _check_database_health(self) -> ComponentHealth:
162 | """Check database health."""
163 | start_time = time.time()
164 |
165 | try:
166 | from sqlalchemy import text
167 |
168 | from maverick_mcp.data.database import get_db_session
169 |
170 | with get_db_session() as session:
171 | # Simple query to test database connectivity
172 | result = session.execute(text("SELECT 1 as health_check"))
173 | result.fetchone()
174 |
175 | response_time = (time.time() - start_time) * 1000
176 |
177 | return ComponentHealth(
178 | name="database",
179 | status=HealthStatus.HEALTHY,
180 | message="Database connection successful",
181 | response_time_ms=response_time,
182 | last_check=datetime.now(UTC),
183 | details={"connection_type": "SQLAlchemy"},
184 | )
185 |
186 | except Exception as e:
187 | return ComponentHealth(
188 | name="database",
189 | status=HealthStatus.UNHEALTHY,
190 | message=f"Database connection failed: {str(e)}",
191 | response_time_ms=(time.time() - start_time) * 1000,
192 | last_check=datetime.now(UTC),
193 | )
194 |
195 | async def _check_cache_health(self) -> ComponentHealth:
196 | """Check cache health."""
197 | start_time = time.time()
198 |
199 | try:
200 | from maverick_mcp.data.cache import get_cache_stats, get_redis_client
201 |
202 | # Check Redis connection if available
203 | redis_client = get_redis_client()
204 | cache_details = {"type": "memory"}
205 |
206 | if redis_client:
207 | # Test Redis connection
208 | await asyncio.get_event_loop().run_in_executor(None, redis_client.ping)
209 | cache_details["type"] = "redis"
210 | cache_details["redis_connected"] = True
211 |
212 | # Get cache statistics
213 | stats = get_cache_stats()
214 | cache_details.update(
215 | {
216 | "hit_rate_percent": stats.get("hit_rate_percent", 0),
217 | "total_requests": stats.get("total_requests", 0),
218 | "memory_cache_size": stats.get("memory_cache_size", 0),
219 | }
220 | )
221 |
222 | response_time = (time.time() - start_time) * 1000
223 |
224 | return ComponentHealth(
225 | name="cache",
226 | status=HealthStatus.HEALTHY,
227 | message="Cache system operational",
228 | response_time_ms=response_time,
229 | last_check=datetime.now(UTC),
230 | details=cache_details,
231 | )
232 |
233 | except Exception as e:
234 | return ComponentHealth(
235 | name="cache",
236 | status=HealthStatus.DEGRADED,
237 | message=f"Cache issues detected: {str(e)}",
238 | response_time_ms=(time.time() - start_time) * 1000,
239 | last_check=datetime.now(UTC),
240 | )
241 |
242 | async def _check_tiingo_api_health(self) -> ComponentHealth:
243 | """Check Tiingo API health."""
244 | start_time = time.time()
245 |
246 | try:
247 | from maverick_mcp.config.settings import get_settings
248 | from maverick_mcp.providers.data_provider import get_stock_provider
249 |
250 | settings = get_settings()
251 | if not settings.data_providers.tiingo_api_key:
252 | return ComponentHealth(
253 | name="tiingo_api",
254 | status=HealthStatus.UNKNOWN,
255 | message="Tiingo API key not configured",
256 | response_time_ms=(time.time() - start_time) * 1000,
257 | last_check=datetime.now(UTC),
258 | )
259 |
260 | # Test API with a simple quote request
261 | provider = get_stock_provider()
262 | quote = await provider.get_quote("AAPL")
263 |
264 | response_time = (time.time() - start_time) * 1000
265 |
266 | if quote and quote.get("price"):
267 | return ComponentHealth(
268 | name="tiingo_api",
269 | status=HealthStatus.HEALTHY,
270 | message="Tiingo API responding correctly",
271 | response_time_ms=response_time,
272 | last_check=datetime.now(UTC),
273 | details={"test_symbol": "AAPL", "price_available": True},
274 | )
275 | else:
276 | return ComponentHealth(
277 | name="tiingo_api",
278 | status=HealthStatus.DEGRADED,
279 | message="Tiingo API responding but data may be incomplete",
280 | response_time_ms=response_time,
281 | last_check=datetime.now(UTC),
282 | )
283 |
284 | except Exception as e:
285 | return ComponentHealth(
286 | name="tiingo_api",
287 | status=HealthStatus.UNHEALTHY,
288 | message=f"Tiingo API check failed: {str(e)}",
289 | response_time_ms=(time.time() - start_time) * 1000,
290 | last_check=datetime.now(UTC),
291 | )
292 |
293 | async def _check_openrouter_api_health(self) -> ComponentHealth:
294 | """Check OpenRouter API health."""
295 | start_time = time.time()
296 |
297 | try:
298 | from maverick_mcp.config.settings import get_settings
299 |
300 | settings = get_settings()
301 | if not settings.research.openrouter_api_key:
302 | return ComponentHealth(
303 | name="openrouter_api",
304 | status=HealthStatus.UNKNOWN,
305 | message="OpenRouter API key not configured",
306 | response_time_ms=(time.time() - start_time) * 1000,
307 | last_check=datetime.now(UTC),
308 | )
309 |
310 | # For now, just check if the key is configured
311 | # A full API test would require making an actual request
312 | response_time = (time.time() - start_time) * 1000
313 |
314 | return ComponentHealth(
315 | name="openrouter_api",
316 | status=HealthStatus.HEALTHY,
317 | message="OpenRouter API key configured",
318 | response_time_ms=response_time,
319 | last_check=datetime.now(UTC),
320 | details={"api_key_configured": True},
321 | )
322 |
323 | except Exception as e:
324 | return ComponentHealth(
325 | name="openrouter_api",
326 | status=HealthStatus.UNHEALTHY,
327 | message=f"OpenRouter API check failed: {str(e)}",
328 | response_time_ms=(time.time() - start_time) * 1000,
329 | last_check=datetime.now(UTC),
330 | )
331 |
332 | async def _check_exa_api_health(self) -> ComponentHealth:
333 | """Check Exa API health."""
334 | start_time = time.time()
335 |
336 | try:
337 | from maverick_mcp.config.settings import get_settings
338 |
339 | settings = get_settings()
340 | if not settings.research.exa_api_key:
341 | return ComponentHealth(
342 | name="exa_api",
343 | status=HealthStatus.UNKNOWN,
344 | message="Exa API key not configured",
345 | response_time_ms=(time.time() - start_time) * 1000,
346 | last_check=datetime.now(UTC),
347 | )
348 |
349 | # For now, just check if the key is configured
350 | # A full API test would require making an actual request
351 | response_time = (time.time() - start_time) * 1000
352 |
353 | return ComponentHealth(
354 | name="exa_api",
355 | status=HealthStatus.HEALTHY,
356 | message="Exa API key configured",
357 | response_time_ms=response_time,
358 | last_check=datetime.now(UTC),
359 | details={"api_key_configured": True},
360 | )
361 |
362 | except Exception as e:
363 | return ComponentHealth(
364 | name="exa_api",
365 | status=HealthStatus.UNHEALTHY,
366 | message=f"Exa API check failed: {str(e)}",
367 | response_time_ms=(time.time() - start_time) * 1000,
368 | last_check=datetime.now(UTC),
369 | )
370 |
371 | async def _check_system_resources_health(self) -> ComponentHealth:
372 | """Check system resource health."""
373 | start_time = time.time()
374 |
375 | try:
376 | import psutil
377 |
378 | # Get system resource usage
379 | cpu_percent = psutil.cpu_percent(interval=1)
380 | memory = psutil.virtual_memory()
381 | disk = psutil.disk_usage("/")
382 |
383 | # Determine status based on resource usage
384 | status = HealthStatus.HEALTHY
385 | messages = []
386 |
387 | if cpu_percent > 80:
388 | status = (
389 | HealthStatus.DEGRADED
390 | if cpu_percent < 90
391 | else HealthStatus.UNHEALTHY
392 | )
393 | messages.append(f"High CPU usage: {cpu_percent:.1f}%")
394 |
395 | if memory.percent > 85:
396 | status = (
397 | HealthStatus.DEGRADED
398 | if memory.percent < 95
399 | else HealthStatus.UNHEALTHY
400 | )
401 | messages.append(f"High memory usage: {memory.percent:.1f}%")
402 |
403 | if disk.percent > 90:
404 | status = (
405 | HealthStatus.DEGRADED
406 | if disk.percent < 95
407 | else HealthStatus.UNHEALTHY
408 | )
409 | messages.append(f"High disk usage: {disk.percent:.1f}%")
410 |
411 | message = (
412 | "; ".join(messages)
413 | if messages
414 | else "System resources within normal limits"
415 | )
416 |
417 | response_time = (time.time() - start_time) * 1000
418 |
419 | return ComponentHealth(
420 | name="system_resources",
421 | status=status,
422 | message=message,
423 | response_time_ms=response_time,
424 | last_check=datetime.now(UTC),
425 | details={
426 | "cpu_percent": cpu_percent,
427 | "memory_percent": memory.percent,
428 | "disk_percent": disk.percent,
429 | "memory_available_gb": memory.available / (1024**3),
430 | "disk_free_gb": disk.free / (1024**3),
431 | },
432 | )
433 |
434 | except ImportError:
435 | return ComponentHealth(
436 | name="system_resources",
437 | status=HealthStatus.UNKNOWN,
438 | message="psutil not available for system monitoring",
439 | response_time_ms=(time.time() - start_time) * 1000,
440 | last_check=datetime.now(UTC),
441 | )
442 | except Exception as e:
443 | return ComponentHealth(
444 | name="system_resources",
445 | status=HealthStatus.UNHEALTHY,
446 | message=f"System resource check failed: {str(e)}",
447 | response_time_ms=(time.time() - start_time) * 1000,
448 | last_check=datetime.now(UTC),
449 | )
450 |
451 | def _calculate_overall_status(
452 | self, components: dict[str, ComponentHealth]
453 | ) -> HealthStatus:
454 | """
455 | Calculate overall system health status based on component health.
456 |
457 | Args:
458 | components: Dictionary of component health results
459 |
460 | Returns:
461 | Overall HealthStatus
462 | """
463 | if not components:
464 | return HealthStatus.UNKNOWN
465 |
466 | statuses = [comp.status for comp in components.values()]
467 |
468 | # If any component is unhealthy, system is unhealthy
469 | if HealthStatus.UNHEALTHY in statuses:
470 | return HealthStatus.UNHEALTHY
471 |
472 | # If any component is degraded, system is degraded
473 | if HealthStatus.DEGRADED in statuses:
474 | return HealthStatus.DEGRADED
475 |
476 | # If all components are healthy, system is healthy
477 | if all(status == HealthStatus.HEALTHY for status in statuses):
478 | return HealthStatus.HEALTHY
479 |
480 | # Mixed healthy/unknown status defaults to degraded
481 | return HealthStatus.DEGRADED
482 |
483 | def _get_application_version(self) -> str | None:
484 | """Get application version."""
485 | try:
486 | from maverick_mcp import __version__
487 |
488 | return __version__
489 | except ImportError:
490 | return None
491 |
492 | async def check_component(self, component_name: str) -> ComponentHealth:
493 | """
494 | Check health of a specific component.
495 |
496 | Args:
497 | component_name: Name of the component to check
498 |
499 | Returns:
500 | ComponentHealth for the specified component
501 |
502 | Raises:
503 | ValueError: If component_name is not supported
504 | """
505 | if component_name not in self._component_checkers:
506 | raise ValueError(
507 | f"Unknown component: {component_name}. "
508 | f"Supported components: {list(self._component_checkers.keys())}"
509 | )
510 |
511 | return await self._check_component_with_timeout(component_name)
512 |
513 | def get_supported_components(self) -> list[str]:
514 | """
515 | Get list of supported component names.
516 |
517 | Returns:
518 | List of component names that can be checked
519 | """
520 | return list(self._component_checkers.keys())
521 |
522 | def get_health_status(self) -> dict[str, Any]:
523 | """
524 | Get comprehensive health status (synchronous wrapper).
525 |
526 | Returns:
527 | Dictionary with health status and component information
528 | """
529 | import asyncio
530 |
531 | try:
532 | # Try to get the current event loop
533 | loop = asyncio.get_event_loop()
534 | if loop.is_running():
535 | # We're already in an async context, return simplified status
536 | return {
537 | "status": "HEALTHY",
538 | "components": {
539 | name: {"status": "UNKNOWN", "message": "Check pending"}
540 | for name in self._component_checkers.keys()
541 | },
542 | "timestamp": datetime.now(UTC).isoformat(),
543 | "message": "Health check in async context",
544 | "uptime_seconds": time.time() - self.start_time,
545 | }
546 | else:
547 | # Run the async check in the existing loop
548 | result = loop.run_until_complete(self.check_health())
549 | return self._health_to_dict(result)
550 | except RuntimeError:
551 | # No event loop exists, create one
552 | result = asyncio.run(self.check_health())
553 | return self._health_to_dict(result)
554 |
555 | async def check_overall_health(self) -> dict[str, Any]:
556 | """
557 | Async method to check overall health.
558 |
559 | Returns:
560 | Dictionary with health status information
561 | """
562 | result = await self.check_health()
563 | return self._health_to_dict(result)
564 |
565 | def _health_to_dict(self, health: SystemHealth) -> dict[str, Any]:
566 | """
567 | Convert SystemHealth object to dictionary.
568 |
569 | Args:
570 | health: SystemHealth object
571 |
572 | Returns:
573 | Dictionary representation
574 | """
575 | return {
576 | "status": health.status.value,
577 | "components": {
578 | name: {
579 | "status": comp.status.value,
580 | "message": comp.message,
581 | "response_time_ms": comp.response_time_ms,
582 | "details": comp.details,
583 | "last_check": comp.last_check.isoformat()
584 | if comp.last_check
585 | else None,
586 | }
587 | for name, comp in health.components.items()
588 | },
589 | "overall_response_time_ms": health.overall_response_time_ms,
590 | "timestamp": health.timestamp.isoformat(),
591 | "uptime_seconds": health.uptime_seconds,
592 | "version": health.version,
593 | }
594 |
595 |
596 | # Convenience function for quick health checks
597 | async def check_system_health(components: list[str] | None = None) -> SystemHealth:
598 | """
599 | Convenience function to check system health.
600 |
601 | Args:
602 | components: Optional list of component names to check
603 |
604 | Returns:
605 | SystemHealth object
606 | """
607 | checker = HealthChecker()
608 | return await checker.check_health(components)
609 |
610 |
611 | # Global health checker instance
612 | _global_health_checker: HealthChecker | None = None
613 |
614 |
615 | def get_health_checker() -> HealthChecker:
616 | """
617 | Get or create the global health checker instance.
618 |
619 | Returns:
620 | HealthChecker instance
621 | """
622 | global _global_health_checker
623 | if _global_health_checker is None:
624 | _global_health_checker = HealthChecker()
625 | return _global_health_checker
626 |
```
--------------------------------------------------------------------------------
/maverick_mcp/utils/monitoring.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Monitoring and observability integration for MaverickMCP.
3 |
4 | This module provides Sentry error tracking and Prometheus metrics integration
5 | for production monitoring and alerting.
6 | """
7 |
8 | import os
9 | import time
10 | from contextlib import contextmanager
11 | from typing import Any
12 |
13 | from maverick_mcp.config.settings import settings
14 | from maverick_mcp.utils.logging import get_logger
15 |
16 | # Optional prometheus integration
17 | try:
18 | from prometheus_client import Counter, Gauge, Histogram, generate_latest
19 |
20 | PROMETHEUS_AVAILABLE = True
21 | except ImportError:
22 | logger = get_logger(__name__)
23 | logger.warning("Prometheus client not available. Metrics will be disabled.")
24 | PROMETHEUS_AVAILABLE = False
25 |
26 | # Create stub classes for when prometheus is not available
27 | class _MetricStub:
28 | def __init__(self, *args, **kwargs):
29 | pass
30 |
31 | def inc(self, *args, **kwargs):
32 | pass
33 |
34 | def observe(self, *args, **kwargs):
35 | pass
36 |
37 | def set(self, *args, **kwargs):
38 | pass
39 |
40 | def dec(self, *args, **kwargs):
41 | pass
42 |
43 | def labels(self, *args, **kwargs):
44 | return self
45 |
46 | Counter = Gauge = Histogram = _MetricStub
47 |
48 | def generate_latest():
49 | return b"# Prometheus not available"
50 |
51 |
52 | logger = get_logger(__name__)
53 |
54 | # HTTP Request metrics
55 | request_counter = Counter(
56 | "maverick_requests_total",
57 | "Total number of API requests",
58 | ["method", "endpoint", "status", "user_type"],
59 | )
60 |
61 | request_duration = Histogram(
62 | "maverick_request_duration_seconds",
63 | "Request duration in seconds",
64 | ["method", "endpoint", "user_type"],
65 | buckets=(
66 | 0.01,
67 | 0.025,
68 | 0.05,
69 | 0.1,
70 | 0.25,
71 | 0.5,
72 | 1.0,
73 | 2.5,
74 | 5.0,
75 | 10.0,
76 | 30.0,
77 | 60.0,
78 | float("inf"),
79 | ),
80 | )
81 |
82 | request_size_bytes = Histogram(
83 | "maverick_request_size_bytes",
84 | "HTTP request size in bytes",
85 | ["method", "endpoint"],
86 | buckets=(1024, 4096, 16384, 65536, 262144, 1048576, 4194304, float("inf")),
87 | )
88 |
89 | response_size_bytes = Histogram(
90 | "maverick_response_size_bytes",
91 | "HTTP response size in bytes",
92 | ["method", "endpoint", "status"],
93 | buckets=(1024, 4096, 16384, 65536, 262144, 1048576, 4194304, float("inf")),
94 | )
95 |
96 | # Connection metrics
97 | active_connections = Gauge(
98 | "maverick_active_connections", "Number of active connections"
99 | )
100 |
101 | concurrent_requests = Gauge(
102 | "maverick_concurrent_requests", "Number of concurrent requests being processed"
103 | )
104 |
105 | # Tool execution metrics
106 | tool_usage_counter = Counter(
107 | "maverick_tool_usage_total",
108 | "Total tool usage count",
109 | ["tool_name", "user_id", "status"],
110 | )
111 |
112 | tool_duration = Histogram(
113 | "maverick_tool_duration_seconds",
114 | "Tool execution duration in seconds",
115 | ["tool_name", "complexity"],
116 | buckets=(0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0, 60.0, 120.0, 300.0, float("inf")),
117 | )
118 |
119 | tool_errors = Counter(
120 | "maverick_tool_errors_total",
121 | "Total tool execution errors",
122 | ["tool_name", "error_type", "complexity"],
123 | )
124 |
125 | # Error metrics
126 | error_counter = Counter(
127 | "maverick_errors_total",
128 | "Total number of errors",
129 | ["error_type", "endpoint", "severity"],
130 | )
131 |
132 | rate_limit_hits = Counter(
133 | "maverick_rate_limit_hits_total",
134 | "Rate limit violations",
135 | ["user_id", "endpoint", "limit_type"],
136 | )
137 |
138 | # Cache metrics
139 | cache_hits = Counter(
140 | "maverick_cache_hits_total", "Total cache hits", ["cache_type", "key_prefix"]
141 | )
142 |
143 | cache_misses = Counter(
144 | "maverick_cache_misses_total", "Total cache misses", ["cache_type", "key_prefix"]
145 | )
146 |
147 | cache_evictions = Counter(
148 | "maverick_cache_evictions_total", "Total cache evictions", ["cache_type", "reason"]
149 | )
150 |
151 | cache_size_bytes = Gauge(
152 | "maverick_cache_size_bytes", "Cache size in bytes", ["cache_type"]
153 | )
154 |
155 | cache_keys_total = Gauge(
156 | "maverick_cache_keys_total", "Total number of keys in cache", ["cache_type"]
157 | )
158 |
159 | # Database metrics
160 | db_connection_pool_size = Gauge(
161 | "maverick_db_connection_pool_size", "Database connection pool size"
162 | )
163 |
164 | db_active_connections = Gauge(
165 | "maverick_db_active_connections", "Active database connections"
166 | )
167 |
168 | db_idle_connections = Gauge("maverick_db_idle_connections", "Idle database connections")
169 |
170 | db_query_duration = Histogram(
171 | "maverick_db_query_duration_seconds",
172 | "Database query duration in seconds",
173 | ["query_type", "table"],
174 | buckets=(
175 | 0.001,
176 | 0.005,
177 | 0.01,
178 | 0.025,
179 | 0.05,
180 | 0.1,
181 | 0.25,
182 | 0.5,
183 | 1.0,
184 | 2.5,
185 | 5.0,
186 | float("inf"),
187 | ),
188 | )
189 |
190 | db_queries_total = Counter(
191 | "maverick_db_queries_total",
192 | "Total database queries",
193 | ["query_type", "table", "status"],
194 | )
195 |
196 | db_connections_created = Counter(
197 | "maverick_db_connections_created_total", "Total database connections created"
198 | )
199 |
200 | db_connections_closed = Counter(
201 | "maverick_db_connections_closed_total",
202 | "Total database connections closed",
203 | ["reason"],
204 | )
205 |
206 | # Redis metrics
207 | redis_connections = Gauge("maverick_redis_connections", "Number of Redis connections")
208 |
209 | redis_operations = Counter(
210 | "maverick_redis_operations_total", "Total Redis operations", ["operation", "status"]
211 | )
212 |
213 | redis_operation_duration = Histogram(
214 | "maverick_redis_operation_duration_seconds",
215 | "Redis operation duration in seconds",
216 | ["operation"],
217 | buckets=(0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, float("inf")),
218 | )
219 |
220 | redis_memory_usage = Gauge(
221 | "maverick_redis_memory_usage_bytes", "Redis memory usage in bytes"
222 | )
223 |
224 | redis_keyspace_hits = Counter(
225 | "maverick_redis_keyspace_hits_total", "Redis keyspace hits"
226 | )
227 |
228 | redis_keyspace_misses = Counter(
229 | "maverick_redis_keyspace_misses_total", "Redis keyspace misses"
230 | )
231 |
232 | # External API metrics
233 | external_api_calls = Counter(
234 | "maverick_external_api_calls_total",
235 | "External API calls",
236 | ["service", "endpoint", "method", "status"],
237 | )
238 |
239 | external_api_duration = Histogram(
240 | "maverick_external_api_duration_seconds",
241 | "External API call duration in seconds",
242 | ["service", "endpoint"],
243 | buckets=(0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 30.0, 60.0, float("inf")),
244 | )
245 |
246 | external_api_errors = Counter(
247 | "maverick_external_api_errors_total",
248 | "External API errors",
249 | ["service", "endpoint", "error_type"],
250 | )
251 |
252 | # Business metrics
253 | daily_active_users = Gauge("maverick_daily_active_users", "Daily active users count")
254 |
255 | monthly_active_users = Gauge(
256 | "maverick_monthly_active_users", "Monthly active users count"
257 | )
258 |
259 | user_sessions = Counter(
260 | "maverick_user_sessions_total", "Total user sessions", ["user_type", "auth_method"]
261 | )
262 |
263 | user_session_duration = Histogram(
264 | "maverick_user_session_duration_seconds",
265 | "User session duration in seconds",
266 | ["user_type"],
267 | buckets=(60, 300, 900, 1800, 3600, 7200, 14400, 28800, 86400, float("inf")),
268 | )
269 |
270 | # Performance metrics
271 | memory_usage_bytes = Gauge(
272 | "maverick_memory_usage_bytes", "Process memory usage in bytes"
273 | )
274 |
275 | cpu_usage_percent = Gauge("maverick_cpu_usage_percent", "Process CPU usage percentage")
276 |
277 | open_file_descriptors = Gauge(
278 | "maverick_open_file_descriptors", "Number of open file descriptors"
279 | )
280 |
281 | garbage_collections = Counter(
282 | "maverick_garbage_collections_total", "Garbage collection events", ["generation"]
283 | )
284 |
285 | # Security metrics
286 | authentication_attempts = Counter(
287 | "maverick_authentication_attempts_total",
288 | "Authentication attempts",
289 | ["method", "status", "user_agent"],
290 | )
291 |
292 | authorization_checks = Counter(
293 | "maverick_authorization_checks_total",
294 | "Authorization checks",
295 | ["resource", "action", "status"],
296 | )
297 |
298 | security_violations = Counter(
299 | "maverick_security_violations_total",
300 | "Security violations detected",
301 | ["violation_type", "severity"],
302 | )
303 |
304 |
305 | class MonitoringService:
306 | """Service for monitoring and observability."""
307 |
308 | def __init__(self):
309 | self.sentry_enabled = False
310 | self._initialize_sentry()
311 |
312 | def _initialize_sentry(self):
313 | """Initialize Sentry error tracking."""
314 | sentry_dsn = os.getenv("SENTRY_DSN")
315 |
316 | if not sentry_dsn:
317 | if settings.environment == "production":
318 | logger.warning("Sentry DSN not configured in production")
319 | return
320 |
321 | try:
322 | import sentry_sdk
323 | from sentry_sdk.integrations.asyncio import AsyncioIntegration
324 | from sentry_sdk.integrations.logging import LoggingIntegration
325 | from sentry_sdk.integrations.sqlalchemy import SqlalchemyIntegration
326 |
327 | # Configure Sentry
328 | sentry_sdk.init(
329 | dsn=sentry_dsn,
330 | environment=settings.environment,
331 | traces_sample_rate=0.1 if settings.environment == "production" else 1.0,
332 | profiles_sample_rate=0.1
333 | if settings.environment == "production"
334 | else 1.0,
335 | integrations=[
336 | AsyncioIntegration(),
337 | LoggingIntegration(
338 | level=None, # Capture all levels
339 | event_level=None, # Don't create events from logs
340 | ),
341 | SqlalchemyIntegration(),
342 | ],
343 | before_send=self._before_send_sentry,
344 | attach_stacktrace=True,
345 | send_default_pii=False, # Don't send PII
346 | release=os.getenv("RELEASE_VERSION", "unknown"),
347 | )
348 |
349 | # Set user context if available
350 | sentry_sdk.set_context(
351 | "app",
352 | {
353 | "name": settings.app_name,
354 | "environment": settings.environment,
355 | "auth_enabled": settings.auth.enabled,
356 | },
357 | )
358 |
359 | self.sentry_enabled = True
360 | logger.info("Sentry error tracking initialized")
361 |
362 | except ImportError:
363 | logger.warning("Sentry SDK not installed. Run: pip install sentry-sdk")
364 | except Exception as e:
365 | logger.error(f"Failed to initialize Sentry: {e}")
366 |
367 | def _before_send_sentry(
368 | self, event: dict[str, Any], hint: dict[str, Any]
369 | ) -> dict[str, Any] | None:
370 | """Filter events before sending to Sentry."""
371 | # Don't send certain errors
372 | if "exc_info" in hint:
373 | _, exc_value, _ = hint["exc_info"]
374 |
375 | # Skip client errors
376 | error_message = str(exc_value).lower()
377 | if any(
378 | skip in error_message
379 | for skip in [
380 | "client disconnected",
381 | "connection reset",
382 | "broken pipe",
383 | ]
384 | ):
385 | return None
386 |
387 | # Remove sensitive data
388 | if "request" in event:
389 | request = event["request"]
390 | # Remove auth headers
391 | if "headers" in request:
392 | request["headers"] = {
393 | k: v
394 | for k, v in request["headers"].items()
395 | if k.lower() not in ["authorization", "cookie", "x-api-key"]
396 | }
397 | # Remove sensitive query params
398 | if "query_string" in request:
399 | # Parse and filter query string
400 | pass
401 |
402 | return event
403 |
404 | def capture_exception(self, error: Exception, **context):
405 | """Capture exception with Sentry."""
406 | if not self.sentry_enabled:
407 | return
408 |
409 | try:
410 | import sentry_sdk
411 |
412 | # Add context
413 | for key, value in context.items():
414 | sentry_sdk.set_context(key, value)
415 |
416 | # Capture the exception
417 | sentry_sdk.capture_exception(error)
418 |
419 | except Exception as e:
420 | logger.error(f"Failed to capture exception with Sentry: {e}")
421 |
422 | def capture_message(self, message: str, level: str = "info", **context):
423 | """Capture message with Sentry."""
424 | if not self.sentry_enabled:
425 | return
426 |
427 | try:
428 | import sentry_sdk
429 |
430 | # Add context
431 | for key, value in context.items():
432 | sentry_sdk.set_context(key, value)
433 |
434 | # Capture the message
435 | sentry_sdk.capture_message(message, level=level)
436 |
437 | except Exception as e:
438 | logger.error(f"Failed to capture message with Sentry: {e}")
439 |
440 | def set_user_context(self, user_id: str | None, email: str | None = None):
441 | """Set user context for Sentry."""
442 | if not self.sentry_enabled:
443 | return
444 |
445 | try:
446 | import sentry_sdk
447 |
448 | if user_id:
449 | sentry_sdk.set_user(
450 | {
451 | "id": user_id,
452 | "email": email,
453 | }
454 | )
455 | else:
456 | sentry_sdk.set_user(None)
457 |
458 | except Exception as e:
459 | logger.error(f"Failed to set user context: {e}")
460 |
461 | @contextmanager
462 | def transaction(self, name: str, op: str = "task"):
463 | """Create a Sentry transaction."""
464 | if not self.sentry_enabled:
465 | yield
466 | return
467 |
468 | try:
469 | import sentry_sdk
470 |
471 | with sentry_sdk.start_transaction(name=name, op=op) as transaction:
472 | yield transaction
473 |
474 | except Exception as e:
475 | logger.error(f"Failed to create transaction: {e}")
476 | yield
477 |
478 | def add_breadcrumb(
479 | self, message: str, category: str = "app", level: str = "info", **data
480 | ):
481 | """Add breadcrumb for Sentry."""
482 | if not self.sentry_enabled:
483 | return
484 |
485 | try:
486 | import sentry_sdk
487 |
488 | sentry_sdk.add_breadcrumb(
489 | message=message,
490 | category=category,
491 | level=level,
492 | data=data,
493 | )
494 |
495 | except Exception as e:
496 | logger.error(f"Failed to add breadcrumb: {e}")
497 |
498 |
499 | # Global monitoring instance
500 | _monitoring_service: MonitoringService | None = None
501 |
502 |
503 | def get_monitoring_service() -> MonitoringService:
504 | """Get or create the global monitoring service."""
505 | global _monitoring_service
506 | if _monitoring_service is None:
507 | _monitoring_service = MonitoringService()
508 | return _monitoring_service
509 |
510 |
511 | @contextmanager
512 | def track_request(method: str, endpoint: str):
513 | """Track request metrics."""
514 | start_time = time.time()
515 | active_connections.inc()
516 |
517 | status = "unknown"
518 | try:
519 | yield
520 | status = "success"
521 | except Exception as e:
522 | status = "error"
523 | error_type = type(e).__name__
524 | error_counter.labels(error_type=error_type, endpoint=endpoint).inc()
525 |
526 | # Capture with Sentry
527 | monitoring = get_monitoring_service()
528 | monitoring.capture_exception(
529 | e,
530 | request={
531 | "method": method,
532 | "endpoint": endpoint,
533 | },
534 | )
535 | raise
536 | finally:
537 | # Record metrics
538 | duration = time.time() - start_time
539 | request_counter.labels(method=method, endpoint=endpoint, status=status).inc()
540 | request_duration.labels(method=method, endpoint=endpoint).observe(duration)
541 | active_connections.dec()
542 |
543 |
544 | def track_tool_usage(
545 | tool_name: str,
546 | user_id: str,
547 | duration: float,
548 | status: str = "success",
549 | complexity: str = "standard",
550 | ):
551 | """Track comprehensive tool usage metrics."""
552 | tool_usage_counter.labels(
553 | tool_name=tool_name, user_id=str(user_id), status=status
554 | ).inc()
555 | tool_duration.labels(tool_name=tool_name, complexity=complexity).observe(duration)
556 |
557 |
558 | def track_tool_error(tool_name: str, error_type: str, complexity: str = "standard"):
559 | """Track tool execution errors."""
560 | tool_errors.labels(
561 | tool_name=tool_name, error_type=error_type, complexity=complexity
562 | ).inc()
563 |
564 |
565 | def track_cache_operation(
566 | cache_type: str = "default",
567 | operation: str = "get",
568 | hit: bool = False,
569 | key_prefix: str = "unknown",
570 | ):
571 | """Track cache operations with detailed metrics."""
572 | if hit:
573 | cache_hits.labels(cache_type=cache_type, key_prefix=key_prefix).inc()
574 | else:
575 | cache_misses.labels(cache_type=cache_type, key_prefix=key_prefix).inc()
576 |
577 |
578 | def track_cache_eviction(cache_type: str, reason: str):
579 | """Track cache evictions."""
580 | cache_evictions.labels(cache_type=cache_type, reason=reason).inc()
581 |
582 |
583 | def update_cache_metrics(cache_type: str, size_bytes: int, key_count: int):
584 | """Update cache size and key count metrics."""
585 | cache_size_bytes.labels(cache_type=cache_type).set(size_bytes)
586 | cache_keys_total.labels(cache_type=cache_type).set(key_count)
587 |
588 |
589 | def track_database_query(
590 | query_type: str, table: str, duration: float, status: str = "success"
591 | ):
592 | """Track database query metrics."""
593 | db_query_duration.labels(query_type=query_type, table=table).observe(duration)
594 | db_queries_total.labels(query_type=query_type, table=table, status=status).inc()
595 |
596 |
597 | def update_database_metrics(
598 | pool_size: int, active_connections: int, idle_connections: int
599 | ):
600 | """Update database connection metrics."""
601 | db_connection_pool_size.set(pool_size)
602 | db_active_connections.set(active_connections)
603 | db_idle_connections.set(idle_connections)
604 |
605 |
606 | def track_database_connection_event(event_type: str, reason: str = "normal"):
607 | """Track database connection lifecycle events."""
608 | if event_type == "created":
609 | db_connections_created.inc()
610 | elif event_type == "closed":
611 | db_connections_closed.labels(reason=reason).inc()
612 |
613 |
614 | def track_redis_operation(operation: str, duration: float, status: str = "success"):
615 | """Track Redis operation metrics."""
616 | redis_operations.labels(operation=operation, status=status).inc()
617 | redis_operation_duration.labels(operation=operation).observe(duration)
618 |
619 |
620 | def update_redis_metrics(connections: int, memory_bytes: int, hits: int, misses: int):
621 | """Update Redis metrics."""
622 | redis_connections.set(connections)
623 | redis_memory_usage.set(memory_bytes)
624 | if hits > 0:
625 | redis_keyspace_hits.inc(hits)
626 | if misses > 0:
627 | redis_keyspace_misses.inc(misses)
628 |
629 |
630 | def track_external_api_call(
631 | service: str,
632 | endpoint: str,
633 | method: str,
634 | status_code: int,
635 | duration: float,
636 | error_type: str | None = None,
637 | ):
638 | """Track external API call metrics."""
639 | status = "success" if 200 <= status_code < 300 else "error"
640 | external_api_calls.labels(
641 | service=service, endpoint=endpoint, method=method, status=status
642 | ).inc()
643 | external_api_duration.labels(service=service, endpoint=endpoint).observe(duration)
644 |
645 | if error_type:
646 | external_api_errors.labels(
647 | service=service, endpoint=endpoint, error_type=error_type
648 | ).inc()
649 |
650 |
651 | def track_user_session(user_type: str, auth_method: str, duration: float | None = None):
652 | """Track user session metrics."""
653 | user_sessions.labels(user_type=user_type, auth_method=auth_method).inc()
654 | if duration:
655 | user_session_duration.labels(user_type=user_type).observe(duration)
656 |
657 |
658 | def update_active_users(daily_count: int, monthly_count: int):
659 | """Update active user counts."""
660 | daily_active_users.set(daily_count)
661 | monthly_active_users.set(monthly_count)
662 |
663 |
664 | def track_authentication(method: str, status: str, user_agent: str = "unknown"):
665 | """Track authentication attempts."""
666 | authentication_attempts.labels(
667 | method=method,
668 | status=status,
669 | user_agent=user_agent[:50], # Truncate user agent
670 | ).inc()
671 |
672 |
673 | def track_authorization(resource: str, action: str, status: str):
674 | """Track authorization checks."""
675 | authorization_checks.labels(resource=resource, action=action, status=status).inc()
676 |
677 |
678 | def track_security_violation(violation_type: str, severity: str = "medium"):
679 | """Track security violations."""
680 | security_violations.labels(violation_type=violation_type, severity=severity).inc()
681 |
682 |
683 | def track_rate_limit_hit(user_id: str, endpoint: str, limit_type: str):
684 | """Track rate limit violations."""
685 | rate_limit_hits.labels(
686 | user_id=str(user_id), endpoint=endpoint, limit_type=limit_type
687 | ).inc()
688 |
689 |
690 | def update_performance_metrics():
691 | """Update system performance metrics."""
692 | import gc
693 |
694 | import psutil
695 |
696 | process = psutil.Process()
697 |
698 | # Memory usage
699 | memory_info = process.memory_info()
700 | memory_usage_bytes.set(memory_info.rss)
701 |
702 | # CPU usage
703 | cpu_usage_percent.set(process.cpu_percent())
704 |
705 | # File descriptors
706 | try:
707 | open_file_descriptors.set(process.num_fds())
708 | except AttributeError:
709 | # Windows doesn't support num_fds
710 | pass
711 |
712 | # Garbage collection stats
713 | gc_stats = gc.get_stats()
714 | for i, stat in enumerate(gc_stats):
715 | if "collections" in stat:
716 | garbage_collections.labels(generation=str(i)).inc(stat["collections"])
717 |
718 |
719 | def get_metrics() -> str:
720 | """Get Prometheus metrics in text format."""
721 | if PROMETHEUS_AVAILABLE:
722 | return generate_latest().decode("utf-8")
723 | return "# Prometheus not available"
724 |
725 |
726 | def initialize_monitoring():
727 | """Initialize monitoring systems."""
728 | logger.info("Initializing monitoring systems...")
729 |
730 | # Initialize global monitoring service
731 | monitoring = get_monitoring_service()
732 |
733 | if monitoring.sentry_enabled:
734 | logger.info("Sentry error tracking initialized")
735 | else:
736 | logger.info("Sentry error tracking disabled (no DSN configured)")
737 |
738 | if PROMETHEUS_AVAILABLE:
739 | logger.info("Prometheus metrics initialized")
740 | else:
741 | logger.info("Prometheus metrics disabled (client not available)")
742 |
743 | logger.info("Monitoring systems initialization complete")
744 |
```