This is page 17 of 39. Use http://codebase.md/wshobson/maverick-mcp?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.md
│ │ ├── config.yml
│ │ ├── feature_request.md
│ │ ├── question.md
│ │ └── security_report.md
│ ├── pull_request_template.md
│ └── workflows
│ ├── claude-code-review.yml
│ └── claude.yml
├── .gitignore
├── .python-version
├── .vscode
│ ├── launch.json
│ └── settings.json
├── alembic
│ ├── env.py
│ ├── script.py.mako
│ └── versions
│ ├── 001_initial_schema.py
│ ├── 003_add_performance_indexes.py
│ ├── 006_rename_metadata_columns.py
│ ├── 008_performance_optimization_indexes.py
│ ├── 009_rename_to_supply_demand.py
│ ├── 010_self_contained_schema.py
│ ├── 011_remove_proprietary_terms.py
│ ├── 013_add_backtest_persistence_models.py
│ ├── 014_add_portfolio_models.py
│ ├── 08e3945a0c93_merge_heads.py
│ ├── 9374a5c9b679_merge_heads_for_testing.py
│ ├── abf9b9afb134_merge_multiple_heads.py
│ ├── adda6d3fd84b_merge_proprietary_terms_removal_with_.py
│ ├── e0c75b0bdadb_fix_financial_data_precision_only.py
│ ├── f0696e2cac15_add_essential_performance_indexes.py
│ └── fix_database_integrity_issues.py
├── alembic.ini
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── DATABASE_SETUP.md
├── docker-compose.override.yml.example
├── docker-compose.yml
├── Dockerfile
├── docs
│ ├── api
│ │ └── backtesting.md
│ ├── BACKTESTING.md
│ ├── COST_BASIS_SPECIFICATION.md
│ ├── deep_research_agent.md
│ ├── exa_research_testing_strategy.md
│ ├── PORTFOLIO_PERSONALIZATION_PLAN.md
│ ├── PORTFOLIO.md
│ ├── SETUP_SELF_CONTAINED.md
│ └── speed_testing_framework.md
├── examples
│ ├── complete_speed_validation.py
│ ├── deep_research_integration.py
│ ├── llm_optimization_example.py
│ ├── llm_speed_demo.py
│ ├── monitoring_example.py
│ ├── parallel_research_example.py
│ ├── speed_optimization_demo.py
│ └── timeout_fix_demonstration.py
├── LICENSE
├── Makefile
├── MANIFEST.in
├── maverick_mcp
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── circuit_breaker.py
│ │ ├── deep_research.py
│ │ ├── market_analysis.py
│ │ ├── optimized_research.py
│ │ ├── supervisor.py
│ │ └── technical_analysis.py
│ ├── api
│ │ ├── __init__.py
│ │ ├── api_server.py
│ │ ├── connection_manager.py
│ │ ├── dependencies
│ │ │ ├── __init__.py
│ │ │ ├── stock_analysis.py
│ │ │ └── technical_analysis.py
│ │ ├── error_handling.py
│ │ ├── inspector_compatible_sse.py
│ │ ├── inspector_sse.py
│ │ ├── middleware
│ │ │ ├── error_handling.py
│ │ │ ├── mcp_logging.py
│ │ │ ├── rate_limiting_enhanced.py
│ │ │ └── security.py
│ │ ├── openapi_config.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── agents.py
│ │ │ ├── backtesting.py
│ │ │ ├── data_enhanced.py
│ │ │ ├── data.py
│ │ │ ├── health_enhanced.py
│ │ │ ├── health_tools.py
│ │ │ ├── health.py
│ │ │ ├── intelligent_backtesting.py
│ │ │ ├── introspection.py
│ │ │ ├── mcp_prompts.py
│ │ │ ├── monitoring.py
│ │ │ ├── news_sentiment_enhanced.py
│ │ │ ├── performance.py
│ │ │ ├── portfolio.py
│ │ │ ├── research.py
│ │ │ ├── screening_ddd.py
│ │ │ ├── screening_parallel.py
│ │ │ ├── screening.py
│ │ │ ├── technical_ddd.py
│ │ │ ├── technical_enhanced.py
│ │ │ ├── technical.py
│ │ │ └── tool_registry.py
│ │ ├── server.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── base_service.py
│ │ │ ├── market_service.py
│ │ │ ├── portfolio_service.py
│ │ │ ├── prompt_service.py
│ │ │ └── resource_service.py
│ │ ├── simple_sse.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── insomnia_export.py
│ │ └── postman_export.py
│ ├── application
│ │ ├── __init__.py
│ │ ├── commands
│ │ │ └── __init__.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_dto.py
│ │ ├── queries
│ │ │ ├── __init__.py
│ │ │ └── get_technical_analysis.py
│ │ └── screening
│ │ ├── __init__.py
│ │ ├── dtos.py
│ │ └── queries.py
│ ├── backtesting
│ │ ├── __init__.py
│ │ ├── ab_testing.py
│ │ ├── analysis.py
│ │ ├── batch_processing_stub.py
│ │ ├── batch_processing.py
│ │ ├── model_manager.py
│ │ ├── optimization.py
│ │ ├── persistence.py
│ │ ├── retraining_pipeline.py
│ │ ├── strategies
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── ml
│ │ │ │ ├── __init__.py
│ │ │ │ ├── adaptive.py
│ │ │ │ ├── ensemble.py
│ │ │ │ ├── feature_engineering.py
│ │ │ │ └── regime_aware.py
│ │ │ ├── ml_strategies.py
│ │ │ ├── parser.py
│ │ │ └── templates.py
│ │ ├── strategy_executor.py
│ │ ├── vectorbt_engine.py
│ │ └── visualization.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── database_self_contained.py
│ │ ├── database.py
│ │ ├── llm_optimization_config.py
│ │ ├── logging_settings.py
│ │ ├── plotly_config.py
│ │ ├── security_utils.py
│ │ ├── security.py
│ │ ├── settings.py
│ │ ├── technical_constants.py
│ │ ├── tool_estimation.py
│ │ └── validation.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── technical_analysis.py
│ │ └── visualization.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── cache_manager.py
│ │ ├── cache.py
│ │ ├── django_adapter.py
│ │ ├── health.py
│ │ ├── models.py
│ │ ├── performance.py
│ │ ├── session_management.py
│ │ └── validation.py
│ ├── database
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── optimization.py
│ ├── dependencies.py
│ ├── domain
│ │ ├── __init__.py
│ │ ├── entities
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis.py
│ │ ├── events
│ │ │ └── __init__.py
│ │ ├── portfolio.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ ├── entities.py
│ │ │ ├── services.py
│ │ │ └── value_objects.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_service.py
│ │ ├── stock_analysis
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis_service.py
│ │ └── value_objects
│ │ ├── __init__.py
│ │ └── technical_indicators.py
│ ├── exceptions.py
│ ├── infrastructure
│ │ ├── __init__.py
│ │ ├── cache
│ │ │ └── __init__.py
│ │ ├── caching
│ │ │ ├── __init__.py
│ │ │ └── cache_management_service.py
│ │ ├── connection_manager.py
│ │ ├── data_fetching
│ │ │ ├── __init__.py
│ │ │ └── stock_data_service.py
│ │ ├── health
│ │ │ ├── __init__.py
│ │ │ └── health_checker.py
│ │ ├── persistence
│ │ │ ├── __init__.py
│ │ │ └── stock_repository.py
│ │ ├── providers
│ │ │ └── __init__.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ └── repositories.py
│ │ └── sse_optimizer.py
│ ├── langchain_tools
│ │ ├── __init__.py
│ │ ├── adapters.py
│ │ └── registry.py
│ ├── logging_config.py
│ ├── memory
│ │ ├── __init__.py
│ │ └── stores.py
│ ├── monitoring
│ │ ├── __init__.py
│ │ ├── health_check.py
│ │ ├── health_monitor.py
│ │ ├── integration_example.py
│ │ ├── metrics.py
│ │ ├── middleware.py
│ │ └── status_dashboard.py
│ ├── providers
│ │ ├── __init__.py
│ │ ├── dependencies.py
│ │ ├── factories
│ │ │ ├── __init__.py
│ │ │ ├── config_factory.py
│ │ │ └── provider_factory.py
│ │ ├── implementations
│ │ │ ├── __init__.py
│ │ │ ├── cache_adapter.py
│ │ │ ├── macro_data_adapter.py
│ │ │ ├── market_data_adapter.py
│ │ │ ├── persistence_adapter.py
│ │ │ └── stock_data_adapter.py
│ │ ├── interfaces
│ │ │ ├── __init__.py
│ │ │ ├── cache.py
│ │ │ ├── config.py
│ │ │ ├── macro_data.py
│ │ │ ├── market_data.py
│ │ │ ├── persistence.py
│ │ │ └── stock_data.py
│ │ ├── llm_factory.py
│ │ ├── macro_data.py
│ │ ├── market_data.py
│ │ ├── mocks
│ │ │ ├── __init__.py
│ │ │ ├── mock_cache.py
│ │ │ ├── mock_config.py
│ │ │ ├── mock_macro_data.py
│ │ │ ├── mock_market_data.py
│ │ │ ├── mock_persistence.py
│ │ │ └── mock_stock_data.py
│ │ ├── openrouter_provider.py
│ │ ├── optimized_screening.py
│ │ ├── optimized_stock_data.py
│ │ └── stock_data.py
│ ├── README.md
│ ├── tests
│ │ ├── __init__.py
│ │ ├── README_INMEMORY_TESTS.md
│ │ ├── test_cache_debug.py
│ │ ├── test_fixes_validation.py
│ │ ├── test_in_memory_routers.py
│ │ ├── test_in_memory_server.py
│ │ ├── test_macro_data_provider.py
│ │ ├── test_mailgun_email.py
│ │ ├── test_market_calendar_caching.py
│ │ ├── test_mcp_tool_fixes_pytest.py
│ │ ├── test_mcp_tool_fixes.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_models_functional.py
│ │ ├── test_server.py
│ │ ├── test_stock_data_enhanced.py
│ │ ├── test_stock_data_provider.py
│ │ └── test_technical_analysis.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── performance_monitoring.py
│ │ ├── portfolio_manager.py
│ │ ├── risk_management.py
│ │ └── sentiment_analysis.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── agent_errors.py
│ │ ├── batch_processing.py
│ │ ├── cache_warmer.py
│ │ ├── circuit_breaker_decorators.py
│ │ ├── circuit_breaker_services.py
│ │ ├── circuit_breaker.py
│ │ ├── data_chunking.py
│ │ ├── database_monitoring.py
│ │ ├── debug_utils.py
│ │ ├── fallback_strategies.py
│ │ ├── llm_optimization.py
│ │ ├── logging_example.py
│ │ ├── logging_init.py
│ │ ├── logging.py
│ │ ├── mcp_logging.py
│ │ ├── memory_profiler.py
│ │ ├── monitoring_middleware.py
│ │ ├── monitoring.py
│ │ ├── orchestration_logging.py
│ │ ├── parallel_research.py
│ │ ├── parallel_screening.py
│ │ ├── quick_cache.py
│ │ ├── resource_manager.py
│ │ ├── shutdown.py
│ │ ├── stock_helpers.py
│ │ ├── structured_logger.py
│ │ ├── tool_monitoring.py
│ │ ├── tracing.py
│ │ └── yfinance_pool.py
│ ├── validation
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── data.py
│ │ ├── middleware.py
│ │ ├── portfolio.py
│ │ ├── responses.py
│ │ ├── screening.py
│ │ └── technical.py
│ └── workflows
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── market_analyzer.py
│ │ ├── optimizer_agent.py
│ │ ├── strategy_selector.py
│ │ └── validator_agent.py
│ ├── backtesting_workflow.py
│ └── state.py
├── PLANS.md
├── pyproject.toml
├── pyrightconfig.json
├── README.md
├── scripts
│ ├── dev.sh
│ ├── INSTALLATION_GUIDE.md
│ ├── load_example.py
│ ├── load_market_data.py
│ ├── load_tiingo_data.py
│ ├── migrate_db.py
│ ├── README_TIINGO_LOADER.md
│ ├── requirements_tiingo.txt
│ ├── run_stock_screening.py
│ ├── run-migrations.sh
│ ├── seed_db.py
│ ├── seed_sp500.py
│ ├── setup_database.sh
│ ├── setup_self_contained.py
│ ├── setup_sp500_database.sh
│ ├── test_seeded_data.py
│ ├── test_tiingo_loader.py
│ ├── tiingo_config.py
│ └── validate_setup.py
├── SECURITY.md
├── server.json
├── setup.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── core
│ │ └── test_technical_analysis.py
│ ├── data
│ │ └── test_portfolio_models.py
│ ├── domain
│ │ ├── conftest.py
│ │ ├── test_portfolio_entities.py
│ │ └── test_technical_analysis_service.py
│ ├── fixtures
│ │ └── orchestration_fixtures.py
│ ├── integration
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── README.md
│ │ ├── run_integration_tests.sh
│ │ ├── test_api_technical.py
│ │ ├── test_chaos_engineering.py
│ │ ├── test_config_management.py
│ │ ├── test_full_backtest_workflow_advanced.py
│ │ ├── test_full_backtest_workflow.py
│ │ ├── test_high_volume.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_orchestration_complete.py
│ │ ├── test_portfolio_persistence.py
│ │ ├── test_redis_cache.py
│ │ ├── test_security_integration.py.disabled
│ │ └── vcr_setup.py
│ ├── performance
│ │ ├── __init__.py
│ │ ├── test_benchmarks.py
│ │ ├── test_load.py
│ │ ├── test_profiling.py
│ │ └── test_stress.py
│ ├── providers
│ │ └── test_stock_data_simple.py
│ ├── README.md
│ ├── test_agents_router_mcp.py
│ ├── test_backtest_persistence.py
│ ├── test_cache_management_service.py
│ ├── test_cache_serialization.py
│ ├── test_circuit_breaker.py
│ ├── test_database_pool_config_simple.py
│ ├── test_database_pool_config.py
│ ├── test_deep_research_functional.py
│ ├── test_deep_research_integration.py
│ ├── test_deep_research_parallel_execution.py
│ ├── test_error_handling.py
│ ├── test_event_loop_integrity.py
│ ├── test_exa_research_integration.py
│ ├── test_exception_hierarchy.py
│ ├── test_financial_search.py
│ ├── test_graceful_shutdown.py
│ ├── test_integration_simple.py
│ ├── test_langgraph_workflow.py
│ ├── test_market_data_async.py
│ ├── test_market_data_simple.py
│ ├── test_mcp_orchestration_functional.py
│ ├── test_ml_strategies.py
│ ├── test_optimized_research_agent.py
│ ├── test_orchestration_integration.py
│ ├── test_orchestration_logging.py
│ ├── test_orchestration_tools_simple.py
│ ├── test_parallel_research_integration.py
│ ├── test_parallel_research_orchestrator.py
│ ├── test_parallel_research_performance.py
│ ├── test_performance_optimizations.py
│ ├── test_production_validation.py
│ ├── test_provider_architecture.py
│ ├── test_rate_limiting_enhanced.py
│ ├── test_runner_validation.py
│ ├── test_security_comprehensive.py.disabled
│ ├── test_security_cors.py
│ ├── test_security_enhancements.py.disabled
│ ├── test_security_headers.py
│ ├── test_security_penetration.py
│ ├── test_session_management.py
│ ├── test_speed_optimization_validation.py
│ ├── test_stock_analysis_dependencies.py
│ ├── test_stock_analysis_service.py
│ ├── test_stock_data_fetching_service.py
│ ├── test_supervisor_agent.py
│ ├── test_supervisor_functional.py
│ ├── test_tool_estimation_config.py
│ ├── test_visualization.py
│ └── utils
│ ├── test_agent_errors.py
│ ├── test_logging.py
│ ├── test_parallel_screening.py
│ └── test_quick_cache.py
├── tools
│ ├── check_orchestration_config.py
│ ├── experiments
│ │ ├── validation_examples.py
│ │ └── validation_fixed.py
│ ├── fast_dev.sh
│ ├── hot_reload.py
│ ├── quick_test.py
│ └── templates
│ ├── new_router_template.py
│ ├── new_tool_template.py
│ ├── screening_strategy_template.py
│ └── test_template.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/maverick_mcp/infrastructure/screening/repositories.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Screening infrastructure repositories.
3 |
4 | This module contains concrete implementations of repository interfaces
5 | for accessing stock screening data from various persistence layers.
6 | """
7 |
8 | import logging
9 | from decimal import Decimal
10 | from typing import Any
11 |
12 | from sqlalchemy.exc import SQLAlchemyError
13 | from sqlalchemy.orm import Session
14 |
15 | from maverick_mcp.data.models import (
16 | MaverickBearStocks,
17 | MaverickStocks,
18 | SessionLocal,
19 | SupplyDemandBreakoutStocks,
20 | )
21 | from maverick_mcp.domain.screening.services import IStockRepository
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 |
26 | class PostgresStockRepository(IStockRepository):
27 | """
28 | PostgreSQL implementation of the stock repository.
29 |
30 | This repository adapter provides access to stock screening data
31 | stored in PostgreSQL database tables.
32 | """
33 |
34 | def __init__(self, session: Session | None = None):
35 | """
36 | Initialize the repository.
37 |
38 | Args:
39 | session: Optional SQLAlchemy session. If not provided,
40 | a new session will be created for each operation.
41 | """
42 | self._session = session
43 | self._owns_session = session is None
44 |
45 | def _get_session(self) -> tuple[Session, bool]:
46 | """
47 | Get a database session.
48 |
49 | Returns:
50 | Tuple of (session, should_close) where should_close indicates
51 | whether the caller should close the session.
52 | """
53 | if self._session:
54 | return self._session, False
55 | else:
56 | return SessionLocal(), True
57 |
58 | def get_maverick_stocks(
59 | self, limit: int = 20, min_score: int | None = None
60 | ) -> list[dict[str, Any]]:
61 | """
62 | Get Maverick bullish stocks from the database.
63 |
64 | Args:
65 | limit: Maximum number of stocks to return
66 | min_score: Minimum combined score filter
67 |
68 | Returns:
69 | List of stock data dictionaries
70 | """
71 | session, should_close = self._get_session()
72 |
73 | try:
74 | # Build query with optional filtering
75 | query = session.query(MaverickStocks)
76 |
77 | if min_score is not None:
78 | query = query.filter(MaverickStocks.combined_score >= min_score)
79 |
80 | # Order by combined score descending and limit results
81 | stocks = (
82 | query.order_by(MaverickStocks.combined_score.desc()).limit(limit).all()
83 | )
84 |
85 | # Convert to dictionaries
86 | result = []
87 | for stock in stocks:
88 | try:
89 | stock_dict = {
90 | "stock": stock.stock,
91 | "open": float(stock.open) if stock.open else 0.0,
92 | "high": float(stock.high) if stock.high else 0.0,
93 | "low": float(stock.low) if stock.low else 0.0,
94 | "close": float(stock.close) if stock.close else 0.0,
95 | "volume": int(stock.volume) if stock.volume else 0,
96 | "ema_21": float(stock.ema_21) if stock.ema_21 else 0.0,
97 | "sma_50": float(stock.sma_50) if stock.sma_50 else 0.0,
98 | "sma_150": float(stock.sma_150) if stock.sma_150 else 0.0,
99 | "sma_200": float(stock.sma_200) if stock.sma_200 else 0.0,
100 | "momentum_score": float(stock.momentum_score)
101 | if stock.momentum_score
102 | else 0.0,
103 | "avg_vol_30d": float(stock.avg_vol_30d)
104 | if stock.avg_vol_30d
105 | else 0.0,
106 | "adr_pct": float(stock.adr_pct) if stock.adr_pct else 0.0,
107 | "atr": float(stock.atr) if stock.atr else 0.0,
108 | "pat": stock.pat,
109 | "sqz": stock.sqz,
110 | "vcp": stock.vcp,
111 | "entry": stock.entry,
112 | "compression_score": int(stock.compression_score)
113 | if stock.compression_score
114 | else 0,
115 | "pattern_detected": int(stock.pattern_detected)
116 | if stock.pattern_detected
117 | else 0,
118 | "combined_score": int(stock.combined_score)
119 | if stock.combined_score
120 | else 0,
121 | }
122 | result.append(stock_dict)
123 | except (ValueError, TypeError) as e:
124 | logger.warning(
125 | f"Error processing maverick stock {stock.stock}: {e}"
126 | )
127 | continue
128 |
129 | logger.info(
130 | f"Retrieved {len(result)} Maverick bullish stocks (limit: {limit})"
131 | )
132 | return result
133 |
134 | except SQLAlchemyError as e:
135 | logger.error(f"Database error retrieving Maverick stocks: {e}")
136 | raise RuntimeError(f"Failed to retrieve Maverick stocks: {e}")
137 |
138 | except Exception as e:
139 | logger.error(f"Unexpected error retrieving Maverick stocks: {e}")
140 | raise RuntimeError(f"Unexpected error retrieving Maverick stocks: {e}")
141 |
142 | finally:
143 | if should_close:
144 | session.close()
145 |
146 | def get_maverick_bear_stocks(
147 | self, limit: int = 20, min_score: int | None = None
148 | ) -> list[dict[str, Any]]:
149 | """
150 | Get Maverick bearish stocks from the database.
151 |
152 | Args:
153 | limit: Maximum number of stocks to return
154 | min_score: Minimum bear score filter
155 |
156 | Returns:
157 | List of stock data dictionaries
158 | """
159 | session, should_close = self._get_session()
160 |
161 | try:
162 | # Build query with optional filtering
163 | query = session.query(MaverickBearStocks)
164 |
165 | if min_score is not None:
166 | query = query.filter(MaverickBearStocks.score >= min_score)
167 |
168 | # Order by score descending and limit results
169 | stocks = query.order_by(MaverickBearStocks.score.desc()).limit(limit).all()
170 |
171 | # Convert to dictionaries
172 | result = []
173 | for stock in stocks:
174 | try:
175 | stock_dict = {
176 | "stock": stock.stock,
177 | "open": float(stock.open) if stock.open else 0.0,
178 | "high": float(stock.high) if stock.high else 0.0,
179 | "low": float(stock.low) if stock.low else 0.0,
180 | "close": float(stock.close) if stock.close else 0.0,
181 | "volume": float(stock.volume) if stock.volume else 0.0,
182 | "momentum_score": float(stock.momentum_score)
183 | if stock.momentum_score
184 | else 0.0,
185 | "ema_21": float(stock.ema_21) if stock.ema_21 else 0.0,
186 | "sma_50": float(stock.sma_50) if stock.sma_50 else 0.0,
187 | "sma_200": float(stock.sma_200) if stock.sma_200 else 0.0,
188 | "rsi_14": float(stock.rsi_14) if stock.rsi_14 else 0.0,
189 | "macd": float(stock.macd) if stock.macd else 0.0,
190 | "macd_s": float(stock.macd_s) if stock.macd_s else 0.0,
191 | "macd_h": float(stock.macd_h) if stock.macd_h else 0.0,
192 | "dist_days_20": int(stock.dist_days_20)
193 | if stock.dist_days_20
194 | else 0,
195 | "adr_pct": float(stock.adr_pct) if stock.adr_pct else 0.0,
196 | "atr_contraction": bool(stock.atr_contraction)
197 | if stock.atr_contraction is not None
198 | else False,
199 | "atr": float(stock.atr) if stock.atr else 0.0,
200 | "avg_vol_30d": float(stock.avg_vol_30d)
201 | if stock.avg_vol_30d
202 | else 0.0,
203 | "big_down_vol": bool(stock.big_down_vol)
204 | if stock.big_down_vol is not None
205 | else False,
206 | "score": int(stock.score) if stock.score else 0,
207 | "sqz": stock.sqz,
208 | "vcp": stock.vcp,
209 | }
210 | result.append(stock_dict)
211 | except (ValueError, TypeError) as e:
212 | logger.warning(
213 | f"Error processing maverick bear stock {stock.stock}: {e}"
214 | )
215 | continue
216 |
217 | logger.info(
218 | f"Retrieved {len(result)} Maverick bearish stocks (limit: {limit})"
219 | )
220 | return result
221 |
222 | except SQLAlchemyError as e:
223 | logger.error(f"Database error retrieving Maverick bear stocks: {e}")
224 | raise RuntimeError(f"Failed to retrieve Maverick bear stocks: {e}")
225 |
226 | except Exception as e:
227 | logger.error(f"Unexpected error retrieving Maverick bear stocks: {e}")
228 | raise RuntimeError(f"Unexpected error retrieving Maverick bear stocks: {e}")
229 |
230 | finally:
231 | if should_close:
232 | session.close()
233 |
234 | def get_trending_stocks(
235 | self,
236 | limit: int = 20,
237 | min_momentum_score: Decimal | None = None,
238 | filter_moving_averages: bool = False,
239 | ) -> list[dict[str, Any]]:
240 | """
241 | Get trending stocks from the database.
242 |
243 | Args:
244 | limit: Maximum number of stocks to return
245 | min_momentum_score: Minimum momentum score filter
246 | filter_moving_averages: If True, apply moving average filters
247 |
248 | Returns:
249 | List of stock data dictionaries
250 | """
251 | session, should_close = self._get_session()
252 |
253 | try:
254 | # Build query with optional filtering
255 | query = session.query(SupplyDemandBreakoutStocks)
256 |
257 | if min_momentum_score is not None:
258 | query = query.filter(
259 | SupplyDemandBreakoutStocks.momentum_score
260 | >= float(min_momentum_score)
261 | )
262 |
263 | # Apply moving average filters if requested
264 | if filter_moving_averages:
265 | query = query.filter(
266 | SupplyDemandBreakoutStocks.close_price
267 | > SupplyDemandBreakoutStocks.sma_50,
268 | SupplyDemandBreakoutStocks.close_price
269 | > SupplyDemandBreakoutStocks.sma_150,
270 | SupplyDemandBreakoutStocks.close_price
271 | > SupplyDemandBreakoutStocks.sma_200,
272 | SupplyDemandBreakoutStocks.sma_50
273 | > SupplyDemandBreakoutStocks.sma_150,
274 | SupplyDemandBreakoutStocks.sma_150
275 | > SupplyDemandBreakoutStocks.sma_200,
276 | )
277 |
278 | # Order by momentum score descending and limit results
279 | stocks = (
280 | query.order_by(SupplyDemandBreakoutStocks.momentum_score.desc())
281 | .limit(limit)
282 | .all()
283 | )
284 |
285 | # Convert to dictionaries
286 | result = []
287 | for stock in stocks:
288 | try:
289 | stock_dict = {
290 | "stock": stock.stock,
291 | "open": float(stock.open_price) if stock.open_price else 0.0,
292 | "high": float(stock.high_price) if stock.high_price else 0.0,
293 | "low": float(stock.low_price) if stock.low_price else 0.0,
294 | "close": float(stock.close_price) if stock.close_price else 0.0,
295 | "volume": int(stock.volume) if stock.volume else 0,
296 | "ema_21": float(stock.ema_21) if stock.ema_21 else 0.0,
297 | "sma_50": float(stock.sma_50) if stock.sma_50 else 0.0,
298 | "sma_150": float(stock.sma_150) if stock.sma_150 else 0.0,
299 | "sma_200": float(stock.sma_200) if stock.sma_200 else 0.0,
300 | "momentum_score": float(stock.momentum_score)
301 | if stock.momentum_score
302 | else 0.0,
303 | "avg_volume_30d": float(stock.avg_volume_30d)
304 | if stock.avg_volume_30d
305 | else 0.0,
306 | "adr_pct": float(stock.adr_pct) if stock.adr_pct else 0.0,
307 | "atr": float(stock.atr) if stock.atr else 0.0,
308 | "pat": stock.pattern_type,
309 | "sqz": stock.squeeze_status,
310 | "vcp": stock.consolidation_status,
311 | "entry": stock.entry_signal,
312 | }
313 | result.append(stock_dict)
314 | except (ValueError, TypeError) as e:
315 | logger.warning(
316 | f"Error processing trending stock {stock.stock}: {e}"
317 | )
318 | continue
319 |
320 | logger.info(
321 | f"Retrieved {len(result)} trending stocks "
322 | f"(limit: {limit}, MA filter: {filter_moving_averages})"
323 | )
324 | return result
325 |
326 | except SQLAlchemyError as e:
327 | logger.error(f"Database error retrieving trending stocks: {e}")
328 | raise RuntimeError(f"Failed to retrieve trending stocks: {e}")
329 |
330 | except Exception as e:
331 | logger.error(f"Unexpected error retrieving trending stocks: {e}")
332 | raise RuntimeError(f"Unexpected error retrieving trending stocks: {e}")
333 |
334 | finally:
335 | if should_close:
336 | session.close()
337 |
338 | def close(self) -> None:
339 | """
340 | Close the repository and cleanup resources.
341 |
342 | This method should be called when the repository is no longer needed.
343 | """
344 | if self._session and self._owns_session:
345 | try:
346 | self._session.close()
347 | logger.debug("Closed repository session")
348 | except Exception as e:
349 | logger.warning(f"Error closing repository session: {e}")
350 |
351 |
352 | class CachedStockRepository(IStockRepository):
353 | """
354 | Cached implementation of the stock repository.
355 |
356 | This repository decorator adds caching capabilities to any
357 | underlying stock repository implementation.
358 | """
359 |
360 | def __init__(
361 | self, underlying_repository: IStockRepository, cache_ttl_seconds: int = 300
362 | ):
363 | """
364 | Initialize the cached repository.
365 |
366 | Args:
367 | underlying_repository: The repository to wrap with caching
368 | cache_ttl_seconds: Time-to-live for cache entries in seconds
369 | """
370 | self._repository = underlying_repository
371 | self._cache_ttl = cache_ttl_seconds
372 | self._cache: dict[str, tuple[Any, float]] = {}
373 |
374 | def _get_cache_key(self, method: str, **kwargs) -> str:
375 | """Generate a cache key for the given method and parameters."""
376 | sorted_params = sorted(kwargs.items())
377 | param_str = "&".join(f"{k}={v}" for k, v in sorted_params)
378 | return f"{method}?{param_str}"
379 |
380 | def _is_cache_valid(self, timestamp: float) -> bool:
381 | """Check if a cache entry is still valid based on TTL."""
382 | import time
383 |
384 | return (time.time() - timestamp) < self._cache_ttl
385 |
386 | def _get_from_cache_or_execute(self, cache_key: str, func, *args, **kwargs):
387 | """Get result from cache or execute function and cache result."""
388 | import time
389 |
390 | # Check cache first
391 | if cache_key in self._cache:
392 | result, timestamp = self._cache[cache_key]
393 | if self._is_cache_valid(timestamp):
394 | logger.debug(f"Cache hit for {cache_key}")
395 | return result
396 | else:
397 | # Remove expired entry
398 | del self._cache[cache_key]
399 |
400 | # Execute function and cache result
401 | logger.debug(f"Cache miss for {cache_key}, executing function")
402 | result = func(*args, **kwargs)
403 | self._cache[cache_key] = (result, time.time())
404 |
405 | return result
406 |
407 | def get_maverick_stocks(
408 | self, limit: int = 20, min_score: int | None = None
409 | ) -> list[dict[str, Any]]:
410 | """Get Maverick stocks with caching."""
411 | cache_key = self._get_cache_key(
412 | "maverick_stocks", limit=limit, min_score=min_score
413 | )
414 | return self._get_from_cache_or_execute(
415 | cache_key,
416 | self._repository.get_maverick_stocks,
417 | limit=limit,
418 | min_score=min_score,
419 | )
420 |
421 | def get_maverick_bear_stocks(
422 | self, limit: int = 20, min_score: int | None = None
423 | ) -> list[dict[str, Any]]:
424 | """Get Maverick bear stocks with caching."""
425 | cache_key = self._get_cache_key(
426 | "maverick_bear_stocks", limit=limit, min_score=min_score
427 | )
428 | return self._get_from_cache_or_execute(
429 | cache_key,
430 | self._repository.get_maverick_bear_stocks,
431 | limit=limit,
432 | min_score=min_score,
433 | )
434 |
435 | def get_trending_stocks(
436 | self,
437 | limit: int = 20,
438 | min_momentum_score: Decimal | None = None,
439 | filter_moving_averages: bool = False,
440 | ) -> list[dict[str, Any]]:
441 | """Get trending stocks with caching."""
442 | cache_key = self._get_cache_key(
443 | "trending_stocks",
444 | limit=limit,
445 | min_momentum_score=str(min_momentum_score) if min_momentum_score else None,
446 | filter_moving_averages=filter_moving_averages,
447 | )
448 | return self._get_from_cache_or_execute(
449 | cache_key,
450 | self._repository.get_trending_stocks,
451 | limit=limit,
452 | min_momentum_score=min_momentum_score,
453 | filter_moving_averages=filter_moving_averages,
454 | )
455 |
456 | def clear_cache(self) -> None:
457 | """Clear all cached entries."""
458 | self._cache.clear()
459 | logger.info("Cleared repository cache")
460 |
461 | def get_cache_stats(self) -> dict[str, Any]:
462 | """Get cache statistics for monitoring."""
463 | import time
464 |
465 | current_time = time.time()
466 |
467 | total_entries = len(self._cache)
468 | valid_entries = sum(
469 | 1
470 | for _, timestamp in self._cache.values()
471 | if self._is_cache_valid(timestamp)
472 | )
473 |
474 | return {
475 | "total_entries": total_entries,
476 | "valid_entries": valid_entries,
477 | "expired_entries": total_entries - valid_entries,
478 | "cache_ttl_seconds": self._cache_ttl,
479 | "oldest_entry_age": (
480 | min(current_time - timestamp for _, timestamp in self._cache.values())
481 | if self._cache
482 | else 0
483 | ),
484 | }
485 |
```
--------------------------------------------------------------------------------
/maverick_mcp/tools/sentiment_analysis.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Sentiment analysis tools for news, social media, and market sentiment.
3 | """
4 |
5 | import logging
6 | from datetime import datetime, timedelta
7 | from typing import Any
8 |
9 | from pydantic import BaseModel, Field
10 |
11 | from maverick_mcp.agents.base import PersonaAwareTool
12 | from maverick_mcp.config.settings import get_settings
13 | from maverick_mcp.providers.market_data import MarketDataProvider
14 |
15 | logger = logging.getLogger(__name__)
16 | settings = get_settings()
17 |
18 |
19 | class SentimentInput(BaseModel):
20 | """Input for sentiment analysis."""
21 |
22 | symbol: str = Field(description="Stock symbol to analyze")
23 | days_back: int = Field(default=7, description="Days of history to analyze")
24 |
25 |
26 | class MarketBreadthInput(BaseModel):
27 | """Input for market breadth analysis."""
28 |
29 | index: str = Field(default="SPY", description="Market index to analyze")
30 |
31 |
32 | class NewsSentimentTool(PersonaAwareTool):
33 | """Analyze news sentiment for stocks."""
34 |
35 | name: str = "analyze_news_sentiment"
36 | description: str = "Analyze recent news sentiment and its impact on stock price"
37 | args_schema: type[BaseModel] = SentimentInput # type: ignore[assignment]
38 |
39 | def _run(self, symbol: str, days_back: int = 7) -> str:
40 | """Analyze news sentiment synchronously."""
41 | try:
42 | MarketDataProvider()
43 |
44 | # Get recent news (placeholder - would need to implement news API)
45 | # news_data = provider.get_stock_news(symbol, limit=settings.agent.sentiment_news_limit)
46 | news_data: dict[str, Any] = {"articles": []}
47 |
48 | if not news_data or "articles" not in news_data:
49 | return f"No news data available for {symbol}"
50 |
51 | articles = news_data.get("articles", [])
52 | if not articles:
53 | return f"No recent news articles found for {symbol}"
54 |
55 | # Simple sentiment scoring based on keywords
56 | positive_keywords = [
57 | "beat",
58 | "exceed",
59 | "upgrade",
60 | "strong",
61 | "growth",
62 | "profit",
63 | "revenue",
64 | "bullish",
65 | "buy",
66 | "outperform",
67 | "surge",
68 | "rally",
69 | "breakthrough",
70 | "innovation",
71 | "expansion",
72 | "record",
73 | ]
74 | negative_keywords = [
75 | "miss",
76 | "downgrade",
77 | "weak",
78 | "loss",
79 | "decline",
80 | "bearish",
81 | "sell",
82 | "underperform",
83 | "fall",
84 | "cut",
85 | "concern",
86 | "risk",
87 | "lawsuit",
88 | "investigation",
89 | "recall",
90 | "bankruptcy",
91 | ]
92 |
93 | sentiment_scores = []
94 | analyzed_articles = []
95 |
96 | cutoff_date = datetime.now() - timedelta(days=days_back)
97 |
98 | for article in articles[:20]: # Analyze top 20 most recent
99 | title = article.get("title", "").lower()
100 | description = article.get("description", "").lower()
101 | published = article.get("publishedAt", "")
102 |
103 | # Skip old articles
104 | try:
105 | pub_date = datetime.fromisoformat(published.replace("Z", "+00:00"))
106 | if pub_date < cutoff_date:
107 | continue
108 | except Exception:
109 | continue
110 |
111 | text = f"{title} {description}"
112 |
113 | # Count keyword occurrences
114 | positive_count = sum(1 for word in positive_keywords if word in text)
115 | negative_count = sum(1 for word in negative_keywords if word in text)
116 |
117 | # Calculate sentiment score
118 | if positive_count + negative_count > 0:
119 | score = (positive_count - negative_count) / (
120 | positive_count + negative_count
121 | )
122 | else:
123 | score = 0
124 |
125 | sentiment_scores.append(score)
126 | analyzed_articles.append(
127 | {
128 | "title": article.get("title", ""),
129 | "published": published,
130 | "sentiment_score": round(score, 2),
131 | "source": article.get("source", {}).get("name", "Unknown"),
132 | }
133 | )
134 |
135 | if not sentiment_scores:
136 | return f"No recent news articles found for {symbol} in the last {days_back} days"
137 |
138 | # Calculate aggregate sentiment
139 | avg_sentiment = sum(sentiment_scores) / len(sentiment_scores)
140 |
141 | # Determine sentiment category
142 | if avg_sentiment > 0.2:
143 | sentiment_category = "Positive"
144 | sentiment_impact = "Bullish"
145 | elif avg_sentiment < -0.2:
146 | sentiment_category = "Negative"
147 | sentiment_impact = "Bearish"
148 | else:
149 | sentiment_category = "Neutral"
150 | sentiment_impact = "Mixed"
151 |
152 | # Calculate momentum (recent vs older sentiment)
153 | if len(sentiment_scores) >= 5:
154 | recent_sentiment = sum(sentiment_scores[:5]) / 5
155 | older_sentiment = sum(sentiment_scores[5:]) / len(sentiment_scores[5:])
156 | sentiment_momentum = recent_sentiment - older_sentiment
157 | else:
158 | sentiment_momentum = 0
159 |
160 | result = {
161 | "status": "success",
162 | "symbol": symbol,
163 | "sentiment_analysis": {
164 | "overall_sentiment": sentiment_category,
165 | "sentiment_score": round(avg_sentiment, 3),
166 | "sentiment_impact": sentiment_impact,
167 | "sentiment_momentum": round(sentiment_momentum, 3),
168 | "articles_analyzed": len(analyzed_articles),
169 | "analysis_period": f"{days_back} days",
170 | },
171 | "recent_articles": analyzed_articles[:5], # Top 5 most recent
172 | "sentiment_distribution": {
173 | "positive": sum(1 for s in sentiment_scores if s > 0.2),
174 | "neutral": sum(1 for s in sentiment_scores if -0.2 <= s <= 0.2),
175 | "negative": sum(1 for s in sentiment_scores if s < -0.2),
176 | },
177 | }
178 |
179 | # Add trading recommendations based on sentiment and persona
180 | if self.persona:
181 | if sentiment_category == "Positive" and sentiment_momentum > 0:
182 | if self.persona.name == "Aggressive":
183 | result["recommendation"] = "Strong momentum - consider entry"
184 | elif self.persona.name == "Conservative":
185 | result["recommendation"] = (
186 | "Positive sentiment but wait for pullback"
187 | )
188 | else:
189 | result["recommendation"] = (
190 | "Favorable sentiment for gradual entry"
191 | )
192 | elif sentiment_category == "Negative":
193 | if self.persona.name == "Conservative":
194 | result["recommendation"] = "Avoid - negative sentiment"
195 | else:
196 | result["recommendation"] = "Monitor for reversal signals"
197 |
198 | # Format for persona
199 | formatted = self.format_for_persona(result)
200 | return str(formatted)
201 |
202 | except Exception as e:
203 | logger.error(f"Error analyzing news sentiment for {symbol}: {e}")
204 | return f"Error analyzing news sentiment: {str(e)}"
205 |
206 |
207 | class MarketBreadthTool(PersonaAwareTool):
208 | """Analyze overall market breadth and sentiment."""
209 |
210 | name: str = "analyze_market_breadth"
211 | description: str = "Analyze market breadth indicators and overall market sentiment"
212 | args_schema: type[BaseModel] = MarketBreadthInput # type: ignore[assignment]
213 |
214 | def _run(self, index: str = "SPY") -> str:
215 | """Analyze market breadth synchronously."""
216 | try:
217 | provider = MarketDataProvider()
218 |
219 | # Get market movers
220 | gainers = {
221 | "movers": provider.get_top_gainers(
222 | limit=settings.agent.market_movers_gainers_limit
223 | )
224 | }
225 | losers = {
226 | "movers": provider.get_top_losers(
227 | limit=settings.agent.market_movers_losers_limit
228 | )
229 | }
230 | most_active = {
231 | "movers": provider.get_most_active(
232 | limit=settings.agent.market_movers_active_limit
233 | )
234 | }
235 |
236 | # Calculate breadth metrics
237 | total_gainers = len(gainers.get("movers", []))
238 | total_losers = len(losers.get("movers", []))
239 |
240 | if total_gainers + total_losers > 0:
241 | advance_decline_ratio = total_gainers / (total_gainers + total_losers)
242 | else:
243 | advance_decline_ratio = 0.5
244 |
245 | # Calculate average moves
246 | avg_gain = 0
247 | if gainers.get("movers"):
248 | gains = [m.get("change_percent", 0) for m in gainers["movers"]]
249 | avg_gain = sum(gains) / len(gains) if gains else 0
250 |
251 | avg_loss = 0
252 | if losers.get("movers"):
253 | losses = [abs(m.get("change_percent", 0)) for m in losers["movers"]]
254 | avg_loss = sum(losses) / len(losses) if losses else 0
255 |
256 | # Determine market sentiment
257 | if advance_decline_ratio > 0.65:
258 | market_sentiment = "Bullish"
259 | strength = "Strong" if advance_decline_ratio > 0.75 else "Moderate"
260 | elif advance_decline_ratio < 0.35:
261 | market_sentiment = "Bearish"
262 | strength = "Strong" if advance_decline_ratio < 0.25 else "Moderate"
263 | else:
264 | market_sentiment = "Neutral"
265 | strength = "Mixed"
266 |
267 | # Get VIX if available (fear gauge) - placeholder
268 | # vix_data = provider.get_quote("VIX")
269 | vix_data = None
270 | vix_level = None
271 | fear_gauge = "Unknown"
272 |
273 | if vix_data and "price" in vix_data:
274 | vix_level = vix_data["price"]
275 | if vix_level < 15:
276 | fear_gauge = "Low (Complacent)"
277 | elif vix_level < 20:
278 | fear_gauge = "Normal"
279 | elif vix_level < 30:
280 | fear_gauge = "Elevated (Cautious)"
281 | else:
282 | fear_gauge = "High (Fearful)"
283 |
284 | result = {
285 | "status": "success",
286 | "market_breadth": {
287 | "sentiment": market_sentiment,
288 | "strength": strength,
289 | "advance_decline_ratio": round(advance_decline_ratio, 3),
290 | "gainers": total_gainers,
291 | "losers": total_losers,
292 | "most_active": most_active,
293 | "avg_gain_pct": round(avg_gain, 2),
294 | "avg_loss_pct": round(avg_loss, 2),
295 | },
296 | "fear_gauge": {
297 | "vix_level": round(vix_level, 2) if vix_level else None,
298 | "fear_level": fear_gauge,
299 | },
300 | "market_leaders": [
301 | {
302 | "symbol": m.get("symbol"),
303 | "change_pct": round(m.get("change_percent", 0), 2),
304 | "volume": m.get("volume"),
305 | }
306 | for m in gainers.get("movers", [])[:5]
307 | ],
308 | "market_laggards": [
309 | {
310 | "symbol": m.get("symbol"),
311 | "change_pct": round(m.get("change_percent", 0), 2),
312 | "volume": m.get("volume"),
313 | }
314 | for m in losers.get("movers", [])[:5]
315 | ],
316 | }
317 |
318 | # Add persona-specific market interpretation
319 | if self.persona:
320 | if (
321 | market_sentiment == "Bullish"
322 | and self.persona.name == "Conservative"
323 | ):
324 | result["interpretation"] = (
325 | "Market is bullish but be cautious of extended moves"
326 | )
327 | elif (
328 | market_sentiment == "Bearish" and self.persona.name == "Aggressive"
329 | ):
330 | result["interpretation"] = (
331 | "Market weakness presents buying opportunities in oversold stocks"
332 | )
333 | elif market_sentiment == "Neutral":
334 | result["interpretation"] = (
335 | "Mixed market - focus on individual stock selection"
336 | )
337 |
338 | # Format for persona
339 | formatted = self.format_for_persona(result)
340 | return str(formatted)
341 |
342 | except Exception as e:
343 | logger.error(f"Error analyzing market breadth: {e}")
344 | return f"Error analyzing market breadth: {str(e)}"
345 |
346 |
347 | class SectorSentimentTool(PersonaAwareTool):
348 | """Analyze sector rotation and sentiment."""
349 |
350 | name: str = "analyze_sector_sentiment"
351 | description: str = (
352 | "Analyze sector rotation patterns and identify leading/lagging sectors"
353 | )
354 |
355 | def _run(self) -> str:
356 | """Analyze sector sentiment synchronously."""
357 | try:
358 | MarketDataProvider()
359 |
360 | # Major sector ETFs
361 | sectors = {
362 | "Technology": "XLK",
363 | "Healthcare": "XLV",
364 | "Financials": "XLF",
365 | "Energy": "XLE",
366 | "Consumer Discretionary": "XLY",
367 | "Consumer Staples": "XLP",
368 | "Industrials": "XLI",
369 | "Materials": "XLB",
370 | "Real Estate": "XLRE",
371 | "Utilities": "XLU",
372 | "Communications": "XLC",
373 | }
374 |
375 | sector_performance: dict[str, dict[str, Any]] = {}
376 |
377 | for sector_name, etf in sectors.items():
378 | # quote = provider.get_quote(etf)
379 | quote = None # Placeholder - would need quote provider
380 | if quote and "change_percent" in quote:
381 | sector_performance[sector_name] = {
382 | "symbol": etf,
383 | "change_pct": round(quote["change_percent"], 2),
384 | "price": quote.get("price", 0),
385 | "volume": quote.get("volume", 0),
386 | }
387 |
388 | if not sector_performance:
389 | return "Error: Unable to fetch sector performance data"
390 |
391 | # Sort sectors by performance
392 | sorted_sectors = sorted(
393 | sector_performance.items(),
394 | key=lambda x: x[1]["change_pct"],
395 | reverse=True,
396 | )
397 |
398 | # Identify rotation patterns
399 | leading_sectors = sorted_sectors[:3]
400 | lagging_sectors = sorted_sectors[-3:]
401 |
402 | # Determine market regime based on sector leadership
403 | tech_performance = sector_performance.get("Technology", {}).get(
404 | "change_pct", 0
405 | )
406 | defensive_avg = (
407 | sector_performance.get("Utilities", {}).get("change_pct", 0)
408 | + sector_performance.get("Consumer Staples", {}).get("change_pct", 0)
409 | ) / 2
410 |
411 | if tech_performance > 1 and defensive_avg < 0:
412 | market_regime = "Risk-On (Growth Leading)"
413 | elif defensive_avg > 1 and tech_performance < 0:
414 | market_regime = "Risk-Off (Defensive Leading)"
415 | else:
416 | market_regime = "Neutral/Transitioning"
417 |
418 | result = {
419 | "status": "success",
420 | "sector_rotation": {
421 | "market_regime": market_regime,
422 | "leading_sectors": [
423 | {"name": name, **data} for name, data in leading_sectors
424 | ],
425 | "lagging_sectors": [
426 | {"name": name, **data} for name, data in lagging_sectors
427 | ],
428 | },
429 | "all_sectors": dict(sorted_sectors),
430 | "rotation_signals": self._identify_rotation_signals(sector_performance),
431 | }
432 |
433 | # Add persona-specific sector recommendations
434 | if self.persona:
435 | if self.persona.name == "Conservative":
436 | result["recommendations"] = (
437 | "Focus on defensive sectors: "
438 | + ", ".join(
439 | [
440 | s
441 | for s in ["Utilities", "Consumer Staples", "Healthcare"]
442 | if s in sector_performance
443 | ]
444 | )
445 | )
446 | elif self.persona.name == "Aggressive":
447 | result["recommendations"] = (
448 | "Target high-momentum sectors: "
449 | + ", ".join([name for name, _ in leading_sectors])
450 | )
451 |
452 | # Format for persona
453 | formatted = self.format_for_persona(result)
454 | return str(formatted)
455 |
456 | except Exception as e:
457 | logger.error(f"Error analyzing sector sentiment: {e}")
458 | return f"Error analyzing sector sentiment: {str(e)}"
459 |
460 | def _identify_rotation_signals(
461 | self, sector_performance: dict[str, dict]
462 | ) -> list[str]:
463 | """Identify sector rotation signals."""
464 | signals = []
465 |
466 | # Check for tech leadership
467 | tech_perf = sector_performance.get("Technology", {}).get("change_pct", 0)
468 | if tech_perf > 2:
469 | signals.append("Strong tech leadership - growth environment")
470 |
471 | # Check for defensive rotation
472 | defensive_sectors = ["Utilities", "Consumer Staples", "Healthcare"]
473 | defensive_perfs = [
474 | sector_performance.get(s, {}).get("change_pct", 0)
475 | for s in defensive_sectors
476 | ]
477 | if all(p > 0 for p in defensive_perfs) and tech_perf < 0:
478 | signals.append("Defensive rotation - risk-off environment")
479 |
480 | # Check for energy/materials strength
481 | cyclical_strength = (
482 | sector_performance.get("Energy", {}).get("change_pct", 0)
483 | + sector_performance.get("Materials", {}).get("change_pct", 0)
484 | ) / 2
485 | if cyclical_strength > 2:
486 | signals.append("Cyclical strength - inflation/growth theme")
487 |
488 | return signals
489 |
```
--------------------------------------------------------------------------------
/maverick_mcp/backtesting/model_manager.py:
--------------------------------------------------------------------------------
```python
1 | """ML Model Manager for backtesting strategies with versioning and persistence."""
2 |
3 | import json
4 | import logging
5 | from datetime import datetime, timedelta
6 | from pathlib import Path
7 | from typing import Any
8 |
9 | import joblib
10 | import pandas as pd
11 | from sklearn.base import BaseEstimator
12 | from sklearn.preprocessing import StandardScaler
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | class ModelVersion:
18 | """Represents a specific version of an ML model with metadata."""
19 |
20 | def __init__(
21 | self,
22 | model_id: str,
23 | version: str,
24 | model: BaseEstimator,
25 | scaler: StandardScaler | None = None,
26 | metadata: dict[str, Any] | None = None,
27 | performance_metrics: dict[str, float] | None = None,
28 | ):
29 | """Initialize model version.
30 |
31 | Args:
32 | model_id: Unique identifier for the model
33 | version: Version string (e.g., "1.0.0")
34 | model: The trained ML model
35 | scaler: Feature scaler (if used)
36 | metadata: Additional metadata about the model
37 | performance_metrics: Performance metrics from training/validation
38 | """
39 | self.model_id = model_id
40 | self.version = version
41 | self.model = model
42 | self.scaler = scaler
43 | self.metadata = metadata or {}
44 | self.performance_metrics = performance_metrics or {}
45 | self.created_at = datetime.now()
46 | self.last_used = None
47 | self.usage_count = 0
48 |
49 | # Add default metadata
50 | self.metadata.update(
51 | {
52 | "model_type": type(model).__name__,
53 | "created_at": self.created_at.isoformat(),
54 | "sklearn_version": getattr(model, "_sklearn_version", "unknown"),
55 | }
56 | )
57 |
58 | def increment_usage(self):
59 | """Increment usage counter and update last used timestamp."""
60 | self.usage_count += 1
61 | self.last_used = datetime.now()
62 |
63 | def to_dict(self) -> dict[str, Any]:
64 | """Convert to dictionary representation."""
65 | return {
66 | "model_id": self.model_id,
67 | "version": self.version,
68 | "metadata": self.metadata,
69 | "performance_metrics": self.performance_metrics,
70 | "created_at": self.created_at.isoformat(),
71 | "last_used": self.last_used.isoformat() if self.last_used else None,
72 | "usage_count": self.usage_count,
73 | }
74 |
75 |
76 | class ModelManager:
77 | """Manages ML models with versioning, persistence, and performance tracking."""
78 |
79 | def __init__(self, base_path: str | Path = "./models"):
80 | """Initialize model manager.
81 |
82 | Args:
83 | base_path: Base directory for storing models
84 | """
85 | self.base_path = Path(base_path)
86 | self.base_path.mkdir(parents=True, exist_ok=True)
87 |
88 | # Model registry
89 | self.models: dict[str, dict[str, ModelVersion]] = {}
90 | self.active_models: dict[str, str] = {} # model_id -> active_version
91 |
92 | # Performance tracking
93 | self.performance_history: dict[str, list[dict[str, Any]]] = {}
94 |
95 | # Load existing models
96 | self._load_registry()
97 |
98 | def _get_model_path(self, model_id: str, version: str) -> Path:
99 | """Get file path for a specific model version."""
100 | return self.base_path / model_id / f"{version}.pkl"
101 |
102 | def _get_metadata_path(self, model_id: str, version: str) -> Path:
103 | """Get metadata file path for a specific model version."""
104 | return self.base_path / model_id / f"{version}_metadata.json"
105 |
106 | def _get_registry_path(self) -> Path:
107 | """Get registry file path."""
108 | return self.base_path / "registry.json"
109 |
110 | def _load_registry(self):
111 | """Load model registry from disk."""
112 | registry_path = self._get_registry_path()
113 | if registry_path.exists():
114 | try:
115 | with open(registry_path) as f:
116 | registry_data = json.load(f)
117 |
118 | self.active_models = registry_data.get("active_models", {})
119 | models_info = registry_data.get("models", {})
120 |
121 | # Lazy load model metadata (don't load actual models until needed)
122 | for model_id, versions in models_info.items():
123 | self.models[model_id] = {}
124 | for version, version_info in versions.items():
125 | # Create placeholder ModelVersion (model will be loaded on demand)
126 | model_version = ModelVersion(
127 | model_id=model_id,
128 | version=version,
129 | model=None, # Will be loaded on demand
130 | metadata=version_info.get("metadata", {}),
131 | performance_metrics=version_info.get(
132 | "performance_metrics", {}
133 | ),
134 | )
135 | model_version.created_at = datetime.fromisoformat(
136 | version_info.get("created_at", datetime.now().isoformat())
137 | )
138 | model_version.last_used = (
139 | datetime.fromisoformat(version_info["last_used"])
140 | if version_info.get("last_used")
141 | else None
142 | )
143 | model_version.usage_count = version_info.get("usage_count", 0)
144 | self.models[model_id][version] = model_version
145 |
146 | logger.info(
147 | f"Loaded model registry with {len(self.models)} model types"
148 | )
149 |
150 | except Exception as e:
151 | logger.error(f"Error loading model registry: {e}")
152 |
153 | def _save_registry(self):
154 | """Save model registry to disk."""
155 | try:
156 | registry_data = {"active_models": self.active_models, "models": {}}
157 |
158 | for model_id, versions in self.models.items():
159 | registry_data["models"][model_id] = {}
160 | for version, model_version in versions.items():
161 | registry_data["models"][model_id][version] = model_version.to_dict()
162 |
163 | registry_path = self._get_registry_path()
164 | with open(registry_path, "w") as f:
165 | json.dump(registry_data, f, indent=2)
166 |
167 | logger.debug("Saved model registry")
168 |
169 | except Exception as e:
170 | logger.error(f"Error saving model registry: {e}")
171 |
172 | def save_model(
173 | self,
174 | model_id: str,
175 | version: str,
176 | model: BaseEstimator,
177 | scaler: StandardScaler | None = None,
178 | metadata: dict[str, Any] | None = None,
179 | performance_metrics: dict[str, float] | None = None,
180 | set_as_active: bool = True,
181 | ) -> bool:
182 | """Save a model version to disk.
183 |
184 | Args:
185 | model_id: Unique identifier for the model
186 | version: Version string
187 | model: Trained ML model
188 | scaler: Feature scaler (if used)
189 | metadata: Additional metadata
190 | performance_metrics: Performance metrics
191 | set_as_active: Whether to set this as the active version
192 |
193 | Returns:
194 | True if successful
195 | """
196 | try:
197 | # Create model directory
198 | model_dir = self.base_path / model_id
199 | model_dir.mkdir(parents=True, exist_ok=True)
200 |
201 | # Save model and scaler using joblib (better for sklearn models)
202 | model_path = self._get_model_path(model_id, version)
203 | model_data = {
204 | "model": model,
205 | "scaler": scaler,
206 | }
207 | joblib.dump(model_data, model_path)
208 |
209 | # Create ModelVersion instance
210 | model_version = ModelVersion(
211 | model_id=model_id,
212 | version=version,
213 | model=model,
214 | scaler=scaler,
215 | metadata=metadata,
216 | performance_metrics=performance_metrics,
217 | )
218 |
219 | # Save metadata separately
220 | metadata_path = self._get_metadata_path(model_id, version)
221 | with open(metadata_path, "w") as f:
222 | json.dump(model_version.to_dict(), f, indent=2)
223 |
224 | # Update registry
225 | if model_id not in self.models:
226 | self.models[model_id] = {}
227 | self.models[model_id][version] = model_version
228 |
229 | # Set as active if requested
230 | if set_as_active:
231 | self.active_models[model_id] = version
232 |
233 | # Save registry
234 | self._save_registry()
235 |
236 | logger.info(
237 | f"Saved model {model_id} v{version} ({'active' if set_as_active else 'inactive'})"
238 | )
239 | return True
240 |
241 | except Exception as e:
242 | logger.error(f"Error saving model {model_id} v{version}: {e}")
243 | return False
244 |
245 | def load_model(
246 | self, model_id: str, version: str | None = None
247 | ) -> ModelVersion | None:
248 | """Load a specific model version.
249 |
250 | Args:
251 | model_id: Model identifier
252 | version: Version to load (defaults to active version)
253 |
254 | Returns:
255 | ModelVersion instance or None if not found
256 | """
257 | try:
258 | if version is None:
259 | version = self.active_models.get(model_id)
260 | if version is None:
261 | logger.warning(f"No active version found for model {model_id}")
262 | return None
263 |
264 | if model_id not in self.models or version not in self.models[model_id]:
265 | logger.warning(f"Model {model_id} v{version} not found in registry")
266 | return None
267 |
268 | model_version = self.models[model_id][version]
269 |
270 | # Load actual model if not already loaded
271 | if model_version.model is None:
272 | model_path = self._get_model_path(model_id, version)
273 | if not model_path.exists():
274 | logger.error(f"Model file not found: {model_path}")
275 | return None
276 |
277 | model_data = joblib.load(model_path)
278 | model_version.model = model_data["model"]
279 | model_version.scaler = model_data.get("scaler")
280 |
281 | # Update usage statistics
282 | model_version.increment_usage()
283 | self._save_registry()
284 |
285 | logger.debug(f"Loaded model {model_id} v{version}")
286 | return model_version
287 |
288 | except Exception as e:
289 | logger.error(f"Error loading model {model_id} v{version}: {e}")
290 | return None
291 |
292 | def list_models(self) -> dict[str, list[str]]:
293 | """List all available models and their versions.
294 |
295 | Returns:
296 | Dictionary mapping model_id to list of versions
297 | """
298 | return {
299 | model_id: list(versions.keys())
300 | for model_id, versions in self.models.items()
301 | }
302 |
303 | def list_model_versions(self, model_id: str) -> list[dict[str, Any]]:
304 | """List all versions of a specific model with metadata.
305 |
306 | Args:
307 | model_id: Model identifier
308 |
309 | Returns:
310 | List of version information dictionaries
311 | """
312 | if model_id not in self.models:
313 | return []
314 |
315 | versions_info = []
316 | for version, model_version in self.models[model_id].items():
317 | info = model_version.to_dict()
318 | info["is_active"] = self.active_models.get(model_id) == version
319 | versions_info.append(info)
320 |
321 | # Sort by creation date (newest first)
322 | versions_info.sort(key=lambda x: x["created_at"], reverse=True)
323 |
324 | return versions_info
325 |
326 | def set_active_version(self, model_id: str, version: str) -> bool:
327 | """Set the active version for a model.
328 |
329 | Args:
330 | model_id: Model identifier
331 | version: Version to set as active
332 |
333 | Returns:
334 | True if successful
335 | """
336 | if model_id not in self.models or version not in self.models[model_id]:
337 | logger.error(f"Model {model_id} v{version} not found")
338 | return False
339 |
340 | self.active_models[model_id] = version
341 | self._save_registry()
342 | logger.info(f"Set {model_id} v{version} as active")
343 | return True
344 |
345 | def delete_model_version(self, model_id: str, version: str) -> bool:
346 | """Delete a specific model version.
347 |
348 | Args:
349 | model_id: Model identifier
350 | version: Version to delete
351 |
352 | Returns:
353 | True if successful
354 | """
355 | try:
356 | if model_id not in self.models or version not in self.models[model_id]:
357 | logger.warning(f"Model {model_id} v{version} not found")
358 | return False
359 |
360 | # Don't delete active version
361 | if self.active_models.get(model_id) == version:
362 | logger.error(f"Cannot delete active version {model_id} v{version}")
363 | return False
364 |
365 | # Delete files
366 | model_path = self._get_model_path(model_id, version)
367 | metadata_path = self._get_metadata_path(model_id, version)
368 |
369 | if model_path.exists():
370 | model_path.unlink()
371 | if metadata_path.exists():
372 | metadata_path.unlink()
373 |
374 | # Remove from registry
375 | del self.models[model_id][version]
376 |
377 | # Clean up empty model entry
378 | if not self.models[model_id]:
379 | del self.models[model_id]
380 | if model_id in self.active_models:
381 | del self.active_models[model_id]
382 |
383 | self._save_registry()
384 | logger.info(f"Deleted model {model_id} v{version}")
385 | return True
386 |
387 | except Exception as e:
388 | logger.error(f"Error deleting model {model_id} v{version}: {e}")
389 | return False
390 |
391 | def cleanup_old_versions(
392 | self, keep_versions: int = 5, min_age_days: int = 30
393 | ) -> int:
394 | """Clean up old model versions.
395 |
396 | Args:
397 | keep_versions: Number of versions to keep per model
398 | min_age_days: Minimum age in days before deletion
399 |
400 | Returns:
401 | Number of versions deleted
402 | """
403 | deleted_count = 0
404 | cutoff_date = datetime.now() - timedelta(days=min_age_days)
405 |
406 | for model_id, versions in list(self.models.items()):
407 | # Sort versions by creation date (newest first)
408 | sorted_versions = sorted(
409 | versions.items(), key=lambda x: x[1].created_at, reverse=True
410 | )
411 |
412 | # Keep active version and recent versions
413 | active_version = self.active_models.get(model_id)
414 | versions_to_delete = []
415 |
416 | for i, (version, model_version) in enumerate(sorted_versions):
417 | # Skip if it's the active version
418 | if version == active_version:
419 | continue
420 |
421 | # Skip if we haven't kept enough versions yet
422 | if i < keep_versions:
423 | continue
424 |
425 | # Skip if it's too new
426 | if model_version.created_at > cutoff_date:
427 | continue
428 |
429 | versions_to_delete.append(version)
430 |
431 | # Delete old versions
432 | for version in versions_to_delete:
433 | if self.delete_model_version(model_id, version):
434 | deleted_count += 1
435 |
436 | if deleted_count > 0:
437 | logger.info(f"Cleaned up {deleted_count} old model versions")
438 |
439 | return deleted_count
440 |
441 | def get_model_performance_history(self, model_id: str) -> list[dict[str, Any]]:
442 | """Get performance history for a model.
443 |
444 | Args:
445 | model_id: Model identifier
446 |
447 | Returns:
448 | List of performance records
449 | """
450 | return self.performance_history.get(model_id, [])
451 |
452 | def log_model_performance(
453 | self,
454 | model_id: str,
455 | version: str,
456 | metrics: dict[str, float],
457 | additional_data: dict[str, Any] | None = None,
458 | ):
459 | """Log performance metrics for a model.
460 |
461 | Args:
462 | model_id: Model identifier
463 | version: Model version
464 | metrics: Performance metrics
465 | additional_data: Additional data to log
466 | """
467 | if model_id not in self.performance_history:
468 | self.performance_history[model_id] = []
469 |
470 | performance_record = {
471 | "timestamp": datetime.now().isoformat(),
472 | "version": version,
473 | "metrics": metrics,
474 | "additional_data": additional_data or {},
475 | }
476 |
477 | self.performance_history[model_id].append(performance_record)
478 |
479 | # Keep only recent performance records (last 1000)
480 | if len(self.performance_history[model_id]) > 1000:
481 | self.performance_history[model_id] = self.performance_history[model_id][
482 | -1000:
483 | ]
484 |
485 | logger.debug(f"Logged performance for {model_id} v{version}")
486 |
487 | def compare_model_versions(
488 | self, model_id: str, versions: list[str] | None = None
489 | ) -> pd.DataFrame:
490 | """Compare performance metrics across model versions.
491 |
492 | Args:
493 | model_id: Model identifier
494 | versions: Versions to compare (defaults to all versions)
495 |
496 | Returns:
497 | DataFrame with comparison results
498 | """
499 | if model_id not in self.models:
500 | return pd.DataFrame()
501 |
502 | if versions is None:
503 | versions = list(self.models[model_id].keys())
504 |
505 | comparison_data = []
506 | for version in versions:
507 | if version in self.models[model_id]:
508 | model_version = self.models[model_id][version]
509 | row_data = {
510 | "version": version,
511 | "created_at": model_version.created_at,
512 | "usage_count": model_version.usage_count,
513 | "is_active": self.active_models.get(model_id) == version,
514 | }
515 | row_data.update(model_version.performance_metrics)
516 | comparison_data.append(row_data)
517 |
518 | return pd.DataFrame(comparison_data)
519 |
520 | def get_storage_stats(self) -> dict[str, Any]:
521 | """Get storage statistics for the model manager.
522 |
523 | Returns:
524 | Dictionary with storage statistics
525 | """
526 | total_size = 0
527 | total_models = 0
528 | total_versions = 0
529 |
530 | for model_id, versions in self.models.items():
531 | total_models += 1
532 | for version in versions:
533 | total_versions += 1
534 | model_path = self._get_model_path(model_id, version)
535 | if model_path.exists():
536 | total_size += model_path.stat().st_size
537 |
538 | return {
539 | "total_models": total_models,
540 | "total_versions": total_versions,
541 | "total_size_bytes": total_size,
542 | "total_size_mb": total_size / (1024 * 1024),
543 | "base_path": str(self.base_path),
544 | }
545 |
```
--------------------------------------------------------------------------------
/tests/test_runner_validation.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Test runner validation for parallel research functionality test suites.
3 |
4 | This module validates that all test suites follow pytest best practices and async patterns
5 | without triggering circular imports during validation.
6 | """
7 |
8 | import ast
9 | import re
10 | from pathlib import Path
11 | from typing import Any
12 |
13 |
14 | class TestSuiteValidator:
15 | """Validator for test suite structure and patterns."""
16 |
17 | def __init__(self, test_file_path: str):
18 | self.test_file_path = Path(test_file_path)
19 | self.content = self.test_file_path.read_text()
20 | self.tree = ast.parse(self.content)
21 |
22 | def validate_pytest_patterns(self) -> dict[str, Any]:
23 | """Validate pytest patterns and best practices."""
24 | results = {
25 | "has_pytest_markers": False,
26 | "has_async_tests": False,
27 | "has_fixtures": False,
28 | "has_proper_imports": False,
29 | "has_class_based_tests": False,
30 | "test_count": 0,
31 | "async_test_count": 0,
32 | "fixture_count": 0,
33 | }
34 |
35 | # Check imports
36 | for node in ast.walk(self.tree):
37 | if isinstance(node, ast.ImportFrom):
38 | if node.module == "pytest":
39 | results["has_proper_imports"] = True
40 | elif isinstance(node, ast.Import):
41 | for alias in node.names:
42 | if alias.name == "pytest":
43 | results["has_proper_imports"] = True
44 |
45 | # Check for pytest markers, fixtures, and test functions
46 | for node in ast.walk(self.tree):
47 | if isinstance(node, ast.FunctionDef):
48 | # Check for test functions
49 | if node.name.startswith("test_"):
50 | results["test_count"] += 1
51 |
52 | # Check for async tests
53 | if isinstance(node, ast.AsyncFunctionDef):
54 | results["has_async_tests"] = True
55 | results["async_test_count"] += 1
56 |
57 | # Check for fixtures
58 | for decorator in node.decorator_list:
59 | if isinstance(decorator, ast.Attribute):
60 | if decorator.attr == "fixture":
61 | results["has_fixtures"] = True
62 | results["fixture_count"] += 1
63 | elif isinstance(decorator, ast.Name):
64 | if decorator.id == "fixture":
65 | results["has_fixtures"] = True
66 | results["fixture_count"] += 1
67 |
68 | elif isinstance(node, ast.AsyncFunctionDef):
69 | if node.name.startswith("test_"):
70 | results["test_count"] += 1
71 | results["has_async_tests"] = True
72 | results["async_test_count"] += 1
73 |
74 | # Check for pytest markers
75 | marker_pattern = r"@pytest\.mark\.\w+"
76 | if re.search(marker_pattern, self.content):
77 | results["has_pytest_markers"] = True
78 |
79 | # Check for class-based tests
80 | for node in ast.walk(self.tree):
81 | if isinstance(node, ast.ClassDef):
82 | if node.name.startswith("Test"):
83 | results["has_class_based_tests"] = True
84 | break
85 |
86 | return results
87 |
88 | def validate_async_patterns(self) -> dict[str, Any]:
89 | """Validate async/await patterns."""
90 | results = {
91 | "proper_async_await": True,
92 | "has_asyncio_imports": False,
93 | "async_fixtures_marked": True,
94 | "issues": [],
95 | }
96 |
97 | # Check for asyncio imports
98 | if "import asyncio" in self.content or "from asyncio" in self.content:
99 | results["has_asyncio_imports"] = True
100 |
101 | # Check async function patterns
102 | for node in ast.walk(self.tree):
103 | if isinstance(node, ast.AsyncFunctionDef):
104 | # Check if async test functions are properly marked
105 | if node.name.startswith("test_"):
106 | for decorator in node.decorator_list:
107 | if isinstance(decorator, ast.Attribute):
108 | if (
109 | hasattr(decorator.value, "attr")
110 | and decorator.value.attr == "mark"
111 | and decorator.attr == "asyncio"
112 | ):
113 | pass
114 | elif isinstance(decorator, ast.Call):
115 | if (
116 | isinstance(decorator.func, ast.Attribute)
117 | and hasattr(decorator.func.value, "attr")
118 | and decorator.func.value.attr == "mark"
119 | and decorator.func.attr == "asyncio"
120 | ):
121 | pass
122 |
123 | # Not all test environments require explicit asyncio marking
124 | # Modern pytest-asyncio auto-detects async tests
125 |
126 | return results
127 |
128 | def validate_mock_usage(self) -> dict[str, Any]:
129 | """Validate mock usage patterns."""
130 | results = {
131 | "has_mocks": False,
132 | "has_async_mocks": False,
133 | "has_patch_usage": False,
134 | "proper_mock_imports": False,
135 | }
136 |
137 | # Check mock imports
138 | mock_imports = ["Mock", "AsyncMock", "MagicMock", "patch"]
139 | for imp in mock_imports:
140 | if (
141 | f"from unittest.mock import {imp}" in self.content
142 | or f"import {imp}" in self.content
143 | ):
144 | results["proper_mock_imports"] = True
145 | results["has_mocks"] = True
146 | if imp == "AsyncMock":
147 | results["has_async_mocks"] = True
148 | if imp == "patch":
149 | results["has_patch_usage"] = True
150 |
151 | return results
152 |
153 |
154 | class TestParallelResearchTestSuites:
155 | """Test the test suites for parallel research functionality."""
156 |
157 | def test_parallel_research_orchestrator_tests_structure(self):
158 | """Test structure of ParallelResearchOrchestrator test suite."""
159 | test_file = Path(__file__).parent / "test_parallel_research_orchestrator.py"
160 | assert test_file.exists(), "ParallelResearchOrchestrator test file should exist"
161 |
162 | validator = TestSuiteValidator(str(test_file))
163 | results = validator.validate_pytest_patterns()
164 |
165 | assert results["test_count"] > 0, "Should have test functions"
166 | assert results["has_async_tests"], "Should have async tests"
167 | assert results["has_fixtures"], "Should have fixtures"
168 | assert results["has_class_based_tests"], "Should have class-based tests"
169 | assert results["async_test_count"] > 0, "Should have async test functions"
170 |
171 | def test_deep_research_parallel_execution_tests_structure(self):
172 | """Test structure of DeepResearchAgent parallel execution test suite."""
173 | test_file = Path(__file__).parent / "test_deep_research_parallel_execution.py"
174 | assert test_file.exists(), "DeepResearchAgent parallel test file should exist"
175 |
176 | validator = TestSuiteValidator(str(test_file))
177 | results = validator.validate_pytest_patterns()
178 |
179 | assert results["test_count"] > 0, "Should have test functions"
180 | assert results["has_async_tests"], "Should have async tests"
181 | assert results["has_fixtures"], "Should have fixtures"
182 | assert results["has_class_based_tests"], "Should have class-based tests"
183 |
184 | def test_orchestration_logging_tests_structure(self):
185 | """Test structure of OrchestrationLogger test suite."""
186 | test_file = Path(__file__).parent / "test_orchestration_logging.py"
187 | assert test_file.exists(), "OrchestrationLogger test file should exist"
188 |
189 | validator = TestSuiteValidator(str(test_file))
190 | results = validator.validate_pytest_patterns()
191 |
192 | assert results["test_count"] > 0, "Should have test functions"
193 | assert results["has_async_tests"], "Should have async tests"
194 | assert results["has_fixtures"], "Should have fixtures"
195 | assert results["has_class_based_tests"], "Should have class-based tests"
196 |
197 | def test_parallel_research_integration_tests_structure(self):
198 | """Test structure of parallel research integration test suite."""
199 | test_file = Path(__file__).parent / "test_parallel_research_integration.py"
200 | assert test_file.exists(), (
201 | "Parallel research integration test file should exist"
202 | )
203 |
204 | validator = TestSuiteValidator(str(test_file))
205 | results = validator.validate_pytest_patterns()
206 |
207 | assert results["test_count"] > 0, "Should have test functions"
208 | assert results["has_async_tests"], "Should have async tests"
209 | assert results["has_fixtures"], "Should have fixtures"
210 | assert results["has_class_based_tests"], "Should have class-based tests"
211 | assert results["has_pytest_markers"], (
212 | "Should have pytest markers (like @pytest.mark.integration)"
213 | )
214 |
215 | def test_async_patterns_validation(self):
216 | """Test that async patterns are properly implemented across all test suites."""
217 | test_files = [
218 | "test_parallel_research_orchestrator.py",
219 | "test_deep_research_parallel_execution.py",
220 | "test_orchestration_logging.py",
221 | "test_parallel_research_integration.py",
222 | ]
223 |
224 | for test_file in test_files:
225 | file_path = Path(__file__).parent / test_file
226 | if file_path.exists():
227 | validator = TestSuiteValidator(str(file_path))
228 | results = validator.validate_async_patterns()
229 |
230 | assert results["proper_async_await"], (
231 | f"Async patterns should be correct in {test_file}"
232 | )
233 | assert results["has_asyncio_imports"], (
234 | f"Should import asyncio in {test_file}"
235 | )
236 |
237 | def test_mock_usage_patterns(self):
238 | """Test that mock usage patterns are consistent across test suites."""
239 | test_files = [
240 | "test_parallel_research_orchestrator.py",
241 | "test_deep_research_parallel_execution.py",
242 | "test_orchestration_logging.py",
243 | "test_parallel_research_integration.py",
244 | ]
245 |
246 | for test_file in test_files:
247 | file_path = Path(__file__).parent / test_file
248 | if file_path.exists():
249 | validator = TestSuiteValidator(str(file_path))
250 | results = validator.validate_mock_usage()
251 |
252 | assert results["has_mocks"], f"Should use mocks in {test_file}"
253 | assert results["proper_mock_imports"], (
254 | f"Should have proper mock imports in {test_file}"
255 | )
256 |
257 | # For async-heavy test files, should use AsyncMock
258 | if test_file in [
259 | "test_parallel_research_orchestrator.py",
260 | "test_deep_research_parallel_execution.py",
261 | "test_parallel_research_integration.py",
262 | ]:
263 | assert results["has_async_mocks"], (
264 | f"Should use AsyncMock in {test_file}"
265 | )
266 |
267 | def test_test_coverage_completeness(self):
268 | """Test that test coverage is comprehensive for parallel research functionality."""
269 | # Define expected test categories for each component
270 | expected_test_categories = {
271 | "test_parallel_research_orchestrator.py": [
272 | "config",
273 | "task",
274 | "orchestrator",
275 | "distribution",
276 | "result",
277 | "integration",
278 | ],
279 | "test_deep_research_parallel_execution.py": [
280 | "agent",
281 | "subagent",
282 | "execution",
283 | "synthesis",
284 | "integration",
285 | ],
286 | "test_orchestration_logging.py": [
287 | "logger",
288 | "decorator",
289 | "context",
290 | "utility",
291 | "integrated",
292 | "load",
293 | ],
294 | "test_parallel_research_integration.py": [
295 | "endtoend",
296 | "scalability",
297 | "logging",
298 | "error",
299 | "data",
300 | ],
301 | }
302 |
303 | for test_file, expected_categories in expected_test_categories.items():
304 | file_path = Path(__file__).parent / test_file
305 | if file_path.exists():
306 | content = file_path.read_text().lower()
307 |
308 | for category in expected_categories:
309 | assert category in content, (
310 | f"Should have {category} tests in {test_file}"
311 | )
312 |
313 | def test_docstring_quality(self):
314 | """Test that test files have proper docstrings."""
315 | test_files = [
316 | "test_parallel_research_orchestrator.py",
317 | "test_deep_research_parallel_execution.py",
318 | "test_orchestration_logging.py",
319 | "test_parallel_research_integration.py",
320 | ]
321 |
322 | for test_file in test_files:
323 | file_path = Path(__file__).parent / test_file
324 | if file_path.exists():
325 | content = file_path.read_text()
326 |
327 | # Should have module docstring
328 | assert '"""' in content, f"Should have docstrings in {test_file}"
329 |
330 | # Should describe what is being tested
331 | docstring_keywords = ["test", "functionality", "cover", "suite"]
332 | first_docstring = content.split('"""')[1].lower()
333 | assert any(
334 | keyword in first_docstring for keyword in docstring_keywords
335 | ), f"Module docstring should describe testing purpose in {test_file}"
336 |
337 | def test_import_safety(self):
338 | """Test that imports are safe and avoid circular dependencies."""
339 | test_files = [
340 | "test_parallel_research_orchestrator.py",
341 | "test_deep_research_parallel_execution.py",
342 | "test_orchestration_logging.py",
343 | "test_parallel_research_integration.py",
344 | ]
345 |
346 | for test_file in test_files:
347 | file_path = Path(__file__).parent / test_file
348 | if file_path.exists():
349 | content = file_path.read_text()
350 |
351 | # Should not have circular import patterns
352 | lines = content.split("\n")
353 | import_lines = [
354 | line
355 | for line in lines
356 | if line.strip().startswith(("import ", "from "))
357 | ]
358 |
359 | # Basic validation that imports are structured
360 | assert len(import_lines) > 0, (
361 | f"Should have import statements in {test_file}"
362 | )
363 |
364 | # Should import pytest
365 | pytest_imported = any("pytest" in line for line in import_lines)
366 | assert pytest_imported, f"Should import pytest in {test_file}"
367 |
368 | def test_fixture_best_practices(self):
369 | """Test that fixtures follow best practices."""
370 | test_files = [
371 | "test_parallel_research_orchestrator.py",
372 | "test_deep_research_parallel_execution.py",
373 | "test_orchestration_logging.py",
374 | "test_parallel_research_integration.py",
375 | ]
376 |
377 | for test_file in test_files:
378 | file_path = Path(__file__).parent / test_file
379 | if file_path.exists():
380 | content = file_path.read_text()
381 |
382 | # If file has fixtures, they should be properly structured
383 | if "@pytest.fixture" in content:
384 | # Should have fixture decorators
385 | assert "def " in content, (
386 | f"Fixtures should be functions in {test_file}"
387 | )
388 |
389 | # Common fixture patterns should be present
390 | fixture_patterns = ["yield", "return", "Mock", "config"]
391 | has_fixture_pattern = any(
392 | pattern in content for pattern in fixture_patterns
393 | )
394 | assert has_fixture_pattern, (
395 | f"Should have proper fixture patterns in {test_file}"
396 | )
397 |
398 | def test_error_handling_coverage(self):
399 | """Test that error handling scenarios are covered."""
400 | test_files = [
401 | "test_parallel_research_orchestrator.py",
402 | "test_deep_research_parallel_execution.py",
403 | "test_parallel_research_integration.py",
404 | ]
405 |
406 | for test_file in test_files:
407 | file_path = Path(__file__).parent / test_file
408 | if file_path.exists():
409 | content = file_path.read_text().lower()
410 |
411 | # Should test error scenarios
412 | error_keywords = [
413 | "error",
414 | "exception",
415 | "timeout",
416 | "failure",
417 | "fallback",
418 | ]
419 | has_error_tests = any(keyword in content for keyword in error_keywords)
420 | assert has_error_tests, f"Should test error scenarios in {test_file}"
421 |
422 | def test_performance_testing_coverage(self):
423 | """Test that performance characteristics are tested."""
424 | performance_test_files = [
425 | "test_parallel_research_orchestrator.py",
426 | "test_parallel_research_integration.py",
427 | ]
428 |
429 | for test_file in performance_test_files:
430 | file_path = Path(__file__).parent / test_file
431 | if file_path.exists():
432 | content = file_path.read_text().lower()
433 |
434 | # Should test performance characteristics
435 | perf_keywords = [
436 | "performance",
437 | "timing",
438 | "efficiency",
439 | "concurrent",
440 | "parallel",
441 | ]
442 | has_perf_tests = any(keyword in content for keyword in perf_keywords)
443 | assert has_perf_tests, (
444 | f"Should test performance characteristics in {test_file}"
445 | )
446 |
447 | def test_integration_test_markers(self):
448 | """Test that integration tests are properly marked."""
449 | integration_file = (
450 | Path(__file__).parent / "test_parallel_research_integration.py"
451 | )
452 |
453 | if integration_file.exists():
454 | content = integration_file.read_text()
455 |
456 | # Should have integration markers
457 | assert "@pytest.mark.integration" in content, (
458 | "Should mark integration tests"
459 | )
460 |
461 | # Should have integration test classes
462 | integration_patterns = ["TestParallel", "Integration", "EndToEnd"]
463 | has_integration_classes = any(
464 | pattern in content for pattern in integration_patterns
465 | )
466 | assert has_integration_classes, "Should have integration test classes"
467 |
```
--------------------------------------------------------------------------------
/maverick_mcp/providers/openrouter_provider.py:
--------------------------------------------------------------------------------
```python
1 | """OpenRouter LLM provider with intelligent model selection.
2 |
3 | This module provides integration with OpenRouter API for accessing various LLMs
4 | with automatic model selection based on task requirements.
5 | """
6 |
7 | import logging
8 | from enum import Enum
9 | from typing import Any
10 |
11 | from langchain_openai import ChatOpenAI
12 | from pydantic import BaseModel, Field
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | class TaskType(str, Enum):
18 | """Task types for model selection."""
19 |
20 | # Analysis tasks
21 | DEEP_RESEARCH = "deep_research"
22 | MARKET_ANALYSIS = "market_analysis"
23 | TECHNICAL_ANALYSIS = "technical_analysis"
24 | SENTIMENT_ANALYSIS = "sentiment_analysis"
25 | RISK_ASSESSMENT = "risk_assessment"
26 |
27 | # Synthesis tasks
28 | RESULT_SYNTHESIS = "result_synthesis"
29 | PORTFOLIO_OPTIMIZATION = "portfolio_optimization"
30 |
31 | # Query processing
32 | QUERY_CLASSIFICATION = "query_classification"
33 | QUICK_ANSWER = "quick_answer"
34 |
35 | # Complex reasoning
36 | COMPLEX_REASONING = "complex_reasoning"
37 | MULTI_AGENT_ORCHESTRATION = "multi_agent_orchestration"
38 |
39 | # Default
40 | GENERAL = "general"
41 |
42 |
43 | class ModelProfile(BaseModel):
44 | """Profile for an LLM model with capabilities and costs."""
45 |
46 | model_id: str = Field(description="OpenRouter model identifier")
47 | name: str = Field(description="Human-readable model name")
48 | provider: str = Field(description="Model provider (e.g., anthropic, openai)")
49 | context_length: int = Field(description="Maximum context length in tokens")
50 | cost_per_million_input: float = Field(
51 | description="Cost per million input tokens in USD"
52 | )
53 | cost_per_million_output: float = Field(
54 | description="Cost per million output tokens in USD"
55 | )
56 | speed_rating: int = Field(description="Speed rating 1-10 (10 being fastest)")
57 | quality_rating: int = Field(description="Quality rating 1-10 (10 being best)")
58 | best_for: list[TaskType] = Field(description="Task types this model excels at")
59 | temperature: float = Field(
60 | default=0.3, description="Default temperature for this model"
61 | )
62 |
63 |
64 | # Model profiles for intelligent selection
65 | MODEL_PROFILES = {
66 | # Premium models (use sparingly for critical tasks)
67 | "anthropic/claude-opus-4.1": ModelProfile(
68 | model_id="anthropic/claude-opus-4.1",
69 | name="Claude Opus 4.1",
70 | provider="anthropic",
71 | context_length=200000,
72 | cost_per_million_input=15.0,
73 | cost_per_million_output=75.0,
74 | speed_rating=7,
75 | quality_rating=10,
76 | best_for=[
77 | TaskType.COMPLEX_REASONING, # Only for the most complex tasks
78 | ],
79 | temperature=0.3,
80 | ),
81 | # Cost-effective high-quality models (primary workhorses)
82 | "anthropic/claude-sonnet-4": ModelProfile(
83 | model_id="anthropic/claude-sonnet-4",
84 | name="Claude Sonnet 4",
85 | provider="anthropic",
86 | context_length=1000000, # 1M token context capability!
87 | cost_per_million_input=3.0,
88 | cost_per_million_output=15.0,
89 | speed_rating=8,
90 | quality_rating=9,
91 | best_for=[
92 | TaskType.DEEP_RESEARCH,
93 | TaskType.MARKET_ANALYSIS,
94 | TaskType.TECHNICAL_ANALYSIS,
95 | TaskType.MULTI_AGENT_ORCHESTRATION,
96 | TaskType.RESULT_SYNTHESIS,
97 | TaskType.PORTFOLIO_OPTIMIZATION,
98 | ],
99 | temperature=0.3,
100 | ),
101 | "openai/gpt-5": ModelProfile(
102 | model_id="openai/gpt-5",
103 | name="GPT-5",
104 | provider="openai",
105 | context_length=400000,
106 | cost_per_million_input=1.25,
107 | cost_per_million_output=10.0,
108 | speed_rating=8,
109 | quality_rating=9,
110 | best_for=[
111 | TaskType.DEEP_RESEARCH,
112 | TaskType.MARKET_ANALYSIS,
113 | ],
114 | temperature=0.3,
115 | ),
116 | # Excellent cost-performance ratio models
117 | "google/gemini-2.5-pro": ModelProfile(
118 | model_id="google/gemini-2.5-pro",
119 | name="Gemini 2.5 Pro",
120 | provider="google",
121 | context_length=1000000, # 1M token context!
122 | cost_per_million_input=2.0,
123 | cost_per_million_output=8.0,
124 | speed_rating=8,
125 | quality_rating=9,
126 | best_for=[
127 | TaskType.DEEP_RESEARCH,
128 | TaskType.MARKET_ANALYSIS,
129 | TaskType.TECHNICAL_ANALYSIS,
130 | ],
131 | temperature=0.3,
132 | ),
133 | "deepseek/deepseek-r1": ModelProfile(
134 | model_id="deepseek/deepseek-r1",
135 | name="DeepSeek R1",
136 | provider="deepseek",
137 | context_length=128000,
138 | cost_per_million_input=0.5,
139 | cost_per_million_output=1.0,
140 | speed_rating=8,
141 | quality_rating=9,
142 | best_for=[
143 | TaskType.MARKET_ANALYSIS,
144 | TaskType.TECHNICAL_ANALYSIS,
145 | TaskType.RISK_ASSESSMENT,
146 | ],
147 | temperature=0.3,
148 | ),
149 | # Fast, cost-effective models for simpler tasks
150 | # Speed-optimized models for research timeouts
151 | "google/gemini-2.5-flash": ModelProfile(
152 | model_id="google/gemini-2.5-flash",
153 | name="Gemini 2.5 Flash",
154 | provider="google",
155 | context_length=1000000,
156 | cost_per_million_input=0.075, # Ultra low cost
157 | cost_per_million_output=0.30,
158 | speed_rating=10, # 199 tokens/sec - FASTEST available
159 | quality_rating=8,
160 | best_for=[
161 | TaskType.DEEP_RESEARCH,
162 | TaskType.MARKET_ANALYSIS,
163 | TaskType.QUICK_ANSWER,
164 | TaskType.SENTIMENT_ANALYSIS,
165 | ],
166 | temperature=0.2,
167 | ),
168 | "openai/gpt-4o-mini": ModelProfile(
169 | model_id="openai/gpt-4o-mini",
170 | name="GPT-4o Mini",
171 | provider="openai",
172 | context_length=128000,
173 | cost_per_million_input=0.15,
174 | cost_per_million_output=0.60,
175 | speed_rating=9, # 126 tokens/sec - Excellent speed/cost balance
176 | quality_rating=8,
177 | best_for=[
178 | TaskType.DEEP_RESEARCH,
179 | TaskType.MARKET_ANALYSIS,
180 | TaskType.TECHNICAL_ANALYSIS,
181 | TaskType.QUICK_ANSWER,
182 | ],
183 | temperature=0.2,
184 | ),
185 | "anthropic/claude-3.5-haiku": ModelProfile(
186 | model_id="anthropic/claude-3.5-haiku",
187 | name="Claude 3.5 Haiku",
188 | provider="anthropic",
189 | context_length=200000,
190 | cost_per_million_input=0.25,
191 | cost_per_million_output=1.25,
192 | speed_rating=7, # 65.6 tokens/sec - Updated with actual speed rating
193 | quality_rating=8,
194 | best_for=[
195 | TaskType.QUERY_CLASSIFICATION,
196 | TaskType.QUICK_ANSWER,
197 | TaskType.SENTIMENT_ANALYSIS,
198 | ],
199 | temperature=0.2,
200 | ),
201 | "openai/gpt-5-nano": ModelProfile(
202 | model_id="openai/gpt-5-nano",
203 | name="GPT-5 Nano",
204 | provider="openai",
205 | context_length=400000,
206 | cost_per_million_input=0.05,
207 | cost_per_million_output=0.40,
208 | speed_rating=9, # 180 tokens/sec - Very fast
209 | quality_rating=7,
210 | best_for=[
211 | TaskType.QUICK_ANSWER,
212 | TaskType.QUERY_CLASSIFICATION,
213 | TaskType.DEEP_RESEARCH, # Added for emergency research
214 | ],
215 | temperature=0.2,
216 | ),
217 | # Specialized models
218 | "xai/grok-4": ModelProfile(
219 | model_id="xai/grok-4",
220 | name="Grok 4",
221 | provider="xai",
222 | context_length=128000,
223 | cost_per_million_input=3.0,
224 | cost_per_million_output=12.0,
225 | speed_rating=7,
226 | quality_rating=9,
227 | best_for=[
228 | TaskType.MARKET_ANALYSIS,
229 | TaskType.SENTIMENT_ANALYSIS,
230 | TaskType.PORTFOLIO_OPTIMIZATION,
231 | ],
232 | temperature=0.3,
233 | ),
234 | }
235 |
236 |
237 | class OpenRouterProvider:
238 | """Provider for OpenRouter API with intelligent model selection."""
239 |
240 | def __init__(self, api_key: str):
241 | """Initialize OpenRouter provider.
242 |
243 | Args:
244 | api_key: OpenRouter API key
245 | """
246 | self.api_key = api_key
247 | self.base_url = "https://openrouter.ai/api/v1"
248 | self._model_usage_stats: dict[str, dict[str, int]] = {}
249 |
250 | def get_llm(
251 | self,
252 | task_type: TaskType = TaskType.GENERAL,
253 | prefer_fast: bool = False,
254 | prefer_cheap: bool = True, # Default to cost-effective
255 | prefer_quality: bool = False, # Override for premium models
256 | model_override: str | None = None,
257 | temperature: float | None = None,
258 | max_tokens: int = 4096,
259 | timeout_budget: float | None = None, # Emergency mode for timeouts
260 | ) -> ChatOpenAI:
261 | """Get an LLM instance optimized for the task.
262 |
263 | Args:
264 | task_type: Type of task to optimize for
265 | prefer_fast: Prioritize speed over quality
266 | prefer_cheap: Prioritize cost over quality (default True)
267 | prefer_quality: Use premium models regardless of cost
268 | model_override: Override model selection
269 | temperature: Override default temperature
270 | max_tokens: Maximum tokens for response
271 | timeout_budget: Available time budget - triggers emergency mode if < 30s
272 |
273 | Returns:
274 | Configured ChatOpenAI instance
275 | """
276 | # Use override if provided
277 | if model_override:
278 | model_id = model_override
279 | model_profile = MODEL_PROFILES.get(
280 | model_id,
281 | ModelProfile(
282 | model_id=model_id,
283 | name=model_id,
284 | provider="unknown",
285 | context_length=128000,
286 | cost_per_million_input=1.0,
287 | cost_per_million_output=1.0,
288 | speed_rating=5,
289 | quality_rating=5,
290 | best_for=[TaskType.GENERAL],
291 | temperature=0.3,
292 | ),
293 | )
294 | # Emergency mode for tight timeout budgets
295 | elif timeout_budget is not None and timeout_budget < 30:
296 | model_profile = self._select_emergency_model(task_type, timeout_budget)
297 | model_id = model_profile.model_id
298 | logger.warning(
299 | f"EMERGENCY MODE: Selected ultra-fast model '{model_profile.name}' "
300 | f"for {timeout_budget}s timeout budget"
301 | )
302 | else:
303 | model_profile = self._select_model(
304 | task_type, prefer_fast, prefer_cheap, prefer_quality
305 | )
306 | model_id = model_profile.model_id
307 |
308 | # Use provided temperature or model default
309 | final_temperature = (
310 | temperature if temperature is not None else model_profile.temperature
311 | )
312 |
313 | # Log model selection
314 | logger.info(
315 | f"Selected model '{model_profile.name}' for task '{task_type}' "
316 | f"(speed={model_profile.speed_rating}/10, quality={model_profile.quality_rating}/10, "
317 | f"cost=${model_profile.cost_per_million_input}/{model_profile.cost_per_million_output} per 1M tokens)"
318 | )
319 |
320 | # Track usage
321 | self._track_usage(model_id, task_type)
322 |
323 | # Create LangChain ChatOpenAI instance
324 | return ChatOpenAI(
325 | model=model_id,
326 | temperature=final_temperature,
327 | max_tokens=max_tokens,
328 | openai_api_base=self.base_url,
329 | openai_api_key=self.api_key,
330 | default_headers={
331 | "HTTP-Referer": "https://github.com/wshobson/maverick-mcp",
332 | "X-Title": "Maverick MCP",
333 | },
334 | streaming=True,
335 | )
336 |
337 | def _select_model(
338 | self,
339 | task_type: TaskType,
340 | prefer_fast: bool = False,
341 | prefer_cheap: bool = True,
342 | prefer_quality: bool = False,
343 | ) -> ModelProfile:
344 | """Select the best model for the task with cost-efficiency in mind.
345 |
346 | Args:
347 | task_type: Type of task
348 | prefer_fast: Prioritize speed
349 | prefer_cheap: Prioritize cost (default True)
350 | prefer_quality: Use premium models regardless of cost
351 |
352 | Returns:
353 | Selected model profile
354 | """
355 | candidates = []
356 |
357 | # Find models suitable for this task
358 | for profile in MODEL_PROFILES.values():
359 | if task_type in profile.best_for or task_type == TaskType.GENERAL:
360 | candidates.append(profile)
361 |
362 | if not candidates:
363 | # Fallback to GPT-5 Nano for general tasks
364 | return MODEL_PROFILES["openai/gpt-5-nano"]
365 |
366 | # Score and rank candidates
367 | scored_candidates = []
368 | for profile in candidates:
369 | score = 0
370 |
371 | # Calculate average cost for this model
372 | avg_cost = (
373 | profile.cost_per_million_input + profile.cost_per_million_output
374 | ) / 2
375 |
376 | # Quality preference overrides cost considerations
377 | if prefer_quality:
378 | # Heavily weight quality for premium mode
379 | score += profile.quality_rating * 20
380 | # Task fitness is critical
381 | if task_type in profile.best_for:
382 | score += 40
383 | # Minimal cost consideration
384 | score += max(0, 20 - avg_cost)
385 | else:
386 | # Cost-efficiency focused scoring (default)
387 | # Calculate cost-efficiency ratio
388 | cost_efficiency = profile.quality_rating / max(1, avg_cost)
389 | score += cost_efficiency * 30
390 |
391 | # Task fitness bonus
392 | if task_type in profile.best_for:
393 | score += 25
394 |
395 | # Base quality (reduced weight)
396 | score += profile.quality_rating * 5
397 |
398 | # Speed preference
399 | if prefer_fast:
400 | score += profile.speed_rating * 5
401 | else:
402 | score += profile.speed_rating * 2
403 |
404 | # Cost preference adjustment
405 | if prefer_cheap:
406 | # Strong cost preference
407 | cost_score = max(0, 100 - avg_cost * 5)
408 | score += cost_score
409 | else:
410 | # Balanced cost consideration (default)
411 | cost_score = max(0, 60 - avg_cost * 3)
412 | score += cost_score
413 |
414 | scored_candidates.append((score, profile))
415 |
416 | # Sort by score and return best
417 | scored_candidates.sort(key=lambda x: x[0], reverse=True)
418 | return scored_candidates[0][1]
419 |
420 | def _select_emergency_model(
421 | self, task_type: TaskType, timeout_budget: float
422 | ) -> ModelProfile:
423 | """Select the fastest model available for emergency timeout situations.
424 |
425 | Emergency mode prioritizes speed above all other considerations.
426 | Used when timeout_budget < 30 seconds.
427 |
428 | Args:
429 | task_type: Type of task
430 | timeout_budget: Available time in seconds (< 30s)
431 |
432 | Returns:
433 | Fastest available model profile
434 | """
435 | # Emergency model priority (by actual tokens per second)
436 |
437 | # For ultra-tight budgets (< 15s), use only the absolute fastest
438 | if timeout_budget < 15:
439 | return MODEL_PROFILES["google/gemini-2.5-flash"]
440 |
441 | # For tight budgets (< 25s), use fastest available models
442 | if timeout_budget < 25:
443 | if task_type in [TaskType.SENTIMENT_ANALYSIS, TaskType.QUICK_ANSWER]:
444 | return MODEL_PROFILES[
445 | "google/gemini-2.5-flash"
446 | ] # Fastest for all tasks
447 | return MODEL_PROFILES["openai/gpt-4o-mini"] # Speed + quality balance
448 |
449 | # For moderate emergency (< 30s), use speed-optimized models for complex tasks
450 | if task_type in [
451 | TaskType.DEEP_RESEARCH,
452 | TaskType.MARKET_ANALYSIS,
453 | TaskType.TECHNICAL_ANALYSIS,
454 | ]:
455 | return MODEL_PROFILES[
456 | "openai/gpt-4o-mini"
457 | ] # Best speed/quality for research
458 |
459 | # Default to fastest model
460 | return MODEL_PROFILES["google/gemini-2.5-flash"]
461 |
462 | def _track_usage(self, model_id: str, task_type: TaskType):
463 | """Track model usage for analytics.
464 |
465 | Args:
466 | model_id: Model identifier
467 | task_type: Task type
468 | """
469 | if model_id not in self._model_usage_stats:
470 | self._model_usage_stats[model_id] = {}
471 |
472 | task_key = task_type.value
473 | if task_key not in self._model_usage_stats[model_id]:
474 | self._model_usage_stats[model_id][task_key] = 0
475 |
476 | self._model_usage_stats[model_id][task_key] += 1
477 |
478 | def get_usage_stats(self) -> dict[str, dict[str, int]]:
479 | """Get model usage statistics.
480 |
481 | Returns:
482 | Dictionary of model usage by task type
483 | """
484 | return self._model_usage_stats.copy()
485 |
486 | def recommend_models_for_workload(
487 | self, workload: dict[TaskType, int]
488 | ) -> dict[str, Any]:
489 | """Recommend optimal model mix for a given workload.
490 |
491 | Args:
492 | workload: Dictionary of task types and their frequencies
493 |
494 | Returns:
495 | Recommendations including models and estimated costs
496 | """
497 | recommendations = {}
498 | total_cost = 0.0
499 |
500 | for task_type, frequency in workload.items():
501 | # Select best model for this task
502 | model = self._select_model(task_type)
503 |
504 | # Estimate tokens (rough approximation)
505 | avg_input_tokens = 2000
506 | avg_output_tokens = 1000
507 |
508 | # Calculate cost
509 | input_cost = (
510 | avg_input_tokens * frequency * model.cost_per_million_input
511 | ) / 1_000_000
512 | output_cost = (
513 | avg_output_tokens * frequency * model.cost_per_million_output
514 | ) / 1_000_000
515 | task_cost = input_cost + output_cost
516 |
517 | recommendations[task_type.value] = {
518 | "model": model.name,
519 | "model_id": model.model_id,
520 | "frequency": frequency,
521 | "estimated_cost": task_cost,
522 | }
523 |
524 | total_cost += task_cost
525 |
526 | return {
527 | "recommendations": recommendations,
528 | "total_estimated_cost": total_cost,
529 | "cost_per_request": total_cost / sum(workload.values()) if workload else 0,
530 | }
531 |
532 |
533 | # Convenience function for backward compatibility
534 | def get_openrouter_llm(
535 | api_key: str,
536 | task_type: TaskType = TaskType.GENERAL,
537 | prefer_fast: bool = False,
538 | prefer_cheap: bool = True,
539 | prefer_quality: bool = False,
540 | **kwargs,
541 | ) -> ChatOpenAI:
542 | """Get an OpenRouter LLM instance with cost-efficiency by default.
543 |
544 | Args:
545 | api_key: OpenRouter API key
546 | task_type: Task type for model selection
547 | prefer_fast: Prioritize speed
548 | prefer_cheap: Prioritize cost (default True)
549 | prefer_quality: Use premium models regardless of cost
550 | **kwargs: Additional arguments for get_llm
551 |
552 | Returns:
553 | Configured ChatOpenAI instance
554 | """
555 | provider = OpenRouterProvider(api_key)
556 | return provider.get_llm(
557 | task_type=task_type,
558 | prefer_fast=prefer_fast,
559 | prefer_cheap=prefer_cheap,
560 | prefer_quality=prefer_quality,
561 | **kwargs,
562 | )
563 |
```
--------------------------------------------------------------------------------
/tests/utils/test_logging.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Unit tests for maverick_mcp.utils.logging module.
3 |
4 | This module contains comprehensive tests for the structured logging system
5 | to ensure proper logging functionality and context management.
6 | """
7 |
8 | import asyncio
9 | import json
10 | import logging
11 | import time
12 | from unittest.mock import Mock, patch
13 |
14 | import pytest
15 |
16 | from maverick_mcp.utils.logging import (
17 | PerformanceMonitor,
18 | RequestContextLogger,
19 | StructuredFormatter,
20 | _get_query_type,
21 | _sanitize_params,
22 | get_logger,
23 | log_cache_operation,
24 | log_database_query,
25 | log_external_api_call,
26 | log_tool_execution,
27 | request_id_var,
28 | request_start_var,
29 | setup_structured_logging,
30 | tool_name_var,
31 | user_id_var,
32 | )
33 |
34 |
35 | class TestStructuredFormatter:
36 | """Test the StructuredFormatter class."""
37 |
38 | def test_basic_format(self):
39 | """Test basic log formatting."""
40 | formatter = StructuredFormatter()
41 | record = logging.LogRecord(
42 | name="test_logger",
43 | level=logging.INFO,
44 | pathname="/test/path.py",
45 | lineno=42,
46 | msg="Test message",
47 | args=(),
48 | exc_info=None,
49 | )
50 |
51 | result = formatter.format(record)
52 |
53 | # Parse the JSON output
54 | log_data = json.loads(result)
55 |
56 | assert log_data["level"] == "INFO"
57 | assert log_data["logger"] == "test_logger"
58 | assert log_data["message"] == "Test message"
59 | assert log_data["line"] == 42
60 | assert "timestamp" in log_data
61 |
62 | def test_format_with_context(self):
63 | """Test formatting with request context."""
64 | formatter = StructuredFormatter()
65 |
66 | # Set context variables
67 | request_id_var.set("test-request-123")
68 | user_id_var.set("user-456")
69 | tool_name_var.set("test_tool")
70 | request_start_var.set(time.time() - 0.5) # 500ms ago
71 |
72 | record = logging.LogRecord(
73 | name="test_logger",
74 | level=logging.INFO,
75 | pathname="/test/path.py",
76 | lineno=42,
77 | msg="Test message",
78 | args=(),
79 | exc_info=None,
80 | )
81 |
82 | result = formatter.format(record)
83 | log_data = json.loads(result)
84 |
85 | assert log_data["request_id"] == "test-request-123"
86 | assert log_data["user_id"] == "user-456"
87 | assert log_data["tool_name"] == "test_tool"
88 | assert "duration_ms" in log_data
89 | assert log_data["duration_ms"] >= 400 # Should be around 500ms
90 |
91 | # Clean up
92 | request_id_var.set(None)
93 | user_id_var.set(None)
94 | tool_name_var.set(None)
95 | request_start_var.set(None)
96 |
97 | def test_format_with_exception(self):
98 | """Test formatting with exception information."""
99 | formatter = StructuredFormatter()
100 |
101 | try:
102 | raise ValueError("Test error")
103 | except ValueError:
104 | import sys
105 |
106 | exc_info = sys.exc_info()
107 |
108 | record = logging.LogRecord(
109 | name="test_logger",
110 | level=logging.ERROR,
111 | pathname="/test/path.py",
112 | lineno=42,
113 | msg="Error occurred",
114 | args=(),
115 | exc_info=exc_info,
116 | )
117 |
118 | result = formatter.format(record)
119 | log_data = json.loads(result)
120 |
121 | assert "exception" in log_data
122 | assert log_data["exception"]["type"] == "ValueError"
123 | assert log_data["exception"]["message"] == "Test error"
124 | assert isinstance(log_data["exception"]["traceback"], list)
125 |
126 | def test_format_with_extra_fields(self):
127 | """Test formatting with extra fields."""
128 | formatter = StructuredFormatter()
129 |
130 | record = logging.LogRecord(
131 | name="test_logger",
132 | level=logging.INFO,
133 | pathname="/test/path.py",
134 | lineno=42,
135 | msg="Test message",
136 | args=(),
137 | exc_info=None,
138 | )
139 |
140 | # Add extra fields
141 | record.custom_field = "custom_value"
142 | record.user_action = "button_click"
143 |
144 | result = formatter.format(record)
145 | log_data = json.loads(result)
146 |
147 | assert log_data["custom_field"] == "custom_value"
148 | assert log_data["user_action"] == "button_click"
149 |
150 |
151 | class TestRequestContextLogger:
152 | """Test the RequestContextLogger class."""
153 |
154 | @pytest.fixture
155 | def mock_logger(self):
156 | """Create a mock logger."""
157 | return Mock(spec=logging.Logger)
158 |
159 | @pytest.fixture
160 | def context_logger(self, mock_logger):
161 | """Create a RequestContextLogger with mocked dependencies."""
162 | with patch("maverick_mcp.utils.logging.psutil.Process") as mock_process:
163 | mock_process.return_value.memory_info.return_value.rss = (
164 | 100 * 1024 * 1024
165 | ) # 100MB
166 | mock_process.return_value.cpu_percent.return_value = 15.5
167 | return RequestContextLogger(mock_logger)
168 |
169 | def test_info_logging(self, context_logger, mock_logger):
170 | """Test info level logging."""
171 | context_logger.info("Test message", extra={"custom": "value"})
172 |
173 | mock_logger.log.assert_called_once()
174 | call_args = mock_logger.log.call_args
175 |
176 | assert call_args[0][0] == logging.INFO
177 | assert call_args[0][1] == "Test message"
178 | assert "extra" in call_args[1]
179 | assert call_args[1]["extra"]["custom"] == "value"
180 | assert "memory_mb" in call_args[1]["extra"]
181 | assert "cpu_percent" in call_args[1]["extra"]
182 |
183 | def test_error_logging(self, context_logger, mock_logger):
184 | """Test error level logging."""
185 | context_logger.error("Error message")
186 |
187 | mock_logger.log.assert_called_once()
188 | call_args = mock_logger.log.call_args
189 |
190 | assert call_args[0][0] == logging.ERROR
191 | assert call_args[0][1] == "Error message"
192 |
193 | def test_debug_logging(self, context_logger, mock_logger):
194 | """Test debug level logging."""
195 | context_logger.debug("Debug message")
196 |
197 | mock_logger.log.assert_called_once()
198 | call_args = mock_logger.log.call_args
199 |
200 | assert call_args[0][0] == logging.DEBUG
201 | assert call_args[0][1] == "Debug message"
202 |
203 | def test_warning_logging(self, context_logger, mock_logger):
204 | """Test warning level logging."""
205 | context_logger.warning("Warning message")
206 |
207 | mock_logger.log.assert_called_once()
208 | call_args = mock_logger.log.call_args
209 |
210 | assert call_args[0][0] == logging.WARNING
211 | assert call_args[0][1] == "Warning message"
212 |
213 | def test_critical_logging(self, context_logger, mock_logger):
214 | """Test critical level logging."""
215 | context_logger.critical("Critical message")
216 |
217 | mock_logger.log.assert_called_once()
218 | call_args = mock_logger.log.call_args
219 |
220 | assert call_args[0][0] == logging.CRITICAL
221 | assert call_args[0][1] == "Critical message"
222 |
223 |
224 | class TestLoggingSetup:
225 | """Test logging setup functions."""
226 |
227 | def test_setup_structured_logging_json_format(self):
228 | """Test setting up structured logging with JSON format."""
229 | with patch("maverick_mcp.utils.logging.logging.getLogger") as mock_get_logger:
230 | mock_root_logger = Mock()
231 | mock_root_logger.handlers = [] # Empty list of handlers
232 | mock_get_logger.return_value = mock_root_logger
233 |
234 | setup_structured_logging(log_level="DEBUG", log_format="json")
235 |
236 | mock_root_logger.setLevel.assert_called_with(logging.DEBUG)
237 | mock_root_logger.addHandler.assert_called()
238 |
239 | def test_setup_structured_logging_text_format(self):
240 | """Test setting up structured logging with text format."""
241 | with patch("maverick_mcp.utils.logging.logging.getLogger") as mock_get_logger:
242 | mock_root_logger = Mock()
243 | mock_root_logger.handlers = [] # Empty list of handlers
244 | mock_get_logger.return_value = mock_root_logger
245 |
246 | setup_structured_logging(log_level="INFO", log_format="text")
247 |
248 | mock_root_logger.setLevel.assert_called_with(logging.INFO)
249 |
250 | def test_setup_structured_logging_with_file(self):
251 | """Test setting up structured logging with file output."""
252 | with patch("maverick_mcp.utils.logging.logging.getLogger") as mock_get_logger:
253 | with patch(
254 | "maverick_mcp.utils.logging.logging.FileHandler"
255 | ) as mock_file_handler:
256 | mock_root_logger = Mock()
257 | mock_root_logger.handlers = [] # Empty list of handlers
258 | mock_get_logger.return_value = mock_root_logger
259 |
260 | setup_structured_logging(log_file="/tmp/test.log")
261 |
262 | mock_file_handler.assert_called_with("/tmp/test.log")
263 | assert mock_root_logger.addHandler.call_count == 2 # Console + File
264 |
265 | def test_get_logger(self):
266 | """Test getting a logger with context support."""
267 | logger = get_logger("test_module")
268 |
269 | assert isinstance(logger, RequestContextLogger)
270 |
271 |
272 | class TestToolExecutionLogging:
273 | """Test the log_tool_execution decorator."""
274 |
275 | @pytest.mark.asyncio
276 | async def test_successful_tool_execution(self):
277 | """Test logging for successful tool execution."""
278 |
279 | @log_tool_execution
280 | async def test_tool(param1, param2="default"):
281 | await asyncio.sleep(0.1) # Simulate work
282 | return {"result": "success"}
283 |
284 | with patch("maverick_mcp.utils.logging.get_logger") as mock_get_logger:
285 | mock_logger = Mock()
286 | mock_get_logger.return_value = mock_logger
287 |
288 | result = await test_tool("test_value", param2="custom")
289 |
290 | assert result == {"result": "success"}
291 | assert mock_logger.info.call_count >= 2 # Start + Success
292 |
293 | # Check that request context was set and cleared
294 | assert request_id_var.get() is None
295 | assert tool_name_var.get() is None
296 | assert request_start_var.get() is None
297 |
298 | @pytest.mark.asyncio
299 | async def test_failed_tool_execution(self):
300 | """Test logging for failed tool execution."""
301 |
302 | @log_tool_execution
303 | async def failing_tool():
304 | raise ValueError("Test error")
305 |
306 | with patch("maverick_mcp.utils.logging.get_logger") as mock_get_logger:
307 | mock_logger = Mock()
308 | mock_get_logger.return_value = mock_logger
309 |
310 | with pytest.raises(ValueError, match="Test error"):
311 | await failing_tool()
312 |
313 | mock_logger.error.assert_called_once()
314 |
315 | # Check that context was cleared even after exception
316 | assert request_id_var.get() is None
317 | assert tool_name_var.get() is None
318 | assert request_start_var.get() is None
319 |
320 |
321 | class TestParameterSanitization:
322 | """Test parameter sanitization for logging."""
323 |
324 | def test_sanitize_sensitive_params(self):
325 | """Test sanitization of sensitive parameters."""
326 | params = {
327 | "username": "testuser",
328 | "password": "secret123",
329 | "api_key": "key_secret",
330 | "auth_token": "token_value",
331 | "normal_param": "normal_value",
332 | }
333 |
334 | sanitized = _sanitize_params(params)
335 |
336 | assert sanitized["username"] == "testuser"
337 | assert sanitized["password"] == "***REDACTED***"
338 | assert sanitized["api_key"] == "***REDACTED***"
339 | assert sanitized["auth_token"] == "***REDACTED***"
340 | assert sanitized["normal_param"] == "normal_value"
341 |
342 | def test_sanitize_nested_params(self):
343 | """Test sanitization of nested parameters."""
344 | params = {
345 | "config": {
346 | "database_url": "postgresql://user:pass@host/db",
347 | "secret_key": "secret",
348 | "debug": True,
349 | },
350 | "normal": "value",
351 | }
352 |
353 | sanitized = _sanitize_params(params)
354 |
355 | assert sanitized["config"]["database_url"] == "postgresql://user:pass@host/db"
356 | assert sanitized["config"]["secret_key"] == "***REDACTED***"
357 | assert sanitized["config"]["debug"] is True
358 | assert sanitized["normal"] == "value"
359 |
360 | def test_sanitize_long_lists(self):
361 | """Test sanitization of long lists."""
362 | params = {
363 | "short_list": [1, 2, 3],
364 | "long_list": list(range(100)),
365 | }
366 |
367 | sanitized = _sanitize_params(params)
368 |
369 | assert sanitized["short_list"] == [1, 2, 3]
370 | assert sanitized["long_list"] == "[100 items]"
371 |
372 | def test_sanitize_long_strings(self):
373 | """Test sanitization of long strings."""
374 | long_string = "x" * 2000
375 | params = {
376 | "short_string": "hello",
377 | "long_string": long_string,
378 | }
379 |
380 | sanitized = _sanitize_params(params)
381 |
382 | assert sanitized["short_string"] == "hello"
383 | assert "... (2000 chars total)" in sanitized["long_string"]
384 | assert len(sanitized["long_string"]) < 200
385 |
386 |
387 | class TestDatabaseQueryLogging:
388 | """Test database query logging."""
389 |
390 | def test_log_database_query_basic(self):
391 | """Test basic database query logging."""
392 | with patch("maverick_mcp.utils.logging.get_logger") as mock_get_logger:
393 | mock_logger = Mock()
394 | mock_get_logger.return_value = mock_logger
395 |
396 | log_database_query("SELECT * FROM users", {"user_id": 123}, 250)
397 |
398 | mock_logger.info.assert_called_once()
399 | mock_logger.debug.assert_called_once()
400 |
401 | def test_get_query_type(self):
402 | """Test query type detection."""
403 | assert _get_query_type("SELECT * FROM users") == "SELECT"
404 | assert _get_query_type("INSERT INTO users VALUES (1, 'test')") == "INSERT"
405 | assert _get_query_type("UPDATE users SET name = 'test'") == "UPDATE"
406 | assert _get_query_type("DELETE FROM users WHERE id = 1") == "DELETE"
407 | assert _get_query_type("CREATE TABLE test (id INT)") == "CREATE"
408 | assert _get_query_type("DROP TABLE test") == "DROP"
409 | assert _get_query_type("EXPLAIN SELECT * FROM users") == "OTHER"
410 |
411 | def test_slow_query_detection(self):
412 | """Test slow query detection."""
413 | with patch("maverick_mcp.utils.logging.get_logger") as mock_get_logger:
414 | mock_logger = Mock()
415 | mock_get_logger.return_value = mock_logger
416 |
417 | log_database_query("SELECT * FROM large_table", duration_ms=1500)
418 |
419 | # Check that slow_query flag is set in extra
420 | call_args = mock_logger.info.call_args
421 | assert call_args[1]["extra"]["slow_query"] is True
422 |
423 |
424 | class TestCacheOperationLogging:
425 | """Test cache operation logging."""
426 |
427 | def test_log_cache_hit(self):
428 | """Test logging cache hit."""
429 | with patch("maverick_mcp.utils.logging.get_logger") as mock_get_logger:
430 | mock_logger = Mock()
431 | mock_get_logger.return_value = mock_logger
432 |
433 | log_cache_operation("get", "stock_data:AAPL", hit=True, duration_ms=5)
434 |
435 | mock_logger.info.assert_called_once()
436 | call_args = mock_logger.info.call_args
437 | assert "hit" in call_args[0][0]
438 | assert call_args[1]["extra"]["cache_hit"] is True
439 |
440 | def test_log_cache_miss(self):
441 | """Test logging cache miss."""
442 | with patch("maverick_mcp.utils.logging.get_logger") as mock_get_logger:
443 | mock_logger = Mock()
444 | mock_get_logger.return_value = mock_logger
445 |
446 | log_cache_operation("get", "stock_data:MSFT", hit=False)
447 |
448 | mock_logger.info.assert_called_once()
449 | call_args = mock_logger.info.call_args
450 | assert "miss" in call_args[0][0]
451 | assert call_args[1]["extra"]["cache_hit"] is False
452 |
453 |
454 | class TestExternalAPILogging:
455 | """Test external API call logging."""
456 |
457 | def test_log_successful_api_call(self):
458 | """Test logging successful API call."""
459 | with patch("maverick_mcp.utils.logging.get_logger") as mock_get_logger:
460 | mock_logger = Mock()
461 | mock_get_logger.return_value = mock_logger
462 |
463 | log_external_api_call(
464 | service="yahoo_finance",
465 | endpoint="/v8/finance/chart/AAPL",
466 | method="GET",
467 | status_code=200,
468 | duration_ms=150,
469 | )
470 |
471 | mock_logger.info.assert_called_once()
472 | call_args = mock_logger.info.call_args
473 | assert call_args[1]["extra"]["success"] is True
474 |
475 | def test_log_failed_api_call(self):
476 | """Test logging failed API call."""
477 | with patch("maverick_mcp.utils.logging.get_logger") as mock_get_logger:
478 | mock_logger = Mock()
479 | mock_get_logger.return_value = mock_logger
480 |
481 | log_external_api_call(
482 | service="yahoo_finance",
483 | endpoint="/v8/finance/chart/INVALID",
484 | method="GET",
485 | status_code=404,
486 | duration_ms=1000,
487 | error="Symbol not found",
488 | )
489 |
490 | mock_logger.error.assert_called_once()
491 | call_args = mock_logger.error.call_args
492 | assert call_args[1]["extra"]["success"] is False
493 | assert call_args[1]["extra"]["error"] == "Symbol not found"
494 |
495 |
496 | class TestPerformanceMonitor:
497 | """Test the PerformanceMonitor context manager."""
498 |
499 | def test_successful_operation(self):
500 | """Test monitoring successful operation."""
501 | with patch("maverick_mcp.utils.logging.get_logger") as mock_get_logger:
502 | mock_logger = Mock()
503 | mock_get_logger.return_value = mock_logger
504 |
505 | with PerformanceMonitor("test_operation"):
506 | time.sleep(0.1) # Simulate work
507 |
508 | mock_logger.info.assert_called_once()
509 | call_args = mock_logger.info.call_args
510 | assert "completed" in call_args[0][0]
511 | assert call_args[1]["extra"]["success"] is True
512 | assert call_args[1]["extra"]["duration_ms"] >= 100
513 |
514 | def test_failed_operation(self):
515 | """Test monitoring failed operation."""
516 | with patch("maverick_mcp.utils.logging.get_logger") as mock_get_logger:
517 | mock_logger = Mock()
518 | mock_get_logger.return_value = mock_logger
519 |
520 | with pytest.raises(ValueError):
521 | with PerformanceMonitor("failing_operation"):
522 | raise ValueError("Test error")
523 |
524 | mock_logger.error.assert_called_once()
525 | call_args = mock_logger.error.call_args
526 | assert "failed" in call_args[0][0]
527 | assert call_args[1]["extra"]["success"] is False
528 | assert call_args[1]["extra"]["error_type"] == "ValueError"
529 |
530 | def test_memory_tracking(self):
531 | """Test memory usage tracking."""
532 | with patch("maverick_mcp.utils.logging.get_logger") as mock_get_logger:
533 | mock_logger = Mock()
534 | mock_get_logger.return_value = mock_logger
535 |
536 | with PerformanceMonitor("memory_test"):
537 | # Simulate memory allocation
538 | data = list(range(1000))
539 | del data
540 |
541 | mock_logger.info.assert_called_once()
542 | call_args = mock_logger.info.call_args
543 | assert "memory_delta_mb" in call_args[1]["extra"]
544 |
545 |
546 | if __name__ == "__main__":
547 | pytest.main([__file__])
548 |
```
--------------------------------------------------------------------------------
/maverick_mcp/utils/memory_profiler.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Memory profiling and management utilities for the backtesting system.
3 | Provides decorators, monitoring, and optimization tools for memory-efficient operations.
4 | """
5 |
6 | import functools
7 | import gc
8 | import logging
9 | import time
10 | import tracemalloc
11 | import warnings
12 | from collections.abc import Callable, Iterator
13 | from contextlib import contextmanager
14 | from dataclasses import dataclass
15 | from typing import Any
16 |
17 | import numpy as np
18 | import pandas as pd
19 | import psutil
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 | # Memory threshold constants (in bytes)
24 | MEMORY_WARNING_THRESHOLD = 1024 * 1024 * 1024 # 1GB
25 | MEMORY_CRITICAL_THRESHOLD = 2 * 1024 * 1024 * 1024 # 2GB
26 | DATAFRAME_SIZE_THRESHOLD = 100 * 1024 * 1024 # 100MB
27 |
28 | # Global memory tracking
29 | _memory_stats = {
30 | "peak_memory": 0,
31 | "current_memory": 0,
32 | "allocation_count": 0,
33 | "gc_count": 0,
34 | "warning_count": 0,
35 | "critical_count": 0,
36 | "dataframe_optimizations": 0,
37 | }
38 |
39 |
40 | @dataclass
41 | class MemorySnapshot:
42 | """Memory usage snapshot."""
43 |
44 | timestamp: float
45 | rss_memory: int
46 | vms_memory: int
47 | available_memory: int
48 | memory_percent: float
49 | peak_memory: int
50 | tracemalloc_current: int
51 | tracemalloc_peak: int
52 | function_name: str = ""
53 |
54 |
55 | class MemoryProfiler:
56 | """Advanced memory profiler with tracking and optimization features."""
57 |
58 | def __init__(self, enable_tracemalloc: bool = True):
59 | """Initialize memory profiler.
60 |
61 | Args:
62 | enable_tracemalloc: Whether to enable detailed memory tracking
63 | """
64 | self.enable_tracemalloc = enable_tracemalloc
65 | self.snapshots: list[MemorySnapshot] = []
66 | self.process = psutil.Process()
67 |
68 | if self.enable_tracemalloc and not tracemalloc.is_tracing():
69 | tracemalloc.start()
70 |
71 | def get_memory_info(self) -> dict[str, Any]:
72 | """Get current memory information."""
73 | memory_info = self.process.memory_info()
74 | virtual_memory = psutil.virtual_memory()
75 |
76 | result = {
77 | "rss_memory": memory_info.rss,
78 | "vms_memory": memory_info.vms,
79 | "available_memory": virtual_memory.available,
80 | "memory_percent": self.process.memory_percent(),
81 | "total_memory": virtual_memory.total,
82 | }
83 |
84 | if self.enable_tracemalloc and tracemalloc.is_tracing():
85 | current, peak = tracemalloc.get_traced_memory()
86 | result.update(
87 | {
88 | "tracemalloc_current": current,
89 | "tracemalloc_peak": peak,
90 | }
91 | )
92 |
93 | return result
94 |
95 | def take_snapshot(self, function_name: str = "") -> MemorySnapshot:
96 | """Take a memory snapshot."""
97 | memory_info = self.get_memory_info()
98 |
99 | snapshot = MemorySnapshot(
100 | timestamp=time.time(),
101 | rss_memory=memory_info["rss_memory"],
102 | vms_memory=memory_info["vms_memory"],
103 | available_memory=memory_info["available_memory"],
104 | memory_percent=memory_info["memory_percent"],
105 | peak_memory=memory_info.get("tracemalloc_peak", 0),
106 | tracemalloc_current=memory_info.get("tracemalloc_current", 0),
107 | tracemalloc_peak=memory_info.get("tracemalloc_peak", 0),
108 | function_name=function_name,
109 | )
110 |
111 | self.snapshots.append(snapshot)
112 |
113 | # Update global stats
114 | _memory_stats["current_memory"] = snapshot.rss_memory
115 | if snapshot.rss_memory > _memory_stats["peak_memory"]:
116 | _memory_stats["peak_memory"] = snapshot.rss_memory
117 |
118 | # Check thresholds
119 | self._check_memory_thresholds(snapshot)
120 |
121 | return snapshot
122 |
123 | def _check_memory_thresholds(self, snapshot: MemorySnapshot) -> None:
124 | """Check memory thresholds and log warnings."""
125 | if snapshot.rss_memory > MEMORY_CRITICAL_THRESHOLD:
126 | _memory_stats["critical_count"] += 1
127 | logger.critical(
128 | f"CRITICAL: Memory usage {snapshot.rss_memory / (1024**3):.2f}GB "
129 | f"exceeds critical threshold in {snapshot.function_name or 'unknown'}"
130 | )
131 | elif snapshot.rss_memory > MEMORY_WARNING_THRESHOLD:
132 | _memory_stats["warning_count"] += 1
133 | logger.warning(
134 | f"WARNING: High memory usage {snapshot.rss_memory / (1024**3):.2f}GB "
135 | f"in {snapshot.function_name or 'unknown'}"
136 | )
137 |
138 | def get_memory_report(self) -> dict[str, Any]:
139 | """Generate comprehensive memory report."""
140 | if not self.snapshots:
141 | return {"error": "No memory snapshots available"}
142 |
143 | latest = self.snapshots[-1]
144 | first = self.snapshots[0]
145 |
146 | report = {
147 | "current_memory_mb": latest.rss_memory / (1024**2),
148 | "peak_memory_mb": max(s.rss_memory for s in self.snapshots) / (1024**2),
149 | "memory_growth_mb": (latest.rss_memory - first.rss_memory) / (1024**2),
150 | "memory_percent": latest.memory_percent,
151 | "available_memory_gb": latest.available_memory / (1024**3),
152 | "snapshots_count": len(self.snapshots),
153 | "warning_count": _memory_stats["warning_count"],
154 | "critical_count": _memory_stats["critical_count"],
155 | "gc_count": _memory_stats["gc_count"],
156 | "dataframe_optimizations": _memory_stats["dataframe_optimizations"],
157 | }
158 |
159 | if self.enable_tracemalloc:
160 | report.update(
161 | {
162 | "tracemalloc_current_mb": latest.tracemalloc_current / (1024**2),
163 | "tracemalloc_peak_mb": latest.tracemalloc_peak / (1024**2),
164 | }
165 | )
166 |
167 | return report
168 |
169 |
170 | # Global profiler instance
171 | _global_profiler = MemoryProfiler()
172 |
173 |
174 | def get_memory_stats() -> dict[str, Any]:
175 | """Get global memory statistics."""
176 | return {**_memory_stats, **_global_profiler.get_memory_report()}
177 |
178 |
179 | def reset_memory_stats() -> None:
180 | """Reset global memory statistics."""
181 | global _memory_stats
182 | _memory_stats = {
183 | "peak_memory": 0,
184 | "current_memory": 0,
185 | "allocation_count": 0,
186 | "gc_count": 0,
187 | "warning_count": 0,
188 | "critical_count": 0,
189 | "dataframe_optimizations": 0,
190 | }
191 | _global_profiler.snapshots.clear()
192 |
193 |
194 | def profile_memory(
195 | func: Callable = None,
196 | *,
197 | log_results: bool = True,
198 | enable_gc: bool = True,
199 | threshold_mb: float = 100.0,
200 | ):
201 | """Decorator to profile memory usage of a function.
202 |
203 | Args:
204 | func: Function to decorate
205 | log_results: Whether to log memory usage results
206 | enable_gc: Whether to trigger garbage collection
207 | threshold_mb: Memory usage threshold to log warnings (MB)
208 | """
209 |
210 | def decorator(f: Callable) -> Callable:
211 | @functools.wraps(f)
212 | def wrapper(*args, **kwargs):
213 | function_name = f.__name__
214 |
215 | # Take initial snapshot
216 | initial = _global_profiler.take_snapshot(f"start_{function_name}")
217 |
218 | try:
219 | # Execute function
220 | result = f(*args, **kwargs)
221 |
222 | # Take final snapshot
223 | final = _global_profiler.take_snapshot(f"end_{function_name}")
224 |
225 | # Calculate memory usage
226 | memory_diff_mb = (final.rss_memory - initial.rss_memory) / (1024**2)
227 |
228 | if log_results:
229 | if memory_diff_mb > threshold_mb:
230 | logger.warning(
231 | f"High memory usage in {function_name}: "
232 | f"{memory_diff_mb:.2f}MB (threshold: {threshold_mb}MB)"
233 | )
234 | else:
235 | logger.debug(
236 | f"Memory usage in {function_name}: {memory_diff_mb:.2f}MB"
237 | )
238 |
239 | # Trigger garbage collection if enabled
240 | if enable_gc and memory_diff_mb > threshold_mb:
241 | force_garbage_collection()
242 |
243 | return result
244 |
245 | except Exception as e:
246 | # Take error snapshot
247 | _global_profiler.take_snapshot(f"error_{function_name}")
248 | raise e
249 |
250 | return wrapper
251 |
252 | if func is None:
253 | return decorator
254 | else:
255 | return decorator(func)
256 |
257 |
258 | @contextmanager
259 | def memory_context(
260 | name: str = "operation", cleanup_after: bool = True
261 | ) -> Iterator[MemoryProfiler]:
262 | """Context manager for memory profiling operations.
263 |
264 | Args:
265 | name: Name of the operation
266 | cleanup_after: Whether to run garbage collection after
267 |
268 | Yields:
269 | MemoryProfiler instance for manual snapshots
270 | """
271 | profiler = MemoryProfiler()
272 | initial = profiler.take_snapshot(f"start_{name}")
273 |
274 | try:
275 | yield profiler
276 | finally:
277 | final = profiler.take_snapshot(f"end_{name}")
278 |
279 | memory_diff_mb = (final.rss_memory - initial.rss_memory) / (1024**2)
280 | logger.debug(f"Memory usage in {name}: {memory_diff_mb:.2f}MB")
281 |
282 | if cleanup_after:
283 | force_garbage_collection()
284 |
285 |
286 | def optimize_dataframe(
287 | df: pd.DataFrame, aggressive: bool = False, categorical_threshold: float = 0.5
288 | ) -> pd.DataFrame:
289 | """Optimize DataFrame memory usage.
290 |
291 | Args:
292 | df: DataFrame to optimize
293 | aggressive: Whether to use aggressive optimizations
294 | categorical_threshold: Threshold for converting to categorical
295 |
296 | Returns:
297 | Optimized DataFrame
298 | """
299 | initial_memory = df.memory_usage(deep=True).sum()
300 |
301 | if initial_memory < DATAFRAME_SIZE_THRESHOLD:
302 | return df # Skip optimization for small DataFrames
303 |
304 | df_optimized = df.copy()
305 |
306 | for col in df_optimized.columns:
307 | col_type = df_optimized[col].dtype
308 |
309 | if col_type == "object":
310 | # Try to convert to categorical if many duplicates
311 | unique_ratio = df_optimized[col].nunique() / len(df_optimized[col])
312 | if unique_ratio < categorical_threshold:
313 | try:
314 | df_optimized[col] = df_optimized[col].astype("category")
315 | except Exception:
316 | pass
317 |
318 | elif "int" in str(col_type):
319 | # Downcast integers
320 | c_min = df_optimized[col].min()
321 | c_max = df_optimized[col].max()
322 |
323 | if c_min > np.iinfo(np.int8).min and c_max < np.iinfo(np.int8).max:
324 | df_optimized[col] = df_optimized[col].astype(np.int8)
325 | elif c_min > np.iinfo(np.int16).min and c_max < np.iinfo(np.int16).max:
326 | df_optimized[col] = df_optimized[col].astype(np.int16)
327 | elif c_min > np.iinfo(np.int32).min and c_max < np.iinfo(np.int32).max:
328 | df_optimized[col] = df_optimized[col].astype(np.int32)
329 |
330 | elif "float" in str(col_type):
331 | # Downcast floats
332 | if aggressive:
333 | # Try float32 first
334 | try:
335 | temp = df_optimized[col].astype(np.float32)
336 | if np.allclose(
337 | df_optimized[col].fillna(0),
338 | temp.fillna(0),
339 | rtol=1e-6,
340 | equal_nan=True,
341 | ):
342 | df_optimized[col] = temp
343 | except Exception:
344 | pass
345 |
346 | final_memory = df_optimized.memory_usage(deep=True).sum()
347 | memory_saved = initial_memory - final_memory
348 |
349 | if memory_saved > 0:
350 | _memory_stats["dataframe_optimizations"] += 1
351 | logger.debug(
352 | f"DataFrame optimized: {memory_saved / (1024**2):.2f}MB saved "
353 | f"({memory_saved / initial_memory * 100:.1f}% reduction)"
354 | )
355 |
356 | return df_optimized
357 |
358 |
359 | def force_garbage_collection() -> dict[str, int]:
360 | """Force garbage collection and return statistics."""
361 | collected = gc.collect()
362 | _memory_stats["gc_count"] += 1
363 |
364 | stats = {
365 | "collected": collected,
366 | "generation_0": len(gc.get_objects(0)),
367 | "generation_1": len(gc.get_objects(1)),
368 | "generation_2": len(gc.get_objects(2)),
369 | "total_objects": len(gc.get_objects()),
370 | }
371 |
372 | logger.debug(f"Garbage collection: {collected} objects collected")
373 | return stats
374 |
375 |
376 | def check_memory_leak(threshold_mb: float = 100.0) -> bool:
377 | """Check for potential memory leaks.
378 |
379 | Args:
380 | threshold_mb: Memory growth threshold to consider a leak
381 |
382 | Returns:
383 | True if potential leak detected
384 | """
385 | if len(_global_profiler.snapshots) < 10:
386 | return False
387 |
388 | # Compare recent snapshots
389 | recent = _global_profiler.snapshots[-5:]
390 | older = _global_profiler.snapshots[-10:-5]
391 |
392 | recent_avg = sum(s.rss_memory for s in recent) / len(recent)
393 | older_avg = sum(s.rss_memory for s in older) / len(older)
394 |
395 | growth_mb = (recent_avg - older_avg) / (1024**2)
396 |
397 | if growth_mb > threshold_mb:
398 | logger.warning(f"Potential memory leak detected: {growth_mb:.2f}MB growth")
399 | return True
400 |
401 | return False
402 |
403 |
404 | class DataFrameChunker:
405 | """Utility for processing DataFrames in memory-efficient chunks."""
406 |
407 | def __init__(self, chunk_size_mb: float = 50.0):
408 | """Initialize chunker.
409 |
410 | Args:
411 | chunk_size_mb: Maximum chunk size in MB
412 | """
413 | self.chunk_size_mb = chunk_size_mb
414 | self.chunk_size_bytes = int(chunk_size_mb * 1024 * 1024)
415 |
416 | def chunk_dataframe(self, df: pd.DataFrame) -> Iterator[pd.DataFrame]:
417 | """Yield DataFrame chunks based on memory size.
418 |
419 | Args:
420 | df: DataFrame to chunk
421 |
422 | Yields:
423 | DataFrame chunks
424 | """
425 | total_memory = df.memory_usage(deep=True).sum()
426 |
427 | if total_memory <= self.chunk_size_bytes:
428 | yield df
429 | return
430 |
431 | # Calculate approximate rows per chunk
432 | memory_per_row = total_memory / len(df)
433 | rows_per_chunk = max(1, int(self.chunk_size_bytes / memory_per_row))
434 |
435 | logger.debug(
436 | f"Chunking DataFrame: {len(df)} rows, ~{rows_per_chunk} rows per chunk"
437 | )
438 |
439 | for i in range(0, len(df), rows_per_chunk):
440 | chunk = df.iloc[i : i + rows_per_chunk]
441 | yield chunk
442 |
443 | def process_in_chunks(
444 | self,
445 | df: pd.DataFrame,
446 | processor: Callable[[pd.DataFrame], Any],
447 | combine_results: Callable = None,
448 | ) -> Any:
449 | """Process DataFrame in chunks and optionally combine results.
450 |
451 | Args:
452 | df: DataFrame to process
453 | processor: Function to apply to each chunk
454 | combine_results: Function to combine chunk results
455 |
456 | Returns:
457 | Combined results or list of chunk results
458 | """
459 | results = []
460 |
461 | with memory_context("chunk_processing"):
462 | for i, chunk in enumerate(self.chunk_dataframe(df)):
463 | logger.debug(f"Processing chunk {i + 1}")
464 |
465 | with memory_context(f"chunk_{i}"):
466 | result = processor(chunk)
467 | results.append(result)
468 |
469 | if combine_results:
470 | return combine_results(results)
471 |
472 | return results
473 |
474 |
475 | def cleanup_dataframes(*dfs: pd.DataFrame) -> None:
476 | """Clean up DataFrames and force garbage collection.
477 |
478 | Args:
479 | *dfs: DataFrames to clean up
480 | """
481 | for df in dfs:
482 | if hasattr(df, "_mgr"):
483 | # Clear internal references
484 | df._mgr = None
485 | del df
486 |
487 | force_garbage_collection()
488 |
489 |
490 | def get_dataframe_memory_usage(df: pd.DataFrame) -> dict[str, Any]:
491 | """Get detailed memory usage information for a DataFrame.
492 |
493 | Args:
494 | df: DataFrame to analyze
495 |
496 | Returns:
497 | Memory usage statistics
498 | """
499 | memory_usage = df.memory_usage(deep=True)
500 |
501 | return {
502 | "total_memory_mb": memory_usage.sum() / (1024**2),
503 | "index_memory_mb": memory_usage.iloc[0] / (1024**2),
504 | "columns_memory_mb": {
505 | col: memory_usage.loc[col] / (1024**2) for col in df.columns
506 | },
507 | "shape": df.shape,
508 | "dtypes": df.dtypes.to_dict(),
509 | "memory_per_row_bytes": memory_usage.sum() / len(df) if len(df) > 0 else 0,
510 | }
511 |
512 |
513 | @contextmanager
514 | def memory_limit_context(limit_mb: float) -> Iterator[None]:
515 | """Context manager to monitor memory usage within a limit.
516 |
517 | Args:
518 | limit_mb: Memory limit in MB
519 |
520 | Raises:
521 | MemoryError: If memory usage exceeds limit
522 | """
523 | initial_memory = psutil.Process().memory_info().rss
524 | limit_bytes = limit_mb * 1024 * 1024
525 |
526 | try:
527 | yield
528 | finally:
529 | current_memory = psutil.Process().memory_info().rss
530 | memory_used = current_memory - initial_memory
531 |
532 | if memory_used > limit_bytes:
533 | logger.error(
534 | f"Memory limit exceeded: {memory_used / (1024**2):.2f}MB > {limit_mb}MB"
535 | )
536 | # Force cleanup
537 | force_garbage_collection()
538 |
539 |
540 | def suggest_memory_optimizations(df: pd.DataFrame) -> list[str]:
541 | """Suggest memory optimizations for a DataFrame.
542 |
543 | Args:
544 | df: DataFrame to analyze
545 |
546 | Returns:
547 | List of optimization suggestions
548 | """
549 | suggestions = []
550 | memory_info = get_dataframe_memory_usage(df)
551 |
552 | # Check for object columns that could be categorical
553 | for col in df.columns:
554 | if df[col].dtype == "object":
555 | unique_ratio = df[col].nunique() / len(df)
556 | if unique_ratio < 0.5:
557 | memory_savings = memory_info["columns_memory_mb"][col] * (
558 | 1 - unique_ratio
559 | )
560 | suggestions.append(
561 | f"Convert '{col}' to categorical (potential savings: "
562 | f"{memory_savings:.2f}MB, {unique_ratio:.1%} unique values)"
563 | )
564 |
565 | # Check for float64 that could be float32
566 | for col in df.columns:
567 | if df[col].dtype == "float64":
568 | try:
569 | temp = df[col].astype(np.float32)
570 | if np.allclose(df[col].fillna(0), temp.fillna(0), rtol=1e-6):
571 | savings = memory_info["columns_memory_mb"][col] * 0.5
572 | suggestions.append(
573 | f"Convert '{col}' from float64 to float32 "
574 | f"(potential savings: {savings:.2f}MB)"
575 | )
576 | except Exception:
577 | pass
578 |
579 | # Check for integer downcasting opportunities
580 | for col in df.columns:
581 | if "int" in str(df[col].dtype):
582 | c_min = df[col].min()
583 | c_max = df[col].max()
584 | current_bytes = df[col].memory_usage(deep=True) / len(df)
585 |
586 | if c_min >= np.iinfo(np.int8).min and c_max <= np.iinfo(np.int8).max:
587 | if current_bytes > 1:
588 | savings = (current_bytes - 1) * len(df) / (1024**2)
589 | suggestions.append(
590 | f"Convert '{col}' to int8 (potential savings: {savings:.2f}MB)"
591 | )
592 |
593 | return suggestions
594 |
595 |
596 | # Initialize memory monitoring with warning suppression for resource warnings
597 | def _suppress_resource_warnings():
598 | """Suppress ResourceWarnings that can clutter logs during memory profiling."""
599 | warnings.filterwarnings("ignore", category=ResourceWarning)
600 |
601 |
602 | # Auto-initialize
603 | _suppress_resource_warnings()
604 |
```
--------------------------------------------------------------------------------
/maverick_mcp/api/routers/agents.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Agent router for LangGraph-based financial analysis agents.
3 |
4 | This router exposes the LangGraph agents as MCP tools while maintaining
5 | compatibility with the existing infrastructure.
6 | """
7 |
8 | import logging
9 | import os
10 | from typing import Any
11 |
12 | from fastmcp import FastMCP
13 |
14 | from maverick_mcp.agents.deep_research import DeepResearchAgent
15 | from maverick_mcp.agents.market_analysis import MarketAnalysisAgent
16 | from maverick_mcp.agents.supervisor import SupervisorAgent
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 | # Create the agents router
21 | agents_router: FastMCP = FastMCP("Financial_Analysis_Agents")
22 |
23 |
24 | # Cache for agent instances to avoid recreation
25 | _agent_cache: dict[str, Any] = {}
26 |
27 |
28 | def get_or_create_agent(agent_type: str, persona: str = "moderate") -> Any:
29 | """Get or create an agent instance with caching."""
30 | cache_key = f"{agent_type}:{persona}"
31 |
32 | if cache_key not in _agent_cache:
33 | # Import task-aware LLM factory
34 | from maverick_mcp.providers.llm_factory import get_llm
35 | from maverick_mcp.providers.openrouter_provider import TaskType
36 |
37 | # Map agent types to task types for optimal model selection
38 | task_mapping = {
39 | "market": TaskType.MARKET_ANALYSIS,
40 | "technical": TaskType.TECHNICAL_ANALYSIS,
41 | "supervisor": TaskType.MULTI_AGENT_ORCHESTRATION,
42 | "deep_research": TaskType.DEEP_RESEARCH,
43 | }
44 |
45 | task_type = task_mapping.get(agent_type, TaskType.GENERAL)
46 |
47 | # Get optimized LLM for this task
48 | llm = get_llm(task_type=task_type)
49 |
50 | # Create agent based on type
51 | if agent_type == "market":
52 | _agent_cache[cache_key] = MarketAnalysisAgent(
53 | llm=llm, persona=persona, ttl_hours=1
54 | )
55 | elif agent_type == "supervisor":
56 | # Create mock agents for supervisor
57 | agents = {
58 | "market": get_or_create_agent("market", persona),
59 | "technical": None, # Would be actual technical agent in full implementation
60 | }
61 | _agent_cache[cache_key] = SupervisorAgent(
62 | llm=llm, agents=agents, persona=persona, ttl_hours=1
63 | )
64 | elif agent_type == "deep_research":
65 | # Get web search API keys from environment
66 | exa_api_key = os.getenv("EXA_API_KEY")
67 |
68 | agent = DeepResearchAgent(
69 | llm=llm,
70 | persona=persona,
71 | ttl_hours=1,
72 | exa_api_key=exa_api_key,
73 | )
74 | # Mark for initialization - will be initialized on first use
75 | agent._needs_initialization = True
76 | _agent_cache[cache_key] = agent
77 | else:
78 | raise ValueError(f"Unknown agent type: {agent_type}")
79 |
80 | return _agent_cache[cache_key]
81 |
82 |
83 | async def analyze_market_with_agent(
84 | query: str,
85 | persona: str = "moderate",
86 | screening_strategy: str = "momentum",
87 | max_results: int = 20,
88 | session_id: str | None = None,
89 | ) -> dict[str, Any]:
90 | """
91 | Analyze market using LangGraph agent with persona-aware recommendations.
92 |
93 | This tool uses advanced AI agents that adapt their analysis based on
94 | investor risk profiles (conservative, moderate, aggressive).
95 |
96 | Args:
97 | query: Market analysis query (e.g., "Find top momentum stocks")
98 | persona: Investor persona (conservative, moderate, aggressive)
99 | screening_strategy: Strategy to use (momentum, maverick, supply_demand_breakout)
100 | max_results: Maximum number of results
101 | session_id: Optional session ID for conversation continuity
102 |
103 | Returns:
104 | Persona-adjusted market analysis with recommendations
105 | """
106 | try:
107 | # Generate session ID if not provided
108 | if not session_id:
109 | import uuid
110 |
111 | session_id = str(uuid.uuid4())
112 |
113 | # Get or create agent
114 | agent = get_or_create_agent("market", persona)
115 |
116 | # Run analysis
117 | result = await agent.analyze_market(
118 | query=query,
119 | session_id=session_id,
120 | screening_strategy=screening_strategy,
121 | max_results=max_results,
122 | )
123 |
124 | return {
125 | "status": "success",
126 | "agent_type": "market_analysis",
127 | "persona": persona,
128 | "session_id": session_id,
129 | **result,
130 | }
131 |
132 | except Exception as e:
133 | logger.error(f"Error in market agent analysis: {str(e)}")
134 | return {"status": "error", "error": str(e), "agent_type": "market_analysis"}
135 |
136 |
137 | async def get_agent_streaming_analysis(
138 | query: str,
139 | persona: str = "moderate",
140 | stream_mode: str = "updates",
141 | session_id: str | None = None,
142 | ) -> dict[str, Any]:
143 | """
144 | Get streaming market analysis with real-time updates.
145 |
146 | This demonstrates LangGraph's streaming capabilities. In a real
147 | implementation, this would return a streaming response.
148 |
149 | Args:
150 | query: Analysis query
151 | persona: Investor persona
152 | stream_mode: Streaming mode (updates, values, messages)
153 | session_id: Optional session ID
154 |
155 | Returns:
156 | Streaming configuration and initial results
157 | """
158 | try:
159 | if not session_id:
160 | import uuid
161 |
162 | session_id = str(uuid.uuid4())
163 |
164 | agent = get_or_create_agent("market", persona)
165 |
166 | # For MCP compatibility, we'll collect streamed results
167 | # In a real implementation, this would be a streaming endpoint
168 | updates = []
169 |
170 | async for chunk in agent.stream_analysis(
171 | query=query, session_id=session_id, stream_mode=stream_mode
172 | ):
173 | updates.append(chunk)
174 | # Limit collected updates for demo
175 | if len(updates) >= 5:
176 | break
177 |
178 | return {
179 | "status": "success",
180 | "stream_mode": stream_mode,
181 | "persona": persona,
182 | "session_id": session_id,
183 | "updates_collected": len(updates),
184 | "sample_updates": updates[:3],
185 | "note": "Full streaming requires WebSocket or SSE endpoint",
186 | }
187 |
188 | except Exception as e:
189 | logger.error(f"Error in streaming analysis: {str(e)}")
190 | return {"status": "error", "error": str(e)}
191 |
192 |
193 | async def orchestrated_analysis(
194 | query: str,
195 | persona: str = "moderate",
196 | routing_strategy: str = "llm_powered",
197 | max_agents: int = 3,
198 | parallel_execution: bool = True,
199 | session_id: str | None = None,
200 | ) -> dict[str, Any]:
201 | """
202 | Run orchestrated multi-agent analysis using the SupervisorAgent.
203 |
204 | This tool coordinates multiple specialized agents to provide comprehensive
205 | financial analysis. The supervisor intelligently routes queries to appropriate
206 | agents and synthesizes their results.
207 |
208 | Args:
209 | query: Financial analysis query
210 | persona: Investor persona (conservative, moderate, aggressive, day_trader)
211 | routing_strategy: How to route tasks (llm_powered, rule_based, hybrid)
212 | max_agents: Maximum number of agents to use
213 | parallel_execution: Whether to run agents in parallel
214 | session_id: Optional session ID for conversation continuity
215 |
216 | Returns:
217 | Orchestrated analysis with synthesized recommendations
218 | """
219 | try:
220 | if not session_id:
221 | import uuid
222 |
223 | session_id = str(uuid.uuid4())
224 |
225 | # Get supervisor agent
226 | supervisor = get_or_create_agent("supervisor", persona)
227 |
228 | # Run orchestrated analysis
229 | result = await supervisor.coordinate_agents(
230 | query=query,
231 | session_id=session_id,
232 | routing_strategy=routing_strategy,
233 | max_agents=max_agents,
234 | parallel_execution=parallel_execution,
235 | )
236 |
237 | return {
238 | "status": "success",
239 | "agent_type": "supervisor_orchestrated",
240 | "persona": persona,
241 | "session_id": session_id,
242 | "routing_strategy": routing_strategy,
243 | "agents_used": result.get("agents_used", []),
244 | "execution_time_ms": result.get("execution_time_ms"),
245 | "synthesis_confidence": result.get("synthesis_confidence"),
246 | **result,
247 | }
248 |
249 | except Exception as e:
250 | logger.error(f"Error in orchestrated analysis: {str(e)}")
251 | return {
252 | "status": "error",
253 | "error": str(e),
254 | "agent_type": "supervisor_orchestrated",
255 | }
256 |
257 |
258 | async def deep_research_financial(
259 | research_topic: str,
260 | persona: str = "moderate",
261 | research_depth: str = "comprehensive",
262 | focus_areas: list[str] | None = None,
263 | timeframe: str = "30d",
264 | session_id: str | None = None,
265 | ) -> dict[str, Any]:
266 | """
267 | Conduct comprehensive financial research using web search and AI analysis.
268 |
269 | This tool performs deep research on financial topics, companies, or market
270 | trends using multiple web search providers and AI-powered content analysis.
271 |
272 | Args:
273 | research_topic: Main research topic (company, symbol, or market theme)
274 | persona: Investor persona affecting research focus
275 | research_depth: Depth level (basic, standard, comprehensive, exhaustive)
276 | focus_areas: Specific areas to focus on (e.g., ["fundamentals", "technicals"])
277 | timeframe: Time range for research (7d, 30d, 90d, 1y)
278 | session_id: Optional session ID for conversation continuity
279 |
280 | Returns:
281 | Comprehensive research report with validated sources and analysis
282 | """
283 | try:
284 | if not session_id:
285 | import uuid
286 |
287 | session_id = str(uuid.uuid4())
288 |
289 | if focus_areas is None:
290 | focus_areas = ["fundamentals", "market_sentiment", "competitive_landscape"]
291 |
292 | # Get deep research agent
293 | researcher = get_or_create_agent("deep_research", persona)
294 |
295 | # Run deep research
296 | result = await researcher.research_comprehensive(
297 | topic=research_topic,
298 | session_id=session_id,
299 | depth=research_depth,
300 | focus_areas=focus_areas,
301 | timeframe=timeframe,
302 | )
303 |
304 | return {
305 | "status": "success",
306 | "agent_type": "deep_research",
307 | "persona": persona,
308 | "session_id": session_id,
309 | "research_topic": research_topic,
310 | "research_depth": research_depth,
311 | "focus_areas": focus_areas,
312 | "sources_analyzed": result.get("total_sources_processed", 0),
313 | "research_confidence": result.get("research_confidence"),
314 | "validation_checks_passed": result.get("validation_checks_passed"),
315 | **result,
316 | }
317 |
318 | except Exception as e:
319 | logger.error(f"Error in deep research: {str(e)}")
320 | return {"status": "error", "error": str(e), "agent_type": "deep_research"}
321 |
322 |
323 | async def compare_multi_agent_analysis(
324 | query: str,
325 | agent_types: list[str] | None = None,
326 | persona: str = "moderate",
327 | session_id: str | None = None,
328 | ) -> dict[str, Any]:
329 | """
330 | Compare analysis results across multiple agent types.
331 |
332 | Runs the same query through different specialized agents to show how
333 | their approaches and insights differ, providing a multi-dimensional view.
334 |
335 | Args:
336 | query: Analysis query to run across multiple agents
337 | agent_types: List of agent types to compare (default: ["market", "supervisor"])
338 | persona: Investor persona for all agents
339 | session_id: Optional session ID prefix
340 |
341 | Returns:
342 | Comparative analysis showing different agent perspectives
343 | """
344 | try:
345 | if not session_id:
346 | import uuid
347 |
348 | session_id = str(uuid.uuid4())
349 |
350 | if agent_types is None:
351 | agent_types = ["market", "supervisor"]
352 |
353 | results = {}
354 | execution_times = {}
355 |
356 | for agent_type in agent_types:
357 | try:
358 | agent = get_or_create_agent(agent_type, persona)
359 |
360 | # Run analysis based on agent type
361 | if agent_type == "market":
362 | result = await agent.analyze_market(
363 | query=query,
364 | session_id=f"{session_id}_{agent_type}",
365 | max_results=10,
366 | )
367 | elif agent_type == "supervisor":
368 | result = await agent.coordinate_agents(
369 | query=query,
370 | session_id=f"{session_id}_{agent_type}",
371 | max_agents=2,
372 | )
373 | else:
374 | continue
375 |
376 | results[agent_type] = {
377 | "summary": result.get("summary", ""),
378 | "key_findings": result.get("key_findings", []),
379 | "confidence": result.get("confidence", 0.0),
380 | "methodology": result.get("methodology", f"{agent_type} analysis"),
381 | }
382 | execution_times[agent_type] = result.get("execution_time_ms", 0)
383 |
384 | except Exception as e:
385 | logger.warning(f"Error with {agent_type} agent: {str(e)}")
386 | results[agent_type] = {"error": str(e), "status": "failed"}
387 |
388 | return {
389 | "status": "success",
390 | "query": query,
391 | "persona": persona,
392 | "agents_compared": list(results.keys()),
393 | "comparison": results,
394 | "execution_times_ms": execution_times,
395 | "insights": "Each agent brings unique analytical perspectives and methodologies",
396 | }
397 |
398 | except Exception as e:
399 | logger.error(f"Error in multi-agent comparison: {str(e)}")
400 | return {"status": "error", "error": str(e)}
401 |
402 |
403 | def list_available_agents() -> dict[str, Any]:
404 | """
405 | List all available LangGraph agents and their capabilities.
406 |
407 | Returns:
408 | Information about available agents and personas
409 | """
410 | return {
411 | "status": "success",
412 | "agents": {
413 | "market_analysis": {
414 | "description": "Market screening and sector analysis",
415 | "personas": ["conservative", "moderate", "aggressive"],
416 | "capabilities": [
417 | "Momentum screening",
418 | "Sector rotation analysis",
419 | "Market breadth indicators",
420 | "Risk-adjusted recommendations",
421 | ],
422 | "streaming_modes": ["updates", "values", "messages", "debug"],
423 | "status": "active",
424 | },
425 | "supervisor_orchestrated": {
426 | "description": "Multi-agent orchestration and coordination",
427 | "personas": ["conservative", "moderate", "aggressive", "day_trader"],
428 | "capabilities": [
429 | "Intelligent query routing",
430 | "Multi-agent coordination",
431 | "Result synthesis and conflict resolution",
432 | "Parallel and sequential execution",
433 | "Comprehensive analysis workflows",
434 | ],
435 | "routing_strategies": ["llm_powered", "rule_based", "hybrid"],
436 | "status": "active",
437 | },
438 | "deep_research": {
439 | "description": "Comprehensive financial research with web search",
440 | "personas": ["conservative", "moderate", "aggressive", "day_trader"],
441 | "capabilities": [
442 | "Multi-provider web search",
443 | "AI-powered content analysis",
444 | "Source validation and credibility scoring",
445 | "Citation and reference management",
446 | "Comprehensive research reports",
447 | ],
448 | "research_depths": ["basic", "standard", "comprehensive", "exhaustive"],
449 | "focus_areas": [
450 | "fundamentals",
451 | "technicals",
452 | "market_sentiment",
453 | "competitive_landscape",
454 | ],
455 | "status": "active",
456 | },
457 | "technical_analysis": {
458 | "description": "Chart patterns and technical indicators",
459 | "status": "coming_soon",
460 | },
461 | "risk_management": {
462 | "description": "Position sizing and portfolio risk",
463 | "status": "coming_soon",
464 | },
465 | "portfolio_optimization": {
466 | "description": "Rebalancing and allocation",
467 | "status": "coming_soon",
468 | },
469 | },
470 | "orchestrated_tools": {
471 | "orchestrated_analysis": "Coordinate multiple agents for comprehensive analysis",
472 | "deep_research_financial": "Conduct thorough research with web search",
473 | "compare_multi_agent_analysis": "Compare different agent perspectives",
474 | },
475 | "features": {
476 | "persona_adaptation": "Agents adjust recommendations based on risk profile",
477 | "conversation_memory": "Maintains context within sessions",
478 | "streaming_support": "Real-time updates during analysis",
479 | "tool_integration": "Access to all MCP financial tools",
480 | "multi_agent_orchestration": "Coordinate multiple specialized agents",
481 | "web_search_research": "AI-powered research with source validation",
482 | "intelligent_routing": "LLM-powered task routing and optimization",
483 | },
484 | "personas": ["conservative", "moderate", "aggressive", "day_trader"],
485 | "routing_strategies": ["llm_powered", "rule_based", "hybrid"],
486 | "research_depths": ["basic", "standard", "comprehensive", "exhaustive"],
487 | }
488 |
489 |
490 | async def compare_personas_analysis(
491 | query: str, session_id: str | None = None
492 | ) -> dict[str, Any]:
493 | """
494 | Compare analysis across different investor personas.
495 |
496 | Runs the same query through conservative, moderate, and aggressive
497 | personas to show how recommendations differ.
498 |
499 | Args:
500 | query: Analysis query to run
501 | session_id: Optional session ID prefix
502 |
503 | Returns:
504 | Comparative analysis across all personas
505 | """
506 | try:
507 | if not session_id:
508 | import uuid
509 |
510 | session_id = str(uuid.uuid4())
511 |
512 | results = {}
513 |
514 | for persona in ["conservative", "moderate", "aggressive"]:
515 | agent = get_or_create_agent("market", persona)
516 |
517 | # Run analysis for this persona
518 | result = await agent.analyze_market(
519 | query=query, session_id=f"{session_id}_{persona}", max_results=10
520 | )
521 |
522 | results[persona] = {
523 | "summary": result.get("results", {}).get("summary", ""),
524 | "top_picks": result.get("results", {}).get("screened_symbols", [])[:5],
525 | "risk_parameters": {
526 | "risk_tolerance": agent.persona.risk_tolerance,
527 | "max_position_size": f"{agent.persona.position_size_max * 100:.1f}%",
528 | "stop_loss_multiplier": agent.persona.stop_loss_multiplier,
529 | },
530 | }
531 |
532 | return {
533 | "status": "success",
534 | "query": query,
535 | "comparison": results,
536 | "insights": "Notice how recommendations vary by risk profile",
537 | }
538 |
539 | except Exception as e:
540 | logger.error(f"Error in persona comparison: {str(e)}")
541 | return {"status": "error", "error": str(e)}
542 |
```