This is page 13 of 29. Use http://codebase.md/wshobson/maverick-mcp?lines=false&page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.md
│ │ ├── config.yml
│ │ ├── feature_request.md
│ │ ├── question.md
│ │ └── security_report.md
│ ├── pull_request_template.md
│ └── workflows
│ ├── claude-code-review.yml
│ └── claude.yml
├── .gitignore
├── .python-version
├── .vscode
│ ├── launch.json
│ └── settings.json
├── alembic
│ ├── env.py
│ ├── script.py.mako
│ └── versions
│ ├── 001_initial_schema.py
│ ├── 003_add_performance_indexes.py
│ ├── 006_rename_metadata_columns.py
│ ├── 008_performance_optimization_indexes.py
│ ├── 009_rename_to_supply_demand.py
│ ├── 010_self_contained_schema.py
│ ├── 011_remove_proprietary_terms.py
│ ├── 013_add_backtest_persistence_models.py
│ ├── 014_add_portfolio_models.py
│ ├── 08e3945a0c93_merge_heads.py
│ ├── 9374a5c9b679_merge_heads_for_testing.py
│ ├── abf9b9afb134_merge_multiple_heads.py
│ ├── adda6d3fd84b_merge_proprietary_terms_removal_with_.py
│ ├── e0c75b0bdadb_fix_financial_data_precision_only.py
│ ├── f0696e2cac15_add_essential_performance_indexes.py
│ └── fix_database_integrity_issues.py
├── alembic.ini
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── DATABASE_SETUP.md
├── docker-compose.override.yml.example
├── docker-compose.yml
├── Dockerfile
├── docs
│ ├── api
│ │ └── backtesting.md
│ ├── BACKTESTING.md
│ ├── COST_BASIS_SPECIFICATION.md
│ ├── deep_research_agent.md
│ ├── exa_research_testing_strategy.md
│ ├── PORTFOLIO_PERSONALIZATION_PLAN.md
│ ├── PORTFOLIO.md
│ ├── SETUP_SELF_CONTAINED.md
│ └── speed_testing_framework.md
├── examples
│ ├── complete_speed_validation.py
│ ├── deep_research_integration.py
│ ├── llm_optimization_example.py
│ ├── llm_speed_demo.py
│ ├── monitoring_example.py
│ ├── parallel_research_example.py
│ ├── speed_optimization_demo.py
│ └── timeout_fix_demonstration.py
├── LICENSE
├── Makefile
├── MANIFEST.in
├── maverick_mcp
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── circuit_breaker.py
│ │ ├── deep_research.py
│ │ ├── market_analysis.py
│ │ ├── optimized_research.py
│ │ ├── supervisor.py
│ │ └── technical_analysis.py
│ ├── api
│ │ ├── __init__.py
│ │ ├── api_server.py
│ │ ├── connection_manager.py
│ │ ├── dependencies
│ │ │ ├── __init__.py
│ │ │ ├── stock_analysis.py
│ │ │ └── technical_analysis.py
│ │ ├── error_handling.py
│ │ ├── inspector_compatible_sse.py
│ │ ├── inspector_sse.py
│ │ ├── middleware
│ │ │ ├── error_handling.py
│ │ │ ├── mcp_logging.py
│ │ │ ├── rate_limiting_enhanced.py
│ │ │ └── security.py
│ │ ├── openapi_config.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── agents.py
│ │ │ ├── backtesting.py
│ │ │ ├── data_enhanced.py
│ │ │ ├── data.py
│ │ │ ├── health_enhanced.py
│ │ │ ├── health_tools.py
│ │ │ ├── health.py
│ │ │ ├── intelligent_backtesting.py
│ │ │ ├── introspection.py
│ │ │ ├── mcp_prompts.py
│ │ │ ├── monitoring.py
│ │ │ ├── news_sentiment_enhanced.py
│ │ │ ├── performance.py
│ │ │ ├── portfolio.py
│ │ │ ├── research.py
│ │ │ ├── screening_ddd.py
│ │ │ ├── screening_parallel.py
│ │ │ ├── screening.py
│ │ │ ├── technical_ddd.py
│ │ │ ├── technical_enhanced.py
│ │ │ ├── technical.py
│ │ │ └── tool_registry.py
│ │ ├── server.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── base_service.py
│ │ │ ├── market_service.py
│ │ │ ├── portfolio_service.py
│ │ │ ├── prompt_service.py
│ │ │ └── resource_service.py
│ │ ├── simple_sse.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── insomnia_export.py
│ │ └── postman_export.py
│ ├── application
│ │ ├── __init__.py
│ │ ├── commands
│ │ │ └── __init__.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_dto.py
│ │ ├── queries
│ │ │ ├── __init__.py
│ │ │ └── get_technical_analysis.py
│ │ └── screening
│ │ ├── __init__.py
│ │ ├── dtos.py
│ │ └── queries.py
│ ├── backtesting
│ │ ├── __init__.py
│ │ ├── ab_testing.py
│ │ ├── analysis.py
│ │ ├── batch_processing_stub.py
│ │ ├── batch_processing.py
│ │ ├── model_manager.py
│ │ ├── optimization.py
│ │ ├── persistence.py
│ │ ├── retraining_pipeline.py
│ │ ├── strategies
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── ml
│ │ │ │ ├── __init__.py
│ │ │ │ ├── adaptive.py
│ │ │ │ ├── ensemble.py
│ │ │ │ ├── feature_engineering.py
│ │ │ │ └── regime_aware.py
│ │ │ ├── ml_strategies.py
│ │ │ ├── parser.py
│ │ │ └── templates.py
│ │ ├── strategy_executor.py
│ │ ├── vectorbt_engine.py
│ │ └── visualization.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── database_self_contained.py
│ │ ├── database.py
│ │ ├── llm_optimization_config.py
│ │ ├── logging_settings.py
│ │ ├── plotly_config.py
│ │ ├── security_utils.py
│ │ ├── security.py
│ │ ├── settings.py
│ │ ├── technical_constants.py
│ │ ├── tool_estimation.py
│ │ └── validation.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── technical_analysis.py
│ │ └── visualization.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── cache_manager.py
│ │ ├── cache.py
│ │ ├── django_adapter.py
│ │ ├── health.py
│ │ ├── models.py
│ │ ├── performance.py
│ │ ├── session_management.py
│ │ └── validation.py
│ ├── database
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── optimization.py
│ ├── dependencies.py
│ ├── domain
│ │ ├── __init__.py
│ │ ├── entities
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis.py
│ │ ├── events
│ │ │ └── __init__.py
│ │ ├── portfolio.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ ├── entities.py
│ │ │ ├── services.py
│ │ │ └── value_objects.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_service.py
│ │ ├── stock_analysis
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis_service.py
│ │ └── value_objects
│ │ ├── __init__.py
│ │ └── technical_indicators.py
│ ├── exceptions.py
│ ├── infrastructure
│ │ ├── __init__.py
│ │ ├── cache
│ │ │ └── __init__.py
│ │ ├── caching
│ │ │ ├── __init__.py
│ │ │ └── cache_management_service.py
│ │ ├── connection_manager.py
│ │ ├── data_fetching
│ │ │ ├── __init__.py
│ │ │ └── stock_data_service.py
│ │ ├── health
│ │ │ ├── __init__.py
│ │ │ └── health_checker.py
│ │ ├── persistence
│ │ │ ├── __init__.py
│ │ │ └── stock_repository.py
│ │ ├── providers
│ │ │ └── __init__.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ └── repositories.py
│ │ └── sse_optimizer.py
│ ├── langchain_tools
│ │ ├── __init__.py
│ │ ├── adapters.py
│ │ └── registry.py
│ ├── logging_config.py
│ ├── memory
│ │ ├── __init__.py
│ │ └── stores.py
│ ├── monitoring
│ │ ├── __init__.py
│ │ ├── health_check.py
│ │ ├── health_monitor.py
│ │ ├── integration_example.py
│ │ ├── metrics.py
│ │ ├── middleware.py
│ │ └── status_dashboard.py
│ ├── providers
│ │ ├── __init__.py
│ │ ├── dependencies.py
│ │ ├── factories
│ │ │ ├── __init__.py
│ │ │ ├── config_factory.py
│ │ │ └── provider_factory.py
│ │ ├── implementations
│ │ │ ├── __init__.py
│ │ │ ├── cache_adapter.py
│ │ │ ├── macro_data_adapter.py
│ │ │ ├── market_data_adapter.py
│ │ │ ├── persistence_adapter.py
│ │ │ └── stock_data_adapter.py
│ │ ├── interfaces
│ │ │ ├── __init__.py
│ │ │ ├── cache.py
│ │ │ ├── config.py
│ │ │ ├── macro_data.py
│ │ │ ├── market_data.py
│ │ │ ├── persistence.py
│ │ │ └── stock_data.py
│ │ ├── llm_factory.py
│ │ ├── macro_data.py
│ │ ├── market_data.py
│ │ ├── mocks
│ │ │ ├── __init__.py
│ │ │ ├── mock_cache.py
│ │ │ ├── mock_config.py
│ │ │ ├── mock_macro_data.py
│ │ │ ├── mock_market_data.py
│ │ │ ├── mock_persistence.py
│ │ │ └── mock_stock_data.py
│ │ ├── openrouter_provider.py
│ │ ├── optimized_screening.py
│ │ ├── optimized_stock_data.py
│ │ └── stock_data.py
│ ├── README.md
│ ├── tests
│ │ ├── __init__.py
│ │ ├── README_INMEMORY_TESTS.md
│ │ ├── test_cache_debug.py
│ │ ├── test_fixes_validation.py
│ │ ├── test_in_memory_routers.py
│ │ ├── test_in_memory_server.py
│ │ ├── test_macro_data_provider.py
│ │ ├── test_mailgun_email.py
│ │ ├── test_market_calendar_caching.py
│ │ ├── test_mcp_tool_fixes_pytest.py
│ │ ├── test_mcp_tool_fixes.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_models_functional.py
│ │ ├── test_server.py
│ │ ├── test_stock_data_enhanced.py
│ │ ├── test_stock_data_provider.py
│ │ └── test_technical_analysis.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── performance_monitoring.py
│ │ ├── portfolio_manager.py
│ │ ├── risk_management.py
│ │ └── sentiment_analysis.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── agent_errors.py
│ │ ├── batch_processing.py
│ │ ├── cache_warmer.py
│ │ ├── circuit_breaker_decorators.py
│ │ ├── circuit_breaker_services.py
│ │ ├── circuit_breaker.py
│ │ ├── data_chunking.py
│ │ ├── database_monitoring.py
│ │ ├── debug_utils.py
│ │ ├── fallback_strategies.py
│ │ ├── llm_optimization.py
│ │ ├── logging_example.py
│ │ ├── logging_init.py
│ │ ├── logging.py
│ │ ├── mcp_logging.py
│ │ ├── memory_profiler.py
│ │ ├── monitoring_middleware.py
│ │ ├── monitoring.py
│ │ ├── orchestration_logging.py
│ │ ├── parallel_research.py
│ │ ├── parallel_screening.py
│ │ ├── quick_cache.py
│ │ ├── resource_manager.py
│ │ ├── shutdown.py
│ │ ├── stock_helpers.py
│ │ ├── structured_logger.py
│ │ ├── tool_monitoring.py
│ │ ├── tracing.py
│ │ └── yfinance_pool.py
│ ├── validation
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── data.py
│ │ ├── middleware.py
│ │ ├── portfolio.py
│ │ ├── responses.py
│ │ ├── screening.py
│ │ └── technical.py
│ └── workflows
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── market_analyzer.py
│ │ ├── optimizer_agent.py
│ │ ├── strategy_selector.py
│ │ └── validator_agent.py
│ ├── backtesting_workflow.py
│ └── state.py
├── PLANS.md
├── pyproject.toml
├── pyrightconfig.json
├── README.md
├── scripts
│ ├── dev.sh
│ ├── INSTALLATION_GUIDE.md
│ ├── load_example.py
│ ├── load_market_data.py
│ ├── load_tiingo_data.py
│ ├── migrate_db.py
│ ├── README_TIINGO_LOADER.md
│ ├── requirements_tiingo.txt
│ ├── run_stock_screening.py
│ ├── run-migrations.sh
│ ├── seed_db.py
│ ├── seed_sp500.py
│ ├── setup_database.sh
│ ├── setup_self_contained.py
│ ├── setup_sp500_database.sh
│ ├── test_seeded_data.py
│ ├── test_tiingo_loader.py
│ ├── tiingo_config.py
│ └── validate_setup.py
├── SECURITY.md
├── server.json
├── setup.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── core
│ │ └── test_technical_analysis.py
│ ├── data
│ │ └── test_portfolio_models.py
│ ├── domain
│ │ ├── conftest.py
│ │ ├── test_portfolio_entities.py
│ │ └── test_technical_analysis_service.py
│ ├── fixtures
│ │ └── orchestration_fixtures.py
│ ├── integration
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── README.md
│ │ ├── run_integration_tests.sh
│ │ ├── test_api_technical.py
│ │ ├── test_chaos_engineering.py
│ │ ├── test_config_management.py
│ │ ├── test_full_backtest_workflow_advanced.py
│ │ ├── test_full_backtest_workflow.py
│ │ ├── test_high_volume.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_orchestration_complete.py
│ │ ├── test_portfolio_persistence.py
│ │ ├── test_redis_cache.py
│ │ ├── test_security_integration.py.disabled
│ │ └── vcr_setup.py
│ ├── performance
│ │ ├── __init__.py
│ │ ├── test_benchmarks.py
│ │ ├── test_load.py
│ │ ├── test_profiling.py
│ │ └── test_stress.py
│ ├── providers
│ │ └── test_stock_data_simple.py
│ ├── README.md
│ ├── test_agents_router_mcp.py
│ ├── test_backtest_persistence.py
│ ├── test_cache_management_service.py
│ ├── test_cache_serialization.py
│ ├── test_circuit_breaker.py
│ ├── test_database_pool_config_simple.py
│ ├── test_database_pool_config.py
│ ├── test_deep_research_functional.py
│ ├── test_deep_research_integration.py
│ ├── test_deep_research_parallel_execution.py
│ ├── test_error_handling.py
│ ├── test_event_loop_integrity.py
│ ├── test_exa_research_integration.py
│ ├── test_exception_hierarchy.py
│ ├── test_financial_search.py
│ ├── test_graceful_shutdown.py
│ ├── test_integration_simple.py
│ ├── test_langgraph_workflow.py
│ ├── test_market_data_async.py
│ ├── test_market_data_simple.py
│ ├── test_mcp_orchestration_functional.py
│ ├── test_ml_strategies.py
│ ├── test_optimized_research_agent.py
│ ├── test_orchestration_integration.py
│ ├── test_orchestration_logging.py
│ ├── test_orchestration_tools_simple.py
│ ├── test_parallel_research_integration.py
│ ├── test_parallel_research_orchestrator.py
│ ├── test_parallel_research_performance.py
│ ├── test_performance_optimizations.py
│ ├── test_production_validation.py
│ ├── test_provider_architecture.py
│ ├── test_rate_limiting_enhanced.py
│ ├── test_runner_validation.py
│ ├── test_security_comprehensive.py.disabled
│ ├── test_security_cors.py
│ ├── test_security_enhancements.py.disabled
│ ├── test_security_headers.py
│ ├── test_security_penetration.py
│ ├── test_session_management.py
│ ├── test_speed_optimization_validation.py
│ ├── test_stock_analysis_dependencies.py
│ ├── test_stock_analysis_service.py
│ ├── test_stock_data_fetching_service.py
│ ├── test_supervisor_agent.py
│ ├── test_supervisor_functional.py
│ ├── test_tool_estimation_config.py
│ ├── test_visualization.py
│ └── utils
│ ├── test_agent_errors.py
│ ├── test_logging.py
│ ├── test_parallel_screening.py
│ └── test_quick_cache.py
├── tools
│ ├── check_orchestration_config.py
│ ├── experiments
│ │ ├── validation_examples.py
│ │ └── validation_fixed.py
│ ├── fast_dev.sh
│ ├── hot_reload.py
│ ├── quick_test.py
│ └── templates
│ ├── new_router_template.py
│ ├── new_tool_template.py
│ ├── screening_strategy_template.py
│ └── test_template.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/tests/test_runner_validation.py:
--------------------------------------------------------------------------------
```python
"""
Test runner validation for parallel research functionality test suites.
This module validates that all test suites follow pytest best practices and async patterns
without triggering circular imports during validation.
"""
import ast
import re
from pathlib import Path
from typing import Any
class TestSuiteValidator:
"""Validator for test suite structure and patterns."""
def __init__(self, test_file_path: str):
self.test_file_path = Path(test_file_path)
self.content = self.test_file_path.read_text()
self.tree = ast.parse(self.content)
def validate_pytest_patterns(self) -> dict[str, Any]:
"""Validate pytest patterns and best practices."""
results = {
"has_pytest_markers": False,
"has_async_tests": False,
"has_fixtures": False,
"has_proper_imports": False,
"has_class_based_tests": False,
"test_count": 0,
"async_test_count": 0,
"fixture_count": 0,
}
# Check imports
for node in ast.walk(self.tree):
if isinstance(node, ast.ImportFrom):
if node.module == "pytest":
results["has_proper_imports"] = True
elif isinstance(node, ast.Import):
for alias in node.names:
if alias.name == "pytest":
results["has_proper_imports"] = True
# Check for pytest markers, fixtures, and test functions
for node in ast.walk(self.tree):
if isinstance(node, ast.FunctionDef):
# Check for test functions
if node.name.startswith("test_"):
results["test_count"] += 1
# Check for async tests
if isinstance(node, ast.AsyncFunctionDef):
results["has_async_tests"] = True
results["async_test_count"] += 1
# Check for fixtures
for decorator in node.decorator_list:
if isinstance(decorator, ast.Attribute):
if decorator.attr == "fixture":
results["has_fixtures"] = True
results["fixture_count"] += 1
elif isinstance(decorator, ast.Name):
if decorator.id == "fixture":
results["has_fixtures"] = True
results["fixture_count"] += 1
elif isinstance(node, ast.AsyncFunctionDef):
if node.name.startswith("test_"):
results["test_count"] += 1
results["has_async_tests"] = True
results["async_test_count"] += 1
# Check for pytest markers
marker_pattern = r"@pytest\.mark\.\w+"
if re.search(marker_pattern, self.content):
results["has_pytest_markers"] = True
# Check for class-based tests
for node in ast.walk(self.tree):
if isinstance(node, ast.ClassDef):
if node.name.startswith("Test"):
results["has_class_based_tests"] = True
break
return results
def validate_async_patterns(self) -> dict[str, Any]:
"""Validate async/await patterns."""
results = {
"proper_async_await": True,
"has_asyncio_imports": False,
"async_fixtures_marked": True,
"issues": [],
}
# Check for asyncio imports
if "import asyncio" in self.content or "from asyncio" in self.content:
results["has_asyncio_imports"] = True
# Check async function patterns
for node in ast.walk(self.tree):
if isinstance(node, ast.AsyncFunctionDef):
# Check if async test functions are properly marked
if node.name.startswith("test_"):
for decorator in node.decorator_list:
if isinstance(decorator, ast.Attribute):
if (
hasattr(decorator.value, "attr")
and decorator.value.attr == "mark"
and decorator.attr == "asyncio"
):
pass
elif isinstance(decorator, ast.Call):
if (
isinstance(decorator.func, ast.Attribute)
and hasattr(decorator.func.value, "attr")
and decorator.func.value.attr == "mark"
and decorator.func.attr == "asyncio"
):
pass
# Not all test environments require explicit asyncio marking
# Modern pytest-asyncio auto-detects async tests
return results
def validate_mock_usage(self) -> dict[str, Any]:
"""Validate mock usage patterns."""
results = {
"has_mocks": False,
"has_async_mocks": False,
"has_patch_usage": False,
"proper_mock_imports": False,
}
# Check mock imports
mock_imports = ["Mock", "AsyncMock", "MagicMock", "patch"]
for imp in mock_imports:
if (
f"from unittest.mock import {imp}" in self.content
or f"import {imp}" in self.content
):
results["proper_mock_imports"] = True
results["has_mocks"] = True
if imp == "AsyncMock":
results["has_async_mocks"] = True
if imp == "patch":
results["has_patch_usage"] = True
return results
class TestParallelResearchTestSuites:
"""Test the test suites for parallel research functionality."""
def test_parallel_research_orchestrator_tests_structure(self):
"""Test structure of ParallelResearchOrchestrator test suite."""
test_file = Path(__file__).parent / "test_parallel_research_orchestrator.py"
assert test_file.exists(), "ParallelResearchOrchestrator test file should exist"
validator = TestSuiteValidator(str(test_file))
results = validator.validate_pytest_patterns()
assert results["test_count"] > 0, "Should have test functions"
assert results["has_async_tests"], "Should have async tests"
assert results["has_fixtures"], "Should have fixtures"
assert results["has_class_based_tests"], "Should have class-based tests"
assert results["async_test_count"] > 0, "Should have async test functions"
def test_deep_research_parallel_execution_tests_structure(self):
"""Test structure of DeepResearchAgent parallel execution test suite."""
test_file = Path(__file__).parent / "test_deep_research_parallel_execution.py"
assert test_file.exists(), "DeepResearchAgent parallel test file should exist"
validator = TestSuiteValidator(str(test_file))
results = validator.validate_pytest_patterns()
assert results["test_count"] > 0, "Should have test functions"
assert results["has_async_tests"], "Should have async tests"
assert results["has_fixtures"], "Should have fixtures"
assert results["has_class_based_tests"], "Should have class-based tests"
def test_orchestration_logging_tests_structure(self):
"""Test structure of OrchestrationLogger test suite."""
test_file = Path(__file__).parent / "test_orchestration_logging.py"
assert test_file.exists(), "OrchestrationLogger test file should exist"
validator = TestSuiteValidator(str(test_file))
results = validator.validate_pytest_patterns()
assert results["test_count"] > 0, "Should have test functions"
assert results["has_async_tests"], "Should have async tests"
assert results["has_fixtures"], "Should have fixtures"
assert results["has_class_based_tests"], "Should have class-based tests"
def test_parallel_research_integration_tests_structure(self):
"""Test structure of parallel research integration test suite."""
test_file = Path(__file__).parent / "test_parallel_research_integration.py"
assert test_file.exists(), (
"Parallel research integration test file should exist"
)
validator = TestSuiteValidator(str(test_file))
results = validator.validate_pytest_patterns()
assert results["test_count"] > 0, "Should have test functions"
assert results["has_async_tests"], "Should have async tests"
assert results["has_fixtures"], "Should have fixtures"
assert results["has_class_based_tests"], "Should have class-based tests"
assert results["has_pytest_markers"], (
"Should have pytest markers (like @pytest.mark.integration)"
)
def test_async_patterns_validation(self):
"""Test that async patterns are properly implemented across all test suites."""
test_files = [
"test_parallel_research_orchestrator.py",
"test_deep_research_parallel_execution.py",
"test_orchestration_logging.py",
"test_parallel_research_integration.py",
]
for test_file in test_files:
file_path = Path(__file__).parent / test_file
if file_path.exists():
validator = TestSuiteValidator(str(file_path))
results = validator.validate_async_patterns()
assert results["proper_async_await"], (
f"Async patterns should be correct in {test_file}"
)
assert results["has_asyncio_imports"], (
f"Should import asyncio in {test_file}"
)
def test_mock_usage_patterns(self):
"""Test that mock usage patterns are consistent across test suites."""
test_files = [
"test_parallel_research_orchestrator.py",
"test_deep_research_parallel_execution.py",
"test_orchestration_logging.py",
"test_parallel_research_integration.py",
]
for test_file in test_files:
file_path = Path(__file__).parent / test_file
if file_path.exists():
validator = TestSuiteValidator(str(file_path))
results = validator.validate_mock_usage()
assert results["has_mocks"], f"Should use mocks in {test_file}"
assert results["proper_mock_imports"], (
f"Should have proper mock imports in {test_file}"
)
# For async-heavy test files, should use AsyncMock
if test_file in [
"test_parallel_research_orchestrator.py",
"test_deep_research_parallel_execution.py",
"test_parallel_research_integration.py",
]:
assert results["has_async_mocks"], (
f"Should use AsyncMock in {test_file}"
)
def test_test_coverage_completeness(self):
"""Test that test coverage is comprehensive for parallel research functionality."""
# Define expected test categories for each component
expected_test_categories = {
"test_parallel_research_orchestrator.py": [
"config",
"task",
"orchestrator",
"distribution",
"result",
"integration",
],
"test_deep_research_parallel_execution.py": [
"agent",
"subagent",
"execution",
"synthesis",
"integration",
],
"test_orchestration_logging.py": [
"logger",
"decorator",
"context",
"utility",
"integrated",
"load",
],
"test_parallel_research_integration.py": [
"endtoend",
"scalability",
"logging",
"error",
"data",
],
}
for test_file, expected_categories in expected_test_categories.items():
file_path = Path(__file__).parent / test_file
if file_path.exists():
content = file_path.read_text().lower()
for category in expected_categories:
assert category in content, (
f"Should have {category} tests in {test_file}"
)
def test_docstring_quality(self):
"""Test that test files have proper docstrings."""
test_files = [
"test_parallel_research_orchestrator.py",
"test_deep_research_parallel_execution.py",
"test_orchestration_logging.py",
"test_parallel_research_integration.py",
]
for test_file in test_files:
file_path = Path(__file__).parent / test_file
if file_path.exists():
content = file_path.read_text()
# Should have module docstring
assert '"""' in content, f"Should have docstrings in {test_file}"
# Should describe what is being tested
docstring_keywords = ["test", "functionality", "cover", "suite"]
first_docstring = content.split('"""')[1].lower()
assert any(
keyword in first_docstring for keyword in docstring_keywords
), f"Module docstring should describe testing purpose in {test_file}"
def test_import_safety(self):
"""Test that imports are safe and avoid circular dependencies."""
test_files = [
"test_parallel_research_orchestrator.py",
"test_deep_research_parallel_execution.py",
"test_orchestration_logging.py",
"test_parallel_research_integration.py",
]
for test_file in test_files:
file_path = Path(__file__).parent / test_file
if file_path.exists():
content = file_path.read_text()
# Should not have circular import patterns
lines = content.split("\n")
import_lines = [
line
for line in lines
if line.strip().startswith(("import ", "from "))
]
# Basic validation that imports are structured
assert len(import_lines) > 0, (
f"Should have import statements in {test_file}"
)
# Should import pytest
pytest_imported = any("pytest" in line for line in import_lines)
assert pytest_imported, f"Should import pytest in {test_file}"
def test_fixture_best_practices(self):
"""Test that fixtures follow best practices."""
test_files = [
"test_parallel_research_orchestrator.py",
"test_deep_research_parallel_execution.py",
"test_orchestration_logging.py",
"test_parallel_research_integration.py",
]
for test_file in test_files:
file_path = Path(__file__).parent / test_file
if file_path.exists():
content = file_path.read_text()
# If file has fixtures, they should be properly structured
if "@pytest.fixture" in content:
# Should have fixture decorators
assert "def " in content, (
f"Fixtures should be functions in {test_file}"
)
# Common fixture patterns should be present
fixture_patterns = ["yield", "return", "Mock", "config"]
has_fixture_pattern = any(
pattern in content for pattern in fixture_patterns
)
assert has_fixture_pattern, (
f"Should have proper fixture patterns in {test_file}"
)
def test_error_handling_coverage(self):
"""Test that error handling scenarios are covered."""
test_files = [
"test_parallel_research_orchestrator.py",
"test_deep_research_parallel_execution.py",
"test_parallel_research_integration.py",
]
for test_file in test_files:
file_path = Path(__file__).parent / test_file
if file_path.exists():
content = file_path.read_text().lower()
# Should test error scenarios
error_keywords = [
"error",
"exception",
"timeout",
"failure",
"fallback",
]
has_error_tests = any(keyword in content for keyword in error_keywords)
assert has_error_tests, f"Should test error scenarios in {test_file}"
def test_performance_testing_coverage(self):
"""Test that performance characteristics are tested."""
performance_test_files = [
"test_parallel_research_orchestrator.py",
"test_parallel_research_integration.py",
]
for test_file in performance_test_files:
file_path = Path(__file__).parent / test_file
if file_path.exists():
content = file_path.read_text().lower()
# Should test performance characteristics
perf_keywords = [
"performance",
"timing",
"efficiency",
"concurrent",
"parallel",
]
has_perf_tests = any(keyword in content for keyword in perf_keywords)
assert has_perf_tests, (
f"Should test performance characteristics in {test_file}"
)
def test_integration_test_markers(self):
"""Test that integration tests are properly marked."""
integration_file = (
Path(__file__).parent / "test_parallel_research_integration.py"
)
if integration_file.exists():
content = integration_file.read_text()
# Should have integration markers
assert "@pytest.mark.integration" in content, (
"Should mark integration tests"
)
# Should have integration test classes
integration_patterns = ["TestParallel", "Integration", "EndToEnd"]
has_integration_classes = any(
pattern in content for pattern in integration_patterns
)
assert has_integration_classes, "Should have integration test classes"
```
--------------------------------------------------------------------------------
/maverick_mcp/providers/openrouter_provider.py:
--------------------------------------------------------------------------------
```python
"""OpenRouter LLM provider with intelligent model selection.
This module provides integration with OpenRouter API for accessing various LLMs
with automatic model selection based on task requirements.
"""
import logging
from enum import Enum
from typing import Any
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
class TaskType(str, Enum):
"""Task types for model selection."""
# Analysis tasks
DEEP_RESEARCH = "deep_research"
MARKET_ANALYSIS = "market_analysis"
TECHNICAL_ANALYSIS = "technical_analysis"
SENTIMENT_ANALYSIS = "sentiment_analysis"
RISK_ASSESSMENT = "risk_assessment"
# Synthesis tasks
RESULT_SYNTHESIS = "result_synthesis"
PORTFOLIO_OPTIMIZATION = "portfolio_optimization"
# Query processing
QUERY_CLASSIFICATION = "query_classification"
QUICK_ANSWER = "quick_answer"
# Complex reasoning
COMPLEX_REASONING = "complex_reasoning"
MULTI_AGENT_ORCHESTRATION = "multi_agent_orchestration"
# Default
GENERAL = "general"
class ModelProfile(BaseModel):
"""Profile for an LLM model with capabilities and costs."""
model_id: str = Field(description="OpenRouter model identifier")
name: str = Field(description="Human-readable model name")
provider: str = Field(description="Model provider (e.g., anthropic, openai)")
context_length: int = Field(description="Maximum context length in tokens")
cost_per_million_input: float = Field(
description="Cost per million input tokens in USD"
)
cost_per_million_output: float = Field(
description="Cost per million output tokens in USD"
)
speed_rating: int = Field(description="Speed rating 1-10 (10 being fastest)")
quality_rating: int = Field(description="Quality rating 1-10 (10 being best)")
best_for: list[TaskType] = Field(description="Task types this model excels at")
temperature: float = Field(
default=0.3, description="Default temperature for this model"
)
# Model profiles for intelligent selection
MODEL_PROFILES = {
# Premium models (use sparingly for critical tasks)
"anthropic/claude-opus-4.1": ModelProfile(
model_id="anthropic/claude-opus-4.1",
name="Claude Opus 4.1",
provider="anthropic",
context_length=200000,
cost_per_million_input=15.0,
cost_per_million_output=75.0,
speed_rating=7,
quality_rating=10,
best_for=[
TaskType.COMPLEX_REASONING, # Only for the most complex tasks
],
temperature=0.3,
),
# Cost-effective high-quality models (primary workhorses)
"anthropic/claude-sonnet-4": ModelProfile(
model_id="anthropic/claude-sonnet-4",
name="Claude Sonnet 4",
provider="anthropic",
context_length=1000000, # 1M token context capability!
cost_per_million_input=3.0,
cost_per_million_output=15.0,
speed_rating=8,
quality_rating=9,
best_for=[
TaskType.DEEP_RESEARCH,
TaskType.MARKET_ANALYSIS,
TaskType.TECHNICAL_ANALYSIS,
TaskType.MULTI_AGENT_ORCHESTRATION,
TaskType.RESULT_SYNTHESIS,
TaskType.PORTFOLIO_OPTIMIZATION,
],
temperature=0.3,
),
"openai/gpt-5": ModelProfile(
model_id="openai/gpt-5",
name="GPT-5",
provider="openai",
context_length=400000,
cost_per_million_input=1.25,
cost_per_million_output=10.0,
speed_rating=8,
quality_rating=9,
best_for=[
TaskType.DEEP_RESEARCH,
TaskType.MARKET_ANALYSIS,
],
temperature=0.3,
),
# Excellent cost-performance ratio models
"google/gemini-2.5-pro": ModelProfile(
model_id="google/gemini-2.5-pro",
name="Gemini 2.5 Pro",
provider="google",
context_length=1000000, # 1M token context!
cost_per_million_input=2.0,
cost_per_million_output=8.0,
speed_rating=8,
quality_rating=9,
best_for=[
TaskType.DEEP_RESEARCH,
TaskType.MARKET_ANALYSIS,
TaskType.TECHNICAL_ANALYSIS,
],
temperature=0.3,
),
"deepseek/deepseek-r1": ModelProfile(
model_id="deepseek/deepseek-r1",
name="DeepSeek R1",
provider="deepseek",
context_length=128000,
cost_per_million_input=0.5,
cost_per_million_output=1.0,
speed_rating=8,
quality_rating=9,
best_for=[
TaskType.MARKET_ANALYSIS,
TaskType.TECHNICAL_ANALYSIS,
TaskType.RISK_ASSESSMENT,
],
temperature=0.3,
),
# Fast, cost-effective models for simpler tasks
# Speed-optimized models for research timeouts
"google/gemini-2.5-flash": ModelProfile(
model_id="google/gemini-2.5-flash",
name="Gemini 2.5 Flash",
provider="google",
context_length=1000000,
cost_per_million_input=0.075, # Ultra low cost
cost_per_million_output=0.30,
speed_rating=10, # 199 tokens/sec - FASTEST available
quality_rating=8,
best_for=[
TaskType.DEEP_RESEARCH,
TaskType.MARKET_ANALYSIS,
TaskType.QUICK_ANSWER,
TaskType.SENTIMENT_ANALYSIS,
],
temperature=0.2,
),
"openai/gpt-4o-mini": ModelProfile(
model_id="openai/gpt-4o-mini",
name="GPT-4o Mini",
provider="openai",
context_length=128000,
cost_per_million_input=0.15,
cost_per_million_output=0.60,
speed_rating=9, # 126 tokens/sec - Excellent speed/cost balance
quality_rating=8,
best_for=[
TaskType.DEEP_RESEARCH,
TaskType.MARKET_ANALYSIS,
TaskType.TECHNICAL_ANALYSIS,
TaskType.QUICK_ANSWER,
],
temperature=0.2,
),
"anthropic/claude-3.5-haiku": ModelProfile(
model_id="anthropic/claude-3.5-haiku",
name="Claude 3.5 Haiku",
provider="anthropic",
context_length=200000,
cost_per_million_input=0.25,
cost_per_million_output=1.25,
speed_rating=7, # 65.6 tokens/sec - Updated with actual speed rating
quality_rating=8,
best_for=[
TaskType.QUERY_CLASSIFICATION,
TaskType.QUICK_ANSWER,
TaskType.SENTIMENT_ANALYSIS,
],
temperature=0.2,
),
"openai/gpt-5-nano": ModelProfile(
model_id="openai/gpt-5-nano",
name="GPT-5 Nano",
provider="openai",
context_length=400000,
cost_per_million_input=0.05,
cost_per_million_output=0.40,
speed_rating=9, # 180 tokens/sec - Very fast
quality_rating=7,
best_for=[
TaskType.QUICK_ANSWER,
TaskType.QUERY_CLASSIFICATION,
TaskType.DEEP_RESEARCH, # Added for emergency research
],
temperature=0.2,
),
# Specialized models
"xai/grok-4": ModelProfile(
model_id="xai/grok-4",
name="Grok 4",
provider="xai",
context_length=128000,
cost_per_million_input=3.0,
cost_per_million_output=12.0,
speed_rating=7,
quality_rating=9,
best_for=[
TaskType.MARKET_ANALYSIS,
TaskType.SENTIMENT_ANALYSIS,
TaskType.PORTFOLIO_OPTIMIZATION,
],
temperature=0.3,
),
}
class OpenRouterProvider:
"""Provider for OpenRouter API with intelligent model selection."""
def __init__(self, api_key: str):
"""Initialize OpenRouter provider.
Args:
api_key: OpenRouter API key
"""
self.api_key = api_key
self.base_url = "https://openrouter.ai/api/v1"
self._model_usage_stats: dict[str, dict[str, int]] = {}
def get_llm(
self,
task_type: TaskType = TaskType.GENERAL,
prefer_fast: bool = False,
prefer_cheap: bool = True, # Default to cost-effective
prefer_quality: bool = False, # Override for premium models
model_override: str | None = None,
temperature: float | None = None,
max_tokens: int = 4096,
timeout_budget: float | None = None, # Emergency mode for timeouts
) -> ChatOpenAI:
"""Get an LLM instance optimized for the task.
Args:
task_type: Type of task to optimize for
prefer_fast: Prioritize speed over quality
prefer_cheap: Prioritize cost over quality (default True)
prefer_quality: Use premium models regardless of cost
model_override: Override model selection
temperature: Override default temperature
max_tokens: Maximum tokens for response
timeout_budget: Available time budget - triggers emergency mode if < 30s
Returns:
Configured ChatOpenAI instance
"""
# Use override if provided
if model_override:
model_id = model_override
model_profile = MODEL_PROFILES.get(
model_id,
ModelProfile(
model_id=model_id,
name=model_id,
provider="unknown",
context_length=128000,
cost_per_million_input=1.0,
cost_per_million_output=1.0,
speed_rating=5,
quality_rating=5,
best_for=[TaskType.GENERAL],
temperature=0.3,
),
)
# Emergency mode for tight timeout budgets
elif timeout_budget is not None and timeout_budget < 30:
model_profile = self._select_emergency_model(task_type, timeout_budget)
model_id = model_profile.model_id
logger.warning(
f"EMERGENCY MODE: Selected ultra-fast model '{model_profile.name}' "
f"for {timeout_budget}s timeout budget"
)
else:
model_profile = self._select_model(
task_type, prefer_fast, prefer_cheap, prefer_quality
)
model_id = model_profile.model_id
# Use provided temperature or model default
final_temperature = (
temperature if temperature is not None else model_profile.temperature
)
# Log model selection
logger.info(
f"Selected model '{model_profile.name}' for task '{task_type}' "
f"(speed={model_profile.speed_rating}/10, quality={model_profile.quality_rating}/10, "
f"cost=${model_profile.cost_per_million_input}/{model_profile.cost_per_million_output} per 1M tokens)"
)
# Track usage
self._track_usage(model_id, task_type)
# Create LangChain ChatOpenAI instance
return ChatOpenAI(
model=model_id,
temperature=final_temperature,
max_tokens=max_tokens,
openai_api_base=self.base_url,
openai_api_key=self.api_key,
default_headers={
"HTTP-Referer": "https://github.com/wshobson/maverick-mcp",
"X-Title": "Maverick MCP",
},
streaming=True,
)
def _select_model(
self,
task_type: TaskType,
prefer_fast: bool = False,
prefer_cheap: bool = True,
prefer_quality: bool = False,
) -> ModelProfile:
"""Select the best model for the task with cost-efficiency in mind.
Args:
task_type: Type of task
prefer_fast: Prioritize speed
prefer_cheap: Prioritize cost (default True)
prefer_quality: Use premium models regardless of cost
Returns:
Selected model profile
"""
candidates = []
# Find models suitable for this task
for profile in MODEL_PROFILES.values():
if task_type in profile.best_for or task_type == TaskType.GENERAL:
candidates.append(profile)
if not candidates:
# Fallback to GPT-5 Nano for general tasks
return MODEL_PROFILES["openai/gpt-5-nano"]
# Score and rank candidates
scored_candidates = []
for profile in candidates:
score = 0
# Calculate average cost for this model
avg_cost = (
profile.cost_per_million_input + profile.cost_per_million_output
) / 2
# Quality preference overrides cost considerations
if prefer_quality:
# Heavily weight quality for premium mode
score += profile.quality_rating * 20
# Task fitness is critical
if task_type in profile.best_for:
score += 40
# Minimal cost consideration
score += max(0, 20 - avg_cost)
else:
# Cost-efficiency focused scoring (default)
# Calculate cost-efficiency ratio
cost_efficiency = profile.quality_rating / max(1, avg_cost)
score += cost_efficiency * 30
# Task fitness bonus
if task_type in profile.best_for:
score += 25
# Base quality (reduced weight)
score += profile.quality_rating * 5
# Speed preference
if prefer_fast:
score += profile.speed_rating * 5
else:
score += profile.speed_rating * 2
# Cost preference adjustment
if prefer_cheap:
# Strong cost preference
cost_score = max(0, 100 - avg_cost * 5)
score += cost_score
else:
# Balanced cost consideration (default)
cost_score = max(0, 60 - avg_cost * 3)
score += cost_score
scored_candidates.append((score, profile))
# Sort by score and return best
scored_candidates.sort(key=lambda x: x[0], reverse=True)
return scored_candidates[0][1]
def _select_emergency_model(
self, task_type: TaskType, timeout_budget: float
) -> ModelProfile:
"""Select the fastest model available for emergency timeout situations.
Emergency mode prioritizes speed above all other considerations.
Used when timeout_budget < 30 seconds.
Args:
task_type: Type of task
timeout_budget: Available time in seconds (< 30s)
Returns:
Fastest available model profile
"""
# Emergency model priority (by actual tokens per second)
# For ultra-tight budgets (< 15s), use only the absolute fastest
if timeout_budget < 15:
return MODEL_PROFILES["google/gemini-2.5-flash"]
# For tight budgets (< 25s), use fastest available models
if timeout_budget < 25:
if task_type in [TaskType.SENTIMENT_ANALYSIS, TaskType.QUICK_ANSWER]:
return MODEL_PROFILES[
"google/gemini-2.5-flash"
] # Fastest for all tasks
return MODEL_PROFILES["openai/gpt-4o-mini"] # Speed + quality balance
# For moderate emergency (< 30s), use speed-optimized models for complex tasks
if task_type in [
TaskType.DEEP_RESEARCH,
TaskType.MARKET_ANALYSIS,
TaskType.TECHNICAL_ANALYSIS,
]:
return MODEL_PROFILES[
"openai/gpt-4o-mini"
] # Best speed/quality for research
# Default to fastest model
return MODEL_PROFILES["google/gemini-2.5-flash"]
def _track_usage(self, model_id: str, task_type: TaskType):
"""Track model usage for analytics.
Args:
model_id: Model identifier
task_type: Task type
"""
if model_id not in self._model_usage_stats:
self._model_usage_stats[model_id] = {}
task_key = task_type.value
if task_key not in self._model_usage_stats[model_id]:
self._model_usage_stats[model_id][task_key] = 0
self._model_usage_stats[model_id][task_key] += 1
def get_usage_stats(self) -> dict[str, dict[str, int]]:
"""Get model usage statistics.
Returns:
Dictionary of model usage by task type
"""
return self._model_usage_stats.copy()
def recommend_models_for_workload(
self, workload: dict[TaskType, int]
) -> dict[str, Any]:
"""Recommend optimal model mix for a given workload.
Args:
workload: Dictionary of task types and their frequencies
Returns:
Recommendations including models and estimated costs
"""
recommendations = {}
total_cost = 0.0
for task_type, frequency in workload.items():
# Select best model for this task
model = self._select_model(task_type)
# Estimate tokens (rough approximation)
avg_input_tokens = 2000
avg_output_tokens = 1000
# Calculate cost
input_cost = (
avg_input_tokens * frequency * model.cost_per_million_input
) / 1_000_000
output_cost = (
avg_output_tokens * frequency * model.cost_per_million_output
) / 1_000_000
task_cost = input_cost + output_cost
recommendations[task_type.value] = {
"model": model.name,
"model_id": model.model_id,
"frequency": frequency,
"estimated_cost": task_cost,
}
total_cost += task_cost
return {
"recommendations": recommendations,
"total_estimated_cost": total_cost,
"cost_per_request": total_cost / sum(workload.values()) if workload else 0,
}
# Convenience function for backward compatibility
def get_openrouter_llm(
api_key: str,
task_type: TaskType = TaskType.GENERAL,
prefer_fast: bool = False,
prefer_cheap: bool = True,
prefer_quality: bool = False,
**kwargs,
) -> ChatOpenAI:
"""Get an OpenRouter LLM instance with cost-efficiency by default.
Args:
api_key: OpenRouter API key
task_type: Task type for model selection
prefer_fast: Prioritize speed
prefer_cheap: Prioritize cost (default True)
prefer_quality: Use premium models regardless of cost
**kwargs: Additional arguments for get_llm
Returns:
Configured ChatOpenAI instance
"""
provider = OpenRouterProvider(api_key)
return provider.get_llm(
task_type=task_type,
prefer_fast=prefer_fast,
prefer_cheap=prefer_cheap,
prefer_quality=prefer_quality,
**kwargs,
)
```
--------------------------------------------------------------------------------
/tests/utils/test_logging.py:
--------------------------------------------------------------------------------
```python
"""
Unit tests for maverick_mcp.utils.logging module.
This module contains comprehensive tests for the structured logging system
to ensure proper logging functionality and context management.
"""
import asyncio
import json
import logging
import time
from unittest.mock import Mock, patch
import pytest
from maverick_mcp.utils.logging import (
PerformanceMonitor,
RequestContextLogger,
StructuredFormatter,
_get_query_type,
_sanitize_params,
get_logger,
log_cache_operation,
log_database_query,
log_external_api_call,
log_tool_execution,
request_id_var,
request_start_var,
setup_structured_logging,
tool_name_var,
user_id_var,
)
class TestStructuredFormatter:
"""Test the StructuredFormatter class."""
def test_basic_format(self):
"""Test basic log formatting."""
formatter = StructuredFormatter()
record = logging.LogRecord(
name="test_logger",
level=logging.INFO,
pathname="/test/path.py",
lineno=42,
msg="Test message",
args=(),
exc_info=None,
)
result = formatter.format(record)
# Parse the JSON output
log_data = json.loads(result)
assert log_data["level"] == "INFO"
assert log_data["logger"] == "test_logger"
assert log_data["message"] == "Test message"
assert log_data["line"] == 42
assert "timestamp" in log_data
def test_format_with_context(self):
"""Test formatting with request context."""
formatter = StructuredFormatter()
# Set context variables
request_id_var.set("test-request-123")
user_id_var.set("user-456")
tool_name_var.set("test_tool")
request_start_var.set(time.time() - 0.5) # 500ms ago
record = logging.LogRecord(
name="test_logger",
level=logging.INFO,
pathname="/test/path.py",
lineno=42,
msg="Test message",
args=(),
exc_info=None,
)
result = formatter.format(record)
log_data = json.loads(result)
assert log_data["request_id"] == "test-request-123"
assert log_data["user_id"] == "user-456"
assert log_data["tool_name"] == "test_tool"
assert "duration_ms" in log_data
assert log_data["duration_ms"] >= 400 # Should be around 500ms
# Clean up
request_id_var.set(None)
user_id_var.set(None)
tool_name_var.set(None)
request_start_var.set(None)
def test_format_with_exception(self):
"""Test formatting with exception information."""
formatter = StructuredFormatter()
try:
raise ValueError("Test error")
except ValueError:
import sys
exc_info = sys.exc_info()
record = logging.LogRecord(
name="test_logger",
level=logging.ERROR,
pathname="/test/path.py",
lineno=42,
msg="Error occurred",
args=(),
exc_info=exc_info,
)
result = formatter.format(record)
log_data = json.loads(result)
assert "exception" in log_data
assert log_data["exception"]["type"] == "ValueError"
assert log_data["exception"]["message"] == "Test error"
assert isinstance(log_data["exception"]["traceback"], list)
def test_format_with_extra_fields(self):
"""Test formatting with extra fields."""
formatter = StructuredFormatter()
record = logging.LogRecord(
name="test_logger",
level=logging.INFO,
pathname="/test/path.py",
lineno=42,
msg="Test message",
args=(),
exc_info=None,
)
# Add extra fields
record.custom_field = "custom_value"
record.user_action = "button_click"
result = formatter.format(record)
log_data = json.loads(result)
assert log_data["custom_field"] == "custom_value"
assert log_data["user_action"] == "button_click"
class TestRequestContextLogger:
"""Test the RequestContextLogger class."""
@pytest.fixture
def mock_logger(self):
"""Create a mock logger."""
return Mock(spec=logging.Logger)
@pytest.fixture
def context_logger(self, mock_logger):
"""Create a RequestContextLogger with mocked dependencies."""
with patch("maverick_mcp.utils.logging.psutil.Process") as mock_process:
mock_process.return_value.memory_info.return_value.rss = (
100 * 1024 * 1024
) # 100MB
mock_process.return_value.cpu_percent.return_value = 15.5
return RequestContextLogger(mock_logger)
def test_info_logging(self, context_logger, mock_logger):
"""Test info level logging."""
context_logger.info("Test message", extra={"custom": "value"})
mock_logger.log.assert_called_once()
call_args = mock_logger.log.call_args
assert call_args[0][0] == logging.INFO
assert call_args[0][1] == "Test message"
assert "extra" in call_args[1]
assert call_args[1]["extra"]["custom"] == "value"
assert "memory_mb" in call_args[1]["extra"]
assert "cpu_percent" in call_args[1]["extra"]
def test_error_logging(self, context_logger, mock_logger):
"""Test error level logging."""
context_logger.error("Error message")
mock_logger.log.assert_called_once()
call_args = mock_logger.log.call_args
assert call_args[0][0] == logging.ERROR
assert call_args[0][1] == "Error message"
def test_debug_logging(self, context_logger, mock_logger):
"""Test debug level logging."""
context_logger.debug("Debug message")
mock_logger.log.assert_called_once()
call_args = mock_logger.log.call_args
assert call_args[0][0] == logging.DEBUG
assert call_args[0][1] == "Debug message"
def test_warning_logging(self, context_logger, mock_logger):
"""Test warning level logging."""
context_logger.warning("Warning message")
mock_logger.log.assert_called_once()
call_args = mock_logger.log.call_args
assert call_args[0][0] == logging.WARNING
assert call_args[0][1] == "Warning message"
def test_critical_logging(self, context_logger, mock_logger):
"""Test critical level logging."""
context_logger.critical("Critical message")
mock_logger.log.assert_called_once()
call_args = mock_logger.log.call_args
assert call_args[0][0] == logging.CRITICAL
assert call_args[0][1] == "Critical message"
class TestLoggingSetup:
"""Test logging setup functions."""
def test_setup_structured_logging_json_format(self):
"""Test setting up structured logging with JSON format."""
with patch("maverick_mcp.utils.logging.logging.getLogger") as mock_get_logger:
mock_root_logger = Mock()
mock_root_logger.handlers = [] # Empty list of handlers
mock_get_logger.return_value = mock_root_logger
setup_structured_logging(log_level="DEBUG", log_format="json")
mock_root_logger.setLevel.assert_called_with(logging.DEBUG)
mock_root_logger.addHandler.assert_called()
def test_setup_structured_logging_text_format(self):
"""Test setting up structured logging with text format."""
with patch("maverick_mcp.utils.logging.logging.getLogger") as mock_get_logger:
mock_root_logger = Mock()
mock_root_logger.handlers = [] # Empty list of handlers
mock_get_logger.return_value = mock_root_logger
setup_structured_logging(log_level="INFO", log_format="text")
mock_root_logger.setLevel.assert_called_with(logging.INFO)
def test_setup_structured_logging_with_file(self):
"""Test setting up structured logging with file output."""
with patch("maverick_mcp.utils.logging.logging.getLogger") as mock_get_logger:
with patch(
"maverick_mcp.utils.logging.logging.FileHandler"
) as mock_file_handler:
mock_root_logger = Mock()
mock_root_logger.handlers = [] # Empty list of handlers
mock_get_logger.return_value = mock_root_logger
setup_structured_logging(log_file="/tmp/test.log")
mock_file_handler.assert_called_with("/tmp/test.log")
assert mock_root_logger.addHandler.call_count == 2 # Console + File
def test_get_logger(self):
"""Test getting a logger with context support."""
logger = get_logger("test_module")
assert isinstance(logger, RequestContextLogger)
class TestToolExecutionLogging:
"""Test the log_tool_execution decorator."""
@pytest.mark.asyncio
async def test_successful_tool_execution(self):
"""Test logging for successful tool execution."""
@log_tool_execution
async def test_tool(param1, param2="default"):
await asyncio.sleep(0.1) # Simulate work
return {"result": "success"}
with patch("maverick_mcp.utils.logging.get_logger") as mock_get_logger:
mock_logger = Mock()
mock_get_logger.return_value = mock_logger
result = await test_tool("test_value", param2="custom")
assert result == {"result": "success"}
assert mock_logger.info.call_count >= 2 # Start + Success
# Check that request context was set and cleared
assert request_id_var.get() is None
assert tool_name_var.get() is None
assert request_start_var.get() is None
@pytest.mark.asyncio
async def test_failed_tool_execution(self):
"""Test logging for failed tool execution."""
@log_tool_execution
async def failing_tool():
raise ValueError("Test error")
with patch("maverick_mcp.utils.logging.get_logger") as mock_get_logger:
mock_logger = Mock()
mock_get_logger.return_value = mock_logger
with pytest.raises(ValueError, match="Test error"):
await failing_tool()
mock_logger.error.assert_called_once()
# Check that context was cleared even after exception
assert request_id_var.get() is None
assert tool_name_var.get() is None
assert request_start_var.get() is None
class TestParameterSanitization:
"""Test parameter sanitization for logging."""
def test_sanitize_sensitive_params(self):
"""Test sanitization of sensitive parameters."""
params = {
"username": "testuser",
"password": "secret123",
"api_key": "key_secret",
"auth_token": "token_value",
"normal_param": "normal_value",
}
sanitized = _sanitize_params(params)
assert sanitized["username"] == "testuser"
assert sanitized["password"] == "***REDACTED***"
assert sanitized["api_key"] == "***REDACTED***"
assert sanitized["auth_token"] == "***REDACTED***"
assert sanitized["normal_param"] == "normal_value"
def test_sanitize_nested_params(self):
"""Test sanitization of nested parameters."""
params = {
"config": {
"database_url": "postgresql://user:pass@host/db",
"secret_key": "secret",
"debug": True,
},
"normal": "value",
}
sanitized = _sanitize_params(params)
assert sanitized["config"]["database_url"] == "postgresql://user:pass@host/db"
assert sanitized["config"]["secret_key"] == "***REDACTED***"
assert sanitized["config"]["debug"] is True
assert sanitized["normal"] == "value"
def test_sanitize_long_lists(self):
"""Test sanitization of long lists."""
params = {
"short_list": [1, 2, 3],
"long_list": list(range(100)),
}
sanitized = _sanitize_params(params)
assert sanitized["short_list"] == [1, 2, 3]
assert sanitized["long_list"] == "[100 items]"
def test_sanitize_long_strings(self):
"""Test sanitization of long strings."""
long_string = "x" * 2000
params = {
"short_string": "hello",
"long_string": long_string,
}
sanitized = _sanitize_params(params)
assert sanitized["short_string"] == "hello"
assert "... (2000 chars total)" in sanitized["long_string"]
assert len(sanitized["long_string"]) < 200
class TestDatabaseQueryLogging:
"""Test database query logging."""
def test_log_database_query_basic(self):
"""Test basic database query logging."""
with patch("maverick_mcp.utils.logging.get_logger") as mock_get_logger:
mock_logger = Mock()
mock_get_logger.return_value = mock_logger
log_database_query("SELECT * FROM users", {"user_id": 123}, 250)
mock_logger.info.assert_called_once()
mock_logger.debug.assert_called_once()
def test_get_query_type(self):
"""Test query type detection."""
assert _get_query_type("SELECT * FROM users") == "SELECT"
assert _get_query_type("INSERT INTO users VALUES (1, 'test')") == "INSERT"
assert _get_query_type("UPDATE users SET name = 'test'") == "UPDATE"
assert _get_query_type("DELETE FROM users WHERE id = 1") == "DELETE"
assert _get_query_type("CREATE TABLE test (id INT)") == "CREATE"
assert _get_query_type("DROP TABLE test") == "DROP"
assert _get_query_type("EXPLAIN SELECT * FROM users") == "OTHER"
def test_slow_query_detection(self):
"""Test slow query detection."""
with patch("maverick_mcp.utils.logging.get_logger") as mock_get_logger:
mock_logger = Mock()
mock_get_logger.return_value = mock_logger
log_database_query("SELECT * FROM large_table", duration_ms=1500)
# Check that slow_query flag is set in extra
call_args = mock_logger.info.call_args
assert call_args[1]["extra"]["slow_query"] is True
class TestCacheOperationLogging:
"""Test cache operation logging."""
def test_log_cache_hit(self):
"""Test logging cache hit."""
with patch("maverick_mcp.utils.logging.get_logger") as mock_get_logger:
mock_logger = Mock()
mock_get_logger.return_value = mock_logger
log_cache_operation("get", "stock_data:AAPL", hit=True, duration_ms=5)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args
assert "hit" in call_args[0][0]
assert call_args[1]["extra"]["cache_hit"] is True
def test_log_cache_miss(self):
"""Test logging cache miss."""
with patch("maverick_mcp.utils.logging.get_logger") as mock_get_logger:
mock_logger = Mock()
mock_get_logger.return_value = mock_logger
log_cache_operation("get", "stock_data:MSFT", hit=False)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args
assert "miss" in call_args[0][0]
assert call_args[1]["extra"]["cache_hit"] is False
class TestExternalAPILogging:
"""Test external API call logging."""
def test_log_successful_api_call(self):
"""Test logging successful API call."""
with patch("maverick_mcp.utils.logging.get_logger") as mock_get_logger:
mock_logger = Mock()
mock_get_logger.return_value = mock_logger
log_external_api_call(
service="yahoo_finance",
endpoint="/v8/finance/chart/AAPL",
method="GET",
status_code=200,
duration_ms=150,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args
assert call_args[1]["extra"]["success"] is True
def test_log_failed_api_call(self):
"""Test logging failed API call."""
with patch("maverick_mcp.utils.logging.get_logger") as mock_get_logger:
mock_logger = Mock()
mock_get_logger.return_value = mock_logger
log_external_api_call(
service="yahoo_finance",
endpoint="/v8/finance/chart/INVALID",
method="GET",
status_code=404,
duration_ms=1000,
error="Symbol not found",
)
mock_logger.error.assert_called_once()
call_args = mock_logger.error.call_args
assert call_args[1]["extra"]["success"] is False
assert call_args[1]["extra"]["error"] == "Symbol not found"
class TestPerformanceMonitor:
"""Test the PerformanceMonitor context manager."""
def test_successful_operation(self):
"""Test monitoring successful operation."""
with patch("maverick_mcp.utils.logging.get_logger") as mock_get_logger:
mock_logger = Mock()
mock_get_logger.return_value = mock_logger
with PerformanceMonitor("test_operation"):
time.sleep(0.1) # Simulate work
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args
assert "completed" in call_args[0][0]
assert call_args[1]["extra"]["success"] is True
assert call_args[1]["extra"]["duration_ms"] >= 100
def test_failed_operation(self):
"""Test monitoring failed operation."""
with patch("maverick_mcp.utils.logging.get_logger") as mock_get_logger:
mock_logger = Mock()
mock_get_logger.return_value = mock_logger
with pytest.raises(ValueError):
with PerformanceMonitor("failing_operation"):
raise ValueError("Test error")
mock_logger.error.assert_called_once()
call_args = mock_logger.error.call_args
assert "failed" in call_args[0][0]
assert call_args[1]["extra"]["success"] is False
assert call_args[1]["extra"]["error_type"] == "ValueError"
def test_memory_tracking(self):
"""Test memory usage tracking."""
with patch("maverick_mcp.utils.logging.get_logger") as mock_get_logger:
mock_logger = Mock()
mock_get_logger.return_value = mock_logger
with PerformanceMonitor("memory_test"):
# Simulate memory allocation
data = list(range(1000))
del data
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args
assert "memory_delta_mb" in call_args[1]["extra"]
if __name__ == "__main__":
pytest.main([__file__])
```
--------------------------------------------------------------------------------
/maverick_mcp/utils/memory_profiler.py:
--------------------------------------------------------------------------------
```python
"""
Memory profiling and management utilities for the backtesting system.
Provides decorators, monitoring, and optimization tools for memory-efficient operations.
"""
import functools
import gc
import logging
import time
import tracemalloc
import warnings
from collections.abc import Callable, Iterator
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any
import numpy as np
import pandas as pd
import psutil
logger = logging.getLogger(__name__)
# Memory threshold constants (in bytes)
MEMORY_WARNING_THRESHOLD = 1024 * 1024 * 1024 # 1GB
MEMORY_CRITICAL_THRESHOLD = 2 * 1024 * 1024 * 1024 # 2GB
DATAFRAME_SIZE_THRESHOLD = 100 * 1024 * 1024 # 100MB
# Global memory tracking
_memory_stats = {
"peak_memory": 0,
"current_memory": 0,
"allocation_count": 0,
"gc_count": 0,
"warning_count": 0,
"critical_count": 0,
"dataframe_optimizations": 0,
}
@dataclass
class MemorySnapshot:
"""Memory usage snapshot."""
timestamp: float
rss_memory: int
vms_memory: int
available_memory: int
memory_percent: float
peak_memory: int
tracemalloc_current: int
tracemalloc_peak: int
function_name: str = ""
class MemoryProfiler:
"""Advanced memory profiler with tracking and optimization features."""
def __init__(self, enable_tracemalloc: bool = True):
"""Initialize memory profiler.
Args:
enable_tracemalloc: Whether to enable detailed memory tracking
"""
self.enable_tracemalloc = enable_tracemalloc
self.snapshots: list[MemorySnapshot] = []
self.process = psutil.Process()
if self.enable_tracemalloc and not tracemalloc.is_tracing():
tracemalloc.start()
def get_memory_info(self) -> dict[str, Any]:
"""Get current memory information."""
memory_info = self.process.memory_info()
virtual_memory = psutil.virtual_memory()
result = {
"rss_memory": memory_info.rss,
"vms_memory": memory_info.vms,
"available_memory": virtual_memory.available,
"memory_percent": self.process.memory_percent(),
"total_memory": virtual_memory.total,
}
if self.enable_tracemalloc and tracemalloc.is_tracing():
current, peak = tracemalloc.get_traced_memory()
result.update(
{
"tracemalloc_current": current,
"tracemalloc_peak": peak,
}
)
return result
def take_snapshot(self, function_name: str = "") -> MemorySnapshot:
"""Take a memory snapshot."""
memory_info = self.get_memory_info()
snapshot = MemorySnapshot(
timestamp=time.time(),
rss_memory=memory_info["rss_memory"],
vms_memory=memory_info["vms_memory"],
available_memory=memory_info["available_memory"],
memory_percent=memory_info["memory_percent"],
peak_memory=memory_info.get("tracemalloc_peak", 0),
tracemalloc_current=memory_info.get("tracemalloc_current", 0),
tracemalloc_peak=memory_info.get("tracemalloc_peak", 0),
function_name=function_name,
)
self.snapshots.append(snapshot)
# Update global stats
_memory_stats["current_memory"] = snapshot.rss_memory
if snapshot.rss_memory > _memory_stats["peak_memory"]:
_memory_stats["peak_memory"] = snapshot.rss_memory
# Check thresholds
self._check_memory_thresholds(snapshot)
return snapshot
def _check_memory_thresholds(self, snapshot: MemorySnapshot) -> None:
"""Check memory thresholds and log warnings."""
if snapshot.rss_memory > MEMORY_CRITICAL_THRESHOLD:
_memory_stats["critical_count"] += 1
logger.critical(
f"CRITICAL: Memory usage {snapshot.rss_memory / (1024**3):.2f}GB "
f"exceeds critical threshold in {snapshot.function_name or 'unknown'}"
)
elif snapshot.rss_memory > MEMORY_WARNING_THRESHOLD:
_memory_stats["warning_count"] += 1
logger.warning(
f"WARNING: High memory usage {snapshot.rss_memory / (1024**3):.2f}GB "
f"in {snapshot.function_name or 'unknown'}"
)
def get_memory_report(self) -> dict[str, Any]:
"""Generate comprehensive memory report."""
if not self.snapshots:
return {"error": "No memory snapshots available"}
latest = self.snapshots[-1]
first = self.snapshots[0]
report = {
"current_memory_mb": latest.rss_memory / (1024**2),
"peak_memory_mb": max(s.rss_memory for s in self.snapshots) / (1024**2),
"memory_growth_mb": (latest.rss_memory - first.rss_memory) / (1024**2),
"memory_percent": latest.memory_percent,
"available_memory_gb": latest.available_memory / (1024**3),
"snapshots_count": len(self.snapshots),
"warning_count": _memory_stats["warning_count"],
"critical_count": _memory_stats["critical_count"],
"gc_count": _memory_stats["gc_count"],
"dataframe_optimizations": _memory_stats["dataframe_optimizations"],
}
if self.enable_tracemalloc:
report.update(
{
"tracemalloc_current_mb": latest.tracemalloc_current / (1024**2),
"tracemalloc_peak_mb": latest.tracemalloc_peak / (1024**2),
}
)
return report
# Global profiler instance
_global_profiler = MemoryProfiler()
def get_memory_stats() -> dict[str, Any]:
"""Get global memory statistics."""
return {**_memory_stats, **_global_profiler.get_memory_report()}
def reset_memory_stats() -> None:
"""Reset global memory statistics."""
global _memory_stats
_memory_stats = {
"peak_memory": 0,
"current_memory": 0,
"allocation_count": 0,
"gc_count": 0,
"warning_count": 0,
"critical_count": 0,
"dataframe_optimizations": 0,
}
_global_profiler.snapshots.clear()
def profile_memory(
func: Callable = None,
*,
log_results: bool = True,
enable_gc: bool = True,
threshold_mb: float = 100.0,
):
"""Decorator to profile memory usage of a function.
Args:
func: Function to decorate
log_results: Whether to log memory usage results
enable_gc: Whether to trigger garbage collection
threshold_mb: Memory usage threshold to log warnings (MB)
"""
def decorator(f: Callable) -> Callable:
@functools.wraps(f)
def wrapper(*args, **kwargs):
function_name = f.__name__
# Take initial snapshot
initial = _global_profiler.take_snapshot(f"start_{function_name}")
try:
# Execute function
result = f(*args, **kwargs)
# Take final snapshot
final = _global_profiler.take_snapshot(f"end_{function_name}")
# Calculate memory usage
memory_diff_mb = (final.rss_memory - initial.rss_memory) / (1024**2)
if log_results:
if memory_diff_mb > threshold_mb:
logger.warning(
f"High memory usage in {function_name}: "
f"{memory_diff_mb:.2f}MB (threshold: {threshold_mb}MB)"
)
else:
logger.debug(
f"Memory usage in {function_name}: {memory_diff_mb:.2f}MB"
)
# Trigger garbage collection if enabled
if enable_gc and memory_diff_mb > threshold_mb:
force_garbage_collection()
return result
except Exception as e:
# Take error snapshot
_global_profiler.take_snapshot(f"error_{function_name}")
raise e
return wrapper
if func is None:
return decorator
else:
return decorator(func)
@contextmanager
def memory_context(
name: str = "operation", cleanup_after: bool = True
) -> Iterator[MemoryProfiler]:
"""Context manager for memory profiling operations.
Args:
name: Name of the operation
cleanup_after: Whether to run garbage collection after
Yields:
MemoryProfiler instance for manual snapshots
"""
profiler = MemoryProfiler()
initial = profiler.take_snapshot(f"start_{name}")
try:
yield profiler
finally:
final = profiler.take_snapshot(f"end_{name}")
memory_diff_mb = (final.rss_memory - initial.rss_memory) / (1024**2)
logger.debug(f"Memory usage in {name}: {memory_diff_mb:.2f}MB")
if cleanup_after:
force_garbage_collection()
def optimize_dataframe(
df: pd.DataFrame, aggressive: bool = False, categorical_threshold: float = 0.5
) -> pd.DataFrame:
"""Optimize DataFrame memory usage.
Args:
df: DataFrame to optimize
aggressive: Whether to use aggressive optimizations
categorical_threshold: Threshold for converting to categorical
Returns:
Optimized DataFrame
"""
initial_memory = df.memory_usage(deep=True).sum()
if initial_memory < DATAFRAME_SIZE_THRESHOLD:
return df # Skip optimization for small DataFrames
df_optimized = df.copy()
for col in df_optimized.columns:
col_type = df_optimized[col].dtype
if col_type == "object":
# Try to convert to categorical if many duplicates
unique_ratio = df_optimized[col].nunique() / len(df_optimized[col])
if unique_ratio < categorical_threshold:
try:
df_optimized[col] = df_optimized[col].astype("category")
except Exception:
pass
elif "int" in str(col_type):
# Downcast integers
c_min = df_optimized[col].min()
c_max = df_optimized[col].max()
if c_min > np.iinfo(np.int8).min and c_max < np.iinfo(np.int8).max:
df_optimized[col] = df_optimized[col].astype(np.int8)
elif c_min > np.iinfo(np.int16).min and c_max < np.iinfo(np.int16).max:
df_optimized[col] = df_optimized[col].astype(np.int16)
elif c_min > np.iinfo(np.int32).min and c_max < np.iinfo(np.int32).max:
df_optimized[col] = df_optimized[col].astype(np.int32)
elif "float" in str(col_type):
# Downcast floats
if aggressive:
# Try float32 first
try:
temp = df_optimized[col].astype(np.float32)
if np.allclose(
df_optimized[col].fillna(0),
temp.fillna(0),
rtol=1e-6,
equal_nan=True,
):
df_optimized[col] = temp
except Exception:
pass
final_memory = df_optimized.memory_usage(deep=True).sum()
memory_saved = initial_memory - final_memory
if memory_saved > 0:
_memory_stats["dataframe_optimizations"] += 1
logger.debug(
f"DataFrame optimized: {memory_saved / (1024**2):.2f}MB saved "
f"({memory_saved / initial_memory * 100:.1f}% reduction)"
)
return df_optimized
def force_garbage_collection() -> dict[str, int]:
"""Force garbage collection and return statistics."""
collected = gc.collect()
_memory_stats["gc_count"] += 1
stats = {
"collected": collected,
"generation_0": len(gc.get_objects(0)),
"generation_1": len(gc.get_objects(1)),
"generation_2": len(gc.get_objects(2)),
"total_objects": len(gc.get_objects()),
}
logger.debug(f"Garbage collection: {collected} objects collected")
return stats
def check_memory_leak(threshold_mb: float = 100.0) -> bool:
"""Check for potential memory leaks.
Args:
threshold_mb: Memory growth threshold to consider a leak
Returns:
True if potential leak detected
"""
if len(_global_profiler.snapshots) < 10:
return False
# Compare recent snapshots
recent = _global_profiler.snapshots[-5:]
older = _global_profiler.snapshots[-10:-5]
recent_avg = sum(s.rss_memory for s in recent) / len(recent)
older_avg = sum(s.rss_memory for s in older) / len(older)
growth_mb = (recent_avg - older_avg) / (1024**2)
if growth_mb > threshold_mb:
logger.warning(f"Potential memory leak detected: {growth_mb:.2f}MB growth")
return True
return False
class DataFrameChunker:
"""Utility for processing DataFrames in memory-efficient chunks."""
def __init__(self, chunk_size_mb: float = 50.0):
"""Initialize chunker.
Args:
chunk_size_mb: Maximum chunk size in MB
"""
self.chunk_size_mb = chunk_size_mb
self.chunk_size_bytes = int(chunk_size_mb * 1024 * 1024)
def chunk_dataframe(self, df: pd.DataFrame) -> Iterator[pd.DataFrame]:
"""Yield DataFrame chunks based on memory size.
Args:
df: DataFrame to chunk
Yields:
DataFrame chunks
"""
total_memory = df.memory_usage(deep=True).sum()
if total_memory <= self.chunk_size_bytes:
yield df
return
# Calculate approximate rows per chunk
memory_per_row = total_memory / len(df)
rows_per_chunk = max(1, int(self.chunk_size_bytes / memory_per_row))
logger.debug(
f"Chunking DataFrame: {len(df)} rows, ~{rows_per_chunk} rows per chunk"
)
for i in range(0, len(df), rows_per_chunk):
chunk = df.iloc[i : i + rows_per_chunk]
yield chunk
def process_in_chunks(
self,
df: pd.DataFrame,
processor: Callable[[pd.DataFrame], Any],
combine_results: Callable = None,
) -> Any:
"""Process DataFrame in chunks and optionally combine results.
Args:
df: DataFrame to process
processor: Function to apply to each chunk
combine_results: Function to combine chunk results
Returns:
Combined results or list of chunk results
"""
results = []
with memory_context("chunk_processing"):
for i, chunk in enumerate(self.chunk_dataframe(df)):
logger.debug(f"Processing chunk {i + 1}")
with memory_context(f"chunk_{i}"):
result = processor(chunk)
results.append(result)
if combine_results:
return combine_results(results)
return results
def cleanup_dataframes(*dfs: pd.DataFrame) -> None:
"""Clean up DataFrames and force garbage collection.
Args:
*dfs: DataFrames to clean up
"""
for df in dfs:
if hasattr(df, "_mgr"):
# Clear internal references
df._mgr = None
del df
force_garbage_collection()
def get_dataframe_memory_usage(df: pd.DataFrame) -> dict[str, Any]:
"""Get detailed memory usage information for a DataFrame.
Args:
df: DataFrame to analyze
Returns:
Memory usage statistics
"""
memory_usage = df.memory_usage(deep=True)
return {
"total_memory_mb": memory_usage.sum() / (1024**2),
"index_memory_mb": memory_usage.iloc[0] / (1024**2),
"columns_memory_mb": {
col: memory_usage.loc[col] / (1024**2) for col in df.columns
},
"shape": df.shape,
"dtypes": df.dtypes.to_dict(),
"memory_per_row_bytes": memory_usage.sum() / len(df) if len(df) > 0 else 0,
}
@contextmanager
def memory_limit_context(limit_mb: float) -> Iterator[None]:
"""Context manager to monitor memory usage within a limit.
Args:
limit_mb: Memory limit in MB
Raises:
MemoryError: If memory usage exceeds limit
"""
initial_memory = psutil.Process().memory_info().rss
limit_bytes = limit_mb * 1024 * 1024
try:
yield
finally:
current_memory = psutil.Process().memory_info().rss
memory_used = current_memory - initial_memory
if memory_used > limit_bytes:
logger.error(
f"Memory limit exceeded: {memory_used / (1024**2):.2f}MB > {limit_mb}MB"
)
# Force cleanup
force_garbage_collection()
def suggest_memory_optimizations(df: pd.DataFrame) -> list[str]:
"""Suggest memory optimizations for a DataFrame.
Args:
df: DataFrame to analyze
Returns:
List of optimization suggestions
"""
suggestions = []
memory_info = get_dataframe_memory_usage(df)
# Check for object columns that could be categorical
for col in df.columns:
if df[col].dtype == "object":
unique_ratio = df[col].nunique() / len(df)
if unique_ratio < 0.5:
memory_savings = memory_info["columns_memory_mb"][col] * (
1 - unique_ratio
)
suggestions.append(
f"Convert '{col}' to categorical (potential savings: "
f"{memory_savings:.2f}MB, {unique_ratio:.1%} unique values)"
)
# Check for float64 that could be float32
for col in df.columns:
if df[col].dtype == "float64":
try:
temp = df[col].astype(np.float32)
if np.allclose(df[col].fillna(0), temp.fillna(0), rtol=1e-6):
savings = memory_info["columns_memory_mb"][col] * 0.5
suggestions.append(
f"Convert '{col}' from float64 to float32 "
f"(potential savings: {savings:.2f}MB)"
)
except Exception:
pass
# Check for integer downcasting opportunities
for col in df.columns:
if "int" in str(df[col].dtype):
c_min = df[col].min()
c_max = df[col].max()
current_bytes = df[col].memory_usage(deep=True) / len(df)
if c_min >= np.iinfo(np.int8).min and c_max <= np.iinfo(np.int8).max:
if current_bytes > 1:
savings = (current_bytes - 1) * len(df) / (1024**2)
suggestions.append(
f"Convert '{col}' to int8 (potential savings: {savings:.2f}MB)"
)
return suggestions
# Initialize memory monitoring with warning suppression for resource warnings
def _suppress_resource_warnings():
"""Suppress ResourceWarnings that can clutter logs during memory profiling."""
warnings.filterwarnings("ignore", category=ResourceWarning)
# Auto-initialize
_suppress_resource_warnings()
```
--------------------------------------------------------------------------------
/maverick_mcp/api/routers/agents.py:
--------------------------------------------------------------------------------
```python
"""
Agent router for LangGraph-based financial analysis agents.
This router exposes the LangGraph agents as MCP tools while maintaining
compatibility with the existing infrastructure.
"""
import logging
import os
from typing import Any
from fastmcp import FastMCP
from maverick_mcp.agents.deep_research import DeepResearchAgent
from maverick_mcp.agents.market_analysis import MarketAnalysisAgent
from maverick_mcp.agents.supervisor import SupervisorAgent
logger = logging.getLogger(__name__)
# Create the agents router
agents_router: FastMCP = FastMCP("Financial_Analysis_Agents")
# Cache for agent instances to avoid recreation
_agent_cache: dict[str, Any] = {}
def get_or_create_agent(agent_type: str, persona: str = "moderate") -> Any:
"""Get or create an agent instance with caching."""
cache_key = f"{agent_type}:{persona}"
if cache_key not in _agent_cache:
# Import task-aware LLM factory
from maverick_mcp.providers.llm_factory import get_llm
from maverick_mcp.providers.openrouter_provider import TaskType
# Map agent types to task types for optimal model selection
task_mapping = {
"market": TaskType.MARKET_ANALYSIS,
"technical": TaskType.TECHNICAL_ANALYSIS,
"supervisor": TaskType.MULTI_AGENT_ORCHESTRATION,
"deep_research": TaskType.DEEP_RESEARCH,
}
task_type = task_mapping.get(agent_type, TaskType.GENERAL)
# Get optimized LLM for this task
llm = get_llm(task_type=task_type)
# Create agent based on type
if agent_type == "market":
_agent_cache[cache_key] = MarketAnalysisAgent(
llm=llm, persona=persona, ttl_hours=1
)
elif agent_type == "supervisor":
# Create mock agents for supervisor
agents = {
"market": get_or_create_agent("market", persona),
"technical": None, # Would be actual technical agent in full implementation
}
_agent_cache[cache_key] = SupervisorAgent(
llm=llm, agents=agents, persona=persona, ttl_hours=1
)
elif agent_type == "deep_research":
# Get web search API keys from environment
exa_api_key = os.getenv("EXA_API_KEY")
agent = DeepResearchAgent(
llm=llm,
persona=persona,
ttl_hours=1,
exa_api_key=exa_api_key,
)
# Mark for initialization - will be initialized on first use
agent._needs_initialization = True
_agent_cache[cache_key] = agent
else:
raise ValueError(f"Unknown agent type: {agent_type}")
return _agent_cache[cache_key]
async def analyze_market_with_agent(
query: str,
persona: str = "moderate",
screening_strategy: str = "momentum",
max_results: int = 20,
session_id: str | None = None,
) -> dict[str, Any]:
"""
Analyze market using LangGraph agent with persona-aware recommendations.
This tool uses advanced AI agents that adapt their analysis based on
investor risk profiles (conservative, moderate, aggressive).
Args:
query: Market analysis query (e.g., "Find top momentum stocks")
persona: Investor persona (conservative, moderate, aggressive)
screening_strategy: Strategy to use (momentum, maverick, supply_demand_breakout)
max_results: Maximum number of results
session_id: Optional session ID for conversation continuity
Returns:
Persona-adjusted market analysis with recommendations
"""
try:
# Generate session ID if not provided
if not session_id:
import uuid
session_id = str(uuid.uuid4())
# Get or create agent
agent = get_or_create_agent("market", persona)
# Run analysis
result = await agent.analyze_market(
query=query,
session_id=session_id,
screening_strategy=screening_strategy,
max_results=max_results,
)
return {
"status": "success",
"agent_type": "market_analysis",
"persona": persona,
"session_id": session_id,
**result,
}
except Exception as e:
logger.error(f"Error in market agent analysis: {str(e)}")
return {"status": "error", "error": str(e), "agent_type": "market_analysis"}
async def get_agent_streaming_analysis(
query: str,
persona: str = "moderate",
stream_mode: str = "updates",
session_id: str | None = None,
) -> dict[str, Any]:
"""
Get streaming market analysis with real-time updates.
This demonstrates LangGraph's streaming capabilities. In a real
implementation, this would return a streaming response.
Args:
query: Analysis query
persona: Investor persona
stream_mode: Streaming mode (updates, values, messages)
session_id: Optional session ID
Returns:
Streaming configuration and initial results
"""
try:
if not session_id:
import uuid
session_id = str(uuid.uuid4())
agent = get_or_create_agent("market", persona)
# For MCP compatibility, we'll collect streamed results
# In a real implementation, this would be a streaming endpoint
updates = []
async for chunk in agent.stream_analysis(
query=query, session_id=session_id, stream_mode=stream_mode
):
updates.append(chunk)
# Limit collected updates for demo
if len(updates) >= 5:
break
return {
"status": "success",
"stream_mode": stream_mode,
"persona": persona,
"session_id": session_id,
"updates_collected": len(updates),
"sample_updates": updates[:3],
"note": "Full streaming requires WebSocket or SSE endpoint",
}
except Exception as e:
logger.error(f"Error in streaming analysis: {str(e)}")
return {"status": "error", "error": str(e)}
async def orchestrated_analysis(
query: str,
persona: str = "moderate",
routing_strategy: str = "llm_powered",
max_agents: int = 3,
parallel_execution: bool = True,
session_id: str | None = None,
) -> dict[str, Any]:
"""
Run orchestrated multi-agent analysis using the SupervisorAgent.
This tool coordinates multiple specialized agents to provide comprehensive
financial analysis. The supervisor intelligently routes queries to appropriate
agents and synthesizes their results.
Args:
query: Financial analysis query
persona: Investor persona (conservative, moderate, aggressive, day_trader)
routing_strategy: How to route tasks (llm_powered, rule_based, hybrid)
max_agents: Maximum number of agents to use
parallel_execution: Whether to run agents in parallel
session_id: Optional session ID for conversation continuity
Returns:
Orchestrated analysis with synthesized recommendations
"""
try:
if not session_id:
import uuid
session_id = str(uuid.uuid4())
# Get supervisor agent
supervisor = get_or_create_agent("supervisor", persona)
# Run orchestrated analysis
result = await supervisor.coordinate_agents(
query=query,
session_id=session_id,
routing_strategy=routing_strategy,
max_agents=max_agents,
parallel_execution=parallel_execution,
)
return {
"status": "success",
"agent_type": "supervisor_orchestrated",
"persona": persona,
"session_id": session_id,
"routing_strategy": routing_strategy,
"agents_used": result.get("agents_used", []),
"execution_time_ms": result.get("execution_time_ms"),
"synthesis_confidence": result.get("synthesis_confidence"),
**result,
}
except Exception as e:
logger.error(f"Error in orchestrated analysis: {str(e)}")
return {
"status": "error",
"error": str(e),
"agent_type": "supervisor_orchestrated",
}
async def deep_research_financial(
research_topic: str,
persona: str = "moderate",
research_depth: str = "comprehensive",
focus_areas: list[str] | None = None,
timeframe: str = "30d",
session_id: str | None = None,
) -> dict[str, Any]:
"""
Conduct comprehensive financial research using web search and AI analysis.
This tool performs deep research on financial topics, companies, or market
trends using multiple web search providers and AI-powered content analysis.
Args:
research_topic: Main research topic (company, symbol, or market theme)
persona: Investor persona affecting research focus
research_depth: Depth level (basic, standard, comprehensive, exhaustive)
focus_areas: Specific areas to focus on (e.g., ["fundamentals", "technicals"])
timeframe: Time range for research (7d, 30d, 90d, 1y)
session_id: Optional session ID for conversation continuity
Returns:
Comprehensive research report with validated sources and analysis
"""
try:
if not session_id:
import uuid
session_id = str(uuid.uuid4())
if focus_areas is None:
focus_areas = ["fundamentals", "market_sentiment", "competitive_landscape"]
# Get deep research agent
researcher = get_or_create_agent("deep_research", persona)
# Run deep research
result = await researcher.research_comprehensive(
topic=research_topic,
session_id=session_id,
depth=research_depth,
focus_areas=focus_areas,
timeframe=timeframe,
)
return {
"status": "success",
"agent_type": "deep_research",
"persona": persona,
"session_id": session_id,
"research_topic": research_topic,
"research_depth": research_depth,
"focus_areas": focus_areas,
"sources_analyzed": result.get("total_sources_processed", 0),
"research_confidence": result.get("research_confidence"),
"validation_checks_passed": result.get("validation_checks_passed"),
**result,
}
except Exception as e:
logger.error(f"Error in deep research: {str(e)}")
return {"status": "error", "error": str(e), "agent_type": "deep_research"}
async def compare_multi_agent_analysis(
query: str,
agent_types: list[str] | None = None,
persona: str = "moderate",
session_id: str | None = None,
) -> dict[str, Any]:
"""
Compare analysis results across multiple agent types.
Runs the same query through different specialized agents to show how
their approaches and insights differ, providing a multi-dimensional view.
Args:
query: Analysis query to run across multiple agents
agent_types: List of agent types to compare (default: ["market", "supervisor"])
persona: Investor persona for all agents
session_id: Optional session ID prefix
Returns:
Comparative analysis showing different agent perspectives
"""
try:
if not session_id:
import uuid
session_id = str(uuid.uuid4())
if agent_types is None:
agent_types = ["market", "supervisor"]
results = {}
execution_times = {}
for agent_type in agent_types:
try:
agent = get_or_create_agent(agent_type, persona)
# Run analysis based on agent type
if agent_type == "market":
result = await agent.analyze_market(
query=query,
session_id=f"{session_id}_{agent_type}",
max_results=10,
)
elif agent_type == "supervisor":
result = await agent.coordinate_agents(
query=query,
session_id=f"{session_id}_{agent_type}",
max_agents=2,
)
else:
continue
results[agent_type] = {
"summary": result.get("summary", ""),
"key_findings": result.get("key_findings", []),
"confidence": result.get("confidence", 0.0),
"methodology": result.get("methodology", f"{agent_type} analysis"),
}
execution_times[agent_type] = result.get("execution_time_ms", 0)
except Exception as e:
logger.warning(f"Error with {agent_type} agent: {str(e)}")
results[agent_type] = {"error": str(e), "status": "failed"}
return {
"status": "success",
"query": query,
"persona": persona,
"agents_compared": list(results.keys()),
"comparison": results,
"execution_times_ms": execution_times,
"insights": "Each agent brings unique analytical perspectives and methodologies",
}
except Exception as e:
logger.error(f"Error in multi-agent comparison: {str(e)}")
return {"status": "error", "error": str(e)}
def list_available_agents() -> dict[str, Any]:
"""
List all available LangGraph agents and their capabilities.
Returns:
Information about available agents and personas
"""
return {
"status": "success",
"agents": {
"market_analysis": {
"description": "Market screening and sector analysis",
"personas": ["conservative", "moderate", "aggressive"],
"capabilities": [
"Momentum screening",
"Sector rotation analysis",
"Market breadth indicators",
"Risk-adjusted recommendations",
],
"streaming_modes": ["updates", "values", "messages", "debug"],
"status": "active",
},
"supervisor_orchestrated": {
"description": "Multi-agent orchestration and coordination",
"personas": ["conservative", "moderate", "aggressive", "day_trader"],
"capabilities": [
"Intelligent query routing",
"Multi-agent coordination",
"Result synthesis and conflict resolution",
"Parallel and sequential execution",
"Comprehensive analysis workflows",
],
"routing_strategies": ["llm_powered", "rule_based", "hybrid"],
"status": "active",
},
"deep_research": {
"description": "Comprehensive financial research with web search",
"personas": ["conservative", "moderate", "aggressive", "day_trader"],
"capabilities": [
"Multi-provider web search",
"AI-powered content analysis",
"Source validation and credibility scoring",
"Citation and reference management",
"Comprehensive research reports",
],
"research_depths": ["basic", "standard", "comprehensive", "exhaustive"],
"focus_areas": [
"fundamentals",
"technicals",
"market_sentiment",
"competitive_landscape",
],
"status": "active",
},
"technical_analysis": {
"description": "Chart patterns and technical indicators",
"status": "coming_soon",
},
"risk_management": {
"description": "Position sizing and portfolio risk",
"status": "coming_soon",
},
"portfolio_optimization": {
"description": "Rebalancing and allocation",
"status": "coming_soon",
},
},
"orchestrated_tools": {
"orchestrated_analysis": "Coordinate multiple agents for comprehensive analysis",
"deep_research_financial": "Conduct thorough research with web search",
"compare_multi_agent_analysis": "Compare different agent perspectives",
},
"features": {
"persona_adaptation": "Agents adjust recommendations based on risk profile",
"conversation_memory": "Maintains context within sessions",
"streaming_support": "Real-time updates during analysis",
"tool_integration": "Access to all MCP financial tools",
"multi_agent_orchestration": "Coordinate multiple specialized agents",
"web_search_research": "AI-powered research with source validation",
"intelligent_routing": "LLM-powered task routing and optimization",
},
"personas": ["conservative", "moderate", "aggressive", "day_trader"],
"routing_strategies": ["llm_powered", "rule_based", "hybrid"],
"research_depths": ["basic", "standard", "comprehensive", "exhaustive"],
}
async def compare_personas_analysis(
query: str, session_id: str | None = None
) -> dict[str, Any]:
"""
Compare analysis across different investor personas.
Runs the same query through conservative, moderate, and aggressive
personas to show how recommendations differ.
Args:
query: Analysis query to run
session_id: Optional session ID prefix
Returns:
Comparative analysis across all personas
"""
try:
if not session_id:
import uuid
session_id = str(uuid.uuid4())
results = {}
for persona in ["conservative", "moderate", "aggressive"]:
agent = get_or_create_agent("market", persona)
# Run analysis for this persona
result = await agent.analyze_market(
query=query, session_id=f"{session_id}_{persona}", max_results=10
)
results[persona] = {
"summary": result.get("results", {}).get("summary", ""),
"top_picks": result.get("results", {}).get("screened_symbols", [])[:5],
"risk_parameters": {
"risk_tolerance": agent.persona.risk_tolerance,
"max_position_size": f"{agent.persona.position_size_max * 100:.1f}%",
"stop_loss_multiplier": agent.persona.stop_loss_multiplier,
},
}
return {
"status": "success",
"query": query,
"comparison": results,
"insights": "Notice how recommendations vary by risk profile",
}
except Exception as e:
logger.error(f"Error in persona comparison: {str(e)}")
return {"status": "error", "error": str(e)}
```
--------------------------------------------------------------------------------
/maverick_mcp/config/tool_estimation.py:
--------------------------------------------------------------------------------
```python
"""Centralised tool usage estimation configuration."""
from __future__ import annotations
from enum import Enum
from typing import Any
from pydantic import BaseModel, ConfigDict, Field, field_validator
class EstimationBasis(str, Enum):
"""Describes how a tool estimate was derived."""
EMPIRICAL = "empirical"
CONSERVATIVE = "conservative"
HEURISTIC = "heuristic"
SIMULATED = "simulated"
class ToolComplexity(str, Enum):
"""Qualitative complexity buckets used for monitoring and reporting."""
SIMPLE = "simple"
STANDARD = "standard"
COMPLEX = "complex"
PREMIUM = "premium"
class ToolEstimate(BaseModel):
"""Static estimate describing expected LLM usage for a tool."""
model_config = ConfigDict(frozen=True)
llm_calls: int = Field(ge=0)
total_tokens: int = Field(ge=0)
confidence: float = Field(ge=0.0, le=1.0)
based_on: EstimationBasis
complexity: ToolComplexity
notes: str | None = None
@field_validator("llm_calls", "total_tokens")
@classmethod
def _non_negative(cls, value: int) -> int:
if value < 0:
raise ValueError("Estimates must be non-negative")
return value
class MonitoringThresholds(BaseModel):
"""Thresholds for triggering alerting logic."""
llm_calls_warning: int = 15
llm_calls_critical: int = 25
tokens_warning: int = 20_000
tokens_critical: int = 35_000
variance_warning: float = 0.5
variance_critical: float = 1.0
model_config = ConfigDict(validate_assignment=True)
@field_validator(
"llm_calls_warning",
"llm_calls_critical",
"tokens_warning",
"tokens_critical",
)
@classmethod
def _positive(cls, value: int) -> int:
if value <= 0:
raise ValueError("Monitoring thresholds must be positive")
return value
class ToolEstimationConfig(BaseModel):
"""Container for all tool estimates used across the service."""
model_config = ConfigDict(arbitrary_types_allowed=True)
default_confidence: float = 0.75
monitoring: MonitoringThresholds = Field(default_factory=MonitoringThresholds)
simple_default: ToolEstimate = Field(
default_factory=lambda: ToolEstimate(
llm_calls=1,
total_tokens=600,
confidence=0.85,
based_on=EstimationBasis.EMPIRICAL,
complexity=ToolComplexity.SIMPLE,
notes="Baseline simple operation",
)
)
standard_default: ToolEstimate = Field(
default_factory=lambda: ToolEstimate(
llm_calls=3,
total_tokens=4000,
confidence=0.75,
based_on=EstimationBasis.HEURISTIC,
complexity=ToolComplexity.STANDARD,
notes="Baseline standard analysis",
)
)
complex_default: ToolEstimate = Field(
default_factory=lambda: ToolEstimate(
llm_calls=6,
total_tokens=9000,
confidence=0.7,
based_on=EstimationBasis.SIMULATED,
complexity=ToolComplexity.COMPLEX,
notes="Baseline complex workflow",
)
)
premium_default: ToolEstimate = Field(
default_factory=lambda: ToolEstimate(
llm_calls=10,
total_tokens=15000,
confidence=0.65,
based_on=EstimationBasis.CONSERVATIVE,
complexity=ToolComplexity.PREMIUM,
notes="Baseline premium orchestration",
)
)
unknown_tool_estimate: ToolEstimate = Field(
default_factory=lambda: ToolEstimate(
llm_calls=3,
total_tokens=5000,
confidence=0.3,
based_on=EstimationBasis.CONSERVATIVE,
complexity=ToolComplexity.STANDARD,
notes="Fallback estimate for unknown tools",
)
)
tool_estimates: dict[str, ToolEstimate] = Field(default_factory=dict)
def model_post_init(self, _context: Any) -> None: # noqa: D401
if not self.tool_estimates:
self.tool_estimates = _build_default_estimates(self)
else:
normalised: dict[str, ToolEstimate] = {}
for key, estimate in self.tool_estimates.items():
normalised[key.lower()] = estimate
self.tool_estimates = normalised
def get_estimate(self, tool_name: str) -> ToolEstimate:
key = tool_name.lower()
return self.tool_estimates.get(key, self.unknown_tool_estimate)
def get_default_for_complexity(self, complexity: ToolComplexity) -> ToolEstimate:
mapping = {
ToolComplexity.SIMPLE: self.simple_default,
ToolComplexity.STANDARD: self.standard_default,
ToolComplexity.COMPLEX: self.complex_default,
ToolComplexity.PREMIUM: self.premium_default,
}
return mapping[complexity]
def get_tools_by_complexity(self, complexity: ToolComplexity) -> list[str]:
return sorted(
[
name
for name, estimate in self.tool_estimates.items()
if estimate.complexity == complexity
]
)
def get_summary_stats(self) -> dict[str, Any]:
if not self.tool_estimates:
return {}
total_tools = len(self.tool_estimates)
by_complexity: dict[str, int] = {c.value: 0 for c in ToolComplexity}
basis_distribution: dict[str, int] = {b.value: 0 for b in EstimationBasis}
llm_total = 0
token_total = 0
confidence_total = 0.0
for estimate in self.tool_estimates.values():
by_complexity[estimate.complexity.value] += 1
basis_distribution[estimate.based_on.value] += 1
llm_total += estimate.llm_calls
token_total += estimate.total_tokens
confidence_total += estimate.confidence
return {
"total_tools": total_tools,
"by_complexity": by_complexity,
"avg_llm_calls": llm_total / total_tools,
"avg_tokens": token_total / total_tools,
"avg_confidence": confidence_total / total_tools,
"basis_distribution": basis_distribution,
}
def should_alert(
self, tool_name: str, actual_llm_calls: int, actual_tokens: int
) -> tuple[bool, str]:
estimate = self.get_estimate(tool_name)
thresholds = self.monitoring
alerts: list[str] = []
if actual_llm_calls >= thresholds.llm_calls_critical:
alerts.append(
f"Critical: LLM calls ({actual_llm_calls}) exceeded threshold ({thresholds.llm_calls_critical})"
)
elif actual_llm_calls >= thresholds.llm_calls_warning:
alerts.append(
f"Warning: LLM calls ({actual_llm_calls}) exceeded threshold ({thresholds.llm_calls_warning})"
)
if actual_tokens >= thresholds.tokens_critical:
alerts.append(
f"Critical: Token usage ({actual_tokens}) exceeded threshold ({thresholds.tokens_critical})"
)
elif actual_tokens >= thresholds.tokens_warning:
alerts.append(
f"Warning: Token usage ({actual_tokens}) exceeded threshold ({thresholds.tokens_warning})"
)
expected_llm = estimate.llm_calls
expected_tokens = estimate.total_tokens
llm_variance = (
float("inf")
if expected_llm == 0 and actual_llm_calls > 0
else ((actual_llm_calls - expected_llm) / max(expected_llm, 1))
)
token_variance = (
float("inf")
if expected_tokens == 0 and actual_tokens > 0
else ((actual_tokens - expected_tokens) / max(expected_tokens, 1))
)
if llm_variance == float("inf") or llm_variance > thresholds.variance_critical:
alerts.append("Critical: LLM call variance exceeded acceptable range")
elif llm_variance > thresholds.variance_warning:
alerts.append("Warning: LLM call variance elevated")
if (
token_variance == float("inf")
or token_variance > thresholds.variance_critical
):
alerts.append("Critical: Token variance exceeded acceptable range")
elif token_variance > thresholds.variance_warning:
alerts.append("Warning: Token variance elevated")
message = "; ".join(alerts)
return (bool(alerts), message)
def _build_default_estimates(config: ToolEstimationConfig) -> dict[str, ToolEstimate]:
data: dict[str, dict[str, Any]] = {
"get_stock_price": {
"llm_calls": 0,
"total_tokens": 200,
"confidence": 0.92,
"based_on": EstimationBasis.EMPIRICAL,
"complexity": ToolComplexity.SIMPLE,
"notes": "Direct market data lookup",
},
"get_company_info": {
"llm_calls": 1,
"total_tokens": 600,
"confidence": 0.88,
"based_on": EstimationBasis.EMPIRICAL,
"complexity": ToolComplexity.SIMPLE,
"notes": "Cached profile summary",
},
"get_stock_info": {
"llm_calls": 1,
"total_tokens": 550,
"confidence": 0.87,
"based_on": EstimationBasis.EMPIRICAL,
"complexity": ToolComplexity.SIMPLE,
"notes": "Quote lookup",
},
"calculate_sma": {
"llm_calls": 0,
"total_tokens": 180,
"confidence": 0.9,
"based_on": EstimationBasis.EMPIRICAL,
"complexity": ToolComplexity.SIMPLE,
"notes": "Local technical calculation",
},
"get_market_hours": {
"llm_calls": 0,
"total_tokens": 120,
"confidence": 0.95,
"based_on": EstimationBasis.EMPIRICAL,
"complexity": ToolComplexity.SIMPLE,
"notes": "Static schedule lookup",
},
"get_chart_links": {
"llm_calls": 1,
"total_tokens": 500,
"confidence": 0.85,
"based_on": EstimationBasis.HEURISTIC,
"complexity": ToolComplexity.SIMPLE,
"notes": "Generates chart URLs",
},
"list_available_agents": {
"llm_calls": 1,
"total_tokens": 800,
"confidence": 0.82,
"based_on": EstimationBasis.HEURISTIC,
"complexity": ToolComplexity.SIMPLE,
"notes": "Lists registered AI agents",
},
"clear_cache": {
"llm_calls": 0,
"total_tokens": 100,
"confidence": 0.9,
"based_on": EstimationBasis.EMPIRICAL,
"complexity": ToolComplexity.SIMPLE,
"notes": "Invalidates cache entries",
},
"get_cached_price_data": {
"llm_calls": 0,
"total_tokens": 150,
"confidence": 0.86,
"based_on": EstimationBasis.EMPIRICAL,
"complexity": ToolComplexity.SIMPLE,
"notes": "Reads cached OHLC data",
},
"get_watchlist": {
"llm_calls": 1,
"total_tokens": 650,
"confidence": 0.84,
"based_on": EstimationBasis.EMPIRICAL,
"complexity": ToolComplexity.SIMPLE,
"notes": "Fetches saved watchlists",
},
"generate_dev_token": {
"llm_calls": 1,
"total_tokens": 700,
"confidence": 0.82,
"based_on": EstimationBasis.HEURISTIC,
"complexity": ToolComplexity.SIMPLE,
"notes": "Generates development API token",
},
"get_rsi_analysis": {
"llm_calls": 2,
"total_tokens": 3000,
"confidence": 0.78,
"based_on": EstimationBasis.EMPIRICAL,
"complexity": ToolComplexity.STANDARD,
"notes": "RSI interpretation",
},
"get_macd_analysis": {
"llm_calls": 3,
"total_tokens": 3200,
"confidence": 0.74,
"based_on": EstimationBasis.EMPIRICAL,
"complexity": ToolComplexity.STANDARD,
"notes": "MACD indicator narrative",
},
"get_support_resistance": {
"llm_calls": 4,
"total_tokens": 3400,
"confidence": 0.72,
"based_on": EstimationBasis.HEURISTIC,
"complexity": ToolComplexity.STANDARD,
"notes": "Support/resistance summary",
},
"fetch_stock_data": {
"llm_calls": 1,
"total_tokens": 2600,
"confidence": 0.8,
"based_on": EstimationBasis.EMPIRICAL,
"complexity": ToolComplexity.STANDARD,
"notes": "Aggregates OHLC data",
},
"get_maverick_stocks": {
"llm_calls": 4,
"total_tokens": 4500,
"confidence": 0.73,
"based_on": EstimationBasis.SIMULATED,
"complexity": ToolComplexity.STANDARD,
"notes": "Retrieves screening candidates",
},
"get_news_sentiment": {
"llm_calls": 3,
"total_tokens": 4800,
"confidence": 0.76,
"based_on": EstimationBasis.EMPIRICAL,
"complexity": ToolComplexity.STANDARD,
"notes": "Summarises latest news sentiment",
},
"get_economic_calendar": {
"llm_calls": 2,
"total_tokens": 2800,
"confidence": 0.79,
"based_on": EstimationBasis.EMPIRICAL,
"complexity": ToolComplexity.STANDARD,
"notes": "Economic calendar summary",
},
"get_full_technical_analysis": {
"llm_calls": 6,
"total_tokens": 9200,
"confidence": 0.72,
"based_on": EstimationBasis.EMPIRICAL,
"complexity": ToolComplexity.COMPLEX,
"notes": "Comprehensive technical package",
},
"risk_adjusted_analysis": {
"llm_calls": 5,
"total_tokens": 8800,
"confidence": 0.7,
"based_on": EstimationBasis.HEURISTIC,
"complexity": ToolComplexity.COMPLEX,
"notes": "Risk-adjusted metrics",
},
"compare_tickers": {
"llm_calls": 6,
"total_tokens": 9400,
"confidence": 0.71,
"based_on": EstimationBasis.SIMULATED,
"complexity": ToolComplexity.COMPLEX,
"notes": "Ticker comparison",
},
"portfolio_correlation_analysis": {
"llm_calls": 5,
"total_tokens": 8700,
"confidence": 0.72,
"based_on": EstimationBasis.SIMULATED,
"complexity": ToolComplexity.COMPLEX,
"notes": "Portfolio correlation study",
},
"get_market_overview": {
"llm_calls": 4,
"total_tokens": 7800,
"confidence": 0.74,
"based_on": EstimationBasis.HEURISTIC,
"complexity": ToolComplexity.COMPLEX,
"notes": "Market breadth overview",
},
"get_all_screening_recommendations": {
"llm_calls": 5,
"total_tokens": 8200,
"confidence": 0.7,
"based_on": EstimationBasis.SIMULATED,
"complexity": ToolComplexity.COMPLEX,
"notes": "Bulk screening results",
},
"analyze_market_with_agent": {
"llm_calls": 10,
"total_tokens": 14000,
"confidence": 0.65,
"based_on": EstimationBasis.CONSERVATIVE,
"complexity": ToolComplexity.PREMIUM,
"notes": "Multi-agent orchestration",
},
"get_agent_streaming_analysis": {
"llm_calls": 12,
"total_tokens": 16000,
"confidence": 0.6,
"based_on": EstimationBasis.CONSERVATIVE,
"complexity": ToolComplexity.PREMIUM,
"notes": "Streaming agent analysis",
},
"compare_personas_analysis": {
"llm_calls": 9,
"total_tokens": 12000,
"confidence": 0.62,
"based_on": EstimationBasis.HEURISTIC,
"complexity": ToolComplexity.PREMIUM,
"notes": "Persona comparison",
},
}
estimates = {name: ToolEstimate(**details) for name, details in data.items()}
return estimates
_config: ToolEstimationConfig | None = None
def get_tool_estimation_config() -> ToolEstimationConfig:
"""Return the singleton tool estimation configuration."""
global _config
if _config is None:
_config = ToolEstimationConfig()
return _config
def get_tool_estimate(tool_name: str) -> ToolEstimate:
"""Convenience helper returning the estimate for ``tool_name``."""
return get_tool_estimation_config().get_estimate(tool_name)
def should_alert_for_usage(
tool_name: str, llm_calls: int, total_tokens: int
) -> tuple[bool, str]:
"""Check whether actual usage deviates enough to raise an alert."""
return get_tool_estimation_config().should_alert(tool_name, llm_calls, total_tokens)
class ToolCostEstimator:
"""Legacy cost estimator retained for backwards compatibility."""
BASE_COSTS = {
"search": {"simple": 1, "moderate": 3, "complex": 5, "very_complex": 8},
"analysis": {"simple": 2, "moderate": 4, "complex": 7, "very_complex": 12},
"data": {"simple": 1, "moderate": 2, "complex": 4, "very_complex": 6},
"research": {"simple": 3, "moderate": 6, "complex": 10, "very_complex": 15},
}
MULTIPLIERS = {
"batch_size": {"small": 1.0, "medium": 1.5, "large": 2.0},
"time_sensitivity": {"normal": 1.0, "urgent": 1.3, "real_time": 1.5},
}
@classmethod
def estimate_tool_cost(
cls,
tool_name: str,
category: str,
complexity: str = "moderate",
additional_params: dict[str, Any] | None = None,
) -> int:
additional_params = additional_params or {}
base_cost = cls.BASE_COSTS.get(category, {}).get(complexity, 3)
batch_size = additional_params.get("batch_size", 1)
if batch_size <= 10:
batch_multiplier = cls.MULTIPLIERS["batch_size"]["small"]
elif batch_size <= 50:
batch_multiplier = cls.MULTIPLIERS["batch_size"]["medium"]
else:
batch_multiplier = cls.MULTIPLIERS["batch_size"]["large"]
time_sensitivity = additional_params.get("time_sensitivity", "normal")
time_multiplier = cls.MULTIPLIERS["time_sensitivity"].get(time_sensitivity, 1.0)
total_cost = base_cost * batch_multiplier * time_multiplier
if "portfolio" in tool_name.lower():
total_cost *= 1.2
elif "screening" in tool_name.lower():
total_cost *= 1.1
elif "real_time" in tool_name.lower():
total_cost *= 1.3
return max(1, int(total_cost))
tool_cost_estimator = ToolCostEstimator()
def estimate_tool_cost(
tool_name: str,
category: str = "analysis",
complexity: str = "moderate",
**kwargs: Any,
) -> int:
"""Convenience wrapper around :class:`ToolCostEstimator`."""
return tool_cost_estimator.estimate_tool_cost(
tool_name, category, complexity, kwargs
)
```
--------------------------------------------------------------------------------
/tests/test_rate_limiting_enhanced.py:
--------------------------------------------------------------------------------
```python
"""
Test suite for enhanced rate limiting middleware.
Tests various rate limiting scenarios including:
- Different user types (anonymous, authenticated, premium)
- Different endpoint tiers
- Multiple rate limiting strategies
- Monitoring and alerting
"""
import time
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import redis.asyncio as redis
from fastapi import FastAPI, Request
from fastapi.testclient import TestClient
from maverick_mcp.api.middleware.rate_limiting_enhanced import (
EndpointClassification,
EnhancedRateLimitMiddleware,
RateLimitConfig,
RateLimiter,
RateLimitStrategy,
RateLimitTier,
rate_limit,
)
from maverick_mcp.exceptions import RateLimitError
@pytest.fixture
def rate_limit_config():
"""Create test rate limit configuration."""
return RateLimitConfig(
public_limit=100,
auth_limit=5,
data_limit=20,
data_limit_anonymous=5,
analysis_limit=10,
analysis_limit_anonymous=2,
bulk_limit_per_hour=5,
admin_limit=10,
premium_multiplier=5.0,
enterprise_multiplier=10.0,
default_strategy=RateLimitStrategy.SLIDING_WINDOW,
burst_multiplier=1.5,
window_size_seconds=60,
token_refill_rate=1.0,
max_tokens=10,
log_violations=True,
alert_threshold=3,
)
@pytest.fixture
def rate_limiter(rate_limit_config):
"""Create rate limiter instance."""
return RateLimiter(rate_limit_config)
@pytest.fixture
async def mock_redis():
"""Create mock Redis client."""
mock = AsyncMock(spec=redis.Redis)
# Mock pipeline
mock_pipeline = AsyncMock()
mock_pipeline.execute = AsyncMock(return_value=[None, 0, None, None])
mock.pipeline = MagicMock(return_value=mock_pipeline)
# Mock other methods
mock.zrange = AsyncMock(return_value=[])
mock.hgetall = AsyncMock(return_value={})
mock.incr = AsyncMock(return_value=1)
return mock
@pytest.fixture
def test_app():
"""Create test FastAPI app."""
app = FastAPI()
@app.get("/health")
async def health():
return {"status": "ok"}
@app.post("/api/auth/login")
async def login():
return {"token": "test"}
@app.get("/api/data/stock/{symbol}")
async def get_stock(symbol: str):
return {"symbol": symbol, "price": 100}
@app.post("/api/screening/bulk")
async def bulk_screening():
return {"stocks": ["AAPL", "GOOGL", "MSFT"]}
@app.get("/api/admin/users")
async def admin_users():
return {"users": []}
return app
class TestEndpointClassification:
"""Test endpoint classification."""
def test_classify_public_endpoints(self):
"""Test classification of public endpoints."""
assert (
EndpointClassification.classify_endpoint("/health") == RateLimitTier.PUBLIC
)
assert (
EndpointClassification.classify_endpoint("/api/docs")
== RateLimitTier.PUBLIC
)
assert (
EndpointClassification.classify_endpoint("/api/openapi.json")
== RateLimitTier.PUBLIC
)
def test_classify_auth_endpoints(self):
"""Test classification of authentication endpoints."""
assert (
EndpointClassification.classify_endpoint("/api/auth/login")
== RateLimitTier.AUTHENTICATION
)
assert (
EndpointClassification.classify_endpoint("/api/auth/signup")
== RateLimitTier.AUTHENTICATION
)
assert (
EndpointClassification.classify_endpoint("/api/auth/refresh")
== RateLimitTier.AUTHENTICATION
)
def test_classify_data_endpoints(self):
"""Test classification of data retrieval endpoints."""
assert (
EndpointClassification.classify_endpoint("/api/data/stock/AAPL")
== RateLimitTier.DATA_RETRIEVAL
)
assert (
EndpointClassification.classify_endpoint("/api/stock/quote")
== RateLimitTier.DATA_RETRIEVAL
)
assert (
EndpointClassification.classify_endpoint("/api/market/movers")
== RateLimitTier.DATA_RETRIEVAL
)
def test_classify_analysis_endpoints(self):
"""Test classification of analysis endpoints."""
assert (
EndpointClassification.classify_endpoint("/api/technical/indicators")
== RateLimitTier.ANALYSIS
)
assert (
EndpointClassification.classify_endpoint("/api/screening/maverick")
== RateLimitTier.ANALYSIS
)
assert (
EndpointClassification.classify_endpoint("/api/portfolio/optimize")
== RateLimitTier.ANALYSIS
)
def test_classify_bulk_endpoints(self):
"""Test classification of bulk operation endpoints."""
assert (
EndpointClassification.classify_endpoint("/api/screening/bulk")
== RateLimitTier.BULK_OPERATION
)
assert (
EndpointClassification.classify_endpoint("/api/data/bulk")
== RateLimitTier.BULK_OPERATION
)
assert (
EndpointClassification.classify_endpoint("/api/portfolio/batch")
== RateLimitTier.BULK_OPERATION
)
def test_classify_admin_endpoints(self):
"""Test classification of administrative endpoints."""
assert (
EndpointClassification.classify_endpoint("/api/admin/users")
== RateLimitTier.ADMINISTRATIVE
)
assert (
EndpointClassification.classify_endpoint("/api/admin/system")
== RateLimitTier.ADMINISTRATIVE
)
assert (
EndpointClassification.classify_endpoint("/api/users/admin/delete")
== RateLimitTier.ADMINISTRATIVE
)
def test_default_classification(self):
"""Test default classification for unknown endpoints."""
assert (
EndpointClassification.classify_endpoint("/api/unknown")
== RateLimitTier.DATA_RETRIEVAL
)
assert (
EndpointClassification.classify_endpoint("/random/path")
== RateLimitTier.DATA_RETRIEVAL
)
class TestRateLimiter:
"""Test rate limiter core functionality."""
@pytest.mark.asyncio
async def test_sliding_window_allows_requests(self, rate_limiter, mock_redis):
"""Test sliding window allows requests within limit."""
with patch(
"maverick_mcp.data.performance.redis_manager.get_client",
return_value=mock_redis,
):
is_allowed, info = await rate_limiter.check_rate_limit(
key="test_user",
tier=RateLimitTier.DATA_RETRIEVAL,
limit=10,
window_seconds=60,
strategy=RateLimitStrategy.SLIDING_WINDOW,
)
assert is_allowed is True
assert info["limit"] == 10
assert info["remaining"] == 9
assert "burst_limit" in info
@pytest.mark.asyncio
async def test_sliding_window_blocks_excess(self, rate_limiter, mock_redis):
"""Test sliding window blocks requests over limit."""
# Mock pipeline to return high count
mock_pipeline = AsyncMock()
mock_pipeline.execute = AsyncMock(return_value=[None, 15, None, None])
mock_redis.pipeline = MagicMock(return_value=mock_pipeline)
with patch(
"maverick_mcp.data.performance.redis_manager.get_client",
return_value=mock_redis,
):
is_allowed, info = await rate_limiter.check_rate_limit(
key="test_user",
tier=RateLimitTier.DATA_RETRIEVAL,
limit=10,
window_seconds=60,
strategy=RateLimitStrategy.SLIDING_WINDOW,
)
assert is_allowed is False
assert info["remaining"] == 0
assert info["retry_after"] > 0
@pytest.mark.asyncio
async def test_token_bucket_allows_requests(self, rate_limiter, mock_redis):
"""Test token bucket allows requests with tokens."""
mock_redis.hgetall = AsyncMock(
return_value={"tokens": "5.0", "last_refill": str(time.time())}
)
with patch(
"maverick_mcp.data.performance.redis_manager.get_client",
return_value=mock_redis,
):
is_allowed, info = await rate_limiter.check_rate_limit(
key="test_user",
tier=RateLimitTier.DATA_RETRIEVAL,
limit=10,
window_seconds=60,
strategy=RateLimitStrategy.TOKEN_BUCKET,
)
assert is_allowed is True
assert "tokens" in info
assert "refill_rate" in info
@pytest.mark.asyncio
async def test_token_bucket_blocks_no_tokens(self, rate_limiter, mock_redis):
"""Test token bucket blocks requests without tokens."""
mock_redis.hgetall = AsyncMock(
return_value={"tokens": "0.5", "last_refill": str(time.time())}
)
with patch(
"maverick_mcp.data.performance.redis_manager.get_client",
return_value=mock_redis,
):
is_allowed, info = await rate_limiter.check_rate_limit(
key="test_user",
tier=RateLimitTier.DATA_RETRIEVAL,
limit=10,
window_seconds=60,
strategy=RateLimitStrategy.TOKEN_BUCKET,
)
assert is_allowed is False
assert info["retry_after"] > 0
@pytest.mark.asyncio
async def test_fixed_window_allows_requests(self, rate_limiter, mock_redis):
"""Test fixed window allows requests within limit."""
mock_pipeline = AsyncMock()
mock_pipeline.execute = AsyncMock(return_value=[5, None])
mock_redis.pipeline = MagicMock(return_value=mock_pipeline)
with patch(
"maverick_mcp.data.performance.redis_manager.get_client",
return_value=mock_redis,
):
is_allowed, info = await rate_limiter.check_rate_limit(
key="test_user",
tier=RateLimitTier.DATA_RETRIEVAL,
limit=10,
window_seconds=60,
strategy=RateLimitStrategy.FIXED_WINDOW,
)
assert is_allowed is True
assert info["current_count"] == 5
@pytest.mark.asyncio
async def test_local_fallback_rate_limiting(self, rate_limiter):
"""Test local rate limiting when Redis unavailable."""
with patch(
"maverick_mcp.data.performance.redis_manager.get_client", return_value=None
):
# First few requests should pass
for _i in range(5):
is_allowed, info = await rate_limiter.check_rate_limit(
key="test_user",
tier=RateLimitTier.DATA_RETRIEVAL,
limit=5,
window_seconds=60,
)
assert is_allowed is True
assert info["fallback"] is True
# Next request should be blocked
is_allowed, info = await rate_limiter.check_rate_limit(
key="test_user",
tier=RateLimitTier.DATA_RETRIEVAL,
limit=5,
window_seconds=60,
)
assert is_allowed is False
def test_violation_recording(self, rate_limiter):
"""Test violation count recording."""
tier = RateLimitTier.DATA_RETRIEVAL
assert rate_limiter.get_violation_count("user1", tier=tier) == 0
rate_limiter.record_violation("user1", tier=tier)
assert rate_limiter.get_violation_count("user1", tier=tier) == 1
rate_limiter.record_violation("user1", tier=tier)
assert rate_limiter.get_violation_count("user1", tier=tier) == 2
# Different tiers maintain independent counters
other_tier = RateLimitTier.ANALYSIS
assert rate_limiter.get_violation_count("user1", tier=other_tier) == 0
class TestEnhancedRateLimitMiddleware:
"""Test enhanced rate limit middleware integration."""
@pytest.fixture
def middleware_app(self, test_app, rate_limit_config):
"""Create app with rate limit middleware."""
test_app.add_middleware(EnhancedRateLimitMiddleware, config=rate_limit_config)
return test_app
@pytest.fixture
def client(self, middleware_app):
"""Create test client."""
return TestClient(middleware_app)
def test_bypass_health_check(self, client):
"""Test health check endpoint bypasses rate limiting."""
# Should always succeed
for _ in range(10):
response = client.get("/health")
assert response.status_code == 200
assert "X-RateLimit-Limit" not in response.headers
@patch("maverick_mcp.data.performance.redis_manager.get_client")
def test_anonymous_rate_limiting(self, mock_get_client, client, mock_redis):
"""Test rate limiting for anonymous users."""
mock_get_client.return_value = mock_redis
# Configure mock to allow first 5 requests
call_count = 0
def mock_execute():
nonlocal call_count
call_count += 1
if call_count <= 5:
return [None, call_count - 1, None, None]
else:
return [None, 10, None, None] # Over limit
mock_pipeline = AsyncMock()
mock_pipeline.execute = AsyncMock(side_effect=mock_execute)
mock_redis.pipeline = MagicMock(return_value=mock_pipeline)
mock_redis.zrange = AsyncMock(return_value=[(b"1", time.time())])
# First 5 requests should succeed
for _i in range(5):
response = client.get("/api/data/stock/AAPL")
assert response.status_code == 200
assert "X-RateLimit-Limit" in response.headers
assert "X-RateLimit-Remaining" in response.headers
# 6th request should be rate limited
response = client.get("/api/data/stock/AAPL")
assert response.status_code == 429
assert "Rate limit exceeded" in response.json()["error"]
assert "Retry-After" in response.headers
def test_authenticated_user_headers(self, client):
"""Test authenticated users get proper headers."""
# Mock authenticated request
request = MagicMock(spec=Request)
request.state.user_id = "123"
request.state.user_context = {"role": "user"}
# Headers should be added to response
# This would be tested in integration tests with actual auth
def test_premium_user_multiplier(self, client):
"""Test premium users get higher limits."""
# Mock premium user request
request = MagicMock(spec=Request)
request.state.user_id = "123"
request.state.user_context = {"role": "premium"}
# Premium users should have 5x the limit
# This would be tested in integration tests
def test_endpoint_tier_headers(self, client):
"""Test different endpoints return tier information."""
with patch(
"maverick_mcp.data.performance.redis_manager.get_client", return_value=None
):
# Test auth endpoint
response = client.post("/api/auth/login")
if "X-RateLimit-Tier" in response.headers:
assert response.headers["X-RateLimit-Tier"] == "authentication"
# Test data endpoint
response = client.get("/api/data/stock/AAPL")
if "X-RateLimit-Tier" in response.headers:
assert response.headers["X-RateLimit-Tier"] == "data_retrieval"
# Test bulk endpoint
response = client.post("/api/screening/bulk")
if "X-RateLimit-Tier" in response.headers:
assert response.headers["X-RateLimit-Tier"] == "bulk_operation"
class TestRateLimitDecorator:
"""Test function-level rate limiting decorator."""
@pytest.mark.asyncio
async def test_decorator_allows_requests(self):
"""Test decorator allows requests within limit."""
call_count = 0
@rate_limit(requests_per_minute=5)
async def test_function(request: Request):
nonlocal call_count
call_count += 1
return {"count": call_count}
# Mock request
request = MagicMock(spec=Request)
request.state.user_id = "test_user"
with patch(
"maverick_mcp.data.performance.redis_manager.get_client", return_value=None
):
# Should allow first few calls
for i in range(5):
result = await test_function(request)
assert result["count"] == i + 1
@pytest.mark.asyncio
async def test_decorator_blocks_excess(self):
"""Test decorator blocks excessive requests."""
@rate_limit(requests_per_minute=2)
async def test_function(request: Request):
return {"success": True}
# Mock request with proper attributes for rate limiting
request = MagicMock()
request.state = MagicMock()
request.state.user_id = "test_user"
request.url = MagicMock() # Required for rate limiting detection
with patch(
"maverick_mcp.data.performance.redis_manager.get_client", return_value=None
):
# First 2 should succeed
await test_function(request)
await test_function(request)
# 3rd should raise exception
with pytest.raises(RateLimitError) as exc_info:
await test_function(request)
assert "Rate limit exceeded" in str(exc_info.value)
@pytest.mark.asyncio
async def test_decorator_without_request(self):
"""Test decorator works without request object."""
@rate_limit(requests_per_minute=5)
async def test_function(value: int):
return value * 2
# Should work without rate limiting
result = await test_function(5)
assert result == 10
class TestMonitoringIntegration:
"""Test monitoring and alerting integration."""
@pytest.mark.asyncio
async def test_violation_monitoring(self, rate_limiter, rate_limit_config):
"""Test violations are recorded for monitoring."""
# Record multiple violations
for _i in range(rate_limit_config.alert_threshold + 1):
rate_limiter.record_violation("bad_user", tier=RateLimitTier.DATA_RETRIEVAL)
# Check violation count
assert (
rate_limiter.get_violation_count(
"bad_user", tier=RateLimitTier.DATA_RETRIEVAL
)
> rate_limit_config.alert_threshold
)
@pytest.mark.asyncio
async def test_cleanup_task(self, rate_limiter, mock_redis):
"""Test periodic cleanup of old data."""
mock_redis.scan = AsyncMock(
return_value=(
0,
[
"rate_limit:sw:test1",
"rate_limit:sw:test2",
],
)
)
mock_redis.type = AsyncMock(return_value="zset")
mock_redis.zremrangebyscore = AsyncMock()
mock_redis.zcard = AsyncMock(return_value=0)
mock_redis.delete = AsyncMock()
with patch(
"maverick_mcp.data.performance.redis_manager.get_client",
return_value=mock_redis,
):
await rate_limiter.cleanup_old_data(older_than_hours=1)
# Should have called delete for empty keys
assert mock_redis.delete.called
```
--------------------------------------------------------------------------------
/tests/test_security_headers.py:
--------------------------------------------------------------------------------
```python
"""
Comprehensive Security Headers Tests for Maverick MCP.
Tests security headers configuration, middleware implementation,
environment-specific headers, and CSP/HSTS policies.
"""
import os
from unittest.mock import MagicMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from maverick_mcp.api.middleware.security import (
SecurityHeadersMiddleware as APISecurityHeadersMiddleware,
)
from maverick_mcp.config.security import (
SecurityConfig,
SecurityHeadersConfig,
)
from maverick_mcp.config.security_utils import (
SecurityHeadersMiddleware,
apply_security_headers_to_fastapi,
)
class TestSecurityHeadersConfig:
"""Test security headers configuration."""
def test_security_headers_default_values(self):
"""Test security headers have secure default values."""
config = SecurityHeadersConfig()
assert config.x_content_type_options == "nosniff"
assert config.x_frame_options == "DENY"
assert config.x_xss_protection == "1; mode=block"
assert config.referrer_policy == "strict-origin-when-cross-origin"
assert "geolocation=()" in config.permissions_policy
def test_hsts_header_generation(self):
"""Test HSTS header value generation."""
config = SecurityHeadersConfig()
hsts_header = config.hsts_header_value
assert f"max-age={config.hsts_max_age}" in hsts_header
assert "includeSubDomains" in hsts_header
assert "preload" not in hsts_header # Default is False
def test_hsts_header_with_preload(self):
"""Test HSTS header with preload enabled."""
config = SecurityHeadersConfig(hsts_preload=True)
hsts_header = config.hsts_header_value
assert "preload" in hsts_header
def test_hsts_header_without_subdomains(self):
"""Test HSTS header without subdomains."""
config = SecurityHeadersConfig(hsts_include_subdomains=False)
hsts_header = config.hsts_header_value
assert "includeSubDomains" not in hsts_header
def test_csp_header_generation(self):
"""Test CSP header value generation."""
config = SecurityHeadersConfig()
csp_header = config.csp_header_value
# Check required directives
assert "default-src 'self'" in csp_header
assert "script-src 'self' 'unsafe-inline'" in csp_header
assert "style-src 'self' 'unsafe-inline'" in csp_header
assert "object-src 'none'" in csp_header
assert "connect-src 'self'" in csp_header
assert "frame-src 'none'" in csp_header
assert "base-uri 'self'" in csp_header
assert "form-action 'self'" in csp_header
def test_csp_custom_directives(self):
"""Test CSP with custom directives."""
config = SecurityHeadersConfig(
csp_script_src=["'self'", "https://trusted.com"],
csp_connect_src=["'self'", "https://api.trusted.com"],
)
csp_header = config.csp_header_value
assert "script-src 'self' https://trusted.com" in csp_header
assert "connect-src 'self' https://api.trusted.com" in csp_header
def test_permissions_policy_default(self):
"""Test permissions policy default configuration."""
config = SecurityHeadersConfig()
permissions = config.permissions_policy
assert "geolocation=()" in permissions
assert "microphone=()" in permissions
assert "camera=()" in permissions
assert "usb=()" in permissions
assert "magnetometer=()" in permissions
class TestSecurityHeadersMiddleware:
"""Test security headers middleware implementation."""
def test_middleware_adds_headers(self):
"""Test that middleware adds security headers to responses."""
app = FastAPI()
# Create mock security config
mock_config = MagicMock()
mock_config.get_security_headers.return_value = {
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "DENY",
"X-XSS-Protection": "1; mode=block",
"Content-Security-Policy": "default-src 'self'",
}
app.add_middleware(SecurityHeadersMiddleware, security_config=mock_config)
@app.get("/test")
async def test_endpoint():
return {"message": "test"}
client = TestClient(app)
response = client.get("/test")
assert response.headers["X-Content-Type-Options"] == "nosniff"
assert response.headers["X-Frame-Options"] == "DENY"
assert response.headers["X-XSS-Protection"] == "1; mode=block"
assert response.headers["Content-Security-Policy"] == "default-src 'self'"
def test_middleware_uses_default_config(self):
"""Test that middleware uses default security config when none provided."""
app = FastAPI()
with patch(
"maverick_mcp.config.security_utils.get_security_config"
) as mock_get_config:
mock_config = MagicMock()
mock_config.get_security_headers.return_value = {"X-Frame-Options": "DENY"}
mock_get_config.return_value = mock_config
app.add_middleware(SecurityHeadersMiddleware)
@app.get("/test")
async def test_endpoint():
return {"message": "test"}
client = TestClient(app)
response = client.get("/test")
mock_get_config.assert_called_once()
assert response.headers["X-Frame-Options"] == "DENY"
def test_api_middleware_integration(self):
"""Test API security headers middleware integration."""
app = FastAPI()
app.add_middleware(APISecurityHeadersMiddleware)
@app.get("/test")
async def test_endpoint():
return {"message": "test"}
client = TestClient(app)
response = client.get("/test")
# Should have basic security headers
assert "X-Content-Type-Options" in response.headers
assert "X-Frame-Options" in response.headers
class TestEnvironmentSpecificHeaders:
"""Test environment-specific security headers."""
def test_hsts_in_production(self):
"""Test HSTS header is included in production."""
with patch.dict(os.environ, {"ENVIRONMENT": "production"}, clear=False):
config = SecurityConfig()
headers = config.get_security_headers()
assert "Strict-Transport-Security" in headers
assert "max-age=" in headers["Strict-Transport-Security"]
def test_hsts_in_development(self):
"""Test HSTS header is not included in development."""
with patch.dict(os.environ, {"ENVIRONMENT": "development"}, clear=False):
config = SecurityConfig(force_https=False)
headers = config.get_security_headers()
assert "Strict-Transport-Security" not in headers
def test_hsts_with_force_https(self):
"""Test HSTS header is included when HTTPS is forced."""
with patch.dict(os.environ, {"ENVIRONMENT": "development"}, clear=False):
config = SecurityConfig(force_https=True)
headers = config.get_security_headers()
assert "Strict-Transport-Security" in headers
def test_production_security_validation(self):
"""Test production security validation."""
with patch.dict(os.environ, {"ENVIRONMENT": "production"}, clear=False):
with patch(
"maverick_mcp.config.security._get_cors_origins"
) as mock_origins:
mock_origins.return_value = ["https://app.maverick-mcp.com"]
with patch("logging.getLogger") as mock_logger:
mock_logger_instance = MagicMock()
mock_logger.return_value = mock_logger_instance
# Test with HTTPS not forced (should warn)
SecurityConfig(force_https=False)
# Should log warning about HTTPS
mock_logger_instance.warning.assert_called()
def test_development_security_permissive(self):
"""Test development security is more permissive."""
with patch.dict(os.environ, {"ENVIRONMENT": "development"}, clear=False):
config = SecurityConfig()
assert config.is_development() is True
assert config.is_production() is False
class TestCSPConfiguration:
"""Test Content Security Policy configuration."""
def test_csp_avoids_checkout_domains(self):
"""Test CSP excludes third-party checkout provider domains."""
config = SecurityHeadersConfig()
assert config.csp_script_src == ["'self'", "'unsafe-inline'"]
assert config.csp_connect_src == ["'self'"]
assert config.csp_frame_src == ["'none'"]
def test_csp_blocks_inline_scripts_by_default(self):
"""Test CSP configuration for inline scripts."""
config = SecurityHeadersConfig()
csp = config.csp_header_value
# Note: Current config allows 'unsafe-inline' for compatibility
# In a more secure setup, this should use nonces or hashes
assert "'unsafe-inline'" in csp
def test_csp_blocks_object_embedding(self):
"""Test CSP blocks object embedding."""
config = SecurityHeadersConfig()
csp = config.csp_header_value
assert "object-src 'none'" in csp
def test_csp_restricts_base_uri(self):
"""Test CSP restricts base URI."""
config = SecurityHeadersConfig()
csp = config.csp_header_value
assert "base-uri 'self'" in csp
def test_csp_restricts_form_action(self):
"""Test CSP restricts form actions."""
config = SecurityHeadersConfig()
csp = config.csp_header_value
assert "form-action 'self'" in csp
def test_csp_image_sources(self):
"""Test CSP allows necessary image sources."""
config = SecurityHeadersConfig()
csp = config.csp_header_value
assert "img-src 'self' data: https:" in csp
def test_csp_custom_configuration(self):
"""Test CSP with custom configuration."""
custom_config = SecurityHeadersConfig(
csp_default_src=["'self'", "https://trusted.com"],
csp_script_src=["'self'"],
csp_style_src=["'self'"], # Remove unsafe-inline from styles too
csp_object_src=["'none'"],
)
csp = custom_config.csp_header_value
assert "default-src 'self' https://trusted.com" in csp
assert "script-src 'self'" in csp
# Since we removed unsafe-inline from style-src, it shouldn't be in CSP
assert "style-src 'self'" in csp
assert "'unsafe-inline'" not in csp
class TestXFrameOptionsConfiguration:
"""Test X-Frame-Options configuration."""
def test_frame_options_deny_default(self):
"""Test X-Frame-Options defaults to DENY."""
SecurityHeadersConfig()
headers = SecurityConfig().get_security_headers()
assert headers["X-Frame-Options"] == "DENY"
def test_frame_options_sameorigin(self):
"""Test X-Frame-Options can be set to SAMEORIGIN."""
config = SecurityHeadersConfig(x_frame_options="SAMEORIGIN")
security_config = SecurityConfig(headers=config)
headers = security_config.get_security_headers()
assert headers["X-Frame-Options"] == "SAMEORIGIN"
def test_frame_options_allow_from(self):
"""Test X-Frame-Options with ALLOW-FROM directive."""
config = SecurityHeadersConfig(x_frame_options="ALLOW-FROM https://trusted.com")
security_config = SecurityConfig(headers=config)
headers = security_config.get_security_headers()
assert headers["X-Frame-Options"] == "ALLOW-FROM https://trusted.com"
class TestReferrerPolicyConfiguration:
"""Test Referrer-Policy configuration."""
def test_referrer_policy_default(self):
"""Test Referrer-Policy default value."""
SecurityHeadersConfig()
headers = SecurityConfig().get_security_headers()
assert headers["Referrer-Policy"] == "strict-origin-when-cross-origin"
def test_referrer_policy_custom(self):
"""Test custom Referrer-Policy."""
config = SecurityHeadersConfig(referrer_policy="no-referrer")
security_config = SecurityConfig(headers=config)
headers = security_config.get_security_headers()
assert headers["Referrer-Policy"] == "no-referrer"
class TestPermissionsPolicyConfiguration:
"""Test Permissions-Policy configuration."""
def test_permissions_policy_blocks_dangerous_features(self):
"""Test Permissions-Policy blocks dangerous browser features."""
SecurityHeadersConfig()
headers = SecurityConfig().get_security_headers()
permissions = headers["Permissions-Policy"]
assert "geolocation=()" in permissions
assert "microphone=()" in permissions
assert "camera=()" in permissions
assert "usb=()" in permissions
def test_permissions_policy_custom(self):
"""Test custom Permissions-Policy configuration."""
custom_policy = "geolocation=(self), camera=(), microphone=()"
config = SecurityHeadersConfig(permissions_policy=custom_policy)
security_config = SecurityConfig(headers=config)
headers = security_config.get_security_headers()
assert headers["Permissions-Policy"] == custom_policy
class TestSecurityHeadersIntegration:
"""Test security headers integration with application."""
def test_all_headers_applied(self):
"""Test that all security headers are applied to responses."""
app = FastAPI()
apply_security_headers_to_fastapi(app)
@app.get("/test")
async def test_endpoint():
return {"message": "test"}
client = TestClient(app)
response = client.get("/test")
# Check all expected headers are present
expected_headers = [
"X-Content-Type-Options",
"X-Frame-Options",
"X-XSS-Protection",
"Referrer-Policy",
"Permissions-Policy",
"Content-Security-Policy",
]
for header in expected_headers:
assert header in response.headers
def test_headers_on_error_responses(self):
"""Test security headers are included on error responses."""
app = FastAPI()
apply_security_headers_to_fastapi(app)
@app.get("/error")
async def error_endpoint():
from fastapi import HTTPException
raise HTTPException(status_code=500, detail="Test error")
client = TestClient(app)
response = client.get("/error")
# Even on errors, security headers should be present
assert response.status_code == 500
assert "X-Frame-Options" in response.headers
assert "X-Content-Type-Options" in response.headers
def test_headers_on_different_methods(self):
"""Test security headers on different HTTP methods."""
app = FastAPI()
apply_security_headers_to_fastapi(app)
@app.get("/test")
async def get_endpoint():
return {"method": "GET"}
@app.post("/test")
async def post_endpoint():
return {"method": "POST"}
@app.put("/test")
async def put_endpoint():
return {"method": "PUT"}
client = TestClient(app)
methods = [(client.get, "/test"), (client.post, "/test"), (client.put, "/test")]
for method_func, path in methods:
response = method_func(path)
assert "X-Frame-Options" in response.headers
assert "Content-Security-Policy" in response.headers
def test_headers_override_existing(self):
"""Test security headers override any existing headers."""
app = FastAPI()
apply_security_headers_to_fastapi(app)
@app.get("/test")
async def test_endpoint():
from fastapi import Response
response = Response(content='{"message": "test"}')
response.headers["X-Frame-Options"] = "ALLOWALL" # Insecure value
return response
client = TestClient(app)
response = client.get("/test")
# Security middleware should override the insecure value
assert response.headers["X-Frame-Options"] == "DENY"
class TestSecurityHeadersValidation:
"""Test security headers validation and best practices."""
def test_no_server_header_disclosure(self):
"""Test that server information is not disclosed."""
app = FastAPI()
apply_security_headers_to_fastapi(app)
@app.get("/test")
async def test_endpoint():
return {"message": "test"}
client = TestClient(app)
response = client.get("/test")
# Should not disclose server information
server_header = response.headers.get("Server", "")
assert "uvicorn" not in server_header.lower()
def test_no_powered_by_header(self):
"""Test that X-Powered-By header is not present."""
app = FastAPI()
apply_security_headers_to_fastapi(app)
@app.get("/test")
async def test_endpoint():
return {"message": "test"}
client = TestClient(app)
response = client.get("/test")
assert "X-Powered-By" not in response.headers
def test_content_type_nosniff(self):
"""Test X-Content-Type-Options prevents MIME sniffing."""
app = FastAPI()
apply_security_headers_to_fastapi(app)
@app.get("/test")
async def test_endpoint():
return {"message": "test"}
client = TestClient(app)
response = client.get("/test")
assert response.headers["X-Content-Type-Options"] == "nosniff"
def test_xss_protection_enabled(self):
"""Test X-XSS-Protection is properly configured."""
app = FastAPI()
apply_security_headers_to_fastapi(app)
@app.get("/test")
async def test_endpoint():
return {"message": "test"}
client = TestClient(app)
response = client.get("/test")
xss_protection = response.headers["X-XSS-Protection"]
assert "1" in xss_protection
assert "mode=block" in xss_protection
class TestSecurityHeadersPerformance:
"""Test security headers don't impact performance significantly."""
def test_headers_middleware_performance(self):
"""Test security headers middleware performance."""
app = FastAPI()
apply_security_headers_to_fastapi(app)
@app.get("/test")
async def test_endpoint():
return {"message": "test"}
client = TestClient(app)
# Make multiple requests to test performance
import time
start_time = time.time()
for _ in range(100):
response = client.get("/test")
assert response.status_code == 200
end_time = time.time()
total_time = end_time - start_time
# Should complete 100 requests quickly (less than 5 seconds)
assert total_time < 5.0
def test_headers_memory_usage(self):
"""Test security headers don't cause memory leaks."""
app = FastAPI()
apply_security_headers_to_fastapi(app)
@app.get("/test")
async def test_endpoint():
return {"message": "test"}
client = TestClient(app)
# Make many requests to check for memory leaks
for _ in range(1000):
response = client.get("/test")
assert "X-Frame-Options" in response.headers
# If we reach here without memory issues, test passes
if __name__ == "__main__":
pytest.main([__file__, "-v"])
```
--------------------------------------------------------------------------------
/maverick_mcp/utils/parallel_research.py:
--------------------------------------------------------------------------------
```python
"""
Parallel Research Execution Utilities
This module provides infrastructure for spawning and managing parallel research
subagents for comprehensive financial analysis.
"""
import asyncio
import logging
import time
from collections.abc import Callable
from typing import Any
from ..agents.circuit_breaker import circuit_breaker
from ..config.settings import get_settings
from .orchestration_logging import (
get_orchestration_logger,
log_agent_execution,
log_method_call,
log_parallel_execution,
log_performance_metrics,
log_resource_usage,
)
logger = logging.getLogger(__name__)
settings = get_settings()
class ParallelResearchConfig:
"""Configuration for parallel research operations."""
def __init__(
self,
max_concurrent_agents: int = 6, # OPTIMIZATION: Increased from 4 for better parallelism
timeout_per_agent: int = 60, # OPTIMIZATION: Reduced from 180s to prevent blocking
enable_fallbacks: bool = False, # Disabled by default for speed
rate_limit_delay: float = 0.05, # OPTIMIZATION: Minimal delay (50ms) for API rate limits only
batch_size: int = 3, # OPTIMIZATION: Batch size for task grouping
use_worker_pool: bool = True, # OPTIMIZATION: Enable worker pool pattern
):
self.max_concurrent_agents = max_concurrent_agents
self.timeout_per_agent = timeout_per_agent
self.enable_fallbacks = enable_fallbacks
self.rate_limit_delay = rate_limit_delay
self.batch_size = batch_size
self.use_worker_pool = use_worker_pool
class ResearchTask:
"""Represents a single research task for parallel execution."""
def __init__(
self,
task_id: str,
task_type: str,
target_topic: str,
focus_areas: list[str],
priority: int = 1,
timeout: int | None = None,
):
self.task_id = task_id
self.task_type = task_type # fundamental, technical, sentiment, competitive
self.target_topic = target_topic
self.focus_areas = focus_areas
self.priority = priority
self.timeout = timeout
self.start_time: float | None = None
self.end_time: float | None = None
self.status: str = "pending" # pending, running, completed, failed
self.result: dict[str, Any] | None = None
self.error: str | None = None
class ResearchResult:
"""Aggregated results from parallel research execution."""
def __init__(self):
self.task_results: dict[str, ResearchTask] = {}
self.synthesis: dict[str, Any] | None = None
self.total_execution_time: float = 0.0
self.successful_tasks: int = 0
self.failed_tasks: int = 0
self.parallel_efficiency: float = 0.0
class ParallelResearchOrchestrator:
"""Orchestrates parallel research agent execution."""
def __init__(self, config: ParallelResearchConfig | None = None):
self.config = config or ParallelResearchConfig()
self.active_tasks: dict[str, ResearchTask] = {}
# OPTIMIZATION: Use bounded semaphore for better control
self._semaphore = asyncio.BoundedSemaphore(self.config.max_concurrent_agents)
self.orchestration_logger = get_orchestration_logger("ParallelOrchestrator")
# Track active workers for better coordination
self._active_workers = 0
self._worker_lock = asyncio.Lock()
# Log initialization
self.orchestration_logger.info(
"🎛️ ORCHESTRATOR_INIT",
max_agents=self.config.max_concurrent_agents,
)
@log_method_call(component="ParallelOrchestrator", include_timing=True)
async def execute_parallel_research(
self,
tasks: list[ResearchTask],
research_executor,
synthesis_callback: Callable[..., Any] | None = None,
) -> ResearchResult:
"""
Execute multiple research tasks in parallel with intelligent coordination.
Args:
tasks: List of research tasks to execute
research_executor: Function to execute individual research tasks
synthesis_callback: Optional function to synthesize results
Returns:
ResearchResult with aggregated results and synthesis
"""
self.orchestration_logger.set_request_context(
session_id=tasks[0].task_id.split("_")[0] if tasks else "unknown",
task_count=len(tasks),
)
# Log task overview
self.orchestration_logger.info(
"📋 TASK_OVERVIEW",
task_count=len(tasks),
max_concurrent=self.config.max_concurrent_agents,
)
start_time = time.time()
# Create result container
result = ResearchResult()
with log_parallel_execution(
"ParallelOrchestrator", "research execution", len(tasks)
) as exec_logger:
try:
# Prepare tasks for execution
prepared_tasks = await self._prepare_tasks(tasks)
exec_logger.info(
"🔧 TASKS_PREPARED", prepared_count=len(prepared_tasks)
)
# OPTIMIZATION: Use create_task for true parallel execution
# This allows tasks to start immediately without waiting
exec_logger.info("🚀 PARALLEL_EXECUTION_START")
# Create all tasks immediately for maximum parallelism
running_tasks = []
for task in prepared_tasks:
# Create task immediately without awaiting
task_future = asyncio.create_task(
self._execute_single_task(task, research_executor)
)
running_tasks.append(task_future)
# OPTIMIZATION: Minimal delay only if absolutely needed for API rate limits
# Reduced from progressive delays to fixed minimal delay
if self.config.rate_limit_delay > 0 and len(running_tasks) < len(
prepared_tasks
):
await asyncio.sleep(
self.config.rate_limit_delay * 0.1
) # 10% of original delay
# Wait for all tasks to complete using asyncio.as_completed for better responsiveness
completed_tasks = []
for task_future in asyncio.as_completed(running_tasks):
try:
result_task = await task_future
completed_tasks.append(result_task)
except Exception as e:
# Handle exceptions without blocking other tasks
completed_tasks.append(e)
exec_logger.info("🏁 PARALLEL_EXECUTION_COMPLETE")
# Process results
result = await self._process_task_results(
prepared_tasks, completed_tasks, start_time
)
# Log performance metrics
log_performance_metrics(
"ParallelOrchestrator",
{
"total_tasks": len(tasks),
"successful_tasks": result.successful_tasks,
"failed_tasks": result.failed_tasks,
"parallel_efficiency": result.parallel_efficiency,
"total_duration": result.total_execution_time,
},
)
# Synthesize results if callback provided
if synthesis_callback and result.successful_tasks > 0:
exec_logger.info("🧠 SYNTHESIS_START")
try:
synthesis_start = time.time()
result.synthesis = await synthesis_callback(result.task_results)
_ = (
time.time() - synthesis_start
) # Track duration but not used currently
exec_logger.info("✅ SYNTHESIS_SUCCESS")
except Exception as e:
exec_logger.error("❌ SYNTHESIS_FAILED", error=str(e))
result.synthesis = {"error": f"Synthesis failed: {str(e)}"}
else:
exec_logger.info("⏭️ SYNTHESIS_SKIPPED")
return result
except Exception as e:
exec_logger.error("💥 PARALLEL_EXECUTION_FAILED", error=str(e))
result.total_execution_time = time.time() - start_time
return result
async def _prepare_tasks(self, tasks: list[ResearchTask]) -> list[ResearchTask]:
"""Prepare tasks for execution by setting timeouts and priorities."""
prepared = []
for task in sorted(tasks, key=lambda t: t.priority, reverse=True):
# Set default timeout if not specified
if not task.timeout:
task.timeout = self.config.timeout_per_agent
# Set task to pending status
task.status = "pending"
self.active_tasks[task.task_id] = task
prepared.append(task)
return prepared[: self.config.max_concurrent_agents]
@circuit_breaker("parallel_research_task", failure_threshold=2, recovery_timeout=30)
async def _execute_single_task(
self, task: ResearchTask, research_executor
) -> ResearchTask:
"""Execute a single research task with optimized error handling."""
# OPTIMIZATION: Acquire semaphore with try_acquire pattern for non-blocking
acquired = False
try:
# Try to acquire immediately, if not available, task is already created and will wait
acquired = not self._semaphore.locked()
if not acquired:
# Wait for semaphore but don't block other task creation
await self._semaphore.acquire()
acquired = True
task.start_time = time.time()
task.status = "running"
# Track active worker count
async with self._worker_lock:
self._active_workers += 1
with log_agent_execution(
task.task_type, task.task_id, task.focus_areas
) as agent_logger:
try:
agent_logger.info(
"🎯 TASK_EXECUTION_START",
timeout=task.timeout,
priority=task.priority,
)
# OPTIMIZATION: Use shield to prevent cancellation during critical work
result = await asyncio.shield(
asyncio.wait_for(research_executor(task), timeout=task.timeout)
)
task.result = result
task.status = "completed"
task.end_time = time.time()
# Log successful completion
execution_time = task.end_time - task.start_time
agent_logger.info(
"✨ TASK_EXECUTION_SUCCESS",
duration=f"{execution_time:.3f}s",
)
# Log resource usage if available
if isinstance(result, dict) and "metrics" in result:
log_resource_usage(
f"{task.task_type}Agent",
api_calls=result["metrics"].get("api_calls"),
cache_hits=result["metrics"].get("cache_hits"),
)
return task
except TimeoutError:
task.error = f"Task timeout after {task.timeout}s"
task.status = "failed"
agent_logger.error("⏰ TASK_TIMEOUT", timeout=task.timeout)
except Exception as e:
task.error = str(e)
task.status = "failed"
agent_logger.error("💥 TASK_EXECUTION_FAILED", error=str(e))
finally:
task.end_time = time.time()
# Track active worker count
async with self._worker_lock:
self._active_workers -= 1
return task
finally:
# Always release semaphore if acquired
if acquired:
self._semaphore.release()
async def _process_task_results(
self, tasks: list[ResearchTask], completed_tasks: list[Any], start_time: float
) -> ResearchResult:
"""Process and aggregate results from completed tasks."""
result = ResearchResult()
result.total_execution_time = time.time() - start_time
for task in tasks:
result.task_results[task.task_id] = task
if task.status == "completed":
result.successful_tasks += 1
else:
result.failed_tasks += 1
# Calculate parallel efficiency
if result.total_execution_time > 0:
total_sequential_time = sum(
(task.end_time or 0) - (task.start_time or 0)
for task in tasks
if task.start_time
)
result.parallel_efficiency = (
(total_sequential_time / result.total_execution_time)
if total_sequential_time > 0
else 0.0
)
logger.info(
f"Parallel research completed: {result.successful_tasks} successful, "
f"{result.failed_tasks} failed, {result.parallel_efficiency:.2f}x speedup"
)
return result
class TaskDistributionEngine:
"""Intelligent task distribution for research topics."""
TASK_TYPES = {
"fundamental": {
"keywords": [
"earnings",
"revenue",
"profit",
"cash flow",
"debt",
"valuation",
],
"focus_areas": ["financials", "fundamentals", "earnings", "balance_sheet"],
},
"technical": {
"keywords": [
"price",
"chart",
"trend",
"support",
"resistance",
"momentum",
],
"focus_areas": ["technical_analysis", "chart_patterns", "indicators"],
},
"sentiment": {
"keywords": [
"sentiment",
"news",
"analyst",
"opinion",
"rating",
"recommendation",
],
"focus_areas": ["market_sentiment", "analyst_ratings", "news_sentiment"],
},
"competitive": {
"keywords": [
"competitor",
"market share",
"industry",
"competitive",
"peers",
],
"focus_areas": [
"competitive_analysis",
"industry_analysis",
"market_position",
],
},
}
@log_method_call(component="TaskDistributionEngine", include_timing=True)
def distribute_research_tasks(
self, topic: str, session_id: str, focus_areas: list[str] | None = None
) -> list[ResearchTask]:
"""
Intelligently distribute a research topic into specialized tasks.
Args:
topic: Main research topic
session_id: Session identifier for tracking
focus_areas: Optional specific areas to focus on
Returns:
List of specialized research tasks
"""
distribution_logger = get_orchestration_logger("TaskDistributionEngine")
distribution_logger.set_request_context(session_id=session_id)
distribution_logger.info(
"🎯 TASK_DISTRIBUTION_START",
session_id=session_id,
)
tasks = []
topic_lower = topic.lower()
# Determine which task types are relevant
relevant_types = self._analyze_topic_relevance(topic_lower, focus_areas)
# Log relevance analysis results
distribution_logger.info("🧠 RELEVANCE_ANALYSIS")
# Create tasks for relevant types
created_tasks = []
for task_type, score in relevant_types.items():
if score > 0.3: # Relevance threshold
task = ResearchTask(
task_id=f"{session_id}_{task_type}",
task_type=task_type,
target_topic=topic,
focus_areas=self.TASK_TYPES[task_type]["focus_areas"],
priority=int(score * 10), # Convert to 1-10 priority
)
tasks.append(task)
created_tasks.append(
{
"type": task_type,
"priority": task.priority,
"score": score,
"focus_areas": task.focus_areas[:3], # Limit for logging
}
)
# Log created tasks
if created_tasks:
distribution_logger.info(
"✅ TASKS_CREATED",
task_count=len(created_tasks),
)
# Ensure at least one task (fallback to fundamental analysis)
if not tasks:
distribution_logger.warning(
"⚠️ NO_RELEVANT_TASKS_FOUND - using fallback",
threshold=0.3,
max_score=max(relevant_types.values()) if relevant_types else 0,
)
fallback_task = ResearchTask(
task_id=f"{session_id}_fundamental",
task_type="fundamental",
target_topic=topic,
focus_areas=["general_analysis"],
priority=5,
)
tasks.append(fallback_task)
distribution_logger.info(
"🔄 FALLBACK_TASK_CREATED", task_type="fundamental"
)
# Final summary
task_summary = {
"total_tasks": len(tasks),
"task_types": [t.task_type for t in tasks],
"avg_priority": sum(t.priority for t in tasks) / len(tasks) if tasks else 0,
}
distribution_logger.info("🎉 TASK_DISTRIBUTION_COMPLETE", **task_summary)
return tasks
def _analyze_topic_relevance(
self, topic: str, focus_areas: list[str] | None = None
) -> dict[str, float]:
"""Analyze topic relevance to different research types."""
relevance_scores = {}
for task_type, config in self.TASK_TYPES.items():
score = 0.0
# Score based on keywords in topic
keyword_matches = sum(
1 for keyword in config["keywords"] if keyword in topic
)
score += keyword_matches / len(config["keywords"]) * 0.6
# Score based on focus areas
if focus_areas:
focus_matches = sum(
1
for focus in focus_areas
if any(area in focus.lower() for area in config["focus_areas"])
)
score += focus_matches / len(config["focus_areas"]) * 0.4
else:
# Default relevance for common research types
score += {
"fundamental": 0.8,
"sentiment": 0.6,
"technical": 0.4,
"competitive": 0.5,
}.get(task_type, 0.3)
relevance_scores[task_type] = min(score, 1.0)
return relevance_scores
# Export key classes for easy import
__all__ = [
"ParallelResearchConfig",
"ResearchTask",
"ResearchResult",
"ParallelResearchOrchestrator",
"TaskDistributionEngine",
]
```
--------------------------------------------------------------------------------
/maverick_mcp/utils/data_chunking.py:
--------------------------------------------------------------------------------
```python
"""
Data chunking utilities for memory-efficient processing of large datasets.
Provides streaming, batching, and generator-based approaches for handling large DataFrames.
"""
import logging
import math
from collections.abc import Callable, Generator
from typing import Any, Literal
import numpy as np
import pandas as pd
from maverick_mcp.utils.memory_profiler import (
force_garbage_collection,
get_dataframe_memory_usage,
memory_context,
optimize_dataframe,
)
logger = logging.getLogger(__name__)
# Default chunk size configurations
DEFAULT_CHUNK_SIZE_MB = 50.0
MAX_CHUNK_SIZE_MB = 200.0
MIN_ROWS_PER_CHUNK = 100
class DataChunker:
"""Advanced data chunking utility with multiple strategies."""
def __init__(
self,
chunk_size_mb: float = DEFAULT_CHUNK_SIZE_MB,
min_rows_per_chunk: int = MIN_ROWS_PER_CHUNK,
optimize_chunks: bool = True,
auto_gc: bool = True,
):
"""Initialize data chunker.
Args:
chunk_size_mb: Target chunk size in megabytes
min_rows_per_chunk: Minimum rows per chunk
optimize_chunks: Whether to optimize chunk memory usage
auto_gc: Whether to automatically run garbage collection
"""
self.chunk_size_mb = min(chunk_size_mb, MAX_CHUNK_SIZE_MB)
self.chunk_size_bytes = int(self.chunk_size_mb * 1024 * 1024)
self.min_rows_per_chunk = min_rows_per_chunk
self.optimize_chunks = optimize_chunks
self.auto_gc = auto_gc
logger.debug(
f"DataChunker initialized: {self.chunk_size_mb}MB chunks, "
f"min {self.min_rows_per_chunk} rows"
)
def estimate_chunk_size(self, df: pd.DataFrame) -> tuple[int, int]:
"""Estimate optimal chunk size for a DataFrame.
Args:
df: DataFrame to analyze
Returns:
Tuple of (rows_per_chunk, estimated_chunks)
"""
total_memory = df.memory_usage(deep=True).sum()
memory_per_row = total_memory / len(df) if len(df) > 0 else 0
if memory_per_row == 0:
return len(df), 1
# Calculate rows per chunk based on memory target
rows_per_chunk = max(
self.min_rows_per_chunk, int(self.chunk_size_bytes / memory_per_row)
)
# Ensure we don't exceed the DataFrame size
rows_per_chunk = min(rows_per_chunk, len(df))
estimated_chunks = math.ceil(len(df) / rows_per_chunk)
logger.debug(
f"Estimated chunking: {rows_per_chunk} rows/chunk, "
f"{estimated_chunks} chunks total"
)
return rows_per_chunk, estimated_chunks
def chunk_by_rows(
self, df: pd.DataFrame, rows_per_chunk: int = None
) -> Generator[pd.DataFrame, None, None]:
"""Chunk DataFrame by number of rows.
Args:
df: DataFrame to chunk
rows_per_chunk: Rows per chunk (auto-estimated if None)
Yields:
DataFrame chunks
"""
if rows_per_chunk is None:
rows_per_chunk, _ = self.estimate_chunk_size(df)
total_chunks = math.ceil(len(df) / rows_per_chunk)
logger.debug(
f"Chunking {len(df)} rows into {total_chunks} chunks "
f"of ~{rows_per_chunk} rows each"
)
for i, start_idx in enumerate(range(0, len(df), rows_per_chunk)):
end_idx = min(start_idx + rows_per_chunk, len(df))
chunk = df.iloc[start_idx:end_idx].copy()
if self.optimize_chunks:
chunk = optimize_dataframe(chunk)
logger.debug(
f"Yielding chunk {i + 1}/{total_chunks}: rows {start_idx}-{end_idx - 1}"
)
yield chunk
# Cleanup after yielding
if self.auto_gc:
del chunk
if i % 5 == 0: # GC every 5 chunks
force_garbage_collection()
def chunk_by_memory(self, df: pd.DataFrame) -> Generator[pd.DataFrame, None, None]:
"""Chunk DataFrame by memory size.
Args:
df: DataFrame to chunk
Yields:
DataFrame chunks
"""
total_memory = df.memory_usage(deep=True).sum()
if total_memory <= self.chunk_size_bytes:
if self.optimize_chunks:
df = optimize_dataframe(df)
yield df
return
# Use row-based chunking with memory-based estimation
yield from self.chunk_by_rows(df)
def chunk_by_date(
self,
df: pd.DataFrame,
freq: Literal["D", "W", "M", "Q", "Y"] = "M",
date_column: str = None,
) -> Generator[pd.DataFrame, None, None]:
"""Chunk DataFrame by date periods.
Args:
df: DataFrame to chunk (must have datetime index or date_column)
freq: Frequency for chunking (D=daily, W=weekly, M=monthly, etc.)
date_column: Name of date column (uses index if None)
Yields:
DataFrame chunks by date periods
"""
if date_column:
if date_column not in df.columns:
raise ValueError(f"Date column '{date_column}' not found")
elif not isinstance(df.index, pd.DatetimeIndex):
raise ValueError(
"DataFrame must have datetime index or specify date_column"
)
# Group by period
period_groups = df.groupby(
pd.Grouper(key=date_column, freq=freq)
if date_column
else pd.Grouper(freq=freq)
)
total_periods = len(period_groups)
logger.debug(f"Chunking by {freq} periods: {total_periods} chunks")
for i, (period, group) in enumerate(period_groups):
if len(group) == 0:
continue
if self.optimize_chunks:
group = optimize_dataframe(group)
logger.debug(
f"Yielding period chunk {i + 1}/{total_periods}: "
f"{period} ({len(group)} rows)"
)
yield group
if self.auto_gc and i % 3 == 0: # GC every 3 periods
force_garbage_collection()
def process_in_chunks(
self,
df: pd.DataFrame,
processor: Callable[[pd.DataFrame], Any],
combiner: Callable[[list], Any] = None,
chunk_method: Literal["rows", "memory", "date"] = "memory",
**chunk_kwargs,
) -> Any:
"""Process DataFrame in chunks and combine results.
Args:
df: DataFrame to process
processor: Function to apply to each chunk
combiner: Function to combine results (default: list)
chunk_method: Chunking method to use
**chunk_kwargs: Additional arguments for chunking method
Returns:
Combined results
"""
results = []
# Select chunking method
if chunk_method == "rows":
chunk_generator = self.chunk_by_rows(df, **chunk_kwargs)
elif chunk_method == "memory":
chunk_generator = self.chunk_by_memory(df)
elif chunk_method == "date":
chunk_generator = self.chunk_by_date(df, **chunk_kwargs)
else:
raise ValueError(f"Unknown chunk method: {chunk_method}")
with memory_context("chunk_processing"):
for i, chunk in enumerate(chunk_generator):
try:
with memory_context(f"chunk_{i}"):
result = processor(chunk)
results.append(result)
except Exception as e:
logger.error(f"Error processing chunk {i}: {e}")
raise
# Combine results
if combiner:
return combiner(results)
elif results and isinstance(results[0], pd.DataFrame):
# Auto-combine DataFrames
return pd.concat(results, ignore_index=True)
else:
return results
class StreamingDataProcessor:
"""Streaming data processor for very large datasets."""
def __init__(self, chunk_size_mb: float = DEFAULT_CHUNK_SIZE_MB):
"""Initialize streaming processor.
Args:
chunk_size_mb: Chunk size in MB
"""
self.chunk_size_mb = chunk_size_mb
self.chunker = DataChunker(chunk_size_mb=chunk_size_mb)
def stream_from_csv(
self,
filepath: str,
processor: Callable[[pd.DataFrame], Any],
chunksize: int = None,
**read_kwargs,
) -> Generator[Any, None, None]:
"""Stream process CSV file in chunks.
Args:
filepath: Path to CSV file
processor: Function to process each chunk
chunksize: Rows per chunk (auto-estimated if None)
**read_kwargs: Additional arguments for pd.read_csv
Yields:
Processed results for each chunk
"""
# Estimate chunk size if not provided
if chunksize is None:
# Read a sample to estimate memory usage
sample = pd.read_csv(filepath, nrows=1000, **read_kwargs)
memory_per_row = sample.memory_usage(deep=True).sum() / len(sample)
chunksize = max(100, int(self.chunker.chunk_size_bytes / memory_per_row))
del sample
force_garbage_collection()
logger.info(f"Streaming CSV with {chunksize} rows per chunk")
chunk_reader = pd.read_csv(filepath, chunksize=chunksize, **read_kwargs)
for i, chunk in enumerate(chunk_reader):
with memory_context(f"csv_chunk_{i}"):
# Optimize chunk if needed
if self.chunker.optimize_chunks:
chunk = optimize_dataframe(chunk)
result = processor(chunk)
yield result
# Clean up
del chunk
if i % 5 == 0:
force_garbage_collection()
def stream_from_database(
self,
query: str,
connection,
processor: Callable[[pd.DataFrame], Any],
chunksize: int = None,
) -> Generator[Any, None, None]:
"""Stream process database query results in chunks.
Args:
query: SQL query
connection: Database connection
processor: Function to process each chunk
chunksize: Rows per chunk
Yields:
Processed results for each chunk
"""
if chunksize is None:
chunksize = 10000 # Default for database queries
logger.info(f"Streaming database query with {chunksize} rows per chunk")
chunk_reader = pd.read_sql(query, connection, chunksize=chunksize)
for i, chunk in enumerate(chunk_reader):
with memory_context(f"db_chunk_{i}"):
if self.chunker.optimize_chunks:
chunk = optimize_dataframe(chunk)
result = processor(chunk)
yield result
del chunk
if i % 3 == 0:
force_garbage_collection()
def optimize_dataframe_dtypes(
df: pd.DataFrame, aggressive: bool = False, categorical_threshold: float = 0.5
) -> pd.DataFrame:
"""Optimize DataFrame data types for memory efficiency.
Args:
df: DataFrame to optimize
aggressive: Use aggressive optimizations (may lose precision)
categorical_threshold: Threshold for categorical conversion
Returns:
Optimized DataFrame
"""
logger.debug(f"Optimizing DataFrame dtypes: {df.shape}")
initial_memory = df.memory_usage(deep=True).sum()
df_opt = df.copy()
for col in df_opt.columns:
col_type = df_opt[col].dtype
try:
if col_type == "object":
# Convert string columns to categorical if beneficial
unique_count = df_opt[col].nunique()
total_count = len(df_opt[col])
if unique_count / total_count < categorical_threshold:
df_opt[col] = df_opt[col].astype("category")
logger.debug(f"Converted {col} to categorical")
elif "int" in str(col_type):
# Downcast integers
c_min = df_opt[col].min()
c_max = df_opt[col].max()
if c_min >= np.iinfo(np.int8).min and c_max <= np.iinfo(np.int8).max:
df_opt[col] = df_opt[col].astype(np.int8)
elif (
c_min >= np.iinfo(np.int16).min and c_max <= np.iinfo(np.int16).max
):
df_opt[col] = df_opt[col].astype(np.int16)
elif (
c_min >= np.iinfo(np.int32).min and c_max <= np.iinfo(np.int32).max
):
df_opt[col] = df_opt[col].astype(np.int32)
elif "float" in str(col_type) and col_type == "float64":
# Downcast float64 to float32 if no precision loss
if aggressive:
# Check if conversion preserves data
temp = df_opt[col].astype(np.float32)
if np.allclose(
df_opt[col].fillna(0), temp.fillna(0), rtol=1e-6, equal_nan=True
):
df_opt[col] = temp
logger.debug(f"Converted {col} to float32")
except Exception as e:
logger.debug(f"Could not optimize column {col}: {e}")
continue
final_memory = df_opt.memory_usage(deep=True).sum()
memory_saved = initial_memory - final_memory
if memory_saved > 0:
logger.info(
f"DataFrame optimization saved {memory_saved / (1024**2):.2f}MB "
f"({memory_saved / initial_memory * 100:.1f}% reduction)"
)
return df_opt
def create_memory_efficient_dataframe(
data: dict | list, optimize: bool = True, categorical_columns: list[str] = None
) -> pd.DataFrame:
"""Create a memory-efficient DataFrame from data.
Args:
data: Data to create DataFrame from
optimize: Whether to optimize dtypes
categorical_columns: Columns to convert to categorical
Returns:
Memory-optimized DataFrame
"""
with memory_context("creating_dataframe"):
df = pd.DataFrame(data)
if categorical_columns:
for col in categorical_columns:
if col in df.columns:
df[col] = df[col].astype("category")
if optimize:
df = optimize_dataframe_dtypes(df)
return df
def batch_process_large_dataframe(
df: pd.DataFrame,
operation: Callable,
batch_size: int = None,
combine_results: bool = True,
) -> Any:
"""Process large DataFrame in batches to manage memory.
Args:
df: Large DataFrame to process
operation: Function to apply to each batch
batch_size: Size of each batch (auto-estimated if None)
combine_results: Whether to combine batch results
Returns:
Combined results or list of batch results
"""
chunker = DataChunker()
if batch_size:
chunk_generator = chunker.chunk_by_rows(df, batch_size)
else:
chunk_generator = chunker.chunk_by_memory(df)
results = []
with memory_context("batch_processing"):
for i, batch in enumerate(chunk_generator):
logger.debug(f"Processing batch {i + 1}")
with memory_context(f"batch_{i}"):
result = operation(batch)
results.append(result)
if combine_results and results:
if isinstance(results[0], pd.DataFrame):
return pd.concat(results, ignore_index=True)
elif isinstance(results[0], int | float):
return sum(results)
elif isinstance(results[0], list):
return [item for sublist in results for item in sublist]
return results
class LazyDataFrame:
"""Lazy evaluation wrapper for large DataFrames."""
def __init__(self, data_source: str | pd.DataFrame, chunk_size_mb: float = 50.0):
"""Initialize lazy DataFrame.
Args:
data_source: File path or DataFrame
chunk_size_mb: Chunk size for processing
"""
self.data_source = data_source
self.chunker = DataChunker(chunk_size_mb=chunk_size_mb)
self._cached_info = None
def get_info(self) -> dict[str, Any]:
"""Get DataFrame information without loading full data."""
if self._cached_info:
return self._cached_info
if isinstance(self.data_source, str):
# Read just the header and a sample
sample = pd.read_csv(self.data_source, nrows=100)
total_rows = sum(1 for _ in open(self.data_source)) - 1 # Subtract header
self._cached_info = {
"columns": sample.columns.tolist(),
"dtypes": sample.dtypes.to_dict(),
"estimated_rows": total_rows,
"sample_memory_mb": sample.memory_usage(deep=True).sum() / (1024**2),
}
else:
self._cached_info = get_dataframe_memory_usage(self.data_source)
return self._cached_info
def apply_chunked(self, operation: Callable) -> Any:
"""Apply operation in chunks."""
if isinstance(self.data_source, str):
processor = StreamingDataProcessor(self.chunker.chunk_size_mb)
results = list(processor.stream_from_csv(self.data_source, operation))
else:
results = self.chunker.process_in_chunks(self.data_source, operation)
return results
def to_optimized_dataframe(self) -> pd.DataFrame:
"""Load and optimize the full DataFrame."""
if isinstance(self.data_source, str):
df = pd.read_csv(self.data_source)
else:
df = self.data_source.copy()
return optimize_dataframe_dtypes(df)
# Utility functions for common operations
def chunked_concat(
dataframes: list[pd.DataFrame], chunk_size: int = 10
) -> pd.DataFrame:
"""Concatenate DataFrames in chunks to manage memory.
Args:
dataframes: List of DataFrames to concatenate
chunk_size: Number of DataFrames to concat at once
Returns:
Concatenated DataFrame
"""
if not dataframes:
return pd.DataFrame()
if len(dataframes) <= chunk_size:
return pd.concat(dataframes, ignore_index=True)
# Process in chunks
results = []
for i in range(0, len(dataframes), chunk_size):
chunk = dataframes[i : i + chunk_size]
with memory_context(f"concat_chunk_{i // chunk_size}"):
result = pd.concat(chunk, ignore_index=True)
results.append(result)
# Clean up chunk
for df in chunk:
del df
force_garbage_collection()
# Final concatenation
with memory_context("final_concat"):
final_result = pd.concat(results, ignore_index=True)
return final_result
def memory_efficient_groupby(
df: pd.DataFrame, group_col: str, agg_func: Callable, chunk_size_mb: float = 50.0
) -> pd.DataFrame:
"""Perform memory-efficient groupby operations.
Args:
df: DataFrame to group
group_col: Column to group by
agg_func: Aggregation function
chunk_size_mb: Chunk size in MB
Returns:
Aggregated DataFrame
"""
if group_col not in df.columns:
raise ValueError(f"Group column '{group_col}' not found")
chunker = DataChunker(chunk_size_mb=chunk_size_mb)
results = []
def process_chunk(chunk):
return chunk.groupby(group_col).apply(agg_func).reset_index()
results = chunker.process_in_chunks(df, process_chunk)
# Combine and re-aggregate results
combined = pd.concat(results, ignore_index=True)
final_result = combined.groupby(group_col).apply(agg_func).reset_index()
return final_result
```