This is page 25 of 39. Use http://codebase.md/wshobson/maverick-mcp?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.md
│ │ ├── config.yml
│ │ ├── feature_request.md
│ │ ├── question.md
│ │ └── security_report.md
│ ├── pull_request_template.md
│ └── workflows
│ ├── claude-code-review.yml
│ └── claude.yml
├── .gitignore
├── .jules
│ └── bolt.md
├── .python-version
├── .vscode
│ ├── launch.json
│ └── settings.json
├── alembic
│ ├── env.py
│ ├── script.py.mako
│ └── versions
│ ├── 001_initial_schema.py
│ ├── 003_add_performance_indexes.py
│ ├── 006_rename_metadata_columns.py
│ ├── 008_performance_optimization_indexes.py
│ ├── 009_rename_to_supply_demand.py
│ ├── 010_self_contained_schema.py
│ ├── 011_remove_proprietary_terms.py
│ ├── 013_add_backtest_persistence_models.py
│ ├── 014_add_portfolio_models.py
│ ├── 08e3945a0c93_merge_heads.py
│ ├── 9374a5c9b679_merge_heads_for_testing.py
│ ├── abf9b9afb134_merge_multiple_heads.py
│ ├── adda6d3fd84b_merge_proprietary_terms_removal_with_.py
│ ├── e0c75b0bdadb_fix_financial_data_precision_only.py
│ ├── f0696e2cac15_add_essential_performance_indexes.py
│ └── fix_database_integrity_issues.py
├── alembic.ini
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── DATABASE_SETUP.md
├── docker-compose.override.yml.example
├── docker-compose.yml
├── Dockerfile
├── docs
│ ├── api
│ │ └── backtesting.md
│ ├── BACKTESTING.md
│ ├── COST_BASIS_SPECIFICATION.md
│ ├── deep_research_agent.md
│ ├── exa_research_testing_strategy.md
│ ├── PORTFOLIO_PERSONALIZATION_PLAN.md
│ ├── PORTFOLIO.md
│ ├── SETUP_SELF_CONTAINED.md
│ └── speed_testing_framework.md
├── examples
│ ├── complete_speed_validation.py
│ ├── deep_research_integration.py
│ ├── llm_optimization_example.py
│ ├── llm_speed_demo.py
│ ├── monitoring_example.py
│ ├── parallel_research_example.py
│ ├── speed_optimization_demo.py
│ └── timeout_fix_demonstration.py
├── LICENSE
├── Makefile
├── MANIFEST.in
├── maverick_mcp
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── circuit_breaker.py
│ │ ├── deep_research.py
│ │ ├── market_analysis.py
│ │ ├── optimized_research.py
│ │ ├── supervisor.py
│ │ └── technical_analysis.py
│ ├── api
│ │ ├── __init__.py
│ │ ├── api_server.py
│ │ ├── connection_manager.py
│ │ ├── dependencies
│ │ │ ├── __init__.py
│ │ │ ├── stock_analysis.py
│ │ │ └── technical_analysis.py
│ │ ├── error_handling.py
│ │ ├── inspector_compatible_sse.py
│ │ ├── inspector_sse.py
│ │ ├── middleware
│ │ │ ├── error_handling.py
│ │ │ ├── mcp_logging.py
│ │ │ ├── rate_limiting_enhanced.py
│ │ │ └── security.py
│ │ ├── openapi_config.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── agents.py
│ │ │ ├── backtesting.py
│ │ │ ├── data_enhanced.py
│ │ │ ├── data.py
│ │ │ ├── health_enhanced.py
│ │ │ ├── health_tools.py
│ │ │ ├── health.py
│ │ │ ├── intelligent_backtesting.py
│ │ │ ├── introspection.py
│ │ │ ├── mcp_prompts.py
│ │ │ ├── monitoring.py
│ │ │ ├── news_sentiment_enhanced.py
│ │ │ ├── performance.py
│ │ │ ├── portfolio.py
│ │ │ ├── research.py
│ │ │ ├── screening_ddd.py
│ │ │ ├── screening_parallel.py
│ │ │ ├── screening.py
│ │ │ ├── technical_ddd.py
│ │ │ ├── technical_enhanced.py
│ │ │ ├── technical.py
│ │ │ └── tool_registry.py
│ │ ├── server.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── base_service.py
│ │ │ ├── market_service.py
│ │ │ ├── portfolio_service.py
│ │ │ ├── prompt_service.py
│ │ │ └── resource_service.py
│ │ ├── simple_sse.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── insomnia_export.py
│ │ └── postman_export.py
│ ├── application
│ │ ├── __init__.py
│ │ ├── commands
│ │ │ └── __init__.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_dto.py
│ │ ├── queries
│ │ │ ├── __init__.py
│ │ │ └── get_technical_analysis.py
│ │ └── screening
│ │ ├── __init__.py
│ │ ├── dtos.py
│ │ └── queries.py
│ ├── backtesting
│ │ ├── __init__.py
│ │ ├── ab_testing.py
│ │ ├── analysis.py
│ │ ├── batch_processing_stub.py
│ │ ├── batch_processing.py
│ │ ├── model_manager.py
│ │ ├── optimization.py
│ │ ├── persistence.py
│ │ ├── retraining_pipeline.py
│ │ ├── strategies
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── ml
│ │ │ │ ├── __init__.py
│ │ │ │ ├── adaptive.py
│ │ │ │ ├── ensemble.py
│ │ │ │ ├── feature_engineering.py
│ │ │ │ └── regime_aware.py
│ │ │ ├── ml_strategies.py
│ │ │ ├── parser.py
│ │ │ └── templates.py
│ │ ├── strategy_executor.py
│ │ ├── vectorbt_engine.py
│ │ └── visualization.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── database_self_contained.py
│ │ ├── database.py
│ │ ├── llm_optimization_config.py
│ │ ├── logging_settings.py
│ │ ├── plotly_config.py
│ │ ├── security_utils.py
│ │ ├── security.py
│ │ ├── settings.py
│ │ ├── technical_constants.py
│ │ ├── tool_estimation.py
│ │ └── validation.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── technical_analysis.py
│ │ └── visualization.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── cache_manager.py
│ │ ├── cache.py
│ │ ├── django_adapter.py
│ │ ├── health.py
│ │ ├── models.py
│ │ ├── performance.py
│ │ ├── session_management.py
│ │ └── validation.py
│ ├── database
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── optimization.py
│ ├── dependencies.py
│ ├── domain
│ │ ├── __init__.py
│ │ ├── entities
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis.py
│ │ ├── events
│ │ │ └── __init__.py
│ │ ├── portfolio.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ ├── entities.py
│ │ │ ├── services.py
│ │ │ └── value_objects.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_service.py
│ │ ├── stock_analysis
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis_service.py
│ │ └── value_objects
│ │ ├── __init__.py
│ │ └── technical_indicators.py
│ ├── exceptions.py
│ ├── infrastructure
│ │ ├── __init__.py
│ │ ├── cache
│ │ │ └── __init__.py
│ │ ├── caching
│ │ │ ├── __init__.py
│ │ │ └── cache_management_service.py
│ │ ├── connection_manager.py
│ │ ├── data_fetching
│ │ │ ├── __init__.py
│ │ │ └── stock_data_service.py
│ │ ├── health
│ │ │ ├── __init__.py
│ │ │ └── health_checker.py
│ │ ├── persistence
│ │ │ ├── __init__.py
│ │ │ └── stock_repository.py
│ │ ├── providers
│ │ │ └── __init__.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ └── repositories.py
│ │ └── sse_optimizer.py
│ ├── langchain_tools
│ │ ├── __init__.py
│ │ ├── adapters.py
│ │ └── registry.py
│ ├── logging_config.py
│ ├── memory
│ │ ├── __init__.py
│ │ └── stores.py
│ ├── monitoring
│ │ ├── __init__.py
│ │ ├── health_check.py
│ │ ├── health_monitor.py
│ │ ├── integration_example.py
│ │ ├── metrics.py
│ │ ├── middleware.py
│ │ └── status_dashboard.py
│ ├── providers
│ │ ├── __init__.py
│ │ ├── dependencies.py
│ │ ├── factories
│ │ │ ├── __init__.py
│ │ │ ├── config_factory.py
│ │ │ └── provider_factory.py
│ │ ├── implementations
│ │ │ ├── __init__.py
│ │ │ ├── cache_adapter.py
│ │ │ ├── macro_data_adapter.py
│ │ │ ├── market_data_adapter.py
│ │ │ ├── persistence_adapter.py
│ │ │ └── stock_data_adapter.py
│ │ ├── interfaces
│ │ │ ├── __init__.py
│ │ │ ├── cache.py
│ │ │ ├── config.py
│ │ │ ├── macro_data.py
│ │ │ ├── market_data.py
│ │ │ ├── persistence.py
│ │ │ └── stock_data.py
│ │ ├── llm_factory.py
│ │ ├── macro_data.py
│ │ ├── market_data.py
│ │ ├── mocks
│ │ │ ├── __init__.py
│ │ │ ├── mock_cache.py
│ │ │ ├── mock_config.py
│ │ │ ├── mock_macro_data.py
│ │ │ ├── mock_market_data.py
│ │ │ ├── mock_persistence.py
│ │ │ └── mock_stock_data.py
│ │ ├── openrouter_provider.py
│ │ ├── optimized_screening.py
│ │ ├── optimized_stock_data.py
│ │ └── stock_data.py
│ ├── README.md
│ ├── tests
│ │ ├── __init__.py
│ │ ├── README_INMEMORY_TESTS.md
│ │ ├── test_cache_debug.py
│ │ ├── test_fixes_validation.py
│ │ ├── test_in_memory_routers.py
│ │ ├── test_in_memory_server.py
│ │ ├── test_macro_data_provider.py
│ │ ├── test_mailgun_email.py
│ │ ├── test_market_calendar_caching.py
│ │ ├── test_mcp_tool_fixes_pytest.py
│ │ ├── test_mcp_tool_fixes.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_models_functional.py
│ │ ├── test_server.py
│ │ ├── test_stock_data_enhanced.py
│ │ ├── test_stock_data_provider.py
│ │ └── test_technical_analysis.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── performance_monitoring.py
│ │ ├── portfolio_manager.py
│ │ ├── risk_management.py
│ │ └── sentiment_analysis.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── agent_errors.py
│ │ ├── batch_processing.py
│ │ ├── cache_warmer.py
│ │ ├── circuit_breaker_decorators.py
│ │ ├── circuit_breaker_services.py
│ │ ├── circuit_breaker.py
│ │ ├── data_chunking.py
│ │ ├── database_monitoring.py
│ │ ├── debug_utils.py
│ │ ├── fallback_strategies.py
│ │ ├── llm_optimization.py
│ │ ├── logging_example.py
│ │ ├── logging_init.py
│ │ ├── logging.py
│ │ ├── mcp_logging.py
│ │ ├── memory_profiler.py
│ │ ├── monitoring_middleware.py
│ │ ├── monitoring.py
│ │ ├── orchestration_logging.py
│ │ ├── parallel_research.py
│ │ ├── parallel_screening.py
│ │ ├── quick_cache.py
│ │ ├── resource_manager.py
│ │ ├── shutdown.py
│ │ ├── stock_helpers.py
│ │ ├── structured_logger.py
│ │ ├── tool_monitoring.py
│ │ ├── tracing.py
│ │ └── yfinance_pool.py
│ ├── validation
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── data.py
│ │ ├── middleware.py
│ │ ├── portfolio.py
│ │ ├── responses.py
│ │ ├── screening.py
│ │ └── technical.py
│ └── workflows
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── market_analyzer.py
│ │ ├── optimizer_agent.py
│ │ ├── strategy_selector.py
│ │ └── validator_agent.py
│ ├── backtesting_workflow.py
│ └── state.py
├── PLANS.md
├── pyproject.toml
├── pyrightconfig.json
├── README.md
├── scripts
│ ├── dev.sh
│ ├── INSTALLATION_GUIDE.md
│ ├── load_example.py
│ ├── load_market_data.py
│ ├── load_tiingo_data.py
│ ├── migrate_db.py
│ ├── README_TIINGO_LOADER.md
│ ├── requirements_tiingo.txt
│ ├── run_stock_screening.py
│ ├── run-migrations.sh
│ ├── seed_db.py
│ ├── seed_sp500.py
│ ├── setup_database.sh
│ ├── setup_self_contained.py
│ ├── setup_sp500_database.sh
│ ├── test_seeded_data.py
│ ├── test_tiingo_loader.py
│ ├── tiingo_config.py
│ └── validate_setup.py
├── SECURITY.md
├── server.json
├── setup.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── core
│ │ └── test_technical_analysis.py
│ ├── data
│ │ └── test_portfolio_models.py
│ ├── domain
│ │ ├── conftest.py
│ │ ├── test_portfolio_entities.py
│ │ └── test_technical_analysis_service.py
│ ├── fixtures
│ │ └── orchestration_fixtures.py
│ ├── integration
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── README.md
│ │ ├── run_integration_tests.sh
│ │ ├── test_api_technical.py
│ │ ├── test_chaos_engineering.py
│ │ ├── test_config_management.py
│ │ ├── test_full_backtest_workflow_advanced.py
│ │ ├── test_full_backtest_workflow.py
│ │ ├── test_high_volume.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_orchestration_complete.py
│ │ ├── test_portfolio_persistence.py
│ │ ├── test_redis_cache.py
│ │ ├── test_security_integration.py.disabled
│ │ └── vcr_setup.py
│ ├── performance
│ │ ├── __init__.py
│ │ ├── test_benchmarks.py
│ │ ├── test_load.py
│ │ ├── test_profiling.py
│ │ └── test_stress.py
│ ├── providers
│ │ └── test_stock_data_simple.py
│ ├── README.md
│ ├── test_agents_router_mcp.py
│ ├── test_backtest_persistence.py
│ ├── test_cache_management_service.py
│ ├── test_cache_serialization.py
│ ├── test_circuit_breaker.py
│ ├── test_database_pool_config_simple.py
│ ├── test_database_pool_config.py
│ ├── test_deep_research_functional.py
│ ├── test_deep_research_integration.py
│ ├── test_deep_research_parallel_execution.py
│ ├── test_error_handling.py
│ ├── test_event_loop_integrity.py
│ ├── test_exa_research_integration.py
│ ├── test_exception_hierarchy.py
│ ├── test_financial_search.py
│ ├── test_graceful_shutdown.py
│ ├── test_integration_simple.py
│ ├── test_langgraph_workflow.py
│ ├── test_market_data_async.py
│ ├── test_market_data_simple.py
│ ├── test_mcp_orchestration_functional.py
│ ├── test_ml_strategies.py
│ ├── test_optimized_research_agent.py
│ ├── test_orchestration_integration.py
│ ├── test_orchestration_logging.py
│ ├── test_orchestration_tools_simple.py
│ ├── test_parallel_research_integration.py
│ ├── test_parallel_research_orchestrator.py
│ ├── test_parallel_research_performance.py
│ ├── test_performance_optimizations.py
│ ├── test_production_validation.py
│ ├── test_provider_architecture.py
│ ├── test_rate_limiting_enhanced.py
│ ├── test_runner_validation.py
│ ├── test_security_comprehensive.py.disabled
│ ├── test_security_cors.py
│ ├── test_security_enhancements.py.disabled
│ ├── test_security_headers.py
│ ├── test_security_penetration.py
│ ├── test_session_management.py
│ ├── test_speed_optimization_validation.py
│ ├── test_stock_analysis_dependencies.py
│ ├── test_stock_analysis_service.py
│ ├── test_stock_data_fetching_service.py
│ ├── test_supervisor_agent.py
│ ├── test_supervisor_functional.py
│ ├── test_tool_estimation_config.py
│ ├── test_visualization.py
│ ├── unit
│ │ └── test_stock_repository_adapter.py
│ └── utils
│ ├── test_agent_errors.py
│ ├── test_logging.py
│ ├── test_parallel_screening.py
│ └── test_quick_cache.py
├── tools
│ ├── check_orchestration_config.py
│ ├── experiments
│ │ ├── validation_examples.py
│ │ └── validation_fixed.py
│ ├── fast_dev.sh
│ ├── hot_reload.py
│ ├── quick_test.py
│ └── templates
│ ├── new_router_template.py
│ ├── new_tool_template.py
│ ├── screening_strategy_template.py
│ └── test_template.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/tests/test_security_penetration.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Security Penetration Testing Suite for MaverickMCP.
3 |
4 | This suite performs security penetration testing to validate that
5 | security protections are active and effective against real attack vectors.
6 |
7 | Tests include:
8 | - Authentication bypass attempts
9 | - CSRF attack vectors
10 | - Rate limiting evasion
11 | - Input validation bypass
12 | - Session hijacking attempts
13 | - SQL injection prevention
14 | - XSS protection validation
15 | - Information disclosure prevention
16 | """
17 |
18 | import time
19 | from datetime import UTC, datetime, timedelta
20 | from uuid import uuid4
21 |
22 | import pytest
23 | from fastapi.testclient import TestClient
24 |
25 | from maverick_mcp.api.api_server import create_api_app
26 |
27 |
28 | @pytest.fixture
29 | def security_test_app():
30 | """Create app for security testing."""
31 | return create_api_app()
32 |
33 |
34 | @pytest.fixture
35 | def security_client(security_test_app):
36 | """Create client for security testing."""
37 | return TestClient(security_test_app)
38 |
39 |
40 | @pytest.fixture
41 | def test_user():
42 | """Test user for security testing."""
43 | return {
44 | "email": f"sectest{uuid4().hex[:8]}@example.com",
45 | "password": "SecurePass123!",
46 | "name": "Security Test User",
47 | "company": "Test Security Inc",
48 | }
49 |
50 |
51 | class TestAuthenticationSecurity:
52 | """Test authentication security against bypass attempts."""
53 |
54 | @pytest.mark.integration
55 | def test_jwt_token_manipulation_resistance(self, security_client, test_user):
56 | """Test resistance to JWT token manipulation attacks."""
57 |
58 | # Register and login
59 | security_client.post("/auth/register", json=test_user)
60 | login_response = security_client.post(
61 | "/auth/login",
62 | json={"email": test_user["email"], "password": test_user["password"]},
63 | )
64 |
65 | # Extract tokens from cookies
66 | cookies = login_response.cookies
67 | access_token_cookie = cookies.get("maverick_access_token")
68 |
69 | if not access_token_cookie:
70 | pytest.skip("JWT tokens not in cookies - may be test environment")
71 |
72 | # Attempt 1: Modified JWT signature
73 | tampered_token = access_token_cookie[:-10] + "tampered123"
74 |
75 | response = security_client.get(
76 | "/user/profile", cookies={"maverick_access_token": tampered_token}
77 | )
78 | assert response.status_code == 401 # Should reject tampered token
79 |
80 | # Attempt 2: Algorithm confusion attack (trying "none" algorithm)
81 | none_algorithm_token = "eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.eyJ1c2VyX2lkIjoxLCJleHAiOjk5OTk5OTk5OTl9."
82 |
83 | response = security_client.get(
84 | "/user/profile", cookies={"maverick_access_token": none_algorithm_token}
85 | )
86 | assert response.status_code == 401 # Should reject "none" algorithm
87 |
88 | # Attempt 3: Expired token
89 | {
90 | "user_id": 1,
91 | "exp": int((datetime.now(UTC) - timedelta(hours=1)).timestamp()),
92 | "iat": int((datetime.now(UTC) - timedelta(hours=2)).timestamp()),
93 | "jti": "expired_token",
94 | }
95 |
96 | # This would require creating an expired token with the same secret
97 | # For security, we just test that expired tokens are rejected
98 |
99 | @pytest.mark.integration
100 | def test_session_fixation_protection(self, security_client, test_user):
101 | """Test protection against session fixation attacks."""
102 |
103 | # Get initial session state
104 | initial_response = security_client.get("/auth/login")
105 | initial_cookies = initial_response.cookies
106 |
107 | # Login with potential pre-set session
108 | security_client.post("/auth/register", json=test_user)
109 | login_response = security_client.post(
110 | "/auth/login",
111 | json={"email": test_user["email"], "password": test_user["password"]},
112 | cookies=initial_cookies, # Try to maintain old session
113 | )
114 |
115 | # Verify new session is created (cookies should be different)
116 | new_cookies = login_response.cookies
117 |
118 | # Session should be regenerated after login
119 | if "maverick_access_token" in new_cookies:
120 | # New token should be different from any pre-existing one
121 | assert login_response.status_code == 200
122 |
123 | @pytest.mark.integration
124 | def test_concurrent_session_limits(self, security_client, test_user):
125 | """Test limits on concurrent sessions."""
126 |
127 | # Register user
128 | security_client.post("/auth/register", json=test_user)
129 |
130 | # Create multiple concurrent sessions
131 | session_responses = []
132 | for _i in range(5):
133 | client_instance = TestClient(security_client.app)
134 | response = client_instance.post(
135 | "/auth/login",
136 | json={"email": test_user["email"], "password": test_user["password"]},
137 | )
138 | session_responses.append(response)
139 |
140 | # All should succeed (or be limited if concurrent session limits implemented)
141 | success_count = sum(1 for r in session_responses if r.status_code == 200)
142 | assert success_count >= 1 # At least one should succeed
143 |
144 | # If concurrent session limits are implemented, test that old sessions are invalidated
145 |
146 | @pytest.mark.integration
147 | def test_password_brute_force_protection(self, security_client, test_user):
148 | """Test protection against password brute force attacks."""
149 |
150 | # Register user
151 | security_client.post("/auth/register", json=test_user)
152 |
153 | # Attempt multiple failed logins
154 | failed_attempts = []
155 | for i in range(10):
156 | response = security_client.post(
157 | "/auth/login",
158 | json={"email": test_user["email"], "password": f"wrong_password_{i}"},
159 | )
160 | failed_attempts.append(response.status_code)
161 |
162 | # Small delay to avoid overwhelming the system
163 | time.sleep(0.1)
164 |
165 | # Should have multiple failures
166 | assert all(status == 401 for status in failed_attempts)
167 |
168 | # After multiple failures, account should be locked or rate limited
169 | # Test with correct password - should be blocked if protection is active
170 | final_attempt = security_client.post(
171 | "/auth/login",
172 | json={"email": test_user["email"], "password": test_user["password"]},
173 | )
174 |
175 | # If brute force protection is active, should be rate limited
176 | # Otherwise, should succeed
177 | assert final_attempt.status_code in [200, 401, 429]
178 |
179 |
180 | class TestCSRFAttackVectors:
181 | """Test CSRF protection against various attack vectors."""
182 |
183 | @pytest.mark.integration
184 | def test_csrf_attack_simulation(self, security_client, test_user):
185 | """Simulate CSRF attacks to test protection."""
186 |
187 | # Setup authenticated session
188 | security_client.post("/auth/register", json=test_user)
189 | login_response = security_client.post(
190 | "/auth/login",
191 | json={"email": test_user["email"], "password": test_user["password"]},
192 | )
193 | csrf_token = login_response.json().get("csrf_token")
194 |
195 | # Attack 1: Missing CSRF token
196 | attack_response_1 = security_client.post(
197 | "/user/profile", json={"name": "Attacked Name"}
198 | )
199 | assert attack_response_1.status_code == 403
200 | assert "CSRF" in attack_response_1.json()["detail"]
201 |
202 | # Attack 2: Invalid CSRF token
203 | attack_response_2 = security_client.post(
204 | "/user/profile",
205 | json={"name": "Attacked Name"},
206 | headers={"X-CSRF-Token": "invalid_token_123"},
207 | )
208 | assert attack_response_2.status_code == 403
209 |
210 | # Attack 3: CSRF token from different session
211 | # Create second user and get their CSRF token
212 | other_user = {
213 | "email": f"other{uuid4().hex[:8]}@example.com",
214 | "password": "OtherPass123!",
215 | "name": "Other User",
216 | }
217 |
218 | other_client = TestClient(security_client.app)
219 | other_client.post("/auth/register", json=other_user)
220 | other_login = other_client.post(
221 | "/auth/login",
222 | json={"email": other_user["email"], "password": other_user["password"]},
223 | )
224 | other_csrf = other_login.json().get("csrf_token")
225 |
226 | # Try to use other user's CSRF token
227 | attack_response_3 = security_client.post(
228 | "/user/profile",
229 | json={"name": "Cross-User Attack"},
230 | headers={"X-CSRF-Token": other_csrf},
231 | )
232 | assert attack_response_3.status_code == 403
233 |
234 | # Legitimate request should work
235 | legitimate_response = security_client.post(
236 | "/user/profile",
237 | json={"name": "Legitimate Update"},
238 | headers={"X-CSRF-Token": csrf_token},
239 | )
240 | assert legitimate_response.status_code == 200
241 |
242 | @pytest.mark.integration
243 | def test_csrf_double_submit_validation(self, security_client, test_user):
244 | """Test CSRF double-submit cookie validation."""
245 |
246 | # Setup session
247 | security_client.post("/auth/register", json=test_user)
248 | login_response = security_client.post(
249 | "/auth/login",
250 | json={"email": test_user["email"], "password": test_user["password"]},
251 | )
252 | csrf_token = login_response.json().get("csrf_token")
253 | cookies = login_response.cookies
254 |
255 | # Attack: Modify CSRF cookie but keep header the same
256 | modified_cookies = cookies.copy()
257 | if "maverick_csrf_token" in modified_cookies:
258 | modified_cookies["maverick_csrf_token"] = "modified_csrf_token"
259 |
260 | attack_response = security_client.post(
261 | "/user/profile",
262 | json={"name": "CSRF Cookie Attack"},
263 | headers={"X-CSRF-Token": csrf_token},
264 | cookies=modified_cookies,
265 | )
266 | assert attack_response.status_code == 403
267 |
268 | @pytest.mark.integration
269 | def test_csrf_token_entropy_and_uniqueness(self, security_client, test_user):
270 | """Test CSRF tokens have sufficient entropy and are unique."""
271 |
272 | # Register user
273 | security_client.post("/auth/register", json=test_user)
274 |
275 | # Generate multiple CSRF tokens
276 | csrf_tokens = []
277 | for _i in range(5):
278 | response = security_client.post(
279 | "/auth/login",
280 | json={"email": test_user["email"], "password": test_user["password"]},
281 | )
282 | csrf_token = response.json().get("csrf_token")
283 | if csrf_token:
284 | csrf_tokens.append(csrf_token)
285 |
286 | if csrf_tokens:
287 | # All tokens should be unique
288 | assert len(set(csrf_tokens)) == len(csrf_tokens)
289 |
290 | # Tokens should have sufficient length (at least 32 chars)
291 | for token in csrf_tokens:
292 | assert len(token) >= 32
293 |
294 | # Tokens should not be predictable patterns
295 | for i, token in enumerate(csrf_tokens[1:], 1):
296 | # Should not be sequential or pattern-based
297 | assert token != csrf_tokens[0] + str(i)
298 | assert not token.startswith(csrf_tokens[0][:-5])
299 |
300 |
301 | class TestRateLimitingEvasion:
302 | """Test rate limiting against evasion attempts."""
303 |
304 | @pytest.mark.integration
305 | def test_ip_based_rate_limit_evasion(self, security_client):
306 | """Test attempts to evade IP-based rate limiting."""
307 |
308 | # Test basic rate limiting
309 | responses = []
310 | for _i in range(25):
311 | response = security_client.get("/api/data")
312 | responses.append(response.status_code)
313 |
314 | # Should hit rate limit
315 | sum(1 for status in responses if status == 200)
316 | rate_limited_count = sum(1 for status in responses if status == 429)
317 | assert rate_limited_count > 0 # Should have some rate limited responses
318 |
319 | # Attempt 1: X-Forwarded-For header spoofing
320 | spoofed_responses = []
321 | for i in range(10):
322 | response = security_client.get(
323 | "/api/data", headers={"X-Forwarded-For": f"192.168.1.{i}"}
324 | )
325 | spoofed_responses.append(response.status_code)
326 |
327 | # Should still be rate limited (proper implementation should use real IP)
328 | sum(1 for status in spoofed_responses if status == 429)
329 |
330 | # Attempt 2: X-Real-IP header spoofing
331 | real_ip_responses = []
332 | for i in range(5):
333 | response = security_client.get(
334 | "/api/data", headers={"X-Real-IP": f"10.0.0.{i}"}
335 | )
336 | real_ip_responses.append(response.status_code)
337 |
338 | # Rate limiting should not be easily bypassed
339 |
340 | @pytest.mark.integration
341 | def test_user_agent_rotation_evasion(self, security_client):
342 | """Test rate limiting against user agent rotation."""
343 |
344 | user_agents = [
345 | "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
346 | "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36",
347 | "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36",
348 | "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:91.0) Gecko/20100101",
349 | "Mozilla/5.0 (iPhone; CPU iPhone OS 14_7_1 like Mac OS X)",
350 | ]
351 |
352 | # Attempt to evade rate limiting by rotating user agents
353 | ua_responses = []
354 | for i in range(15):
355 | ua = user_agents[i % len(user_agents)]
356 | response = security_client.get("/api/data", headers={"User-Agent": ua})
357 | ua_responses.append(response.status_code)
358 |
359 | # Should still enforce rate limiting regardless of user agent
360 | sum(1 for status in ua_responses if status == 429)
361 | # Should have some rate limiting if effective
362 |
363 | @pytest.mark.integration
364 | def test_distributed_rate_limit_evasion(self, security_client):
365 | """Test against distributed rate limit evasion attempts."""
366 |
367 | # Simulate requests with small delays (trying to stay under rate limits)
368 | distributed_responses = []
369 | for _i in range(10):
370 | response = security_client.get("/api/data")
371 | distributed_responses.append(response.status_code)
372 | time.sleep(0.1) # Small delay
373 |
374 | # Even with delays, sustained high-rate requests should be limited
375 | # This tests if rate limiting has proper time windows
376 |
377 |
378 | class TestInputValidationBypass:
379 | """Test input validation against bypass attempts."""
380 |
381 | @pytest.mark.integration
382 | def test_sql_injection_prevention(self, security_client, test_user):
383 | """Test SQL injection prevention."""
384 |
385 | # SQL injection payloads
386 | sql_payloads = [
387 | "'; DROP TABLE users; --",
388 | "' OR '1'='1",
389 | "' UNION SELECT * FROM users --",
390 | "'; DELETE FROM users WHERE '1'='1",
391 | "' OR 1=1 --",
392 | "admin'--",
393 | "admin'/*",
394 | "' OR 'x'='x",
395 | "' AND id IS NULL; --",
396 | "'OR 1=1#",
397 | ]
398 |
399 | # Test SQL injection in login email field
400 | for payload in sql_payloads:
401 | response = security_client.post(
402 | "/auth/login", json={"email": payload, "password": "any_password"}
403 | )
404 |
405 | # Should handle gracefully without SQL errors
406 | assert response.status_code in [400, 401, 422] # Not 500 (SQL error)
407 |
408 | # Response should not contain SQL error messages
409 | response_text = response.text.lower()
410 | sql_error_indicators = [
411 | "syntax error",
412 | "sql",
413 | "mysql",
414 | "postgresql",
415 | "sqlite",
416 | "database",
417 | "column",
418 | "table",
419 | "select",
420 | "union",
421 | ]
422 |
423 | for indicator in sql_error_indicators:
424 | assert indicator not in response_text
425 |
426 | # Test SQL injection in registration fields
427 | for field in ["name", "company"]:
428 | malicious_user = test_user.copy()
429 | malicious_user[field] = "'; DROP TABLE users; --"
430 |
431 | response = security_client.post("/auth/register", json=malicious_user)
432 |
433 | # Should either reject or sanitize the input
434 | assert response.status_code in [200, 201, 400, 422]
435 |
436 | if response.status_code in [200, 201]:
437 | # If accepted, verify it's sanitized
438 | login_response = security_client.post(
439 | "/auth/login",
440 | json={
441 | "email": malicious_user["email"],
442 | "password": malicious_user["password"],
443 | },
444 | )
445 |
446 | if login_response.status_code == 200:
447 | csrf_token = login_response.json().get("csrf_token")
448 | profile_response = security_client.get(
449 | "/user/profile", headers={"X-CSRF-Token": csrf_token}
450 | )
451 |
452 | if profile_response.status_code == 200:
453 | profile_data = profile_response.json()
454 | # SQL injection should be sanitized
455 | assert "DROP TABLE" not in profile_data.get(field, "")
456 |
457 | @pytest.mark.integration
458 | def test_xss_prevention(self, security_client, test_user):
459 | """Test XSS prevention."""
460 |
461 | xss_payloads = [
462 | "<script>alert('XSS')</script>",
463 | "<img src=x onerror=alert('XSS')>",
464 | "javascript:alert('XSS')",
465 | "<svg onload=alert('XSS')>",
466 | "<iframe src=javascript:alert('XSS')>",
467 | "';alert('XSS');//",
468 | "<body onload=alert('XSS')>",
469 | "<input onfocus=alert('XSS') autofocus>",
470 | "<select onfocus=alert('XSS') autofocus>",
471 | "<textarea onfocus=alert('XSS') autofocus>",
472 | ]
473 |
474 | for payload in xss_payloads:
475 | # Test XSS in user registration
476 | malicious_user = test_user.copy()
477 | malicious_user["email"] = f"xss{uuid4().hex[:8]}@example.com"
478 | malicious_user["name"] = payload
479 |
480 | response = security_client.post("/auth/register", json=malicious_user)
481 |
482 | if response.status_code in [200, 201]:
483 | # Login and check profile
484 | login_response = security_client.post(
485 | "/auth/login",
486 | json={
487 | "email": malicious_user["email"],
488 | "password": malicious_user["password"],
489 | },
490 | )
491 |
492 | if login_response.status_code == 200:
493 | csrf_token = login_response.json().get("csrf_token")
494 | profile_response = security_client.get(
495 | "/user/profile", headers={"X-CSRF-Token": csrf_token}
496 | )
497 |
498 | if profile_response.status_code == 200:
499 | profile_data = profile_response.json()
500 | stored_name = profile_data.get("name", "")
501 |
502 | # XSS should be escaped or removed
503 | assert "<script>" not in stored_name
504 | assert "javascript:" not in stored_name
505 | assert "onerror=" not in stored_name
506 | assert "onload=" not in stored_name
507 | assert "alert(" not in stored_name
508 |
509 | @pytest.mark.integration
510 | def test_path_traversal_prevention(self, security_client):
511 | """Test path traversal prevention."""
512 |
513 | path_traversal_payloads = [
514 | "../../../etc/passwd",
515 | "..\\..\\..\\windows\\system32\\config\\sam",
516 | "....//....//....//etc/passwd",
517 | "%2e%2e%2f%2e%2e%2f%2e%2e%2fetc%2fpasswd",
518 | "..%252f..%252f..%252fetc%252fpasswd",
519 | "..%c0%af..%c0%af..%c0%afetc%c0%afpasswd",
520 | ]
521 |
522 | # Test path traversal in file access endpoints (if any)
523 | for payload in path_traversal_payloads:
524 | # Test in URL path
525 | response = security_client.get(f"/api/files/{payload}")
526 |
527 | # Should return 404 or 400, not 500 or file contents
528 | assert response.status_code in [400, 404, 422]
529 |
530 | # Should not return file system contents
531 | response_text = response.text.lower()
532 | sensitive_file_indicators = [
533 | "root:",
534 | "daemon:",
535 | "bin:",
536 | "sys:", # /etc/passwd content
537 | "[boot loader]",
538 | "[operating systems]", # Windows boot.ini
539 | "password",
540 | "hash",
541 | "secret",
542 | ]
543 |
544 | for indicator in sensitive_file_indicators:
545 | assert indicator not in response_text
546 |
547 | @pytest.mark.integration
548 | def test_command_injection_prevention(self, security_client, test_user):
549 | """Test command injection prevention."""
550 |
551 | command_injection_payloads = [
552 | "; cat /etc/passwd",
553 | "| cat /etc/passwd",
554 | "& dir",
555 | "`cat /etc/passwd`",
556 | "$(cat /etc/passwd)",
557 | "; rm -rf /",
558 | "&& rm -rf /",
559 | "|| rm -rf /",
560 | "; shutdown -h now",
561 | "'; whoami; echo '",
562 | ]
563 |
564 | # Test command injection in various fields
565 | for payload in command_injection_payloads:
566 | malicious_user = test_user.copy()
567 | malicious_user["email"] = f"cmd{uuid4().hex[:8]}@example.com"
568 | malicious_user["company"] = payload
569 |
570 | response = security_client.post("/auth/register", json=malicious_user)
571 |
572 | # Should handle gracefully
573 | assert response.status_code in [200, 201, 400, 422]
574 |
575 | # Should not execute system commands
576 | response_text = response.text
577 | command_output_indicators = [
578 | "root:",
579 | "daemon:",
580 | "bin:", # Output of cat /etc/passwd
581 | "total ",
582 | "drwx", # Output of ls -la
583 | "uid=",
584 | "gid=", # Output of whoami/id
585 | ]
586 |
587 | for indicator in command_output_indicators:
588 | assert indicator not in response_text
589 |
590 |
591 | class TestInformationDisclosure:
592 | """Test prevention of information disclosure."""
593 |
594 | @pytest.mark.integration
595 | def test_error_message_sanitization(self, security_client):
596 | """Test that error messages don't leak sensitive information."""
597 |
598 | # Test 404 error
599 | response = security_client.get("/nonexistent/endpoint/123")
600 | assert response.status_code == 404
601 |
602 | error_data = response.json()
603 | error_message = str(error_data).lower()
604 |
605 | # Should not contain sensitive system information
606 | sensitive_info = [
607 | "/users/",
608 | "/home/",
609 | "\\users\\",
610 | "\\home\\", # File paths
611 | "password",
612 | "secret",
613 | "key",
614 | "token",
615 | "jwt", # Credentials
616 | "localhost",
617 | "127.0.0.1",
618 | "redis://",
619 | "postgresql://", # Internal addresses
620 | "traceback",
621 | "stack trace",
622 | "exception",
623 | "error at", # Stack traces
624 | "python",
625 | "uvicorn",
626 | "fastapi",
627 | "sqlalchemy", # Framework details
628 | "database",
629 | "sql",
630 | "query",
631 | "connection", # Database details
632 | ]
633 |
634 | for info in sensitive_info:
635 | assert info not in error_message
636 |
637 | # Should include request ID for tracking
638 | assert "request_id" in error_data or "error_id" in error_data
639 |
640 | @pytest.mark.integration
641 | def test_debug_information_disclosure(self, security_client):
642 | """Test that debug information is not disclosed."""
643 |
644 | # Attempt to trigger various error conditions
645 | error_test_cases = [
646 | ("/auth/login", {"invalid": "json_structure"}),
647 | ("/user/profile", {}), # Missing authentication
648 | ]
649 |
650 | for endpoint, data in error_test_cases:
651 | response = security_client.post(endpoint, json=data)
652 |
653 | # Should not contain debug information
654 | response_text = response.text.lower()
655 | debug_indicators = [
656 | "traceback",
657 | "stack trace",
658 | "file ",
659 | "line ",
660 | "exception",
661 | "raise ",
662 | "assert",
663 | "debug",
664 | "__file__",
665 | "__name__",
666 | "locals()",
667 | "globals()",
668 | ]
669 |
670 | for indicator in debug_indicators:
671 | assert indicator not in response_text
672 |
673 | @pytest.mark.integration
674 | def test_version_information_disclosure(self, security_client):
675 | """Test that version information is not disclosed."""
676 |
677 | # Test common endpoints that might leak version info
678 | test_endpoints = [
679 | "/health",
680 | "/",
681 | "/api/docs",
682 | "/metrics",
683 | ]
684 |
685 | for endpoint in test_endpoints:
686 | response = security_client.get(endpoint)
687 |
688 | if response.status_code == 200:
689 | response_text = response.text.lower()
690 |
691 | # Should not contain detailed version information
692 | version_indicators = [
693 | "python/",
694 | "fastapi/",
695 | "uvicorn/",
696 | "nginx/",
697 | "version",
698 | "build",
699 | "commit",
700 | "git",
701 | "dev",
702 | "debug",
703 | "staging",
704 | "test",
705 | ]
706 |
707 | # Some version info might be acceptable in health endpoints
708 | if endpoint != "/health":
709 | for indicator in version_indicators:
710 | assert indicator not in response_text
711 |
712 | @pytest.mark.integration
713 | def test_user_enumeration_prevention(self, security_client):
714 | """Test prevention of user enumeration attacks."""
715 |
716 | # Test with valid email (user exists)
717 | existing_user = {
718 | "email": f"existing{uuid4().hex[:8]}@example.com",
719 | "password": "ValidPass123!",
720 | "name": "Existing User",
721 | }
722 | security_client.post("/auth/register", json=existing_user)
723 |
724 | # Test login with existing user but wrong password
725 | response_existing = security_client.post(
726 | "/auth/login",
727 | json={"email": existing_user["email"], "password": "wrong_password"},
728 | )
729 |
730 | # Test login with non-existing user
731 | response_nonexisting = security_client.post(
732 | "/auth/login",
733 | json={
734 | "email": f"nonexisting{uuid4().hex[:8]}@example.com",
735 | "password": "any_password",
736 | },
737 | )
738 |
739 | # Both should return similar error messages and status codes
740 | assert response_existing.status_code == response_nonexisting.status_code
741 |
742 | # Error messages should not distinguish between cases
743 | error_1 = response_existing.json().get("detail", "")
744 | error_2 = response_nonexisting.json().get("detail", "")
745 |
746 | # Should not contain user-specific information
747 | user_specific_terms = [
748 | "user not found",
749 | "user does not exist",
750 | "invalid user",
751 | "email not found",
752 | "account not found",
753 | "user unknown",
754 | ]
755 |
756 | for term in user_specific_terms:
757 | assert term.lower() not in error_1.lower()
758 | assert term.lower() not in error_2.lower()
759 |
760 |
761 | if __name__ == "__main__":
762 | pytest.main([__file__, "-v", "--tb=short"])
763 |
```
--------------------------------------------------------------------------------
/maverick_mcp/core/technical_analysis.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Technical analysis functions for Maverick-MCP.
3 |
4 | This module contains functions for performing technical analysis on financial data,
5 | including calculating indicators, analyzing trends, and generating trading signals.
6 |
7 | DISCLAIMER: All technical analysis functions in this module are for educational
8 | purposes only. Technical indicators are mathematical calculations based on historical
9 | data and do not predict future price movements. Past performance does not guarantee
10 | future results. Always conduct thorough research and consult with qualified financial
11 | professionals before making investment decisions.
12 | """
13 |
14 | import logging
15 | from collections.abc import Sequence
16 | from typing import Any
17 |
18 | import numpy as np
19 | import pandas as pd
20 | import pandas_ta as ta
21 |
22 | from maverick_mcp.config.technical_constants import TECHNICAL_CONFIG
23 |
24 | # Set up logging
25 | logging.basicConfig(
26 | level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
27 | )
28 | logger = logging.getLogger("maverick_mcp.technical_analysis")
29 |
30 |
31 | def _get_column_case_insensitive(df: pd.DataFrame, column_name: str) -> str | None:
32 | """
33 | Get the actual column name from the dataframe in a case-insensitive way.
34 |
35 | Args:
36 | df: DataFrame to search
37 | column_name: Name of the column to find (case-insensitive)
38 |
39 | Returns:
40 | The actual column name if found, None otherwise
41 | """
42 | if column_name in df.columns:
43 | return column_name
44 |
45 | column_name_lower = column_name.lower()
46 | for col in df.columns:
47 | if col.lower() == column_name_lower:
48 | return col
49 | return None
50 |
51 |
52 | def add_technical_indicators(df: pd.DataFrame) -> pd.DataFrame:
53 | """
54 | Add technical indicators to the dataframe
55 |
56 | Args:
57 | df: DataFrame with OHLCV price data
58 |
59 | Returns:
60 | DataFrame with added technical indicators
61 | """
62 | # Ensure column names are lowercase
63 | df = df.copy()
64 | df.columns = [col.lower() for col in df.columns]
65 |
66 | # Use pandas_ta for all indicators with configurable parameters
67 | # EMA
68 | df["ema_21"] = ta.ema(df["close"], length=TECHNICAL_CONFIG.EMA_PERIOD)
69 | # SMA
70 | df["sma_50"] = ta.sma(df["close"], length=TECHNICAL_CONFIG.SMA_SHORT_PERIOD)
71 | df["sma_200"] = ta.sma(df["close"], length=TECHNICAL_CONFIG.SMA_LONG_PERIOD)
72 | # RSI
73 | df["rsi"] = ta.rsi(df["close"], length=TECHNICAL_CONFIG.RSI_PERIOD)
74 | # MACD
75 | macd = ta.macd(
76 | df["close"],
77 | fast=TECHNICAL_CONFIG.MACD_FAST_PERIOD,
78 | slow=TECHNICAL_CONFIG.MACD_SLOW_PERIOD,
79 | signal=TECHNICAL_CONFIG.MACD_SIGNAL_PERIOD,
80 | )
81 | if macd is not None and not macd.empty:
82 | df["macd_12_26_9"] = macd["MACD_12_26_9"]
83 | df["macds_12_26_9"] = macd["MACDs_12_26_9"]
84 | df["macdh_12_26_9"] = macd["MACDh_12_26_9"]
85 | else:
86 | df["macd_12_26_9"] = np.nan
87 | df["macds_12_26_9"] = np.nan
88 | df["macdh_12_26_9"] = np.nan
89 | # Bollinger Bands
90 | bbands = ta.bbands(df["close"], length=20, std=2.0)
91 | if bbands is not None and not bbands.empty:
92 | resolved_columns = _resolve_bollinger_columns(bbands.columns)
93 | if resolved_columns:
94 | mid_col, upper_col, lower_col = resolved_columns
95 | df["sma_20"] = bbands[mid_col]
96 | df["bbu_20_2.0"] = bbands[upper_col]
97 | df["bbl_20_2.0"] = bbands[lower_col]
98 | else:
99 | logger.warning(
100 | "Bollinger Bands columns missing expected names: %s",
101 | list(bbands.columns),
102 | )
103 | df["sma_20"] = np.nan
104 | df["bbu_20_2.0"] = np.nan
105 | df["bbl_20_2.0"] = np.nan
106 | else:
107 | df["sma_20"] = np.nan
108 | df["bbu_20_2.0"] = np.nan
109 | df["bbl_20_2.0"] = np.nan
110 | df["stdev"] = df["close"].rolling(window=20).std()
111 | # ATR
112 | df["atr"] = ta.atr(df["high"], df["low"], df["close"], length=14)
113 | # Stochastic Oscillator
114 | stoch = ta.stoch(df["high"], df["low"], df["close"], k=14, d=3, smooth_k=3)
115 | if stoch is not None and not stoch.empty:
116 | df["stochk_14_3_3"] = stoch["STOCHk_14_3_3"]
117 | df["stochd_14_3_3"] = stoch["STOCHd_14_3_3"]
118 | else:
119 | df["stochk_14_3_3"] = np.nan
120 | df["stochd_14_3_3"] = np.nan
121 | # ADX
122 | adx = ta.adx(df["high"], df["low"], df["close"], length=14)
123 | if adx is not None and not adx.empty:
124 | df["adx_14"] = adx["ADX_14"]
125 | else:
126 | df["adx_14"] = np.nan
127 |
128 | return df
129 |
130 |
131 | def _resolve_bollinger_columns(columns: Sequence[str]) -> tuple[str, str, str] | None:
132 | """Resolve Bollinger Band column names across pandas-ta variants."""
133 |
134 | candidate_sets = [
135 | ("BBM_20_2.0", "BBU_20_2.0", "BBL_20_2.0"),
136 | ("BBM_20_2", "BBU_20_2", "BBL_20_2"),
137 | ]
138 |
139 | for candidate in candidate_sets:
140 | if set(candidate).issubset(columns):
141 | return candidate
142 |
143 | mid_candidates = [column for column in columns if column.startswith("BBM_")]
144 | upper_candidates = [column for column in columns if column.startswith("BBU_")]
145 | lower_candidates = [column for column in columns if column.startswith("BBL_")]
146 |
147 | if mid_candidates and upper_candidates and lower_candidates:
148 | return mid_candidates[0], upper_candidates[0], lower_candidates[0]
149 |
150 | return None
151 |
152 |
153 | def identify_support_levels(df: pd.DataFrame) -> list[float]:
154 | """
155 | Identify support levels using recent lows
156 |
157 | Args:
158 | df: DataFrame with price data
159 |
160 | Returns:
161 | List of support price levels
162 | """
163 | # Use the lowest points in recent periods
164 | last_month = df.iloc[-30:] if len(df) >= 30 else df
165 | min_price = last_month["low"].min()
166 |
167 | # Additional support levels
168 | support_levels = [
169 | round(min_price, 2),
170 | round(df["close"].iloc[-1] * 0.95, 2), # 5% below current price
171 | round(df["close"].iloc[-1] * 0.90, 2), # 10% below current price
172 | ]
173 |
174 | return sorted(set(support_levels))
175 |
176 |
177 | def identify_resistance_levels(df: pd.DataFrame) -> list[float]:
178 | """
179 | Identify resistance levels using recent highs
180 |
181 | Args:
182 | df: DataFrame with price data
183 |
184 | Returns:
185 | List of resistance price levels
186 | """
187 | # Use the highest points in recent periods
188 | last_month = df.iloc[-30:] if len(df) >= 30 else df
189 | max_price = last_month["high"].max()
190 |
191 | # Additional resistance levels
192 | resistance_levels = [
193 | round(max_price, 2),
194 | round(df["close"].iloc[-1] * 1.05, 2), # 5% above current price
195 | round(df["close"].iloc[-1] * 1.10, 2), # 10% above current price
196 | ]
197 |
198 | return sorted(set(resistance_levels))
199 |
200 |
201 | def analyze_trend(df: pd.DataFrame) -> int:
202 | """
203 | Calculate the trend strength of a stock based on various technical indicators.
204 |
205 | Args:
206 | df: DataFrame with price and indicator data
207 |
208 | Returns:
209 | Integer trend strength score (0-7)
210 | """
211 | try:
212 | trend_strength = 0
213 | close_price = df["close"].iloc[-1]
214 |
215 | # Check SMA 50
216 | sma_50 = df["sma_50"].iloc[-1]
217 | if pd.notna(sma_50) and close_price > sma_50:
218 | trend_strength += 1
219 |
220 | # Check EMA 21
221 | ema_21 = df["ema_21"].iloc[-1]
222 | if pd.notna(ema_21) and close_price > ema_21:
223 | trend_strength += 1
224 |
225 | # Check EMA 21 vs SMA 50
226 | if pd.notna(ema_21) and pd.notna(sma_50) and ema_21 > sma_50:
227 | trend_strength += 1
228 |
229 | # Check SMA 50 vs SMA 200
230 | sma_200 = df["sma_200"].iloc[-1]
231 | if pd.notna(sma_50) and pd.notna(sma_200) and sma_50 > sma_200:
232 | trend_strength += 1
233 |
234 | # Check RSI
235 | rsi = df["rsi"].iloc[-1]
236 | if pd.notna(rsi) and rsi > 50:
237 | trend_strength += 1
238 |
239 | # Check MACD
240 | macd = df["macd_12_26_9"].iloc[-1]
241 | if pd.notna(macd) and macd > 0:
242 | trend_strength += 1
243 |
244 | # Check ADX
245 | adx = df["adx_14"].iloc[-1]
246 | if pd.notna(adx) and adx > 25:
247 | trend_strength += 1
248 |
249 | return trend_strength
250 | except Exception as e:
251 | logger.error(f"Error calculating trend strength: {e}")
252 | return 0
253 |
254 |
255 | def analyze_rsi(df: pd.DataFrame) -> dict[str, Any]:
256 | """
257 | Analyze RSI indicator
258 |
259 | Args:
260 | df: DataFrame with price and indicator data
261 |
262 | Returns:
263 | Dictionary with RSI analysis
264 | """
265 | try:
266 | # Check if dataframe is valid and has RSI column
267 | if df.empty:
268 | return {
269 | "current": None,
270 | "signal": "unavailable",
271 | "description": "No data available for RSI calculation",
272 | }
273 |
274 | if "rsi" not in df.columns:
275 | return {
276 | "current": None,
277 | "signal": "unavailable",
278 | "description": "RSI indicator not calculated",
279 | }
280 |
281 | if len(df) == 0:
282 | return {
283 | "current": None,
284 | "signal": "unavailable",
285 | "description": "Insufficient data for RSI calculation",
286 | }
287 |
288 | rsi = df["rsi"].iloc[-1]
289 |
290 | # Check if RSI is NaN
291 | if pd.isna(rsi):
292 | return {
293 | "current": None,
294 | "signal": "unavailable",
295 | "description": "RSI data not available (insufficient data points)",
296 | }
297 |
298 | if rsi > 70:
299 | signal = "overbought"
300 | elif rsi < 30:
301 | signal = "oversold"
302 | elif rsi > 50:
303 | signal = "bullish"
304 | else:
305 | signal = "bearish"
306 |
307 | return {
308 | "current": round(rsi, 2),
309 | "signal": signal,
310 | "description": f"RSI is currently at {round(rsi, 2)}, indicating {signal} conditions.",
311 | }
312 | except Exception as e:
313 | logger.error(f"Error analyzing RSI: {e}")
314 | return {
315 | "current": None,
316 | "signal": "error",
317 | "description": f"Error calculating RSI: {str(e)}",
318 | }
319 |
320 |
321 | def analyze_macd(df: pd.DataFrame) -> dict[str, Any]:
322 | """
323 | Analyze MACD indicator
324 |
325 | Args:
326 | df: DataFrame with price and indicator data
327 |
328 | Returns:
329 | Dictionary with MACD analysis
330 | """
331 | try:
332 | macd = df["macd_12_26_9"].iloc[-1]
333 | signal = df["macds_12_26_9"].iloc[-1]
334 | histogram = df["macdh_12_26_9"].iloc[-1]
335 |
336 | # Check if any values are NaN
337 | if pd.isna(macd) or pd.isna(signal) or pd.isna(histogram):
338 | return {
339 | "macd": None,
340 | "signal": None,
341 | "histogram": None,
342 | "indicator": "unavailable",
343 | "crossover": "unavailable",
344 | "description": "MACD data not available (insufficient data points)",
345 | }
346 |
347 | if macd > signal and histogram > 0:
348 | signal_type = "bullish"
349 | elif macd < signal and histogram < 0:
350 | signal_type = "bearish"
351 | elif macd > signal and macd < 0:
352 | signal_type = "improving"
353 | elif macd < signal and macd > 0:
354 | signal_type = "weakening"
355 | else:
356 | signal_type = "neutral"
357 |
358 | # Check for crossover (ensure we have enough data)
359 | crossover = "no recent crossover"
360 | if len(df) >= 2:
361 | prev_macd = df["macd_12_26_9"].iloc[-2]
362 | prev_signal = df["macds_12_26_9"].iloc[-2]
363 | if pd.notna(prev_macd) and pd.notna(prev_signal):
364 | if prev_macd <= prev_signal and macd > signal:
365 | crossover = "bullish crossover detected"
366 | elif prev_macd >= prev_signal and macd < signal:
367 | crossover = "bearish crossover detected"
368 |
369 | return {
370 | "macd": round(macd, 2),
371 | "signal": round(signal, 2),
372 | "histogram": round(histogram, 2),
373 | "indicator": signal_type,
374 | "crossover": crossover,
375 | "description": f"MACD is {signal_type} with {crossover}.",
376 | }
377 | except Exception as e:
378 | logger.error(f"Error analyzing MACD: {e}")
379 | return {
380 | "macd": None,
381 | "signal": None,
382 | "histogram": None,
383 | "indicator": "error",
384 | "crossover": "error",
385 | "description": "Error calculating MACD",
386 | }
387 |
388 |
389 | def analyze_stochastic(df: pd.DataFrame) -> dict[str, Any]:
390 | """
391 | Analyze Stochastic Oscillator
392 |
393 | Args:
394 | df: DataFrame with price and indicator data
395 |
396 | Returns:
397 | Dictionary with stochastic oscillator analysis
398 | """
399 | try:
400 | k = df["stochk_14_3_3"].iloc[-1]
401 | d = df["stochd_14_3_3"].iloc[-1]
402 |
403 | # Check if values are NaN
404 | if pd.isna(k) or pd.isna(d):
405 | return {
406 | "k": None,
407 | "d": None,
408 | "signal": "unavailable",
409 | "crossover": "unavailable",
410 | "description": "Stochastic data not available (insufficient data points)",
411 | }
412 |
413 | if k > 80 and d > 80:
414 | signal = "overbought"
415 | elif k < 20 and d < 20:
416 | signal = "oversold"
417 | elif k > d:
418 | signal = "bullish"
419 | else:
420 | signal = "bearish"
421 |
422 | # Check for crossover (ensure we have enough data)
423 | crossover = "no recent crossover"
424 | if len(df) >= 2:
425 | prev_k = df["stochk_14_3_3"].iloc[-2]
426 | prev_d = df["stochd_14_3_3"].iloc[-2]
427 | if pd.notna(prev_k) and pd.notna(prev_d):
428 | if prev_k <= prev_d and k > d:
429 | crossover = "bullish crossover detected"
430 | elif prev_k >= prev_d and k < d:
431 | crossover = "bearish crossover detected"
432 |
433 | return {
434 | "k": round(k, 2),
435 | "d": round(d, 2),
436 | "signal": signal,
437 | "crossover": crossover,
438 | "description": f"Stochastic Oscillator is {signal} with {crossover}.",
439 | }
440 | except Exception as e:
441 | logger.error(f"Error analyzing Stochastic: {e}")
442 | return {
443 | "k": None,
444 | "d": None,
445 | "signal": "error",
446 | "crossover": "error",
447 | "description": "Error calculating Stochastic",
448 | }
449 |
450 |
451 | def analyze_bollinger_bands(df: pd.DataFrame) -> dict[str, Any]:
452 | """
453 | Analyze Bollinger Bands
454 |
455 | Args:
456 | df: DataFrame with price and indicator data
457 |
458 | Returns:
459 | Dictionary with Bollinger Bands analysis
460 | """
461 | try:
462 | current_price = df["close"].iloc[-1]
463 | upper_band = df["bbu_20_2.0"].iloc[-1]
464 | lower_band = df["bbl_20_2.0"].iloc[-1]
465 | middle_band = df["sma_20"].iloc[-1]
466 |
467 | # Check if any values are NaN
468 | if pd.isna(upper_band) or pd.isna(lower_band) or pd.isna(middle_band):
469 | return {
470 | "upper_band": None,
471 | "middle_band": None,
472 | "lower_band": None,
473 | "position": "unavailable",
474 | "signal": "unavailable",
475 | "volatility": "unavailable",
476 | "description": "Bollinger Bands data not available (insufficient data points)",
477 | }
478 |
479 | if current_price > upper_band:
480 | position = "above upper band"
481 | signal = "overbought"
482 | elif current_price < lower_band:
483 | position = "below lower band"
484 | signal = "oversold"
485 | elif current_price > middle_band:
486 | position = "above middle band"
487 | signal = "bullish"
488 | else:
489 | position = "below middle band"
490 | signal = "bearish"
491 |
492 | # Check for BB squeeze (volatility contraction)
493 | volatility = "stable"
494 | if len(df) >= 5:
495 | try:
496 | bb_widths = []
497 | for i in range(-5, 0):
498 | upper = df["bbu_20_2.0"].iloc[i]
499 | lower = df["bbl_20_2.0"].iloc[i]
500 | middle = df["sma_20"].iloc[i]
501 | if (
502 | pd.notna(upper)
503 | and pd.notna(lower)
504 | and pd.notna(middle)
505 | and middle != 0
506 | ):
507 | bb_widths.append((upper - lower) / middle)
508 |
509 | if len(bb_widths) == 5:
510 | if all(bb_widths[i] < bb_widths[i - 1] for i in range(1, 5)):
511 | volatility = "contracting (potential breakout ahead)"
512 | elif all(bb_widths[i] > bb_widths[i - 1] for i in range(1, 5)):
513 | volatility = "expanding (increased volatility)"
514 | except Exception:
515 | # If volatility calculation fails, keep it as stable
516 | pass
517 |
518 | return {
519 | "upper_band": round(upper_band, 2),
520 | "middle_band": round(middle_band, 2),
521 | "lower_band": round(lower_band, 2),
522 | "position": position,
523 | "signal": signal,
524 | "volatility": volatility,
525 | "description": f"Price is {position}, indicating {signal} conditions. Volatility is {volatility}.",
526 | }
527 | except Exception as e:
528 | logger.error(f"Error analyzing Bollinger Bands: {e}")
529 | return {
530 | "upper_band": None,
531 | "middle_band": None,
532 | "lower_band": None,
533 | "position": "error",
534 | "signal": "error",
535 | "volatility": "error",
536 | "description": "Error calculating Bollinger Bands",
537 | }
538 |
539 |
540 | def analyze_volume(df: pd.DataFrame) -> dict[str, Any]:
541 | """
542 | Analyze volume patterns
543 |
544 | Args:
545 | df: DataFrame with price and volume data
546 |
547 | Returns:
548 | Dictionary with volume analysis
549 | """
550 | try:
551 | current_volume = df["volume"].iloc[-1]
552 |
553 | # Check if we have enough data for average
554 | if len(df) < 10:
555 | avg_volume = df["volume"].mean()
556 | else:
557 | avg_volume = df["volume"].iloc[-10:].mean()
558 |
559 | # Check for invalid values
560 | if pd.isna(current_volume) or pd.isna(avg_volume) or avg_volume == 0:
561 | return {
562 | "current": None,
563 | "average": None,
564 | "ratio": None,
565 | "description": "unavailable",
566 | "signal": "unavailable",
567 | }
568 |
569 | volume_ratio = current_volume / avg_volume
570 |
571 | if volume_ratio > 1.5:
572 | volume_desc = "above average"
573 | if len(df) >= 2 and df["close"].iloc[-1] > df["close"].iloc[-2]:
574 | signal = "bullish (high volume on up move)"
575 | else:
576 | signal = "bearish (high volume on down move)"
577 | elif volume_ratio < 0.7:
578 | volume_desc = "below average"
579 | signal = "weak conviction"
580 | else:
581 | volume_desc = "average"
582 | signal = "neutral"
583 |
584 | return {
585 | "current": int(current_volume),
586 | "average": int(avg_volume),
587 | "ratio": round(volume_ratio, 2),
588 | "description": volume_desc,
589 | "signal": signal,
590 | }
591 | except Exception as e:
592 | logger.error(f"Error analyzing volume: {e}")
593 | return {
594 | "current": None,
595 | "average": None,
596 | "ratio": None,
597 | "description": "error",
598 | "signal": "error",
599 | }
600 |
601 |
602 | def identify_chart_patterns(df: pd.DataFrame) -> list[str]:
603 | """
604 | Identify common chart patterns
605 |
606 | Args:
607 | df: DataFrame with price data
608 |
609 | Returns:
610 | List of identified chart patterns
611 | """
612 | patterns = []
613 |
614 | # Check for potential double bottom (W formation)
615 | if len(df) >= 40:
616 | recent_lows = df["low"].iloc[-40:].values
617 | potential_bottoms = []
618 |
619 | for i in range(1, len(recent_lows) - 1):
620 | if (
621 | recent_lows[i] < recent_lows[i - 1]
622 | and recent_lows[i] < recent_lows[i + 1]
623 | ):
624 | potential_bottoms.append(i)
625 |
626 | if (
627 | len(potential_bottoms) >= 2
628 | and potential_bottoms[-1] - potential_bottoms[-2] >= 5
629 | ):
630 | if (
631 | abs(
632 | recent_lows[potential_bottoms[-1]]
633 | - recent_lows[potential_bottoms[-2]]
634 | )
635 | / recent_lows[potential_bottoms[-2]]
636 | < 0.05
637 | ):
638 | patterns.append("Double Bottom (W)")
639 |
640 | # Check for potential double top (M formation)
641 | if len(df) >= 40:
642 | recent_highs = df["high"].iloc[-40:].values
643 | potential_tops = []
644 |
645 | for i in range(1, len(recent_highs) - 1):
646 | if (
647 | recent_highs[i] > recent_highs[i - 1]
648 | and recent_highs[i] > recent_highs[i + 1]
649 | ):
650 | potential_tops.append(i)
651 |
652 | if len(potential_tops) >= 2 and potential_tops[-1] - potential_tops[-2] >= 5:
653 | if (
654 | abs(recent_highs[potential_tops[-1]] - recent_highs[potential_tops[-2]])
655 | / recent_highs[potential_tops[-2]]
656 | < 0.05
657 | ):
658 | patterns.append("Double Top (M)")
659 |
660 | # Check for bullish flag/pennant
661 | if len(df) >= 20:
662 | recent_prices = df["close"].iloc[-20:].values
663 | if (
664 | recent_prices[0] < recent_prices[10]
665 | and all(
666 | recent_prices[i] >= recent_prices[i - 1] * 0.99 for i in range(1, 10)
667 | )
668 | and all(
669 | abs(recent_prices[i] - recent_prices[i - 1]) / recent_prices[i - 1]
670 | < 0.02
671 | for i in range(11, 20)
672 | )
673 | ):
674 | patterns.append("Bullish Flag/Pennant")
675 |
676 | # Check for bearish flag/pennant
677 | if len(df) >= 20:
678 | recent_prices = df["close"].iloc[-20:].values
679 | if (
680 | recent_prices[0] > recent_prices[10]
681 | and all(
682 | recent_prices[i] <= recent_prices[i - 1] * 1.01 for i in range(1, 10)
683 | )
684 | and all(
685 | abs(recent_prices[i] - recent_prices[i - 1]) / recent_prices[i - 1]
686 | < 0.02
687 | for i in range(11, 20)
688 | )
689 | ):
690 | patterns.append("Bearish Flag/Pennant")
691 |
692 | return patterns
693 |
694 |
695 | def calculate_atr(df: pd.DataFrame, period: int = 14) -> pd.Series:
696 | """
697 | Calculate Average True Range (ATR) for the given dataframe.
698 |
699 | Args:
700 | df: DataFrame with high, low, and close price data
701 | period: Period for ATR calculation (default: 14)
702 |
703 | Returns:
704 | Series with ATR values
705 | """
706 | # Optimized to avoid copying the entire dataframe
707 | high_col = _get_column_case_insensitive(df, "high")
708 | low_col = _get_column_case_insensitive(df, "low")
709 | close_col = _get_column_case_insensitive(df, "close")
710 |
711 | if not (high_col and low_col and close_col):
712 | # Fallback to old method if columns are not found (unlikely if they exist)
713 | # This preserves previous behavior for missing columns which might raise error later or handle it
714 | logger.warning("Could not find High, Low, Close columns case-insensitively. Falling back to copy method.")
715 | df_copy = df.copy()
716 | df_copy.columns = [col.lower() for col in df_copy.columns]
717 | return ta.atr(df_copy["high"], df_copy["low"], df_copy["close"], length=period)
718 |
719 | # Use pandas_ta to calculate ATR
720 | atr = ta.atr(df[high_col], df[low_col], df[close_col], length=period)
721 |
722 | # Ensure we return a Series
723 | if isinstance(atr, pd.Series):
724 | return atr
725 | elif isinstance(atr, pd.DataFrame):
726 | # If it's a DataFrame, take the first column
727 | return pd.Series(atr.iloc[:, 0])
728 | elif atr is not None:
729 | # If it's a numpy array or other iterable
730 | return pd.Series(atr)
731 | else:
732 | # Return empty series if calculation failed
733 | return pd.Series(dtype=float)
734 |
735 |
736 | def generate_outlook(
737 | df: pd.DataFrame,
738 | trend: str,
739 | rsi_analysis: dict[str, Any],
740 | macd_analysis: dict[str, Any],
741 | stoch_analysis: dict[str, Any],
742 | ) -> str:
743 | """
744 | Generate an overall outlook based on technical analysis
745 |
746 | Args:
747 | df: DataFrame with price and indicator data
748 | trend: Trend direction from analyze_trend
749 | rsi_analysis: RSI analysis from analyze_rsi
750 | macd_analysis: MACD analysis from analyze_macd
751 | stoch_analysis: Stochastic analysis from analyze_stochastic
752 |
753 | Returns:
754 | String with overall market outlook
755 | """
756 | bullish_signals = 0
757 | bearish_signals = 0
758 |
759 | # Count signals from different indicators
760 | if trend == "uptrend":
761 | bullish_signals += 2
762 | elif trend == "downtrend":
763 | bearish_signals += 2
764 |
765 | if rsi_analysis["signal"] == "bullish" or rsi_analysis["signal"] == "oversold":
766 | bullish_signals += 1
767 | elif rsi_analysis["signal"] == "bearish" or rsi_analysis["signal"] == "overbought":
768 | bearish_signals += 1
769 |
770 | if (
771 | macd_analysis["indicator"] == "bullish"
772 | or macd_analysis["crossover"] == "bullish crossover detected"
773 | ):
774 | bullish_signals += 1
775 | elif (
776 | macd_analysis["indicator"] == "bearish"
777 | or macd_analysis["crossover"] == "bearish crossover detected"
778 | ):
779 | bearish_signals += 1
780 |
781 | if stoch_analysis["signal"] == "bullish" or stoch_analysis["signal"] == "oversold":
782 | bullish_signals += 1
783 | elif (
784 | stoch_analysis["signal"] == "bearish"
785 | or stoch_analysis["signal"] == "overbought"
786 | ):
787 | bearish_signals += 1
788 |
789 | # Generate outlook based on signals
790 | if bullish_signals >= 4:
791 | return "strongly bullish"
792 | elif bullish_signals > bearish_signals:
793 | return "moderately bullish"
794 | elif bearish_signals >= 4:
795 | return "strongly bearish"
796 | elif bearish_signals > bullish_signals:
797 | return "moderately bearish"
798 | else:
799 | return "neutral"
800 |
801 |
802 | def calculate_rsi(df: pd.DataFrame, period: int = 14) -> pd.Series:
803 | """
804 | Calculate RSI (Relative Strength Index) for the given dataframe.
805 |
806 | Args:
807 | df: DataFrame with price data
808 | period: Period for RSI calculation (default: 14)
809 |
810 | Returns:
811 | Series with RSI values
812 | """
813 | # Optimized to avoid copying the entire dataframe
814 | close_col = _get_column_case_insensitive(df, "close")
815 |
816 | # Ensure we have the required 'close' column
817 | if not close_col:
818 | # Check if we should fallback or raise error immediately.
819 | # Original code: copies, lowercases, then checks for "close".
820 | # So if we can't find it case-insensitively, we can raise ValueError.
821 | raise ValueError("DataFrame must contain a 'close' or 'Close' column")
822 |
823 | # Use pandas_ta to calculate RSI
824 | rsi = ta.rsi(df[close_col], length=period)
825 |
826 | # Ensure we return a Series
827 | if isinstance(rsi, pd.Series):
828 | return rsi
829 | elif rsi is not None:
830 | # If it's a numpy array or other iterable
831 | return pd.Series(rsi, index=df.index)
832 | else:
833 | # Return empty series if calculation failed
834 | return pd.Series(dtype=float, index=df.index)
835 |
836 |
837 | def calculate_sma(df: pd.DataFrame, period: int) -> pd.Series:
838 | """
839 | Calculate Simple Moving Average (SMA) for the given dataframe.
840 |
841 | Args:
842 | df: DataFrame with price data
843 | period: Period for SMA calculation
844 |
845 | Returns:
846 | Series with SMA values
847 | """
848 | # Optimized to avoid copying the entire dataframe
849 | close_col = _get_column_case_insensitive(df, "close")
850 |
851 | # Ensure we have the required 'close' column
852 | if not close_col:
853 | raise ValueError("DataFrame must contain a 'close' or 'Close' column")
854 |
855 | # Use pandas_ta to calculate SMA
856 | sma = ta.sma(df[close_col], length=period)
857 |
858 | # Ensure we return a Series
859 | if isinstance(sma, pd.Series):
860 | return sma
861 | elif sma is not None:
862 | # If it's a numpy array or other iterable
863 | return pd.Series(sma, index=df.index)
864 | else:
865 | # Return empty series if calculation failed
866 | return pd.Series(dtype=float, index=df.index)
867 |
```
--------------------------------------------------------------------------------
/maverick_mcp/tests/test_models_functional.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Functional tests for SQLAlchemy models that test the actual data operations.
3 |
4 | These tests verify model functionality by reading from the existing production database
5 | without creating any new tables or modifying data.
6 | """
7 |
8 | import os
9 | import uuid
10 | from datetime import datetime, timedelta
11 | from decimal import Decimal
12 |
13 | import pytest
14 | from sqlalchemy import create_engine, text
15 | from sqlalchemy.exc import ProgrammingError
16 | from sqlalchemy.orm import sessionmaker
17 |
18 | from maverick_mcp.data.models import (
19 | DATABASE_URL,
20 | MaverickBearStocks,
21 | MaverickStocks,
22 | PriceCache,
23 | Stock,
24 | SupplyDemandBreakoutStocks,
25 | get_latest_maverick_screening,
26 | )
27 |
28 |
29 | @pytest.fixture(scope="session")
30 | def read_only_engine():
31 | """Create a read-only database engine for the existing database."""
32 | # Use the existing database URL from environment or default
33 | db_url = os.getenv("POSTGRES_URL", DATABASE_URL)
34 |
35 | try:
36 | # Create engine with read-only intent
37 | engine = create_engine(db_url, echo=False)
38 | # Test the connection
39 | with engine.connect() as conn:
40 | conn.execute(text("SELECT 1"))
41 | except Exception as e:
42 | pytest.skip(f"Database not available: {e}")
43 | return
44 |
45 | yield engine
46 |
47 | engine.dispose()
48 |
49 |
50 | @pytest.fixture(scope="function")
51 | def db_session(read_only_engine):
52 | """Create a read-only database session for each test."""
53 | SessionLocal = sessionmaker(bind=read_only_engine)
54 | session = SessionLocal()
55 |
56 | yield session
57 |
58 | session.rollback() # Rollback any potential changes
59 | session.close()
60 |
61 |
62 | class TestStockModelReadOnly:
63 | """Test the Stock model functionality with read-only operations."""
64 |
65 | def test_query_stocks(self, db_session):
66 | """Test querying existing stocks from the database."""
67 | # Query for any existing stocks
68 | stocks = db_session.query(Stock).limit(5).all()
69 |
70 | # If there are stocks in the database, verify their structure
71 | for stock in stocks:
72 | assert hasattr(stock, "stock_id")
73 | assert hasattr(stock, "ticker_symbol")
74 | assert hasattr(stock, "created_at")
75 | assert hasattr(stock, "updated_at")
76 |
77 | # Verify timestamps are timezone-aware
78 | if stock.created_at:
79 | assert stock.created_at.tzinfo is not None
80 | if stock.updated_at:
81 | assert stock.updated_at.tzinfo is not None
82 |
83 | def test_query_by_ticker(self, db_session):
84 | """Test querying stock by ticker symbol."""
85 | # Try to find a common stock like AAPL
86 | stock = db_session.query(Stock).filter_by(ticker_symbol="AAPL").first()
87 |
88 | if stock:
89 | assert stock.ticker_symbol == "AAPL"
90 | assert isinstance(stock.stock_id, uuid.UUID)
91 |
92 | def test_stock_repr(self, db_session):
93 | """Test string representation of Stock."""
94 | stock = db_session.query(Stock).first()
95 | if stock:
96 | repr_str = repr(stock)
97 | assert "<Stock(" in repr_str
98 | assert "ticker=" in repr_str
99 | assert stock.ticker_symbol in repr_str
100 |
101 | def test_stock_relationships(self, db_session):
102 | """Test stock relationships with price caches."""
103 | # Find a stock that has price data
104 | stock_with_prices = db_session.query(Stock).join(PriceCache).distinct().first()
105 |
106 | if stock_with_prices:
107 | # Access the relationship
108 | price_caches = stock_with_prices.price_caches
109 | assert isinstance(price_caches, list)
110 |
111 | # Verify each price cache belongs to this stock
112 | for pc in price_caches[:5]: # Check first 5
113 | assert pc.stock_id == stock_with_prices.stock_id
114 | assert pc.stock == stock_with_prices
115 |
116 |
117 | class TestPriceCacheModelReadOnly:
118 | """Test the PriceCache model functionality with read-only operations."""
119 |
120 | def test_query_price_cache(self, db_session):
121 | """Test querying existing price cache entries."""
122 | # Query for any existing price data
123 | prices = db_session.query(PriceCache).limit(10).all()
124 |
125 | # Verify structure of price entries
126 | for price in prices:
127 | assert hasattr(price, "price_cache_id")
128 | assert hasattr(price, "stock_id")
129 | assert hasattr(price, "date")
130 | assert hasattr(price, "close_price")
131 |
132 | # Verify data types
133 | if price.price_cache_id:
134 | assert isinstance(price.price_cache_id, uuid.UUID)
135 | if price.close_price:
136 | assert isinstance(price.close_price, Decimal)
137 |
138 | def test_price_cache_repr(self, db_session):
139 | """Test string representation of PriceCache."""
140 | price = db_session.query(PriceCache).first()
141 | if price:
142 | repr_str = repr(price)
143 | assert "<PriceCache(" in repr_str
144 | assert "stock_id=" in repr_str
145 | assert "date=" in repr_str
146 | assert "close=" in repr_str
147 |
148 | def test_get_price_data(self, db_session):
149 | """Test retrieving price data as DataFrame for existing tickers."""
150 | # Try to get price data for a common stock
151 | # Use a recent date range that might have data
152 | end_date = datetime.now().date()
153 | start_date = end_date - timedelta(days=30)
154 |
155 | # Try common tickers
156 | for ticker in ["AAPL", "MSFT", "GOOGL"]:
157 | df = PriceCache.get_price_data(
158 | db_session,
159 | ticker,
160 | start_date.strftime("%Y-%m-%d"),
161 | end_date.strftime("%Y-%m-%d"),
162 | )
163 |
164 | if not df.empty:
165 | # Verify DataFrame structure
166 | assert df.index.name == "date"
167 | assert "open" in df.columns
168 | assert "high" in df.columns
169 | assert "low" in df.columns
170 | assert "close" in df.columns
171 | assert "volume" in df.columns
172 | assert "symbol" in df.columns
173 | assert df["symbol"].iloc[0] == ticker
174 |
175 | # Verify data types
176 | assert df["close"].dtype == float
177 | assert df["volume"].dtype == int
178 | break
179 |
180 | def test_stock_relationship(self, db_session):
181 | """Test relationship back to Stock."""
182 | # Find a price entry with stock relationship
183 | price = db_session.query(PriceCache).join(Stock).first()
184 |
185 | if price:
186 | assert price.stock is not None
187 | assert price.stock.stock_id == price.stock_id
188 | assert hasattr(price.stock, "ticker_symbol")
189 |
190 |
191 | @pytest.mark.integration
192 | class TestMaverickStocksReadOnly:
193 | """Test MaverickStocks model functionality with read-only operations."""
194 |
195 | def test_query_maverick_stocks(self, db_session):
196 | """Test querying existing maverick stock entries."""
197 | try:
198 | # Query for any existing maverick stocks
199 | mavericks = db_session.query(MaverickStocks).limit(10).all()
200 |
201 | # Verify structure of maverick entries
202 | for maverick in mavericks:
203 | assert hasattr(maverick, "id")
204 | assert hasattr(maverick, "stock")
205 | assert hasattr(maverick, "close")
206 | assert hasattr(maverick, "combined_score")
207 | assert hasattr(maverick, "momentum_score")
208 | except Exception as e:
209 | if "does not exist" in str(e):
210 | pytest.skip(f"MaverickStocks table not found: {e}")
211 | else:
212 | raise
213 |
214 | def test_maverick_repr(self, db_session):
215 | """Test string representation of MaverickStocks."""
216 | try:
217 | maverick = db_session.query(MaverickStocks).first()
218 | if maverick:
219 | repr_str = repr(maverick)
220 | assert "<MaverickStock(" in repr_str
221 | assert "stock=" in repr_str
222 | assert "close=" in repr_str
223 | assert "score=" in repr_str
224 | except ProgrammingError as e:
225 | if "does not exist" in str(e):
226 | pytest.skip(f"MaverickStocks table not found: {e}")
227 | else:
228 | raise
229 |
230 | def test_get_top_stocks(self, db_session):
231 | """Test retrieving top maverick stocks."""
232 | try:
233 | # Get top stocks from existing data
234 | top_stocks = MaverickStocks.get_top_stocks(db_session, limit=20)
235 |
236 | # Verify results are sorted by combined_score
237 | if len(top_stocks) > 1:
238 | for i in range(len(top_stocks) - 1):
239 | assert (
240 | top_stocks[i].combined_score >= top_stocks[i + 1].combined_score
241 | )
242 |
243 | # Verify limit is respected
244 | assert len(top_stocks) <= 20
245 | except ProgrammingError as e:
246 | if "does not exist" in str(e):
247 | pytest.skip(f"MaverickStocks table not found: {e}")
248 | else:
249 | raise
250 |
251 | def test_maverick_to_dict(self, db_session):
252 | """Test converting MaverickStocks to dictionary."""
253 | try:
254 | maverick = db_session.query(MaverickStocks).first()
255 | if maverick:
256 | data = maverick.to_dict()
257 |
258 | # Verify expected keys
259 | expected_keys = [
260 | "stock",
261 | "close",
262 | "volume",
263 | "momentum_score",
264 | "adr_pct",
265 | "pattern",
266 | "squeeze",
267 | "consolidation",
268 | "entry",
269 | "combined_score",
270 | "compression_score",
271 | "pattern_detected",
272 | ]
273 | for key in expected_keys:
274 | assert key in data
275 |
276 | # Verify data types
277 | assert isinstance(data["stock"], str)
278 | assert isinstance(data["combined_score"], int | type(None))
279 | except ProgrammingError as e:
280 | if "does not exist" in str(e):
281 | pytest.skip(f"MaverickStocks table not found: {e}")
282 | else:
283 | raise
284 |
285 |
286 | @pytest.mark.integration
287 | class TestMaverickBearStocksReadOnly:
288 | """Test MaverickBearStocks model functionality with read-only operations."""
289 |
290 | def test_query_bear_stocks(self, db_session):
291 | """Test querying existing maverick bear stock entries."""
292 | try:
293 | # Query for any existing bear stocks
294 | bears = db_session.query(MaverickBearStocks).limit(10).all()
295 |
296 | # Verify structure of bear entries
297 | for bear in bears:
298 | assert hasattr(bear, "id")
299 | assert hasattr(bear, "stock")
300 | assert hasattr(bear, "close")
301 | assert hasattr(bear, "score")
302 | assert hasattr(bear, "momentum_score")
303 | assert hasattr(bear, "rsi_14")
304 | assert hasattr(bear, "atr_contraction")
305 | assert hasattr(bear, "big_down_vol")
306 | except Exception as e:
307 | if "does not exist" in str(e):
308 | pytest.skip(f"MaverickBearStocks table not found: {e}")
309 | else:
310 | raise
311 |
312 | def test_bear_repr(self, db_session):
313 | """Test string representation of MaverickBearStocks."""
314 | try:
315 | bear = db_session.query(MaverickBearStocks).first()
316 | if bear:
317 | repr_str = repr(bear)
318 | assert "<MaverickBearStock(" in repr_str
319 | assert "stock=" in repr_str
320 | assert "close=" in repr_str
321 | assert "score=" in repr_str
322 | except ProgrammingError as e:
323 | if "does not exist" in str(e):
324 | pytest.skip(f"MaverickBearStocks table not found: {e}")
325 | else:
326 | raise
327 |
328 | def test_get_top_bear_stocks(self, db_session):
329 | """Test retrieving top bear stocks."""
330 | try:
331 | # Get top bear stocks from existing data
332 | top_bears = MaverickBearStocks.get_top_stocks(db_session, limit=20)
333 |
334 | # Verify results are sorted by score
335 | if len(top_bears) > 1:
336 | for i in range(len(top_bears) - 1):
337 | assert top_bears[i].score >= top_bears[i + 1].score
338 |
339 | # Verify limit is respected
340 | assert len(top_bears) <= 20
341 | except ProgrammingError as e:
342 | if "does not exist" in str(e):
343 | pytest.skip(f"MaverickBearStocks table not found: {e}")
344 | else:
345 | raise
346 |
347 | def test_bear_to_dict(self, db_session):
348 | """Test converting MaverickBearStocks to dictionary."""
349 | try:
350 | bear = db_session.query(MaverickBearStocks).first()
351 | if bear:
352 | data = bear.to_dict()
353 |
354 | # Verify expected keys
355 | expected_keys = [
356 | "stock",
357 | "close",
358 | "volume",
359 | "momentum_score",
360 | "rsi_14",
361 | "macd",
362 | "macd_signal",
363 | "macd_histogram",
364 | "adr_pct",
365 | "atr",
366 | "atr_contraction",
367 | "avg_vol_30d",
368 | "big_down_vol",
369 | "score",
370 | "squeeze",
371 | "consolidation",
372 | ]
373 | for key in expected_keys:
374 | assert key in data
375 |
376 | # Verify boolean fields
377 | assert isinstance(data["atr_contraction"], bool)
378 | assert isinstance(data["big_down_vol"], bool)
379 | except ProgrammingError as e:
380 | if "does not exist" in str(e):
381 | pytest.skip(f"MaverickBearStocks table not found: {e}")
382 | else:
383 | raise
384 |
385 |
386 | @pytest.mark.integration
387 | class TestSupplyDemandBreakoutStocksReadOnly:
388 | """Test SupplyDemandBreakoutStocks model functionality with read-only operations."""
389 |
390 | def test_query_supply_demand_stocks(self, db_session):
391 | """Test querying existing supply/demand breakout stock entries."""
392 | try:
393 | # Query for any existing supply/demand breakout stocks
394 | stocks = db_session.query(SupplyDemandBreakoutStocks).limit(10).all()
395 |
396 | # Verify structure of supply/demand breakout entries
397 | for stock in stocks:
398 | assert hasattr(stock, "id")
399 | assert hasattr(stock, "stock")
400 | assert hasattr(stock, "close")
401 | assert hasattr(stock, "momentum_score")
402 | assert hasattr(stock, "sma_50")
403 | assert hasattr(stock, "sma_150")
404 | assert hasattr(stock, "sma_200")
405 | except Exception as e:
406 | if "does not exist" in str(e):
407 | pytest.skip(f"SupplyDemandBreakoutStocks table not found: {e}")
408 | else:
409 | raise
410 |
411 | def test_supply_demand_repr(self, db_session):
412 | """Test string representation of SupplyDemandBreakoutStocks."""
413 | try:
414 | supply_demand = db_session.query(SupplyDemandBreakoutStocks).first()
415 | if supply_demand:
416 | repr_str = repr(supply_demand)
417 | assert "<supply/demand breakoutStock(" in repr_str
418 | assert "stock=" in repr_str
419 | assert "close=" in repr_str
420 | assert "rs=" in repr_str
421 | except ProgrammingError as e:
422 | if "does not exist" in str(e):
423 | pytest.skip(f"SupplyDemandBreakoutStocks table not found: {e}")
424 | else:
425 | raise
426 |
427 | def test_get_top_supply_demand_stocks(self, db_session):
428 | """Test retrieving top supply/demand breakout stocks."""
429 | try:
430 | # Get top stocks from existing data
431 | top_stocks = SupplyDemandBreakoutStocks.get_top_stocks(db_session, limit=20)
432 |
433 | # Verify results are sorted by momentum_score
434 | if len(top_stocks) > 1:
435 | for i in range(len(top_stocks) - 1):
436 | assert (
437 | top_stocks[i].momentum_score >= top_stocks[i + 1].momentum_score
438 | )
439 |
440 | # Verify limit is respected
441 | assert len(top_stocks) <= 20
442 | except ProgrammingError as e:
443 | if "does not exist" in str(e):
444 | pytest.skip(f"SupplyDemandBreakoutStocks table not found: {e}")
445 | else:
446 | raise
447 |
448 | def test_get_stocks_above_moving_averages(self, db_session):
449 | """Test retrieving stocks meeting supply/demand breakout criteria."""
450 | try:
451 | # Get stocks that meet supply/demand breakout criteria from existing data
452 | stocks = SupplyDemandBreakoutStocks.get_stocks_above_moving_averages(
453 | db_session
454 | )
455 |
456 | # Verify all returned stocks meet the criteria
457 | for stock in stocks:
458 | assert stock.close > stock.sma_50
459 | assert stock.close > stock.sma_150
460 | assert stock.close > stock.sma_200
461 | assert stock.sma_50 > stock.sma_150
462 | assert stock.sma_150 > stock.sma_200
463 |
464 | # Verify they're sorted by momentum score
465 | if len(stocks) > 1:
466 | for i in range(len(stocks) - 1):
467 | assert stocks[i].momentum_score >= stocks[i + 1].momentum_score
468 | except ProgrammingError as e:
469 | if "does not exist" in str(e):
470 | pytest.skip(f"SupplyDemandBreakoutStocks table not found: {e}")
471 | else:
472 | raise
473 |
474 | def test_supply_demand_to_dict(self, db_session):
475 | """Test converting SupplyDemandBreakoutStocks to dictionary."""
476 | try:
477 | supply_demand = db_session.query(SupplyDemandBreakoutStocks).first()
478 | if supply_demand:
479 | data = supply_demand.to_dict()
480 |
481 | # Verify expected keys
482 | expected_keys = [
483 | "stock",
484 | "close",
485 | "volume",
486 | "momentum_score",
487 | "adr_pct",
488 | "pattern",
489 | "squeeze",
490 | "consolidation",
491 | "entry",
492 | "ema_21",
493 | "sma_50",
494 | "sma_150",
495 | "sma_200",
496 | "atr",
497 | "avg_volume_30d",
498 | ]
499 | for key in expected_keys:
500 | assert key in data
501 |
502 | # Verify data types
503 | assert isinstance(data["stock"], str)
504 | assert isinstance(data["momentum_score"], float | int)
505 | except ProgrammingError as e:
506 | if "does not exist" in str(e):
507 | pytest.skip(f"SupplyDemandBreakoutStocks table not found: {e}")
508 | else:
509 | raise
510 |
511 |
512 | @pytest.mark.integration
513 | class TestGetLatestMaverickScreeningReadOnly:
514 | """Test the get_latest_maverick_screening function with read-only operations."""
515 |
516 | def test_get_latest_screening(self):
517 | """Test retrieving latest screening results from existing data."""
518 | try:
519 | # Call the function directly - it creates its own session
520 | results = get_latest_maverick_screening()
521 |
522 | # Verify structure
523 | assert isinstance(results, dict)
524 | assert "maverick_stocks" in results
525 | assert "maverick_bear_stocks" in results
526 | assert "supply_demand_stocks" in results
527 |
528 | # Verify each result is a list of dictionaries
529 | assert isinstance(results["maverick_stocks"], list)
530 | assert isinstance(results["maverick_bear_stocks"], list)
531 | assert isinstance(results["supply_demand_stocks"], list)
532 |
533 | # If there are maverick stocks, verify their structure
534 | if results["maverick_stocks"]:
535 | stock_dict = results["maverick_stocks"][0]
536 | assert isinstance(stock_dict, dict)
537 | assert "stock" in stock_dict
538 | assert "combined_score" in stock_dict
539 |
540 | # Verify they're sorted by combined_score
541 | scores = [s["combined_score"] for s in results["maverick_stocks"]]
542 | assert scores == sorted(scores, reverse=True)
543 |
544 | # If there are bear stocks, verify their structure
545 | if results["maverick_bear_stocks"]:
546 | bear_dict = results["maverick_bear_stocks"][0]
547 | assert isinstance(bear_dict, dict)
548 | assert "stock" in bear_dict
549 | assert "score" in bear_dict
550 |
551 | # Verify they're sorted by score
552 | scores = [s["score"] for s in results["maverick_bear_stocks"]]
553 | assert scores == sorted(scores, reverse=True)
554 |
555 | # If there are supply/demand breakout stocks, verify their structure
556 | if results["supply_demand_stocks"]:
557 | min_dict = results["supply_demand_stocks"][0]
558 | assert isinstance(min_dict, dict)
559 | assert "stock" in min_dict
560 | assert "momentum_score" in min_dict
561 |
562 | # Verify they're sorted by momentum_score
563 | ratings = [s["momentum_score"] for s in results["supply_demand_stocks"]]
564 | assert ratings == sorted(ratings, reverse=True)
565 |
566 | except Exception as e:
567 | # If tables don't exist, that's okay for read-only tests
568 | if "does not exist" in str(e):
569 | pytest.skip(f"Screening tables not found in database: {e}")
570 | else:
571 | raise
572 |
573 |
574 | class TestDatabaseStructureReadOnly:
575 | """Test database structure and relationships with read-only operations."""
576 |
577 | def test_stock_ticker_query_performance(self, db_session):
578 | """Test that ticker queries work efficiently (index should exist)."""
579 | # Query for a specific ticker - should be fast if indexed
580 | import time
581 |
582 | start_time = time.time()
583 |
584 | # Try to find a stock by ticker
585 | stock = db_session.query(Stock).filter_by(ticker_symbol="AAPL").first()
586 |
587 | query_time = time.time() - start_time
588 |
589 | # Query should be reasonably fast if properly indexed
590 | # Allow up to 1 second for connection overhead
591 | assert query_time < 1.0
592 |
593 | # If stock exists, verify it has expected fields
594 | if stock:
595 | assert stock.ticker_symbol == "AAPL"
596 |
597 | def test_price_cache_date_query_performance(self, db_session):
598 | """Test that price cache queries by stock and date are efficient."""
599 | # First find a stock with prices
600 | stock_with_prices = db_session.query(Stock).join(PriceCache).first()
601 |
602 | if stock_with_prices:
603 | # Get a recent date
604 | recent_price = (
605 | db_session.query(PriceCache)
606 | .filter_by(stock_id=stock_with_prices.stock_id)
607 | .order_by(PriceCache.date.desc())
608 | .first()
609 | )
610 |
611 | if recent_price:
612 | # Query for specific stock_id and date - should be fast
613 | import time
614 |
615 | start_time = time.time()
616 |
617 | result = (
618 | db_session.query(PriceCache)
619 | .filter_by(
620 | stock_id=stock_with_prices.stock_id, date=recent_price.date
621 | )
622 | .first()
623 | )
624 |
625 | query_time = time.time() - start_time
626 |
627 | # Query should be reasonably fast if composite index exists
628 | assert query_time < 1.0
629 | assert result is not None
630 | assert result.price_cache_id == recent_price.price_cache_id
631 |
632 |
633 | class TestDataTypesAndConstraintsReadOnly:
634 | """Test data types and constraints with read-only operations."""
635 |
636 | def test_null_values_in_existing_data(self, db_session):
637 | """Test handling of null values in optional fields in existing data."""
638 | # Query stocks that might have null values
639 | stocks = db_session.query(Stock).limit(20).all()
640 |
641 | for stock in stocks:
642 | # These fields are optional and can be None
643 | assert hasattr(stock, "company_name")
644 | assert hasattr(stock, "sector")
645 | assert hasattr(stock, "industry")
646 |
647 | # Verify ticker_symbol is never null (it's required)
648 | assert stock.ticker_symbol is not None
649 | assert isinstance(stock.ticker_symbol, str)
650 |
651 | def test_decimal_precision_in_existing_data(self, db_session):
652 | """Test decimal precision in existing price data."""
653 | # Query some price data
654 | prices = db_session.query(PriceCache).limit(10).all()
655 |
656 | for price in prices:
657 | # Verify decimal fields
658 | if price.close_price is not None:
659 | assert isinstance(price.close_price, Decimal)
660 | # Check precision (should have at most 2 decimal places)
661 | str_price = str(price.close_price)
662 | if "." in str_price:
663 | decimal_places = len(str_price.split(".")[1])
664 | assert decimal_places <= 2
665 |
666 | # Same for other price fields
667 | for field in ["open_price", "high_price", "low_price"]:
668 | value = getattr(price, field)
669 | if value is not None:
670 | assert isinstance(value, Decimal)
671 |
672 | def test_volume_data_types(self, db_session):
673 | """Test volume data types in existing data."""
674 | # Query price data with volumes
675 | prices = (
676 | db_session.query(PriceCache)
677 | .filter(PriceCache.volume.isnot(None))
678 | .limit(10)
679 | .all()
680 | )
681 |
682 | for price in prices:
683 | assert isinstance(price.volume, int)
684 | assert price.volume >= 0
685 |
686 | def test_timezone_handling_in_existing_data(self, db_session):
687 | """Test that timestamps have timezone info in existing data."""
688 | # Query any model with timestamps
689 | stocks = db_session.query(Stock).limit(5).all()
690 |
691 | # Skip test if no stocks found
692 | if not stocks:
693 | pytest.skip("No stock data found in database")
694 |
695 | # Check if data has timezone info (newer data should, legacy data might not)
696 | has_tz_info = False
697 | for stock in stocks:
698 | if stock.created_at and stock.created_at.tzinfo is not None:
699 | has_tz_info = True
700 | # Data should have timezone info (not necessarily UTC for legacy data)
701 | # New data created by the app will be UTC
702 |
703 | if stock.updated_at and stock.updated_at.tzinfo is not None:
704 | has_tz_info = True
705 | # Data should have timezone info (not necessarily UTC for legacy data)
706 |
707 | # This test just verifies that timezone-aware timestamps are being used
708 | # Legacy data might not be UTC, but new data will be
709 | if has_tz_info:
710 | # Pass - data has timezone info which is what we want
711 | pass
712 | else:
713 | pytest.skip(
714 | "Legacy data without timezone info - new data will have timezone info"
715 | )
716 |
717 | def test_relationships_integrity(self, db_session):
718 | """Test that relationships maintain referential integrity."""
719 | # Find prices with valid stock relationships
720 | prices_with_stocks = db_session.query(PriceCache).join(Stock).limit(10).all()
721 |
722 | for price in prices_with_stocks:
723 | # Verify the relationship is intact
724 | assert price.stock is not None
725 | assert price.stock.stock_id == price.stock_id
726 |
727 | # Verify reverse relationship
728 | assert price in price.stock.price_caches
729 |
```
--------------------------------------------------------------------------------
/examples/llm_optimization_example.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | LLM Optimization Example for Research Agents - Speed-Optimized Edition.
3 |
4 | This example demonstrates how to use the comprehensive LLM optimization strategies
5 | with new speed-optimized models to prevent research agent timeouts while maintaining
6 | research quality. Features 2-3x speed improvements with Gemini 2.5 Flash and GPT-4o Mini.
7 | """
8 |
9 | import asyncio
10 | import logging
11 | import os
12 | import time
13 | from typing import Any
14 |
15 | from maverick_mcp.agents.optimized_research import (
16 | OptimizedDeepResearchAgent,
17 | create_optimized_research_agent,
18 | )
19 | from maverick_mcp.config.llm_optimization_config import (
20 | ModelSelectionStrategy,
21 | ResearchComplexity,
22 | create_adaptive_config,
23 | create_balanced_config,
24 | create_emergency_config,
25 | create_fast_config,
26 | )
27 | from maverick_mcp.providers.openrouter_provider import (
28 | OpenRouterProvider,
29 | TaskType,
30 | )
31 |
32 | # Set up logging
33 | logging.basicConfig(
34 | level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
35 | )
36 | logger = logging.getLogger(__name__)
37 |
38 |
39 | class OptimizationExamples:
40 | """Examples demonstrating LLM optimization strategies."""
41 |
42 | def __init__(self, openrouter_api_key: str):
43 | """Initialize with OpenRouter API key."""
44 | self.openrouter_api_key = openrouter_api_key
45 |
46 | async def example_1_emergency_research(self) -> dict[str, Any]:
47 | """
48 | Example 1: Emergency research with <20 second time budget.
49 |
50 | Use case: Real-time alerts or urgent market events requiring immediate analysis.
51 | """
52 | logger.info("🚨 Example 1: Emergency Research (<20s)")
53 |
54 | # Create emergency configuration (for optimization reference)
55 | _ = create_emergency_config(time_budget=15.0)
56 |
57 | # Create optimized agent
58 | agent = create_optimized_research_agent(
59 | openrouter_api_key=self.openrouter_api_key,
60 | persona="aggressive", # Aggressive for quick decisions
61 | time_budget_seconds=15.0,
62 | target_confidence=0.6, # Lower bar for emergency
63 | )
64 |
65 | # Execute emergency research
66 | start_time = time.time()
67 |
68 | result = await agent.research_comprehensive(
69 | topic="NVDA earnings surprise impact",
70 | session_id="emergency_001",
71 | depth="basic",
72 | focus_areas=["sentiment", "catalyst"],
73 | time_budget_seconds=15.0,
74 | target_confidence=0.6,
75 | )
76 |
77 | execution_time = time.time() - start_time
78 |
79 | logger.info(f"✅ Emergency research completed in {execution_time:.2f}s")
80 | logger.info(
81 | f"Optimization features used: {result.get('optimization_metrics', {}).get('optimization_features_used', [])}"
82 | )
83 |
84 | return {
85 | "scenario": "emergency",
86 | "time_budget": 15.0,
87 | "actual_time": execution_time,
88 | "success": execution_time < 20, # Success if under 20s
89 | "confidence": result.get("findings", {}).get("confidence_score", 0),
90 | "sources_processed": result.get("sources_analyzed", 0),
91 | "optimization_features": result.get("optimization_metrics", {}).get(
92 | "optimization_features_used", []
93 | ),
94 | }
95 |
96 | async def example_2_fast_research(self) -> dict[str, Any]:
97 | """
98 | Example 2: Fast research with 45 second time budget.
99 |
100 | Use case: Quick analysis for trading decisions or portfolio updates.
101 | """
102 | logger.info("⚡ Example 2: Fast Research (45s)")
103 |
104 | # Create fast configuration
105 | _ = create_fast_config(time_budget=45.0)
106 |
107 | # Create optimized agent
108 | agent = create_optimized_research_agent(
109 | openrouter_api_key=self.openrouter_api_key,
110 | persona="moderate",
111 | time_budget_seconds=45.0,
112 | target_confidence=0.7,
113 | )
114 |
115 | start_time = time.time()
116 |
117 | result = await agent.research_comprehensive(
118 | topic="Tesla Q4 2024 delivery numbers analysis",
119 | session_id="fast_001",
120 | depth="standard",
121 | focus_areas=["fundamental", "sentiment"],
122 | time_budget_seconds=45.0,
123 | target_confidence=0.7,
124 | )
125 |
126 | execution_time = time.time() - start_time
127 |
128 | logger.info(f"✅ Fast research completed in {execution_time:.2f}s")
129 |
130 | return {
131 | "scenario": "fast",
132 | "time_budget": 45.0,
133 | "actual_time": execution_time,
134 | "success": execution_time < 60,
135 | "confidence": result.get("findings", {}).get("confidence_score", 0),
136 | "sources_processed": result.get("sources_analyzed", 0),
137 | "early_terminated": result.get("findings", {}).get(
138 | "early_terminated", False
139 | ),
140 | }
141 |
142 | async def example_3_balanced_research(self) -> dict[str, Any]:
143 | """
144 | Example 3: Balanced research with 2 minute time budget.
145 |
146 | Use case: Standard research for investment decisions.
147 | """
148 | logger.info("⚖️ Example 3: Balanced Research (120s)")
149 |
150 | # Create balanced configuration
151 | _ = create_balanced_config(time_budget=120.0)
152 |
153 | agent = create_optimized_research_agent(
154 | openrouter_api_key=self.openrouter_api_key,
155 | persona="conservative",
156 | time_budget_seconds=120.0,
157 | target_confidence=0.75,
158 | )
159 |
160 | start_time = time.time()
161 |
162 | result = await agent.research_comprehensive(
163 | topic="Microsoft cloud services competitive position 2024",
164 | session_id="balanced_001",
165 | depth="comprehensive",
166 | focus_areas=["competitive", "fundamental", "technical"],
167 | time_budget_seconds=120.0,
168 | target_confidence=0.75,
169 | )
170 |
171 | execution_time = time.time() - start_time
172 |
173 | logger.info(f"✅ Balanced research completed in {execution_time:.2f}s")
174 |
175 | return {
176 | "scenario": "balanced",
177 | "time_budget": 120.0,
178 | "actual_time": execution_time,
179 | "success": execution_time < 150, # 25% buffer
180 | "confidence": result.get("findings", {}).get("confidence_score", 0),
181 | "sources_processed": result.get("sources_analyzed", 0),
182 | "processing_mode": result.get("findings", {}).get(
183 | "processing_mode", "unknown"
184 | ),
185 | }
186 |
187 | async def example_4_adaptive_research(self) -> dict[str, Any]:
188 | """
189 | Example 4: Adaptive research that adjusts based on complexity and available time.
190 |
191 | Use case: Dynamic research where time constraints may vary.
192 | """
193 | logger.info("🎯 Example 4: Adaptive Research")
194 |
195 | # Simulate varying time constraints
196 | scenarios = [
197 | {
198 | "time_budget": 30,
199 | "complexity": ResearchComplexity.SIMPLE,
200 | "topic": "Apple stock price today",
201 | },
202 | {
203 | "time_budget": 90,
204 | "complexity": ResearchComplexity.MODERATE,
205 | "topic": "Federal Reserve interest rate policy impact on tech stocks",
206 | },
207 | {
208 | "time_budget": 180,
209 | "complexity": ResearchComplexity.COMPLEX,
210 | "topic": "Cryptocurrency regulation implications for financial institutions",
211 | },
212 | ]
213 |
214 | results = []
215 |
216 | for i, scenario in enumerate(scenarios):
217 | logger.info(
218 | f"📊 Adaptive scenario {i + 1}: {scenario['complexity'].value} complexity, {scenario['time_budget']}s budget"
219 | )
220 |
221 | # Create adaptive configuration
222 | config = create_adaptive_config(
223 | time_budget_seconds=scenario["time_budget"],
224 | complexity=scenario["complexity"],
225 | )
226 |
227 | agent = create_optimized_research_agent(
228 | openrouter_api_key=self.openrouter_api_key, persona="moderate"
229 | )
230 |
231 | start_time = time.time()
232 |
233 | result = await agent.research_comprehensive(
234 | topic=scenario["topic"],
235 | session_id=f"adaptive_{i + 1:03d}",
236 | time_budget_seconds=scenario["time_budget"],
237 | target_confidence=config.preset.target_confidence,
238 | )
239 |
240 | execution_time = time.time() - start_time
241 |
242 | scenario_result = {
243 | "scenario_id": i + 1,
244 | "complexity": scenario["complexity"].value,
245 | "time_budget": scenario["time_budget"],
246 | "actual_time": execution_time,
247 | "success": execution_time < scenario["time_budget"] * 1.1, # 10% buffer
248 | "confidence": result.get("findings", {}).get("confidence_score", 0),
249 | "sources_processed": result.get("sources_analyzed", 0),
250 | "adaptations_used": result.get("optimization_metrics", {}).get(
251 | "optimization_features_used", []
252 | ),
253 | }
254 |
255 | results.append(scenario_result)
256 |
257 | logger.info(
258 | f"✅ Adaptive scenario {i + 1} completed in {execution_time:.2f}s"
259 | )
260 |
261 | return {
262 | "scenario": "adaptive",
263 | "scenarios_tested": len(scenarios),
264 | "results": results,
265 | "overall_success": all(r["success"] for r in results),
266 | }
267 |
268 | async def example_5_optimization_comparison(self) -> dict[str, Any]:
269 | """
270 | Example 5: Compare optimized vs non-optimized research performance.
271 |
272 | Use case: Demonstrate the effectiveness of optimizations.
273 | """
274 | logger.info("📈 Example 5: Optimization Comparison")
275 |
276 | test_topic = "Amazon Web Services market share growth 2024"
277 | time_budget = 90.0
278 |
279 | results = {}
280 |
281 | # Test with optimizations enabled
282 | logger.info("🔧 Testing WITH optimizations...")
283 |
284 | optimized_agent = OptimizedDeepResearchAgent(
285 | openrouter_provider=OpenRouterProvider(self.openrouter_api_key),
286 | persona="moderate",
287 | optimization_enabled=True,
288 | )
289 |
290 | start_time = time.time()
291 | optimized_result = await optimized_agent.research_comprehensive(
292 | topic=test_topic,
293 | session_id="comparison_optimized",
294 | time_budget_seconds=time_budget,
295 | target_confidence=0.75,
296 | )
297 | optimized_time = time.time() - start_time
298 |
299 | results["optimized"] = {
300 | "execution_time": optimized_time,
301 | "success": optimized_time < time_budget,
302 | "confidence": optimized_result.get("findings", {}).get(
303 | "confidence_score", 0
304 | ),
305 | "sources_processed": optimized_result.get("sources_analyzed", 0),
306 | "optimization_features": optimized_result.get(
307 | "optimization_metrics", {}
308 | ).get("optimization_features_used", []),
309 | }
310 |
311 | # Test with optimizations disabled
312 | logger.info("🐌 Testing WITHOUT optimizations...")
313 |
314 | standard_agent = OptimizedDeepResearchAgent(
315 | openrouter_provider=OpenRouterProvider(self.openrouter_api_key),
316 | persona="moderate",
317 | optimization_enabled=False,
318 | )
319 |
320 | start_time = time.time()
321 | try:
322 | standard_result = await asyncio.wait_for(
323 | standard_agent.research_comprehensive(
324 | topic=test_topic, session_id="comparison_standard", depth="standard"
325 | ),
326 | timeout=time_budget + 30, # Give extra time for timeout demonstration
327 | )
328 | standard_time = time.time() - start_time
329 |
330 | results["standard"] = {
331 | "execution_time": standard_time,
332 | "success": standard_time < time_budget,
333 | "confidence": standard_result.get("findings", {}).get(
334 | "confidence_score", 0
335 | ),
336 | "sources_processed": standard_result.get("sources_analyzed", 0),
337 | "timed_out": False,
338 | }
339 |
340 | except TimeoutError:
341 | standard_time = time_budget + 30
342 | results["standard"] = {
343 | "execution_time": standard_time,
344 | "success": False,
345 | "confidence": 0,
346 | "sources_processed": 0,
347 | "timed_out": True,
348 | }
349 |
350 | # Calculate improvement metrics
351 | time_improvement = (
352 | (
353 | results["standard"]["execution_time"]
354 | - results["optimized"]["execution_time"]
355 | )
356 | / results["standard"]["execution_time"]
357 | * 100
358 | )
359 | confidence_ratio = results["optimized"]["confidence"] / max(
360 | results["standard"]["confidence"], 0.01
361 | )
362 |
363 | results["comparison"] = {
364 | "time_improvement_pct": time_improvement,
365 | "optimized_faster": results["optimized"]["execution_time"]
366 | < results["standard"]["execution_time"],
367 | "confidence_ratio": confidence_ratio,
368 | "both_successful": results["optimized"]["success"]
369 | and results["standard"]["success"],
370 | }
371 |
372 | logger.info("📊 Optimization Results:")
373 | logger.info(
374 | f" Optimized: {results['optimized']['execution_time']:.2f}s (success: {results['optimized']['success']})"
375 | )
376 | logger.info(
377 | f" Standard: {results['standard']['execution_time']:.2f}s (success: {results['standard']['success']})"
378 | )
379 | logger.info(f" Time improvement: {time_improvement:.1f}%")
380 |
381 | return results
382 |
383 | async def example_6_speed_optimized_models(self) -> dict[str, Any]:
384 | """
385 | Example 6: Test the new speed-optimized models (Gemini 2.5 Flash, GPT-4o Mini).
386 |
387 | Use case: Demonstrate 2-3x speed improvements with the fastest available models.
388 | """
389 | logger.info("🚀 Example 6: Speed-Optimized Models Test")
390 |
391 | speed_test_results = {}
392 |
393 | # Test Gemini 2.5 Flash (199 tokens/sec - fastest)
394 | logger.info("🔥 Testing Gemini 2.5 Flash (199 tokens/sec)...")
395 | provider = OpenRouterProvider(self.openrouter_api_key)
396 |
397 | gemini_llm = provider.get_llm(
398 | model_override="google/gemini-2.5-flash",
399 | task_type=TaskType.DEEP_RESEARCH,
400 | prefer_fast=True,
401 | )
402 |
403 | start_time = time.time()
404 | try:
405 | response = await gemini_llm.ainvoke(
406 | [
407 | {
408 | "role": "user",
409 | "content": "Analyze Tesla's Q4 2024 performance in exactly 3 bullet points. Be concise and factual.",
410 | }
411 | ]
412 | )
413 | gemini_time = time.time() - start_time
414 |
415 | # Safely handle content that could be string or list
416 | content_text = (
417 | response.content
418 | if isinstance(response.content, str)
419 | else str(response.content)
420 | if response.content
421 | else ""
422 | )
423 | speed_test_results["gemini_2_5_flash"] = {
424 | "execution_time": gemini_time,
425 | "tokens_per_second": len(content_text.split()) / gemini_time
426 | if gemini_time > 0
427 | else 0,
428 | "success": True,
429 | "response_quality": "high" if len(content_text) > 50 else "low",
430 | }
431 | except Exception as e:
432 | speed_test_results["gemini_2_5_flash"] = {
433 | "execution_time": 999,
434 | "success": False,
435 | "error": str(e),
436 | }
437 |
438 | # Test GPT-4o Mini (126 tokens/sec - excellent balance)
439 | logger.info("⚡ Testing GPT-4o Mini (126 tokens/sec)...")
440 |
441 | gpt_llm = provider.get_llm(
442 | model_override="openai/gpt-4o-mini",
443 | task_type=TaskType.MARKET_ANALYSIS,
444 | prefer_fast=True,
445 | )
446 |
447 | start_time = time.time()
448 | try:
449 | response = await gpt_llm.ainvoke(
450 | [
451 | {
452 | "role": "user",
453 | "content": "Analyze Amazon's cloud services competitive position in exactly 3 bullet points. Be concise and factual.",
454 | }
455 | ]
456 | )
457 | gpt_time = time.time() - start_time
458 |
459 | # Safely handle content that could be string or list
460 | content_text = (
461 | response.content
462 | if isinstance(response.content, str)
463 | else str(response.content)
464 | if response.content
465 | else ""
466 | )
467 | speed_test_results["gpt_4o_mini"] = {
468 | "execution_time": gpt_time,
469 | "tokens_per_second": len(content_text.split()) / gpt_time
470 | if gpt_time > 0
471 | else 0,
472 | "success": True,
473 | "response_quality": "high" if len(content_text) > 50 else "low",
474 | }
475 | except Exception as e:
476 | speed_test_results["gpt_4o_mini"] = {
477 | "execution_time": 999,
478 | "success": False,
479 | "error": str(e),
480 | }
481 |
482 | # Test Claude 3.5 Haiku (65.6 tokens/sec - old baseline)
483 | logger.info("🐌 Testing Claude 3.5 Haiku (65.6 tokens/sec - baseline)...")
484 |
485 | claude_llm = provider.get_llm(
486 | model_override="anthropic/claude-3.5-haiku",
487 | task_type=TaskType.QUICK_ANSWER,
488 | prefer_fast=True,
489 | )
490 |
491 | start_time = time.time()
492 | try:
493 | response = await claude_llm.ainvoke(
494 | [
495 | {
496 | "role": "user",
497 | "content": "Analyze Microsoft's AI strategy in exactly 3 bullet points. Be concise and factual.",
498 | }
499 | ]
500 | )
501 | claude_time = time.time() - start_time
502 |
503 | # Safely handle content that could be string or list
504 | content_text = (
505 | response.content
506 | if isinstance(response.content, str)
507 | else str(response.content)
508 | if response.content
509 | else ""
510 | )
511 | speed_test_results["claude_3_5_haiku"] = {
512 | "execution_time": claude_time,
513 | "tokens_per_second": len(content_text.split()) / claude_time
514 | if claude_time > 0
515 | else 0,
516 | "success": True,
517 | "response_quality": "high" if len(content_text) > 50 else "low",
518 | }
519 | except Exception as e:
520 | speed_test_results["claude_3_5_haiku"] = {
521 | "execution_time": 999,
522 | "success": False,
523 | "error": str(e),
524 | }
525 |
526 | # Calculate speed improvements
527 | baseline_time = speed_test_results.get("claude_3_5_haiku", {}).get(
528 | "execution_time", 10
529 | )
530 |
531 | if speed_test_results["gemini_2_5_flash"]["success"]:
532 | gemini_improvement = (
533 | (
534 | baseline_time
535 | - speed_test_results["gemini_2_5_flash"]["execution_time"]
536 | )
537 | / baseline_time
538 | * 100
539 | )
540 | else:
541 | gemini_improvement = 0
542 |
543 | if speed_test_results["gpt_4o_mini"]["success"]:
544 | gpt_improvement = (
545 | (baseline_time - speed_test_results["gpt_4o_mini"]["execution_time"])
546 | / baseline_time
547 | * 100
548 | )
549 | else:
550 | gpt_improvement = 0
551 |
552 | # Test emergency model selection
553 | emergency_models = ModelSelectionStrategy.get_model_priority(
554 | time_remaining=20.0,
555 | task_type=TaskType.DEEP_RESEARCH,
556 | complexity=ResearchComplexity.MODERATE,
557 | )
558 |
559 | logger.info("📊 Speed Test Results:")
560 | logger.info(
561 | f" Gemini 2.5 Flash: {speed_test_results['gemini_2_5_flash']['execution_time']:.2f}s ({gemini_improvement:+.1f}% vs baseline)"
562 | )
563 | logger.info(
564 | f" GPT-4o Mini: {speed_test_results['gpt_4o_mini']['execution_time']:.2f}s ({gpt_improvement:+.1f}% vs baseline)"
565 | )
566 | logger.info(
567 | f" Claude 3.5 Haiku: {speed_test_results['claude_3_5_haiku']['execution_time']:.2f}s (baseline)"
568 | )
569 | logger.info(f" Emergency models: {emergency_models[:2]}")
570 |
571 | return {
572 | "scenario": "speed_optimization",
573 | "models_tested": 3,
574 | "speed_results": speed_test_results,
575 | "improvements": {
576 | "gemini_2_5_flash_vs_baseline_pct": gemini_improvement,
577 | "gpt_4o_mini_vs_baseline_pct": gpt_improvement,
578 | },
579 | "emergency_models": emergency_models[:2],
580 | "success": all(
581 | result.get("success", False) for result in speed_test_results.values()
582 | ),
583 | "fastest_model": min(
584 | speed_test_results.items(),
585 | key=lambda x: x[1].get("execution_time", 999),
586 | )[0],
587 | "speed_optimization_effective": gemini_improvement > 30
588 | or gpt_improvement > 20, # 30%+ or 20%+ improvement
589 | }
590 |
591 | def test_model_selection_strategy(self) -> dict[str, Any]:
592 | """Test the updated model selection strategy with speed-optimized models."""
593 |
594 | logger.info("🎯 Testing Model Selection Strategy...")
595 |
596 | test_scenarios = [
597 | {"time": 15, "task": TaskType.DEEP_RESEARCH, "desc": "Ultra Emergency"},
598 | {"time": 25, "task": TaskType.MARKET_ANALYSIS, "desc": "Emergency"},
599 | {"time": 45, "task": TaskType.TECHNICAL_ANALYSIS, "desc": "Fast"},
600 | {"time": 120, "task": TaskType.RESULT_SYNTHESIS, "desc": "Balanced"},
601 | ]
602 |
603 | strategy_results = {}
604 |
605 | for scenario in test_scenarios:
606 | models = ModelSelectionStrategy.get_model_priority(
607 | time_remaining=scenario["time"],
608 | task_type=scenario["task"],
609 | complexity=ResearchComplexity.MODERATE,
610 | )
611 |
612 | strategy_results[scenario["desc"].lower()] = {
613 | "time_budget": scenario["time"],
614 | "primary_model": models[0] if models else "None",
615 | "backup_models": models[1:3] if len(models) > 1 else [],
616 | "total_available": len(models),
617 | "uses_speed_optimized": any(
618 | model in ["google/gemini-2.5-flash", "openai/gpt-4o-mini"]
619 | for model in models[:2]
620 | ),
621 | }
622 |
623 | logger.info(
624 | f" {scenario['desc']} ({scenario['time']}s): Primary = {models[0] if models else 'None'}"
625 | )
626 |
627 | return {
628 | "test_scenarios": len(test_scenarios),
629 | "strategy_results": strategy_results,
630 | "all_scenarios_use_speed_models": all(
631 | result["uses_speed_optimized"] for result in strategy_results.values()
632 | ),
633 | "success": True,
634 | }
635 |
636 | async def run_all_examples(self) -> dict[str, Any]:
637 | """Run all optimization examples and return combined results."""
638 |
639 | logger.info("🚀 Starting LLM Optimization Examples...")
640 |
641 | all_results = {}
642 |
643 | try:
644 | # Run each example
645 | all_results["emergency"] = await self.example_1_emergency_research()
646 | all_results["fast"] = await self.example_2_fast_research()
647 | all_results["balanced"] = await self.example_3_balanced_research()
648 | all_results["adaptive"] = await self.example_4_adaptive_research()
649 | all_results["comparison"] = await self.example_5_optimization_comparison()
650 | all_results[
651 | "speed_optimization"
652 | ] = await self.example_6_speed_optimized_models()
653 | all_results["model_strategy"] = self.test_model_selection_strategy()
654 |
655 | # Calculate overall success metrics
656 | successful_examples = sum(
657 | 1
658 | for result in all_results.values()
659 | if result.get("success") or result.get("overall_success")
660 | )
661 |
662 | all_results["summary"] = {
663 | "total_examples": 7, # Updated for new examples
664 | "successful_examples": successful_examples,
665 | "success_rate_pct": (successful_examples / 7) * 100,
666 | "optimization_effectiveness": "High"
667 | if successful_examples >= 6
668 | else "Moderate"
669 | if successful_examples >= 4
670 | else "Low",
671 | "speed_optimization_available": all_results.get(
672 | "speed_optimization", {}
673 | ).get("success", False),
674 | "speed_improvement_demonstrated": all_results.get(
675 | "speed_optimization", {}
676 | ).get("speed_optimization_effective", False),
677 | }
678 |
679 | logger.info(
680 | f"🎉 All examples completed! Success rate: {all_results['summary']['success_rate_pct']:.0f}%"
681 | )
682 |
683 | except Exception as e:
684 | logger.error(f"❌ Example execution failed: {e}")
685 | all_results["error"] = str(e)
686 |
687 | return all_results
688 |
689 |
690 | async def main():
691 | """Main function to run optimization examples."""
692 |
693 | # Get OpenRouter API key
694 | openrouter_api_key = os.getenv("OPENROUTER_API_KEY")
695 | if not openrouter_api_key:
696 | logger.error("❌ OPENROUTER_API_KEY environment variable not set")
697 | return
698 |
699 | # Create examples instance
700 | examples = OptimizationExamples(openrouter_api_key)
701 |
702 | # Run all examples
703 | results = await examples.run_all_examples()
704 |
705 | # Print summary
706 | print("\n" + "=" * 80)
707 | print("LLM OPTIMIZATION RESULTS SUMMARY")
708 | print("=" * 80)
709 |
710 | if "summary" in results:
711 | summary = results["summary"]
712 | print(f"Total Examples: {summary['total_examples']}")
713 | print(f"Successful: {summary['successful_examples']}")
714 | print(f"Success Rate: {summary['success_rate_pct']:.0f}%")
715 | print(f"Effectiveness: {summary['optimization_effectiveness']}")
716 |
717 | if "comparison" in results and "comparison" in results["comparison"]:
718 | comp = results["comparison"]["comparison"]
719 | if comp.get("time_improvement_pct", 0) > 0:
720 | print(f"Speed Improvement: {comp['time_improvement_pct']:.1f}%")
721 |
722 | if "speed_optimization" in results and results["speed_optimization"].get("success"):
723 | speed_results = results["speed_optimization"]
724 | print(f"Fastest Model: {speed_results.get('fastest_model', 'Unknown')}")
725 |
726 | improvements = speed_results.get("improvements", {})
727 | if improvements.get("gemini_2_5_flash_vs_baseline_pct", 0) > 0:
728 | print(
729 | f"Gemini 2.5 Flash Speed Boost: {improvements['gemini_2_5_flash_vs_baseline_pct']:+.1f}%"
730 | )
731 | if improvements.get("gpt_4o_mini_vs_baseline_pct", 0) > 0:
732 | print(
733 | f"GPT-4o Mini Speed Boost: {improvements['gpt_4o_mini_vs_baseline_pct']:+.1f}%"
734 | )
735 |
736 | print("\nDetailed Results:")
737 | for example_name, result in results.items():
738 | if example_name not in ["summary", "error"]:
739 | if isinstance(result, dict):
740 | success = result.get("success") or result.get("overall_success")
741 | time_info = (
742 | f"{result.get('actual_time', 0):.1f}s"
743 | if "actual_time" in result
744 | else "N/A"
745 | )
746 | print(
747 | f" {example_name.title()}: {'✅ SUCCESS' if success else '❌ FAILED'} ({time_info})"
748 | )
749 |
750 | print("=" * 80)
751 |
752 |
753 | if __name__ == "__main__":
754 | # Run the examples
755 | asyncio.run(main())
756 |
```
--------------------------------------------------------------------------------
/tests/test_parallel_research_orchestrator.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Comprehensive test suite for ParallelResearchOrchestrator.
3 |
4 | This test suite covers:
5 | - Parallel task execution with concurrency control
6 | - Task distribution and load balancing
7 | - Error handling and timeout management
8 | - Synthesis callback functionality
9 | - Performance improvements over sequential execution
10 | - Circuit breaker integration
11 | - Resource usage monitoring
12 | """
13 |
14 | import asyncio
15 | import time
16 | from typing import Any
17 |
18 | import pytest
19 |
20 | from maverick_mcp.utils.parallel_research import (
21 | ParallelResearchConfig,
22 | ParallelResearchOrchestrator,
23 | ResearchResult,
24 | ResearchTask,
25 | TaskDistributionEngine,
26 | )
27 |
28 |
29 | class TestParallelResearchConfig:
30 | """Test ParallelResearchConfig configuration class."""
31 |
32 | def test_default_configuration(self):
33 | """Test default configuration values."""
34 | config = ParallelResearchConfig()
35 |
36 | assert config.max_concurrent_agents == 4
37 | assert config.timeout_per_agent == 180
38 | assert config.enable_fallbacks is False
39 | assert config.rate_limit_delay == 0.5
40 |
41 | def test_custom_configuration(self):
42 | """Test custom configuration values."""
43 | config = ParallelResearchConfig(
44 | max_concurrent_agents=8,
45 | timeout_per_agent=180,
46 | enable_fallbacks=False,
47 | rate_limit_delay=0.5,
48 | )
49 |
50 | assert config.max_concurrent_agents == 8
51 | assert config.timeout_per_agent == 180
52 | assert config.enable_fallbacks is False
53 | assert config.rate_limit_delay == 0.5
54 |
55 |
56 | class TestResearchTask:
57 | """Test ResearchTask data class."""
58 |
59 | def test_research_task_creation(self):
60 | """Test basic research task creation."""
61 | task = ResearchTask(
62 | task_id="test_123_fundamental",
63 | task_type="fundamental",
64 | target_topic="AAPL financial analysis",
65 | focus_areas=["earnings", "valuation", "growth"],
66 | priority=8,
67 | timeout=240,
68 | )
69 |
70 | assert task.task_id == "test_123_fundamental"
71 | assert task.task_type == "fundamental"
72 | assert task.target_topic == "AAPL financial analysis"
73 | assert task.focus_areas == ["earnings", "valuation", "growth"]
74 | assert task.priority == 8
75 | assert task.timeout == 240
76 | assert task.status == "pending"
77 | assert task.result is None
78 | assert task.error is None
79 |
80 | def test_task_lifecycle_tracking(self):
81 | """Test task lifecycle status tracking."""
82 | task = ResearchTask(
83 | task_id="lifecycle_test",
84 | task_type="sentiment",
85 | target_topic="TSLA sentiment analysis",
86 | focus_areas=["news", "social"],
87 | )
88 |
89 | # Initial state
90 | assert task.status == "pending"
91 | assert task.start_time is None
92 | assert task.end_time is None
93 |
94 | # Simulate task execution
95 | task.start_time = time.time()
96 | task.status = "running"
97 |
98 | # Simulate completion
99 | time.sleep(0.01) # Small delay to ensure different timestamps
100 | task.end_time = time.time()
101 | task.status = "completed"
102 | task.result = {"insights": ["Test insight"]}
103 |
104 | assert task.status == "completed"
105 | assert task.start_time < task.end_time
106 | assert task.result is not None
107 |
108 | def test_task_error_handling(self):
109 | """Test task error state tracking."""
110 | task = ResearchTask(
111 | task_id="error_test",
112 | task_type="technical",
113 | target_topic="NVDA technical analysis",
114 | focus_areas=["chart_patterns"],
115 | )
116 |
117 | # Simulate error
118 | task.status = "failed"
119 | task.error = "API timeout occurred"
120 | task.end_time = time.time()
121 |
122 | assert task.status == "failed"
123 | assert task.error == "API timeout occurred"
124 | assert task.result is None
125 |
126 |
127 | class TestParallelResearchOrchestrator:
128 | """Test ParallelResearchOrchestrator main functionality."""
129 |
130 | @pytest.fixture
131 | def config(self):
132 | """Create test configuration."""
133 | return ParallelResearchConfig(
134 | max_concurrent_agents=3,
135 | timeout_per_agent=5, # Short timeout for tests
136 | enable_fallbacks=True,
137 | rate_limit_delay=0.1, # Fast rate limit for tests
138 | )
139 |
140 | @pytest.fixture
141 | def orchestrator(self, config):
142 | """Create orchestrator with test configuration."""
143 | return ParallelResearchOrchestrator(config)
144 |
145 | @pytest.fixture
146 | def sample_tasks(self):
147 | """Create sample research tasks for testing."""
148 | return [
149 | ResearchTask(
150 | task_id="test_123_fundamental",
151 | task_type="fundamental",
152 | target_topic="AAPL analysis",
153 | focus_areas=["earnings", "valuation"],
154 | priority=8,
155 | ),
156 | ResearchTask(
157 | task_id="test_123_technical",
158 | task_type="technical",
159 | target_topic="AAPL analysis",
160 | focus_areas=["chart_patterns", "indicators"],
161 | priority=6,
162 | ),
163 | ResearchTask(
164 | task_id="test_123_sentiment",
165 | task_type="sentiment",
166 | target_topic="AAPL analysis",
167 | focus_areas=["news", "analyst_ratings"],
168 | priority=7,
169 | ),
170 | ]
171 |
172 | def test_orchestrator_initialization(self, config):
173 | """Test orchestrator initialization."""
174 | orchestrator = ParallelResearchOrchestrator(config)
175 |
176 | assert orchestrator.config == config
177 | assert orchestrator.active_tasks == {}
178 | assert orchestrator._semaphore._value == config.max_concurrent_agents
179 | assert orchestrator.orchestration_logger is not None
180 |
181 | def test_orchestrator_default_config(self):
182 | """Test orchestrator with default configuration."""
183 | orchestrator = ParallelResearchOrchestrator()
184 |
185 | assert orchestrator.config.max_concurrent_agents == 4
186 | assert orchestrator.config.timeout_per_agent == 180
187 |
188 | @pytest.mark.asyncio
189 | async def test_successful_parallel_execution(self, orchestrator, sample_tasks):
190 | """Test successful parallel execution of research tasks."""
191 |
192 | # Mock research executor that returns success
193 | async def mock_executor(task: ResearchTask) -> dict[str, Any]:
194 | await asyncio.sleep(0.1) # Simulate work
195 | return {
196 | "research_type": task.task_type,
197 | "insights": [f"Insight for {task.task_type}"],
198 | "sentiment": {"direction": "bullish", "confidence": 0.8},
199 | "credibility_score": 0.9,
200 | }
201 |
202 | # Mock synthesis callback
203 | async def mock_synthesis(
204 | task_results: dict[str, ResearchTask],
205 | ) -> dict[str, Any]:
206 | return {
207 | "synthesis": "Combined analysis from parallel research",
208 | "confidence_score": 0.85,
209 | "key_findings": ["Finding 1", "Finding 2"],
210 | }
211 |
212 | # Execute parallel research
213 | start_time = time.time()
214 | result = await orchestrator.execute_parallel_research(
215 | tasks=sample_tasks,
216 | research_executor=mock_executor,
217 | synthesis_callback=mock_synthesis,
218 | )
219 | execution_time = time.time() - start_time
220 |
221 | # Verify results
222 | assert isinstance(result, ResearchResult)
223 | assert result.successful_tasks == 3
224 | assert result.failed_tasks == 0
225 | assert result.synthesis is not None
226 | assert (
227 | result.synthesis["synthesis"] == "Combined analysis from parallel research"
228 | )
229 | assert len(result.task_results) == 3
230 |
231 | # Verify parallel efficiency (should be faster than sequential)
232 | assert (
233 | execution_time < 0.5
234 | ) # Should complete much faster than 3 * 0.1s sequential
235 | assert result.parallel_efficiency > 0.0 # Should show some efficiency
236 |
237 | @pytest.mark.asyncio
238 | async def test_concurrency_control(self, orchestrator, config):
239 | """Test that concurrency is properly limited."""
240 | execution_order = []
241 | active_count = 0
242 | max_concurrent = 0
243 |
244 | async def mock_executor(task: ResearchTask) -> dict[str, Any]:
245 | nonlocal active_count, max_concurrent
246 |
247 | active_count += 1
248 | max_concurrent = max(max_concurrent, active_count)
249 | execution_order.append(f"start_{task.task_id}")
250 |
251 | await asyncio.sleep(0.1) # Simulate work
252 |
253 | active_count -= 1
254 | execution_order.append(f"end_{task.task_id}")
255 | return {"result": f"completed_{task.task_id}"}
256 |
257 | # Create more tasks than max concurrent agents
258 | tasks = [
259 | ResearchTask(f"task_{i}", "fundamental", "topic", ["focus"], priority=i)
260 | for i in range(
261 | config.max_concurrent_agents + 2
262 | ) # 5 tasks, max 3 concurrent
263 | ]
264 |
265 | result = await orchestrator.execute_parallel_research(
266 | tasks=tasks,
267 | research_executor=mock_executor,
268 | )
269 |
270 | # Verify concurrency was limited
271 | assert max_concurrent <= config.max_concurrent_agents
272 | assert (
273 | result.successful_tasks == config.max_concurrent_agents
274 | ) # Limited by config
275 | assert len(execution_order) > 0
276 |
277 | @pytest.mark.asyncio
278 | async def test_task_timeout_handling(self, orchestrator):
279 | """Test handling of task timeouts."""
280 |
281 | async def slow_executor(task: ResearchTask) -> dict[str, Any]:
282 | await asyncio.sleep(10) # Longer than timeout
283 | return {"result": "should_not_complete"}
284 |
285 | tasks = [
286 | ResearchTask(
287 | "timeout_task",
288 | "fundamental",
289 | "slow topic",
290 | ["focus"],
291 | timeout=1, # Very short timeout
292 | )
293 | ]
294 |
295 | result = await orchestrator.execute_parallel_research(
296 | tasks=tasks,
297 | research_executor=slow_executor,
298 | )
299 |
300 | # Verify timeout was handled
301 | assert result.successful_tasks == 0
302 | assert result.failed_tasks == 1
303 |
304 | failed_task = result.task_results["timeout_task"]
305 | assert failed_task.status == "failed"
306 | assert "timeout" in failed_task.error.lower()
307 |
308 | @pytest.mark.asyncio
309 | async def test_task_error_handling(self, orchestrator, sample_tasks):
310 | """Test handling of task execution errors."""
311 |
312 | async def error_executor(task: ResearchTask) -> dict[str, Any]:
313 | if task.task_type == "technical":
314 | raise ValueError(f"Error in {task.task_type} analysis")
315 | return {"result": f"success_{task.task_type}"}
316 |
317 | result = await orchestrator.execute_parallel_research(
318 | tasks=sample_tasks,
319 | research_executor=error_executor,
320 | )
321 |
322 | # Verify mixed success/failure results
323 | assert result.successful_tasks == 2 # fundamental and sentiment should succeed
324 | assert result.failed_tasks == 1 # technical should fail
325 |
326 | # Check specific task status
327 | technical_task = next(
328 | task
329 | for task in result.task_results.values()
330 | if task.task_type == "technical"
331 | )
332 | assert technical_task.status == "failed"
333 | assert "Error in technical analysis" in technical_task.error
334 |
335 | @pytest.mark.asyncio
336 | async def test_task_preparation_and_prioritization(self, orchestrator):
337 | """Test task preparation and priority-based ordering."""
338 | tasks = [
339 | ResearchTask("low_priority", "technical", "topic", ["focus"], priority=2),
340 | ResearchTask(
341 | "high_priority", "fundamental", "topic", ["focus"], priority=9
342 | ),
343 | ResearchTask("med_priority", "sentiment", "topic", ["focus"], priority=5),
344 | ]
345 |
346 | async def track_executor(task: ResearchTask) -> dict[str, Any]:
347 | return {"task_id": task.task_id, "priority": task.priority}
348 |
349 | result = await orchestrator.execute_parallel_research(
350 | tasks=tasks,
351 | research_executor=track_executor,
352 | )
353 |
354 | # Verify all tasks were prepared (limited by max_concurrent_agents = 3)
355 | assert len(result.task_results) == 3
356 |
357 | # Verify tasks have default timeout set
358 | for task in result.task_results.values():
359 | assert task.timeout == orchestrator.config.timeout_per_agent
360 |
361 | @pytest.mark.asyncio
362 | async def test_synthesis_callback_error_handling(self, orchestrator, sample_tasks):
363 | """Test synthesis callback error handling."""
364 |
365 | async def success_executor(task: ResearchTask) -> dict[str, Any]:
366 | return {"result": f"success_{task.task_type}"}
367 |
368 | async def failing_synthesis(
369 | task_results: dict[str, ResearchTask],
370 | ) -> dict[str, Any]:
371 | raise RuntimeError("Synthesis failed!")
372 |
373 | result = await orchestrator.execute_parallel_research(
374 | tasks=sample_tasks,
375 | research_executor=success_executor,
376 | synthesis_callback=failing_synthesis,
377 | )
378 |
379 | # Verify tasks succeeded but synthesis failed gracefully
380 | assert result.successful_tasks == 3
381 | assert result.synthesis is not None
382 | assert "error" in result.synthesis
383 | assert "Synthesis failed" in result.synthesis["error"]
384 |
385 | @pytest.mark.asyncio
386 | async def test_no_synthesis_callback(self, orchestrator, sample_tasks):
387 | """Test execution without synthesis callback."""
388 |
389 | async def success_executor(task: ResearchTask) -> dict[str, Any]:
390 | return {"result": f"success_{task.task_type}"}
391 |
392 | result = await orchestrator.execute_parallel_research(
393 | tasks=sample_tasks,
394 | research_executor=success_executor,
395 | # No synthesis callback provided
396 | )
397 |
398 | assert result.successful_tasks == 3
399 | assert result.synthesis is None # Should be None when no callback
400 |
401 | @pytest.mark.asyncio
402 | async def test_rate_limiting_between_tasks(self, orchestrator):
403 | """Test rate limiting delays between task starts."""
404 | start_times = []
405 |
406 | async def timing_executor(task: ResearchTask) -> dict[str, Any]:
407 | start_times.append(time.time())
408 | await asyncio.sleep(0.05)
409 | return {"result": task.task_id}
410 |
411 | tasks = [
412 | ResearchTask(f"task_{i}", "fundamental", "topic", ["focus"])
413 | for i in range(3)
414 | ]
415 |
416 | await orchestrator.execute_parallel_research(
417 | tasks=tasks,
418 | research_executor=timing_executor,
419 | )
420 |
421 | # Verify rate limiting created delays (approximately rate_limit_delay apart)
422 | assert len(start_times) == 3
423 | # Note: Due to parallel execution, exact timing is hard to verify
424 | # but we can check that execution completed
425 |
426 | @pytest.mark.asyncio
427 | async def test_empty_task_list(self, orchestrator):
428 | """Test handling of empty task list."""
429 |
430 | async def unused_executor(task: ResearchTask) -> dict[str, Any]:
431 | return {"result": "should_not_be_called"}
432 |
433 | result = await orchestrator.execute_parallel_research(
434 | tasks=[],
435 | research_executor=unused_executor,
436 | )
437 |
438 | assert result.successful_tasks == 0
439 | assert result.failed_tasks == 0
440 | assert result.task_results == {}
441 | assert result.synthesis is None
442 |
443 | @pytest.mark.asyncio
444 | async def test_performance_metrics_calculation(self, orchestrator, sample_tasks):
445 | """Test calculation of performance metrics."""
446 | task_durations = []
447 |
448 | async def tracked_executor(task: ResearchTask) -> dict[str, Any]:
449 | start = time.time()
450 | await asyncio.sleep(0.05) # Simulate work
451 | duration = time.time() - start
452 | task_durations.append(duration)
453 | return {"result": task.task_id}
454 |
455 | result = await orchestrator.execute_parallel_research(
456 | tasks=sample_tasks,
457 | research_executor=tracked_executor,
458 | )
459 |
460 | # Verify performance metrics
461 | assert result.total_execution_time > 0
462 | assert result.parallel_efficiency > 0
463 |
464 | # Parallel efficiency should be roughly: sum(individual_durations) / total_wall_time
465 | expected_sequential_time = sum(task_durations)
466 | efficiency_ratio = expected_sequential_time / result.total_execution_time
467 |
468 | # Allow some tolerance for timing variations
469 | assert abs(result.parallel_efficiency - efficiency_ratio) < 0.5
470 |
471 | @pytest.mark.asyncio
472 | async def test_circuit_breaker_integration(self, orchestrator):
473 | """Test integration with circuit breaker pattern."""
474 | failure_count = 0
475 |
476 | async def circuit_breaker_executor(task: ResearchTask) -> dict[str, Any]:
477 | nonlocal failure_count
478 | failure_count += 1
479 | if failure_count <= 2: # First 2 calls fail
480 | raise RuntimeError("Circuit breaker test failure")
481 | return {"result": "success_after_failures"}
482 |
483 | tasks = [
484 | ResearchTask(f"cb_task_{i}", "fundamental", "topic", ["focus"])
485 | for i in range(3)
486 | ]
487 |
488 | # Note: The actual circuit breaker is applied in _execute_single_task
489 | # This test verifies that errors are properly handled
490 | result = await orchestrator.execute_parallel_research(
491 | tasks=tasks,
492 | research_executor=circuit_breaker_executor,
493 | )
494 |
495 | # Should have some failures and potentially some successes
496 | assert result.failed_tasks >= 2 # At least 2 should fail
497 | assert result.total_execution_time > 0
498 |
499 |
500 | class TestTaskDistributionEngine:
501 | """Test TaskDistributionEngine functionality."""
502 |
503 | def test_task_distribution_engine_creation(self):
504 | """Test creation of task distribution engine."""
505 | engine = TaskDistributionEngine()
506 | assert hasattr(engine, "TASK_TYPES")
507 | assert "fundamental" in engine.TASK_TYPES
508 | assert "technical" in engine.TASK_TYPES
509 | assert "sentiment" in engine.TASK_TYPES
510 | assert "competitive" in engine.TASK_TYPES
511 |
512 | def test_topic_relevance_analysis(self):
513 | """Test analysis of topic relevance to different research types."""
514 | engine = TaskDistributionEngine()
515 |
516 | # Test financial topic
517 | relevance = engine._analyze_topic_relevance(
518 | "AAPL earnings revenue profit analysis"
519 | )
520 |
521 | assert (
522 | relevance["fundamental"] > relevance["technical"]
523 | ) # Should favor fundamental
524 | assert all(0 <= score <= 1 for score in relevance.values()) # Valid range
525 | assert len(relevance) == 4 # All task types
526 |
527 | def test_distribute_research_tasks(self):
528 | """Test distribution of research topic into specialized tasks."""
529 | engine = TaskDistributionEngine()
530 |
531 | tasks = engine.distribute_research_tasks(
532 | topic="Tesla financial performance and market sentiment",
533 | session_id="test_123",
534 | focus_areas=["earnings", "sentiment"],
535 | )
536 |
537 | assert len(tasks) > 0
538 | assert all(isinstance(task, ResearchTask) for task in tasks)
539 | assert all(
540 | task.session_id == "test_123" for task in []
541 | ) # Tasks don't have session_id directly
542 | assert all(
543 | task.target_topic == "Tesla financial performance and market sentiment"
544 | for task in tasks
545 | )
546 |
547 | # Verify task types are relevant
548 | task_types = {task.task_type for task in tasks}
549 | assert (
550 | "fundamental" in task_types or "sentiment" in task_types
551 | ) # Should include relevant types
552 |
553 | def test_fallback_task_creation(self):
554 | """Test fallback task creation when no relevant tasks found."""
555 | engine = TaskDistributionEngine()
556 |
557 | # Use a topic that truly has low relevance scores and will trigger fallback
558 | # First, let's mock the _analyze_topic_relevance to return low scores
559 | original_method = engine._analyze_topic_relevance
560 |
561 | def mock_low_relevance(topic, focus_areas=None):
562 | return {
563 | "fundamental": 0.1,
564 | "technical": 0.1,
565 | "sentiment": 0.1,
566 | "competitive": 0.1,
567 | }
568 |
569 | engine._analyze_topic_relevance = mock_low_relevance
570 | tasks = engine.distribute_research_tasks(
571 | topic="fallback test topic", session_id="fallback_test"
572 | )
573 | # Restore original method
574 | engine._analyze_topic_relevance = original_method
575 |
576 | # Should create at least one fallback task
577 | assert len(tasks) >= 1
578 | # Should have fundamental as fallback
579 | fallback_task = tasks[0]
580 | assert fallback_task.task_type == "fundamental"
581 | assert fallback_task.priority == 5 # Default priority
582 |
583 | def test_task_priority_assignment(self):
584 | """Test priority assignment based on relevance scores."""
585 | engine = TaskDistributionEngine()
586 |
587 | tasks = engine.distribute_research_tasks(
588 | topic="Apple dividend yield earnings cash flow stability", # Should favor fundamental
589 | session_id="priority_test",
590 | )
591 |
592 | # Find fundamental task (should have higher priority for this topic)
593 | fundamental_tasks = [task for task in tasks if task.task_type == "fundamental"]
594 | if fundamental_tasks:
595 | fundamental_task = fundamental_tasks[0]
596 | assert fundamental_task.priority >= 5 # Should have decent priority
597 |
598 | def test_focus_areas_integration(self):
599 | """Test integration of provided focus areas."""
600 | engine = TaskDistributionEngine()
601 |
602 | tasks = engine.distribute_research_tasks(
603 | topic="Microsoft analysis",
604 | session_id="focus_test",
605 | focus_areas=["technical_analysis", "chart_patterns"],
606 | )
607 |
608 | # Should include technical analysis tasks when focus areas suggest it
609 | {task.task_type for task in tasks}
610 | # Should favor technical analysis given the focus areas
611 | assert len(tasks) > 0 # Should create some tasks
612 |
613 |
614 | class TestResearchResult:
615 | """Test ResearchResult data structure."""
616 |
617 | def test_research_result_initialization(self):
618 | """Test ResearchResult initialization."""
619 | result = ResearchResult()
620 |
621 | assert result.task_results == {}
622 | assert result.synthesis is None
623 | assert result.total_execution_time == 0.0
624 | assert result.successful_tasks == 0
625 | assert result.failed_tasks == 0
626 | assert result.parallel_efficiency == 0.0
627 |
628 | def test_research_result_data_storage(self):
629 | """Test storing data in ResearchResult."""
630 | result = ResearchResult()
631 |
632 | # Add sample task results
633 | task1 = ResearchTask("task_1", "fundamental", "topic", ["focus"])
634 | task1.status = "completed"
635 | task2 = ResearchTask("task_2", "technical", "topic", ["focus"])
636 | task2.status = "failed"
637 |
638 | result.task_results = {"task_1": task1, "task_2": task2}
639 | result.successful_tasks = 1
640 | result.failed_tasks = 1
641 | result.total_execution_time = 2.5
642 | result.parallel_efficiency = 1.8
643 | result.synthesis = {"findings": "Test findings"}
644 |
645 | assert len(result.task_results) == 2
646 | assert result.successful_tasks == 1
647 | assert result.failed_tasks == 1
648 | assert result.total_execution_time == 2.5
649 | assert result.parallel_efficiency == 1.8
650 | assert result.synthesis["findings"] == "Test findings"
651 |
652 |
653 | @pytest.mark.integration
654 | class TestParallelResearchIntegration:
655 | """Integration tests for complete parallel research workflow."""
656 |
657 | @pytest.fixture
658 | def full_orchestrator(self):
659 | """Create orchestrator with realistic configuration."""
660 | config = ParallelResearchConfig(
661 | max_concurrent_agents=2, # Reduced for testing
662 | timeout_per_agent=10,
663 | enable_fallbacks=True,
664 | rate_limit_delay=0.1,
665 | )
666 | return ParallelResearchOrchestrator(config)
667 |
668 | @pytest.mark.asyncio
669 | async def test_end_to_end_parallel_research(self, full_orchestrator):
670 | """Test complete end-to-end parallel research workflow."""
671 | # Create realistic research tasks
672 | engine = TaskDistributionEngine()
673 | tasks = engine.distribute_research_tasks(
674 | topic="Apple Inc financial analysis and market outlook",
675 | session_id="integration_test",
676 | )
677 |
678 | # Mock a realistic research executor
679 | async def realistic_executor(task: ResearchTask) -> dict[str, Any]:
680 | await asyncio.sleep(0.1) # Simulate API calls
681 |
682 | return {
683 | "research_type": task.task_type,
684 | "insights": [
685 | f"{task.task_type} insight 1 for {task.target_topic}",
686 | f"{task.task_type} insight 2 based on {task.focus_areas[0] if task.focus_areas else 'general'}",
687 | ],
688 | "sentiment": {
689 | "direction": "bullish"
690 | if task.task_type != "technical"
691 | else "neutral",
692 | "confidence": 0.75,
693 | },
694 | "risk_factors": [f"{task.task_type} risk factor"],
695 | "opportunities": [f"{task.task_type} opportunity"],
696 | "credibility_score": 0.8,
697 | "sources": [
698 | {
699 | "title": f"Source for {task.task_type} research",
700 | "url": f"https://example.com/{task.task_type}",
701 | "credibility_score": 0.85,
702 | }
703 | ],
704 | }
705 |
706 | # Mock synthesis callback
707 | async def integration_synthesis(
708 | task_results: dict[str, ResearchTask],
709 | ) -> dict[str, Any]:
710 | successful_results = [
711 | task.result
712 | for task in task_results.values()
713 | if task.status == "completed" and task.result
714 | ]
715 |
716 | all_insights = []
717 | for result in successful_results:
718 | all_insights.extend(result.get("insights", []))
719 |
720 | return {
721 | "synthesis": f"Integrated analysis from {len(successful_results)} research angles",
722 | "confidence_score": 0.82,
723 | "key_findings": all_insights[:5], # Top 5 insights
724 | "overall_sentiment": "bullish",
725 | "research_depth": "comprehensive",
726 | }
727 |
728 | # Execute the integration test
729 | start_time = time.time()
730 | result = await full_orchestrator.execute_parallel_research(
731 | tasks=tasks,
732 | research_executor=realistic_executor,
733 | synthesis_callback=integration_synthesis,
734 | )
735 | execution_time = time.time() - start_time
736 |
737 | # Comprehensive verification
738 | assert isinstance(result, ResearchResult)
739 | assert result.successful_tasks > 0
740 | assert result.total_execution_time > 0
741 | assert execution_time < 5 # Should complete reasonably quickly
742 |
743 | # Verify synthesis was generated
744 | assert result.synthesis is not None
745 | assert "synthesis" in result.synthesis
746 | assert result.synthesis["confidence_score"] > 0
747 |
748 | # Verify task results structure
749 | for task_id, task in result.task_results.items():
750 | assert isinstance(task, ResearchTask)
751 | assert task.task_id == task_id
752 | if task.status == "completed":
753 | assert task.result is not None
754 | assert "insights" in task.result
755 | assert "sentiment" in task.result
756 |
757 | # Verify performance characteristics
758 | if result.successful_tasks > 1:
759 | assert result.parallel_efficiency > 1.0 # Should show parallel benefit
760 |
```