#
tokens: 32708/50000 2/435 files (page 28/29)
lines: off (toggle) GitHub
raw markdown copy
This is page 28 of 29. Use http://codebase.md/wshobson/maverick-mcp?page={x} to view the full context.

# Directory Structure

```
├── .dockerignore
├── .env.example
├── .github
│   ├── dependabot.yml
│   ├── FUNDING.yml
│   ├── ISSUE_TEMPLATE
│   │   ├── bug_report.md
│   │   ├── config.yml
│   │   ├── feature_request.md
│   │   ├── question.md
│   │   └── security_report.md
│   ├── pull_request_template.md
│   └── workflows
│       ├── claude-code-review.yml
│       └── claude.yml
├── .gitignore
├── .python-version
├── .vscode
│   ├── launch.json
│   └── settings.json
├── alembic
│   ├── env.py
│   ├── script.py.mako
│   └── versions
│       ├── 001_initial_schema.py
│       ├── 003_add_performance_indexes.py
│       ├── 006_rename_metadata_columns.py
│       ├── 008_performance_optimization_indexes.py
│       ├── 009_rename_to_supply_demand.py
│       ├── 010_self_contained_schema.py
│       ├── 011_remove_proprietary_terms.py
│       ├── 013_add_backtest_persistence_models.py
│       ├── 014_add_portfolio_models.py
│       ├── 08e3945a0c93_merge_heads.py
│       ├── 9374a5c9b679_merge_heads_for_testing.py
│       ├── abf9b9afb134_merge_multiple_heads.py
│       ├── adda6d3fd84b_merge_proprietary_terms_removal_with_.py
│       ├── e0c75b0bdadb_fix_financial_data_precision_only.py
│       ├── f0696e2cac15_add_essential_performance_indexes.py
│       └── fix_database_integrity_issues.py
├── alembic.ini
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── DATABASE_SETUP.md
├── docker-compose.override.yml.example
├── docker-compose.yml
├── Dockerfile
├── docs
│   ├── api
│   │   └── backtesting.md
│   ├── BACKTESTING.md
│   ├── COST_BASIS_SPECIFICATION.md
│   ├── deep_research_agent.md
│   ├── exa_research_testing_strategy.md
│   ├── PORTFOLIO_PERSONALIZATION_PLAN.md
│   ├── PORTFOLIO.md
│   ├── SETUP_SELF_CONTAINED.md
│   └── speed_testing_framework.md
├── examples
│   ├── complete_speed_validation.py
│   ├── deep_research_integration.py
│   ├── llm_optimization_example.py
│   ├── llm_speed_demo.py
│   ├── monitoring_example.py
│   ├── parallel_research_example.py
│   ├── speed_optimization_demo.py
│   └── timeout_fix_demonstration.py
├── LICENSE
├── Makefile
├── MANIFEST.in
├── maverick_mcp
│   ├── __init__.py
│   ├── agents
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── circuit_breaker.py
│   │   ├── deep_research.py
│   │   ├── market_analysis.py
│   │   ├── optimized_research.py
│   │   ├── supervisor.py
│   │   └── technical_analysis.py
│   ├── api
│   │   ├── __init__.py
│   │   ├── api_server.py
│   │   ├── connection_manager.py
│   │   ├── dependencies
│   │   │   ├── __init__.py
│   │   │   ├── stock_analysis.py
│   │   │   └── technical_analysis.py
│   │   ├── error_handling.py
│   │   ├── inspector_compatible_sse.py
│   │   ├── inspector_sse.py
│   │   ├── middleware
│   │   │   ├── error_handling.py
│   │   │   ├── mcp_logging.py
│   │   │   ├── rate_limiting_enhanced.py
│   │   │   └── security.py
│   │   ├── openapi_config.py
│   │   ├── routers
│   │   │   ├── __init__.py
│   │   │   ├── agents.py
│   │   │   ├── backtesting.py
│   │   │   ├── data_enhanced.py
│   │   │   ├── data.py
│   │   │   ├── health_enhanced.py
│   │   │   ├── health_tools.py
│   │   │   ├── health.py
│   │   │   ├── intelligent_backtesting.py
│   │   │   ├── introspection.py
│   │   │   ├── mcp_prompts.py
│   │   │   ├── monitoring.py
│   │   │   ├── news_sentiment_enhanced.py
│   │   │   ├── performance.py
│   │   │   ├── portfolio.py
│   │   │   ├── research.py
│   │   │   ├── screening_ddd.py
│   │   │   ├── screening_parallel.py
│   │   │   ├── screening.py
│   │   │   ├── technical_ddd.py
│   │   │   ├── technical_enhanced.py
│   │   │   ├── technical.py
│   │   │   └── tool_registry.py
│   │   ├── server.py
│   │   ├── services
│   │   │   ├── __init__.py
│   │   │   ├── base_service.py
│   │   │   ├── market_service.py
│   │   │   ├── portfolio_service.py
│   │   │   ├── prompt_service.py
│   │   │   └── resource_service.py
│   │   ├── simple_sse.py
│   │   └── utils
│   │       ├── __init__.py
│   │       ├── insomnia_export.py
│   │       └── postman_export.py
│   ├── application
│   │   ├── __init__.py
│   │   ├── commands
│   │   │   └── __init__.py
│   │   ├── dto
│   │   │   ├── __init__.py
│   │   │   └── technical_analysis_dto.py
│   │   ├── queries
│   │   │   ├── __init__.py
│   │   │   └── get_technical_analysis.py
│   │   └── screening
│   │       ├── __init__.py
│   │       ├── dtos.py
│   │       └── queries.py
│   ├── backtesting
│   │   ├── __init__.py
│   │   ├── ab_testing.py
│   │   ├── analysis.py
│   │   ├── batch_processing_stub.py
│   │   ├── batch_processing.py
│   │   ├── model_manager.py
│   │   ├── optimization.py
│   │   ├── persistence.py
│   │   ├── retraining_pipeline.py
│   │   ├── strategies
│   │   │   ├── __init__.py
│   │   │   ├── base.py
│   │   │   ├── ml
│   │   │   │   ├── __init__.py
│   │   │   │   ├── adaptive.py
│   │   │   │   ├── ensemble.py
│   │   │   │   ├── feature_engineering.py
│   │   │   │   └── regime_aware.py
│   │   │   ├── ml_strategies.py
│   │   │   ├── parser.py
│   │   │   └── templates.py
│   │   ├── strategy_executor.py
│   │   ├── vectorbt_engine.py
│   │   └── visualization.py
│   ├── config
│   │   ├── __init__.py
│   │   ├── constants.py
│   │   ├── database_self_contained.py
│   │   ├── database.py
│   │   ├── llm_optimization_config.py
│   │   ├── logging_settings.py
│   │   ├── plotly_config.py
│   │   ├── security_utils.py
│   │   ├── security.py
│   │   ├── settings.py
│   │   ├── technical_constants.py
│   │   ├── tool_estimation.py
│   │   └── validation.py
│   ├── core
│   │   ├── __init__.py
│   │   ├── technical_analysis.py
│   │   └── visualization.py
│   ├── data
│   │   ├── __init__.py
│   │   ├── cache_manager.py
│   │   ├── cache.py
│   │   ├── django_adapter.py
│   │   ├── health.py
│   │   ├── models.py
│   │   ├── performance.py
│   │   ├── session_management.py
│   │   └── validation.py
│   ├── database
│   │   ├── __init__.py
│   │   ├── base.py
│   │   └── optimization.py
│   ├── dependencies.py
│   ├── domain
│   │   ├── __init__.py
│   │   ├── entities
│   │   │   ├── __init__.py
│   │   │   └── stock_analysis.py
│   │   ├── events
│   │   │   └── __init__.py
│   │   ├── portfolio.py
│   │   ├── screening
│   │   │   ├── __init__.py
│   │   │   ├── entities.py
│   │   │   ├── services.py
│   │   │   └── value_objects.py
│   │   ├── services
│   │   │   ├── __init__.py
│   │   │   └── technical_analysis_service.py
│   │   ├── stock_analysis
│   │   │   ├── __init__.py
│   │   │   └── stock_analysis_service.py
│   │   └── value_objects
│   │       ├── __init__.py
│   │       └── technical_indicators.py
│   ├── exceptions.py
│   ├── infrastructure
│   │   ├── __init__.py
│   │   ├── cache
│   │   │   └── __init__.py
│   │   ├── caching
│   │   │   ├── __init__.py
│   │   │   └── cache_management_service.py
│   │   ├── connection_manager.py
│   │   ├── data_fetching
│   │   │   ├── __init__.py
│   │   │   └── stock_data_service.py
│   │   ├── health
│   │   │   ├── __init__.py
│   │   │   └── health_checker.py
│   │   ├── persistence
│   │   │   ├── __init__.py
│   │   │   └── stock_repository.py
│   │   ├── providers
│   │   │   └── __init__.py
│   │   ├── screening
│   │   │   ├── __init__.py
│   │   │   └── repositories.py
│   │   └── sse_optimizer.py
│   ├── langchain_tools
│   │   ├── __init__.py
│   │   ├── adapters.py
│   │   └── registry.py
│   ├── logging_config.py
│   ├── memory
│   │   ├── __init__.py
│   │   └── stores.py
│   ├── monitoring
│   │   ├── __init__.py
│   │   ├── health_check.py
│   │   ├── health_monitor.py
│   │   ├── integration_example.py
│   │   ├── metrics.py
│   │   ├── middleware.py
│   │   └── status_dashboard.py
│   ├── providers
│   │   ├── __init__.py
│   │   ├── dependencies.py
│   │   ├── factories
│   │   │   ├── __init__.py
│   │   │   ├── config_factory.py
│   │   │   └── provider_factory.py
│   │   ├── implementations
│   │   │   ├── __init__.py
│   │   │   ├── cache_adapter.py
│   │   │   ├── macro_data_adapter.py
│   │   │   ├── market_data_adapter.py
│   │   │   ├── persistence_adapter.py
│   │   │   └── stock_data_adapter.py
│   │   ├── interfaces
│   │   │   ├── __init__.py
│   │   │   ├── cache.py
│   │   │   ├── config.py
│   │   │   ├── macro_data.py
│   │   │   ├── market_data.py
│   │   │   ├── persistence.py
│   │   │   └── stock_data.py
│   │   ├── llm_factory.py
│   │   ├── macro_data.py
│   │   ├── market_data.py
│   │   ├── mocks
│   │   │   ├── __init__.py
│   │   │   ├── mock_cache.py
│   │   │   ├── mock_config.py
│   │   │   ├── mock_macro_data.py
│   │   │   ├── mock_market_data.py
│   │   │   ├── mock_persistence.py
│   │   │   └── mock_stock_data.py
│   │   ├── openrouter_provider.py
│   │   ├── optimized_screening.py
│   │   ├── optimized_stock_data.py
│   │   └── stock_data.py
│   ├── README.md
│   ├── tests
│   │   ├── __init__.py
│   │   ├── README_INMEMORY_TESTS.md
│   │   ├── test_cache_debug.py
│   │   ├── test_fixes_validation.py
│   │   ├── test_in_memory_routers.py
│   │   ├── test_in_memory_server.py
│   │   ├── test_macro_data_provider.py
│   │   ├── test_mailgun_email.py
│   │   ├── test_market_calendar_caching.py
│   │   ├── test_mcp_tool_fixes_pytest.py
│   │   ├── test_mcp_tool_fixes.py
│   │   ├── test_mcp_tools.py
│   │   ├── test_models_functional.py
│   │   ├── test_server.py
│   │   ├── test_stock_data_enhanced.py
│   │   ├── test_stock_data_provider.py
│   │   └── test_technical_analysis.py
│   ├── tools
│   │   ├── __init__.py
│   │   ├── performance_monitoring.py
│   │   ├── portfolio_manager.py
│   │   ├── risk_management.py
│   │   └── sentiment_analysis.py
│   ├── utils
│   │   ├── __init__.py
│   │   ├── agent_errors.py
│   │   ├── batch_processing.py
│   │   ├── cache_warmer.py
│   │   ├── circuit_breaker_decorators.py
│   │   ├── circuit_breaker_services.py
│   │   ├── circuit_breaker.py
│   │   ├── data_chunking.py
│   │   ├── database_monitoring.py
│   │   ├── debug_utils.py
│   │   ├── fallback_strategies.py
│   │   ├── llm_optimization.py
│   │   ├── logging_example.py
│   │   ├── logging_init.py
│   │   ├── logging.py
│   │   ├── mcp_logging.py
│   │   ├── memory_profiler.py
│   │   ├── monitoring_middleware.py
│   │   ├── monitoring.py
│   │   ├── orchestration_logging.py
│   │   ├── parallel_research.py
│   │   ├── parallel_screening.py
│   │   ├── quick_cache.py
│   │   ├── resource_manager.py
│   │   ├── shutdown.py
│   │   ├── stock_helpers.py
│   │   ├── structured_logger.py
│   │   ├── tool_monitoring.py
│   │   ├── tracing.py
│   │   └── yfinance_pool.py
│   ├── validation
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── data.py
│   │   ├── middleware.py
│   │   ├── portfolio.py
│   │   ├── responses.py
│   │   ├── screening.py
│   │   └── technical.py
│   └── workflows
│       ├── __init__.py
│       ├── agents
│       │   ├── __init__.py
│       │   ├── market_analyzer.py
│       │   ├── optimizer_agent.py
│       │   ├── strategy_selector.py
│       │   └── validator_agent.py
│       ├── backtesting_workflow.py
│       └── state.py
├── PLANS.md
├── pyproject.toml
├── pyrightconfig.json
├── README.md
├── scripts
│   ├── dev.sh
│   ├── INSTALLATION_GUIDE.md
│   ├── load_example.py
│   ├── load_market_data.py
│   ├── load_tiingo_data.py
│   ├── migrate_db.py
│   ├── README_TIINGO_LOADER.md
│   ├── requirements_tiingo.txt
│   ├── run_stock_screening.py
│   ├── run-migrations.sh
│   ├── seed_db.py
│   ├── seed_sp500.py
│   ├── setup_database.sh
│   ├── setup_self_contained.py
│   ├── setup_sp500_database.sh
│   ├── test_seeded_data.py
│   ├── test_tiingo_loader.py
│   ├── tiingo_config.py
│   └── validate_setup.py
├── SECURITY.md
├── server.json
├── setup.py
├── tests
│   ├── __init__.py
│   ├── conftest.py
│   ├── core
│   │   └── test_technical_analysis.py
│   ├── data
│   │   └── test_portfolio_models.py
│   ├── domain
│   │   ├── conftest.py
│   │   ├── test_portfolio_entities.py
│   │   └── test_technical_analysis_service.py
│   ├── fixtures
│   │   └── orchestration_fixtures.py
│   ├── integration
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── README.md
│   │   ├── run_integration_tests.sh
│   │   ├── test_api_technical.py
│   │   ├── test_chaos_engineering.py
│   │   ├── test_config_management.py
│   │   ├── test_full_backtest_workflow_advanced.py
│   │   ├── test_full_backtest_workflow.py
│   │   ├── test_high_volume.py
│   │   ├── test_mcp_tools.py
│   │   ├── test_orchestration_complete.py
│   │   ├── test_portfolio_persistence.py
│   │   ├── test_redis_cache.py
│   │   ├── test_security_integration.py.disabled
│   │   └── vcr_setup.py
│   ├── performance
│   │   ├── __init__.py
│   │   ├── test_benchmarks.py
│   │   ├── test_load.py
│   │   ├── test_profiling.py
│   │   └── test_stress.py
│   ├── providers
│   │   └── test_stock_data_simple.py
│   ├── README.md
│   ├── test_agents_router_mcp.py
│   ├── test_backtest_persistence.py
│   ├── test_cache_management_service.py
│   ├── test_cache_serialization.py
│   ├── test_circuit_breaker.py
│   ├── test_database_pool_config_simple.py
│   ├── test_database_pool_config.py
│   ├── test_deep_research_functional.py
│   ├── test_deep_research_integration.py
│   ├── test_deep_research_parallel_execution.py
│   ├── test_error_handling.py
│   ├── test_event_loop_integrity.py
│   ├── test_exa_research_integration.py
│   ├── test_exception_hierarchy.py
│   ├── test_financial_search.py
│   ├── test_graceful_shutdown.py
│   ├── test_integration_simple.py
│   ├── test_langgraph_workflow.py
│   ├── test_market_data_async.py
│   ├── test_market_data_simple.py
│   ├── test_mcp_orchestration_functional.py
│   ├── test_ml_strategies.py
│   ├── test_optimized_research_agent.py
│   ├── test_orchestration_integration.py
│   ├── test_orchestration_logging.py
│   ├── test_orchestration_tools_simple.py
│   ├── test_parallel_research_integration.py
│   ├── test_parallel_research_orchestrator.py
│   ├── test_parallel_research_performance.py
│   ├── test_performance_optimizations.py
│   ├── test_production_validation.py
│   ├── test_provider_architecture.py
│   ├── test_rate_limiting_enhanced.py
│   ├── test_runner_validation.py
│   ├── test_security_comprehensive.py.disabled
│   ├── test_security_cors.py
│   ├── test_security_enhancements.py.disabled
│   ├── test_security_headers.py
│   ├── test_security_penetration.py
│   ├── test_session_management.py
│   ├── test_speed_optimization_validation.py
│   ├── test_stock_analysis_dependencies.py
│   ├── test_stock_analysis_service.py
│   ├── test_stock_data_fetching_service.py
│   ├── test_supervisor_agent.py
│   ├── test_supervisor_functional.py
│   ├── test_tool_estimation_config.py
│   ├── test_visualization.py
│   └── utils
│       ├── test_agent_errors.py
│       ├── test_logging.py
│       ├── test_parallel_screening.py
│       └── test_quick_cache.py
├── tools
│   ├── check_orchestration_config.py
│   ├── experiments
│   │   ├── validation_examples.py
│   │   └── validation_fixed.py
│   ├── fast_dev.sh
│   ├── hot_reload.py
│   ├── quick_test.py
│   └── templates
│       ├── new_router_template.py
│       ├── new_tool_template.py
│       ├── screening_strategy_template.py
│       └── test_template.py
└── uv.lock
```

# Files

--------------------------------------------------------------------------------
/maverick_mcp/utils/llm_optimization.py:
--------------------------------------------------------------------------------

```python
"""
LLM-side optimizations for research agents to prevent timeouts.

This module provides comprehensive optimization strategies including:
- Adaptive model selection based on time constraints
- Progressive token budgeting with confidence tracking
- Parallel LLM processing with intelligent load balancing
- Optimized prompt engineering for speed
- Early termination based on confidence thresholds
- Content filtering to reduce processing overhead
"""

import asyncio
import logging
import re
import time
from datetime import datetime
from enum import Enum
from typing import Any

from langchain_core.messages import HumanMessage, SystemMessage
from pydantic import BaseModel, Field

from maverick_mcp.providers.openrouter_provider import (
    OpenRouterProvider,
    TaskType,
)
from maverick_mcp.utils.orchestration_logging import (
    get_orchestration_logger,
    log_method_call,
)

logger = logging.getLogger(__name__)


class ResearchPhase(str, Enum):
    """Research phases for token allocation."""

    SEARCH = "search"
    CONTENT_ANALYSIS = "content_analysis"
    SYNTHESIS = "synthesis"
    VALIDATION = "validation"


class ModelConfiguration(BaseModel):
    """Configuration for model selection with time optimization."""

    model_id: str = Field(description="OpenRouter model identifier")
    max_tokens: int = Field(description="Maximum output tokens")
    temperature: float = Field(description="Model temperature")
    timeout_seconds: float = Field(description="Request timeout")
    parallel_batch_size: int = Field(
        default=1, description="Sources per batch for this model"
    )


class TokenAllocation(BaseModel):
    """Token allocation for a research phase."""

    input_tokens: int = Field(description="Maximum input tokens")
    output_tokens: int = Field(description="Maximum output tokens")
    per_source_tokens: int = Field(description="Tokens per source")
    emergency_reserve: int = Field(description="Emergency reserve tokens")
    timeout_seconds: float = Field(description="Processing timeout")


class AdaptiveModelSelector:
    """Intelligent model selection based on time budgets and task complexity."""

    def __init__(self, openrouter_provider: OpenRouterProvider):
        self.provider = openrouter_provider
        self.performance_cache = {}  # Cache model performance metrics

    def select_model_for_time_budget(
        self,
        task_type: TaskType,
        time_remaining_seconds: float,
        complexity_score: float,
        content_size_tokens: int,
        confidence_threshold: float = 0.8,
        current_confidence: float = 0.0,
    ) -> ModelConfiguration:
        """Select optimal model based on available time and requirements."""

        # Time pressure categories with adaptive thresholds
        if time_remaining_seconds < 10:
            return self._select_emergency_model(task_type, content_size_tokens)
        elif time_remaining_seconds < 25:
            return self._select_fast_quality_model(task_type, complexity_score)
        elif time_remaining_seconds < 45:
            return self._select_balanced_model(
                task_type, complexity_score, current_confidence
            )
        else:
            return self._select_optimal_model(
                task_type, complexity_score, confidence_threshold
            )

    def _select_emergency_model(
        self, task_type: TaskType, content_size: int
    ) -> ModelConfiguration:
        """Ultra-fast models for time-critical situations."""
        # OPTIMIZATION: Prioritize speed with increased batch sizes
        if content_size > 20000:  # Large content needs fast + capable models
            return ModelConfiguration(
                model_id="google/gemini-2.5-flash",  # 199 tokens/sec - fastest available
                max_tokens=min(800, content_size // 25),  # Adaptive token limit
                temperature=0.05,  # OPTIMIZATION: Minimal temp for deterministic fast response
                timeout_seconds=5,  # OPTIMIZATION: Reduced from 8s
                parallel_batch_size=8,  # OPTIMIZATION: Doubled for faster processing
            )
        else:
            return ModelConfiguration(
                model_id="openai/gpt-4o-mini",  # 126 tokens/sec - excellent speed/cost balance
                max_tokens=min(500, content_size // 20),
                temperature=0.03,  # OPTIMIZATION: Near-zero for fastest response
                timeout_seconds=4,  # OPTIMIZATION: Reduced from 6s
                parallel_batch_size=10,  # OPTIMIZATION: Doubled for maximum parallelism
            )

    def _select_fast_quality_model(
        self, task_type: TaskType, complexity_score: float
    ) -> ModelConfiguration:
        """Balance speed and quality for time-constrained situations."""
        if complexity_score > 0.7 or task_type == TaskType.COMPLEX_REASONING:
            # Complex tasks - use fast model with good quality
            return ModelConfiguration(
                model_id="openai/gpt-4o-mini",  # 126 tokens/sec + good quality
                max_tokens=1200,
                temperature=0.1,  # OPTIMIZATION: Reduced for faster response
                timeout_seconds=10,  # OPTIMIZATION: Reduced from 18s
                parallel_batch_size=6,  # OPTIMIZATION: Doubled for better parallelism
            )
        else:
            # Simple tasks - use the fastest model available
            return ModelConfiguration(
                model_id="google/gemini-2.5-flash",  # 199 tokens/sec - fastest
                max_tokens=1000,
                temperature=0.1,  # OPTIMIZATION: Reduced for faster response
                timeout_seconds=8,  # OPTIMIZATION: Reduced from 12s
                parallel_batch_size=8,  # OPTIMIZATION: Doubled for maximum speed
            )

    def _select_balanced_model(
        self, task_type: TaskType, complexity_score: float, current_confidence: float
    ) -> ModelConfiguration:
        """Standard mode with cost-effectiveness focus."""
        # If confidence is already high, use fastest models for validation
        if current_confidence > 0.7:
            return ModelConfiguration(
                model_id="google/gemini-2.5-flash",  # 199 tokens/sec - fastest validation
                max_tokens=1500,
                temperature=0.25,
                timeout_seconds=20,  # Reduced for fastest model
                parallel_batch_size=4,  # Increased for speed
            )

        # Standard balanced approach - prioritize speed-optimized models
        if task_type in [TaskType.DEEP_RESEARCH, TaskType.RESULT_SYNTHESIS]:
            return ModelConfiguration(
                model_id="openai/gpt-4o-mini",  # Speed + quality balance for research
                max_tokens=2000,
                temperature=0.3,
                timeout_seconds=25,  # Reduced for faster model
                parallel_batch_size=3,  # Increased for speed
            )
        else:
            return ModelConfiguration(
                model_id="google/gemini-2.5-flash",  # Fastest for general tasks
                max_tokens=1500,
                temperature=0.25,
                timeout_seconds=20,  # Reduced for fastest model
                parallel_batch_size=4,  # Increased for speed
            )

    def _select_optimal_model(
        self, task_type: TaskType, complexity_score: float, confidence_threshold: float
    ) -> ModelConfiguration:
        """Comprehensive mode for complex analysis."""
        # Use premium models for the most complex tasks when time allows
        if complexity_score > 0.8 and task_type == TaskType.DEEP_RESEARCH:
            return ModelConfiguration(
                model_id="google/gemini-2.5-pro",
                max_tokens=3000,
                temperature=0.3,
                timeout_seconds=45,
                parallel_batch_size=1,  # Deep thinking models work better individually
            )

        # High-quality cost-effective models for standard comprehensive analysis
        return ModelConfiguration(
            model_id="anthropic/claude-sonnet-4",
            max_tokens=2500,
            temperature=0.3,
            timeout_seconds=40,
            parallel_batch_size=2,
        )

    def calculate_task_complexity(
        self, content: str, task_type: TaskType, focus_areas: list[str] | None = None
    ) -> float:
        """Calculate complexity score based on content and task requirements."""
        if not content:
            return 0.3  # Default low complexity

        content_lower = content.lower()

        # Financial complexity indicators
        complexity_indicators = {
            "financial_jargon": len(
                re.findall(
                    r"\b(?:ebitda|dcf|roic?|wacc|beta|volatility|sharpe)\b",
                    content_lower,
                )
            ),
            "numerical_data": len(re.findall(r"\$?[\d,]+\.?\d*[%kmbKMB]?", content)),
            "comparative_analysis": len(
                re.findall(
                    r"\b(?:versus|compared|relative|outperform|underperform)\b",
                    content_lower,
                )
            ),
            "temporal_analysis": len(
                re.findall(r"\b(?:quarterly|q[1-4]|fy|yoy|qoq|annual)\b", content_lower)
            ),
            "market_terms": len(
                re.findall(
                    r"\b(?:bullish|bearish|catalyst|headwind|tailwind)\b", content_lower
                )
            ),
            "technical_terms": len(
                re.findall(
                    r"\b(?:support|resistance|breakout|rsi|macd|sma|ema)\b",
                    content_lower,
                )
            ),
        }

        # Calculate base complexity
        total_indicators = sum(complexity_indicators.values())
        content_length = len(content.split())
        base_complexity = min(total_indicators / max(content_length / 100, 1), 1.0)

        # Task-specific complexity adjustments
        task_multipliers = {
            TaskType.DEEP_RESEARCH: 1.4,
            TaskType.COMPLEX_REASONING: 1.6,
            TaskType.RESULT_SYNTHESIS: 1.2,
            TaskType.TECHNICAL_ANALYSIS: 1.3,
            TaskType.SENTIMENT_ANALYSIS: 0.8,
            TaskType.QUICK_ANSWER: 0.5,
        }

        # Focus area adjustments
        focus_multiplier = 1.0
        if focus_areas:
            complex_focus_areas = [
                "competitive_analysis",
                "fundamental_analysis",
                "complex_reasoning",
            ]
            if any(area in focus_areas for area in complex_focus_areas):
                focus_multiplier = 1.2

        final_complexity = (
            base_complexity * task_multipliers.get(task_type, 1.0) * focus_multiplier
        )
        return min(final_complexity, 1.0)


class ProgressiveTokenBudgeter:
    """Manages token budgets across research phases with time awareness."""

    def __init__(
        self, total_time_budget_seconds: float, confidence_target: float = 0.75
    ):
        self.total_time_budget = total_time_budget_seconds
        self.confidence_target = confidence_target
        self.phase_budgets = self._calculate_base_phase_budgets()
        self.time_started = time.time()

    def _calculate_base_phase_budgets(self) -> dict[ResearchPhase, int]:
        """Calculate base token budgets for each research phase."""
        # Allocate tokens based on typical phase requirements
        if self.total_time_budget < 30:
            # Emergency mode - minimal tokens
            return {
                ResearchPhase.SEARCH: 500,
                ResearchPhase.CONTENT_ANALYSIS: 2000,
                ResearchPhase.SYNTHESIS: 800,
                ResearchPhase.VALIDATION: 300,
            }
        elif self.total_time_budget < 60:
            # Fast mode
            return {
                ResearchPhase.SEARCH: 1000,
                ResearchPhase.CONTENT_ANALYSIS: 4000,
                ResearchPhase.SYNTHESIS: 1500,
                ResearchPhase.VALIDATION: 500,
            }
        else:
            # Standard mode
            return {
                ResearchPhase.SEARCH: 1500,
                ResearchPhase.CONTENT_ANALYSIS: 6000,
                ResearchPhase.SYNTHESIS: 2500,
                ResearchPhase.VALIDATION: 1000,
            }

    def allocate_tokens_for_phase(
        self,
        phase: ResearchPhase,
        sources_count: int,
        current_confidence: float,
        complexity_score: float = 0.5,
    ) -> TokenAllocation:
        """Allocate tokens for a research phase based on current state."""

        time_elapsed = time.time() - self.time_started
        time_remaining = max(0, self.total_time_budget - time_elapsed)

        base_budget = self.phase_budgets[phase]

        # Confidence-based scaling
        if current_confidence > self.confidence_target:
            # High confidence - focus on validation with fewer tokens
            confidence_multiplier = 0.7
        elif current_confidence < 0.4:
            # Low confidence - increase token usage if time allows
            confidence_multiplier = 1.3 if time_remaining > 30 else 0.9
        else:
            confidence_multiplier = 1.0

        # Time pressure scaling
        time_multiplier = self._calculate_time_multiplier(time_remaining)

        # Complexity scaling
        complexity_multiplier = 0.8 + (complexity_score * 0.4)  # Range: 0.8 to 1.2

        # Source count scaling (diminishing returns)
        if sources_count > 0:
            source_multiplier = min(1.0 + (sources_count - 3) * 0.05, 1.3)
        else:
            source_multiplier = 1.0

        # Calculate final budget
        final_budget = int(
            base_budget
            * confidence_multiplier
            * time_multiplier
            * complexity_multiplier
            * source_multiplier
        )

        # Calculate timeout based on available time and token budget
        base_timeout = min(time_remaining * 0.8, 45)  # Max 45 seconds per phase
        adjusted_timeout = base_timeout * (final_budget / base_budget) ** 0.5

        return TokenAllocation(
            input_tokens=min(int(final_budget * 0.75), 15000),  # Cap input tokens
            output_tokens=min(int(final_budget * 0.25), 3000),  # Cap output tokens
            per_source_tokens=final_budget // max(sources_count, 1)
            if sources_count > 0
            else final_budget,
            emergency_reserve=200,  # Always keep emergency reserve
            timeout_seconds=max(adjusted_timeout, 5),  # Minimum 5 seconds
        )

    def get_next_allocation(
        self,
        sources_remaining: int,
        current_confidence: float,
        time_elapsed_seconds: float,
    ) -> dict[str, Any]:
        """Get the next token allocation for processing sources."""
        time_remaining = max(0, self.total_time_budget - time_elapsed_seconds)

        # Determine priority based on confidence and time pressure
        if current_confidence < 0.4 and time_remaining > 30:
            priority = "high"
        elif current_confidence < 0.6 and time_remaining > 15:
            priority = "medium"
        else:
            priority = "low"

        # Calculate time budget per remaining source
        if sources_remaining > 0:
            time_per_source = time_remaining / sources_remaining
        else:
            time_per_source = 0

        # Calculate token budget
        base_tokens = self.phase_budgets.get(ResearchPhase.CONTENT_ANALYSIS, 2000)

        if priority == "high":
            max_tokens = min(int(base_tokens * 1.2), 4000)
        elif priority == "medium":
            max_tokens = base_tokens
        else:
            max_tokens = int(base_tokens * 0.8)

        return {
            "time_budget": min(time_per_source, 30.0),  # Cap at 30 seconds
            "max_tokens": max_tokens,
            "priority": priority,
            "sources_remaining": sources_remaining,
        }

    def _calculate_time_multiplier(self, time_remaining: float) -> float:
        """Scale token budget based on time pressure."""
        if time_remaining < 5:
            return 0.2  # Extreme emergency mode
        elif time_remaining < 15:
            return 0.4  # Emergency mode
        elif time_remaining < 30:
            return 0.7  # Time-constrained
        elif time_remaining < 60:
            return 0.9  # Slightly reduced
        else:
            return 1.0  # Full budget available


class ParallelLLMProcessor:
    """Handles parallel LLM operations with intelligent load balancing."""

    def __init__(
        self,
        openrouter_provider: OpenRouterProvider,
        max_concurrent: int = 5,  # OPTIMIZATION: Increased from 3
    ):
        self.provider = openrouter_provider
        self.max_concurrent = max_concurrent
        self.semaphore = asyncio.BoundedSemaphore(
            max_concurrent
        )  # OPTIMIZATION: Use BoundedSemaphore
        self.model_selector = AdaptiveModelSelector(openrouter_provider)
        self.orchestration_logger = get_orchestration_logger("ParallelLLMProcessor")
        # OPTIMIZATION: Track active requests for better coordination
        self._active_requests = 0
        self._request_lock = asyncio.Lock()

    @log_method_call(component="ParallelLLMProcessor", include_timing=True)
    async def parallel_content_analysis(
        self,
        sources: list[dict],
        analysis_type: str,
        persona: str,
        time_budget_seconds: float,
        current_confidence: float = 0.0,
    ) -> list[dict]:
        """Analyze multiple sources in parallel with adaptive optimization."""

        if not sources:
            return []

        self.orchestration_logger.set_request_context(
            analysis_type=analysis_type,
            source_count=len(sources),
            time_budget=time_budget_seconds,
        )

        # Calculate complexity for all sources
        combined_content = "\n".join(
            [source.get("content", "")[:1000] for source in sources[:5]]
        )
        overall_complexity = self.model_selector.calculate_task_complexity(
            combined_content,
            TaskType.SENTIMENT_ANALYSIS
            if analysis_type == "sentiment"
            else TaskType.MARKET_ANALYSIS,
        )

        # Determine optimal batching strategy
        model_config = self.model_selector.select_model_for_time_budget(
            task_type=TaskType.SENTIMENT_ANALYSIS
            if analysis_type == "sentiment"
            else TaskType.MARKET_ANALYSIS,
            time_remaining_seconds=time_budget_seconds,
            complexity_score=overall_complexity,
            content_size_tokens=len(combined_content) // 4,
            current_confidence=current_confidence,
        )

        # Create batches based on model configuration
        batches = self._create_optimal_batches(
            sources, model_config.parallel_batch_size
        )

        self.orchestration_logger.info(
            "🔄 PARALLEL_ANALYSIS_START",
            total_sources=len(sources),
            batch_count=len(batches),
        )

        # OPTIMIZATION: Process batches using create_task for immediate parallelism
        running_tasks = []
        for i, batch in enumerate(batches):
            # Create task immediately without awaiting
            task_future = asyncio.create_task(
                self._analyze_source_batch(
                    batch=batch,
                    batch_id=i,
                    analysis_type=analysis_type,
                    persona=persona,
                    model_config=model_config,
                    overall_complexity=overall_complexity,
                )
            )
            running_tasks.append((i, task_future))  # Track batch ID with future

            # OPTIMIZATION: Minimal stagger to prevent API overload
            if i < len(batches) - 1:  # Don't delay after last batch
                await asyncio.sleep(0.01)  # 10ms micro-delay

        # OPTIMIZATION: Use as_completed for progressive result handling
        batch_results = [None] * len(batches)  # Pre-allocate results list
        timeout_at = time.time() + (time_budget_seconds * 0.9)

        try:
            for batch_id, task_future in running_tasks:
                remaining_time = timeout_at - time.time()
                if remaining_time <= 0:
                    raise TimeoutError()

                try:
                    result = await asyncio.wait_for(task_future, timeout=remaining_time)
                    batch_results[batch_id] = result
                except Exception as e:
                    batch_results[batch_id] = e
        except TimeoutError:
            self.orchestration_logger.warning(
                "⏰ PARALLEL_ANALYSIS_TIMEOUT", timeout=time_budget_seconds
            )
            return self._create_fallback_results(sources)

        # Flatten and process results
        final_results = []
        successful_batches = 0
        for i, batch_result in enumerate(batch_results):
            if isinstance(batch_result, Exception):
                self.orchestration_logger.warning(
                    "⚠️ BATCH_FAILED", batch_id=i, error=str(batch_result)
                )
                # Add fallback results for failed batch
                final_results.extend(self._create_fallback_results(batches[i]))
            else:
                final_results.extend(batch_result)
                successful_batches += 1

        self.orchestration_logger.info(
            "✅ PARALLEL_ANALYSIS_COMPLETE",
            successful_batches=successful_batches,
            results_count=len(final_results),
        )

        return final_results

    def _create_optimal_batches(
        self, sources: list[dict], batch_size: int
    ) -> list[list[dict]]:
        """Create optimal batches for parallel processing."""
        if batch_size <= 1:
            return [[source] for source in sources]

        batches = []
        for i in range(0, len(sources), batch_size):
            batch = sources[i : i + batch_size]
            batches.append(batch)

        return batches

    async def _analyze_source_batch(
        self,
        batch: list[dict],
        batch_id: int,
        analysis_type: str,
        persona: str,
        model_config: ModelConfiguration,
        overall_complexity: float,
    ) -> list[dict]:
        """Analyze a batch of sources with optimized LLM call."""

        # OPTIMIZATION: Track active requests for better coordination
        async with self._request_lock:
            self._active_requests += 1

        try:
            # OPTIMIZATION: Acquire semaphore without blocking other task creation
            await self.semaphore.acquire()
            try:
                # Create batch analysis prompt
                batch_prompt = self._create_batch_analysis_prompt(
                    batch, analysis_type, persona, model_config.max_tokens
                )

                # Get LLM instance
                llm = self.provider.get_llm(
                    model_override=model_config.model_id,
                    temperature=model_config.temperature,
                    max_tokens=model_config.max_tokens,
                )

                # Execute with timeout
                start_time = time.time()
                result = await asyncio.wait_for(
                    llm.ainvoke(
                        [
                            SystemMessage(
                                content="You are a financial analyst. Provide structured, concise analysis."
                            ),
                            HumanMessage(content=batch_prompt),
                        ]
                    ),
                    timeout=model_config.timeout_seconds,
                )

                execution_time = time.time() - start_time

                # Parse batch results
                parsed_results = self._parse_batch_analysis_result(
                    result.content, batch
                )

                self.orchestration_logger.debug(
                    "✨ BATCH_SUCCESS",
                    batch_id=batch_id,
                    duration=f"{execution_time:.2f}s",
                )

                return parsed_results

            except TimeoutError:
                self.orchestration_logger.warning(
                    "⏰ BATCH_TIMEOUT",
                    batch_id=batch_id,
                    timeout=model_config.timeout_seconds,
                )
                return self._create_fallback_results(batch)
            except Exception as e:
                self.orchestration_logger.error(
                    "💥 BATCH_ERROR", batch_id=batch_id, error=str(e)
                )
                return self._create_fallback_results(batch)
            finally:
                # OPTIMIZATION: Always release semaphore
                self.semaphore.release()
        finally:
            # OPTIMIZATION: Track active requests
            async with self._request_lock:
                self._active_requests -= 1

    def _create_batch_analysis_prompt(
        self, batch: list[dict], analysis_type: str, persona: str, max_tokens: int
    ) -> str:
        """Create optimized prompt for batch analysis."""

        # Determine prompt style based on token budget
        if max_tokens < 800:
            style = "ultra_concise"
        elif max_tokens < 1500:
            style = "concise"
        else:
            style = "detailed"

        prompt_templates = {
            "ultra_concise": """URGENT BATCH ANALYSIS - {analysis_type} for {persona} investor.

Analyze {source_count} sources. For EACH source, provide:
SOURCE_N: SENTIMENT:Bull/Bear/Neutral|CONFIDENCE:0-1|INSIGHT:one key point|RISK:main risk

{sources}

Keep total response under 500 words.""",
            "concise": """BATCH ANALYSIS - {analysis_type} for {persona} investor perspective.

Analyze these {source_count} sources. For each source provide:
- Sentiment: Bull/Bear/Neutral + confidence (0-1)
- Key insight (1 sentence)
- Main risk (1 sentence)
- Relevance score (0-1)

{sources}

Format consistently. Target ~100 words per source.""",
            "detailed": """Comprehensive {analysis_type} analysis for {persona} investor.

Analyze these {source_count} sources with structured output for each:

{sources}

For each source provide:
1. Sentiment (direction, confidence 0-1, brief reasoning)
2. Key insights (2-3 main points)
3. Risk factors (1-2 key risks)
4. Opportunities (1-2 opportunities if any)
5. Credibility assessment (0-1 score)
6. Relevance score (0-1)

Maintain {persona} investor perspective throughout.""",
        }

        # Format sources for prompt
        sources_text = ""
        for i, source in enumerate(batch, 1):
            content = source.get("content", "")[:1500]  # Limit content length
            title = source.get("title", f"Source {i}")
            sources_text += f"\nSOURCE {i} - {title}:\n{content}\n{'---' * 20}\n"

        template = prompt_templates[style]
        return template.format(
            analysis_type=analysis_type,
            persona=persona,
            source_count=len(batch),
            sources=sources_text.strip(),
        )

    def _parse_batch_analysis_result(
        self, result_content: str, batch: list[dict]
    ) -> list[dict]:
        """Parse LLM batch analysis result into structured data."""

        results = []

        # Try structured parsing first
        source_sections = re.split(r"\n(?:SOURCE\s+\d+|---+)", result_content)

        if len(source_sections) >= len(batch):
            # Structured parsing successful
            for _i, (source, section) in enumerate(
                zip(batch, source_sections[1 : len(batch) + 1], strict=False)
            ):
                parsed = self._parse_source_analysis(section, source)
                results.append(parsed)
        else:
            # Fallback to simple parsing
            for i, source in enumerate(batch):
                fallback_analysis = self._create_simple_fallback_analysis(
                    result_content, source, i
                )
                results.append(fallback_analysis)

        return results

    def _parse_source_analysis(self, analysis_text: str, source: dict) -> dict:
        """Parse analysis text for a single source."""

        # Extract sentiment
        sentiment_match = re.search(
            r"sentiment:?\s*(\w+)[,\s]*(?:confidence:?\s*([\d.]+))?",
            analysis_text.lower(),
        )
        if sentiment_match:
            direction = sentiment_match.group(1).lower()
            confidence = float(sentiment_match.group(2) or 0.5)

            # Map common sentiment terms
            if direction in ["bull", "bullish", "positive"]:
                direction = "bullish"
            elif direction in ["bear", "bearish", "negative"]:
                direction = "bearish"
            else:
                direction = "neutral"
        else:
            direction = "neutral"
            confidence = 0.5

        # Extract other information
        insights = self._extract_insights(analysis_text)
        risks = self._extract_risks(analysis_text)
        opportunities = self._extract_opportunities(analysis_text)

        # Extract scores
        relevance_match = re.search(r"relevance:?\s*([\d.]+)", analysis_text.lower())
        relevance_score = float(relevance_match.group(1)) if relevance_match else 0.6

        credibility_match = re.search(
            r"credibility:?\s*([\d.]+)", analysis_text.lower()
        )
        credibility_score = (
            float(credibility_match.group(1)) if credibility_match else 0.7
        )

        return {
            **source,
            "analysis": {
                "insights": insights,
                "sentiment": {"direction": direction, "confidence": confidence},
                "risk_factors": risks,
                "opportunities": opportunities,
                "credibility_score": credibility_score,
                "relevance_score": relevance_score,
                "analysis_timestamp": datetime.now(),
                "batch_processed": True,
            },
        }

    def _extract_insights(self, text: str) -> list[str]:
        """Extract insights from analysis text."""
        insights = []

        # Look for insight patterns
        insight_patterns = [
            r"insight:?\s*([^.\n]+)",
            r"key point:?\s*([^.\n]+)",
            r"main finding:?\s*([^.\n]+)",
        ]

        for pattern in insight_patterns:
            matches = re.findall(pattern, text, re.IGNORECASE)
            insights.extend([m.strip() for m in matches if m.strip()])

        # If no structured insights found, extract bullet points
        if not insights:
            bullet_matches = re.findall(r"[•\-\*]\s*([^.\n]+)", text)
            insights.extend([m.strip() for m in bullet_matches if m.strip()][:3])

        return insights[:5]  # Limit to 5 insights

    def _extract_risks(self, text: str) -> list[str]:
        """Extract risk factors from analysis text."""
        risk_patterns = [
            r"risk:?\s*([^.\n]+)",
            r"concern:?\s*([^.\n]+)",
            r"headwind:?\s*([^.\n]+)",
        ]

        risks = []
        for pattern in risk_patterns:
            matches = re.findall(pattern, text, re.IGNORECASE)
            risks.extend([m.strip() for m in matches if m.strip()])

        return risks[:3]

    def _extract_opportunities(self, text: str) -> list[str]:
        """Extract opportunities from analysis text."""
        opp_patterns = [
            r"opportunit(?:y|ies):?\s*([^.\n]+)",
            r"catalyst:?\s*([^.\n]+)",
            r"tailwind:?\s*([^.\n]+)",
        ]

        opportunities = []
        for pattern in opp_patterns:
            matches = re.findall(pattern, text, re.IGNORECASE)
            opportunities.extend([m.strip() for m in matches if m.strip()])

        return opportunities[:3]

    def _create_simple_fallback_analysis(
        self, full_analysis: str, source: dict, index: int
    ) -> dict:
        """Create simple fallback analysis when parsing fails."""

        # Basic sentiment analysis from text
        analysis_lower = full_analysis.lower()

        positive_words = ["positive", "bullish", "strong", "growth", "opportunity"]
        negative_words = ["negative", "bearish", "weak", "decline", "risk"]

        pos_count = sum(1 for word in positive_words if word in analysis_lower)
        neg_count = sum(1 for word in negative_words if word in analysis_lower)

        if pos_count > neg_count:
            sentiment = "bullish"
            confidence = 0.6
        elif neg_count > pos_count:
            sentiment = "bearish"
            confidence = 0.6
        else:
            sentiment = "neutral"
            confidence = 0.5

        return {
            **source,
            "analysis": {
                "insights": [f"Analysis based on source content (index {index})"],
                "sentiment": {"direction": sentiment, "confidence": confidence},
                "risk_factors": ["Unable to extract specific risks"],
                "opportunities": ["Unable to extract specific opportunities"],
                "credibility_score": 0.5,
                "relevance_score": 0.5,
                "analysis_timestamp": datetime.now(),
                "fallback_used": True,
                "batch_processed": True,
            },
        }

    def _create_fallback_results(self, sources: list[dict]) -> list[dict]:
        """Create fallback results when batch processing fails."""
        results = []
        for source in sources:
            fallback_result = {
                **source,
                "analysis": {
                    "insights": ["Analysis failed - using fallback"],
                    "sentiment": {"direction": "neutral", "confidence": 0.3},
                    "risk_factors": ["Analysis timeout - unable to assess risks"],
                    "opportunities": [],
                    "credibility_score": 0.5,
                    "relevance_score": 0.5,
                    "analysis_timestamp": datetime.now(),
                    "fallback_used": True,
                    "batch_timeout": True,
                },
            }
            results.append(fallback_result)
        return results


class OptimizedPromptEngine:
    """Creates optimized prompts for different time constraints and confidence levels."""

    def __init__(self):
        self.prompt_cache = {}  # Cache for generated prompts

        self.prompt_templates = {
            "emergency": {
                "content_analysis": """URGENT: Quick 3-point analysis of financial content for {persona} investor.

Content: {content}

Provide ONLY:
1. SENTIMENT: Bull/Bear/Neutral + confidence (0-1)
2. KEY_RISK: Primary risk factor
3. KEY_OPPORTUNITY: Main opportunity (if any)

Format: SENTIMENT:Bull|0.8 KEY_RISK:Market volatility KEY_OPPORTUNITY:Earnings growth
Max 50 words total. No explanations.""",
                "synthesis": """URGENT: 2-sentence summary from {source_count} sources for {persona} investor.

Key findings: {key_points}

Provide: 1) Overall sentiment direction 2) Primary investment implication
Max 40 words total.""",
            },
            "fast": {
                "content_analysis": """Quick financial analysis for {persona} investor - 5 points max.

Content: {content}

Provide concisely:
• Sentiment: Bull/Bear/Neutral (confidence 0-1)
• Key insight (1 sentence)
• Main risk (1 sentence)
• Main opportunity (1 sentence)
• Relevance score (0-1)

Target: Under 150 words total.""",
                "synthesis": """Synthesize research findings for {persona} investor.

Sources: {source_count} | Key insights: {insights}

4-part summary:
1. Overall sentiment + confidence
2. Top 2 opportunities
3. Top 2 risks
4. Recommended action

Limit: 200 words max.""",
            },
            "standard": {
                "content_analysis": """Financial content analysis for {persona} investor.

Content: {content}
Focus areas: {focus_areas}

Structured analysis:
- Sentiment (direction, confidence 0-1, brief reasoning)
- Key insights (3-5 bullet points)
- Risk factors (2-3 main risks)
- Opportunities (2-3 opportunities)
- Credibility assessment (0-1)
- Relevance score (0-1)

Target: 300-500 words.""",
                "synthesis": """Comprehensive research synthesis for {persona} investor.

Research Summary:
- Sources analyzed: {source_count}
- Key insights: {insights}
- Time horizon: {time_horizon}

Provide detailed analysis:
1. Executive Summary (2-3 sentences)
2. Key Findings (5-7 bullet points)
3. Investment Implications
4. Risk Assessment
5. Recommended Actions
6. Confidence Level + reasoning

Tailor specifically for {persona} investment characteristics.""",
            },
        }

    def get_optimized_prompt(
        self,
        prompt_type: str,
        time_remaining: float,
        confidence_level: float,
        **context,
    ) -> str:
        """Generate optimized prompt based on time constraints and confidence."""

        # Create cache key
        cache_key = f"{prompt_type}_{time_remaining:.0f}_{confidence_level:.1f}_{hash(str(sorted(context.items())))}"

        if cache_key in self.prompt_cache:
            return self.prompt_cache[cache_key]

        # Select template based on time pressure
        if time_remaining < 15:
            template_category = "emergency"
        elif time_remaining < 45:
            template_category = "fast"
        else:
            template_category = "standard"

        template = self.prompt_templates[template_category].get(prompt_type)

        if not template:
            # Fallback to fast template
            template = self.prompt_templates["fast"].get(
                prompt_type, "Analyze the content quickly and provide key insights."
            )

        # Add confidence-based instructions
        confidence_instructions = ""
        if confidence_level > 0.7:
            confidence_instructions = "\n\nNOTE: High confidence already achieved. Focus on validation and contradictory evidence."
        elif confidence_level < 0.4:
            confidence_instructions = "\n\nNOTE: Low confidence. Look for strong supporting evidence to build confidence."

        # Format template with context
        formatted_prompt = template.format(**context) + confidence_instructions

        # Cache the result
        self.prompt_cache[cache_key] = formatted_prompt

        return formatted_prompt

    def create_time_optimized_synthesis_prompt(
        self,
        sources: list[dict],
        persona: str,
        time_remaining: float,
        current_confidence: float,
    ) -> str:
        """Create synthesis prompt optimized for available time."""

        # Extract key information from sources
        insights = []
        sentiments = []
        for source in sources:
            analysis = source.get("analysis", {})
            insights.extend(analysis.get("insights", [])[:2])  # Limit per source
            sentiment = analysis.get("sentiment", {})
            if sentiment:
                sentiments.append(sentiment.get("direction", "neutral"))

        # Prepare context
        context = {
            "persona": persona,
            "source_count": len(sources),
            "insights": "; ".join(insights[:8]),  # Top 8 insights
            "key_points": "; ".join(insights[:8]),  # For backward compatibility
            "time_horizon": "short-term" if time_remaining < 30 else "medium-term",
        }

        return self.get_optimized_prompt(
            "synthesis", time_remaining, current_confidence, **context
        )


class ConfidenceTracker:
    """Tracks research confidence and triggers early termination when appropriate."""

    def __init__(
        self,
        target_confidence: float = 0.75,
        min_sources: int = 3,
        max_sources: int = 15,
    ):
        self.target_confidence = target_confidence
        self.min_sources = min_sources
        self.max_sources = max_sources
        self.confidence_history = []
        self.evidence_history = []
        self.source_count = 0
        self.sources_analyzed = 0  # For backward compatibility
        self.last_significant_improvement = 0
        self.sentiment_votes = {"bullish": 0, "bearish": 0, "neutral": 0}

    def update_confidence(
        self,
        new_evidence: dict,
        source_credibility: float | None = None,
        credibility_score: float | None = None,
    ) -> dict[str, Any]:
        """Update confidence based on new evidence and return continuation decision."""

        # Handle both parameter names for backward compatibility
        if source_credibility is None and credibility_score is not None:
            source_credibility = credibility_score
        elif source_credibility is None and credibility_score is None:
            source_credibility = 0.5  # Default value

        self.source_count += 1
        self.sources_analyzed += 1  # Keep both for compatibility

        # Store evidence
        self.evidence_history.append(
            {
                "evidence": new_evidence,
                "credibility": source_credibility,
                "timestamp": datetime.now(),
            }
        )

        # Update sentiment voting
        sentiment = new_evidence.get("sentiment", {})
        direction = sentiment.get("direction", "neutral")
        confidence = sentiment.get("confidence", 0.5)

        # Weight vote by source credibility and sentiment confidence
        vote_weight = source_credibility * confidence
        self.sentiment_votes[direction] += vote_weight

        # Calculate evidence strength
        evidence_strength = self._calculate_evidence_strength(
            new_evidence, source_credibility
        )

        # Update confidence using Bayesian-style updating
        current_confidence = self._update_bayesian_confidence(evidence_strength)
        self.confidence_history.append(current_confidence)

        # Check for significant improvement
        if len(self.confidence_history) >= 2:
            improvement = current_confidence - self.confidence_history[-2]
            if improvement > 0.1:  # 10% improvement
                self.last_significant_improvement = self.source_count

        # Make continuation decision
        should_continue = self._should_continue_research(current_confidence)

        return {
            "current_confidence": current_confidence,
            "should_continue": should_continue,
            "sources_processed": self.source_count,
            "sources_analyzed": self.source_count,  # For backward compatibility
            "confidence_trend": self._calculate_confidence_trend(),
            "early_termination_reason": None
            if should_continue
            else self._get_termination_reason(current_confidence),
            "sentiment_consensus": self._calculate_sentiment_consensus(),
        }

    def _calculate_evidence_strength(self, evidence: dict, credibility: float) -> float:
        """Calculate the strength of new evidence."""

        # Base strength from sentiment confidence
        sentiment = evidence.get("sentiment", {})
        sentiment_confidence = sentiment.get("confidence", 0.5)

        # Adjust for source credibility
        credibility_adjusted = sentiment_confidence * credibility

        # Factor in evidence richness
        insights_count = len(evidence.get("insights", []))
        risk_factors_count = len(evidence.get("risk_factors", []))
        opportunities_count = len(evidence.get("opportunities", []))

        # Evidence richness score (0-1)
        evidence_richness = min(
            (insights_count + risk_factors_count + opportunities_count) / 12, 1.0
        )

        # Relevance factor
        relevance_score = evidence.get("relevance_score", 0.5)

        # Final evidence strength calculation
        final_strength = credibility_adjusted * (
            0.5 + 0.3 * evidence_richness + 0.2 * relevance_score
        )

        return min(final_strength, 1.0)

    def _update_bayesian_confidence(self, evidence_strength: float) -> float:
        """Update confidence using Bayesian approach."""

        if not self.confidence_history:
            # First evidence - base confidence
            return evidence_strength

        # Current prior
        prior = self.confidence_history[-1]

        # Bayesian update with evidence strength as likelihood
        # Simple approximation: weighted average with decay
        decay_factor = 0.9 ** (self.source_count - 1)  # Diminishing returns

        updated = prior * decay_factor + evidence_strength * (1 - decay_factor)

        # Ensure within bounds
        return max(0.1, min(updated, 0.95))

    def _should_continue_research(self, current_confidence: float) -> bool:
        """Determine if research should continue based on multiple factors."""

        # Always process minimum sources
        if self.source_count < self.min_sources:
            return True

        # Stop at maximum sources
        if self.source_count >= self.max_sources:
            return False

        # High confidence reached
        if current_confidence >= self.target_confidence:
            return False

        # Check for diminishing returns
        if self.source_count - self.last_significant_improvement > 4:
            # No significant improvement in last 4 sources
            return False

        # Check sentiment consensus
        consensus_score = self._calculate_sentiment_consensus()
        if consensus_score > 0.8 and self.source_count >= 5:
            # Strong consensus with adequate sample
            return False

        # Check confidence plateau
        if len(self.confidence_history) >= 3:
            recent_change = abs(current_confidence - self.confidence_history[-3])
            if recent_change < 0.03:  # Less than 3% change in last 3 sources
                return False

        return True

    def _calculate_confidence_trend(self) -> str:
        """Calculate the trend in confidence over recent sources."""

        if len(self.confidence_history) < 3:
            return "insufficient_data"

        recent = self.confidence_history[-3:]

        # Calculate trend
        if recent[-1] > recent[0] + 0.05:
            return "increasing"
        elif recent[-1] < recent[0] - 0.05:
            return "decreasing"
        else:
            return "stable"

    def _calculate_sentiment_consensus(self) -> float:
        """Calculate how much sources agree on sentiment."""

        total_votes = sum(self.sentiment_votes.values())
        if total_votes == 0:
            return 0.0

        # Calculate consensus as max vote share
        max_votes = max(self.sentiment_votes.values())
        consensus = max_votes / total_votes

        return consensus

    def _get_termination_reason(self, current_confidence: float) -> str:
        """Get reason for early termination."""

        if current_confidence >= self.target_confidence:
            return "target_confidence_reached"
        elif self.source_count >= self.max_sources:
            return "max_sources_reached"
        elif self._calculate_sentiment_consensus() > 0.8:
            return "strong_consensus"
        elif self.source_count - self.last_significant_improvement > 4:
            return "diminishing_returns"
        else:
            return "confidence_plateau"


class IntelligentContentFilter:
    """Pre-filters and prioritizes content to reduce LLM processing overhead."""

    def __init__(self):
        self.relevance_keywords = {
            "fundamental": {
                "high": [
                    "earnings",
                    "revenue",
                    "profit",
                    "ebitda",
                    "cash flow",
                    "debt",
                    "valuation",
                ],
                "medium": [
                    "balance sheet",
                    "income statement",
                    "financial",
                    "quarterly",
                    "annual",
                ],
                "context": ["company", "business", "financial results", "guidance"],
            },
            "technical": {
                "high": [
                    "price",
                    "chart",
                    "trend",
                    "support",
                    "resistance",
                    "breakout",
                ],
                "medium": ["volume", "rsi", "macd", "moving average", "pattern"],
                "context": ["technical analysis", "trading", "momentum"],
            },
            "sentiment": {
                "high": ["rating", "upgrade", "downgrade", "buy", "sell", "hold"],
                "medium": ["analyst", "recommendation", "target price", "outlook"],
                "context": ["opinion", "sentiment", "market mood"],
            },
            "competitive": {
                "high": [
                    "market share",
                    "competitor",
                    "competitive advantage",
                    "industry",
                ],
                "medium": ["peer", "comparison", "market position", "sector"],
                "context": ["competitive landscape", "industry analysis"],
            },
        }

        self.domain_credibility_scores = {
            "reuters.com": 0.95,
            "bloomberg.com": 0.95,
            "wsj.com": 0.90,
            "ft.com": 0.90,
            "marketwatch.com": 0.85,
            "cnbc.com": 0.80,
            "yahoo.com": 0.75,
            "seekingalpha.com": 0.80,
            "fool.com": 0.70,
            "investing.com": 0.75,
        }

    async def filter_and_prioritize_sources(
        self,
        sources: list[dict],
        research_focus: str,
        time_budget: float,
        target_source_count: int | None = None,
        current_confidence: float = 0.0,
    ) -> list[dict]:
        """Filter and prioritize sources based on relevance, quality, and time constraints."""

        if not sources:
            return []

        # Determine target count based on time budget and confidence
        if target_source_count is None:
            target_source_count = self._calculate_optimal_source_count(
                time_budget, current_confidence, len(sources)
            )

        # Quick relevance scoring without LLM
        scored_sources = []
        for source in sources:
            relevance_score = self._calculate_relevance_score(source, research_focus)
            credibility_score = self._get_source_credibility(source)
            recency_score = self._calculate_recency_score(source.get("published_date"))

            # Combined score with weights
            combined_score = (
                relevance_score * 0.5 + credibility_score * 0.3 + recency_score * 0.2
            )

            if combined_score > 0.3:  # Relevance threshold
                scored_sources.append((combined_score, source))

        # Sort by combined score
        scored_sources.sort(key=lambda x: x[0], reverse=True)

        # Select diverse sources
        selected_sources = self._select_diverse_sources(
            scored_sources, target_source_count, research_focus
        )

        # Pre-process content for faster LLM processing
        processed_sources = []
        for score, source in selected_sources:
            processed_source = self._preprocess_content(
                source, research_focus, time_budget
            )
            processed_source["relevance_score"] = score
            processed_sources.append(processed_source)

        return processed_sources

    def _calculate_optimal_source_count(
        self, time_budget: float, current_confidence: float, available_sources: int
    ) -> int:
        """Calculate optimal number of sources to process given constraints."""

        # Base count from time budget
        if time_budget < 20:
            base_count = 3
        elif time_budget < 40:
            base_count = 6
        elif time_budget < 80:
            base_count = 10
        else:
            base_count = 15

        # Adjust for confidence level
        if current_confidence > 0.7:
            # High confidence - fewer sources needed
            confidence_multiplier = 0.7
        elif current_confidence < 0.4:
            # Low confidence - more sources helpful
            confidence_multiplier = 1.2
        else:
            confidence_multiplier = 1.0

        # Final calculation
        target_count = int(base_count * confidence_multiplier)

        # Ensure we don't exceed available sources
        return min(target_count, available_sources, 20)  # Cap at 20

    def _calculate_relevance_score(self, source: dict, research_focus: str) -> float:
        """Calculate relevance score using keyword matching and heuristics."""

        content = source.get("content", "").lower()
        title = source.get("title", "").lower()

        if not content and not title:
            return 0.0

        focus_keywords = self.relevance_keywords.get(research_focus, {})

        # High-value keywords
        high_keywords = focus_keywords.get("high", [])
        high_score = sum(1 for keyword in high_keywords if keyword in content) / max(
            len(high_keywords), 1
        )

        # Medium-value keywords
        medium_keywords = focus_keywords.get("medium", [])
        medium_score = sum(
            1 for keyword in medium_keywords if keyword in content
        ) / max(len(medium_keywords), 1)

        # Context keywords
        context_keywords = focus_keywords.get("context", [])
        context_score = sum(
            1 for keyword in context_keywords if keyword in content
        ) / max(len(context_keywords), 1)

        # Title relevance (titles are more focused)
        title_high_score = sum(
            1 for keyword in high_keywords if keyword in title
        ) / max(len(high_keywords), 1)

        # Combine scores with weights
        relevance_score = (
            high_score * 0.4
            + medium_score * 0.25
            + context_score * 0.15
            + title_high_score * 0.2
        )

        # Boost for very relevant titles
        if any(keyword in title for keyword in high_keywords):
            relevance_score *= 1.2

        return min(relevance_score, 1.0)

    def _get_source_credibility(self, source: dict) -> float:
        """Calculate source credibility based on domain and other factors."""

        url = source.get("url", "").lower()

        # Domain-based credibility
        domain_score = 0.5  # Default
        for domain, score in self.domain_credibility_scores.items():
            if domain in url:
                domain_score = score
                break

        # Boost for specific high-quality indicators
        if any(indicator in url for indicator in [".gov", ".edu", "sec.gov"]):
            domain_score = min(domain_score + 0.2, 1.0)

        # Penalty for low-quality indicators
        if any(indicator in url for indicator in ["blog", "forum", "reddit"]):
            domain_score *= 0.8

        return domain_score

    def _calculate_recency_score(self, published_date: str) -> float:
        """Calculate recency score based on publication date."""

        if not published_date:
            return 0.5  # Default for unknown dates

        try:
            # Parse date (handle various formats)
            if "T" in published_date:
                pub_date = datetime.fromisoformat(published_date.replace("Z", "+00:00"))
            else:
                pub_date = datetime.strptime(published_date, "%Y-%m-%d")

            # Calculate days old
            days_old = (datetime.now() - pub_date.replace(tzinfo=None)).days

            # Scoring based on age
            if days_old <= 1:
                return 1.0  # Very recent
            elif days_old <= 7:
                return 0.9  # Recent
            elif days_old <= 30:
                return 0.7  # Fairly recent
            elif days_old <= 90:
                return 0.5  # Moderately old
            else:
                return 0.3  # Old

        except (ValueError, TypeError):
            return 0.5  # Default for unparseable dates

    def _select_diverse_sources(
        self,
        scored_sources: list[tuple[float, dict]],
        target_count: int,
        research_focus: str,
    ) -> list[tuple[float, dict]]:
        """Select diverse sources to avoid redundancy."""

        if len(scored_sources) <= target_count:
            return scored_sources

        selected = []
        used_domains = set()

        # First pass: select high-scoring diverse sources
        for score, source in scored_sources:
            if len(selected) >= target_count:
                break

            url = source.get("url", "")
            domain = self._extract_domain(url)

            # Ensure diversity by domain (max 2 from same domain initially)
            domain_count = sum(
                1
                for _, s in selected
                if self._extract_domain(s.get("url", "")) == domain
            )

            if domain_count < 2 or len(selected) < target_count // 2:
                selected.append((score, source))
                used_domains.add(domain)

        # Second pass: fill remaining slots with best remaining sources
        remaining_needed = target_count - len(selected)
        if remaining_needed > 0:
            remaining_sources = scored_sources[len(selected) :]
            selected.extend(remaining_sources[:remaining_needed])

        return selected[:target_count]

    def _extract_domain(self, url: str) -> str:
        """Extract domain from URL."""
        try:
            if "//" in url:
                domain = url.split("//")[1].split("/")[0]
                return domain.replace("www.", "")
            return url
        except Exception:
            return url

    def _preprocess_content(
        self, source: dict, research_focus: str, time_budget: float
    ) -> dict:
        """Pre-process content to optimize for LLM analysis."""

        content = source.get("content", "")
        if not content:
            return source

        # Determine content length limit based on time budget
        if time_budget < 30:
            max_length = 800  # Emergency mode
        elif time_budget < 60:
            max_length = 1200  # Fast mode
        else:
            max_length = 2000  # Standard mode

        # If content is already short enough, return as-is
        if len(content) <= max_length:
            source_copy = source.copy()
            source_copy["original_length"] = len(content)
            source_copy["filtered"] = False
            return source_copy

        # Extract most relevant sentences/paragraphs
        sentences = re.split(r"[.!?]+", content)
        focus_keywords = self.relevance_keywords.get(research_focus, {})
        all_keywords = (
            focus_keywords.get("high", [])
            + focus_keywords.get("medium", [])
            + focus_keywords.get("context", [])
        )

        # Score sentences by keyword relevance
        scored_sentences = []
        for sentence in sentences:
            if len(sentence.strip()) < 20:  # Skip very short sentences
                continue

            sentence_lower = sentence.lower()
            keyword_count = sum(
                1 for keyword in all_keywords if keyword in sentence_lower
            )

            # Boost for financial numbers and percentages
            has_numbers = bool(re.search(r"\$?[\d,]+\.?\d*[%kmbKMB]?", sentence))
            number_boost = 0.5 if has_numbers else 0

            sentence_score = keyword_count + number_boost
            if sentence_score > 0:
                scored_sentences.append((sentence_score, sentence.strip()))

        # Sort by relevance and select top sentences
        scored_sentences.sort(key=lambda x: x[0], reverse=True)

        # Build filtered content
        filtered_content = ""
        for _score, sentence in scored_sentences:
            if len(filtered_content) + len(sentence) > max_length:
                break
            filtered_content += sentence + ". "

        # If no relevant sentences found, take first part of original content
        if not filtered_content:
            filtered_content = content[:max_length]

        # Create processed source
        source_copy = source.copy()
        source_copy["content"] = filtered_content.strip()
        source_copy["original_length"] = len(content)
        source_copy["filtered_length"] = len(filtered_content)
        source_copy["filtered"] = True
        source_copy["compression_ratio"] = len(filtered_content) / len(content)

        return source_copy


# Export main classes for integration
__all__ = [
    "AdaptiveModelSelector",
    "ProgressiveTokenBudgeter",
    "ParallelLLMProcessor",
    "OptimizedPromptEngine",
    "ConfidenceTracker",
    "IntelligentContentFilter",
    "ModelConfiguration",
    "TokenAllocation",
    "ResearchPhase",
]

```

--------------------------------------------------------------------------------
/maverick_mcp/data/models.py:
--------------------------------------------------------------------------------

```python
"""
SQLAlchemy models for MaverickMCP.

This module defines database models for financial data storage and analysis,
including PriceCache and Maverick screening models.
"""

from __future__ import annotations

import logging
import os
import threading
import uuid
from collections.abc import AsyncGenerator, Sequence
from datetime import UTC, date, datetime, timedelta
from decimal import Decimal

import pandas as pd
from sqlalchemy import (
    JSON,
    BigInteger,
    Boolean,
    Column,
    Date,
    DateTime,
    ForeignKey,
    Index,
    Integer,
    Numeric,
    String,
    Text,
    UniqueConstraint,
    Uuid,
    create_engine,
    inspect,
)
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import Session, relationship, sessionmaker
from sqlalchemy.pool import NullPool, QueuePool

from maverick_mcp.config.settings import get_settings
from maverick_mcp.database.base import Base

# Set up logging
logger = logging.getLogger("maverick_mcp.data.models")
settings = get_settings()


# Helper function to get the right integer type for autoincrement primary keys
def get_primary_key_type():
    """Get the appropriate primary key type based on database backend."""
    # SQLite works better with INTEGER for autoincrement, PostgreSQL can use BIGINT
    if "sqlite" in DATABASE_URL:
        return Integer
    else:
        return BigInteger


# Database connection setup
# Try multiple possible environment variable names
# Use SQLite in-memory for GitHub Actions or test environments
if os.getenv("GITHUB_ACTIONS") == "true" or os.getenv("CI") == "true":
    DATABASE_URL = "sqlite:///:memory:"
else:
    DATABASE_URL = (
        os.getenv("DATABASE_URL")
        or os.getenv("POSTGRES_URL")
        or "sqlite:///maverick_mcp.db"  # Default to SQLite
    )

# Database configuration from settings
DB_POOL_SIZE = settings.db.pool_size
DB_MAX_OVERFLOW = settings.db.pool_max_overflow
DB_POOL_TIMEOUT = settings.db.pool_timeout
DB_POOL_RECYCLE = int(os.getenv("DB_POOL_RECYCLE", "3600"))  # 1 hour
DB_POOL_PRE_PING = os.getenv("DB_POOL_PRE_PING", "true").lower() == "true"
DB_ECHO = os.getenv("DB_ECHO", "false").lower() == "true"
DB_USE_POOLING = os.getenv("DB_USE_POOLING", "true").lower() == "true"

# Log the connection string (without password) for debugging
if DATABASE_URL:
    # Mask password in URL for logging
    masked_url = DATABASE_URL
    if "@" in DATABASE_URL and "://" in DATABASE_URL:
        parts = DATABASE_URL.split("://", 1)
        if len(parts) == 2 and "@" in parts[1]:
            user_pass, host_db = parts[1].split("@", 1)
            if ":" in user_pass:
                user, _ = user_pass.split(":", 1)
                masked_url = f"{parts[0]}://{user}:****@{host_db}"
    logger.info(f"Using database URL: {masked_url}")
    logger.info(f"Connection pooling: {'ENABLED' if DB_USE_POOLING else 'DISABLED'}")
    if DB_USE_POOLING:
        logger.info(
            f"Pool config: size={DB_POOL_SIZE}, max_overflow={DB_MAX_OVERFLOW}, "
            f"timeout={DB_POOL_TIMEOUT}s, recycle={DB_POOL_RECYCLE}s"
        )

# Create engine with configurable connection pooling
if DB_USE_POOLING:
    # Prepare connection arguments based on database type
    if "postgresql" in DATABASE_URL:
        # PostgreSQL-specific connection args
        sync_connect_args = {
            "connect_timeout": 10,
            "application_name": "maverick_mcp",
            "options": f"-c statement_timeout={settings.db.statement_timeout}",
        }
    elif "sqlite" in DATABASE_URL:
        # SQLite-specific args - no SSL parameters
        sync_connect_args = {"check_same_thread": False}
    else:
        # Default - no connection args
        sync_connect_args = {}

    # Use QueuePool for production environments
    engine = create_engine(
        DATABASE_URL,
        poolclass=QueuePool,
        pool_size=DB_POOL_SIZE,
        max_overflow=DB_MAX_OVERFLOW,
        pool_timeout=DB_POOL_TIMEOUT,
        pool_recycle=DB_POOL_RECYCLE,
        pool_pre_ping=DB_POOL_PRE_PING,
        echo=DB_ECHO,
        connect_args=sync_connect_args,
    )
else:
    # Prepare minimal connection arguments for NullPool
    if "sqlite" in DATABASE_URL:
        sync_connect_args = {"check_same_thread": False}
    else:
        sync_connect_args = {}

    # Use NullPool for serverless/development environments
    engine = create_engine(
        DATABASE_URL,
        poolclass=NullPool,
        echo=DB_ECHO,
        connect_args=sync_connect_args,
    )

# Create session factory
_session_factory = sessionmaker(autocommit=False, autoflush=False, bind=engine)

_schema_lock = threading.Lock()
_schema_initialized = False


def ensure_database_schema(force: bool = False) -> bool:
    """Ensure the database schema exists for the configured engine.

    Args:
        force: When ``True`` the schema will be (re)created even if it appears
            to exist already.

    Returns:
        ``True`` if the schema creation routine executed, ``False`` otherwise.
    """

    global _schema_initialized

    # Fast path: skip inspection once the schema has been verified unless the
    # caller explicitly requests a forced refresh.
    if not force and _schema_initialized:
        return False

    with _schema_lock:
        if not force and _schema_initialized:
            return False

        try:
            inspector = inspect(engine)
            existing_tables = set(inspector.get_table_names())
        except SQLAlchemyError as exc:  # pragma: no cover - safety net
            logger.warning(
                "Unable to inspect database schema; attempting to create tables anyway",
                exc_info=exc,
            )
            existing_tables = set()

        defined_tables = set(Base.metadata.tables.keys())
        missing_tables = defined_tables - existing_tables

        should_create = force or bool(missing_tables)
        if should_create:
            if missing_tables:
                logger.info(
                    "Creating missing database tables: %s",
                    ", ".join(sorted(missing_tables)),
                )
            else:
                logger.info("Ensuring database schema is up to date")

            Base.metadata.create_all(bind=engine)
            _schema_initialized = True
            return True

        _schema_initialized = True
        return False


class _SessionFactoryWrapper:
    """Session factory that ensures the schema exists before creating sessions."""

    def __init__(self, factory: sessionmaker):
        self._factory = factory

    def __call__(self, *args, **kwargs):
        ensure_database_schema()
        return self._factory(*args, **kwargs)

    def __getattr__(self, name):
        return getattr(self._factory, name)


SessionLocal = _SessionFactoryWrapper(_session_factory)

# Create async engine - cached globally for reuse
_async_engine = None
_async_session_factory = None


def _get_async_engine():
    """Get or create the async engine singleton."""
    global _async_engine
    if _async_engine is None:
        # Convert sync URL to async URL
        if DATABASE_URL.startswith("sqlite://"):
            async_url = DATABASE_URL.replace("sqlite://", "sqlite+aiosqlite://")
        else:
            async_url = DATABASE_URL.replace("postgresql://", "postgresql+asyncpg://")

        # Create async engine - don't specify poolclass for async engines
        # SQLAlchemy will use the appropriate async pool automatically
        if DB_USE_POOLING:
            # Prepare connection arguments based on database type
            if "postgresql" in async_url:
                # PostgreSQL-specific connection args
                async_connect_args = {
                    "server_settings": {
                        "application_name": "maverick_mcp_async",
                        "statement_timeout": str(settings.db.statement_timeout),
                    }
                }
            elif "sqlite" in async_url:
                # SQLite-specific args - no SSL parameters
                async_connect_args = {"check_same_thread": False}
            else:
                # Default - no connection args
                async_connect_args = {}

            _async_engine = create_async_engine(
                async_url,
                # Don't specify poolclass - let SQLAlchemy choose the async pool
                pool_size=DB_POOL_SIZE,
                max_overflow=DB_MAX_OVERFLOW,
                pool_timeout=DB_POOL_TIMEOUT,
                pool_recycle=DB_POOL_RECYCLE,
                pool_pre_ping=DB_POOL_PRE_PING,
                echo=DB_ECHO,
                connect_args=async_connect_args,
            )
        else:
            # Prepare minimal connection arguments for NullPool
            if "sqlite" in async_url:
                async_connect_args = {"check_same_thread": False}
            else:
                async_connect_args = {}

            _async_engine = create_async_engine(
                async_url,
                poolclass=NullPool,
                echo=DB_ECHO,
                connect_args=async_connect_args,
            )
        logger.info("Created async database engine")
    return _async_engine


def _get_async_session_factory():
    """Get or create the async session factory singleton."""
    global _async_session_factory
    if _async_session_factory is None:
        engine = _get_async_engine()
        _async_session_factory = async_sessionmaker(
            engine, class_=AsyncSession, expire_on_commit=False
        )
        logger.info("Created async session factory")
    return _async_session_factory


def get_db():
    """Get database session."""
    ensure_database_schema()
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()


# Async database support - imports moved to top of file


async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
    """Get an async database session using the cached engine."""
    # Get the cached session factory
    async_session_factory = _get_async_session_factory()

    # Create and yield a session
    async with async_session_factory() as session:
        try:
            yield session
        finally:
            await session.close()


async def close_async_db_connections():
    """Close the async database engine and cleanup connections."""
    global _async_engine, _async_session_factory
    if _async_engine:
        await _async_engine.dispose()
        _async_engine = None
        _async_session_factory = None
        logger.info("Closed async database engine")


def init_db():
    """Initialize database by creating all tables."""

    ensure_database_schema(force=True)


class TimestampMixin:
    """Mixin for created_at and updated_at timestamps."""

    created_at = Column(
        DateTime(timezone=True),
        default=lambda: datetime.now(UTC),
        nullable=False,
    )
    updated_at = Column(
        DateTime(timezone=True),
        default=lambda: datetime.now(UTC),
        onupdate=lambda: datetime.now(UTC),
        nullable=False,
    )


class Stock(Base, TimestampMixin):
    """Stock model for storing basic stock information."""

    __tablename__ = "mcp_stocks"

    stock_id = Column(Uuid, primary_key=True, default=uuid.uuid4)
    ticker_symbol = Column(String(10), unique=True, nullable=False, index=True)
    company_name = Column(String(255))
    description = Column(Text)
    sector = Column(String(100))
    industry = Column(String(100))
    exchange = Column(String(50))
    country = Column(String(50))
    currency = Column(String(3))
    isin = Column(String(12))

    # Additional stock metadata
    market_cap = Column(BigInteger)
    shares_outstanding = Column(BigInteger)
    is_etf = Column(Boolean, default=False)
    is_active = Column(Boolean, default=True, index=True)

    # Relationships
    price_caches = relationship(
        "PriceCache",
        back_populates="stock",
        cascade="all, delete-orphan",
        lazy="selectin",  # Eager load price caches to prevent N+1 queries
    )
    maverick_stocks = relationship(
        "MaverickStocks", back_populates="stock", cascade="all, delete-orphan"
    )
    maverick_bear_stocks = relationship(
        "MaverickBearStocks", back_populates="stock", cascade="all, delete-orphan"
    )
    supply_demand_stocks = relationship(
        "SupplyDemandBreakoutStocks",
        back_populates="stock",
        cascade="all, delete-orphan",
    )
    technical_cache = relationship(
        "TechnicalCache", back_populates="stock", cascade="all, delete-orphan"
    )

    def __repr__(self):
        return f"<Stock(ticker={self.ticker_symbol}, name={self.company_name})>"

    @classmethod
    def get_or_create(cls, session: Session, ticker_symbol: str, **kwargs) -> Stock:
        """Get existing stock or create new one."""
        stock = (
            session.query(cls).filter_by(ticker_symbol=ticker_symbol.upper()).first()
        )
        if not stock:
            stock = cls(ticker_symbol=ticker_symbol.upper(), **kwargs)
            session.add(stock)
            session.commit()
        return stock


class PriceCache(Base, TimestampMixin):
    """Cache for historical stock price data."""

    __tablename__ = "mcp_price_cache"
    __table_args__ = (
        UniqueConstraint("stock_id", "date", name="mcp_price_cache_stock_date_unique"),
        Index("mcp_price_cache_stock_id_date_idx", "stock_id", "date"),
        Index("mcp_price_cache_ticker_date_idx", "stock_id", "date"),
    )

    price_cache_id = Column(Uuid, primary_key=True, default=uuid.uuid4)
    stock_id = Column(Uuid, ForeignKey("mcp_stocks.stock_id"), nullable=False)
    date = Column(Date, nullable=False)
    open_price = Column(Numeric(12, 4))
    high_price = Column(Numeric(12, 4))
    low_price = Column(Numeric(12, 4))
    close_price = Column(Numeric(12, 4))
    volume = Column(BigInteger)

    # Relationships
    stock = relationship(
        "Stock", back_populates="price_caches", lazy="joined"
    )  # Eager load stock info

    def __repr__(self):
        return f"<PriceCache(stock_id={self.stock_id}, date={self.date}, close={self.close_price})>"

    @classmethod
    def get_price_data(
        cls,
        session: Session,
        ticker_symbol: str,
        start_date: str,
        end_date: str | None = None,
    ) -> pd.DataFrame:
        """
        Return a pandas DataFrame of price data for the specified symbol and date range.

        Args:
            session: Database session
            ticker_symbol: Stock ticker symbol
            start_date: Start date in YYYY-MM-DD format
            end_date: End date in YYYY-MM-DD format (default: today)

        Returns:
            DataFrame with OHLCV data indexed by date
        """
        if not end_date:
            end_date = datetime.now(UTC).strftime("%Y-%m-%d")

        # Query with join to get ticker symbol
        query = (
            session.query(
                cls.date,
                cls.open_price.label("open"),
                cls.high_price.label("high"),
                cls.low_price.label("low"),
                cls.close_price.label("close"),
                cls.volume,
            )
            .join(Stock)
            .filter(
                Stock.ticker_symbol == ticker_symbol.upper(),
                cls.date >= pd.to_datetime(start_date).date(),
                cls.date <= pd.to_datetime(end_date).date(),
            )
            .order_by(cls.date)
        )

        # Convert to DataFrame
        df = pd.DataFrame(query.all())

        if not df.empty:
            df["date"] = pd.to_datetime(df["date"])
            df.set_index("date", inplace=True)

            # Convert decimal types to float
            for col in ["open", "high", "low", "close"]:
                df[col] = df[col].astype(float)

            df["volume"] = df["volume"].astype(int)
            df["symbol"] = ticker_symbol.upper()

        return df


class MaverickStocks(Base, TimestampMixin):
    """Maverick stocks screening results - self-contained model."""

    __tablename__ = "mcp_maverick_stocks"
    __table_args__ = (
        Index("mcp_maverick_stocks_combined_score_idx", "combined_score"),
        Index(
            "mcp_maverick_stocks_momentum_score_idx", "momentum_score"
        ),  # formerly rs_rating_idx
        Index("mcp_maverick_stocks_date_analyzed_idx", "date_analyzed"),
        Index("mcp_maverick_stocks_stock_date_idx", "stock_id", "date_analyzed"),
    )

    id = Column(get_primary_key_type(), primary_key=True, autoincrement=True)
    stock_id = Column(
        Uuid,
        ForeignKey("mcp_stocks.stock_id"),
        nullable=False,
        index=True,
    )
    date_analyzed = Column(
        Date, nullable=False, default=lambda: datetime.now(UTC).date()
    )
    # OHLCV Data
    open_price = Column(Numeric(12, 4), default=0)
    high_price = Column(Numeric(12, 4), default=0)
    low_price = Column(Numeric(12, 4), default=0)
    close_price = Column(Numeric(12, 4), default=0)
    volume = Column(BigInteger, default=0)

    # Technical Indicators
    ema_21 = Column(Numeric(12, 4), default=0)
    sma_50 = Column(Numeric(12, 4), default=0)
    sma_150 = Column(Numeric(12, 4), default=0)
    sma_200 = Column(Numeric(12, 4), default=0)
    momentum_score = Column(Numeric(5, 2), default=0)  # formerly rs_rating
    avg_vol_30d = Column(Numeric(15, 2), default=0)
    adr_pct = Column(Numeric(5, 2), default=0)
    atr = Column(Numeric(12, 4), default=0)

    # Pattern Analysis
    pattern_type = Column(String(50))  # 'pat' field
    squeeze_status = Column(String(50))  # 'sqz' field
    consolidation_status = Column(String(50))  # formerly vcp_status, 'vcp' field
    entry_signal = Column(String(50))  # 'entry' field

    # Scoring
    compression_score = Column(Integer, default=0)
    pattern_detected = Column(Integer, default=0)
    combined_score = Column(Integer, default=0)

    # Relationships
    stock = relationship("Stock", back_populates="maverick_stocks")

    def __repr__(self):
        return f"<MaverickStock(stock_id={self.stock_id}, close={self.close_price}, score={self.combined_score})>"

    @classmethod
    def get_top_stocks(
        cls, session: Session, limit: int = 20
    ) -> Sequence[MaverickStocks]:
        """Get top maverick stocks by combined score."""
        return (
            session.query(cls)
            .join(Stock)
            .order_by(cls.combined_score.desc())
            .limit(limit)
            .all()
        )

    @classmethod
    def get_latest_analysis(
        cls, session: Session, days_back: int = 1
    ) -> Sequence[MaverickStocks]:
        """Get latest maverick analysis within specified days."""
        cutoff_date = datetime.now(UTC).date() - timedelta(days=days_back)
        return (
            session.query(cls)
            .join(Stock)
            .filter(cls.date_analyzed >= cutoff_date)
            .order_by(cls.combined_score.desc())
            .all()
        )

    def to_dict(self) -> dict:
        """Convert to dictionary for JSON serialization."""
        return {
            "stock_id": str(self.stock_id),
            "ticker": self.stock.ticker_symbol if self.stock else None,
            "date_analyzed": self.date_analyzed.isoformat()
            if self.date_analyzed
            else None,
            "close": float(self.close_price) if self.close_price else 0,
            "volume": self.volume,
            "momentum_score": float(self.momentum_score)
            if self.momentum_score
            else 0,  # formerly rs_rating
            "adr_pct": float(self.adr_pct) if self.adr_pct else 0,
            "pattern": self.pattern_type,
            "squeeze": self.squeeze_status,
            "consolidation": self.consolidation_status,  # formerly vcp
            "entry": self.entry_signal,
            "combined_score": self.combined_score,
            "compression_score": self.compression_score,
            "pattern_detected": self.pattern_detected,
            "ema_21": float(self.ema_21) if self.ema_21 else 0,
            "sma_50": float(self.sma_50) if self.sma_50 else 0,
            "sma_150": float(self.sma_150) if self.sma_150 else 0,
            "sma_200": float(self.sma_200) if self.sma_200 else 0,
            "atr": float(self.atr) if self.atr else 0,
            "avg_vol_30d": float(self.avg_vol_30d) if self.avg_vol_30d else 0,
        }


class MaverickBearStocks(Base, TimestampMixin):
    """Maverick bear stocks screening results - self-contained model."""

    __tablename__ = "mcp_maverick_bear_stocks"
    __table_args__ = (
        Index("mcp_maverick_bear_stocks_score_idx", "score"),
        Index(
            "mcp_maverick_bear_stocks_momentum_score_idx", "momentum_score"
        ),  # formerly rs_rating_idx
        Index("mcp_maverick_bear_stocks_date_analyzed_idx", "date_analyzed"),
        Index("mcp_maverick_bear_stocks_stock_date_idx", "stock_id", "date_analyzed"),
    )

    id = Column(get_primary_key_type(), primary_key=True, autoincrement=True)
    stock_id = Column(
        Uuid,
        ForeignKey("mcp_stocks.stock_id"),
        nullable=False,
        index=True,
    )
    date_analyzed = Column(
        Date, nullable=False, default=lambda: datetime.now(UTC).date()
    )

    # OHLCV Data
    open_price = Column(Numeric(12, 4), default=0)
    high_price = Column(Numeric(12, 4), default=0)
    low_price = Column(Numeric(12, 4), default=0)
    close_price = Column(Numeric(12, 4), default=0)
    volume = Column(BigInteger, default=0)

    # Technical Indicators
    momentum_score = Column(Numeric(5, 2), default=0)  # formerly rs_rating
    ema_21 = Column(Numeric(12, 4), default=0)
    sma_50 = Column(Numeric(12, 4), default=0)
    sma_200 = Column(Numeric(12, 4), default=0)
    rsi_14 = Column(Numeric(5, 2), default=0)

    # MACD Indicators
    macd = Column(Numeric(12, 6), default=0)
    macd_signal = Column(Numeric(12, 6), default=0)
    macd_histogram = Column(Numeric(12, 6), default=0)

    # Additional Bear Market Indicators
    dist_days_20 = Column(Integer, default=0)  # Days from 20 SMA
    adr_pct = Column(Numeric(5, 2), default=0)
    atr_contraction = Column(Boolean, default=False)
    atr = Column(Numeric(12, 4), default=0)
    avg_vol_30d = Column(Numeric(15, 2), default=0)
    big_down_vol = Column(Boolean, default=False)

    # Pattern Analysis
    squeeze_status = Column(String(50))  # 'sqz' field
    consolidation_status = Column(String(50))  # formerly vcp_status, 'vcp' field

    # Scoring
    score = Column(Integer, default=0)

    # Relationships
    stock = relationship("Stock", back_populates="maverick_bear_stocks")

    def __repr__(self):
        return f"<MaverickBearStock(stock_id={self.stock_id}, close={self.close_price}, score={self.score})>"

    @classmethod
    def get_top_stocks(
        cls, session: Session, limit: int = 20
    ) -> Sequence[MaverickBearStocks]:
        """Get top maverick bear stocks by score."""
        return (
            session.query(cls).join(Stock).order_by(cls.score.desc()).limit(limit).all()
        )

    @classmethod
    def get_latest_analysis(
        cls, session: Session, days_back: int = 1
    ) -> Sequence[MaverickBearStocks]:
        """Get latest bear analysis within specified days."""
        cutoff_date = datetime.now(UTC).date() - timedelta(days=days_back)
        return (
            session.query(cls)
            .join(Stock)
            .filter(cls.date_analyzed >= cutoff_date)
            .order_by(cls.score.desc())
            .all()
        )

    def to_dict(self) -> dict:
        """Convert to dictionary for JSON serialization."""
        return {
            "stock_id": str(self.stock_id),
            "ticker": self.stock.ticker_symbol if self.stock else None,
            "date_analyzed": self.date_analyzed.isoformat()
            if self.date_analyzed
            else None,
            "close": float(self.close_price) if self.close_price else 0,
            "volume": self.volume,
            "momentum_score": float(self.momentum_score)
            if self.momentum_score
            else 0,  # formerly rs_rating
            "rsi_14": float(self.rsi_14) if self.rsi_14 else 0,
            "macd": float(self.macd) if self.macd else 0,
            "macd_signal": float(self.macd_signal) if self.macd_signal else 0,
            "macd_histogram": float(self.macd_histogram) if self.macd_histogram else 0,
            "adr_pct": float(self.adr_pct) if self.adr_pct else 0,
            "atr": float(self.atr) if self.atr else 0,
            "atr_contraction": self.atr_contraction,
            "avg_vol_30d": float(self.avg_vol_30d) if self.avg_vol_30d else 0,
            "big_down_vol": self.big_down_vol,
            "score": self.score,
            "squeeze": self.squeeze_status,
            "consolidation": self.consolidation_status,  # formerly vcp
            "ema_21": float(self.ema_21) if self.ema_21 else 0,
            "sma_50": float(self.sma_50) if self.sma_50 else 0,
            "sma_200": float(self.sma_200) if self.sma_200 else 0,
            "dist_days_20": self.dist_days_20,
        }


class SupplyDemandBreakoutStocks(Base, TimestampMixin):
    """Supply/demand breakout stocks screening results - self-contained model.

    This model identifies stocks experiencing accumulation breakouts with strong relative strength,
    indicating a potential shift from supply to demand dominance in the market structure.
    """

    __tablename__ = "mcp_supply_demand_breakouts"
    __table_args__ = (
        Index(
            "mcp_supply_demand_breakouts_momentum_score_idx", "momentum_score"
        ),  # formerly rs_rating_idx
        Index("mcp_supply_demand_breakouts_date_analyzed_idx", "date_analyzed"),
        Index(
            "mcp_supply_demand_breakouts_stock_date_idx", "stock_id", "date_analyzed"
        ),
        Index(
            "mcp_supply_demand_breakouts_ma_filter_idx",
            "close_price",
            "sma_50",
            "sma_150",
            "sma_200",
        ),
    )

    id = Column(get_primary_key_type(), primary_key=True, autoincrement=True)
    stock_id = Column(
        Uuid,
        ForeignKey("mcp_stocks.stock_id"),
        nullable=False,
        index=True,
    )
    date_analyzed = Column(
        Date, nullable=False, default=lambda: datetime.now(UTC).date()
    )

    # OHLCV Data
    open_price = Column(Numeric(12, 4), default=0)
    high_price = Column(Numeric(12, 4), default=0)
    low_price = Column(Numeric(12, 4), default=0)
    close_price = Column(Numeric(12, 4), default=0)
    volume = Column(BigInteger, default=0)

    # Technical Indicators
    ema_21 = Column(Numeric(12, 4), default=0)
    sma_50 = Column(Numeric(12, 4), default=0)
    sma_150 = Column(Numeric(12, 4), default=0)
    sma_200 = Column(Numeric(12, 4), default=0)
    momentum_score = Column(Numeric(5, 2), default=0)  # formerly rs_rating
    avg_volume_30d = Column(Numeric(15, 2), default=0)
    adr_pct = Column(Numeric(5, 2), default=0)
    atr = Column(Numeric(12, 4), default=0)

    # Pattern Analysis
    pattern_type = Column(String(50))  # 'pat' field
    squeeze_status = Column(String(50))  # 'sqz' field
    consolidation_status = Column(String(50))  # formerly vcp_status, 'vcp' field
    entry_signal = Column(String(50))  # 'entry' field

    # Supply/Demand Analysis
    accumulation_rating = Column(Numeric(5, 2), default=0)
    distribution_rating = Column(Numeric(5, 2), default=0)
    breakout_strength = Column(Numeric(5, 2), default=0)

    # Relationships
    stock = relationship("Stock", back_populates="supply_demand_stocks")

    def __repr__(self):
        return f"<SupplyDemandBreakoutStock(stock_id={self.stock_id}, close={self.close_price}, momentum={self.momentum_score})>"  # formerly rs

    @classmethod
    def get_top_stocks(
        cls, session: Session, limit: int = 20
    ) -> Sequence[SupplyDemandBreakoutStocks]:
        """Get top supply/demand breakout stocks by momentum score."""  # formerly relative strength rating
        return (
            session.query(cls)
            .join(Stock)
            .order_by(cls.momentum_score.desc())  # formerly rs_rating
            .limit(limit)
            .all()
        )

    @classmethod
    def get_stocks_above_moving_averages(
        cls, session: Session
    ) -> Sequence[SupplyDemandBreakoutStocks]:
        """Get stocks in demand expansion phase - trading above all major moving averages.

        This identifies stocks with:
        - Price above 50, 150, and 200-day moving averages (demand zone)
        - Upward trending moving averages (accumulation structure)
        - Indicates institutional accumulation and supply absorption
        """
        return (
            session.query(cls)
            .join(Stock)
            .filter(
                cls.close_price > cls.sma_50,
                cls.close_price > cls.sma_150,
                cls.close_price > cls.sma_200,
                cls.sma_50 > cls.sma_150,
                cls.sma_150 > cls.sma_200,
            )
            .order_by(cls.momentum_score.desc())  # formerly rs_rating
            .all()
        )

    @classmethod
    def get_latest_analysis(
        cls, session: Session, days_back: int = 1
    ) -> Sequence[SupplyDemandBreakoutStocks]:
        """Get latest supply/demand analysis within specified days."""
        cutoff_date = datetime.now(UTC).date() - timedelta(days=days_back)
        return (
            session.query(cls)
            .join(Stock)
            .filter(cls.date_analyzed >= cutoff_date)
            .order_by(cls.momentum_score.desc())  # formerly rs_rating
            .all()
        )

    def to_dict(self) -> dict:
        """Convert to dictionary for JSON serialization."""
        return {
            "stock_id": str(self.stock_id),
            "ticker": self.stock.ticker_symbol if self.stock else None,
            "date_analyzed": self.date_analyzed.isoformat()
            if self.date_analyzed
            else None,
            "close": float(self.close_price) if self.close_price else 0,
            "volume": self.volume,
            "momentum_score": float(self.momentum_score)
            if self.momentum_score
            else 0,  # formerly rs_rating
            "adr_pct": float(self.adr_pct) if self.adr_pct else 0,
            "pattern": self.pattern_type,
            "squeeze": self.squeeze_status,
            "consolidation": self.consolidation_status,  # formerly vcp
            "entry": self.entry_signal,
            "ema_21": float(self.ema_21) if self.ema_21 else 0,
            "sma_50": float(self.sma_50) if self.sma_50 else 0,
            "sma_150": float(self.sma_150) if self.sma_150 else 0,
            "sma_200": float(self.sma_200) if self.sma_200 else 0,
            "atr": float(self.atr) if self.atr else 0,
            "avg_volume_30d": float(self.avg_volume_30d) if self.avg_volume_30d else 0,
            "accumulation_rating": float(self.accumulation_rating)
            if self.accumulation_rating
            else 0,
            "distribution_rating": float(self.distribution_rating)
            if self.distribution_rating
            else 0,
            "breakout_strength": float(self.breakout_strength)
            if self.breakout_strength
            else 0,
        }


class TechnicalCache(Base, TimestampMixin):
    """Cache for calculated technical indicators."""

    __tablename__ = "mcp_technical_cache"
    __table_args__ = (
        UniqueConstraint(
            "stock_id",
            "date",
            "indicator_type",
            name="mcp_technical_cache_stock_date_indicator_unique",
        ),
        Index("mcp_technical_cache_stock_date_idx", "stock_id", "date"),
        Index("mcp_technical_cache_indicator_idx", "indicator_type"),
        Index("mcp_technical_cache_date_idx", "date"),
    )

    id = Column(get_primary_key_type(), primary_key=True, autoincrement=True)
    stock_id = Column(Uuid, ForeignKey("mcp_stocks.stock_id"), nullable=False)
    date = Column(Date, nullable=False)
    indicator_type = Column(
        String(50), nullable=False
    )  # 'SMA_20', 'EMA_21', 'RSI_14', etc.

    # Flexible indicator values
    value = Column(Numeric(20, 8))  # Primary indicator value
    value_2 = Column(Numeric(20, 8))  # Secondary value (e.g., MACD signal)
    value_3 = Column(Numeric(20, 8))  # Tertiary value (e.g., MACD histogram)

    # Text values for complex indicators
    meta_data = Column(Text)  # JSON string for additional metadata

    # Calculation parameters
    period = Column(Integer)  # Period used (20 for SMA_20, etc.)
    parameters = Column(Text)  # JSON string for additional parameters

    # Relationships
    stock = relationship("Stock", back_populates="technical_cache")

    def __repr__(self):
        return (
            f"<TechnicalCache(stock_id={self.stock_id}, date={self.date}, "
            f"indicator={self.indicator_type}, value={self.value})>"
        )

    @classmethod
    def get_indicator(
        cls,
        session: Session,
        ticker_symbol: str,
        indicator_type: str,
        start_date: str,
        end_date: str | None = None,
    ) -> pd.DataFrame:
        """
        Get technical indicator data for a symbol and date range.

        Args:
            session: Database session
            ticker_symbol: Stock ticker symbol
            indicator_type: Type of indicator (e.g., 'SMA_20', 'RSI_14')
            start_date: Start date in YYYY-MM-DD format
            end_date: End date in YYYY-MM-DD format (default: today)

        Returns:
            DataFrame with indicator data indexed by date
        """
        if not end_date:
            end_date = datetime.now(UTC).strftime("%Y-%m-%d")

        query = (
            session.query(
                cls.date,
                cls.value,
                cls.value_2,
                cls.value_3,
                cls.meta_data,
                cls.parameters,
            )
            .join(Stock)
            .filter(
                Stock.ticker_symbol == ticker_symbol.upper(),
                cls.indicator_type == indicator_type,
                cls.date >= pd.to_datetime(start_date).date(),
                cls.date <= pd.to_datetime(end_date).date(),
            )
            .order_by(cls.date)
        )

        df = pd.DataFrame(query.all())

        if not df.empty:
            df["date"] = pd.to_datetime(df["date"])
            df.set_index("date", inplace=True)

            # Convert decimal types to float
            for col in ["value", "value_2", "value_3"]:
                if col in df.columns:
                    df[col] = df[col].astype(float)

            df["symbol"] = ticker_symbol.upper()
            df["indicator_type"] = indicator_type

        return df

    def to_dict(self) -> dict:
        """Convert to dictionary for JSON serialization."""
        return {
            "stock_id": str(self.stock_id),
            "date": self.date.isoformat() if self.date else None,
            "indicator_type": self.indicator_type,
            "value": float(self.value) if self.value else None,
            "value_2": float(self.value_2) if self.value_2 else None,
            "value_3": float(self.value_3) if self.value_3 else None,
            "period": self.period,
            "meta_data": self.meta_data,
            "parameters": self.parameters,
        }


# Backtesting Models


class BacktestResult(Base, TimestampMixin):
    """Main backtest results table with comprehensive metrics."""

    __tablename__ = "mcp_backtest_results"
    __table_args__ = (
        Index("mcp_backtest_results_symbol_idx", "symbol"),
        Index("mcp_backtest_results_strategy_idx", "strategy_type"),
        Index("mcp_backtest_results_date_idx", "backtest_date"),
        Index("mcp_backtest_results_sharpe_idx", "sharpe_ratio"),
        Index("mcp_backtest_results_total_return_idx", "total_return"),
        Index("mcp_backtest_results_symbol_strategy_idx", "symbol", "strategy_type"),
    )

    backtest_id = Column(Uuid, primary_key=True, default=uuid.uuid4)

    # Basic backtest metadata
    symbol = Column(String(10), nullable=False, index=True)
    strategy_type = Column(String(50), nullable=False)
    backtest_date = Column(
        DateTime(timezone=True), nullable=False, default=lambda: datetime.now(UTC)
    )

    # Date range and setup
    start_date = Column(Date, nullable=False)
    end_date = Column(Date, nullable=False)
    initial_capital = Column(Numeric(15, 2), default=10000.0)

    # Trading costs and parameters
    fees = Column(Numeric(6, 4), default=0.001)  # 0.1% default
    slippage = Column(Numeric(6, 4), default=0.001)  # 0.1% default

    # Strategy parameters (stored as JSON for flexibility)
    parameters = Column(JSON)

    # Key Performance Metrics
    total_return = Column(Numeric(10, 4))  # Total return percentage
    annualized_return = Column(Numeric(10, 4))  # Annualized return percentage
    sharpe_ratio = Column(Numeric(8, 4))
    sortino_ratio = Column(Numeric(8, 4))
    calmar_ratio = Column(Numeric(8, 4))

    # Risk Metrics
    max_drawdown = Column(Numeric(8, 4))  # Maximum drawdown percentage
    max_drawdown_duration = Column(Integer)  # Days
    volatility = Column(Numeric(8, 4))  # Annualized volatility
    downside_volatility = Column(Numeric(8, 4))  # Downside deviation

    # Trade Statistics
    total_trades = Column(Integer, default=0)
    winning_trades = Column(Integer, default=0)
    losing_trades = Column(Integer, default=0)
    win_rate = Column(Numeric(5, 4))  # Win rate percentage

    # P&L Statistics
    profit_factor = Column(Numeric(8, 4))  # Gross profit / Gross loss
    average_win = Column(Numeric(12, 4))
    average_loss = Column(Numeric(12, 4))
    largest_win = Column(Numeric(12, 4))
    largest_loss = Column(Numeric(12, 4))

    # Portfolio Value Metrics
    final_portfolio_value = Column(Numeric(15, 2))
    peak_portfolio_value = Column(Numeric(15, 2))

    # Additional Analysis
    beta = Column(Numeric(8, 4))  # Market beta
    alpha = Column(Numeric(8, 4))  # Alpha vs market

    # Time series data (stored as JSON for efficient queries)
    equity_curve = Column(JSON)  # Daily portfolio values
    drawdown_series = Column(JSON)  # Daily drawdown values

    # Execution metadata
    execution_time_seconds = Column(Numeric(8, 3))  # How long the backtest took
    data_points = Column(Integer)  # Number of data points used

    # Status and notes
    status = Column(String(20), default="completed")  # completed, failed, in_progress
    error_message = Column(Text)  # Error details if status = failed
    notes = Column(Text)  # User notes

    # Relationships
    trades = relationship(
        "BacktestTrade",
        back_populates="backtest_result",
        cascade="all, delete-orphan",
        lazy="selectin",
    )
    optimization_results = relationship(
        "OptimizationResult",
        back_populates="backtest_result",
        cascade="all, delete-orphan",
    )

    def __repr__(self):
        return (
            f"<BacktestResult(id={self.backtest_id}, symbol={self.symbol}, "
            f"strategy={self.strategy_type}, return={self.total_return})>"
        )

    @classmethod
    def get_by_symbol_and_strategy(
        cls, session: Session, symbol: str, strategy_type: str, limit: int = 10
    ) -> Sequence[BacktestResult]:
        """Get recent backtests for a specific symbol and strategy."""
        return (
            session.query(cls)
            .filter(cls.symbol == symbol.upper(), cls.strategy_type == strategy_type)
            .order_by(cls.backtest_date.desc())
            .limit(limit)
            .all()
        )

    @classmethod
    def get_best_performing(
        cls, session: Session, metric: str = "sharpe_ratio", limit: int = 20
    ) -> Sequence[BacktestResult]:
        """Get best performing backtests by specified metric."""
        metric_column = getattr(cls, metric, cls.sharpe_ratio)
        return (
            session.query(cls)
            .filter(cls.status == "completed")
            .order_by(metric_column.desc())
            .limit(limit)
            .all()
        )

    def to_dict(self) -> dict:
        """Convert to dictionary for JSON serialization."""
        return {
            "backtest_id": str(self.backtest_id),
            "symbol": self.symbol,
            "strategy_type": self.strategy_type,
            "backtest_date": self.backtest_date.isoformat()
            if self.backtest_date
            else None,
            "start_date": self.start_date.isoformat() if self.start_date else None,
            "end_date": self.end_date.isoformat() if self.end_date else None,
            "initial_capital": float(self.initial_capital)
            if self.initial_capital
            else 0,
            "total_return": float(self.total_return) if self.total_return else 0,
            "sharpe_ratio": float(self.sharpe_ratio) if self.sharpe_ratio else 0,
            "max_drawdown": float(self.max_drawdown) if self.max_drawdown else 0,
            "win_rate": float(self.win_rate) if self.win_rate else 0,
            "total_trades": self.total_trades,
            "parameters": self.parameters,
            "status": self.status,
        }


class BacktestTrade(Base, TimestampMixin):
    """Individual trade records from backtests."""

    __tablename__ = "mcp_backtest_trades"
    __table_args__ = (
        Index("mcp_backtest_trades_backtest_idx", "backtest_id"),
        Index("mcp_backtest_trades_entry_date_idx", "entry_date"),
        Index("mcp_backtest_trades_exit_date_idx", "exit_date"),
        Index("mcp_backtest_trades_pnl_idx", "pnl"),
        Index("mcp_backtest_trades_backtest_entry_idx", "backtest_id", "entry_date"),
    )

    trade_id = Column(Uuid, primary_key=True, default=uuid.uuid4)
    backtest_id = Column(
        Uuid, ForeignKey("mcp_backtest_results.backtest_id"), nullable=False
    )

    # Trade identification
    trade_number = Column(
        Integer, nullable=False
    )  # Sequential trade number in backtest

    # Entry details
    entry_date = Column(Date, nullable=False)
    entry_price = Column(Numeric(12, 4), nullable=False)
    entry_time = Column(DateTime(timezone=True))  # For intraday backtests

    # Exit details
    exit_date = Column(Date)
    exit_price = Column(Numeric(12, 4))
    exit_time = Column(DateTime(timezone=True))

    # Position details
    position_size = Column(Numeric(15, 2))  # Number of shares/units
    direction = Column(String(5), nullable=False)  # 'long' or 'short'

    # P&L and performance
    pnl = Column(Numeric(12, 4))  # Profit/Loss in currency
    pnl_percent = Column(Numeric(8, 4))  # P&L as percentage

    # Risk metrics for this trade
    mae = Column(Numeric(8, 4))  # Maximum Adverse Excursion
    mfe = Column(Numeric(8, 4))  # Maximum Favorable Excursion

    # Trade duration
    duration_days = Column(Integer)
    duration_hours = Column(Numeric(8, 2))  # For intraday precision

    # Exit reason and fees
    exit_reason = Column(String(50))  # stop_loss, take_profit, signal, time_exit
    fees_paid = Column(Numeric(10, 4), default=0)
    slippage_cost = Column(Numeric(10, 4), default=0)

    # Relationships
    backtest_result = relationship(
        "BacktestResult", back_populates="trades", lazy="joined"
    )

    def __repr__(self):
        return (
            f"<BacktestTrade(id={self.trade_id}, backtest_id={self.backtest_id}, "
            f"pnl={self.pnl}, duration={self.duration_days}d)>"
        )

    @classmethod
    def get_trades_for_backtest(
        cls, session: Session, backtest_id: str
    ) -> Sequence[BacktestTrade]:
        """Get all trades for a specific backtest."""
        return (
            session.query(cls)
            .filter(cls.backtest_id == backtest_id)
            .order_by(cls.entry_date, cls.trade_number)
            .all()
        )

    @classmethod
    def get_winning_trades(
        cls, session: Session, backtest_id: str
    ) -> Sequence[BacktestTrade]:
        """Get winning trades for a backtest."""
        return (
            session.query(cls)
            .filter(cls.backtest_id == backtest_id, cls.pnl > 0)
            .order_by(cls.pnl.desc())
            .all()
        )

    @classmethod
    def get_losing_trades(
        cls, session: Session, backtest_id: str
    ) -> Sequence[BacktestTrade]:
        """Get losing trades for a backtest."""
        return (
            session.query(cls)
            .filter(cls.backtest_id == backtest_id, cls.pnl < 0)
            .order_by(cls.pnl)
            .all()
        )


class OptimizationResult(Base, TimestampMixin):
    """Parameter optimization results for strategies."""

    __tablename__ = "mcp_optimization_results"
    __table_args__ = (
        Index("mcp_optimization_results_backtest_idx", "backtest_id"),
        Index("mcp_optimization_results_param_set_idx", "parameter_set"),
        Index("mcp_optimization_results_objective_idx", "objective_value"),
    )

    optimization_id = Column(Uuid, primary_key=True, default=uuid.uuid4)
    backtest_id = Column(
        Uuid, ForeignKey("mcp_backtest_results.backtest_id"), nullable=False
    )

    # Optimization metadata
    optimization_date = Column(
        DateTime(timezone=True), default=lambda: datetime.now(UTC)
    )
    parameter_set = Column(Integer, nullable=False)  # Set number in optimization run

    # Parameters tested (JSON for flexibility)
    parameters = Column(JSON, nullable=False)

    # Optimization objective and results
    objective_function = Column(
        String(50)
    )  # sharpe_ratio, total_return, profit_factor, etc.
    objective_value = Column(Numeric(12, 6))  # Value of objective function

    # Key metrics for this parameter set
    total_return = Column(Numeric(10, 4))
    sharpe_ratio = Column(Numeric(8, 4))
    max_drawdown = Column(Numeric(8, 4))
    win_rate = Column(Numeric(5, 4))
    profit_factor = Column(Numeric(8, 4))
    total_trades = Column(Integer)

    # Ranking within optimization
    rank = Column(Integer)  # 1 = best, 2 = second best, etc.

    # Statistical significance
    is_statistically_significant = Column(Boolean, default=False)
    p_value = Column(Numeric(8, 6))  # Statistical significance test result

    # Relationships
    backtest_result = relationship(
        "BacktestResult", back_populates="optimization_results", lazy="joined"
    )

    def __repr__(self):
        return (
            f"<OptimizationResult(id={self.optimization_id}, "
            f"objective={self.objective_value}, rank={self.rank})>"
        )

    @classmethod
    def get_best_parameters(
        cls, session: Session, backtest_id: str, limit: int = 5
    ) -> Sequence[OptimizationResult]:
        """Get top performing parameter sets for a backtest."""
        return (
            session.query(cls)
            .filter(cls.backtest_id == backtest_id)
            .order_by(cls.rank)
            .limit(limit)
            .all()
        )


class WalkForwardTest(Base, TimestampMixin):
    """Walk-forward validation test results."""

    __tablename__ = "mcp_walk_forward_tests"
    __table_args__ = (
        Index("mcp_walk_forward_tests_parent_idx", "parent_backtest_id"),
        Index("mcp_walk_forward_tests_period_idx", "test_period_start"),
        Index("mcp_walk_forward_tests_performance_idx", "out_of_sample_return"),
    )

    walk_forward_id = Column(Uuid, primary_key=True, default=uuid.uuid4)
    parent_backtest_id = Column(
        Uuid, ForeignKey("mcp_backtest_results.backtest_id"), nullable=False
    )

    # Test configuration
    test_date = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
    window_size_months = Column(Integer, nullable=False)  # Training window size
    step_size_months = Column(Integer, nullable=False)  # Step size for walking forward

    # Time periods
    training_start = Column(Date, nullable=False)
    training_end = Column(Date, nullable=False)
    test_period_start = Column(Date, nullable=False)
    test_period_end = Column(Date, nullable=False)

    # Optimization results from training period
    optimal_parameters = Column(JSON)  # Best parameters from training
    training_performance = Column(Numeric(10, 4))  # Training period return

    # Out-of-sample test results
    out_of_sample_return = Column(Numeric(10, 4))
    out_of_sample_sharpe = Column(Numeric(8, 4))
    out_of_sample_drawdown = Column(Numeric(8, 4))
    out_of_sample_trades = Column(Integer)

    # Performance vs training expectations
    performance_ratio = Column(Numeric(8, 4))  # Out-sample return / Training return
    degradation_factor = Column(Numeric(8, 4))  # How much performance degraded

    # Statistical validation
    is_profitable = Column(Boolean)
    is_statistically_significant = Column(Boolean, default=False)

    # Relationships
    parent_backtest = relationship(
        "BacktestResult", foreign_keys=[parent_backtest_id], lazy="joined"
    )

    def __repr__(self):
        return (
            f"<WalkForwardTest(id={self.walk_forward_id}, "
            f"return={self.out_of_sample_return}, ratio={self.performance_ratio})>"
        )

    @classmethod
    def get_walk_forward_results(
        cls, session: Session, parent_backtest_id: str
    ) -> Sequence[WalkForwardTest]:
        """Get all walk-forward test results for a backtest."""
        return (
            session.query(cls)
            .filter(cls.parent_backtest_id == parent_backtest_id)
            .order_by(cls.test_period_start)
            .all()
        )


class BacktestPortfolio(Base, TimestampMixin):
    """Portfolio-level backtests with multiple symbols."""

    __tablename__ = "mcp_backtest_portfolios"
    __table_args__ = (
        Index("mcp_backtest_portfolios_name_idx", "portfolio_name"),
        Index("mcp_backtest_portfolios_date_idx", "backtest_date"),
        Index("mcp_backtest_portfolios_return_idx", "total_return"),
    )

    portfolio_backtest_id = Column(Uuid, primary_key=True, default=uuid.uuid4)

    # Portfolio identification
    portfolio_name = Column(String(100), nullable=False)
    description = Column(Text)

    # Test metadata
    backtest_date = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
    start_date = Column(Date, nullable=False)
    end_date = Column(Date, nullable=False)

    # Portfolio composition
    symbols = Column(JSON, nullable=False)  # List of symbols
    weights = Column(JSON)  # Portfolio weights (if not equal weight)
    rebalance_frequency = Column(String(20))  # daily, weekly, monthly, quarterly

    # Portfolio parameters
    initial_capital = Column(Numeric(15, 2), default=100000.0)
    max_positions = Column(Integer)  # Maximum concurrent positions
    position_sizing_method = Column(
        String(50)
    )  # equal_weight, volatility_weighted, etc.

    # Risk management
    portfolio_stop_loss = Column(Numeric(6, 4))  # Portfolio-level stop loss
    max_sector_allocation = Column(Numeric(5, 4))  # Maximum allocation per sector
    correlation_threshold = Column(
        Numeric(5, 4)
    )  # Maximum correlation between holdings

    # Performance metrics (portfolio level)
    total_return = Column(Numeric(10, 4))
    annualized_return = Column(Numeric(10, 4))
    sharpe_ratio = Column(Numeric(8, 4))
    sortino_ratio = Column(Numeric(8, 4))
    max_drawdown = Column(Numeric(8, 4))
    volatility = Column(Numeric(8, 4))

    # Portfolio-specific metrics
    diversification_ratio = Column(Numeric(8, 4))  # Portfolio vol / Weighted avg vol
    concentration_index = Column(Numeric(8, 4))  # Herfindahl index
    turnover_rate = Column(Numeric(8, 4))  # Portfolio turnover

    # Individual component backtests (JSON references)
    component_backtest_ids = Column(JSON)  # List of individual backtest IDs

    # Time series data
    portfolio_equity_curve = Column(JSON)
    portfolio_weights_history = Column(JSON)  # Historical weights over time

    # Status
    status = Column(String(20), default="completed")
    notes = Column(Text)

    def __repr__(self):
        return (
            f"<BacktestPortfolio(id={self.portfolio_backtest_id}, "
            f"name={self.portfolio_name}, return={self.total_return})>"
        )

    @classmethod
    def get_portfolio_backtests(
        cls, session: Session, portfolio_name: str | None = None, limit: int = 10
    ) -> Sequence[BacktestPortfolio]:
        """Get portfolio backtests, optionally filtered by name."""
        query = session.query(cls).order_by(cls.backtest_date.desc())
        if portfolio_name:
            query = query.filter(cls.portfolio_name == portfolio_name)
        return query.limit(limit).all()

    def to_dict(self) -> dict:
        """Convert to dictionary for JSON serialization."""
        return {
            "portfolio_backtest_id": str(self.portfolio_backtest_id),
            "portfolio_name": self.portfolio_name,
            "symbols": self.symbols,
            "start_date": self.start_date.isoformat() if self.start_date else None,
            "end_date": self.end_date.isoformat() if self.end_date else None,
            "total_return": float(self.total_return) if self.total_return else 0,
            "sharpe_ratio": float(self.sharpe_ratio) if self.sharpe_ratio else 0,
            "max_drawdown": float(self.max_drawdown) if self.max_drawdown else 0,
            "status": self.status,
        }


# Helper functions for working with the models
def bulk_insert_price_data(
    session: Session, ticker_symbol: str, df: pd.DataFrame
) -> int:
    """
    Bulk insert price data from a DataFrame.

    Args:
        session: Database session
        ticker_symbol: Stock ticker symbol
        df: DataFrame with OHLCV data (must have date index)

    Returns:
        Number of records inserted (or would be inserted)
    """
    if df.empty:
        return 0

    # Get or create stock
    stock = Stock.get_or_create(session, ticker_symbol)

    # First, check how many records already exist
    existing_dates = set()
    if hasattr(df.index[0], "date"):
        dates_to_check = [d.date() for d in df.index]
    else:
        dates_to_check = list(df.index)

    existing_query = session.query(PriceCache.date).filter(
        PriceCache.stock_id == stock.stock_id, PriceCache.date.in_(dates_to_check)
    )
    existing_dates = {row[0] for row in existing_query.all()}

    # Prepare data for bulk insert
    records = []
    new_count = 0
    for date_idx, row in df.iterrows():
        # Handle different index types - datetime index vs date index
        if hasattr(date_idx, "date") and callable(date_idx.date):
            date_val = date_idx.date()  # type: ignore[attr-defined]
        elif hasattr(date_idx, "to_pydatetime") and callable(date_idx.to_pydatetime):
            date_val = date_idx.to_pydatetime().date()  # type: ignore[attr-defined]
        else:
            # Assume it's already a date-like object
            date_val = date_idx

        # Skip if already exists
        if date_val in existing_dates:
            continue

        new_count += 1

        # Handle both lowercase and capitalized column names from yfinance
        open_val = row.get("open", row.get("Open", 0))
        high_val = row.get("high", row.get("High", 0))
        low_val = row.get("low", row.get("Low", 0))
        close_val = row.get("close", row.get("Close", 0))
        volume_val = row.get("volume", row.get("Volume", 0))

        # Handle None values
        if volume_val is None:
            volume_val = 0

        records.append(
            {
                "stock_id": stock.stock_id,
                "date": date_val,
                "open_price": Decimal(str(open_val)),
                "high_price": Decimal(str(high_val)),
                "low_price": Decimal(str(low_val)),
                "close_price": Decimal(str(close_val)),
                "volume": int(volume_val),
                "created_at": datetime.now(UTC),
                "updated_at": datetime.now(UTC),
            }
        )

    # Only insert if there are new records
    if records:
        # Use database-specific upsert logic
        if "postgresql" in DATABASE_URL:
            from sqlalchemy.dialects.postgresql import insert

            stmt = insert(PriceCache).values(records)
            stmt = stmt.on_conflict_do_nothing(index_elements=["stock_id", "date"])
        else:
            # For SQLite, use INSERT OR IGNORE
            from sqlalchemy import insert

            stmt = insert(PriceCache).values(records)
            # SQLite doesn't support on_conflict_do_nothing, use INSERT OR IGNORE
            stmt = stmt.prefix_with("OR IGNORE")

        result = session.execute(stmt)
        session.commit()

        # Log if rowcount differs from expected
        if result.rowcount != new_count:
            logger.warning(
                f"Expected to insert {new_count} records but rowcount was {result.rowcount}"
            )

        return result.rowcount
    else:
        logger.debug(
            f"All {len(df)} records already exist in cache for {ticker_symbol}"
        )
        return 0


def get_latest_maverick_screening(days_back: int = 1) -> dict:
    """Get latest screening results from all maverick tables."""
    with SessionLocal() as session:
        results = {
            "maverick_stocks": [
                stock.to_dict()
                for stock in MaverickStocks.get_latest_analysis(
                    session, days_back=days_back
                )
            ],
            "maverick_bear_stocks": [
                stock.to_dict()
                for stock in MaverickBearStocks.get_latest_analysis(
                    session, days_back=days_back
                )
            ],
            "supply_demand_breakouts": [
                stock.to_dict()
                for stock in SupplyDemandBreakoutStocks.get_latest_analysis(
                    session, days_back=days_back
                )
            ],
        }

    return results


def bulk_insert_screening_data(
    session: Session,
    model_class,
    screening_data: list[dict],
    date_analyzed: date | None = None,
) -> int:
    """
    Bulk insert screening data for any screening model.

    Args:
        session: Database session
        model_class: The screening model class (MaverickStocks, etc.)
        screening_data: List of screening result dictionaries
        date_analyzed: Date of analysis (default: today)

    Returns:
        Number of records inserted
    """
    if not screening_data:
        return 0

    if date_analyzed is None:
        date_analyzed = datetime.now(UTC).date()

    # Remove existing data for this date
    session.query(model_class).filter(
        model_class.date_analyzed == date_analyzed
    ).delete()

    inserted_count = 0
    for data in screening_data:
        # Get or create stock
        ticker = data.get("ticker") or data.get("symbol")
        if not ticker:
            continue

        stock = Stock.get_or_create(session, ticker)

        # Create screening record
        record_data = {
            "stock_id": stock.stock_id,
            "date_analyzed": date_analyzed,
        }

        # Map common fields
        field_mapping = {
            "open": "open_price",
            "high": "high_price",
            "low": "low_price",
            "close": "close_price",
            "pat": "pattern_type",
            "sqz": "squeeze_status",
            "vcp": "consolidation_status",
            "entry": "entry_signal",
        }

        for key, value in data.items():
            if key in ["ticker", "symbol"]:
                continue
            mapped_key = field_mapping.get(key, key)
            if hasattr(model_class, mapped_key):
                record_data[mapped_key] = value

        record = model_class(**record_data)
        session.add(record)
        inserted_count += 1

    session.commit()
    return inserted_count


# ============================================================================
# Portfolio Management Models
# ============================================================================


class UserPortfolio(TimestampMixin, Base):
    """
    User portfolio for tracking investment holdings.

    Follows personal-use design with single user_id="default" for the personal
    MaverickMCP server. Stores portfolio metadata and relationships to positions.

    Attributes:
        id: Unique portfolio identifier (UUID)
        user_id: User identifier (default: "default" for single-user)
        name: Portfolio display name
        positions: Relationship to PortfolioPosition records
    """

    __tablename__ = "mcp_portfolios"

    id = Column(Uuid, primary_key=True, default=uuid.uuid4)
    user_id = Column(String(100), nullable=False, default="default", index=True)
    name = Column(String(200), nullable=False, default="My Portfolio")

    # Relationships
    positions = relationship(
        "PortfolioPosition",
        back_populates="portfolio",
        cascade="all, delete-orphan",
        lazy="selectin",  # Efficient loading
    )

    # Indexes for queries
    __table_args__ = (
        Index("idx_portfolio_user", "user_id"),
        UniqueConstraint("user_id", "name", name="uq_user_portfolio_name"),
    )

    def __repr__(self):
        return f"<UserPortfolio(id={self.id}, name='{self.name}', positions={len(self.positions)})>"


class PortfolioPosition(TimestampMixin, Base):
    """
    Individual position within a portfolio with cost basis tracking.

    Stores position details with high-precision Decimal types for financial accuracy.
    Uses average cost basis method for educational simplicity.

    Attributes:
        id: Unique position identifier (UUID)
        portfolio_id: Foreign key to parent portfolio
        ticker: Stock ticker symbol (e.g., "AAPL")
        shares: Number of shares owned (supports fractional shares)
        average_cost_basis: Average cost per share
        total_cost: Total capital invested (shares × average_cost_basis)
        purchase_date: Earliest purchase date for this position
        notes: Optional user notes about the position
    """

    __tablename__ = "mcp_portfolio_positions"

    id = Column(Uuid, primary_key=True, default=uuid.uuid4)
    portfolio_id = Column(
        Uuid, ForeignKey("mcp_portfolios.id", ondelete="CASCADE"), nullable=False
    )

    # Position details with financial precision
    ticker = Column(String(20), nullable=False, index=True)
    shares = Column(
        Numeric(20, 8), nullable=False
    )  # High precision for fractional shares
    average_cost_basis = Column(
        Numeric(12, 4), nullable=False
    )  # 4 decimal places (cents)
    total_cost = Column(Numeric(20, 4), nullable=False)  # Total capital invested
    purchase_date = Column(DateTime(timezone=True), nullable=False)  # Earliest purchase
    notes = Column(Text, nullable=True)  # Optional user notes

    # Relationships
    portfolio = relationship("UserPortfolio", back_populates="positions")

    # Indexes for efficient queries
    __table_args__ = (
        Index("idx_position_portfolio", "portfolio_id"),
        Index("idx_position_ticker", "ticker"),
        Index("idx_position_portfolio_ticker", "portfolio_id", "ticker"),
        UniqueConstraint("portfolio_id", "ticker", name="uq_portfolio_position_ticker"),
    )

    def __repr__(self):
        return f"<PortfolioPosition(ticker='{self.ticker}', shares={self.shares}, cost_basis={self.average_cost_basis})>"


# Auth models removed for personal use - no multi-user functionality needed

# Initialize tables when module is imported
if __name__ == "__main__":
    logger.info("Creating database tables...")
    init_db()
    logger.info("Database tables created successfully!")

```
Page 28/29FirstPrevNextLast