This is page 28 of 28. Use http://codebase.md/wshobson/maverick-mcp?page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.md
│ │ ├── config.yml
│ │ ├── feature_request.md
│ │ ├── question.md
│ │ └── security_report.md
│ ├── pull_request_template.md
│ └── workflows
│ ├── claude-code-review.yml
│ └── claude.yml
├── .gitignore
├── .jules
│ └── bolt.md
├── .python-version
├── .vscode
│ ├── launch.json
│ └── settings.json
├── alembic
│ ├── env.py
│ ├── script.py.mako
│ └── versions
│ ├── 001_initial_schema.py
│ ├── 003_add_performance_indexes.py
│ ├── 006_rename_metadata_columns.py
│ ├── 008_performance_optimization_indexes.py
│ ├── 009_rename_to_supply_demand.py
│ ├── 010_self_contained_schema.py
│ ├── 011_remove_proprietary_terms.py
│ ├── 013_add_backtest_persistence_models.py
│ ├── 014_add_portfolio_models.py
│ ├── 08e3945a0c93_merge_heads.py
│ ├── 9374a5c9b679_merge_heads_for_testing.py
│ ├── abf9b9afb134_merge_multiple_heads.py
│ ├── adda6d3fd84b_merge_proprietary_terms_removal_with_.py
│ ├── e0c75b0bdadb_fix_financial_data_precision_only.py
│ ├── f0696e2cac15_add_essential_performance_indexes.py
│ └── fix_database_integrity_issues.py
├── alembic.ini
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── DATABASE_SETUP.md
├── docker-compose.override.yml.example
├── docker-compose.yml
├── Dockerfile
├── docs
│ ├── api
│ │ └── backtesting.md
│ ├── BACKTESTING.md
│ ├── COST_BASIS_SPECIFICATION.md
│ ├── deep_research_agent.md
│ ├── exa_research_testing_strategy.md
│ ├── PORTFOLIO_PERSONALIZATION_PLAN.md
│ ├── PORTFOLIO.md
│ ├── SETUP_SELF_CONTAINED.md
│ └── speed_testing_framework.md
├── examples
│ ├── complete_speed_validation.py
│ ├── deep_research_integration.py
│ ├── llm_optimization_example.py
│ ├── llm_speed_demo.py
│ ├── monitoring_example.py
│ ├── parallel_research_example.py
│ ├── speed_optimization_demo.py
│ └── timeout_fix_demonstration.py
├── LICENSE
├── Makefile
├── MANIFEST.in
├── maverick_mcp
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── circuit_breaker.py
│ │ ├── deep_research.py
│ │ ├── market_analysis.py
│ │ ├── optimized_research.py
│ │ ├── supervisor.py
│ │ └── technical_analysis.py
│ ├── api
│ │ ├── __init__.py
│ │ ├── api_server.py
│ │ ├── connection_manager.py
│ │ ├── dependencies
│ │ │ ├── __init__.py
│ │ │ ├── stock_analysis.py
│ │ │ └── technical_analysis.py
│ │ ├── error_handling.py
│ │ ├── inspector_compatible_sse.py
│ │ ├── inspector_sse.py
│ │ ├── middleware
│ │ │ ├── error_handling.py
│ │ │ ├── mcp_logging.py
│ │ │ ├── rate_limiting_enhanced.py
│ │ │ └── security.py
│ │ ├── openapi_config.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── agents.py
│ │ │ ├── backtesting.py
│ │ │ ├── data_enhanced.py
│ │ │ ├── data.py
│ │ │ ├── health_enhanced.py
│ │ │ ├── health_tools.py
│ │ │ ├── health.py
│ │ │ ├── intelligent_backtesting.py
│ │ │ ├── introspection.py
│ │ │ ├── mcp_prompts.py
│ │ │ ├── monitoring.py
│ │ │ ├── news_sentiment_enhanced.py
│ │ │ ├── performance.py
│ │ │ ├── portfolio.py
│ │ │ ├── research.py
│ │ │ ├── screening_ddd.py
│ │ │ ├── screening_parallel.py
│ │ │ ├── screening.py
│ │ │ ├── technical_ddd.py
│ │ │ ├── technical_enhanced.py
│ │ │ ├── technical.py
│ │ │ └── tool_registry.py
│ │ ├── server.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── base_service.py
│ │ │ ├── market_service.py
│ │ │ ├── portfolio_service.py
│ │ │ ├── prompt_service.py
│ │ │ └── resource_service.py
│ │ ├── simple_sse.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── insomnia_export.py
│ │ └── postman_export.py
│ ├── application
│ │ ├── __init__.py
│ │ ├── commands
│ │ │ └── __init__.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_dto.py
│ │ ├── queries
│ │ │ ├── __init__.py
│ │ │ └── get_technical_analysis.py
│ │ └── screening
│ │ ├── __init__.py
│ │ ├── dtos.py
│ │ └── queries.py
│ ├── backtesting
│ │ ├── __init__.py
│ │ ├── ab_testing.py
│ │ ├── analysis.py
│ │ ├── batch_processing_stub.py
│ │ ├── batch_processing.py
│ │ ├── model_manager.py
│ │ ├── optimization.py
│ │ ├── persistence.py
│ │ ├── retraining_pipeline.py
│ │ ├── strategies
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── ml
│ │ │ │ ├── __init__.py
│ │ │ │ ├── adaptive.py
│ │ │ │ ├── ensemble.py
│ │ │ │ ├── feature_engineering.py
│ │ │ │ └── regime_aware.py
│ │ │ ├── ml_strategies.py
│ │ │ ├── parser.py
│ │ │ └── templates.py
│ │ ├── strategy_executor.py
│ │ ├── vectorbt_engine.py
│ │ └── visualization.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── database_self_contained.py
│ │ ├── database.py
│ │ ├── llm_optimization_config.py
│ │ ├── logging_settings.py
│ │ ├── plotly_config.py
│ │ ├── security_utils.py
│ │ ├── security.py
│ │ ├── settings.py
│ │ ├── technical_constants.py
│ │ ├── tool_estimation.py
│ │ └── validation.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── technical_analysis.py
│ │ └── visualization.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── cache_manager.py
│ │ ├── cache.py
│ │ ├── django_adapter.py
│ │ ├── health.py
│ │ ├── models.py
│ │ ├── performance.py
│ │ ├── session_management.py
│ │ └── validation.py
│ ├── database
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── optimization.py
│ ├── dependencies.py
│ ├── domain
│ │ ├── __init__.py
│ │ ├── entities
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis.py
│ │ ├── events
│ │ │ └── __init__.py
│ │ ├── portfolio.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ ├── entities.py
│ │ │ ├── services.py
│ │ │ └── value_objects.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_service.py
│ │ ├── stock_analysis
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis_service.py
│ │ └── value_objects
│ │ ├── __init__.py
│ │ └── technical_indicators.py
│ ├── exceptions.py
│ ├── infrastructure
│ │ ├── __init__.py
│ │ ├── cache
│ │ │ └── __init__.py
│ │ ├── caching
│ │ │ ├── __init__.py
│ │ │ └── cache_management_service.py
│ │ ├── connection_manager.py
│ │ ├── data_fetching
│ │ │ ├── __init__.py
│ │ │ └── stock_data_service.py
│ │ ├── health
│ │ │ ├── __init__.py
│ │ │ └── health_checker.py
│ │ ├── persistence
│ │ │ ├── __init__.py
│ │ │ └── stock_repository.py
│ │ ├── providers
│ │ │ └── __init__.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ └── repositories.py
│ │ └── sse_optimizer.py
│ ├── langchain_tools
│ │ ├── __init__.py
│ │ ├── adapters.py
│ │ └── registry.py
│ ├── logging_config.py
│ ├── memory
│ │ ├── __init__.py
│ │ └── stores.py
│ ├── monitoring
│ │ ├── __init__.py
│ │ ├── health_check.py
│ │ ├── health_monitor.py
│ │ ├── integration_example.py
│ │ ├── metrics.py
│ │ ├── middleware.py
│ │ └── status_dashboard.py
│ ├── providers
│ │ ├── __init__.py
│ │ ├── dependencies.py
│ │ ├── factories
│ │ │ ├── __init__.py
│ │ │ ├── config_factory.py
│ │ │ └── provider_factory.py
│ │ ├── implementations
│ │ │ ├── __init__.py
│ │ │ ├── cache_adapter.py
│ │ │ ├── macro_data_adapter.py
│ │ │ ├── market_data_adapter.py
│ │ │ ├── persistence_adapter.py
│ │ │ └── stock_data_adapter.py
│ │ ├── interfaces
│ │ │ ├── __init__.py
│ │ │ ├── cache.py
│ │ │ ├── config.py
│ │ │ ├── macro_data.py
│ │ │ ├── market_data.py
│ │ │ ├── persistence.py
│ │ │ └── stock_data.py
│ │ ├── llm_factory.py
│ │ ├── macro_data.py
│ │ ├── market_data.py
│ │ ├── mocks
│ │ │ ├── __init__.py
│ │ │ ├── mock_cache.py
│ │ │ ├── mock_config.py
│ │ │ ├── mock_macro_data.py
│ │ │ ├── mock_market_data.py
│ │ │ ├── mock_persistence.py
│ │ │ └── mock_stock_data.py
│ │ ├── openrouter_provider.py
│ │ ├── optimized_screening.py
│ │ ├── optimized_stock_data.py
│ │ └── stock_data.py
│ ├── README.md
│ ├── tests
│ │ ├── __init__.py
│ │ ├── README_INMEMORY_TESTS.md
│ │ ├── test_cache_debug.py
│ │ ├── test_fixes_validation.py
│ │ ├── test_in_memory_routers.py
│ │ ├── test_in_memory_server.py
│ │ ├── test_macro_data_provider.py
│ │ ├── test_mailgun_email.py
│ │ ├── test_market_calendar_caching.py
│ │ ├── test_mcp_tool_fixes_pytest.py
│ │ ├── test_mcp_tool_fixes.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_models_functional.py
│ │ ├── test_server.py
│ │ ├── test_stock_data_enhanced.py
│ │ ├── test_stock_data_provider.py
│ │ └── test_technical_analysis.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── performance_monitoring.py
│ │ ├── portfolio_manager.py
│ │ ├── risk_management.py
│ │ └── sentiment_analysis.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── agent_errors.py
│ │ ├── batch_processing.py
│ │ ├── cache_warmer.py
│ │ ├── circuit_breaker_decorators.py
│ │ ├── circuit_breaker_services.py
│ │ ├── circuit_breaker.py
│ │ ├── data_chunking.py
│ │ ├── database_monitoring.py
│ │ ├── debug_utils.py
│ │ ├── fallback_strategies.py
│ │ ├── llm_optimization.py
│ │ ├── logging_example.py
│ │ ├── logging_init.py
│ │ ├── logging.py
│ │ ├── mcp_logging.py
│ │ ├── memory_profiler.py
│ │ ├── monitoring_middleware.py
│ │ ├── monitoring.py
│ │ ├── orchestration_logging.py
│ │ ├── parallel_research.py
│ │ ├── parallel_screening.py
│ │ ├── quick_cache.py
│ │ ├── resource_manager.py
│ │ ├── shutdown.py
│ │ ├── stock_helpers.py
│ │ ├── structured_logger.py
│ │ ├── tool_monitoring.py
│ │ ├── tracing.py
│ │ └── yfinance_pool.py
│ ├── validation
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── data.py
│ │ ├── middleware.py
│ │ ├── portfolio.py
│ │ ├── responses.py
│ │ ├── screening.py
│ │ └── technical.py
│ └── workflows
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── market_analyzer.py
│ │ ├── optimizer_agent.py
│ │ ├── strategy_selector.py
│ │ └── validator_agent.py
│ ├── backtesting_workflow.py
│ └── state.py
├── PLANS.md
├── pyproject.toml
├── pyrightconfig.json
├── README.md
├── scripts
│ ├── dev.sh
│ ├── INSTALLATION_GUIDE.md
│ ├── load_example.py
│ ├── load_market_data.py
│ ├── load_tiingo_data.py
│ ├── migrate_db.py
│ ├── README_TIINGO_LOADER.md
│ ├── requirements_tiingo.txt
│ ├── run_stock_screening.py
│ ├── run-migrations.sh
│ ├── seed_db.py
│ ├── seed_sp500.py
│ ├── setup_database.sh
│ ├── setup_self_contained.py
│ ├── setup_sp500_database.sh
│ ├── test_seeded_data.py
│ ├── test_tiingo_loader.py
│ ├── tiingo_config.py
│ └── validate_setup.py
├── SECURITY.md
├── server.json
├── setup.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── core
│ │ └── test_technical_analysis.py
│ ├── data
│ │ └── test_portfolio_models.py
│ ├── domain
│ │ ├── conftest.py
│ │ ├── test_portfolio_entities.py
│ │ └── test_technical_analysis_service.py
│ ├── fixtures
│ │ └── orchestration_fixtures.py
│ ├── integration
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── README.md
│ │ ├── run_integration_tests.sh
│ │ ├── test_api_technical.py
│ │ ├── test_chaos_engineering.py
│ │ ├── test_config_management.py
│ │ ├── test_full_backtest_workflow_advanced.py
│ │ ├── test_full_backtest_workflow.py
│ │ ├── test_high_volume.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_orchestration_complete.py
│ │ ├── test_portfolio_persistence.py
│ │ ├── test_redis_cache.py
│ │ ├── test_security_integration.py.disabled
│ │ └── vcr_setup.py
│ ├── performance
│ │ ├── __init__.py
│ │ ├── test_benchmarks.py
│ │ ├── test_load.py
│ │ ├── test_profiling.py
│ │ └── test_stress.py
│ ├── providers
│ │ └── test_stock_data_simple.py
│ ├── README.md
│ ├── test_agents_router_mcp.py
│ ├── test_backtest_persistence.py
│ ├── test_cache_management_service.py
│ ├── test_cache_serialization.py
│ ├── test_circuit_breaker.py
│ ├── test_database_pool_config_simple.py
│ ├── test_database_pool_config.py
│ ├── test_deep_research_functional.py
│ ├── test_deep_research_integration.py
│ ├── test_deep_research_parallel_execution.py
│ ├── test_error_handling.py
│ ├── test_event_loop_integrity.py
│ ├── test_exa_research_integration.py
│ ├── test_exception_hierarchy.py
│ ├── test_financial_search.py
│ ├── test_graceful_shutdown.py
│ ├── test_integration_simple.py
│ ├── test_langgraph_workflow.py
│ ├── test_market_data_async.py
│ ├── test_market_data_simple.py
│ ├── test_mcp_orchestration_functional.py
│ ├── test_ml_strategies.py
│ ├── test_optimized_research_agent.py
│ ├── test_orchestration_integration.py
│ ├── test_orchestration_logging.py
│ ├── test_orchestration_tools_simple.py
│ ├── test_parallel_research_integration.py
│ ├── test_parallel_research_orchestrator.py
│ ├── test_parallel_research_performance.py
│ ├── test_performance_optimizations.py
│ ├── test_production_validation.py
│ ├── test_provider_architecture.py
│ ├── test_rate_limiting_enhanced.py
│ ├── test_runner_validation.py
│ ├── test_security_comprehensive.py.disabled
│ ├── test_security_cors.py
│ ├── test_security_enhancements.py.disabled
│ ├── test_security_headers.py
│ ├── test_security_penetration.py
│ ├── test_session_management.py
│ ├── test_speed_optimization_validation.py
│ ├── test_stock_analysis_dependencies.py
│ ├── test_stock_analysis_service.py
│ ├── test_stock_data_fetching_service.py
│ ├── test_supervisor_agent.py
│ ├── test_supervisor_functional.py
│ ├── test_tool_estimation_config.py
│ ├── test_visualization.py
│ ├── unit
│ │ └── test_stock_repository_adapter.py
│ └── utils
│ ├── test_agent_errors.py
│ ├── test_logging.py
│ ├── test_parallel_screening.py
│ └── test_quick_cache.py
├── tools
│ ├── check_orchestration_config.py
│ ├── experiments
│ │ ├── validation_examples.py
│ │ └── validation_fixed.py
│ ├── fast_dev.sh
│ ├── hot_reload.py
│ ├── quick_test.py
│ └── templates
│ ├── new_router_template.py
│ ├── new_tool_template.py
│ ├── screening_strategy_template.py
│ └── test_template.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/maverick_mcp/data/models.py:
--------------------------------------------------------------------------------
```python
"""
SQLAlchemy models for MaverickMCP.
This module defines database models for financial data storage and analysis,
including PriceCache and Maverick screening models.
"""
from __future__ import annotations
import logging
import os
import threading
import uuid
from collections.abc import AsyncGenerator, Sequence
from datetime import UTC, date, datetime, timedelta
from decimal import Decimal
import pandas as pd
from sqlalchemy import (
JSON,
BigInteger,
Boolean,
Column,
Date,
DateTime,
ForeignKey,
Index,
Integer,
Numeric,
String,
Text,
UniqueConstraint,
Uuid,
create_engine,
inspect,
)
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import Session, relationship, sessionmaker
from sqlalchemy.pool import NullPool, QueuePool
from maverick_mcp.config.settings import get_settings
from maverick_mcp.database.base import Base
# Set up logging
logger = logging.getLogger("maverick_mcp.data.models")
settings = get_settings()
# Helper function to get the right integer type for autoincrement primary keys
def get_primary_key_type():
"""Get the appropriate primary key type based on database backend."""
# SQLite works better with INTEGER for autoincrement, PostgreSQL can use BIGINT
if "sqlite" in DATABASE_URL:
return Integer
else:
return BigInteger
# Database connection setup
# Try multiple possible environment variable names
# Use SQLite in-memory for GitHub Actions or test environments
if os.getenv("GITHUB_ACTIONS") == "true" or os.getenv("CI") == "true":
DATABASE_URL = "sqlite:///:memory:"
else:
DATABASE_URL = (
os.getenv("DATABASE_URL")
or os.getenv("POSTGRES_URL")
or "sqlite:///maverick_mcp.db" # Default to SQLite
)
# Database configuration from settings
DB_POOL_SIZE = settings.db.pool_size
DB_MAX_OVERFLOW = settings.db.pool_max_overflow
DB_POOL_TIMEOUT = settings.db.pool_timeout
DB_POOL_RECYCLE = int(os.getenv("DB_POOL_RECYCLE", "3600")) # 1 hour
DB_POOL_PRE_PING = os.getenv("DB_POOL_PRE_PING", "true").lower() == "true"
DB_ECHO = os.getenv("DB_ECHO", "false").lower() == "true"
DB_USE_POOLING = os.getenv("DB_USE_POOLING", "true").lower() == "true"
# Log the connection string (without password) for debugging
if DATABASE_URL:
# Mask password in URL for logging
masked_url = DATABASE_URL
if "@" in DATABASE_URL and "://" in DATABASE_URL:
parts = DATABASE_URL.split("://", 1)
if len(parts) == 2 and "@" in parts[1]:
user_pass, host_db = parts[1].split("@", 1)
if ":" in user_pass:
user, _ = user_pass.split(":", 1)
masked_url = f"{parts[0]}://{user}:****@{host_db}"
logger.info(f"Using database URL: {masked_url}")
logger.info(f"Connection pooling: {'ENABLED' if DB_USE_POOLING else 'DISABLED'}")
if DB_USE_POOLING:
logger.info(
f"Pool config: size={DB_POOL_SIZE}, max_overflow={DB_MAX_OVERFLOW}, "
f"timeout={DB_POOL_TIMEOUT}s, recycle={DB_POOL_RECYCLE}s"
)
# Create engine with configurable connection pooling
if DB_USE_POOLING:
# Prepare connection arguments based on database type
if "postgresql" in DATABASE_URL:
# PostgreSQL-specific connection args
sync_connect_args = {
"connect_timeout": 10,
"application_name": "maverick_mcp",
"options": f"-c statement_timeout={settings.db.statement_timeout}",
}
elif "sqlite" in DATABASE_URL:
# SQLite-specific args - no SSL parameters
sync_connect_args = {"check_same_thread": False}
else:
# Default - no connection args
sync_connect_args = {}
# Use QueuePool for production environments
engine = create_engine(
DATABASE_URL,
poolclass=QueuePool,
pool_size=DB_POOL_SIZE,
max_overflow=DB_MAX_OVERFLOW,
pool_timeout=DB_POOL_TIMEOUT,
pool_recycle=DB_POOL_RECYCLE,
pool_pre_ping=DB_POOL_PRE_PING,
echo=DB_ECHO,
connect_args=sync_connect_args,
)
else:
# Prepare minimal connection arguments for NullPool
if "sqlite" in DATABASE_URL:
sync_connect_args = {"check_same_thread": False}
else:
sync_connect_args = {}
# Use NullPool for serverless/development environments
engine = create_engine(
DATABASE_URL,
poolclass=NullPool,
echo=DB_ECHO,
connect_args=sync_connect_args,
)
# Create session factory
_session_factory = sessionmaker(autocommit=False, autoflush=False, bind=engine)
_schema_lock = threading.Lock()
_schema_initialized = False
def ensure_database_schema(force: bool = False) -> bool:
"""Ensure the database schema exists for the configured engine.
Args:
force: When ``True`` the schema will be (re)created even if it appears
to exist already.
Returns:
``True`` if the schema creation routine executed, ``False`` otherwise.
"""
global _schema_initialized
# Fast path: skip inspection once the schema has been verified unless the
# caller explicitly requests a forced refresh.
if not force and _schema_initialized:
return False
with _schema_lock:
if not force and _schema_initialized:
return False
try:
inspector = inspect(engine)
existing_tables = set(inspector.get_table_names())
except SQLAlchemyError as exc: # pragma: no cover - safety net
logger.warning(
"Unable to inspect database schema; attempting to create tables anyway",
exc_info=exc,
)
existing_tables = set()
defined_tables = set(Base.metadata.tables.keys())
missing_tables = defined_tables - existing_tables
should_create = force or bool(missing_tables)
if should_create:
if missing_tables:
logger.info(
"Creating missing database tables: %s",
", ".join(sorted(missing_tables)),
)
else:
logger.info("Ensuring database schema is up to date")
Base.metadata.create_all(bind=engine)
_schema_initialized = True
return True
_schema_initialized = True
return False
class _SessionFactoryWrapper:
"""Session factory that ensures the schema exists before creating sessions."""
def __init__(self, factory: sessionmaker):
self._factory = factory
def __call__(self, *args, **kwargs):
ensure_database_schema()
return self._factory(*args, **kwargs)
def __getattr__(self, name):
return getattr(self._factory, name)
SessionLocal = _SessionFactoryWrapper(_session_factory)
# Create async engine - cached globally for reuse
_async_engine = None
_async_session_factory = None
def _get_async_engine():
"""Get or create the async engine singleton."""
global _async_engine
if _async_engine is None:
# Convert sync URL to async URL
if DATABASE_URL.startswith("sqlite://"):
async_url = DATABASE_URL.replace("sqlite://", "sqlite+aiosqlite://")
else:
async_url = DATABASE_URL.replace("postgresql://", "postgresql+asyncpg://")
# Create async engine - don't specify poolclass for async engines
# SQLAlchemy will use the appropriate async pool automatically
if DB_USE_POOLING:
# Prepare connection arguments based on database type
if "postgresql" in async_url:
# PostgreSQL-specific connection args
async_connect_args = {
"server_settings": {
"application_name": "maverick_mcp_async",
"statement_timeout": str(settings.db.statement_timeout),
}
}
elif "sqlite" in async_url:
# SQLite-specific args - no SSL parameters
async_connect_args = {"check_same_thread": False}
else:
# Default - no connection args
async_connect_args = {}
_async_engine = create_async_engine(
async_url,
# Don't specify poolclass - let SQLAlchemy choose the async pool
pool_size=DB_POOL_SIZE,
max_overflow=DB_MAX_OVERFLOW,
pool_timeout=DB_POOL_TIMEOUT,
pool_recycle=DB_POOL_RECYCLE,
pool_pre_ping=DB_POOL_PRE_PING,
echo=DB_ECHO,
connect_args=async_connect_args,
)
else:
# Prepare minimal connection arguments for NullPool
if "sqlite" in async_url:
async_connect_args = {"check_same_thread": False}
else:
async_connect_args = {}
_async_engine = create_async_engine(
async_url,
poolclass=NullPool,
echo=DB_ECHO,
connect_args=async_connect_args,
)
logger.info("Created async database engine")
return _async_engine
def _get_async_session_factory():
"""Get or create the async session factory singleton."""
global _async_session_factory
if _async_session_factory is None:
engine = _get_async_engine()
_async_session_factory = async_sessionmaker(
engine, class_=AsyncSession, expire_on_commit=False
)
logger.info("Created async session factory")
return _async_session_factory
def get_db():
"""Get database session."""
ensure_database_schema()
db = SessionLocal()
try:
yield db
finally:
db.close()
# Async database support - imports moved to top of file
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
"""Get an async database session using the cached engine."""
# Get the cached session factory
async_session_factory = _get_async_session_factory()
# Create and yield a session
async with async_session_factory() as session:
try:
yield session
finally:
await session.close()
async def close_async_db_connections():
"""Close the async database engine and cleanup connections."""
global _async_engine, _async_session_factory
if _async_engine:
await _async_engine.dispose()
_async_engine = None
_async_session_factory = None
logger.info("Closed async database engine")
def init_db():
"""Initialize database by creating all tables."""
ensure_database_schema(force=True)
class TimestampMixin:
"""Mixin for created_at and updated_at timestamps."""
created_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(UTC),
nullable=False,
)
updated_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(UTC),
nullable=False,
)
class Stock(Base, TimestampMixin):
"""Stock model for storing basic stock information."""
__tablename__ = "mcp_stocks"
stock_id = Column(Uuid, primary_key=True, default=uuid.uuid4)
ticker_symbol = Column(String(10), unique=True, nullable=False, index=True)
company_name = Column(String(255))
description = Column(Text)
sector = Column(String(100))
industry = Column(String(100))
exchange = Column(String(50))
country = Column(String(50))
currency = Column(String(3))
isin = Column(String(12))
# Additional stock metadata
market_cap = Column(BigInteger)
shares_outstanding = Column(BigInteger)
is_etf = Column(Boolean, default=False)
is_active = Column(Boolean, default=True, index=True)
# Relationships
price_caches = relationship(
"PriceCache",
back_populates="stock",
cascade="all, delete-orphan",
lazy="selectin", # Eager load price caches to prevent N+1 queries
)
maverick_stocks = relationship(
"MaverickStocks", back_populates="stock", cascade="all, delete-orphan"
)
maverick_bear_stocks = relationship(
"MaverickBearStocks", back_populates="stock", cascade="all, delete-orphan"
)
supply_demand_stocks = relationship(
"SupplyDemandBreakoutStocks",
back_populates="stock",
cascade="all, delete-orphan",
)
technical_cache = relationship(
"TechnicalCache", back_populates="stock", cascade="all, delete-orphan"
)
def __repr__(self):
return f"<Stock(ticker={self.ticker_symbol}, name={self.company_name})>"
@classmethod
def get_or_create(cls, session: Session, ticker_symbol: str, **kwargs) -> Stock:
"""Get existing stock or create new one."""
stock = (
session.query(cls).filter_by(ticker_symbol=ticker_symbol.upper()).first()
)
if not stock:
stock = cls(ticker_symbol=ticker_symbol.upper(), **kwargs)
session.add(stock)
session.commit()
return stock
class PriceCache(Base, TimestampMixin):
"""Cache for historical stock price data."""
__tablename__ = "mcp_price_cache"
__table_args__ = (
UniqueConstraint("stock_id", "date", name="mcp_price_cache_stock_date_unique"),
Index("mcp_price_cache_stock_id_date_idx", "stock_id", "date"),
Index("mcp_price_cache_ticker_date_idx", "stock_id", "date"),
)
price_cache_id = Column(Uuid, primary_key=True, default=uuid.uuid4)
stock_id = Column(Uuid, ForeignKey("mcp_stocks.stock_id"), nullable=False)
date = Column(Date, nullable=False)
open_price = Column(Numeric(12, 4))
high_price = Column(Numeric(12, 4))
low_price = Column(Numeric(12, 4))
close_price = Column(Numeric(12, 4))
volume = Column(BigInteger)
# Relationships
stock = relationship(
"Stock", back_populates="price_caches", lazy="joined"
) # Eager load stock info
def __repr__(self):
return f"<PriceCache(stock_id={self.stock_id}, date={self.date}, close={self.close_price})>"
@classmethod
def get_price_data(
cls,
session: Session,
ticker_symbol: str,
start_date: str,
end_date: str | None = None,
) -> pd.DataFrame:
"""
Return a pandas DataFrame of price data for the specified symbol and date range.
Args:
session: Database session
ticker_symbol: Stock ticker symbol
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format (default: today)
Returns:
DataFrame with OHLCV data indexed by date
"""
if not end_date:
end_date = datetime.now(UTC).strftime("%Y-%m-%d")
# Query with join to get ticker symbol
query = (
session.query(
cls.date,
cls.open_price.label("open"),
cls.high_price.label("high"),
cls.low_price.label("low"),
cls.close_price.label("close"),
cls.volume,
)
.join(Stock)
.filter(
Stock.ticker_symbol == ticker_symbol.upper(),
cls.date >= pd.to_datetime(start_date).date(),
cls.date <= pd.to_datetime(end_date).date(),
)
.order_by(cls.date)
)
# Convert to DataFrame
df = pd.DataFrame(query.all())
if not df.empty:
df["date"] = pd.to_datetime(df["date"])
df.set_index("date", inplace=True)
# Convert decimal types to float
for col in ["open", "high", "low", "close"]:
df[col] = df[col].astype(float)
df["volume"] = df["volume"].astype(int)
df["symbol"] = ticker_symbol.upper()
return df
class MaverickStocks(Base, TimestampMixin):
"""Maverick stocks screening results - self-contained model."""
__tablename__ = "mcp_maverick_stocks"
__table_args__ = (
Index("mcp_maverick_stocks_combined_score_idx", "combined_score"),
Index(
"mcp_maverick_stocks_momentum_score_idx", "momentum_score"
), # formerly rs_rating_idx
Index("mcp_maverick_stocks_date_analyzed_idx", "date_analyzed"),
Index("mcp_maverick_stocks_stock_date_idx", "stock_id", "date_analyzed"),
)
id = Column(get_primary_key_type(), primary_key=True, autoincrement=True)
stock_id = Column(
Uuid,
ForeignKey("mcp_stocks.stock_id"),
nullable=False,
index=True,
)
date_analyzed = Column(
Date, nullable=False, default=lambda: datetime.now(UTC).date()
)
# OHLCV Data
open_price = Column(Numeric(12, 4), default=0)
high_price = Column(Numeric(12, 4), default=0)
low_price = Column(Numeric(12, 4), default=0)
close_price = Column(Numeric(12, 4), default=0)
volume = Column(BigInteger, default=0)
# Technical Indicators
ema_21 = Column(Numeric(12, 4), default=0)
sma_50 = Column(Numeric(12, 4), default=0)
sma_150 = Column(Numeric(12, 4), default=0)
sma_200 = Column(Numeric(12, 4), default=0)
momentum_score = Column(Numeric(5, 2), default=0) # formerly rs_rating
avg_vol_30d = Column(Numeric(15, 2), default=0)
adr_pct = Column(Numeric(5, 2), default=0)
atr = Column(Numeric(12, 4), default=0)
# Pattern Analysis
pattern_type = Column(String(50)) # 'pat' field
squeeze_status = Column(String(50)) # 'sqz' field
consolidation_status = Column(String(50)) # formerly vcp_status, 'vcp' field
entry_signal = Column(String(50)) # 'entry' field
# Scoring
compression_score = Column(Integer, default=0)
pattern_detected = Column(Integer, default=0)
combined_score = Column(Integer, default=0)
# Relationships
stock = relationship("Stock", back_populates="maverick_stocks")
def __repr__(self):
return f"<MaverickStock(stock_id={self.stock_id}, close={self.close_price}, score={self.combined_score})>"
@classmethod
def get_top_stocks(
cls, session: Session, limit: int = 20
) -> Sequence[MaverickStocks]:
"""Get top maverick stocks by combined score."""
return (
session.query(cls)
.join(Stock)
.order_by(cls.combined_score.desc())
.limit(limit)
.all()
)
@classmethod
def get_latest_analysis(
cls, session: Session, days_back: int = 1
) -> Sequence[MaverickStocks]:
"""Get latest maverick analysis within specified days."""
cutoff_date = datetime.now(UTC).date() - timedelta(days=days_back)
return (
session.query(cls)
.join(Stock)
.filter(cls.date_analyzed >= cutoff_date)
.order_by(cls.combined_score.desc())
.all()
)
def to_dict(self) -> dict:
"""Convert to dictionary for JSON serialization."""
return {
"stock_id": str(self.stock_id),
"ticker": self.stock.ticker_symbol if self.stock else None,
"date_analyzed": self.date_analyzed.isoformat()
if self.date_analyzed
else None,
"close": float(self.close_price) if self.close_price else 0,
"volume": self.volume,
"momentum_score": float(self.momentum_score)
if self.momentum_score
else 0, # formerly rs_rating
"adr_pct": float(self.adr_pct) if self.adr_pct else 0,
"pattern": self.pattern_type,
"squeeze": self.squeeze_status,
"consolidation": self.consolidation_status, # formerly vcp
"entry": self.entry_signal,
"combined_score": self.combined_score,
"compression_score": self.compression_score,
"pattern_detected": self.pattern_detected,
"ema_21": float(self.ema_21) if self.ema_21 else 0,
"sma_50": float(self.sma_50) if self.sma_50 else 0,
"sma_150": float(self.sma_150) if self.sma_150 else 0,
"sma_200": float(self.sma_200) if self.sma_200 else 0,
"atr": float(self.atr) if self.atr else 0,
"avg_vol_30d": float(self.avg_vol_30d) if self.avg_vol_30d else 0,
}
class MaverickBearStocks(Base, TimestampMixin):
"""Maverick bear stocks screening results - self-contained model."""
__tablename__ = "mcp_maverick_bear_stocks"
__table_args__ = (
Index("mcp_maverick_bear_stocks_score_idx", "score"),
Index(
"mcp_maverick_bear_stocks_momentum_score_idx", "momentum_score"
), # formerly rs_rating_idx
Index("mcp_maverick_bear_stocks_date_analyzed_idx", "date_analyzed"),
Index("mcp_maverick_bear_stocks_stock_date_idx", "stock_id", "date_analyzed"),
)
id = Column(get_primary_key_type(), primary_key=True, autoincrement=True)
stock_id = Column(
Uuid,
ForeignKey("mcp_stocks.stock_id"),
nullable=False,
index=True,
)
date_analyzed = Column(
Date, nullable=False, default=lambda: datetime.now(UTC).date()
)
# OHLCV Data
open_price = Column(Numeric(12, 4), default=0)
high_price = Column(Numeric(12, 4), default=0)
low_price = Column(Numeric(12, 4), default=0)
close_price = Column(Numeric(12, 4), default=0)
volume = Column(BigInteger, default=0)
# Technical Indicators
momentum_score = Column(Numeric(5, 2), default=0) # formerly rs_rating
ema_21 = Column(Numeric(12, 4), default=0)
sma_50 = Column(Numeric(12, 4), default=0)
sma_200 = Column(Numeric(12, 4), default=0)
rsi_14 = Column(Numeric(5, 2), default=0)
# MACD Indicators
macd = Column(Numeric(12, 6), default=0)
macd_signal = Column(Numeric(12, 6), default=0)
macd_histogram = Column(Numeric(12, 6), default=0)
# Additional Bear Market Indicators
dist_days_20 = Column(Integer, default=0) # Days from 20 SMA
adr_pct = Column(Numeric(5, 2), default=0)
atr_contraction = Column(Boolean, default=False)
atr = Column(Numeric(12, 4), default=0)
avg_vol_30d = Column(Numeric(15, 2), default=0)
big_down_vol = Column(Boolean, default=False)
# Pattern Analysis
squeeze_status = Column(String(50)) # 'sqz' field
consolidation_status = Column(String(50)) # formerly vcp_status, 'vcp' field
# Scoring
score = Column(Integer, default=0)
# Relationships
stock = relationship("Stock", back_populates="maverick_bear_stocks")
def __repr__(self):
return f"<MaverickBearStock(stock_id={self.stock_id}, close={self.close_price}, score={self.score})>"
@classmethod
def get_top_stocks(
cls, session: Session, limit: int = 20
) -> Sequence[MaverickBearStocks]:
"""Get top maverick bear stocks by score."""
return (
session.query(cls).join(Stock).order_by(cls.score.desc()).limit(limit).all()
)
@classmethod
def get_latest_analysis(
cls, session: Session, days_back: int = 1
) -> Sequence[MaverickBearStocks]:
"""Get latest bear analysis within specified days."""
cutoff_date = datetime.now(UTC).date() - timedelta(days=days_back)
return (
session.query(cls)
.join(Stock)
.filter(cls.date_analyzed >= cutoff_date)
.order_by(cls.score.desc())
.all()
)
def to_dict(self) -> dict:
"""Convert to dictionary for JSON serialization."""
return {
"stock_id": str(self.stock_id),
"ticker": self.stock.ticker_symbol if self.stock else None,
"date_analyzed": self.date_analyzed.isoformat()
if self.date_analyzed
else None,
"close": float(self.close_price) if self.close_price else 0,
"volume": self.volume,
"momentum_score": float(self.momentum_score)
if self.momentum_score
else 0, # formerly rs_rating
"rsi_14": float(self.rsi_14) if self.rsi_14 else 0,
"macd": float(self.macd) if self.macd else 0,
"macd_signal": float(self.macd_signal) if self.macd_signal else 0,
"macd_histogram": float(self.macd_histogram) if self.macd_histogram else 0,
"adr_pct": float(self.adr_pct) if self.adr_pct else 0,
"atr": float(self.atr) if self.atr else 0,
"atr_contraction": self.atr_contraction,
"avg_vol_30d": float(self.avg_vol_30d) if self.avg_vol_30d else 0,
"big_down_vol": self.big_down_vol,
"score": self.score,
"squeeze": self.squeeze_status,
"consolidation": self.consolidation_status, # formerly vcp
"ema_21": float(self.ema_21) if self.ema_21 else 0,
"sma_50": float(self.sma_50) if self.sma_50 else 0,
"sma_200": float(self.sma_200) if self.sma_200 else 0,
"dist_days_20": self.dist_days_20,
}
class SupplyDemandBreakoutStocks(Base, TimestampMixin):
"""Supply/demand breakout stocks screening results - self-contained model.
This model identifies stocks experiencing accumulation breakouts with strong relative strength,
indicating a potential shift from supply to demand dominance in the market structure.
"""
__tablename__ = "mcp_supply_demand_breakouts"
__table_args__ = (
Index(
"mcp_supply_demand_breakouts_momentum_score_idx", "momentum_score"
), # formerly rs_rating_idx
Index("mcp_supply_demand_breakouts_date_analyzed_idx", "date_analyzed"),
Index(
"mcp_supply_demand_breakouts_stock_date_idx", "stock_id", "date_analyzed"
),
Index(
"mcp_supply_demand_breakouts_ma_filter_idx",
"close_price",
"sma_50",
"sma_150",
"sma_200",
),
)
id = Column(get_primary_key_type(), primary_key=True, autoincrement=True)
stock_id = Column(
Uuid,
ForeignKey("mcp_stocks.stock_id"),
nullable=False,
index=True,
)
date_analyzed = Column(
Date, nullable=False, default=lambda: datetime.now(UTC).date()
)
# OHLCV Data
open_price = Column(Numeric(12, 4), default=0)
high_price = Column(Numeric(12, 4), default=0)
low_price = Column(Numeric(12, 4), default=0)
close_price = Column(Numeric(12, 4), default=0)
volume = Column(BigInteger, default=0)
# Technical Indicators
ema_21 = Column(Numeric(12, 4), default=0)
sma_50 = Column(Numeric(12, 4), default=0)
sma_150 = Column(Numeric(12, 4), default=0)
sma_200 = Column(Numeric(12, 4), default=0)
momentum_score = Column(Numeric(5, 2), default=0) # formerly rs_rating
avg_volume_30d = Column(Numeric(15, 2), default=0)
adr_pct = Column(Numeric(5, 2), default=0)
atr = Column(Numeric(12, 4), default=0)
# Pattern Analysis
pattern_type = Column(String(50)) # 'pat' field
squeeze_status = Column(String(50)) # 'sqz' field
consolidation_status = Column(String(50)) # formerly vcp_status, 'vcp' field
entry_signal = Column(String(50)) # 'entry' field
# Supply/Demand Analysis
accumulation_rating = Column(Numeric(5, 2), default=0)
distribution_rating = Column(Numeric(5, 2), default=0)
breakout_strength = Column(Numeric(5, 2), default=0)
# Relationships
stock = relationship("Stock", back_populates="supply_demand_stocks")
def __repr__(self):
return f"<SupplyDemandBreakoutStock(stock_id={self.stock_id}, close={self.close_price}, momentum={self.momentum_score})>" # formerly rs
@classmethod
def get_top_stocks(
cls, session: Session, limit: int = 20
) -> Sequence[SupplyDemandBreakoutStocks]:
"""Get top supply/demand breakout stocks by momentum score.""" # formerly relative strength rating
return (
session.query(cls)
.join(Stock)
.order_by(cls.momentum_score.desc()) # formerly rs_rating
.limit(limit)
.all()
)
@classmethod
def get_stocks_above_moving_averages(
cls, session: Session
) -> Sequence[SupplyDemandBreakoutStocks]:
"""Get stocks in demand expansion phase - trading above all major moving averages.
This identifies stocks with:
- Price above 50, 150, and 200-day moving averages (demand zone)
- Upward trending moving averages (accumulation structure)
- Indicates institutional accumulation and supply absorption
"""
return (
session.query(cls)
.join(Stock)
.filter(
cls.close_price > cls.sma_50,
cls.close_price > cls.sma_150,
cls.close_price > cls.sma_200,
cls.sma_50 > cls.sma_150,
cls.sma_150 > cls.sma_200,
)
.order_by(cls.momentum_score.desc()) # formerly rs_rating
.all()
)
@classmethod
def get_latest_analysis(
cls, session: Session, days_back: int = 1
) -> Sequence[SupplyDemandBreakoutStocks]:
"""Get latest supply/demand analysis within specified days."""
cutoff_date = datetime.now(UTC).date() - timedelta(days=days_back)
return (
session.query(cls)
.join(Stock)
.filter(cls.date_analyzed >= cutoff_date)
.order_by(cls.momentum_score.desc()) # formerly rs_rating
.all()
)
def to_dict(self) -> dict:
"""Convert to dictionary for JSON serialization."""
return {
"stock_id": str(self.stock_id),
"ticker": self.stock.ticker_symbol if self.stock else None,
"date_analyzed": self.date_analyzed.isoformat()
if self.date_analyzed
else None,
"close": float(self.close_price) if self.close_price else 0,
"volume": self.volume,
"momentum_score": float(self.momentum_score)
if self.momentum_score
else 0, # formerly rs_rating
"adr_pct": float(self.adr_pct) if self.adr_pct else 0,
"pattern": self.pattern_type,
"squeeze": self.squeeze_status,
"consolidation": self.consolidation_status, # formerly vcp
"entry": self.entry_signal,
"ema_21": float(self.ema_21) if self.ema_21 else 0,
"sma_50": float(self.sma_50) if self.sma_50 else 0,
"sma_150": float(self.sma_150) if self.sma_150 else 0,
"sma_200": float(self.sma_200) if self.sma_200 else 0,
"atr": float(self.atr) if self.atr else 0,
"avg_volume_30d": float(self.avg_volume_30d) if self.avg_volume_30d else 0,
"accumulation_rating": float(self.accumulation_rating)
if self.accumulation_rating
else 0,
"distribution_rating": float(self.distribution_rating)
if self.distribution_rating
else 0,
"breakout_strength": float(self.breakout_strength)
if self.breakout_strength
else 0,
}
class TechnicalCache(Base, TimestampMixin):
"""Cache for calculated technical indicators."""
__tablename__ = "mcp_technical_cache"
__table_args__ = (
UniqueConstraint(
"stock_id",
"date",
"indicator_type",
name="mcp_technical_cache_stock_date_indicator_unique",
),
Index("mcp_technical_cache_stock_date_idx", "stock_id", "date"),
Index("mcp_technical_cache_indicator_idx", "indicator_type"),
Index("mcp_technical_cache_date_idx", "date"),
)
id = Column(get_primary_key_type(), primary_key=True, autoincrement=True)
stock_id = Column(Uuid, ForeignKey("mcp_stocks.stock_id"), nullable=False)
date = Column(Date, nullable=False)
indicator_type = Column(
String(50), nullable=False
) # 'SMA_20', 'EMA_21', 'RSI_14', etc.
# Flexible indicator values
value = Column(Numeric(20, 8)) # Primary indicator value
value_2 = Column(Numeric(20, 8)) # Secondary value (e.g., MACD signal)
value_3 = Column(Numeric(20, 8)) # Tertiary value (e.g., MACD histogram)
# Text values for complex indicators
meta_data = Column(Text) # JSON string for additional metadata
# Calculation parameters
period = Column(Integer) # Period used (20 for SMA_20, etc.)
parameters = Column(Text) # JSON string for additional parameters
# Relationships
stock = relationship("Stock", back_populates="technical_cache")
def __repr__(self):
return (
f"<TechnicalCache(stock_id={self.stock_id}, date={self.date}, "
f"indicator={self.indicator_type}, value={self.value})>"
)
@classmethod
def get_indicator(
cls,
session: Session,
ticker_symbol: str,
indicator_type: str,
start_date: str,
end_date: str | None = None,
) -> pd.DataFrame:
"""
Get technical indicator data for a symbol and date range.
Args:
session: Database session
ticker_symbol: Stock ticker symbol
indicator_type: Type of indicator (e.g., 'SMA_20', 'RSI_14')
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format (default: today)
Returns:
DataFrame with indicator data indexed by date
"""
if not end_date:
end_date = datetime.now(UTC).strftime("%Y-%m-%d")
query = (
session.query(
cls.date,
cls.value,
cls.value_2,
cls.value_3,
cls.meta_data,
cls.parameters,
)
.join(Stock)
.filter(
Stock.ticker_symbol == ticker_symbol.upper(),
cls.indicator_type == indicator_type,
cls.date >= pd.to_datetime(start_date).date(),
cls.date <= pd.to_datetime(end_date).date(),
)
.order_by(cls.date)
)
df = pd.DataFrame(query.all())
if not df.empty:
df["date"] = pd.to_datetime(df["date"])
df.set_index("date", inplace=True)
# Convert decimal types to float
for col in ["value", "value_2", "value_3"]:
if col in df.columns:
df[col] = df[col].astype(float)
df["symbol"] = ticker_symbol.upper()
df["indicator_type"] = indicator_type
return df
def to_dict(self) -> dict:
"""Convert to dictionary for JSON serialization."""
return {
"stock_id": str(self.stock_id),
"date": self.date.isoformat() if self.date else None,
"indicator_type": self.indicator_type,
"value": float(self.value) if self.value else None,
"value_2": float(self.value_2) if self.value_2 else None,
"value_3": float(self.value_3) if self.value_3 else None,
"period": self.period,
"meta_data": self.meta_data,
"parameters": self.parameters,
}
# Backtesting Models
class BacktestResult(Base, TimestampMixin):
"""Main backtest results table with comprehensive metrics."""
__tablename__ = "mcp_backtest_results"
__table_args__ = (
Index("mcp_backtest_results_symbol_idx", "symbol"),
Index("mcp_backtest_results_strategy_idx", "strategy_type"),
Index("mcp_backtest_results_date_idx", "backtest_date"),
Index("mcp_backtest_results_sharpe_idx", "sharpe_ratio"),
Index("mcp_backtest_results_total_return_idx", "total_return"),
Index("mcp_backtest_results_symbol_strategy_idx", "symbol", "strategy_type"),
)
backtest_id = Column(Uuid, primary_key=True, default=uuid.uuid4)
# Basic backtest metadata
symbol = Column(String(10), nullable=False, index=True)
strategy_type = Column(String(50), nullable=False)
backtest_date = Column(
DateTime(timezone=True), nullable=False, default=lambda: datetime.now(UTC)
)
# Date range and setup
start_date = Column(Date, nullable=False)
end_date = Column(Date, nullable=False)
initial_capital = Column(Numeric(15, 2), default=10000.0)
# Trading costs and parameters
fees = Column(Numeric(6, 4), default=0.001) # 0.1% default
slippage = Column(Numeric(6, 4), default=0.001) # 0.1% default
# Strategy parameters (stored as JSON for flexibility)
parameters = Column(JSON)
# Key Performance Metrics
total_return = Column(Numeric(10, 4)) # Total return percentage
annualized_return = Column(Numeric(10, 4)) # Annualized return percentage
sharpe_ratio = Column(Numeric(8, 4))
sortino_ratio = Column(Numeric(8, 4))
calmar_ratio = Column(Numeric(8, 4))
# Risk Metrics
max_drawdown = Column(Numeric(8, 4)) # Maximum drawdown percentage
max_drawdown_duration = Column(Integer) # Days
volatility = Column(Numeric(8, 4)) # Annualized volatility
downside_volatility = Column(Numeric(8, 4)) # Downside deviation
# Trade Statistics
total_trades = Column(Integer, default=0)
winning_trades = Column(Integer, default=0)
losing_trades = Column(Integer, default=0)
win_rate = Column(Numeric(5, 4)) # Win rate percentage
# P&L Statistics
profit_factor = Column(Numeric(8, 4)) # Gross profit / Gross loss
average_win = Column(Numeric(12, 4))
average_loss = Column(Numeric(12, 4))
largest_win = Column(Numeric(12, 4))
largest_loss = Column(Numeric(12, 4))
# Portfolio Value Metrics
final_portfolio_value = Column(Numeric(15, 2))
peak_portfolio_value = Column(Numeric(15, 2))
# Additional Analysis
beta = Column(Numeric(8, 4)) # Market beta
alpha = Column(Numeric(8, 4)) # Alpha vs market
# Time series data (stored as JSON for efficient queries)
equity_curve = Column(JSON) # Daily portfolio values
drawdown_series = Column(JSON) # Daily drawdown values
# Execution metadata
execution_time_seconds = Column(Numeric(8, 3)) # How long the backtest took
data_points = Column(Integer) # Number of data points used
# Status and notes
status = Column(String(20), default="completed") # completed, failed, in_progress
error_message = Column(Text) # Error details if status = failed
notes = Column(Text) # User notes
# Relationships
trades = relationship(
"BacktestTrade",
back_populates="backtest_result",
cascade="all, delete-orphan",
lazy="selectin",
)
optimization_results = relationship(
"OptimizationResult",
back_populates="backtest_result",
cascade="all, delete-orphan",
)
def __repr__(self):
return (
f"<BacktestResult(id={self.backtest_id}, symbol={self.symbol}, "
f"strategy={self.strategy_type}, return={self.total_return})>"
)
@classmethod
def get_by_symbol_and_strategy(
cls, session: Session, symbol: str, strategy_type: str, limit: int = 10
) -> Sequence[BacktestResult]:
"""Get recent backtests for a specific symbol and strategy."""
return (
session.query(cls)
.filter(cls.symbol == symbol.upper(), cls.strategy_type == strategy_type)
.order_by(cls.backtest_date.desc())
.limit(limit)
.all()
)
@classmethod
def get_best_performing(
cls, session: Session, metric: str = "sharpe_ratio", limit: int = 20
) -> Sequence[BacktestResult]:
"""Get best performing backtests by specified metric."""
metric_column = getattr(cls, metric, cls.sharpe_ratio)
return (
session.query(cls)
.filter(cls.status == "completed")
.order_by(metric_column.desc())
.limit(limit)
.all()
)
def to_dict(self) -> dict:
"""Convert to dictionary for JSON serialization."""
return {
"backtest_id": str(self.backtest_id),
"symbol": self.symbol,
"strategy_type": self.strategy_type,
"backtest_date": self.backtest_date.isoformat()
if self.backtest_date
else None,
"start_date": self.start_date.isoformat() if self.start_date else None,
"end_date": self.end_date.isoformat() if self.end_date else None,
"initial_capital": float(self.initial_capital)
if self.initial_capital
else 0,
"total_return": float(self.total_return) if self.total_return else 0,
"sharpe_ratio": float(self.sharpe_ratio) if self.sharpe_ratio else 0,
"max_drawdown": float(self.max_drawdown) if self.max_drawdown else 0,
"win_rate": float(self.win_rate) if self.win_rate else 0,
"total_trades": self.total_trades,
"parameters": self.parameters,
"status": self.status,
}
class BacktestTrade(Base, TimestampMixin):
"""Individual trade records from backtests."""
__tablename__ = "mcp_backtest_trades"
__table_args__ = (
Index("mcp_backtest_trades_backtest_idx", "backtest_id"),
Index("mcp_backtest_trades_entry_date_idx", "entry_date"),
Index("mcp_backtest_trades_exit_date_idx", "exit_date"),
Index("mcp_backtest_trades_pnl_idx", "pnl"),
Index("mcp_backtest_trades_backtest_entry_idx", "backtest_id", "entry_date"),
)
trade_id = Column(Uuid, primary_key=True, default=uuid.uuid4)
backtest_id = Column(
Uuid, ForeignKey("mcp_backtest_results.backtest_id"), nullable=False
)
# Trade identification
trade_number = Column(
Integer, nullable=False
) # Sequential trade number in backtest
# Entry details
entry_date = Column(Date, nullable=False)
entry_price = Column(Numeric(12, 4), nullable=False)
entry_time = Column(DateTime(timezone=True)) # For intraday backtests
# Exit details
exit_date = Column(Date)
exit_price = Column(Numeric(12, 4))
exit_time = Column(DateTime(timezone=True))
# Position details
position_size = Column(Numeric(15, 2)) # Number of shares/units
direction = Column(String(5), nullable=False) # 'long' or 'short'
# P&L and performance
pnl = Column(Numeric(12, 4)) # Profit/Loss in currency
pnl_percent = Column(Numeric(8, 4)) # P&L as percentage
# Risk metrics for this trade
mae = Column(Numeric(8, 4)) # Maximum Adverse Excursion
mfe = Column(Numeric(8, 4)) # Maximum Favorable Excursion
# Trade duration
duration_days = Column(Integer)
duration_hours = Column(Numeric(8, 2)) # For intraday precision
# Exit reason and fees
exit_reason = Column(String(50)) # stop_loss, take_profit, signal, time_exit
fees_paid = Column(Numeric(10, 4), default=0)
slippage_cost = Column(Numeric(10, 4), default=0)
# Relationships
backtest_result = relationship(
"BacktestResult", back_populates="trades", lazy="joined"
)
def __repr__(self):
return (
f"<BacktestTrade(id={self.trade_id}, backtest_id={self.backtest_id}, "
f"pnl={self.pnl}, duration={self.duration_days}d)>"
)
@classmethod
def get_trades_for_backtest(
cls, session: Session, backtest_id: str
) -> Sequence[BacktestTrade]:
"""Get all trades for a specific backtest."""
return (
session.query(cls)
.filter(cls.backtest_id == backtest_id)
.order_by(cls.entry_date, cls.trade_number)
.all()
)
@classmethod
def get_winning_trades(
cls, session: Session, backtest_id: str
) -> Sequence[BacktestTrade]:
"""Get winning trades for a backtest."""
return (
session.query(cls)
.filter(cls.backtest_id == backtest_id, cls.pnl > 0)
.order_by(cls.pnl.desc())
.all()
)
@classmethod
def get_losing_trades(
cls, session: Session, backtest_id: str
) -> Sequence[BacktestTrade]:
"""Get losing trades for a backtest."""
return (
session.query(cls)
.filter(cls.backtest_id == backtest_id, cls.pnl < 0)
.order_by(cls.pnl)
.all()
)
class OptimizationResult(Base, TimestampMixin):
"""Parameter optimization results for strategies."""
__tablename__ = "mcp_optimization_results"
__table_args__ = (
Index("mcp_optimization_results_backtest_idx", "backtest_id"),
Index("mcp_optimization_results_param_set_idx", "parameter_set"),
Index("mcp_optimization_results_objective_idx", "objective_value"),
)
optimization_id = Column(Uuid, primary_key=True, default=uuid.uuid4)
backtest_id = Column(
Uuid, ForeignKey("mcp_backtest_results.backtest_id"), nullable=False
)
# Optimization metadata
optimization_date = Column(
DateTime(timezone=True), default=lambda: datetime.now(UTC)
)
parameter_set = Column(Integer, nullable=False) # Set number in optimization run
# Parameters tested (JSON for flexibility)
parameters = Column(JSON, nullable=False)
# Optimization objective and results
objective_function = Column(
String(50)
) # sharpe_ratio, total_return, profit_factor, etc.
objective_value = Column(Numeric(12, 6)) # Value of objective function
# Key metrics for this parameter set
total_return = Column(Numeric(10, 4))
sharpe_ratio = Column(Numeric(8, 4))
max_drawdown = Column(Numeric(8, 4))
win_rate = Column(Numeric(5, 4))
profit_factor = Column(Numeric(8, 4))
total_trades = Column(Integer)
# Ranking within optimization
rank = Column(Integer) # 1 = best, 2 = second best, etc.
# Statistical significance
is_statistically_significant = Column(Boolean, default=False)
p_value = Column(Numeric(8, 6)) # Statistical significance test result
# Relationships
backtest_result = relationship(
"BacktestResult", back_populates="optimization_results", lazy="joined"
)
def __repr__(self):
return (
f"<OptimizationResult(id={self.optimization_id}, "
f"objective={self.objective_value}, rank={self.rank})>"
)
@classmethod
def get_best_parameters(
cls, session: Session, backtest_id: str, limit: int = 5
) -> Sequence[OptimizationResult]:
"""Get top performing parameter sets for a backtest."""
return (
session.query(cls)
.filter(cls.backtest_id == backtest_id)
.order_by(cls.rank)
.limit(limit)
.all()
)
class WalkForwardTest(Base, TimestampMixin):
"""Walk-forward validation test results."""
__tablename__ = "mcp_walk_forward_tests"
__table_args__ = (
Index("mcp_walk_forward_tests_parent_idx", "parent_backtest_id"),
Index("mcp_walk_forward_tests_period_idx", "test_period_start"),
Index("mcp_walk_forward_tests_performance_idx", "out_of_sample_return"),
)
walk_forward_id = Column(Uuid, primary_key=True, default=uuid.uuid4)
parent_backtest_id = Column(
Uuid, ForeignKey("mcp_backtest_results.backtest_id"), nullable=False
)
# Test configuration
test_date = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
window_size_months = Column(Integer, nullable=False) # Training window size
step_size_months = Column(Integer, nullable=False) # Step size for walking forward
# Time periods
training_start = Column(Date, nullable=False)
training_end = Column(Date, nullable=False)
test_period_start = Column(Date, nullable=False)
test_period_end = Column(Date, nullable=False)
# Optimization results from training period
optimal_parameters = Column(JSON) # Best parameters from training
training_performance = Column(Numeric(10, 4)) # Training period return
# Out-of-sample test results
out_of_sample_return = Column(Numeric(10, 4))
out_of_sample_sharpe = Column(Numeric(8, 4))
out_of_sample_drawdown = Column(Numeric(8, 4))
out_of_sample_trades = Column(Integer)
# Performance vs training expectations
performance_ratio = Column(Numeric(8, 4)) # Out-sample return / Training return
degradation_factor = Column(Numeric(8, 4)) # How much performance degraded
# Statistical validation
is_profitable = Column(Boolean)
is_statistically_significant = Column(Boolean, default=False)
# Relationships
parent_backtest = relationship(
"BacktestResult", foreign_keys=[parent_backtest_id], lazy="joined"
)
def __repr__(self):
return (
f"<WalkForwardTest(id={self.walk_forward_id}, "
f"return={self.out_of_sample_return}, ratio={self.performance_ratio})>"
)
@classmethod
def get_walk_forward_results(
cls, session: Session, parent_backtest_id: str
) -> Sequence[WalkForwardTest]:
"""Get all walk-forward test results for a backtest."""
return (
session.query(cls)
.filter(cls.parent_backtest_id == parent_backtest_id)
.order_by(cls.test_period_start)
.all()
)
class BacktestPortfolio(Base, TimestampMixin):
"""Portfolio-level backtests with multiple symbols."""
__tablename__ = "mcp_backtest_portfolios"
__table_args__ = (
Index("mcp_backtest_portfolios_name_idx", "portfolio_name"),
Index("mcp_backtest_portfolios_date_idx", "backtest_date"),
Index("mcp_backtest_portfolios_return_idx", "total_return"),
)
portfolio_backtest_id = Column(Uuid, primary_key=True, default=uuid.uuid4)
# Portfolio identification
portfolio_name = Column(String(100), nullable=False)
description = Column(Text)
# Test metadata
backtest_date = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
start_date = Column(Date, nullable=False)
end_date = Column(Date, nullable=False)
# Portfolio composition
symbols = Column(JSON, nullable=False) # List of symbols
weights = Column(JSON) # Portfolio weights (if not equal weight)
rebalance_frequency = Column(String(20)) # daily, weekly, monthly, quarterly
# Portfolio parameters
initial_capital = Column(Numeric(15, 2), default=100000.0)
max_positions = Column(Integer) # Maximum concurrent positions
position_sizing_method = Column(
String(50)
) # equal_weight, volatility_weighted, etc.
# Risk management
portfolio_stop_loss = Column(Numeric(6, 4)) # Portfolio-level stop loss
max_sector_allocation = Column(Numeric(5, 4)) # Maximum allocation per sector
correlation_threshold = Column(
Numeric(5, 4)
) # Maximum correlation between holdings
# Performance metrics (portfolio level)
total_return = Column(Numeric(10, 4))
annualized_return = Column(Numeric(10, 4))
sharpe_ratio = Column(Numeric(8, 4))
sortino_ratio = Column(Numeric(8, 4))
max_drawdown = Column(Numeric(8, 4))
volatility = Column(Numeric(8, 4))
# Portfolio-specific metrics
diversification_ratio = Column(Numeric(8, 4)) # Portfolio vol / Weighted avg vol
concentration_index = Column(Numeric(8, 4)) # Herfindahl index
turnover_rate = Column(Numeric(8, 4)) # Portfolio turnover
# Individual component backtests (JSON references)
component_backtest_ids = Column(JSON) # List of individual backtest IDs
# Time series data
portfolio_equity_curve = Column(JSON)
portfolio_weights_history = Column(JSON) # Historical weights over time
# Status
status = Column(String(20), default="completed")
notes = Column(Text)
def __repr__(self):
return (
f"<BacktestPortfolio(id={self.portfolio_backtest_id}, "
f"name={self.portfolio_name}, return={self.total_return})>"
)
@classmethod
def get_portfolio_backtests(
cls, session: Session, portfolio_name: str | None = None, limit: int = 10
) -> Sequence[BacktestPortfolio]:
"""Get portfolio backtests, optionally filtered by name."""
query = session.query(cls).order_by(cls.backtest_date.desc())
if portfolio_name:
query = query.filter(cls.portfolio_name == portfolio_name)
return query.limit(limit).all()
def to_dict(self) -> dict:
"""Convert to dictionary for JSON serialization."""
return {
"portfolio_backtest_id": str(self.portfolio_backtest_id),
"portfolio_name": self.portfolio_name,
"symbols": self.symbols,
"start_date": self.start_date.isoformat() if self.start_date else None,
"end_date": self.end_date.isoformat() if self.end_date else None,
"total_return": float(self.total_return) if self.total_return else 0,
"sharpe_ratio": float(self.sharpe_ratio) if self.sharpe_ratio else 0,
"max_drawdown": float(self.max_drawdown) if self.max_drawdown else 0,
"status": self.status,
}
# Helper functions for working with the models
def bulk_insert_price_data(
session: Session, ticker_symbol: str, df: pd.DataFrame
) -> int:
"""
Bulk insert price data from a DataFrame.
Args:
session: Database session
ticker_symbol: Stock ticker symbol
df: DataFrame with OHLCV data (must have date index)
Returns:
Number of records inserted (or would be inserted)
"""
if df.empty:
return 0
# Get or create stock
stock = Stock.get_or_create(session, ticker_symbol)
# First, check how many records already exist
existing_dates = set()
if hasattr(df.index[0], "date"):
dates_to_check = [d.date() for d in df.index]
else:
dates_to_check = list(df.index)
existing_query = session.query(PriceCache.date).filter(
PriceCache.stock_id == stock.stock_id, PriceCache.date.in_(dates_to_check)
)
existing_dates = {row[0] for row in existing_query.all()}
# Prepare data for bulk insert
records = []
new_count = 0
for date_idx, row in df.iterrows():
# Handle different index types - datetime index vs date index
if hasattr(date_idx, "date") and callable(date_idx.date):
date_val = date_idx.date() # type: ignore[attr-defined]
elif hasattr(date_idx, "to_pydatetime") and callable(date_idx.to_pydatetime):
date_val = date_idx.to_pydatetime().date() # type: ignore[attr-defined]
else:
# Assume it's already a date-like object
date_val = date_idx
# Skip if already exists
if date_val in existing_dates:
continue
new_count += 1
# Handle both lowercase and capitalized column names from yfinance
open_val = row.get("open", row.get("Open", 0))
high_val = row.get("high", row.get("High", 0))
low_val = row.get("low", row.get("Low", 0))
close_val = row.get("close", row.get("Close", 0))
volume_val = row.get("volume", row.get("Volume", 0))
# Handle None values
if volume_val is None:
volume_val = 0
records.append(
{
"stock_id": stock.stock_id,
"date": date_val,
"open_price": Decimal(str(open_val)),
"high_price": Decimal(str(high_val)),
"low_price": Decimal(str(low_val)),
"close_price": Decimal(str(close_val)),
"volume": int(volume_val),
"created_at": datetime.now(UTC),
"updated_at": datetime.now(UTC),
}
)
# Only insert if there are new records
if records:
# Use database-specific upsert logic
if "postgresql" in DATABASE_URL:
from sqlalchemy.dialects.postgresql import insert
stmt = insert(PriceCache).values(records)
stmt = stmt.on_conflict_do_nothing(index_elements=["stock_id", "date"])
else:
# For SQLite, use INSERT OR IGNORE
from sqlalchemy import insert
stmt = insert(PriceCache).values(records)
# SQLite doesn't support on_conflict_do_nothing, use INSERT OR IGNORE
stmt = stmt.prefix_with("OR IGNORE")
result = session.execute(stmt)
session.commit()
# Log if rowcount differs from expected
if result.rowcount != new_count:
logger.warning(
f"Expected to insert {new_count} records but rowcount was {result.rowcount}"
)
return result.rowcount
else:
logger.debug(
f"All {len(df)} records already exist in cache for {ticker_symbol}"
)
return 0
def get_latest_maverick_screening(days_back: int = 1) -> dict:
"""Get latest screening results from all maverick tables."""
with SessionLocal() as session:
results = {
"maverick_stocks": [
stock.to_dict()
for stock in MaverickStocks.get_latest_analysis(
session, days_back=days_back
)
],
"maverick_bear_stocks": [
stock.to_dict()
for stock in MaverickBearStocks.get_latest_analysis(
session, days_back=days_back
)
],
"supply_demand_breakouts": [
stock.to_dict()
for stock in SupplyDemandBreakoutStocks.get_latest_analysis(
session, days_back=days_back
)
],
}
return results
def bulk_insert_screening_data(
session: Session,
model_class,
screening_data: list[dict],
date_analyzed: date | None = None,
) -> int:
"""
Bulk insert screening data for any screening model.
Args:
session: Database session
model_class: The screening model class (MaverickStocks, etc.)
screening_data: List of screening result dictionaries
date_analyzed: Date of analysis (default: today)
Returns:
Number of records inserted
"""
if not screening_data:
return 0
if date_analyzed is None:
date_analyzed = datetime.now(UTC).date()
# Remove existing data for this date
session.query(model_class).filter(
model_class.date_analyzed == date_analyzed
).delete()
inserted_count = 0
for data in screening_data:
# Get or create stock
ticker = data.get("ticker") or data.get("symbol")
if not ticker:
continue
stock = Stock.get_or_create(session, ticker)
# Create screening record
record_data = {
"stock_id": stock.stock_id,
"date_analyzed": date_analyzed,
}
# Map common fields
field_mapping = {
"open": "open_price",
"high": "high_price",
"low": "low_price",
"close": "close_price",
"pat": "pattern_type",
"sqz": "squeeze_status",
"vcp": "consolidation_status",
"entry": "entry_signal",
}
for key, value in data.items():
if key in ["ticker", "symbol"]:
continue
mapped_key = field_mapping.get(key, key)
if hasattr(model_class, mapped_key):
record_data[mapped_key] = value
record = model_class(**record_data)
session.add(record)
inserted_count += 1
session.commit()
return inserted_count
# ============================================================================
# Portfolio Management Models
# ============================================================================
class UserPortfolio(TimestampMixin, Base):
"""
User portfolio for tracking investment holdings.
Follows personal-use design with single user_id="default" for the personal
MaverickMCP server. Stores portfolio metadata and relationships to positions.
Attributes:
id: Unique portfolio identifier (UUID)
user_id: User identifier (default: "default" for single-user)
name: Portfolio display name
positions: Relationship to PortfolioPosition records
"""
__tablename__ = "mcp_portfolios"
id = Column(Uuid, primary_key=True, default=uuid.uuid4)
user_id = Column(String(100), nullable=False, default="default", index=True)
name = Column(String(200), nullable=False, default="My Portfolio")
# Relationships
positions = relationship(
"PortfolioPosition",
back_populates="portfolio",
cascade="all, delete-orphan",
lazy="selectin", # Efficient loading
)
# Indexes for queries
__table_args__ = (
Index("idx_portfolio_user", "user_id"),
UniqueConstraint("user_id", "name", name="uq_user_portfolio_name"),
)
def __repr__(self):
return f"<UserPortfolio(id={self.id}, name='{self.name}', positions={len(self.positions)})>"
class PortfolioPosition(TimestampMixin, Base):
"""
Individual position within a portfolio with cost basis tracking.
Stores position details with high-precision Decimal types for financial accuracy.
Uses average cost basis method for educational simplicity.
Attributes:
id: Unique position identifier (UUID)
portfolio_id: Foreign key to parent portfolio
ticker: Stock ticker symbol (e.g., "AAPL")
shares: Number of shares owned (supports fractional shares)
average_cost_basis: Average cost per share
total_cost: Total capital invested (shares × average_cost_basis)
purchase_date: Earliest purchase date for this position
notes: Optional user notes about the position
"""
__tablename__ = "mcp_portfolio_positions"
id = Column(Uuid, primary_key=True, default=uuid.uuid4)
portfolio_id = Column(
Uuid, ForeignKey("mcp_portfolios.id", ondelete="CASCADE"), nullable=False
)
# Position details with financial precision
ticker = Column(String(20), nullable=False, index=True)
shares = Column(
Numeric(20, 8), nullable=False
) # High precision for fractional shares
average_cost_basis = Column(
Numeric(12, 4), nullable=False
) # 4 decimal places (cents)
total_cost = Column(Numeric(20, 4), nullable=False) # Total capital invested
purchase_date = Column(DateTime(timezone=True), nullable=False) # Earliest purchase
notes = Column(Text, nullable=True) # Optional user notes
# Relationships
portfolio = relationship("UserPortfolio", back_populates="positions")
# Indexes for efficient queries
__table_args__ = (
Index("idx_position_portfolio", "portfolio_id"),
Index("idx_position_ticker", "ticker"),
Index("idx_position_portfolio_ticker", "portfolio_id", "ticker"),
UniqueConstraint("portfolio_id", "ticker", name="uq_portfolio_position_ticker"),
)
def __repr__(self):
return f"<PortfolioPosition(ticker='{self.ticker}', shares={self.shares}, cost_basis={self.average_cost_basis})>"
# Auth models removed for personal use - no multi-user functionality needed
# Initialize tables when module is imported
if __name__ == "__main__":
logger.info("Creating database tables...")
init_db()
logger.info("Database tables created successfully!")
```
--------------------------------------------------------------------------------
/maverick_mcp/agents/deep_research.py:
--------------------------------------------------------------------------------
```python
"""
DeepResearchAgent implementation using 2025 LangGraph patterns.
Provides comprehensive financial research capabilities with web search,
content analysis, sentiment detection, and source validation.
"""
from __future__ import annotations
import asyncio
import json
import logging
from collections.abc import Iterable
from datetime import UTC, datetime
from typing import Any
from uuid import uuid4
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import BaseTool, tool
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph # type: ignore[import-untyped]
from langgraph.types import Command # type: ignore[import-untyped]
from maverick_mcp.agents.base import PersonaAwareAgent
from maverick_mcp.agents.circuit_breaker import circuit_manager
from maverick_mcp.config.settings import get_settings
from maverick_mcp.exceptions import (
WebSearchError,
)
from maverick_mcp.memory.stores import ConversationStore
from maverick_mcp.utils.orchestration_logging import (
get_orchestration_logger,
log_agent_execution,
log_method_call,
log_performance_metrics,
log_synthesis_operation,
)
try: # pragma: no cover - optional dependency
from tavily import TavilyClient # type: ignore[import-not-found]
except ImportError: # pragma: no cover
TavilyClient = None # type: ignore[assignment]
# Import moved to avoid circular dependency - will import where needed
from maverick_mcp.workflows.state import DeepResearchState
logger = logging.getLogger(__name__)
settings = get_settings()
# Global search provider cache and connection manager
_search_provider_cache: dict[str, Any] = {}
async def get_cached_search_provider(exa_api_key: str | None = None) -> Any | None:
"""Get cached Exa search provider to avoid repeated initialization delays."""
cache_key = f"exa:{exa_api_key is not None}"
if cache_key in _search_provider_cache:
return _search_provider_cache[cache_key]
logger.info("Initializing Exa search provider")
provider = None
# Initialize Exa provider with caching
if exa_api_key:
try:
provider = ExaSearchProvider(exa_api_key)
logger.info("Initialized Exa search provider")
# Cache the provider
_search_provider_cache[cache_key] = provider
except ImportError as e:
logger.warning(f"Failed to initialize Exa provider: {e}")
return provider
# Research depth levels optimized for quick searches
RESEARCH_DEPTH_LEVELS = {
"basic": {
"max_sources": 3,
"max_searches": 1, # Reduced for speed
"analysis_depth": "summary",
"validation_required": False,
},
"standard": {
"max_sources": 5, # Reduced from 8
"max_searches": 2, # Reduced from 4
"analysis_depth": "detailed",
"validation_required": False, # Disabled for speed
},
"comprehensive": {
"max_sources": 10, # Reduced from 15
"max_searches": 3, # Reduced from 6
"analysis_depth": "comprehensive",
"validation_required": False, # Disabled for speed
},
"exhaustive": {
"max_sources": 15, # Reduced from 25
"max_searches": 5, # Reduced from 10
"analysis_depth": "exhaustive",
"validation_required": True,
},
}
# Persona-specific research focus areas
PERSONA_RESEARCH_FOCUS = {
"conservative": {
"keywords": [
"dividend",
"stability",
"risk",
"debt",
"cash flow",
"established",
],
"sources": [
"sec filings",
"annual reports",
"rating agencies",
"dividend history",
],
"risk_focus": "downside protection",
"time_horizon": "long-term",
},
"moderate": {
"keywords": ["growth", "value", "balance", "diversification", "fundamentals"],
"sources": ["financial statements", "analyst reports", "industry analysis"],
"risk_focus": "risk-adjusted returns",
"time_horizon": "medium-term",
},
"aggressive": {
"keywords": ["growth", "momentum", "opportunity", "innovation", "expansion"],
"sources": [
"news",
"earnings calls",
"industry trends",
"competitive analysis",
],
"risk_focus": "upside potential",
"time_horizon": "short to medium-term",
},
"day_trader": {
"keywords": [
"catalysts",
"earnings",
"news",
"volume",
"volatility",
"momentum",
],
"sources": ["breaking news", "social sentiment", "earnings announcements"],
"risk_focus": "short-term risks",
"time_horizon": "intraday to weekly",
},
}
class WebSearchProvider:
"""Base class for web search providers with early abort mechanism."""
def __init__(self, api_key: str):
self.api_key = api_key
self.rate_limiter = None # Implement rate limiting
self._failure_count = 0
self._max_failures = 3 # Abort after 3 consecutive failures
self._is_healthy = True
self.settings = get_settings()
self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
def _calculate_timeout(
self, query: str, timeout_budget: float | None = None
) -> float:
"""Calculate generous timeout for thorough research operations."""
query_words = len(query.split())
# Generous timeout calculation for thorough search operations
if query_words <= 3:
base_timeout = 30.0 # Simple queries - 30s for thorough results
elif query_words <= 8:
base_timeout = 45.0 # Standard queries - 45s for comprehensive search
else:
base_timeout = 60.0 # Complex queries - 60s for exhaustive search
# Apply budget constraints if available
if timeout_budget and timeout_budget > 0:
# Use generous portion of available budget per search operation
budget_timeout = max(
timeout_budget * 0.6, 30.0
) # At least 30s, use 60% of budget
calculated_timeout = min(base_timeout, budget_timeout)
# Ensure minimum timeout (at least 30s for thorough search)
calculated_timeout = max(calculated_timeout, 30.0)
else:
calculated_timeout = base_timeout
# Final timeout with generous minimum for thorough search
final_timeout = max(calculated_timeout, 30.0)
return final_timeout
def _record_failure(self, error_type: str = "unknown") -> None:
"""Record a search failure and check if provider should be disabled."""
self._failure_count += 1
# Use separate thresholds for timeout vs other failures
timeout_threshold = getattr(
self.settings.performance, "search_timeout_failure_threshold", 12
)
# Much more tolerant of timeout failures - they may be due to network/complexity
if error_type == "timeout" and self._failure_count >= timeout_threshold:
self._is_healthy = False
logger.warning(
f"Search provider {self.__class__.__name__} disabled after "
f"{self._failure_count} consecutive timeout failures (threshold: {timeout_threshold})"
)
elif error_type != "timeout" and self._failure_count >= self._max_failures * 2:
# Be more lenient for non-timeout failures (2x threshold)
self._is_healthy = False
logger.warning(
f"Search provider {self.__class__.__name__} disabled after "
f"{self._failure_count} total non-timeout failures"
)
logger.debug(
f"Provider {self.__class__.__name__} failure recorded: "
f"type={error_type}, count={self._failure_count}, healthy={self._is_healthy}"
)
def _record_success(self) -> None:
"""Record a successful search and reset failure count."""
if self._failure_count > 0:
logger.info(
f"Search provider {self.__class__.__name__} recovered after "
f"{self._failure_count} failures"
)
self._failure_count = 0
self._is_healthy = True
def is_healthy(self) -> bool:
"""Check if provider is healthy and should be used."""
return self._is_healthy
async def search(
self, query: str, num_results: int = 10, timeout_budget: float | None = None
) -> list[dict[str, Any]]:
"""Perform web search and return results."""
raise NotImplementedError
async def get_content(self, url: str) -> dict[str, Any]:
"""Extract content from URL."""
raise NotImplementedError
async def search_multiple_providers(
self,
queries: list[str],
providers: list[str] | None = None,
max_results_per_query: int = 5,
) -> dict[str, list[dict[str, Any]]]:
"""Search using multiple providers and return aggregated results."""
providers = providers or ["exa"] # Default to available providers
results = {}
for provider_name in providers:
provider_results = []
for query in queries:
try:
query_results = await self.search(query, max_results_per_query)
provider_results.extend(query_results or [])
except Exception as e:
self.logger.warning(
f"Search failed for provider {provider_name}, query '{query}': {e}"
)
continue
results[provider_name] = provider_results
return results
def _timeframe_to_date(self, timeframe: str) -> str | None:
"""Convert timeframe string to date string."""
from datetime import datetime, timedelta
now = datetime.now()
if timeframe == "1d":
date = now - timedelta(days=1)
elif timeframe == "1w":
date = now - timedelta(weeks=1)
elif timeframe == "1m":
date = now - timedelta(days=30)
else:
# Invalid or unsupported timeframe, return None
return None
return date.strftime("%Y-%m-%d")
class ExaSearchProvider(WebSearchProvider):
"""Exa search provider for comprehensive web search using MCP tools with financial optimization."""
def __init__(self, api_key: str):
super().__init__(api_key)
# Store the API key for verification
self._api_key_verified = bool(api_key)
# Financial-specific domain preferences for better results
self.financial_domains = [
"sec.gov",
"edgar.sec.gov",
"investor.gov",
"bloomberg.com",
"reuters.com",
"wsj.com",
"ft.com",
"marketwatch.com",
"yahoo.com/finance",
"finance.yahoo.com",
"morningstar.com",
"fool.com",
"seekingalpha.com",
"investopedia.com",
"barrons.com",
"cnbc.com",
"nasdaq.com",
"nyse.com",
"finra.org",
"federalreserve.gov",
"treasury.gov",
"bls.gov",
]
# Domains to exclude for financial searches
self.excluded_domains = [
"facebook.com",
"twitter.com",
"x.com",
"instagram.com",
"tiktok.com",
"reddit.com",
"pinterest.com",
"linkedin.com",
"youtube.com",
"wikipedia.org",
]
logger.info("Initialized ExaSearchProvider with financial optimization")
async def search(
self, query: str, num_results: int = 10, timeout_budget: float | None = None
) -> list[dict[str, Any]]:
"""Search using Exa via async client for comprehensive web results with adaptive timeout."""
return await self._search_with_strategy(
query, num_results, timeout_budget, "auto"
)
async def search_financial(
self,
query: str,
num_results: int = 10,
timeout_budget: float | None = None,
strategy: str = "hybrid",
) -> list[dict[str, Any]]:
"""
Enhanced financial search with optimized queries and domain targeting.
Args:
query: Search query
num_results: Number of results to return
timeout_budget: Timeout budget in seconds
strategy: Search strategy - 'hybrid', 'authoritative', 'comprehensive', or 'auto'
"""
return await self._search_with_strategy(
query, num_results, timeout_budget, strategy
)
async def _search_with_strategy(
self, query: str, num_results: int, timeout_budget: float | None, strategy: str
) -> list[dict[str, Any]]:
"""Internal method to handle different search strategies."""
# Check provider health before attempting search
if not self.is_healthy():
logger.warning("Exa provider is unhealthy - skipping search")
raise WebSearchError("Exa provider disabled due to repeated failures")
# Calculate adaptive timeout
search_timeout = self._calculate_timeout(query, timeout_budget)
try:
# Use search-specific circuit breaker settings (more tolerant)
circuit_breaker = await circuit_manager.get_or_create(
"exa_search",
failure_threshold=getattr(
self.settings.performance,
"search_circuit_breaker_failure_threshold",
8,
),
recovery_timeout=getattr(
self.settings.performance,
"search_circuit_breaker_recovery_timeout",
30,
),
)
async def _search():
# Use the async exa-py library for web search
try:
from exa_py import AsyncExa
# Initialize AsyncExa client with API key
async_exa_client = AsyncExa(api_key=self.api_key)
# Configure search parameters based on strategy
search_params = self._get_search_params(
query, num_results, strategy
)
# Call Exa search with optimized parameters
exa_response = await async_exa_client.search_and_contents(
**search_params
)
# Convert Exa response to standard format with enhanced metadata
results = []
if exa_response and hasattr(exa_response, "results"):
for result in exa_response.results:
# Enhanced result processing with financial relevance scoring
financial_relevance = self._calculate_financial_relevance(
result
)
results.append(
{
"url": result.url or "",
"title": result.title or "No Title",
"content": (result.text or "")[:2000],
"raw_content": (result.text or "")[
:5000
], # Increased for financial content
"published_date": result.published_date or "",
"score": result.score
if hasattr(result, "score")
and result.score is not None
else 0.7,
"financial_relevance": financial_relevance,
"provider": "exa",
"author": result.author
if hasattr(result, "author")
and result.author is not None
else "",
"domain": self._extract_domain(result.url or ""),
"is_authoritative": self._is_authoritative_source(
result.url or ""
),
}
)
# Sort results by financial relevance and score
results.sort(
key=lambda x: (x["financial_relevance"], x["score"]),
reverse=True,
)
return results
except ImportError:
logger.error("exa-py library not available - cannot perform search")
raise WebSearchError(
"exa-py library required for ExaSearchProvider"
)
except Exception as e:
logger.error(f"Error calling Exa API: {e}")
raise e
# Use adaptive timeout based on query complexity and budget
result = await asyncio.wait_for(
circuit_breaker.call(_search), timeout=search_timeout
)
self._record_success() # Record successful search
logger.debug(
f"Exa search completed in {search_timeout:.1f}s timeout window"
)
return result
except TimeoutError:
self._record_failure("timeout") # Record timeout as specific failure type
query_snippet = query[:100] + ("..." if len(query) > 100 else "")
logger.error(
f"Exa search timeout after {search_timeout:.1f} seconds (failure #{self._failure_count}) "
f"for query: '{query_snippet}'"
)
raise WebSearchError(
f"Exa search timed out after {search_timeout:.1f} seconds"
)
except Exception as e:
self._record_failure("error") # Record non-timeout failure
logger.error(f"Exa search error (failure #{self._failure_count}): {e}")
raise WebSearchError(f"Exa search failed: {str(e)}")
def _get_search_params(
self, query: str, num_results: int, strategy: str
) -> dict[str, Any]:
"""
Generate optimized search parameters based on strategy and query type.
Args:
query: Search query
num_results: Number of results
strategy: Search strategy
Returns:
Dictionary of search parameters for Exa API
"""
# Base parameters
params = {
"query": query,
"num_results": num_results,
"text": {"max_characters": 5000}, # Increased for financial content
}
# Strategy-specific optimizations
if strategy == "authoritative":
# Focus on authoritative financial sources
# Note: Exa API doesn't allow both include_domains and exclude_domains with content
params.update(
{
"include_domains": self.financial_domains[
:10
], # Top authoritative sources
"type": "auto", # Let Exa decide neural vs keyword
"start_published_date": "2020-01-01", # Recent financial data
}
)
elif strategy == "comprehensive":
# Broad search across all financial sources
params.update(
{
"exclude_domains": self.excluded_domains,
"type": "neural", # Better for comprehensive understanding
"start_published_date": "2018-01-01", # Broader historical context
}
)
elif strategy == "hybrid":
# Balanced approach with domain preferences
params.update(
{
"exclude_domains": self.excluded_domains,
"type": "auto", # Hybrid neural/keyword approach
"start_published_date": "2019-01-01",
# Use domain weighting rather than strict inclusion
}
)
else: # "auto" or default
# Standard search with basic optimizations
params.update(
{
"exclude_domains": self.excluded_domains[:5], # Basic exclusions
"type": "auto",
}
)
# Add financial-specific query enhancements
enhanced_query = self._enhance_financial_query(query)
if enhanced_query != query:
params["query"] = enhanced_query
return params
def _enhance_financial_query(self, query: str) -> str:
"""
Enhance queries with financial context and terminology.
Args:
query: Original search query
Returns:
Enhanced query with financial context
"""
# Financial keywords that improve search quality
financial_terms = {
"earnings",
"revenue",
"profit",
"loss",
"financial",
"quarterly",
"annual",
"SEC",
"10-K",
"10-Q",
"balance sheet",
"income statement",
"cash flow",
"dividend",
"stock",
"share",
"market cap",
"valuation",
}
query_lower = query.lower()
# Check if query already contains financial terms
has_financial_context = any(term in query_lower for term in financial_terms)
# Add context for company/stock queries
if not has_financial_context:
# Detect if it's a company or stock symbol query
if any(
indicator in query_lower
for indicator in ["company", "corp", "inc", "$", "stock"]
):
return f"{query} financial analysis earnings revenue"
elif len(query.split()) <= 3 and query.isupper(): # Likely stock symbol
return f"{query} stock financial performance earnings"
elif "analysis" in query_lower or "research" in query_lower:
return f"{query} financial data SEC filings"
return query
def _calculate_financial_relevance(self, result) -> float:
"""
Calculate financial relevance score for a search result.
Args:
result: Exa search result object
Returns:
Financial relevance score (0.0 to 1.0)
"""
score = 0.0
# Domain-based scoring
domain = self._extract_domain(result.url)
if domain in self.financial_domains:
if domain in ["sec.gov", "edgar.sec.gov", "federalreserve.gov"]:
score += 0.4 # Highest authority
elif domain in ["bloomberg.com", "reuters.com", "wsj.com", "ft.com"]:
score += 0.3 # High-quality financial news
else:
score += 0.2 # Other financial sources
# Content-based scoring
if hasattr(result, "text") and result.text:
text_lower = result.text.lower()
# Financial terminology scoring
financial_keywords = [
"earnings",
"revenue",
"profit",
"financial",
"quarterly",
"annual",
"sec filing",
"10-k",
"10-q",
"balance sheet",
"income statement",
"cash flow",
"dividend",
"market cap",
"valuation",
"analyst",
"forecast",
"guidance",
"ebitda",
"eps",
"pe ratio",
]
keyword_matches = sum(
1 for keyword in financial_keywords if keyword in text_lower
)
score += min(keyword_matches * 0.05, 0.3) # Max 0.3 from keywords
# Title-based scoring
if hasattr(result, "title") and result.title:
title_lower = result.title.lower()
if any(
term in title_lower
for term in ["financial", "earnings", "quarterly", "annual", "sec"]
):
score += 0.1
# Recency bonus for financial data
if hasattr(result, "published_date") and result.published_date:
try:
from datetime import datetime
# Handle different date formats
date_str = str(result.published_date)
if date_str and date_str != "":
# Handle ISO format with Z
if date_str.endswith("Z"):
date_str = date_str.replace("Z", "+00:00")
pub_date = datetime.fromisoformat(date_str)
days_old = (datetime.now(UTC) - pub_date).days
if days_old <= 30:
score += 0.1 # Recent data bonus
elif days_old <= 90:
score += 0.05 # Somewhat recent bonus
except (ValueError, AttributeError, TypeError):
pass # Skip if date parsing fails
return min(score, 1.0) # Cap at 1.0
def _extract_domain(self, url: str) -> str:
"""Extract domain from URL."""
try:
from urllib.parse import urlparse
return urlparse(url).netloc.lower().replace("www.", "")
except Exception:
return ""
def _is_authoritative_source(self, url: str) -> bool:
"""Check if URL is from an authoritative financial source."""
domain = self._extract_domain(url)
authoritative_domains = [
"sec.gov",
"edgar.sec.gov",
"federalreserve.gov",
"treasury.gov",
"bloomberg.com",
"reuters.com",
"wsj.com",
"ft.com",
]
return domain in authoritative_domains
class TavilySearchProvider(WebSearchProvider):
"""Tavily search provider with sensible filtering for financial research."""
def __init__(self, api_key: str):
super().__init__(api_key)
self.excluded_domains = {
"facebook.com",
"twitter.com",
"x.com",
"instagram.com",
"reddit.com",
}
async def search(
self, query: str, num_results: int = 10, timeout_budget: float | None = None
) -> list[dict[str, Any]]:
if not self.is_healthy():
raise WebSearchError("Tavily provider disabled due to repeated failures")
timeout = self._calculate_timeout(query, timeout_budget)
circuit_breaker = await circuit_manager.get_or_create(
"tavily_search",
failure_threshold=8,
recovery_timeout=30,
)
async def _search() -> list[dict[str, Any]]:
if TavilyClient is None:
raise ImportError("tavily package is required for TavilySearchProvider")
client = TavilyClient(api_key=self.api_key)
response = await asyncio.get_event_loop().run_in_executor(
None,
lambda: client.search(query=query, max_results=num_results),
)
return self._process_results(response.get("results", []))
return await circuit_breaker.call(_search, timeout=timeout)
def _process_results(
self, results: Iterable[dict[str, Any]]
) -> list[dict[str, Any]]:
processed: list[dict[str, Any]] = []
for item in results:
url = item.get("url", "")
if any(domain in url for domain in self.excluded_domains):
continue
processed.append(
{
"url": url,
"title": item.get("title"),
"content": item.get("content") or item.get("raw_content", ""),
"raw_content": item.get("raw_content"),
"published_date": item.get("published_date"),
"score": item.get("score", 0.0),
"provider": "tavily",
}
)
return processed
class ContentAnalyzer:
"""AI-powered content analysis for research results with batch processing capability."""
def __init__(self, llm: BaseChatModel):
self.llm = llm
self._batch_size = 4 # Process up to 4 sources concurrently
@staticmethod
def _coerce_message_content(raw_content: Any) -> str:
"""Convert LLM response content to a string for JSON parsing."""
if isinstance(raw_content, str):
return raw_content
if isinstance(raw_content, list):
parts: list[str] = []
for item in raw_content:
if isinstance(item, dict):
text_value = item.get("text")
if isinstance(text_value, str):
parts.append(text_value)
else:
parts.append(str(text_value))
else:
parts.append(str(item))
return "".join(parts)
return str(raw_content)
async def analyze_content(
self, content: str, persona: str, analysis_focus: str = "general"
) -> dict[str, Any]:
"""Analyze content with AI for insights, sentiment, and relevance."""
persona_focus = PERSONA_RESEARCH_FOCUS.get(
persona, PERSONA_RESEARCH_FOCUS["moderate"]
)
analysis_prompt = f"""
Analyze this financial content from the perspective of a {persona} investor.
Content to analyze:
{content[:3000]} # Limit content length
Focus Areas: {", ".join(persona_focus["keywords"])}
Risk Focus: {persona_focus["risk_focus"]}
Time Horizon: {persona_focus["time_horizon"]}
Provide analysis in the following structure:
1. KEY_INSIGHTS: 3-5 bullet points of most important insights
2. SENTIMENT: Overall sentiment (bullish/bearish/neutral) with confidence (0-1)
3. RISK_FACTORS: Key risks identified relevant to {persona} investors
4. OPPORTUNITIES: Investment opportunities or catalysts identified
5. CREDIBILITY: Assessment of source credibility (0-1 score)
6. RELEVANCE: How relevant is this to {persona} investment strategy (0-1 score)
7. SUMMARY: 2-3 sentence summary for {persona} investors
Format as JSON with clear structure.
"""
try:
response = await self.llm.ainvoke(
[
SystemMessage(
content="You are a financial content analyst. Return only valid JSON."
),
HumanMessage(content=analysis_prompt),
]
)
raw_content = self._coerce_message_content(response.content).strip()
analysis = json.loads(raw_content)
return {
"insights": analysis.get("KEY_INSIGHTS", []),
"sentiment": {
"direction": analysis.get("SENTIMENT", {}).get(
"direction", "neutral"
),
"confidence": analysis.get("SENTIMENT", {}).get("confidence", 0.5),
},
"risk_factors": analysis.get("RISK_FACTORS", []),
"opportunities": analysis.get("OPPORTUNITIES", []),
"credibility_score": analysis.get("CREDIBILITY", 0.5),
"relevance_score": analysis.get("RELEVANCE", 0.5),
"summary": analysis.get("SUMMARY", ""),
"analysis_timestamp": datetime.now(),
}
except Exception as e:
logger.warning(f"AI content analysis failed: {e}, using fallback")
return self._fallback_analysis(content, persona)
def _fallback_analysis(self, content: str, persona: str) -> dict[str, Any]:
"""Fallback analysis using keyword matching."""
persona_focus = PERSONA_RESEARCH_FOCUS.get(
persona, PERSONA_RESEARCH_FOCUS["moderate"]
)
content_lower = content.lower()
# Simple sentiment analysis
positive_words = [
"growth",
"increase",
"profit",
"success",
"opportunity",
"strong",
]
negative_words = ["decline", "loss", "risk", "problem", "concern", "weak"]
positive_count = sum(1 for word in positive_words if word in content_lower)
negative_count = sum(1 for word in negative_words if word in content_lower)
if positive_count > negative_count:
sentiment = "bullish"
confidence = 0.6
elif negative_count > positive_count:
sentiment = "bearish"
confidence = 0.6
else:
sentiment = "neutral"
confidence = 0.5
# Relevance scoring based on keywords
keyword_matches = sum(
1 for keyword in persona_focus["keywords"] if keyword in content_lower
)
relevance_score = min(keyword_matches / len(persona_focus["keywords"]), 1.0)
return {
"insights": [f"Fallback analysis for {persona} investor perspective"],
"sentiment": {"direction": sentiment, "confidence": confidence},
"risk_factors": ["Unable to perform detailed risk analysis"],
"opportunities": ["Unable to identify specific opportunities"],
"credibility_score": 0.5,
"relevance_score": relevance_score,
"summary": f"Content analysis for {persona} investor using fallback method",
"analysis_timestamp": datetime.now(),
"fallback_used": True,
}
async def analyze_content_batch(
self,
content_items: list[tuple[str, str]],
persona: str,
analysis_focus: str = "general",
) -> list[dict[str, Any]]:
"""
Analyze multiple content items in parallel batches for improved performance.
Args:
content_items: List of (content, source_identifier) tuples
persona: Investor persona for analysis perspective
analysis_focus: Focus area for analysis
Returns:
List of analysis results in same order as input
"""
if not content_items:
return []
# Process items in batches to avoid overwhelming the LLM
results = []
for i in range(0, len(content_items), self._batch_size):
batch = content_items[i : i + self._batch_size]
# Create concurrent tasks for this batch
tasks = [
self.analyze_content(content, persona, analysis_focus)
for content, _ in batch
]
# Wait for all tasks in this batch to complete
try:
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
# Process results and handle exceptions
for j, result in enumerate(batch_results):
if isinstance(result, Exception):
logger.warning(
f"Batch analysis failed for item {i + j}: {result}"
)
# Use fallback for failed items
content, source_id = batch[j]
fallback_result = self._fallback_analysis(content, persona)
fallback_result["source_identifier"] = source_id
fallback_result["batch_processed"] = True
results.append(fallback_result)
elif isinstance(result, dict):
enriched_result = dict(result)
enriched_result["source_identifier"] = batch[j][1]
enriched_result["batch_processed"] = True
results.append(enriched_result)
else:
content, source_id = batch[j]
fallback_result = self._fallback_analysis(content, persona)
fallback_result["source_identifier"] = source_id
fallback_result["batch_processed"] = True
results.append(fallback_result)
except Exception as e:
logger.error(f"Batch analysis completely failed: {e}")
# Fallback for entire batch
for content, source_id in batch:
fallback_result = self._fallback_analysis(content, persona)
fallback_result["source_identifier"] = source_id
fallback_result["batch_processed"] = True
fallback_result["batch_error"] = str(e)
results.append(fallback_result)
logger.info(
f"Batch content analysis completed: {len(content_items)} items processed "
f"in {(len(content_items) + self._batch_size - 1) // self._batch_size} batches"
)
return results
async def analyze_content_items(
self,
content_items: list[dict[str, Any]],
focus_areas: list[str],
) -> dict[str, Any]:
"""
Analyze content items for test compatibility.
Args:
content_items: List of search result dictionaries with content/text field
focus_areas: List of focus areas for analysis
Returns:
Dictionary with aggregated analysis results
"""
if not content_items:
return {
"insights": [],
"sentiment_scores": [],
"credibility_scores": [],
}
# For test compatibility, directly use LLM with test-compatible format
analyzed_results = []
for item in content_items:
content = item.get("text") or item.get("content") or ""
if content:
try:
# Direct LLM call for test compatibility
prompt = f"Analyze: {content[:500]}"
response = await self.llm.ainvoke(
[
SystemMessage(
content="You are a financial content analyst. Return only valid JSON."
),
HumanMessage(content=prompt),
]
)
coerced_content = self._coerce_message_content(
response.content
).strip()
analysis = json.loads(coerced_content)
analyzed_results.append(analysis)
except Exception as e:
logger.warning(f"Content analysis failed: {e}")
# Add fallback analysis
analyzed_results.append(
{
"insights": [
{"insight": "Analysis failed", "confidence": 0.1}
],
"sentiment": {"direction": "neutral", "confidence": 0.5},
"credibility": 0.5,
}
)
# Aggregate results
all_insights = []
sentiment_scores = []
credibility_scores = []
for result in analyzed_results:
# Handle test format with nested insight objects
insights = result.get("insights", [])
if isinstance(insights, list):
for insight in insights:
if isinstance(insight, dict) and "insight" in insight:
all_insights.append(insight["insight"])
elif isinstance(insight, str):
all_insights.append(insight)
else:
all_insights.append(str(insight))
sentiment = result.get("sentiment", {})
if sentiment:
sentiment_scores.append(sentiment)
credibility = result.get(
"credibility_score", result.get("credibility", 0.5)
)
credibility_scores.append(credibility)
return {
"insights": all_insights,
"sentiment_scores": sentiment_scores,
"credibility_scores": credibility_scores,
}
async def _analyze_single_content(
self, content_item: dict[str, Any] | str, focus_areas: list[str] | None = None
) -> dict[str, Any]:
"""Analyze single content item - used by tests."""
if isinstance(content_item, dict):
content = content_item.get("text") or content_item.get("content") or ""
else:
content = content_item
try:
result = await self.analyze_content(content, "moderate")
# Ensure test-compatible format
if "credibility_score" in result and "credibility" not in result:
result["credibility"] = result["credibility_score"]
return result
except Exception as e:
logger.warning(f"Single content analysis failed: {e}")
# Return fallback result
return {
"sentiment": {"direction": "neutral", "confidence": 0.5},
"credibility": 0.5,
"credibility_score": 0.5,
"insights": [],
"risk_factors": [],
"opportunities": [],
}
async def _extract_themes(
self, content_items: list[dict[str, Any]]
) -> list[dict[str, Any]]:
"""Extract themes from content items - used by tests."""
if not content_items:
return []
# Use LLM to extract structured themes
try:
content_text = "\n".join(
[item.get("text", item.get("content", "")) for item in content_items]
)
prompt = f"""
Extract key themes from the following content and return as JSON:
{content_text[:2000]}
Return format: {{"themes": [{{"theme": "theme_name", "relevance": 0.9, "mentions": 10}}]}}
"""
response = await self.llm.ainvoke(
[
SystemMessage(
content="You are a theme extraction AI. Return only valid JSON."
),
HumanMessage(content=prompt),
]
)
result = json.loads(
ContentAnalyzer._coerce_message_content(response.content)
)
return result.get("themes", [])
except Exception as e:
logger.warning(f"Theme extraction failed: {e}")
# Fallback to simple keyword-based themes
themes = []
for item in content_items:
content = item.get("text") or item.get("content") or ""
if content:
content_lower = content.lower()
if "growth" in content_lower:
themes.append(
{"theme": "Growth", "relevance": 0.8, "mentions": 1}
)
if "earnings" in content_lower:
themes.append(
{"theme": "Earnings", "relevance": 0.7, "mentions": 1}
)
if "technology" in content_lower:
themes.append(
{"theme": "Technology", "relevance": 0.6, "mentions": 1}
)
return themes
class DeepResearchAgent(PersonaAwareAgent):
"""
Deep research agent using 2025 LangGraph patterns.
Provides comprehensive financial research with web search, content analysis,
sentiment detection, and source validation.
"""
def __init__(
self,
llm: BaseChatModel,
persona: str = "moderate",
checkpointer: MemorySaver | None = None,
ttl_hours: int = 24, # Research results cached longer
exa_api_key: str | None = None,
default_depth: str = "standard",
max_sources: int | None = None,
research_depth: str | None = None,
enable_parallel_execution: bool = True,
parallel_config=None, # Type: ParallelResearchConfig | None
):
"""Initialize deep research agent."""
# Import here to avoid circular dependency
from maverick_mcp.utils.parallel_research import (
ParallelResearchConfig,
ParallelResearchOrchestrator,
TaskDistributionEngine,
)
# Store API key for immediate loading of search provider (pre-initialization)
self._exa_api_key = exa_api_key
self._search_providers_loaded = False
self.search_providers = []
# Pre-initialize search providers immediately (async init will be called separately)
self._initialization_pending = True
# Configuration
self.default_depth = research_depth or default_depth
self.max_sources = max_sources or RESEARCH_DEPTH_LEVELS.get(
self.default_depth, {}
).get("max_sources", 10)
self.content_analyzer = ContentAnalyzer(llm)
# Parallel execution configuration
self.enable_parallel_execution = enable_parallel_execution
self.parallel_config = parallel_config or ParallelResearchConfig(
max_concurrent_agents=settings.data_limits.max_parallel_agents,
timeout_per_agent=180, # 3 minutes per agent for thorough research
enable_fallbacks=False, # Disable fallbacks for speed
rate_limit_delay=0.5, # Reduced delay for faster execution
)
self.parallel_orchestrator = ParallelResearchOrchestrator(self.parallel_config)
self.task_distributor = TaskDistributionEngine()
# Get research-specific tools
research_tools = self._get_research_tools()
# Initialize base class
super().__init__(
llm=llm,
tools=research_tools,
persona=persona,
checkpointer=checkpointer or MemorySaver(),
ttl_hours=ttl_hours,
)
# Initialize components
self.conversation_store = ConversationStore(ttl_hours=ttl_hours)
@property
def web_search_provider(self):
"""Compatibility property for tests - returns first search provider."""
return self.search_providers[0] if self.search_providers else None
def _is_insight_relevant_for_persona(
self, insight: dict[str, Any], characteristics: dict[str, Any]
) -> bool:
"""Check if an insight is relevant for a given persona - used by tests."""
# Simple implementation for test compatibility
# In a real implementation, this would analyze the insight against persona characteristics
return True # Default permissive approach as mentioned in test comments
async def initialize(self) -> None:
"""Pre-initialize Exa search provider to eliminate lazy loading overhead during research."""
if not self._initialization_pending:
return
try:
provider = await get_cached_search_provider(self._exa_api_key)
self.search_providers = [provider] if provider else []
self._search_providers_loaded = True
self._initialization_pending = False
if not self.search_providers:
logger.warning(
"Exa search provider not available - research capabilities will be limited"
)
else:
logger.info("Pre-initialized Exa search provider")
except Exception as e:
logger.error(f"Failed to pre-initialize Exa search provider: {e}")
self.search_providers = []
self._search_providers_loaded = True
self._initialization_pending = False
logger.info(
f"DeepResearchAgent pre-initialized with {len(self.search_providers)} search providers, "
f"parallel execution: {self.enable_parallel_execution}"
)
async def _ensure_search_providers_loaded(self) -> None:
"""Ensure search providers are loaded - fallback to initialization if not pre-initialized."""
if self._search_providers_loaded:
return
# Check if initialization was marked as needed
if hasattr(self, "_needs_initialization") and self._needs_initialization:
logger.info("Performing deferred initialization of search providers")
await self.initialize()
self._needs_initialization = False
else:
# Fallback to pre-initialization if not done during agent creation
logger.warning(
"Search providers not pre-initialized - falling back to lazy loading"
)
await self.initialize()
def get_state_schema(self) -> type:
"""Return DeepResearchState schema."""
return DeepResearchState
def _get_research_tools(self) -> list[BaseTool]:
"""Get tools specific to research capabilities."""
tools = []
@tool
async def web_search_financial(
query: str,
num_results: int = 10,
provider: str = "auto",
strategy: str = "hybrid",
) -> dict[str, Any]:
"""
Search the web for financial information using optimized providers and strategies.
Args:
query: Search query for financial information
num_results: Number of results to return (default: 10)
provider: Search provider to use ('auto', 'exa', 'tavily')
strategy: Search strategy ('hybrid', 'authoritative', 'comprehensive', 'auto')
"""
return await self._perform_financial_search(
query, num_results, provider, strategy
)
@tool
async def analyze_company_fundamentals(
symbol: str, depth: str = "standard"
) -> dict[str, Any]:
"""Research company fundamentals including financials, competitive position, and outlook."""
return await self._research_company_fundamentals(symbol, depth)
@tool
async def analyze_market_sentiment(
topic: str, timeframe: str = "7d"
) -> dict[str, Any]:
"""Analyze market sentiment around a topic using news and social signals."""
return await self._analyze_market_sentiment_tool(topic, timeframe)
@tool
async def validate_research_claims(
claims: list[str], sources: list[str]
) -> dict[str, Any]:
"""Validate research claims against multiple sources for fact-checking."""
return await self._validate_claims(claims, sources)
tools.extend(
[
web_search_financial,
analyze_company_fundamentals,
analyze_market_sentiment,
validate_research_claims,
]
)
return tools
async def _perform_web_search(
self, query: str, num_results: int, provider: str = "auto"
) -> dict[str, Any]:
"""Fallback web search across configured providers."""
await self._ensure_search_providers_loaded()
if not self.search_providers:
return {
"error": "No search providers available",
"results": [],
"total_results": 0,
}
aggregated_results: list[dict[str, Any]] = []
target = provider.lower()
for provider_obj in self.search_providers:
provider_name = provider_obj.__class__.__name__.lower()
if target != "auto" and target not in provider_name:
continue
try:
results = await provider_obj.search(query, num_results)
aggregated_results.extend(results)
if target != "auto":
break
except Exception as error: # pragma: no cover - fallback logging
logger.warning(
"Fallback web search failed for provider %s: %s",
provider_obj.__class__.__name__,
error,
)
if not aggregated_results:
return {
"error": "Search failed",
"results": [],
"total_results": 0,
}
truncated_results = aggregated_results[:num_results]
return {
"results": truncated_results,
"total_results": len(truncated_results),
"search_duration": 0.0,
"search_strategy": "fallback",
}
async def _research_company_fundamentals(
self, symbol: str, depth: str = "standard"
) -> dict[str, Any]:
"""Convenience wrapper for company fundamental research used by tools."""
session_id = f"fundamentals-{symbol}-{uuid4().hex}"
focus_areas = [
"fundamentals",
"financials",
"valuation",
"risk_management",
"growth_drivers",
]
return await self.research_comprehensive(
topic=f"{symbol} company fundamentals analysis",
session_id=session_id,
depth=depth,
focus_areas=focus_areas,
timeframe="180d",
use_parallel_execution=False,
)
async def _analyze_market_sentiment_tool(
self, topic: str, timeframe: str = "7d"
) -> dict[str, Any]:
"""Wrapper used by the sentiment analysis tool."""
session_id = f"sentiment-{uuid4().hex}"
return await self.analyze_market_sentiment(
topic=topic,
session_id=session_id,
timeframe=timeframe,
use_parallel_execution=False,
)
async def _validate_claims(
self, claims: list[str], sources: list[str]
) -> dict[str, Any]:
"""Lightweight claim validation used for tool compatibility."""
validation_results: list[dict[str, Any]] = []
for claim in claims:
source_checks = []
for source in sources:
source_checks.append(
{
"source": source,
"status": "not_verified",
"confidence": 0.0,
"notes": "Automatic validation not available in fallback mode",
}
)
validation_results.append(
{
"claim": claim,
"validated": False,
"confidence": 0.0,
"evidence": [],
"source_checks": source_checks,
}
)
return {
"results": validation_results,
"summary": "Claim validation is currently using fallback heuristics.",
}
async def _perform_financial_search(
self, query: str, num_results: int, provider: str, strategy: str
) -> dict[str, Any]:
"""
Perform optimized financial search with enhanced strategies.
Args:
query: Search query
num_results: Number of results
provider: Search provider preference
strategy: Search strategy
Returns:
Dictionary with search results and metadata
"""
if not self.search_providers:
return {
"error": "No search providers available",
"results": [],
"total_results": 0,
}
start_time = datetime.now()
all_results = []
# Use Exa provider with financial optimization if available
exa_provider = None
for p in self.search_providers:
if isinstance(p, ExaSearchProvider):
exa_provider = p
break
if exa_provider and (provider == "auto" or provider == "exa"):
try:
# Use the enhanced financial search method
results = await exa_provider.search_financial(
query, num_results, strategy=strategy
)
# Add search metadata
for result in results:
result.update(
{
"search_strategy": strategy,
"search_timestamp": start_time.isoformat(),
"enhanced_query": query,
}
)
all_results.extend(results)
logger.info(
f"Financial search completed: {len(results)} results "
f"using strategy '{strategy}' in {(datetime.now() - start_time).total_seconds():.2f}s"
)
except Exception as e:
logger.error(f"Enhanced financial search failed: {e}")
# Fallback to regular search if available
if hasattr(self, "_perform_web_search"):
return await self._perform_web_search(query, num_results, provider)
else:
return {
"error": f"Financial search failed: {str(e)}",
"results": [],
"total_results": 0,
}
else:
# Use regular search providers
try:
for provider_obj in self.search_providers:
if (
provider == "auto"
or provider.lower() in str(type(provider_obj)).lower()
):
results = await provider_obj.search(query, num_results)
all_results.extend(results)
break
except Exception as e:
logger.error(f"Fallback search failed: {e}")
return {
"error": f"Search failed: {str(e)}",
"results": [],
"total_results": 0,
}
# Sort by financial relevance and authority
all_results.sort(
key=lambda x: (
x.get("financial_relevance", 0),
x.get("is_authoritative", False),
x.get("score", 0),
),
reverse=True,
)
return {
"results": all_results[:num_results],
"total_results": len(all_results),
"search_strategy": strategy,
"search_duration": (datetime.now() - start_time).total_seconds(),
"enhanced_search": True,
}
def _build_graph(self):
"""Build research workflow graph with multi-step research process."""
workflow = StateGraph(DeepResearchState)
# Core research workflow nodes
workflow.add_node("plan_research", self._plan_research)
workflow.add_node("execute_searches", self._execute_searches)
workflow.add_node("analyze_content", self._analyze_content)
workflow.add_node("validate_sources", self._validate_sources)
workflow.add_node("synthesize_findings", self._synthesize_findings)
workflow.add_node("generate_citations", self._generate_citations)
# Specialized research nodes
workflow.add_node("sentiment_analysis", self._sentiment_analysis)
workflow.add_node("fundamental_analysis", self._fundamental_analysis)
workflow.add_node("competitive_analysis", self._competitive_analysis)
# Quality control nodes
workflow.add_node("fact_validation", self._fact_validation)
workflow.add_node("source_credibility", self._source_credibility)
# Define workflow edges
workflow.add_edge(START, "plan_research")
workflow.add_edge("plan_research", "execute_searches")
workflow.add_edge("execute_searches", "analyze_content")
# Conditional routing based on research type
workflow.add_conditional_edges(
"analyze_content",
self._route_specialized_analysis,
{
"sentiment": "sentiment_analysis",
"fundamental": "fundamental_analysis",
"competitive": "competitive_analysis",
"validation": "validate_sources",
"synthesis": "synthesize_findings",
},
)
# Specialized analysis flows
workflow.add_edge("sentiment_analysis", "validate_sources")
workflow.add_edge("fundamental_analysis", "validate_sources")
workflow.add_edge("competitive_analysis", "validate_sources")
# Quality control flow
workflow.add_edge("validate_sources", "fact_validation")
workflow.add_edge("fact_validation", "source_credibility")
workflow.add_edge("source_credibility", "synthesize_findings")
# Final steps
workflow.add_edge("synthesize_findings", "generate_citations")
workflow.add_edge("generate_citations", END)
return workflow.compile(checkpointer=self.checkpointer)
@log_method_call(component="DeepResearchAgent", include_timing=True)
async def research_comprehensive(
self,
topic: str,
session_id: str,
depth: str | None = None,
focus_areas: list[str] | None = None,
timeframe: str = "30d",
timeout_budget: float | None = None, # Total timeout budget in seconds
**kwargs,
) -> dict[str, Any]:
"""
Comprehensive research on a financial topic.
Args:
topic: Research topic or company/symbol
session_id: Session identifier
depth: Research depth (basic/standard/comprehensive/exhaustive)
focus_areas: Specific areas to focus on
timeframe: Time range for research
timeout_budget: Total timeout budget in seconds (enables budget allocation)
**kwargs: Additional parameters
Returns:
Comprehensive research results with analysis and citations
"""
# Ensure search providers are loaded (cached for performance)
await self._ensure_search_providers_loaded()
# Check if search providers are available
if not self.search_providers:
return {
"error": "Research functionality unavailable - no search providers configured",
"details": "Please configure EXA_API_KEY environment variable to enable research capabilities",
"topic": topic,
"available_functionality": "Limited to pre-existing data and basic analysis",
}
start_time = datetime.now()
depth = depth or self.default_depth
# Calculate timeout budget allocation for generous research timeouts
timeout_budgets = {}
if timeout_budget and timeout_budget > 0:
timeout_budgets = {
"search_budget": timeout_budget
* 0.50, # 50% for search operations (generous allocation)
"analysis_budget": timeout_budget * 0.30, # 30% for content analysis
"synthesis_budget": timeout_budget * 0.20, # 20% for result synthesis
"total_budget": timeout_budget,
"allocation_strategy": "comprehensive_research",
}
logger.info(
f"TIMEOUT_BUDGET_ALLOCATION: total={timeout_budget}s → "
f"search={timeout_budgets['search_budget']:.1f}s, "
f"analysis={timeout_budgets['analysis_budget']:.1f}s, "
f"synthesis={timeout_budgets['synthesis_budget']:.1f}s"
)
# Initialize research state
initial_state = {
"messages": [HumanMessage(content=f"Research: {topic}")],
"persona": self.persona.name,
"session_id": session_id,
"timestamp": datetime.now(),
"research_topic": topic,
"research_depth": depth,
"focus_areas": focus_areas
or PERSONA_RESEARCH_FOCUS[self.persona.name.lower()]["keywords"],
"timeframe": timeframe,
"search_queries": [],
"search_results": [],
"analyzed_content": [],
"validated_sources": [],
"research_findings": [],
"sentiment_analysis": {},
"source_credibility_scores": {},
"citations": [],
"research_status": "planning",
"research_confidence": 0.0,
"source_diversity_score": 0.0,
"fact_validation_results": [],
"execution_time_ms": 0.0,
"api_calls_made": 0,
"cache_hits": 0,
"cache_misses": 0,
# Timeout budget allocation for intelligent time management
"timeout_budgets": timeout_budgets,
# Legacy fields
"token_count": 0,
"error": None,
"analyzed_stocks": {},
"key_price_levels": {},
"last_analysis_time": {},
"conversation_context": {},
}
# Add additional parameters
initial_state.update(kwargs)
# Set up orchestration logging
orchestration_logger = get_orchestration_logger("DeepResearchAgent")
orchestration_logger.set_request_context(
session_id=session_id,
research_topic=topic[:50], # Truncate for logging
research_depth=depth,
)
# Check if parallel execution is enabled and requested
use_parallel = kwargs.get(
"use_parallel_execution", self.enable_parallel_execution
)
orchestration_logger.info(
"🔍 RESEARCH_START",
execution_mode="parallel" if use_parallel else "sequential",
focus_areas=focus_areas[:3] if focus_areas else None,
timeframe=timeframe,
)
if use_parallel:
orchestration_logger.info("🚀 PARALLEL_EXECUTION_SELECTED")
try:
result = await self._execute_parallel_research(
topic=topic,
session_id=session_id,
depth=depth,
focus_areas=focus_areas,
timeframe=timeframe,
initial_state=initial_state,
start_time=start_time,
**kwargs,
)
orchestration_logger.info("✅ PARALLEL_EXECUTION_SUCCESS")
return result
except Exception as e:
orchestration_logger.warning(
"⚠️ PARALLEL_FALLBACK_TRIGGERED",
error=str(e),
fallback_mode="sequential",
)
# Fall through to sequential execution
# Execute research workflow (sequential)
orchestration_logger.info("🔄 SEQUENTIAL_EXECUTION_START")
try:
result = await self.graph.ainvoke(
initial_state,
config={
"configurable": {
"thread_id": session_id,
"checkpoint_ns": "deep_research",
}
},
)
# Calculate execution time
execution_time = (datetime.now() - start_time).total_seconds() * 1000
result["execution_time_ms"] = execution_time
return self._format_research_response(result)
except Exception as e:
logger.error(f"Error in deep research: {e}")
return {
"status": "error",
"error": str(e),
"execution_time_ms": (datetime.now() - start_time).total_seconds()
* 1000,
"agent_type": "deep_research",
}
# Workflow node implementations
async def _plan_research(self, state: DeepResearchState) -> Command:
"""Plan research strategy based on topic and persona."""
topic = state["research_topic"]
depth_config = RESEARCH_DEPTH_LEVELS[state["research_depth"]]
persona_focus = PERSONA_RESEARCH_FOCUS[self.persona.name.lower()]
# Generate search queries based on topic and persona
search_queries = await self._generate_search_queries(
topic, persona_focus, depth_config
)
return Command(
goto="execute_searches",
update={"search_queries": search_queries, "research_status": "searching"},
)
async def _safe_search(
self,
provider: WebSearchProvider,
query: str,
num_results: int = 5,
timeout_budget: float | None = None,
) -> list[dict[str, Any]]:
"""Safely execute search with a provider, handling exceptions gracefully."""
try:
return await provider.search(
query, num_results=num_results, timeout_budget=timeout_budget
)
except Exception as e:
logger.warning(
f"Search failed for '{query}' with provider {type(provider).__name__}: {e}"
)
return [] # Return empty list on failure
async def _execute_searches(self, state: DeepResearchState) -> Command:
"""Execute web searches using available providers with timeout budget awareness."""
search_queries = state["search_queries"]
depth_config = RESEARCH_DEPTH_LEVELS[state["research_depth"]]
# Calculate timeout budget per search operation
timeout_budgets = state.get("timeout_budgets", {})
search_budget = timeout_budgets.get("search_budget")
if search_budget:
# Divide search budget across queries and providers
total_search_operations = len(
search_queries[: depth_config["max_searches"]]
) * len(self.search_providers)
timeout_per_search = (
search_budget / max(total_search_operations, 1)
if total_search_operations > 0
else search_budget
)
logger.info(
f"SEARCH_BUDGET_ALLOCATION: {search_budget:.1f}s total → "
f"{timeout_per_search:.1f}s per search ({total_search_operations} operations)"
)
else:
timeout_per_search = None
all_results = []
# Create all search tasks for parallel execution with budget-aware timeouts
search_tasks = []
for query in search_queries[: depth_config["max_searches"]]:
for provider in self.search_providers:
# Create async task for each provider/query combination with timeout budget
search_tasks.append(
self._safe_search(
provider,
query,
num_results=5,
timeout_budget=timeout_per_search,
)
)
# Execute all searches in parallel using asyncio.gather()
if search_tasks:
parallel_results = await asyncio.gather(
*search_tasks, return_exceptions=True
)
# Process results and filter out exceptions
for result in parallel_results:
if isinstance(result, Exception):
# Log the exception but continue with other results
logger.warning(f"Search task failed: {result}")
elif isinstance(result, list):
all_results.extend(result)
elif result is not None:
all_results.append(result)
# Deduplicate and limit results
unique_results = []
seen_urls = set()
for result in all_results:
if (
result["url"] not in seen_urls
and len(unique_results) < depth_config["max_sources"]
):
unique_results.append(result)
seen_urls.add(result["url"])
logger.info(
f"Search completed: {len(unique_results)} unique results from {len(all_results)} total"
)
return Command(
goto="analyze_content",
update={"search_results": unique_results, "research_status": "analyzing"},
)
async def _analyze_content(self, state: DeepResearchState) -> Command:
"""Analyze search results using AI content analysis."""
search_results = state["search_results"]
analyzed_content = []
# Analyze each piece of content
for result in search_results:
if result.get("content"):
analysis = await self.content_analyzer.analyze_content(
content=result["content"],
persona=self.persona.name.lower(),
analysis_focus=state["research_depth"],
)
analyzed_content.append({**result, "analysis": analysis})
return Command(
goto="validate_sources",
update={
"analyzed_content": analyzed_content,
"research_status": "validating",
},
)
def _route_specialized_analysis(self, state: DeepResearchState) -> str:
"""Route to specialized analysis based on research focus."""
focus_areas = state.get("focus_areas", [])
if any(word in focus_areas for word in ["sentiment", "news", "social"]):
return "sentiment"
elif any(
word in focus_areas for word in ["fundamental", "financial", "earnings"]
):
return "fundamental"
elif any(word in focus_areas for word in ["competitive", "market", "industry"]):
return "competitive"
else:
return "validation"
async def _validate_sources(self, state: DeepResearchState) -> Command:
"""Validate source credibility and filter results."""
analyzed_content = state["analyzed_content"]
validated_sources = []
credibility_scores = {}
for content in analyzed_content:
# Calculate credibility score based on multiple factors
credibility_score = self._calculate_source_credibility(content)
credibility_scores[content["url"]] = credibility_score
# Only include sources above credibility threshold
if credibility_score >= 0.6: # Configurable threshold
validated_sources.append(content)
return Command(
goto="synthesize_findings",
update={
"validated_sources": validated_sources,
"source_credibility_scores": credibility_scores,
"research_status": "synthesizing",
},
)
async def _synthesize_findings(self, state: DeepResearchState) -> Command:
"""Synthesize research findings into coherent insights."""
validated_sources = state["validated_sources"]
# Generate synthesis using LLM
synthesis_prompt = self._build_synthesis_prompt(validated_sources, state)
synthesis_response = await self.llm.ainvoke(
[
SystemMessage(content="You are a financial research synthesizer."),
HumanMessage(content=synthesis_prompt),
]
)
raw_synthesis = ContentAnalyzer._coerce_message_content(
synthesis_response.content
)
research_findings = {
"synthesis": raw_synthesis,
"key_insights": self._extract_key_insights(validated_sources),
"overall_sentiment": self._calculate_overall_sentiment(validated_sources),
"risk_assessment": self._assess_risks(validated_sources),
"investment_implications": self._derive_investment_implications(
validated_sources
),
"confidence_score": self._calculate_research_confidence(validated_sources),
}
return Command(
goto="generate_citations",
update={
"research_findings": research_findings,
"research_confidence": research_findings["confidence_score"],
"research_status": "completing",
},
)
async def _generate_citations(self, state: DeepResearchState) -> Command:
"""Generate proper citations for all sources."""
validated_sources = state["validated_sources"]
citations = []
for i, source in enumerate(validated_sources, 1):
citation = {
"id": i,
"title": source.get("title", "Untitled"),
"url": source["url"],
"published_date": source.get("published_date"),
"author": source.get("author"),
"credibility_score": state["source_credibility_scores"].get(
source["url"], 0.5
),
"relevance_score": source.get("analysis", {}).get(
"relevance_score", 0.5
),
}
citations.append(citation)
return Command(
goto="__end__",
update={"citations": citations, "research_status": "completed"},
)
# Helper methods
async def _generate_search_queries(
self, topic: str, persona_focus: dict[str, Any], depth_config: dict[str, Any]
) -> list[str]:
"""Generate search queries optimized for the research topic and persona."""
base_queries = [
f"{topic} financial analysis",
f"{topic} investment research",
f"{topic} market outlook",
]
# Add persona-specific queries
persona_queries = [
f"{topic} {keyword}" for keyword in persona_focus["keywords"][:3]
]
# Add source-specific queries
source_queries = [
f"{topic} {source}" for source in persona_focus["sources"][:2]
]
all_queries = base_queries + persona_queries + source_queries
return all_queries[: depth_config["max_searches"]]
def _calculate_source_credibility(self, content: dict[str, Any]) -> float:
"""Calculate credibility score for a source."""
score = 0.5 # Base score
# Domain credibility
url = content.get("url", "")
if any(domain in url for domain in [".gov", ".edu", ".org"]):
score += 0.2
elif any(
domain in url
for domain in [
"sec.gov",
"investopedia.com",
"bloomberg.com",
"reuters.com",
]
):
score += 0.3
# Publication date recency
pub_date = content.get("published_date")
if pub_date:
try:
date_obj = datetime.fromisoformat(pub_date.replace("Z", "+00:00"))
days_old = (datetime.now() - date_obj).days
if days_old < 30:
score += 0.1
elif days_old < 90:
score += 0.05
except (ValueError, TypeError, AttributeError):
pass
# Content analysis credibility
if "analysis" in content:
analysis_cred = content["analysis"].get("credibility_score", 0.5)
score = (score + analysis_cred) / 2
return min(score, 1.0)
def _build_synthesis_prompt(
self, sources: list[dict[str, Any]], state: DeepResearchState
) -> str:
"""Build synthesis prompt for final research output."""
topic = state["research_topic"]
persona = self.persona.name
prompt = f"""
Synthesize comprehensive research findings on '{topic}' for a {persona} investor.
Research Sources ({len(sources)} validated sources):
"""
for i, source in enumerate(sources, 1):
analysis = source.get("analysis", {})
prompt += f"\n{i}. {source.get('title', 'Unknown Title')}"
prompt += f" - Insights: {', '.join(analysis.get('insights', [])[:2])}"
prompt += f" - Sentiment: {analysis.get('sentiment', {}).get('direction', 'neutral')}"
prompt += f" - Credibility: {state['source_credibility_scores'].get(source['url'], 0.5):.2f}"
prompt += f"""
Please provide a comprehensive synthesis that includes:
1. Executive Summary (2-3 sentences)
2. Key Findings (5-7 bullet points)
3. Investment Implications for {persona} investors
4. Risk Considerations
5. Recommended Actions
6. Confidence Level and reasoning
Tailor the analysis specifically for {persona} investment characteristics and risk tolerance.
"""
return prompt
def _extract_key_insights(self, sources: list[dict[str, Any]]) -> list[str]:
"""Extract and consolidate key insights from all sources."""
all_insights = []
for source in sources:
analysis = source.get("analysis", {})
insights = analysis.get("insights", [])
all_insights.extend(insights)
# Simple deduplication (could be enhanced with semantic similarity)
unique_insights = list(dict.fromkeys(all_insights))
return unique_insights[:10] # Return top 10 insights
def _calculate_overall_sentiment(
self, sources: list[dict[str, Any]]
) -> dict[str, Any]:
"""Calculate overall sentiment from all sources."""
sentiments = []
weights = []
for source in sources:
analysis = source.get("analysis", {})
sentiment = analysis.get("sentiment", {})
# Convert sentiment to numeric value
direction = sentiment.get("direction", "neutral")
if direction == "bullish":
sentiment_value = 1
elif direction == "bearish":
sentiment_value = -1
else:
sentiment_value = 0
confidence = sentiment.get("confidence", 0.5)
credibility = source.get("credibility_score", 0.5)
sentiments.append(sentiment_value)
weights.append(confidence * credibility)
if not sentiments:
return {"direction": "neutral", "confidence": 0.5, "consensus": 0.5}
# Weighted average sentiment
weighted_sentiment = sum(
s * w for s, w in zip(sentiments, weights, strict=False)
) / sum(weights)
# Convert back to direction
if weighted_sentiment > 0.2:
overall_direction = "bullish"
elif weighted_sentiment < -0.2:
overall_direction = "bearish"
else:
overall_direction = "neutral"
# Calculate consensus (how much sources agree)
sentiment_variance = sum(weights) / len(sentiments) if sentiments else 0
consensus = 1 - sentiment_variance if sentiment_variance < 1 else 0
return {
"direction": overall_direction,
"confidence": abs(weighted_sentiment),
"consensus": consensus,
"source_count": len(sentiments),
}
def _assess_risks(self, sources: list[dict[str, Any]]) -> list[str]:
"""Consolidate risk assessments from all sources."""
all_risks = []
for source in sources:
analysis = source.get("analysis", {})
risks = analysis.get("risk_factors", [])
all_risks.extend(risks)
# Deduplicate and return top risks
unique_risks = list(dict.fromkeys(all_risks))
return unique_risks[:8]
def _derive_investment_implications(
self, sources: list[dict[str, Any]]
) -> dict[str, Any]:
"""Derive investment implications based on research findings."""
opportunities = []
threats = []
for source in sources:
analysis = source.get("analysis", {})
opps = analysis.get("opportunities", [])
risks = analysis.get("risk_factors", [])
opportunities.extend(opps)
threats.extend(risks)
return {
"opportunities": list(dict.fromkeys(opportunities))[:5],
"threats": list(dict.fromkeys(threats))[:5],
"recommended_action": self._recommend_action(sources),
"time_horizon": PERSONA_RESEARCH_FOCUS[self.persona.name.lower()][
"time_horizon"
],
}
def _recommend_action(self, sources: list[dict[str, Any]]) -> str:
"""Recommend investment action based on research findings."""
overall_sentiment = self._calculate_overall_sentiment(sources)
if (
overall_sentiment["direction"] == "bullish"
and overall_sentiment["confidence"] > 0.7
):
if self.persona.name.lower() == "conservative":
return "Consider gradual position building with proper risk management"
else:
return "Consider initiating position with appropriate position sizing"
elif (
overall_sentiment["direction"] == "bearish"
and overall_sentiment["confidence"] > 0.7
):
return "Exercise caution - consider waiting for better entry or avoiding"
else:
return "Monitor closely - mixed signals suggest waiting for clarity"
def _calculate_research_confidence(self, sources: list[dict[str, Any]]) -> float:
"""Calculate overall confidence in research findings."""
if not sources:
return 0.0
# Factors that increase confidence
source_count_factor = min(
len(sources) / 10, 1.0
) # More sources = higher confidence
avg_credibility = sum(
source.get("credibility_score", 0.5) for source in sources
) / len(sources)
avg_relevance = sum(
source.get("analysis", {}).get("relevance_score", 0.5) for source in sources
) / len(sources)
# Diversity of sources (different domains)
unique_domains = len(
{source["url"].split("/")[2] for source in sources if "url" in source}
)
diversity_factor = min(unique_domains / 5, 1.0)
# Combine factors
confidence = (
source_count_factor + avg_credibility + avg_relevance + diversity_factor
) / 4
return round(confidence, 2)
def _format_research_response(self, result: dict[str, Any]) -> dict[str, Any]:
"""Format research response for consistent output."""
return {
"status": "success",
"agent_type": "deep_research",
"persona": result.get("persona"),
"research_topic": result.get("research_topic"),
"research_depth": result.get("research_depth"),
"findings": result.get("research_findings", {}),
"sources_analyzed": len(result.get("validated_sources", [])),
"confidence_score": result.get("research_confidence", 0.0),
"citations": result.get("citations", []),
"execution_time_ms": result.get("execution_time_ms", 0.0),
"search_queries_used": result.get("search_queries", []),
"source_diversity": result.get("source_diversity_score", 0.0),
}
# Specialized research analysis methods
async def _sentiment_analysis(self, state: DeepResearchState) -> Command:
"""Perform specialized sentiment analysis."""
logger.info("Performing sentiment analysis")
# For now, route to content analysis with sentiment focus
original_focus = state.get("focus_areas", [])
state["focus_areas"] = ["market_sentiment", "sentiment", "mood"]
result = await self._analyze_content(state)
state["focus_areas"] = original_focus # Restore original focus
return result
async def _fundamental_analysis(self, state: DeepResearchState) -> Command:
"""Perform specialized fundamental analysis."""
logger.info("Performing fundamental analysis")
# For now, route to content analysis with fundamental focus
original_focus = state.get("focus_areas", [])
state["focus_areas"] = ["fundamentals", "financials", "valuation"]
result = await self._analyze_content(state)
state["focus_areas"] = original_focus # Restore original focus
return result
async def _competitive_analysis(self, state: DeepResearchState) -> Command:
"""Perform specialized competitive analysis."""
logger.info("Performing competitive analysis")
# For now, route to content analysis with competitive focus
original_focus = state.get("focus_areas", [])
state["focus_areas"] = ["competitive_landscape", "market_share", "competitors"]
result = await self._analyze_content(state)
state["focus_areas"] = original_focus # Restore original focus
return result
async def _fact_validation(self, state: DeepResearchState) -> Command:
"""Perform fact validation on research findings."""
logger.info("Performing fact validation")
# For now, route to source validation
return await self._validate_sources(state)
async def _source_credibility(self, state: DeepResearchState) -> Command:
"""Assess source credibility and reliability."""
logger.info("Assessing source credibility")
# For now, route to source validation
return await self._validate_sources(state)
async def research_company_comprehensive(
self,
symbol: str,
session_id: str,
include_competitive_analysis: bool = False,
**kwargs,
) -> dict[str, Any]:
"""
Comprehensive company research.
Args:
symbol: Stock symbol to research
session_id: Session identifier
include_competitive_analysis: Whether to include competitive analysis
**kwargs: Additional parameters
Returns:
Comprehensive company research results
"""
topic = f"{symbol} company financial analysis and outlook"
if include_competitive_analysis:
kwargs["focus_areas"] = kwargs.get("focus_areas", []) + [
"competitive_analysis",
"market_position",
]
return await self.research_comprehensive(
topic=topic, session_id=session_id, **kwargs
)
async def research_topic(
self,
query: str,
session_id: str,
focus_areas: list[str] | None = None,
timeframe: str = "30d",
**kwargs,
) -> dict[str, Any]:
"""
General topic research.
Args:
query: Research query or topic
session_id: Session identifier
focus_areas: Specific areas to focus on
timeframe: Time range for research
**kwargs: Additional parameters
Returns:
Research results for the given topic
"""
return await self.research_comprehensive(
topic=query,
session_id=session_id,
focus_areas=focus_areas,
timeframe=timeframe,
**kwargs,
)
async def analyze_market_sentiment(
self, topic: str, session_id: str, timeframe: str = "7d", **kwargs
) -> dict[str, Any]:
"""
Analyze market sentiment around a topic.
Args:
topic: Topic to analyze sentiment for
session_id: Session identifier
timeframe: Time range for analysis
**kwargs: Additional parameters
Returns:
Market sentiment analysis results
"""
return await self.research_comprehensive(
topic=f"market sentiment analysis: {topic}",
session_id=session_id,
focus_areas=["sentiment", "market_mood", "investor_sentiment"],
timeframe=timeframe,
**kwargs,
)
# Parallel Execution Implementation
@log_method_call(component="DeepResearchAgent", include_timing=True)
async def _execute_parallel_research(
self,
topic: str,
session_id: str,
depth: str,
focus_areas: list[str] | None = None,
timeframe: str = "30d",
initial_state: dict[str, Any] | None = None,
start_time: datetime | None = None,
**kwargs,
) -> dict[str, Any]:
"""
Execute research using parallel subagent execution.
Args:
topic: Research topic
session_id: Session identifier
depth: Research depth level
focus_areas: Specific focus areas
timeframe: Research timeframe
initial_state: Initial state for backward compatibility
start_time: Start time for execution measurement
**kwargs: Additional parameters
Returns:
Research results in same format as sequential execution
"""
orchestration_logger = get_orchestration_logger("ParallelExecution")
orchestration_logger.set_request_context(session_id=session_id)
try:
# Generate research tasks using task distributor
orchestration_logger.info("🎯 TASK_DISTRIBUTION_START")
research_tasks = self.task_distributor.distribute_research_tasks(
topic=topic, session_id=session_id, focus_areas=focus_areas
)
orchestration_logger.info(
"📋 TASKS_GENERATED",
task_count=len(research_tasks),
task_types=[t.task_type for t in research_tasks],
)
# Execute tasks in parallel
orchestration_logger.info("🚀 PARALLEL_ORCHESTRATION_START")
research_result = (
await self.parallel_orchestrator.execute_parallel_research(
tasks=research_tasks,
research_executor=self._execute_subagent_task,
synthesis_callback=self._synthesize_parallel_results,
)
)
# Log parallel execution metrics
log_performance_metrics(
"ParallelExecution",
{
"total_tasks": research_result.successful_tasks
+ research_result.failed_tasks,
"successful_tasks": research_result.successful_tasks,
"failed_tasks": research_result.failed_tasks,
"parallel_efficiency": research_result.parallel_efficiency,
"execution_time": research_result.total_execution_time,
},
)
# Convert parallel results to expected format
orchestration_logger.info("🔄 RESULT_FORMATTING_START")
formatted_result = await self._format_parallel_research_response(
research_result=research_result,
topic=topic,
session_id=session_id,
depth=depth,
initial_state=initial_state,
start_time=start_time,
)
orchestration_logger.info(
"✅ PARALLEL_RESEARCH_COMPLETE",
result_confidence=formatted_result.get("confidence_score", 0.0),
sources_analyzed=formatted_result.get("sources_analyzed", 0),
)
return formatted_result
except Exception as e:
orchestration_logger.error("❌ PARALLEL_RESEARCH_FAILED", error=str(e))
raise # Re-raise to trigger fallback to sequential
async def _execute_subagent_task(
self, task
) -> dict[str, Any]: # Type: ResearchTask
"""
Execute a single research task using specialized subagent.
Args:
task: ResearchTask to execute
Returns:
Research results from specialized subagent
"""
with log_agent_execution(
task.task_type, task.task_id, task.focus_areas
) as agent_logger:
agent_logger.info(
"🎯 SUBAGENT_ROUTING",
target_topic=task.target_topic[:50],
focus_count=len(task.focus_areas),
priority=task.priority,
)
# Route to appropriate subagent based on task type
if task.task_type == "fundamental":
subagent = FundamentalResearchAgent(self)
return await subagent.execute_research(task)
elif task.task_type == "technical":
subagent = TechnicalResearchAgent(self)
return await subagent.execute_research(task)
elif task.task_type == "sentiment":
subagent = SentimentResearchAgent(self)
return await subagent.execute_research(task)
elif task.task_type == "competitive":
subagent = CompetitiveResearchAgent(self)
return await subagent.execute_research(task)
else:
# Default to fundamental analysis
agent_logger.warning("⚠️ UNKNOWN_TASK_TYPE", fallback="fundamental")
subagent = FundamentalResearchAgent(self)
return await subagent.execute_research(task)
async def _synthesize_parallel_results(
self,
task_results, # Type: dict[str, ResearchTask]
) -> dict[str, Any]:
"""
Synthesize results from multiple parallel research tasks.
Args:
task_results: Dictionary of task IDs to ResearchTask objects
Returns:
Synthesized research insights
"""
synthesis_logger = get_orchestration_logger("ResultSynthesis")
log_synthesis_operation(
"parallel_research_synthesis",
len(task_results),
f"Synthesizing from {len(task_results)} research tasks",
)
# Extract successful results
successful_results = {
task_id: task.result
for task_id, task in task_results.items()
if task.status == "completed" and task.result
}
synthesis_logger.info(
"📊 SYNTHESIS_INPUT_ANALYSIS",
total_tasks=len(task_results),
successful_tasks=len(successful_results),
failed_tasks=len(task_results) - len(successful_results),
)
if not successful_results:
synthesis_logger.warning("⚠️ NO_SUCCESSFUL_RESULTS")
return {
"synthesis": "No research results available for synthesis",
"confidence_score": 0.0,
}
all_insights = []
all_risks = []
all_opportunities = []
sentiment_scores = []
credibility_scores = []
# Aggregate results from all successful tasks
for task_id, task in task_results.items():
if task.status == "completed" and task.result:
task_type = task_id.split("_")[-1] if "_" in task_id else "unknown"
synthesis_logger.debug(
"🔍 PROCESSING_TASK_RESULT",
task_id=task_id,
task_type=task_type,
has_insights="insights" in task.result
if isinstance(task.result, dict)
else False,
)
result = task.result
# Extract insights
insights = result.get("insights", [])
all_insights.extend(insights)
# Extract risks and opportunities
risks = result.get("risk_factors", [])
opportunities = result.get("opportunities", [])
all_risks.extend(risks)
all_opportunities.extend(opportunities)
# Extract sentiment
sentiment = result.get("sentiment", {})
if sentiment:
sentiment_scores.append(sentiment)
# Extract credibility
credibility = result.get("credibility_score", 0.5)
credibility_scores.append(credibility)
# Calculate overall metrics
overall_sentiment = self._calculate_aggregated_sentiment(sentiment_scores)
average_credibility = (
sum(credibility_scores) / len(credibility_scores)
if credibility_scores
else 0.5
)
# Generate synthesis using LLM
synthesis_prompt = self._build_parallel_synthesis_prompt(
task_results, all_insights, all_risks, all_opportunities, overall_sentiment
)
try:
synthesis_response = await self.llm.ainvoke(
[
SystemMessage(
content="You are a financial research synthesizer. Combine insights from multiple specialized research agents."
),
HumanMessage(content=synthesis_prompt),
]
)
synthesis_text = ContentAnalyzer._coerce_message_content(
synthesis_response.content
)
synthesis_logger.info("🧠 LLM_SYNTHESIS_SUCCESS")
except Exception as e:
synthesis_logger.warning(
"⚠️ LLM_SYNTHESIS_FAILED", error=str(e), fallback="text_fallback"
)
synthesis_text = self._generate_fallback_synthesis(
all_insights, overall_sentiment
)
synthesis_result = {
"synthesis": synthesis_text,
"key_insights": list(dict.fromkeys(all_insights))[
:10
], # Deduplicate and limit
"overall_sentiment": overall_sentiment,
"risk_assessment": list(dict.fromkeys(all_risks))[:8],
"investment_implications": {
"opportunities": list(dict.fromkeys(all_opportunities))[:5],
"threats": list(dict.fromkeys(all_risks))[:5],
"recommended_action": self._derive_parallel_recommendation(
overall_sentiment
),
},
"confidence_score": average_credibility,
"task_breakdown": {
task_id: {
"type": task.task_type,
"status": task.status,
"execution_time": (task.end_time - task.start_time)
if task.start_time and task.end_time
else 0,
}
for task_id, task in task_results.items()
},
}
synthesis_logger.info(
"✅ SYNTHESIS_COMPLETE",
insights_count=len(all_insights),
overall_confidence=average_credibility,
sentiment_direction=synthesis_result["overall_sentiment"]["direction"],
)
return synthesis_result
async def _format_parallel_research_response(
self,
research_result,
topic: str,
session_id: str,
depth: str,
initial_state: dict[str, Any] | None,
start_time: datetime | None,
) -> dict[str, Any]:
"""Format parallel research results to match expected sequential format."""
if start_time is not None:
execution_time = (datetime.now() - start_time).total_seconds() * 1000
else:
execution_time = 0.0
# Extract synthesis from research result
synthesis = research_result.synthesis or {}
state_snapshot: dict[str, Any] = initial_state or {}
# Create citations from task results
citations = []
all_sources = []
citation_id = 1
for _task_id, task in research_result.task_results.items():
if task.status == "completed" and task.result:
sources = task.result.get("sources", [])
for source in sources:
citation = {
"id": citation_id,
"title": source.get("title", "Unknown Title"),
"url": source.get("url", ""),
"published_date": source.get("published_date"),
"author": source.get("author"),
"credibility_score": source.get("credibility_score", 0.5),
"relevance_score": source.get("relevance_score", 0.5),
"research_type": task.task_type,
}
citations.append(citation)
all_sources.append(source)
citation_id += 1
return {
"status": "success",
"agent_type": "deep_research",
"execution_mode": "parallel",
"persona": state_snapshot.get("persona"),
"research_topic": topic,
"research_depth": depth,
"findings": synthesis,
"sources_analyzed": len(all_sources),
"confidence_score": synthesis.get("confidence_score", 0.0),
"citations": citations,
"execution_time_ms": execution_time,
"parallel_execution_stats": {
"total_tasks": len(research_result.task_results),
"successful_tasks": research_result.successful_tasks,
"failed_tasks": research_result.failed_tasks,
"parallel_efficiency": research_result.parallel_efficiency,
"task_breakdown": synthesis.get("task_breakdown", {}),
},
"search_queries_used": [], # Will be populated by subagents
"source_diversity": len({source.get("url", "") for source in all_sources})
/ max(len(all_sources), 1),
}
# Helper methods for parallel execution
def _calculate_aggregated_sentiment(
self, sentiment_scores: list[dict[str, Any]]
) -> dict[str, Any]:
"""Calculate overall sentiment from multiple sentiment scores."""
if not sentiment_scores:
return {"direction": "neutral", "confidence": 0.5}
# Convert sentiment directions to numeric values
numeric_scores = []
confidences = []
for sentiment in sentiment_scores:
direction = sentiment.get("direction", "neutral")
confidence = sentiment.get("confidence", 0.5)
if direction == "bullish":
numeric_scores.append(1 * confidence)
elif direction == "bearish":
numeric_scores.append(-1 * confidence)
else:
numeric_scores.append(0)
confidences.append(confidence)
# Calculate weighted average
avg_score = sum(numeric_scores) / len(numeric_scores)
avg_confidence = sum(confidences) / len(confidences)
# Convert back to direction
if avg_score > 0.2:
direction = "bullish"
elif avg_score < -0.2:
direction = "bearish"
else:
direction = "neutral"
return {
"direction": direction,
"confidence": avg_confidence,
"consensus": 1 - abs(avg_score) if abs(avg_score) < 1 else 0,
"source_count": len(sentiment_scores),
}
def _build_parallel_synthesis_prompt(
self,
task_results: dict[str, Any], # Actually dict[str, ResearchTask]
all_insights: list[str],
all_risks: list[str],
all_opportunities: list[str],
overall_sentiment: dict[str, Any],
) -> str:
"""Build synthesis prompt for parallel research results."""
successful_tasks = [
task for task in task_results.values() if task.status == "completed"
]
prompt = f"""
Synthesize comprehensive research findings from {len(successful_tasks)} specialized research agents.
Research Task Results:
"""
for task in successful_tasks:
if task.result:
prompt += f"\n{task.task_type.upper()} RESEARCH:"
prompt += f" - Status: {task.status}"
prompt += f" - Key Insights: {', '.join(task.result.get('insights', [])[:3])}"
prompt += f" - Sentiment: {task.result.get('sentiment', {}).get('direction', 'neutral')}"
prompt += f"""
AGGREGATED DATA:
- Total Insights: {len(all_insights)}
- Risk Factors: {len(all_risks)}
- Opportunities: {len(all_opportunities)}
- Overall Sentiment: {overall_sentiment.get("direction")} (confidence: {overall_sentiment.get("confidence", 0.5):.2f})
Please provide a comprehensive synthesis that includes:
1. Executive Summary (2-3 sentences)
2. Key Findings from all research areas
3. Investment Implications for {self.persona.name} investors
4. Risk Assessment and Mitigation
5. Recommended Actions based on parallel analysis
6. Confidence Level and reasoning
Focus on insights that are supported by multiple research agents and highlight any contradictions.
"""
return prompt
def _generate_fallback_synthesis(
self, insights: list[str], sentiment: dict[str, Any]
) -> str:
"""Generate fallback synthesis when LLM synthesis fails."""
return f"""
Research synthesis generated from {len(insights)} insights.
Overall sentiment: {sentiment.get("direction", "neutral")} with {sentiment.get("confidence", 0.5):.2f} confidence.
Key insights identified:
{chr(10).join(f"- {insight}" for insight in insights[:5])}
This is a fallback synthesis due to LLM processing limitations.
"""
def _derive_parallel_recommendation(self, sentiment: dict[str, Any]) -> str:
"""Derive investment recommendation from parallel analysis."""
direction = sentiment.get("direction", "neutral")
confidence = sentiment.get("confidence", 0.5)
if direction == "bullish" and confidence > 0.7:
return "Strong buy signal based on parallel analysis from multiple research angles"
elif direction == "bullish" and confidence > 0.5:
return "Consider position building with appropriate risk management"
elif direction == "bearish" and confidence > 0.7:
return "Exercise significant caution - multiple research areas show negative signals"
elif direction == "bearish" and confidence > 0.5:
return "Monitor closely - mixed to negative signals suggest waiting"
else:
return "Neutral stance recommended - parallel analysis shows mixed signals"
# Specialized Subagent Classes
class BaseSubagent:
"""Base class for specialized research subagents."""
def __init__(self, parent_agent: DeepResearchAgent):
self.parent = parent_agent
self.llm = parent_agent.llm
self.search_providers = parent_agent.search_providers
self.content_analyzer = parent_agent.content_analyzer
self.persona = parent_agent.persona
self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
async def execute_research(self, task) -> dict[str, Any]: # task: ResearchTask
"""Execute research task - to be implemented by subclasses."""
raise NotImplementedError
async def _safe_search(
self,
provider: WebSearchProvider,
query: str,
num_results: int = 5,
timeout_budget: float | None = None,
) -> list[dict[str, Any]]:
"""Safely execute search with a provider, handling exceptions gracefully."""
try:
return await provider.search(
query, num_results=num_results, timeout_budget=timeout_budget
)
except Exception as e:
self.logger.warning(
f"Search failed for '{query}' with provider {type(provider).__name__}: {e}"
)
return [] # Return empty list on failure
async def _perform_specialized_search(
self,
topic: str,
specialized_queries: list[str],
max_results: int = 10,
timeout_budget: float | None = None,
) -> list[dict[str, Any]]:
"""Perform specialized web search for this subagent type."""
all_results = []
# Create all search tasks for parallel execution
search_tasks = []
results_per_query = (
max_results // len(specialized_queries)
if specialized_queries
else max_results
)
# Calculate timeout per search if budget provided
if timeout_budget:
total_searches = len(specialized_queries) * len(self.search_providers)
timeout_per_search = timeout_budget / max(total_searches, 1)
else:
timeout_per_search = None
for query in specialized_queries:
for provider in self.search_providers:
# Create async task for each provider/query combination
search_tasks.append(
self._safe_search(
provider,
query,
num_results=results_per_query,
timeout_budget=timeout_per_search,
)
)
# Execute all searches in parallel using asyncio.gather()
if search_tasks:
parallel_results = await asyncio.gather(
*search_tasks, return_exceptions=True
)
# Process results and filter out exceptions
for result in parallel_results:
if isinstance(result, Exception):
# Log the exception but continue with other results
self.logger.warning(f"Search task failed: {result}")
elif isinstance(result, list):
all_results.extend(result)
elif result is not None:
all_results.append(result)
# Deduplicate results
seen_urls = set()
unique_results = []
for result in all_results:
if result.get("url") not in seen_urls:
seen_urls.add(result["url"])
unique_results.append(result)
return unique_results[:max_results]
async def _analyze_search_results(
self, results: list[dict[str, Any]], analysis_focus: str
) -> list[dict[str, Any]]:
"""Analyze search results with specialized focus."""
analyzed_results = []
for result in results:
if result.get("content"):
try:
analysis = await self.content_analyzer.analyze_content(
content=result["content"],
persona=self.persona.name.lower(),
analysis_focus=analysis_focus,
)
# Add source credibility
credibility_score = self._calculate_source_credibility(result)
analysis["credibility_score"] = credibility_score
analyzed_results.append(
{
**result,
"analysis": analysis,
"credibility_score": credibility_score,
}
)
except Exception as e:
self.logger.warning(
f"Content analysis failed for {result.get('url', 'unknown')}: {e}"
)
return analyzed_results
def _calculate_source_credibility(self, source: dict[str, Any]) -> float:
"""Calculate credibility score for a source - reuse from parent."""
return self.parent._calculate_source_credibility(source)
class FundamentalResearchAgent(BaseSubagent):
"""Specialized agent for fundamental financial analysis."""
async def execute_research(self, task) -> dict[str, Any]: # task: ResearchTask
"""Execute fundamental analysis research."""
self.logger.info(f"Executing fundamental research for: {task.target_topic}")
# Generate fundamental-specific search queries
queries = self._generate_fundamental_queries(task.target_topic)
# Perform specialized search
search_results = await self._perform_specialized_search(
topic=task.target_topic, specialized_queries=queries, max_results=8
)
# Analyze results with fundamental focus
analyzed_results = await self._analyze_search_results(
search_results, analysis_focus="fundamental_analysis"
)
# Extract fundamental-specific insights
insights = []
risks = []
opportunities = []
sources = []
for result in analyzed_results:
analysis = result.get("analysis", {})
insights.extend(analysis.get("insights", []))
risks.extend(analysis.get("risk_factors", []))
opportunities.extend(analysis.get("opportunities", []))
sources.append(
{
"title": result.get("title", ""),
"url": result.get("url", ""),
"credibility_score": result.get("credibility_score", 0.5),
"published_date": result.get("published_date"),
"author": result.get("author"),
}
)
return {
"research_type": "fundamental",
"insights": list(dict.fromkeys(insights))[:8], # Deduplicate
"risk_factors": list(dict.fromkeys(risks))[:6],
"opportunities": list(dict.fromkeys(opportunities))[:6],
"sentiment": self._calculate_fundamental_sentiment(analyzed_results),
"credibility_score": self._calculate_average_credibility(analyzed_results),
"sources": sources,
"focus_areas": [
"earnings",
"valuation",
"financial_health",
"growth_prospects",
],
}
def _generate_fundamental_queries(self, topic: str) -> list[str]:
"""Generate fundamental analysis specific queries."""
return [
f"{topic} earnings report financial results",
f"{topic} revenue growth profit margins",
f"{topic} balance sheet debt ratio financial health",
f"{topic} valuation PE ratio price earnings",
f"{topic} cash flow dividend payout",
]
def _calculate_fundamental_sentiment(
self, results: list[dict[str, Any]]
) -> dict[str, Any]:
"""Calculate sentiment specific to fundamental analysis."""
sentiments = []
for result in results:
analysis = result.get("analysis", {})
sentiment = analysis.get("sentiment", {})
if sentiment:
sentiments.append(sentiment)
if not sentiments:
return {"direction": "neutral", "confidence": 0.5}
# Simple aggregation for now
bullish_count = sum(1 for s in sentiments if s.get("direction") == "bullish")
bearish_count = sum(1 for s in sentiments if s.get("direction") == "bearish")
if bullish_count > bearish_count:
return {"direction": "bullish", "confidence": 0.7}
elif bearish_count > bullish_count:
return {"direction": "bearish", "confidence": 0.7}
else:
return {"direction": "neutral", "confidence": 0.5}
def _calculate_average_credibility(self, results: list[dict[str, Any]]) -> float:
"""Calculate average credibility of sources."""
if not results:
return 0.5
credibility_scores = [r.get("credibility_score", 0.5) for r in results]
return sum(credibility_scores) / len(credibility_scores)
class TechnicalResearchAgent(BaseSubagent):
"""Specialized agent for technical analysis research."""
async def execute_research(self, task) -> dict[str, Any]: # task: ResearchTask
"""Execute technical analysis research."""
self.logger.info(f"Executing technical research for: {task.target_topic}")
queries = self._generate_technical_queries(task.target_topic)
search_results = await self._perform_specialized_search(
topic=task.target_topic, specialized_queries=queries, max_results=6
)
analyzed_results = await self._analyze_search_results(
search_results, analysis_focus="technical_analysis"
)
# Extract technical-specific insights
insights = []
risks = []
opportunities = []
sources = []
for result in analyzed_results:
analysis = result.get("analysis", {})
insights.extend(analysis.get("insights", []))
risks.extend(analysis.get("risk_factors", []))
opportunities.extend(analysis.get("opportunities", []))
sources.append(
{
"title": result.get("title", ""),
"url": result.get("url", ""),
"credibility_score": result.get("credibility_score", 0.5),
"published_date": result.get("published_date"),
"author": result.get("author"),
}
)
return {
"research_type": "technical",
"insights": list(dict.fromkeys(insights))[:8],
"risk_factors": list(dict.fromkeys(risks))[:6],
"opportunities": list(dict.fromkeys(opportunities))[:6],
"sentiment": self._calculate_technical_sentiment(analyzed_results),
"credibility_score": self._calculate_average_credibility(analyzed_results),
"sources": sources,
"focus_areas": [
"price_action",
"chart_patterns",
"technical_indicators",
"support_resistance",
],
}
def _generate_technical_queries(self, topic: str) -> list[str]:
"""Generate technical analysis specific queries."""
return [
f"{topic} technical analysis chart pattern",
f"{topic} price target support resistance",
f"{topic} RSI MACD technical indicators",
f"{topic} breakout trend analysis",
f"{topic} volume analysis price movement",
]
def _calculate_technical_sentiment(
self, results: list[dict[str, Any]]
) -> dict[str, Any]:
"""Calculate sentiment specific to technical analysis."""
# Similar to fundamental but focused on technical indicators
sentiments = [
r.get("analysis", {}).get("sentiment", {})
for r in results
if r.get("analysis")
]
sentiments = [s for s in sentiments if s]
if not sentiments:
return {"direction": "neutral", "confidence": 0.5}
bullish_count = sum(1 for s in sentiments if s.get("direction") == "bullish")
bearish_count = sum(1 for s in sentiments if s.get("direction") == "bearish")
if bullish_count > bearish_count:
return {"direction": "bullish", "confidence": 0.6}
elif bearish_count > bullish_count:
return {"direction": "bearish", "confidence": 0.6}
else:
return {"direction": "neutral", "confidence": 0.5}
def _calculate_average_credibility(self, results: list[dict[str, Any]]) -> float:
"""Calculate average credibility of sources."""
if not results:
return 0.5
credibility_scores = [r.get("credibility_score", 0.5) for r in results]
return sum(credibility_scores) / len(credibility_scores)
class SentimentResearchAgent(BaseSubagent):
"""Specialized agent for market sentiment analysis."""
async def execute_research(self, task) -> dict[str, Any]: # task: ResearchTask
"""Execute sentiment analysis research."""
self.logger.info(f"Executing sentiment research for: {task.target_topic}")
queries = self._generate_sentiment_queries(task.target_topic)
search_results = await self._perform_specialized_search(
topic=task.target_topic, specialized_queries=queries, max_results=10
)
analyzed_results = await self._analyze_search_results(
search_results, analysis_focus="sentiment_analysis"
)
# Extract sentiment-specific insights
insights = []
risks = []
opportunities = []
sources = []
for result in analyzed_results:
analysis = result.get("analysis", {})
insights.extend(analysis.get("insights", []))
risks.extend(analysis.get("risk_factors", []))
opportunities.extend(analysis.get("opportunities", []))
sources.append(
{
"title": result.get("title", ""),
"url": result.get("url", ""),
"credibility_score": result.get("credibility_score", 0.5),
"published_date": result.get("published_date"),
"author": result.get("author"),
}
)
return {
"research_type": "sentiment",
"insights": list(dict.fromkeys(insights))[:8],
"risk_factors": list(dict.fromkeys(risks))[:6],
"opportunities": list(dict.fromkeys(opportunities))[:6],
"sentiment": self._calculate_market_sentiment(analyzed_results),
"credibility_score": self._calculate_average_credibility(analyzed_results),
"sources": sources,
"focus_areas": [
"market_sentiment",
"analyst_opinions",
"news_sentiment",
"social_sentiment",
],
}
def _generate_sentiment_queries(self, topic: str) -> list[str]:
"""Generate sentiment analysis specific queries."""
return [
f"{topic} analyst rating recommendation upgrade downgrade",
f"{topic} market sentiment investor opinion",
f"{topic} news sentiment positive negative",
f"{topic} social sentiment reddit twitter discussion",
f"{topic} institutional investor sentiment",
]
def _calculate_market_sentiment(
self, results: list[dict[str, Any]]
) -> dict[str, Any]:
"""Calculate overall market sentiment."""
sentiments = [
r.get("analysis", {}).get("sentiment", {})
for r in results
if r.get("analysis")
]
sentiments = [s for s in sentiments if s]
if not sentiments:
return {"direction": "neutral", "confidence": 0.5}
# Weighted by confidence
weighted_scores = []
total_confidence = 0
for sentiment in sentiments:
direction = sentiment.get("direction", "neutral")
confidence = sentiment.get("confidence", 0.5)
if direction == "bullish":
weighted_scores.append(1 * confidence)
elif direction == "bearish":
weighted_scores.append(-1 * confidence)
else:
weighted_scores.append(0)
total_confidence += confidence
if not weighted_scores:
return {"direction": "neutral", "confidence": 0.5}
avg_score = sum(weighted_scores) / len(weighted_scores)
avg_confidence = total_confidence / len(sentiments)
if avg_score > 0.3:
return {"direction": "bullish", "confidence": avg_confidence}
elif avg_score < -0.3:
return {"direction": "bearish", "confidence": avg_confidence}
else:
return {"direction": "neutral", "confidence": avg_confidence}
def _calculate_average_credibility(self, results: list[dict[str, Any]]) -> float:
"""Calculate average credibility of sources."""
if not results:
return 0.5
credibility_scores = [r.get("credibility_score", 0.5) for r in results]
return sum(credibility_scores) / len(credibility_scores)
class CompetitiveResearchAgent(BaseSubagent):
"""Specialized agent for competitive and industry analysis."""
async def execute_research(self, task) -> dict[str, Any]: # task: ResearchTask
"""Execute competitive analysis research."""
self.logger.info(f"Executing competitive research for: {task.target_topic}")
queries = self._generate_competitive_queries(task.target_topic)
search_results = await self._perform_specialized_search(
topic=task.target_topic, specialized_queries=queries, max_results=8
)
analyzed_results = await self._analyze_search_results(
search_results, analysis_focus="competitive_analysis"
)
# Extract competitive-specific insights
insights = []
risks = []
opportunities = []
sources = []
for result in analyzed_results:
analysis = result.get("analysis", {})
insights.extend(analysis.get("insights", []))
risks.extend(analysis.get("risk_factors", []))
opportunities.extend(analysis.get("opportunities", []))
sources.append(
{
"title": result.get("title", ""),
"url": result.get("url", ""),
"credibility_score": result.get("credibility_score", 0.5),
"published_date": result.get("published_date"),
"author": result.get("author"),
}
)
return {
"research_type": "competitive",
"insights": list(dict.fromkeys(insights))[:8],
"risk_factors": list(dict.fromkeys(risks))[:6],
"opportunities": list(dict.fromkeys(opportunities))[:6],
"sentiment": self._calculate_competitive_sentiment(analyzed_results),
"credibility_score": self._calculate_average_credibility(analyzed_results),
"sources": sources,
"focus_areas": [
"competitive_position",
"market_share",
"industry_trends",
"competitive_advantages",
],
}
def _generate_competitive_queries(self, topic: str) -> list[str]:
"""Generate competitive analysis specific queries."""
return [
f"{topic} market share competitive position industry",
f"{topic} competitors comparison competitive advantage",
f"{topic} industry analysis market trends",
f"{topic} competitive landscape market dynamics",
f"{topic} industry outlook sector performance",
]
def _calculate_competitive_sentiment(
self, results: list[dict[str, Any]]
) -> dict[str, Any]:
"""Calculate sentiment specific to competitive positioning."""
sentiments = [
r.get("analysis", {}).get("sentiment", {})
for r in results
if r.get("analysis")
]
sentiments = [s for s in sentiments if s]
if not sentiments:
return {"direction": "neutral", "confidence": 0.5}
# Focus on competitive strength indicators
bullish_count = sum(1 for s in sentiments if s.get("direction") == "bullish")
bearish_count = sum(1 for s in sentiments if s.get("direction") == "bearish")
if bullish_count > bearish_count:
return {"direction": "bullish", "confidence": 0.6}
elif bearish_count > bullish_count:
return {"direction": "bearish", "confidence": 0.6}
else:
return {"direction": "neutral", "confidence": 0.5}
def _calculate_average_credibility(self, results: list[dict[str, Any]]) -> float:
"""Calculate average credibility of sources."""
if not results:
return 0.5
credibility_scores = [r.get("credibility_score", 0.5) for r in results]
return sum(credibility_scores) / len(credibility_scores)
```