This is page 33 of 39. Use http://codebase.md/wshobson/maverick-mcp?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.md
│ │ ├── config.yml
│ │ ├── feature_request.md
│ │ ├── question.md
│ │ └── security_report.md
│ ├── pull_request_template.md
│ └── workflows
│ ├── claude-code-review.yml
│ └── claude.yml
├── .gitignore
├── .python-version
├── .vscode
│ ├── launch.json
│ └── settings.json
├── alembic
│ ├── env.py
│ ├── script.py.mako
│ └── versions
│ ├── 001_initial_schema.py
│ ├── 003_add_performance_indexes.py
│ ├── 006_rename_metadata_columns.py
│ ├── 008_performance_optimization_indexes.py
│ ├── 009_rename_to_supply_demand.py
│ ├── 010_self_contained_schema.py
│ ├── 011_remove_proprietary_terms.py
│ ├── 013_add_backtest_persistence_models.py
│ ├── 014_add_portfolio_models.py
│ ├── 08e3945a0c93_merge_heads.py
│ ├── 9374a5c9b679_merge_heads_for_testing.py
│ ├── abf9b9afb134_merge_multiple_heads.py
│ ├── adda6d3fd84b_merge_proprietary_terms_removal_with_.py
│ ├── e0c75b0bdadb_fix_financial_data_precision_only.py
│ ├── f0696e2cac15_add_essential_performance_indexes.py
│ └── fix_database_integrity_issues.py
├── alembic.ini
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── DATABASE_SETUP.md
├── docker-compose.override.yml.example
├── docker-compose.yml
├── Dockerfile
├── docs
│ ├── api
│ │ └── backtesting.md
│ ├── BACKTESTING.md
│ ├── COST_BASIS_SPECIFICATION.md
│ ├── deep_research_agent.md
│ ├── exa_research_testing_strategy.md
│ ├── PORTFOLIO_PERSONALIZATION_PLAN.md
│ ├── PORTFOLIO.md
│ ├── SETUP_SELF_CONTAINED.md
│ └── speed_testing_framework.md
├── examples
│ ├── complete_speed_validation.py
│ ├── deep_research_integration.py
│ ├── llm_optimization_example.py
│ ├── llm_speed_demo.py
│ ├── monitoring_example.py
│ ├── parallel_research_example.py
│ ├── speed_optimization_demo.py
│ └── timeout_fix_demonstration.py
├── LICENSE
├── Makefile
├── MANIFEST.in
├── maverick_mcp
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── circuit_breaker.py
│ │ ├── deep_research.py
│ │ ├── market_analysis.py
│ │ ├── optimized_research.py
│ │ ├── supervisor.py
│ │ └── technical_analysis.py
│ ├── api
│ │ ├── __init__.py
│ │ ├── api_server.py
│ │ ├── connection_manager.py
│ │ ├── dependencies
│ │ │ ├── __init__.py
│ │ │ ├── stock_analysis.py
│ │ │ └── technical_analysis.py
│ │ ├── error_handling.py
│ │ ├── inspector_compatible_sse.py
│ │ ├── inspector_sse.py
│ │ ├── middleware
│ │ │ ├── error_handling.py
│ │ │ ├── mcp_logging.py
│ │ │ ├── rate_limiting_enhanced.py
│ │ │ └── security.py
│ │ ├── openapi_config.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── agents.py
│ │ │ ├── backtesting.py
│ │ │ ├── data_enhanced.py
│ │ │ ├── data.py
│ │ │ ├── health_enhanced.py
│ │ │ ├── health_tools.py
│ │ │ ├── health.py
│ │ │ ├── intelligent_backtesting.py
│ │ │ ├── introspection.py
│ │ │ ├── mcp_prompts.py
│ │ │ ├── monitoring.py
│ │ │ ├── news_sentiment_enhanced.py
│ │ │ ├── performance.py
│ │ │ ├── portfolio.py
│ │ │ ├── research.py
│ │ │ ├── screening_ddd.py
│ │ │ ├── screening_parallel.py
│ │ │ ├── screening.py
│ │ │ ├── technical_ddd.py
│ │ │ ├── technical_enhanced.py
│ │ │ ├── technical.py
│ │ │ └── tool_registry.py
│ │ ├── server.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── base_service.py
│ │ │ ├── market_service.py
│ │ │ ├── portfolio_service.py
│ │ │ ├── prompt_service.py
│ │ │ └── resource_service.py
│ │ ├── simple_sse.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── insomnia_export.py
│ │ └── postman_export.py
│ ├── application
│ │ ├── __init__.py
│ │ ├── commands
│ │ │ └── __init__.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_dto.py
│ │ ├── queries
│ │ │ ├── __init__.py
│ │ │ └── get_technical_analysis.py
│ │ └── screening
│ │ ├── __init__.py
│ │ ├── dtos.py
│ │ └── queries.py
│ ├── backtesting
│ │ ├── __init__.py
│ │ ├── ab_testing.py
│ │ ├── analysis.py
│ │ ├── batch_processing_stub.py
│ │ ├── batch_processing.py
│ │ ├── model_manager.py
│ │ ├── optimization.py
│ │ ├── persistence.py
│ │ ├── retraining_pipeline.py
│ │ ├── strategies
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── ml
│ │ │ │ ├── __init__.py
│ │ │ │ ├── adaptive.py
│ │ │ │ ├── ensemble.py
│ │ │ │ ├── feature_engineering.py
│ │ │ │ └── regime_aware.py
│ │ │ ├── ml_strategies.py
│ │ │ ├── parser.py
│ │ │ └── templates.py
│ │ ├── strategy_executor.py
│ │ ├── vectorbt_engine.py
│ │ └── visualization.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── database_self_contained.py
│ │ ├── database.py
│ │ ├── llm_optimization_config.py
│ │ ├── logging_settings.py
│ │ ├── plotly_config.py
│ │ ├── security_utils.py
│ │ ├── security.py
│ │ ├── settings.py
│ │ ├── technical_constants.py
│ │ ├── tool_estimation.py
│ │ └── validation.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── technical_analysis.py
│ │ └── visualization.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── cache_manager.py
│ │ ├── cache.py
│ │ ├── django_adapter.py
│ │ ├── health.py
│ │ ├── models.py
│ │ ├── performance.py
│ │ ├── session_management.py
│ │ └── validation.py
│ ├── database
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── optimization.py
│ ├── dependencies.py
│ ├── domain
│ │ ├── __init__.py
│ │ ├── entities
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis.py
│ │ ├── events
│ │ │ └── __init__.py
│ │ ├── portfolio.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ ├── entities.py
│ │ │ ├── services.py
│ │ │ └── value_objects.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ └── technical_analysis_service.py
│ │ ├── stock_analysis
│ │ │ ├── __init__.py
│ │ │ └── stock_analysis_service.py
│ │ └── value_objects
│ │ ├── __init__.py
│ │ └── technical_indicators.py
│ ├── exceptions.py
│ ├── infrastructure
│ │ ├── __init__.py
│ │ ├── cache
│ │ │ └── __init__.py
│ │ ├── caching
│ │ │ ├── __init__.py
│ │ │ └── cache_management_service.py
│ │ ├── connection_manager.py
│ │ ├── data_fetching
│ │ │ ├── __init__.py
│ │ │ └── stock_data_service.py
│ │ ├── health
│ │ │ ├── __init__.py
│ │ │ └── health_checker.py
│ │ ├── persistence
│ │ │ ├── __init__.py
│ │ │ └── stock_repository.py
│ │ ├── providers
│ │ │ └── __init__.py
│ │ ├── screening
│ │ │ ├── __init__.py
│ │ │ └── repositories.py
│ │ └── sse_optimizer.py
│ ├── langchain_tools
│ │ ├── __init__.py
│ │ ├── adapters.py
│ │ └── registry.py
│ ├── logging_config.py
│ ├── memory
│ │ ├── __init__.py
│ │ └── stores.py
│ ├── monitoring
│ │ ├── __init__.py
│ │ ├── health_check.py
│ │ ├── health_monitor.py
│ │ ├── integration_example.py
│ │ ├── metrics.py
│ │ ├── middleware.py
│ │ └── status_dashboard.py
│ ├── providers
│ │ ├── __init__.py
│ │ ├── dependencies.py
│ │ ├── factories
│ │ │ ├── __init__.py
│ │ │ ├── config_factory.py
│ │ │ └── provider_factory.py
│ │ ├── implementations
│ │ │ ├── __init__.py
│ │ │ ├── cache_adapter.py
│ │ │ ├── macro_data_adapter.py
│ │ │ ├── market_data_adapter.py
│ │ │ ├── persistence_adapter.py
│ │ │ └── stock_data_adapter.py
│ │ ├── interfaces
│ │ │ ├── __init__.py
│ │ │ ├── cache.py
│ │ │ ├── config.py
│ │ │ ├── macro_data.py
│ │ │ ├── market_data.py
│ │ │ ├── persistence.py
│ │ │ └── stock_data.py
│ │ ├── llm_factory.py
│ │ ├── macro_data.py
│ │ ├── market_data.py
│ │ ├── mocks
│ │ │ ├── __init__.py
│ │ │ ├── mock_cache.py
│ │ │ ├── mock_config.py
│ │ │ ├── mock_macro_data.py
│ │ │ ├── mock_market_data.py
│ │ │ ├── mock_persistence.py
│ │ │ └── mock_stock_data.py
│ │ ├── openrouter_provider.py
│ │ ├── optimized_screening.py
│ │ ├── optimized_stock_data.py
│ │ └── stock_data.py
│ ├── README.md
│ ├── tests
│ │ ├── __init__.py
│ │ ├── README_INMEMORY_TESTS.md
│ │ ├── test_cache_debug.py
│ │ ├── test_fixes_validation.py
│ │ ├── test_in_memory_routers.py
│ │ ├── test_in_memory_server.py
│ │ ├── test_macro_data_provider.py
│ │ ├── test_mailgun_email.py
│ │ ├── test_market_calendar_caching.py
│ │ ├── test_mcp_tool_fixes_pytest.py
│ │ ├── test_mcp_tool_fixes.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_models_functional.py
│ │ ├── test_server.py
│ │ ├── test_stock_data_enhanced.py
│ │ ├── test_stock_data_provider.py
│ │ └── test_technical_analysis.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── performance_monitoring.py
│ │ ├── portfolio_manager.py
│ │ ├── risk_management.py
│ │ └── sentiment_analysis.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── agent_errors.py
│ │ ├── batch_processing.py
│ │ ├── cache_warmer.py
│ │ ├── circuit_breaker_decorators.py
│ │ ├── circuit_breaker_services.py
│ │ ├── circuit_breaker.py
│ │ ├── data_chunking.py
│ │ ├── database_monitoring.py
│ │ ├── debug_utils.py
│ │ ├── fallback_strategies.py
│ │ ├── llm_optimization.py
│ │ ├── logging_example.py
│ │ ├── logging_init.py
│ │ ├── logging.py
│ │ ├── mcp_logging.py
│ │ ├── memory_profiler.py
│ │ ├── monitoring_middleware.py
│ │ ├── monitoring.py
│ │ ├── orchestration_logging.py
│ │ ├── parallel_research.py
│ │ ├── parallel_screening.py
│ │ ├── quick_cache.py
│ │ ├── resource_manager.py
│ │ ├── shutdown.py
│ │ ├── stock_helpers.py
│ │ ├── structured_logger.py
│ │ ├── tool_monitoring.py
│ │ ├── tracing.py
│ │ └── yfinance_pool.py
│ ├── validation
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── data.py
│ │ ├── middleware.py
│ │ ├── portfolio.py
│ │ ├── responses.py
│ │ ├── screening.py
│ │ └── technical.py
│ └── workflows
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── market_analyzer.py
│ │ ├── optimizer_agent.py
│ │ ├── strategy_selector.py
│ │ └── validator_agent.py
│ ├── backtesting_workflow.py
│ └── state.py
├── PLANS.md
├── pyproject.toml
├── pyrightconfig.json
├── README.md
├── scripts
│ ├── dev.sh
│ ├── INSTALLATION_GUIDE.md
│ ├── load_example.py
│ ├── load_market_data.py
│ ├── load_tiingo_data.py
│ ├── migrate_db.py
│ ├── README_TIINGO_LOADER.md
│ ├── requirements_tiingo.txt
│ ├── run_stock_screening.py
│ ├── run-migrations.sh
│ ├── seed_db.py
│ ├── seed_sp500.py
│ ├── setup_database.sh
│ ├── setup_self_contained.py
│ ├── setup_sp500_database.sh
│ ├── test_seeded_data.py
│ ├── test_tiingo_loader.py
│ ├── tiingo_config.py
│ └── validate_setup.py
├── SECURITY.md
├── server.json
├── setup.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── core
│ │ └── test_technical_analysis.py
│ ├── data
│ │ └── test_portfolio_models.py
│ ├── domain
│ │ ├── conftest.py
│ │ ├── test_portfolio_entities.py
│ │ └── test_technical_analysis_service.py
│ ├── fixtures
│ │ └── orchestration_fixtures.py
│ ├── integration
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── README.md
│ │ ├── run_integration_tests.sh
│ │ ├── test_api_technical.py
│ │ ├── test_chaos_engineering.py
│ │ ├── test_config_management.py
│ │ ├── test_full_backtest_workflow_advanced.py
│ │ ├── test_full_backtest_workflow.py
│ │ ├── test_high_volume.py
│ │ ├── test_mcp_tools.py
│ │ ├── test_orchestration_complete.py
│ │ ├── test_portfolio_persistence.py
│ │ ├── test_redis_cache.py
│ │ ├── test_security_integration.py.disabled
│ │ └── vcr_setup.py
│ ├── performance
│ │ ├── __init__.py
│ │ ├── test_benchmarks.py
│ │ ├── test_load.py
│ │ ├── test_profiling.py
│ │ └── test_stress.py
│ ├── providers
│ │ └── test_stock_data_simple.py
│ ├── README.md
│ ├── test_agents_router_mcp.py
│ ├── test_backtest_persistence.py
│ ├── test_cache_management_service.py
│ ├── test_cache_serialization.py
│ ├── test_circuit_breaker.py
│ ├── test_database_pool_config_simple.py
│ ├── test_database_pool_config.py
│ ├── test_deep_research_functional.py
│ ├── test_deep_research_integration.py
│ ├── test_deep_research_parallel_execution.py
│ ├── test_error_handling.py
│ ├── test_event_loop_integrity.py
│ ├── test_exa_research_integration.py
│ ├── test_exception_hierarchy.py
│ ├── test_financial_search.py
│ ├── test_graceful_shutdown.py
│ ├── test_integration_simple.py
│ ├── test_langgraph_workflow.py
│ ├── test_market_data_async.py
│ ├── test_market_data_simple.py
│ ├── test_mcp_orchestration_functional.py
│ ├── test_ml_strategies.py
│ ├── test_optimized_research_agent.py
│ ├── test_orchestration_integration.py
│ ├── test_orchestration_logging.py
│ ├── test_orchestration_tools_simple.py
│ ├── test_parallel_research_integration.py
│ ├── test_parallel_research_orchestrator.py
│ ├── test_parallel_research_performance.py
│ ├── test_performance_optimizations.py
│ ├── test_production_validation.py
│ ├── test_provider_architecture.py
│ ├── test_rate_limiting_enhanced.py
│ ├── test_runner_validation.py
│ ├── test_security_comprehensive.py.disabled
│ ├── test_security_cors.py
│ ├── test_security_enhancements.py.disabled
│ ├── test_security_headers.py
│ ├── test_security_penetration.py
│ ├── test_session_management.py
│ ├── test_speed_optimization_validation.py
│ ├── test_stock_analysis_dependencies.py
│ ├── test_stock_analysis_service.py
│ ├── test_stock_data_fetching_service.py
│ ├── test_supervisor_agent.py
│ ├── test_supervisor_functional.py
│ ├── test_tool_estimation_config.py
│ ├── test_visualization.py
│ └── utils
│ ├── test_agent_errors.py
│ ├── test_logging.py
│ ├── test_parallel_screening.py
│ └── test_quick_cache.py
├── tools
│ ├── check_orchestration_config.py
│ ├── experiments
│ │ ├── validation_examples.py
│ │ └── validation_fixed.py
│ ├── fast_dev.sh
│ ├── hot_reload.py
│ ├── quick_test.py
│ └── templates
│ ├── new_router_template.py
│ ├── new_tool_template.py
│ ├── screening_strategy_template.py
│ └── test_template.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/tests/integration/test_portfolio_persistence.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Comprehensive integration tests for portfolio persistence layer.
3 |
4 | Tests the database CRUD operations, relationships, constraints, and data integrity
5 | for the portfolio management system. Uses pytest fixtures with database sessions
6 | and SQLite for testing without external dependencies.
7 |
8 | Test Coverage:
9 | - Database CRUD operations (Create, Read, Update, Delete)
10 | - Relationship management (portfolio -> positions)
11 | - Unique constraints (user+portfolio name, portfolio+ticker)
12 | - Cascade deletes (portfolio deletion removes positions)
13 | - Data integrity (Decimal precision, timezone-aware datetimes)
14 | - Query performance (selectin loading, filtering)
15 | """
16 |
17 | import uuid
18 | from datetime import UTC, datetime, timedelta
19 | from decimal import Decimal
20 |
21 | import pytest
22 | from sqlalchemy import exc
23 | from sqlalchemy.orm import Session
24 |
25 | from maverick_mcp.data.models import PortfolioPosition, UserPortfolio
26 |
27 | pytestmark = pytest.mark.integration
28 |
29 |
30 | class TestPortfolioCreation:
31 | """Test suite for creating portfolios."""
32 |
33 | def test_create_portfolio_with_defaults(self, db_session: Session):
34 | """Test creating a portfolio with default values."""
35 | portfolio = UserPortfolio(
36 | user_id="default",
37 | name="My Portfolio",
38 | )
39 | db_session.add(portfolio)
40 | db_session.commit()
41 |
42 | # Verify creation
43 | retrieved = db_session.query(UserPortfolio).filter_by(id=portfolio.id).first()
44 | assert retrieved is not None
45 | assert retrieved.user_id == "default"
46 | assert retrieved.name == "My Portfolio"
47 | assert retrieved.positions == []
48 | assert retrieved.created_at is not None
49 | assert retrieved.updated_at is not None
50 |
51 | def test_create_portfolio_with_custom_user(self, db_session: Session):
52 | """Test creating a portfolio for a specific user."""
53 | portfolio = UserPortfolio(
54 | user_id="user123",
55 | name="User Portfolio",
56 | )
57 | db_session.add(portfolio)
58 | db_session.commit()
59 |
60 | retrieved = db_session.query(UserPortfolio).filter_by(id=portfolio.id).first()
61 | assert retrieved.user_id == "user123"
62 | assert retrieved.name == "User Portfolio"
63 |
64 | def test_create_multiple_portfolios_for_same_user(self, db_session: Session):
65 | """Test creating multiple portfolios for the same user."""
66 | portfolio1 = UserPortfolio(user_id="user1", name="Portfolio 1")
67 | portfolio2 = UserPortfolio(user_id="user1", name="Portfolio 2")
68 |
69 | db_session.add_all([portfolio1, portfolio2])
70 | db_session.commit()
71 |
72 | portfolios = db_session.query(UserPortfolio).filter_by(user_id="user1").all()
73 | assert len(portfolios) == 2
74 | assert {p.name for p in portfolios} == {"Portfolio 1", "Portfolio 2"}
75 |
76 | def test_portfolio_timestamps_created(self, db_session: Session):
77 | """Test that portfolio timestamps are set on creation."""
78 | before = datetime.now(UTC)
79 | portfolio = UserPortfolio(user_id="default", name="Test")
80 | db_session.add(portfolio)
81 | db_session.commit()
82 | after = datetime.now(UTC)
83 |
84 | retrieved = db_session.query(UserPortfolio).filter_by(id=portfolio.id).first()
85 | assert before <= retrieved.created_at <= after
86 | assert before <= retrieved.updated_at <= after
87 |
88 |
89 | class TestPortfolioPositionCreation:
90 | """Test suite for creating positions within portfolios."""
91 |
92 | @pytest.fixture
93 | def portfolio(self, db_session: Session):
94 | """Create a portfolio for position tests."""
95 | # Use unique name with UUID to avoid constraint violations across tests
96 | unique_name = f"Test Portfolio {uuid.uuid4()}"
97 | portfolio = UserPortfolio(user_id="default", name=unique_name)
98 | db_session.add(portfolio)
99 | db_session.commit()
100 | return portfolio
101 |
102 | def test_create_position_basic(self, db_session: Session, portfolio: UserPortfolio):
103 | """Test creating a basic position in a portfolio."""
104 | position = PortfolioPosition(
105 | portfolio_id=portfolio.id,
106 | ticker="AAPL",
107 | shares=Decimal("10.00000000"),
108 | average_cost_basis=Decimal("150.0000"),
109 | total_cost=Decimal("1500.0000"),
110 | purchase_date=datetime.now(UTC),
111 | )
112 | db_session.add(position)
113 | db_session.commit()
114 |
115 | retrieved = (
116 | db_session.query(PortfolioPosition).filter_by(id=position.id).first()
117 | )
118 | assert retrieved.ticker == "AAPL"
119 | assert retrieved.shares == Decimal("10.00000000")
120 | assert retrieved.average_cost_basis == Decimal("150.0000")
121 | assert retrieved.total_cost == Decimal("1500.0000")
122 | assert retrieved.notes is None
123 |
124 | def test_create_position_with_notes(
125 | self, db_session: Session, portfolio: UserPortfolio
126 | ):
127 | """Test creating a position with optional notes."""
128 | notes = "Accumulated during bear market. Strong technicals."
129 | position = PortfolioPosition(
130 | portfolio_id=portfolio.id,
131 | ticker="MSFT",
132 | shares=Decimal("5.50000000"),
133 | average_cost_basis=Decimal("380.0000"),
134 | total_cost=Decimal("2090.0000"),
135 | purchase_date=datetime.now(UTC),
136 | notes=notes,
137 | )
138 | db_session.add(position)
139 | db_session.commit()
140 |
141 | retrieved = (
142 | db_session.query(PortfolioPosition).filter_by(id=position.id).first()
143 | )
144 | assert retrieved.notes == notes
145 |
146 | def test_create_position_with_fractional_shares(
147 | self, db_session: Session, portfolio: UserPortfolio
148 | ):
149 | """Test that positions support fractional shares."""
150 | position = PortfolioPosition(
151 | portfolio_id=portfolio.id,
152 | ticker="GOOG",
153 | shares=Decimal("2.33333333"), # Fractional shares
154 | average_cost_basis=Decimal("2750.0000"),
155 | total_cost=Decimal("6408.3333"),
156 | purchase_date=datetime.now(UTC),
157 | )
158 | db_session.add(position)
159 | db_session.commit()
160 |
161 | retrieved = (
162 | db_session.query(PortfolioPosition).filter_by(id=position.id).first()
163 | )
164 | assert retrieved.shares == Decimal("2.33333333")
165 |
166 | def test_create_position_with_high_precision_prices(
167 | self, db_session: Session, portfolio: UserPortfolio
168 | ):
169 | """Test that positions maintain Decimal precision for prices."""
170 | position = PortfolioPosition(
171 | portfolio_id=portfolio.id,
172 | ticker="TSLA",
173 | shares=Decimal("1.50000000"),
174 | average_cost_basis=Decimal("245.1234"), # 4 decimal places
175 | total_cost=Decimal("367.6851"),
176 | purchase_date=datetime.now(UTC),
177 | )
178 | db_session.add(position)
179 | db_session.commit()
180 |
181 | retrieved = (
182 | db_session.query(PortfolioPosition).filter_by(id=position.id).first()
183 | )
184 | assert retrieved.average_cost_basis == Decimal("245.1234")
185 | assert retrieved.total_cost == Decimal("367.6851")
186 |
187 | def test_position_gets_portfolio_relationship(
188 | self, db_session: Session, portfolio: UserPortfolio
189 | ):
190 | """Test that position relationship to portfolio is properly loaded."""
191 | position = PortfolioPosition(
192 | portfolio_id=portfolio.id,
193 | ticker="AAPL",
194 | shares=Decimal("10.00000000"),
195 | average_cost_basis=Decimal("150.0000"),
196 | total_cost=Decimal("1500.0000"),
197 | purchase_date=datetime.now(UTC),
198 | )
199 | db_session.add(position)
200 | db_session.commit()
201 |
202 | # Query fresh without expunging to verify relationship loading
203 | retrieved_position = (
204 | db_session.query(PortfolioPosition).filter_by(id=position.id).first()
205 | )
206 | assert retrieved_position.portfolio is not None
207 | assert retrieved_position.portfolio.id == portfolio.id
208 | assert retrieved_position.portfolio.name == portfolio.name
209 |
210 |
211 | class TestPortfolioRead:
212 | """Test suite for reading portfolio data."""
213 |
214 | @pytest.fixture
215 | def portfolio_with_positions(self, db_session: Session):
216 | """Create a portfolio with multiple positions."""
217 | unique_name = f"Mixed Portfolio {uuid.uuid4()}"
218 | portfolio = UserPortfolio(user_id="default", name=unique_name)
219 | db_session.add(portfolio)
220 | db_session.commit()
221 |
222 | positions = [
223 | PortfolioPosition(
224 | portfolio_id=portfolio.id,
225 | ticker="AAPL",
226 | shares=Decimal("10.00000000"),
227 | average_cost_basis=Decimal("150.0000"),
228 | total_cost=Decimal("1500.0000"),
229 | purchase_date=datetime.now(UTC),
230 | ),
231 | PortfolioPosition(
232 | portfolio_id=portfolio.id,
233 | ticker="MSFT",
234 | shares=Decimal("5.00000000"),
235 | average_cost_basis=Decimal("380.0000"),
236 | total_cost=Decimal("1900.0000"),
237 | purchase_date=datetime.now(UTC) - timedelta(days=30),
238 | ),
239 | PortfolioPosition(
240 | portfolio_id=portfolio.id,
241 | ticker="GOOG",
242 | shares=Decimal("2.50000000"),
243 | average_cost_basis=Decimal("2750.0000"),
244 | total_cost=Decimal("6875.0000"),
245 | purchase_date=datetime.now(UTC) - timedelta(days=60),
246 | ),
247 | ]
248 | db_session.add_all(positions)
249 | db_session.commit()
250 |
251 | return portfolio
252 |
253 | def test_read_portfolio_with_eager_loaded_positions(
254 | self, db_session: Session, portfolio_with_positions: UserPortfolio
255 | ):
256 | """Test that positions are eagerly loaded with portfolio (selectin)."""
257 | portfolio = (
258 | db_session.query(UserPortfolio)
259 | .filter_by(id=portfolio_with_positions.id)
260 | .first()
261 | )
262 | assert len(portfolio.positions) == 3
263 | assert {p.ticker for p in portfolio.positions} == {"AAPL", "MSFT", "GOOG"}
264 |
265 | def test_read_position_by_ticker(
266 | self, db_session: Session, portfolio_with_positions: UserPortfolio
267 | ):
268 | """Test filtering positions by ticker."""
269 | position = (
270 | db_session.query(PortfolioPosition)
271 | .filter_by(portfolio_id=portfolio_with_positions.id, ticker="MSFT")
272 | .first()
273 | )
274 | assert position is not None
275 | assert position.ticker == "MSFT"
276 | assert position.shares == Decimal("5.00000000")
277 | assert position.average_cost_basis == Decimal("380.0000")
278 |
279 | def test_read_all_positions_for_portfolio(
280 | self, db_session: Session, portfolio_with_positions: UserPortfolio
281 | ):
282 | """Test reading all positions for a portfolio."""
283 | positions = (
284 | db_session.query(PortfolioPosition)
285 | .filter_by(portfolio_id=portfolio_with_positions.id)
286 | .order_by(PortfolioPosition.ticker)
287 | .all()
288 | )
289 | assert len(positions) == 3
290 | assert positions[0].ticker == "AAPL"
291 | assert positions[1].ticker == "GOOG"
292 | assert positions[2].ticker == "MSFT"
293 |
294 | def test_read_portfolio_by_user_and_name(self, db_session: Session):
295 | """Test reading portfolio by user_id and name."""
296 | portfolio = UserPortfolio(user_id="user1", name="Specific Portfolio")
297 | db_session.add(portfolio)
298 | db_session.commit()
299 |
300 | retrieved = (
301 | db_session.query(UserPortfolio)
302 | .filter_by(user_id="user1", name="Specific Portfolio")
303 | .first()
304 | )
305 | assert retrieved is not None
306 | assert retrieved.id == portfolio.id
307 |
308 | def test_read_multiple_portfolios_for_user(self, db_session: Session):
309 | """Test reading multiple portfolios for the same user."""
310 | user_id = "user_multi"
311 | portfolios = [
312 | UserPortfolio(user_id=user_id, name=f"Portfolio {i}") for i in range(3)
313 | ]
314 | db_session.add_all(portfolios)
315 | db_session.commit()
316 |
317 | retrieved_portfolios = (
318 | db_session.query(UserPortfolio)
319 | .filter_by(user_id=user_id)
320 | .order_by(UserPortfolio.name)
321 | .all()
322 | )
323 | assert len(retrieved_portfolios) == 3
324 |
325 |
326 | class TestPortfolioUpdate:
327 | """Test suite for updating portfolio data."""
328 |
329 | @pytest.fixture
330 | def portfolio_with_position(self, db_session: Session):
331 | """Create portfolio with a position for update tests."""
332 | unique_name = f"Update Test {uuid.uuid4()}"
333 | portfolio = UserPortfolio(user_id="default", name=unique_name)
334 | db_session.add(portfolio)
335 | db_session.commit()
336 |
337 | position = PortfolioPosition(
338 | portfolio_id=portfolio.id,
339 | ticker="AAPL",
340 | shares=Decimal("10.00000000"),
341 | average_cost_basis=Decimal("150.0000"),
342 | total_cost=Decimal("1500.0000"),
343 | purchase_date=datetime.now(UTC),
344 | notes="Initial purchase",
345 | )
346 | db_session.add(position)
347 | db_session.commit()
348 |
349 | return portfolio, position
350 |
351 | def test_update_portfolio_name(
352 | self, db_session: Session, portfolio_with_position: tuple
353 | ):
354 | """Test updating portfolio name."""
355 | portfolio, _ = portfolio_with_position
356 |
357 | portfolio.name = "Updated Portfolio Name"
358 | db_session.commit()
359 |
360 | retrieved = db_session.query(UserPortfolio).filter_by(id=portfolio.id).first()
361 | assert retrieved.name == "Updated Portfolio Name"
362 |
363 | def test_update_position_shares_and_cost(
364 | self, db_session: Session, portfolio_with_position: tuple
365 | ):
366 | """Test updating position shares and cost (simulating averaging)."""
367 | _, position = portfolio_with_position
368 |
369 | # Simulate adding shares with cost basis averaging
370 | position.shares = Decimal("20.00000000")
371 | position.average_cost_basis = Decimal("160.0000") # Averaged cost
372 | position.total_cost = Decimal("3200.0000")
373 | db_session.commit()
374 |
375 | retrieved = (
376 | db_session.query(PortfolioPosition).filter_by(id=position.id).first()
377 | )
378 | assert retrieved.shares == Decimal("20.00000000")
379 | assert retrieved.average_cost_basis == Decimal("160.0000")
380 | assert retrieved.total_cost == Decimal("3200.0000")
381 |
382 | def test_update_position_notes(
383 | self, db_session: Session, portfolio_with_position: tuple
384 | ):
385 | """Test updating position notes."""
386 | _, position = portfolio_with_position
387 |
388 | new_notes = "Sold 5 shares at $180, added 5 at $140"
389 | position.notes = new_notes
390 | db_session.commit()
391 |
392 | retrieved = (
393 | db_session.query(PortfolioPosition).filter_by(id=position.id).first()
394 | )
395 | assert retrieved.notes == new_notes
396 |
397 | def test_update_position_clears_notes(
398 | self, db_session: Session, portfolio_with_position: tuple
399 | ):
400 | """Test clearing position notes."""
401 | _, position = portfolio_with_position
402 |
403 | position.notes = None
404 | db_session.commit()
405 |
406 | retrieved = (
407 | db_session.query(PortfolioPosition).filter_by(id=position.id).first()
408 | )
409 | assert retrieved.notes is None
410 |
411 | def test_portfolio_updated_timestamp_changes(
412 | self, db_session: Session, portfolio_with_position: tuple
413 | ):
414 | """Test that updated_at timestamp changes when portfolio is modified."""
415 | portfolio, _ = portfolio_with_position
416 |
417 | # Small delay to ensure timestamp changes
418 | import time
419 |
420 | time.sleep(0.01)
421 |
422 | portfolio.name = "New Name"
423 | db_session.commit()
424 |
425 | retrieved = db_session.query(UserPortfolio).filter_by(id=portfolio.id).first()
426 | # Note: updated_at may not always change depending on DB precision
427 | # This test verifies the column exists and is updateable
428 | assert retrieved.updated_at is not None
429 |
430 |
431 | class TestPortfolioDelete:
432 | """Test suite for deleting portfolios and positions."""
433 |
434 | @pytest.fixture
435 | def portfolio_with_positions(self, db_session: Session):
436 | """Create portfolio with positions for deletion tests."""
437 | unique_name = f"Delete Test {uuid.uuid4()}"
438 | portfolio = UserPortfolio(user_id="default", name=unique_name)
439 | db_session.add(portfolio)
440 | db_session.commit()
441 |
442 | positions = [
443 | PortfolioPosition(
444 | portfolio_id=portfolio.id,
445 | ticker="AAPL",
446 | shares=Decimal("10.00000000"),
447 | average_cost_basis=Decimal("150.0000"),
448 | total_cost=Decimal("1500.0000"),
449 | purchase_date=datetime.now(UTC),
450 | ),
451 | PortfolioPosition(
452 | portfolio_id=portfolio.id,
453 | ticker="MSFT",
454 | shares=Decimal("5.00000000"),
455 | average_cost_basis=Decimal("380.0000"),
456 | total_cost=Decimal("1900.0000"),
457 | purchase_date=datetime.now(UTC),
458 | ),
459 | ]
460 | db_session.add_all(positions)
461 | db_session.commit()
462 |
463 | return portfolio
464 |
465 | def test_delete_single_position(
466 | self, db_session: Session, portfolio_with_positions: UserPortfolio
467 | ):
468 | """Test deleting a single position from a portfolio."""
469 | position = (
470 | db_session.query(PortfolioPosition)
471 | .filter_by(portfolio_id=portfolio_with_positions.id, ticker="AAPL")
472 | .first()
473 | )
474 | position_id = position.id
475 |
476 | db_session.delete(position)
477 | db_session.commit()
478 |
479 | retrieved = (
480 | db_session.query(PortfolioPosition).filter_by(id=position_id).first()
481 | )
482 | assert retrieved is None
483 |
484 | # Verify other position still exists
485 | other_position = (
486 | db_session.query(PortfolioPosition)
487 | .filter_by(portfolio_id=portfolio_with_positions.id, ticker="MSFT")
488 | .first()
489 | )
490 | assert other_position is not None
491 |
492 | def test_delete_all_positions_from_portfolio(
493 | self, db_session: Session, portfolio_with_positions: UserPortfolio
494 | ):
495 | """Test deleting all positions from a portfolio."""
496 | positions = (
497 | db_session.query(PortfolioPosition)
498 | .filter_by(portfolio_id=portfolio_with_positions.id)
499 | .all()
500 | )
501 |
502 | for position in positions:
503 | db_session.delete(position)
504 | db_session.commit()
505 |
506 | remaining = (
507 | db_session.query(PortfolioPosition)
508 | .filter_by(portfolio_id=portfolio_with_positions.id)
509 | .all()
510 | )
511 | assert len(remaining) == 0
512 |
513 | # Portfolio should still exist
514 | portfolio = (
515 | db_session.query(UserPortfolio)
516 | .filter_by(id=portfolio_with_positions.id)
517 | .first()
518 | )
519 | assert portfolio is not None
520 |
521 | def test_cascade_delete_portfolio_removes_positions(
522 | self, db_session: Session, portfolio_with_positions: UserPortfolio
523 | ):
524 | """Test that deleting a portfolio cascades delete to positions."""
525 | portfolio_id = portfolio_with_positions.id
526 |
527 | db_session.delete(portfolio_with_positions)
528 | db_session.commit()
529 |
530 | # Portfolio should be deleted
531 | portfolio = db_session.query(UserPortfolio).filter_by(id=portfolio_id).first()
532 | assert portfolio is None
533 |
534 | # Positions should also be deleted
535 | positions = (
536 | db_session.query(PortfolioPosition)
537 | .filter_by(portfolio_id=portfolio_id)
538 | .all()
539 | )
540 | assert len(positions) == 0
541 |
542 | def test_delete_portfolio_doesnt_affect_other_portfolios(self, db_session: Session):
543 | """Test that deleting one portfolio doesn't affect others."""
544 | user_id = f"user1_{uuid.uuid4()}"
545 | portfolio1 = UserPortfolio(user_id=user_id, name=f"Portfolio 1 {uuid.uuid4()}")
546 | portfolio2 = UserPortfolio(user_id=user_id, name=f"Portfolio 2 {uuid.uuid4()}")
547 | db_session.add_all([portfolio1, portfolio2])
548 | db_session.commit()
549 |
550 | # Add position to portfolio1
551 | position = PortfolioPosition(
552 | portfolio_id=portfolio1.id,
553 | ticker="AAPL",
554 | shares=Decimal("10.00000000"),
555 | average_cost_basis=Decimal("150.0000"),
556 | total_cost=Decimal("1500.0000"),
557 | purchase_date=datetime.now(UTC),
558 | )
559 | db_session.add(position)
560 | db_session.commit()
561 |
562 | # Delete portfolio1
563 | db_session.delete(portfolio1)
564 | db_session.commit()
565 |
566 | # Portfolio2 should still exist
567 | p2 = db_session.query(UserPortfolio).filter_by(id=portfolio2.id).first()
568 | assert p2 is not None
569 | assert p2.name == portfolio2.name # Use the actual name since it's generated
570 |
571 |
572 | class TestUniqueConstraints:
573 | """Test suite for unique constraint enforcement."""
574 |
575 | def test_duplicate_portfolio_name_for_same_user_fails(self, db_session: Session):
576 | """Test that duplicate portfolio names for same user fail."""
577 | user_id = f"user1_{uuid.uuid4()}"
578 | name = f"My Portfolio {uuid.uuid4()}"
579 |
580 | portfolio1 = UserPortfolio(user_id=user_id, name=name)
581 | db_session.add(portfolio1)
582 | db_session.commit()
583 |
584 | # Try to create duplicate
585 | portfolio2 = UserPortfolio(user_id=user_id, name=name)
586 | db_session.add(portfolio2)
587 |
588 | with pytest.raises(exc.IntegrityError):
589 | db_session.commit()
590 |
591 | def test_same_portfolio_name_different_users_succeeds(self, db_session: Session):
592 | """Test that same portfolio name is allowed for different users."""
593 | name = f"My Portfolio {uuid.uuid4()}"
594 |
595 | portfolio1 = UserPortfolio(user_id=f"user1_{uuid.uuid4()}", name=name)
596 | portfolio2 = UserPortfolio(user_id=f"user2_{uuid.uuid4()}", name=name)
597 | db_session.add_all([portfolio1, portfolio2])
598 | db_session.commit()
599 |
600 | # Both should exist
601 | p1 = (
602 | db_session.query(UserPortfolio)
603 | .filter_by(user_id=portfolio1.user_id, name=name)
604 | .first()
605 | )
606 | p2 = (
607 | db_session.query(UserPortfolio)
608 | .filter_by(user_id=portfolio2.user_id, name=name)
609 | .first()
610 | )
611 | assert p1 is not None
612 | assert p2 is not None
613 | assert p1.id != p2.id
614 |
615 | def test_duplicate_ticker_in_same_portfolio_fails(self, db_session: Session):
616 | """Test that duplicate tickers in same portfolio fail."""
617 | unique_name = f"Test {uuid.uuid4()}"
618 | portfolio = UserPortfolio(user_id="default", name=unique_name)
619 | db_session.add(portfolio)
620 | db_session.commit()
621 |
622 | position1 = PortfolioPosition(
623 | portfolio_id=portfolio.id,
624 | ticker="AAPL",
625 | shares=Decimal("10.00000000"),
626 | average_cost_basis=Decimal("150.0000"),
627 | total_cost=Decimal("1500.0000"),
628 | purchase_date=datetime.now(UTC),
629 | )
630 | db_session.add(position1)
631 | db_session.commit()
632 |
633 | # Try to create duplicate ticker
634 | position2 = PortfolioPosition(
635 | portfolio_id=portfolio.id,
636 | ticker="AAPL",
637 | shares=Decimal("5.00000000"),
638 | average_cost_basis=Decimal("160.0000"),
639 | total_cost=Decimal("800.0000"),
640 | purchase_date=datetime.now(UTC),
641 | )
642 | db_session.add(position2)
643 |
644 | with pytest.raises(exc.IntegrityError):
645 | db_session.commit()
646 |
647 | def test_same_ticker_different_portfolios_succeeds(self, db_session: Session):
648 | """Test that same ticker is allowed in different portfolios."""
649 | user_id = f"user1_{uuid.uuid4()}"
650 | portfolio1 = UserPortfolio(user_id=user_id, name=f"Portfolio 1 {uuid.uuid4()}")
651 | portfolio2 = UserPortfolio(user_id=user_id, name=f"Portfolio 2 {uuid.uuid4()}")
652 | db_session.add_all([portfolio1, portfolio2])
653 | db_session.commit()
654 |
655 | position1 = PortfolioPosition(
656 | portfolio_id=portfolio1.id,
657 | ticker="AAPL",
658 | shares=Decimal("10.00000000"),
659 | average_cost_basis=Decimal("150.0000"),
660 | total_cost=Decimal("1500.0000"),
661 | purchase_date=datetime.now(UTC),
662 | )
663 | position2 = PortfolioPosition(
664 | portfolio_id=portfolio2.id,
665 | ticker="AAPL",
666 | shares=Decimal("5.00000000"),
667 | average_cost_basis=Decimal("160.0000"),
668 | total_cost=Decimal("800.0000"),
669 | purchase_date=datetime.now(UTC),
670 | )
671 | db_session.add_all([position1, position2])
672 | db_session.commit()
673 |
674 | # Both should exist
675 | p1 = (
676 | db_session.query(PortfolioPosition)
677 | .filter_by(portfolio_id=portfolio1.id, ticker="AAPL")
678 | .first()
679 | )
680 | p2 = (
681 | db_session.query(PortfolioPosition)
682 | .filter_by(portfolio_id=portfolio2.id, ticker="AAPL")
683 | .first()
684 | )
685 | assert p1 is not None
686 | assert p2 is not None
687 | assert p1.id != p2.id
688 |
689 |
690 | class TestDataIntegrity:
691 | """Test suite for data integrity and precision."""
692 |
693 | @pytest.fixture
694 | def portfolio(self, db_session: Session):
695 | """Create a portfolio for integrity tests."""
696 | unique_name = f"Integrity Test {uuid.uuid4()}"
697 | portfolio = UserPortfolio(user_id="default", name=unique_name)
698 | db_session.add(portfolio)
699 | db_session.commit()
700 | return portfolio
701 |
702 | def test_decimal_precision_preserved(
703 | self, db_session: Session, portfolio: UserPortfolio
704 | ):
705 | """Test that Decimal precision is maintained through round-trip."""
706 | # Use precision that matches database columns:
707 | # shares: Numeric(20, 8), cost_basis: Numeric(12, 4), total_cost: Numeric(20, 4)
708 | shares = Decimal("1.12345678")
709 | cost_basis = Decimal("2345.6789")
710 | total_cost = Decimal("2637.4012") # Limited to 4 decimal places
711 |
712 | position = PortfolioPosition(
713 | portfolio_id=portfolio.id,
714 | ticker="TEST",
715 | shares=shares,
716 | average_cost_basis=cost_basis,
717 | total_cost=total_cost,
718 | purchase_date=datetime.now(UTC),
719 | )
720 | db_session.add(position)
721 | db_session.commit()
722 |
723 | retrieved = (
724 | db_session.query(PortfolioPosition).filter_by(id=position.id).first()
725 | )
726 | assert retrieved.shares == shares
727 | assert retrieved.average_cost_basis == cost_basis
728 | assert retrieved.total_cost == total_cost
729 |
730 | def test_timezone_aware_datetime_preserved(
731 | self, db_session: Session, portfolio: UserPortfolio
732 | ):
733 | """Test that timezone-aware datetimes are preserved."""
734 | purchase_date = datetime(2024, 1, 15, 14, 30, 45, 123456, tzinfo=UTC)
735 |
736 | position = PortfolioPosition(
737 | portfolio_id=portfolio.id,
738 | ticker="AAPL",
739 | shares=Decimal("10.00000000"),
740 | average_cost_basis=Decimal("150.0000"),
741 | total_cost=Decimal("1500.0000"),
742 | purchase_date=purchase_date,
743 | )
744 | db_session.add(position)
745 | db_session.commit()
746 |
747 | retrieved = (
748 | db_session.query(PortfolioPosition).filter_by(id=position.id).first()
749 | )
750 | assert retrieved.purchase_date.tzinfo is not None
751 | # Compare date/time (may lose microsecond precision depending on DB)
752 | assert retrieved.purchase_date.year == purchase_date.year
753 | assert retrieved.purchase_date.month == purchase_date.month
754 | assert retrieved.purchase_date.day == purchase_date.day
755 | assert retrieved.purchase_date.hour == purchase_date.hour
756 | assert retrieved.purchase_date.minute == purchase_date.minute
757 | assert retrieved.purchase_date.second == purchase_date.second
758 |
759 | def test_null_notes_allowed(self, db_session: Session, portfolio: UserPortfolio):
760 | """Test that NULL notes are properly handled."""
761 | position1 = PortfolioPosition(
762 | portfolio_id=portfolio.id,
763 | ticker="AAPL",
764 | shares=Decimal("10.00000000"),
765 | average_cost_basis=Decimal("150.0000"),
766 | total_cost=Decimal("1500.0000"),
767 | purchase_date=datetime.now(UTC),
768 | notes=None,
769 | )
770 | position2 = PortfolioPosition(
771 | portfolio_id=portfolio.id,
772 | ticker="MSFT",
773 | shares=Decimal("5.00000000"),
774 | average_cost_basis=Decimal("380.0000"),
775 | total_cost=Decimal("1900.0000"),
776 | purchase_date=datetime.now(UTC),
777 | notes="Some notes",
778 | )
779 | db_session.add_all([position1, position2])
780 | db_session.commit()
781 |
782 | p1 = (
783 | db_session.query(PortfolioPosition)
784 | .filter_by(portfolio_id=portfolio.id, ticker="AAPL")
785 | .first()
786 | )
787 | p2 = (
788 | db_session.query(PortfolioPosition)
789 | .filter_by(portfolio_id=portfolio.id, ticker="MSFT")
790 | .first()
791 | )
792 | assert p1.notes is None
793 | assert p2.notes == "Some notes"
794 |
795 | def test_empty_notes_string_stored(
796 | self, db_session: Session, portfolio: UserPortfolio
797 | ):
798 | """Test that empty string notes are stored (if provided)."""
799 | position = PortfolioPosition(
800 | portfolio_id=portfolio.id,
801 | ticker="AAPL",
802 | shares=Decimal("10.00000000"),
803 | average_cost_basis=Decimal("150.0000"),
804 | total_cost=Decimal("1500.0000"),
805 | purchase_date=datetime.now(UTC),
806 | notes="",
807 | )
808 | db_session.add(position)
809 | db_session.commit()
810 |
811 | retrieved = (
812 | db_session.query(PortfolioPosition).filter_by(id=position.id).first()
813 | )
814 | assert retrieved.notes == ""
815 |
816 | def test_large_decimal_values(self, db_session: Session, portfolio: UserPortfolio):
817 | """Test handling of large Decimal values."""
818 | position = PortfolioPosition(
819 | portfolio_id=portfolio.id,
820 | ticker="HUGE",
821 | shares=Decimal("999999999999.99999999"), # Large shares
822 | average_cost_basis=Decimal("9999.9999"), # Large price
823 | total_cost=Decimal("9999999999999999.9999"), # Large total
824 | purchase_date=datetime.now(UTC),
825 | )
826 | db_session.add(position)
827 | db_session.commit()
828 |
829 | retrieved = (
830 | db_session.query(PortfolioPosition).filter_by(id=position.id).first()
831 | )
832 | assert retrieved.shares == Decimal("999999999999.99999999")
833 | assert retrieved.average_cost_basis == Decimal("9999.9999")
834 | assert retrieved.total_cost == Decimal("9999999999999999.9999")
835 |
836 | def test_very_small_decimal_values(
837 | self, db_session: Session, portfolio: UserPortfolio
838 | ):
839 | """Test handling of very small Decimal values.
840 |
841 | Note: total_cost uses Numeric(20, 4) precision, so values smaller than
842 | 0.0001 will be truncated. This is appropriate for stock trading.
843 | """
844 | position = PortfolioPosition(
845 | portfolio_id=portfolio.id,
846 | ticker="TINY",
847 | shares=Decimal("0.00000001"), # Very small shares (supports 8 decimals)
848 | average_cost_basis=Decimal("0.0001"), # Minimum price precision
849 | total_cost=Decimal("0.0000"), # Rounds to 0.0000 due to Numeric(20, 4)
850 | purchase_date=datetime.now(UTC),
851 | )
852 | db_session.add(position)
853 | db_session.commit()
854 |
855 | retrieved = (
856 | db_session.query(PortfolioPosition).filter_by(id=position.id).first()
857 | )
858 | assert retrieved.shares == Decimal("0.00000001")
859 | assert retrieved.average_cost_basis == Decimal("0.0001")
860 | # Total cost truncated to 4 decimal places as per Numeric(20, 4)
861 | assert retrieved.total_cost == Decimal("0.0000")
862 |
863 |
864 | class TestQueryPerformance:
865 | """Test suite for query optimization and index usage."""
866 |
867 | @pytest.fixture
868 | def large_portfolio(self, db_session: Session):
869 | """Create a portfolio with many positions."""
870 | unique_name = f"Large Portfolio {uuid.uuid4()}"
871 | portfolio = UserPortfolio(user_id="default", name=unique_name)
872 | db_session.add(portfolio)
873 | db_session.commit()
874 |
875 | # Create many positions
876 | tickers = ["AAPL", "MSFT", "GOOG", "AMZN", "TSLA", "META", "NVDA", "NFLX"]
877 | positions = [
878 | PortfolioPosition(
879 | portfolio_id=portfolio.id,
880 | ticker=tickers[i % len(tickers)],
881 | shares=Decimal(f"{10 + i}.00000000"),
882 | average_cost_basis=Decimal(f"{100 + (i * 10)}.0000"),
883 | total_cost=Decimal(f"{(10 + i) * (100 + (i * 10))}.0000"),
884 | purchase_date=datetime.now(UTC) - timedelta(days=i),
885 | )
886 | for i in range(len(tickers))
887 | ]
888 | db_session.add_all(positions)
889 | db_session.commit()
890 |
891 | return portfolio
892 |
893 | def test_selectin_loading_of_positions(
894 | self, db_session: Session, large_portfolio: UserPortfolio
895 | ):
896 | """Test that selectin loading prevents N+1 queries on positions."""
897 | portfolio = (
898 | db_session.query(UserPortfolio).filter_by(id=large_portfolio.id).first()
899 | )
900 |
901 | # Accessing positions should not trigger additional queries
902 | # (they should already be loaded via selectin)
903 | assert len(portfolio.positions) > 0
904 | for position in portfolio.positions:
905 | assert position.ticker is not None
906 |
907 | def test_filter_by_ticker_uses_index(
908 | self, db_session: Session, large_portfolio: UserPortfolio
909 | ):
910 | """Test that filtering by ticker uses the index."""
911 | # This test verifies index exists by checking query can filter
912 | positions = (
913 | db_session.query(PortfolioPosition)
914 | .filter_by(portfolio_id=large_portfolio.id, ticker="AAPL")
915 | .all()
916 | )
917 | assert len(positions) >= 1
918 | assert all(p.ticker == "AAPL" for p in positions)
919 |
920 | def test_filter_by_portfolio_id_uses_index(
921 | self, db_session: Session, large_portfolio: UserPortfolio
922 | ):
923 | """Test that filtering by portfolio_id uses the index."""
924 | positions = (
925 | db_session.query(PortfolioPosition)
926 | .filter_by(portfolio_id=large_portfolio.id)
927 | .all()
928 | )
929 | assert len(positions) > 0
930 | assert all(p.portfolio_id == large_portfolio.id for p in positions)
931 |
932 | def test_combined_filter_portfolio_and_ticker(
933 | self, db_session: Session, large_portfolio: UserPortfolio
934 | ):
935 | """Test filtering by both portfolio_id and ticker (composite index)."""
936 | position = (
937 | db_session.query(PortfolioPosition)
938 | .filter_by(portfolio_id=large_portfolio.id, ticker="MSFT")
939 | .first()
940 | )
941 | assert position is not None
942 | assert position.ticker == "MSFT"
943 |
944 | def test_query_user_portfolios_by_user_id(self, db_session: Session):
945 | """Test that querying portfolios by user_id is efficient."""
946 | user_id = f"user_perf_{uuid.uuid4()}"
947 | portfolios = [
948 | UserPortfolio(user_id=user_id, name=f"Portfolio {i}_{uuid.uuid4()}")
949 | for i in range(5)
950 | ]
951 | db_session.add_all(portfolios)
952 | db_session.commit()
953 |
954 | retrieved = db_session.query(UserPortfolio).filter_by(user_id=user_id).all()
955 | assert len(retrieved) == 5
956 |
957 | def test_order_by_ticker_works(
958 | self, db_session: Session, large_portfolio: UserPortfolio
959 | ):
960 | """Test ordering positions by ticker."""
961 | positions = (
962 | db_session.query(PortfolioPosition)
963 | .filter_by(portfolio_id=large_portfolio.id)
964 | .order_by(PortfolioPosition.ticker)
965 | .all()
966 | )
967 | assert len(positions) > 0
968 | # Verify ordering
969 | tickers = [p.ticker for p in positions]
970 | assert tickers == sorted(tickers)
971 |
972 |
973 | class TestPortfolioIntegration:
974 | """End-to-end integration tests combining multiple operations."""
975 |
976 | def test_complete_portfolio_lifecycle(self, db_session: Session):
977 | """Test complete portfolio lifecycle from creation to deletion."""
978 | # Create portfolio
979 | unique_name = f"Lifecycle Portfolio {uuid.uuid4()}"
980 | portfolio = UserPortfolio(user_id="test_user", name=unique_name)
981 | db_session.add(portfolio)
982 | db_session.commit()
983 | portfolio_id = portfolio.id
984 |
985 | # Add positions
986 | positions_data = [
987 | ("AAPL", Decimal("10"), Decimal("150.0000"), Decimal("1500.0000")),
988 | ("MSFT", Decimal("5"), Decimal("380.0000"), Decimal("1900.0000")),
989 | ]
990 |
991 | for ticker, shares, price, total in positions_data:
992 | position = PortfolioPosition(
993 | portfolio_id=portfolio_id,
994 | ticker=ticker,
995 | shares=shares,
996 | average_cost_basis=price,
997 | total_cost=total,
998 | purchase_date=datetime.now(UTC),
999 | )
1000 | db_session.add(position)
1001 | db_session.commit()
1002 |
1003 | # Read and verify
1004 | portfolio = db_session.query(UserPortfolio).filter_by(id=portfolio_id).first()
1005 | assert len(portfolio.positions) == 2
1006 | assert {p.ticker for p in portfolio.positions} == {"AAPL", "MSFT"}
1007 |
1008 | # Update position
1009 | msft_position = next(p for p in portfolio.positions if p.ticker == "MSFT")
1010 | msft_position.shares = Decimal("10") # Double shares
1011 | msft_position.average_cost_basis = Decimal("370.0000") # Averaged price
1012 | msft_position.total_cost = Decimal("3700.0000")
1013 | db_session.commit()
1014 |
1015 | # Delete one position
1016 | aapl_position = next(p for p in portfolio.positions if p.ticker == "AAPL")
1017 | db_session.delete(aapl_position)
1018 | db_session.commit()
1019 |
1020 | # Verify state
1021 | portfolio = db_session.query(UserPortfolio).filter_by(id=portfolio_id).first()
1022 | assert len(portfolio.positions) == 1
1023 | assert portfolio.positions[0].ticker == "MSFT"
1024 | assert portfolio.positions[0].shares == Decimal("10")
1025 |
1026 | # Delete portfolio
1027 | db_session.delete(portfolio)
1028 | db_session.commit()
1029 |
1030 | # Verify deletion
1031 | portfolio = db_session.query(UserPortfolio).filter_by(id=portfolio_id).first()
1032 | assert portfolio is None
1033 |
1034 | positions = (
1035 | db_session.query(PortfolioPosition)
1036 | .filter_by(portfolio_id=portfolio_id)
1037 | .all()
1038 | )
1039 | assert len(positions) == 0
1040 |
1041 | def test_portfolio_with_various_decimal_precision(self, db_session: Session):
1042 | """Test portfolio with positions of varying decimal precisions.
1043 |
1044 | Note: total_cost uses Numeric(20, 4), so values are truncated to 4 decimal places.
1045 | """
1046 | unique_name = f"Mixed Precision {uuid.uuid4()}"
1047 | portfolio = UserPortfolio(user_id="default", name=unique_name)
1048 | db_session.add(portfolio)
1049 | db_session.commit()
1050 |
1051 | positions_data = [
1052 | ("AAPL", Decimal("1"), Decimal("100.00"), Decimal("100.00")),
1053 | ("MSFT", Decimal("1.5"), Decimal("200.5000"), Decimal("300.7500")),
1054 | (
1055 | "GOOG",
1056 | Decimal("0.33333333"),
1057 | Decimal("2750.1234"),
1058 | Decimal("917.5041"), # Truncated from 917.50413522 to 4 decimals
1059 | ),
1060 | ("AMZN", Decimal("100"), Decimal("150.1"), Decimal("15010")),
1061 | ]
1062 |
1063 | for ticker, shares, price, total in positions_data:
1064 | position = PortfolioPosition(
1065 | portfolio_id=portfolio.id,
1066 | ticker=ticker,
1067 | shares=shares,
1068 | average_cost_basis=price,
1069 | total_cost=total,
1070 | purchase_date=datetime.now(UTC),
1071 | )
1072 | db_session.add(position)
1073 | db_session.commit()
1074 |
1075 | # Verify all positions preserved their precision
1076 | portfolio = db_session.query(UserPortfolio).filter_by(id=portfolio.id).first()
1077 | assert len(portfolio.positions) == 4
1078 |
1079 | for (
1080 | expected_ticker,
1081 | expected_shares,
1082 | expected_price,
1083 | expected_total,
1084 | ) in positions_data:
1085 | position = next(
1086 | p for p in portfolio.positions if p.ticker == expected_ticker
1087 | )
1088 | assert position.shares == expected_shares
1089 | assert position.average_cost_basis == expected_price
1090 | assert position.total_cost == expected_total
1091 |
```
--------------------------------------------------------------------------------
/maverick_mcp/api/routers/research.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Deep research tools with adaptive timeout handling and comprehensive optimization.
3 |
4 | This module provides timeout-protected research tools with LLM optimization
5 | to prevent hanging and ensure reliable responses to Claude Desktop.
6 | """
7 |
8 | import asyncio
9 | import logging
10 | import uuid
11 | from datetime import datetime
12 | from typing import Any
13 |
14 | from fastmcp import FastMCP
15 | from pydantic import BaseModel, Field
16 |
17 | from maverick_mcp.agents.base import INVESTOR_PERSONAS
18 | from maverick_mcp.agents.deep_research import DeepResearchAgent
19 | from maverick_mcp.api.middleware.mcp_logging import get_tool_logger
20 | from maverick_mcp.config.settings import get_settings
21 | from maverick_mcp.providers.llm_factory import get_llm
22 | from maverick_mcp.providers.openrouter_provider import TaskType
23 | from maverick_mcp.utils.orchestration_logging import (
24 | log_performance_metrics,
25 | log_tool_invocation,
26 | )
27 |
28 | logger = logging.getLogger(__name__)
29 | settings = get_settings()
30 |
31 | # Initialize LLM and agent
32 | llm = get_llm()
33 | research_agent = None
34 |
35 |
36 | # Request models for tool registration
37 | class ResearchRequest(BaseModel):
38 | """Request model for comprehensive research"""
39 |
40 | query: str = Field(description="Research query or topic")
41 | persona: str | None = Field(
42 | default="moderate",
43 | description="Investor persona (conservative, moderate, aggressive, day_trader)",
44 | )
45 | research_scope: str | None = Field(
46 | default="standard",
47 | description="Research scope (basic, standard, comprehensive, exhaustive)",
48 | )
49 | max_sources: int | None = Field(
50 | default=10, description="Maximum sources to analyze (1-30)"
51 | )
52 | timeframe: str | None = Field(
53 | default="1m", description="Time frame for search (1d, 1w, 1m, 3m)"
54 | )
55 |
56 |
57 | class CompanyResearchRequest(BaseModel):
58 | """Request model for company research"""
59 |
60 | symbol: str = Field(description="Stock ticker symbol")
61 | include_competitive_analysis: bool = Field(
62 | default=False, description="Include competitive analysis"
63 | )
64 | persona: str | None = Field(
65 | default="moderate", description="Investor persona for analysis perspective"
66 | )
67 |
68 |
69 | class SentimentAnalysisRequest(BaseModel):
70 | """Request model for sentiment analysis"""
71 |
72 | topic: str = Field(description="Topic for sentiment analysis")
73 | timeframe: str | None = Field(default="1w", description="Time frame for analysis")
74 | persona: str | None = Field(default="moderate", description="Investor persona")
75 | session_id: str | None = Field(default=None, description="Session identifier")
76 |
77 |
78 | def get_research_agent(
79 | query: str | None = None,
80 | research_scope: str = "standard",
81 | timeout_budget: float = 240.0, # Default timeout for standard research (4 minutes)
82 | max_sources: int = 15,
83 | ) -> DeepResearchAgent:
84 | """
85 | Get or create an optimized research agent with adaptive LLM selection.
86 |
87 | This creates a research agent optimized for the specific query and time constraints,
88 | using adaptive model selection to prevent timeouts while maintaining quality.
89 |
90 | Args:
91 | query: Research query for complexity analysis (optional)
92 | research_scope: Research scope for optimization
93 | timeout_budget: Available timeout budget in seconds
94 | max_sources: Maximum sources to analyze
95 |
96 | Returns:
97 | DeepResearchAgent optimized for the request parameters
98 | """
99 | global research_agent
100 |
101 | # For optimization, create new agents with adaptive LLM selection
102 | # rather than using a singleton when query-specific optimization is needed
103 | if query and timeout_budget < 300:
104 | # Use adaptive optimization for time-constrained requests (less than 5 minutes)
105 | adaptive_llm = _get_adaptive_llm_for_research(
106 | query, research_scope, timeout_budget, max_sources
107 | )
108 |
109 | agent = DeepResearchAgent(
110 | llm=adaptive_llm,
111 | persona="moderate",
112 | max_sources=max_sources,
113 | research_depth=research_scope,
114 | exa_api_key=settings.research.exa_api_key,
115 | )
116 | # Mark for initialization - will be initialized on first use
117 | agent._needs_initialization = True
118 | return agent
119 |
120 | # Use singleton for standard requests
121 | if research_agent is None:
122 | research_agent = DeepResearchAgent(
123 | llm=llm,
124 | persona="moderate",
125 | max_sources=25, # Reduced for faster execution
126 | research_depth="standard", # Reduced depth for speed
127 | exa_api_key=settings.research.exa_api_key,
128 | )
129 | # Mark for initialization - will be initialized on first use
130 | research_agent._needs_initialization = True
131 | return research_agent
132 |
133 |
134 | def _get_timeout_for_research_scope(research_scope: str) -> float:
135 | """
136 | Calculate timeout based on research scope complexity.
137 |
138 | Args:
139 | research_scope: Research scope (basic, standard, comprehensive, exhaustive)
140 |
141 | Returns:
142 | Timeout in seconds appropriate for the research scope
143 | """
144 | timeout_mapping = {
145 | "basic": 120.0, # 2 minutes - generous for basic research
146 | "standard": 240.0, # 4 minutes - standard research with detailed analysis
147 | "comprehensive": 360.0, # 6 minutes - comprehensive research with thorough analysis
148 | "exhaustive": 600.0, # 10 minutes - exhaustive research with validation
149 | }
150 |
151 | return timeout_mapping.get(
152 | research_scope.lower(), 240.0
153 | ) # Default to standard (4 minutes)
154 |
155 |
156 | def _optimize_sources_for_timeout(
157 | research_scope: str, requested_sources: int, timeout_budget: float
158 | ) -> int:
159 | """
160 | Optimize the number of sources based on timeout constraints and research scope.
161 |
162 | This implements intelligent source limiting to maximize quality within time constraints.
163 |
164 | Args:
165 | research_scope: Research scope (basic, standard, comprehensive, exhaustive)
166 | requested_sources: Originally requested number of sources
167 | timeout_budget: Available timeout in seconds
168 |
169 | Returns:
170 | Optimized number of sources that can realistically be processed within timeout
171 | """
172 | # Estimate processing time per source based on scope complexity
173 | processing_time_per_source = {
174 | "basic": 1.5, # 1.5 seconds per source (minimal analysis)
175 | "standard": 2.5, # 2.5 seconds per source (moderate analysis)
176 | "comprehensive": 4.0, # 4 seconds per source (deep analysis)
177 | "exhaustive": 6.0, # 6 seconds per source (maximum analysis)
178 | }
179 |
180 | estimated_time_per_source = processing_time_per_source.get(
181 | research_scope.lower(), 2.5
182 | )
183 |
184 | # Reserve 20% of timeout for search, synthesis, and overhead
185 | available_time_for_sources = timeout_budget * 0.8
186 |
187 | # Calculate maximum sources within timeout
188 | max_sources_for_timeout = int(
189 | available_time_for_sources / estimated_time_per_source
190 | )
191 |
192 | # Apply quality-based limits (better to have fewer high-quality sources)
193 | quality_limits = {
194 | "basic": 8, # Focus on most relevant sources
195 | "standard": 15, # Balanced approach
196 | "comprehensive": 20, # More sources for deep research
197 | "exhaustive": 25, # Maximum sources for exhaustive research
198 | }
199 |
200 | scope_limit = quality_limits.get(research_scope.lower(), 15)
201 |
202 | # Return the minimum of: requested, timeout-constrained, and scope-limited
203 | optimized_sources = min(requested_sources, max_sources_for_timeout, scope_limit)
204 |
205 | # Ensure minimum of 3 sources for meaningful analysis
206 | return max(optimized_sources, 3)
207 |
208 |
209 | def _get_adaptive_llm_for_research(
210 | query: str,
211 | research_scope: str,
212 | timeout_budget: float,
213 | max_sources: int,
214 | ) -> Any:
215 | """
216 | Get an adaptively selected LLM optimized for research performance within timeout constraints.
217 |
218 | This implements intelligent model selection based on:
219 | - Available time budget (timeout pressure)
220 | - Query complexity (inferred from length and scope)
221 | - Research scope requirements
222 | - Number of sources to process
223 |
224 | Args:
225 | query: Research query to analyze complexity
226 | research_scope: Research scope (basic, standard, comprehensive, exhaustive)
227 | timeout_budget: Available timeout in seconds
228 | max_sources: Number of sources to analyze
229 |
230 | Returns:
231 | Optimally selected LLM instance for the research task
232 | """
233 | # Calculate query complexity score (0.0 to 1.0)
234 | complexity_score = 0.0
235 |
236 | # Query length factor (longer queries often indicate complexity)
237 | if len(query) > 200:
238 | complexity_score += 0.3
239 | elif len(query) > 100:
240 | complexity_score += 0.2
241 | elif len(query) > 50:
242 | complexity_score += 0.1
243 |
244 | # Multi-topic queries (multiple companies/concepts)
245 | complexity_keywords = [
246 | "vs",
247 | "versus",
248 | "compare",
249 | "analysis",
250 | "forecast",
251 | "outlook",
252 | "trends",
253 | "market",
254 | "competition",
255 | ]
256 | keyword_matches = sum(
257 | 1 for keyword in complexity_keywords if keyword.lower() in query.lower()
258 | )
259 | complexity_score += min(keyword_matches * 0.1, 0.4)
260 |
261 | # Research scope complexity
262 | scope_complexity = {
263 | "basic": 0.1,
264 | "standard": 0.3,
265 | "comprehensive": 0.6,
266 | "exhaustive": 0.9,
267 | }
268 | complexity_score += scope_complexity.get(research_scope.lower(), 0.3)
269 |
270 | # Source count complexity (more sources = more synthesis required)
271 | if max_sources > 20:
272 | complexity_score += 0.3
273 | elif max_sources > 10:
274 | complexity_score += 0.2
275 | elif max_sources > 5:
276 | complexity_score += 0.1
277 |
278 | # Normalize to 0-1 range
279 | complexity_score = min(complexity_score, 1.0)
280 |
281 | # Time pressure factor (lower means more pressure) - Updated for generous timeouts
282 | time_pressure = 1.0
283 | if timeout_budget < 120:
284 | time_pressure = (
285 | 0.2 # Emergency mode - need fastest models (below basic timeout)
286 | )
287 | elif timeout_budget < 240:
288 | time_pressure = 0.5 # High pressure - prefer fast models (basic to standard)
289 | elif timeout_budget < 360:
290 | time_pressure = (
291 | 0.7 # Moderate pressure - balanced selection (standard to comprehensive)
292 | )
293 | else:
294 | time_pressure = (
295 | 1.0 # Low pressure - can use premium models (comprehensive and above)
296 | )
297 |
298 | # Model selection strategy with timeout budget consideration
299 | if time_pressure <= 0.3 or timeout_budget < 120:
300 | # Emergency mode: prioritize speed above all for <120s timeouts (below basic)
301 | logger.info(
302 | f"Emergency fast model selection triggered - timeout budget: {timeout_budget}s"
303 | )
304 | return get_llm(
305 | task_type=TaskType.DEEP_RESEARCH,
306 | prefer_fast=True,
307 | prefer_cheap=True, # Ultra-fast models (GPT-5 Nano, Claude 3.5 Haiku, DeepSeek R1)
308 | prefer_quality=False,
309 | # Emergency mode triggered for timeout_budget < 30s
310 | )
311 | elif time_pressure <= 0.6 and complexity_score <= 0.4:
312 | # Fast mode for simple queries: speed-optimized but decent quality
313 | return get_llm(
314 | task_type=TaskType.DEEP_RESEARCH,
315 | prefer_fast=True,
316 | prefer_cheap=True,
317 | prefer_quality=False,
318 | # Fast mode for simple queries under time pressure
319 | )
320 | elif complexity_score >= 0.7 and time_pressure >= 0.8:
321 | # Complex query with time available: use premium models
322 | return get_llm(
323 | task_type=TaskType.DEEP_RESEARCH,
324 | prefer_fast=False,
325 | prefer_cheap=False,
326 | prefer_quality=True, # Premium models for complex tasks
327 | )
328 | else:
329 | # Balanced approach: cost-effective quality models
330 | return get_llm(
331 | task_type=TaskType.DEEP_RESEARCH,
332 | prefer_fast=False,
333 | prefer_cheap=True, # Default cost-effective
334 | prefer_quality=False,
335 | )
336 |
337 |
338 | async def _execute_research_with_direct_timeout(
339 | agent,
340 | query: str,
341 | session_id: str,
342 | research_scope: str,
343 | max_sources: int,
344 | timeframe: str,
345 | total_timeout: float,
346 | tool_logger,
347 | ) -> dict[str, Any]:
348 | """
349 | Execute research with direct timeout enforcement using asyncio.wait_for.
350 |
351 | This function provides hard timeout enforcement and graceful failure handling.
352 | """
353 | start_time = asyncio.get_event_loop().time()
354 |
355 | # Granular timing for bottleneck identification
356 | timing_log = {
357 | "research_start": start_time,
358 | "phase_timings": {},
359 | "cumulative_time": 0.0,
360 | }
361 |
362 | def log_phase_timing(phase_name: str):
363 | """Log timing for a specific research phase."""
364 | current_time = asyncio.get_event_loop().time()
365 | phase_duration = current_time - start_time - timing_log["cumulative_time"]
366 | timing_log["phase_timings"][phase_name] = {
367 | "duration": phase_duration,
368 | "cumulative": current_time - start_time,
369 | }
370 | timing_log["cumulative_time"] = current_time - start_time
371 | logger.debug(
372 | f"TIMING: {phase_name} took {phase_duration:.2f}s (cumulative: {timing_log['cumulative_time']:.2f}s)"
373 | )
374 |
375 | try:
376 | tool_logger.step(
377 | "timeout_enforcement",
378 | f"Starting research with {total_timeout}s hard timeout",
379 | )
380 | log_phase_timing("initialization")
381 |
382 | # Use direct asyncio.wait_for for hard timeout enforcement
383 | logger.info(
384 | f"TIMING: Starting research execution phase (budget: {total_timeout}s)"
385 | )
386 |
387 | result = await asyncio.wait_for(
388 | agent.research_topic(
389 | query=query,
390 | session_id=session_id,
391 | research_scope=research_scope,
392 | max_sources=max_sources,
393 | timeframe=timeframe,
394 | timeout_budget=total_timeout, # Pass timeout budget for phase allocation
395 | ),
396 | timeout=total_timeout,
397 | )
398 |
399 | log_phase_timing("research_execution")
400 |
401 | elapsed_time = asyncio.get_event_loop().time() - start_time
402 | tool_logger.step(
403 | "research_completed", f"Research completed in {elapsed_time:.1f}s"
404 | )
405 |
406 | # Log detailed timing breakdown
407 | logger.info(
408 | f"RESEARCH_TIMING_BREAKDOWN: "
409 | f"Total={elapsed_time:.2f}s, "
410 | f"Phases={timing_log['phase_timings']}"
411 | )
412 |
413 | # Add timing information to successful results
414 | if isinstance(result, dict):
415 | result["elapsed_time"] = elapsed_time
416 | result["timeout_warning"] = elapsed_time >= (total_timeout * 0.8)
417 |
418 | return result
419 |
420 | except TimeoutError:
421 | elapsed_time = asyncio.get_event_loop().time() - start_time
422 | log_phase_timing("timeout_exceeded")
423 |
424 | # Log timeout timing analysis
425 | logger.warning(
426 | f"RESEARCH_TIMEOUT: "
427 | f"Exceeded {total_timeout}s limit after {elapsed_time:.2f}s, "
428 | f"Phases={timing_log['phase_timings']}"
429 | )
430 |
431 | tool_logger.step(
432 | "timeout_exceeded",
433 | f"Research timed out after {elapsed_time:.1f}s (limit: {total_timeout}s)",
434 | )
435 |
436 | # Return structured timeout response instead of raising
437 | return {
438 | "status": "timeout",
439 | "content": f"Research operation timed out after {total_timeout} seconds",
440 | "research_confidence": 0.0,
441 | "sources_found": 0,
442 | "timeout_warning": True,
443 | "elapsed_time": elapsed_time,
444 | "completion_percentage": 0,
445 | "timing_breakdown": timing_log["phase_timings"],
446 | "actionable_insights": [
447 | "Research was terminated due to timeout",
448 | "Consider reducing scope or query complexity",
449 | f"Try using 'basic' or 'standard' scope instead of '{research_scope}'",
450 | ],
451 | "content_analysis": {
452 | "consensus_view": {
453 | "direction": "neutral",
454 | "confidence": 0.0,
455 | },
456 | "key_themes": ["Timeout occurred"],
457 | "contrarian_views": [],
458 | },
459 | "persona_insights": {
460 | "summary": "Analysis terminated due to timeout - consider simplifying the query"
461 | },
462 | "error": "timeout_exceeded",
463 | }
464 |
465 | except asyncio.CancelledError:
466 | tool_logger.step("research_cancelled", "Research operation was cancelled")
467 | raise
468 | except Exception as e:
469 | elapsed_time = asyncio.get_event_loop().time() - start_time
470 | tool_logger.error("research_execution_error", e)
471 |
472 | # Return structured error response
473 | return {
474 | "status": "error",
475 | "content": f"Research failed due to error: {str(e)}",
476 | "research_confidence": 0.0,
477 | "sources_found": 0,
478 | "timeout_warning": False,
479 | "elapsed_time": elapsed_time,
480 | "completion_percentage": 0,
481 | "error": str(e),
482 | "error_type": type(e).__name__,
483 | }
484 |
485 |
486 | async def comprehensive_research(
487 | query: str,
488 | persona: str = "moderate",
489 | research_scope: str = "standard",
490 | max_sources: int = 15,
491 | timeframe: str = "1m",
492 | ) -> dict[str, Any]:
493 | """
494 | Enhanced comprehensive research with adaptive timeout protection and step-by-step logging.
495 |
496 | This tool provides reliable research capabilities with:
497 | - Generous timeout based on research scope (basic: 120s, standard: 240s, comprehensive: 360s, exhaustive: 600s)
498 | - Step-by-step execution logging
499 | - Guaranteed JSON-RPC responses
500 | - Optimized scope for faster execution
501 | - Circuit breaker protection
502 |
503 | Args:
504 | query: Research query or topic
505 | persona: Investor persona (conservative, moderate, aggressive, day_trader)
506 | research_scope: Research scope (basic, standard, comprehensive, exhaustive)
507 | max_sources: Maximum sources to analyze (reduced to 15 for speed)
508 | timeframe: Time frame for search (1d, 1w, 1m, 3m)
509 |
510 | Returns:
511 | Dictionary containing research results or error information
512 | """
513 | tool_logger = get_tool_logger("comprehensive_research")
514 | request_id = str(uuid.uuid4())
515 |
516 | # Log incoming parameters
517 | logger.info(
518 | f"📥 RESEARCH_REQUEST: query='{query[:50]}...', scope='{research_scope}', max_sources={max_sources}, timeframe='{timeframe}'"
519 | )
520 |
521 | try:
522 | # Step 1: Calculate optimization parameters first
523 | tool_logger.step(
524 | "optimization_calculation",
525 | f"Calculating adaptive optimization parameters for scope='{research_scope}' with max_sources={max_sources}",
526 | )
527 | adaptive_timeout = _get_timeout_for_research_scope(research_scope)
528 | optimized_sources = _optimize_sources_for_timeout(
529 | research_scope, max_sources, adaptive_timeout
530 | )
531 |
532 | # Log the timeout calculation result explicitly
533 | logger.info(
534 | f"🔧 TIMEOUT_CONFIGURATION: scope='{research_scope}' → timeout={adaptive_timeout}s (was requesting {max_sources} sources, optimized to {optimized_sources})"
535 | )
536 |
537 | # Step 2: Log optimization setup (components initialized in underlying research system)
538 | tool_logger.step(
539 | "optimization_setup",
540 | f"Configuring LLM optimizations (budget: {adaptive_timeout}s, parallel: {optimized_sources > 3})",
541 | )
542 |
543 | # Step 3: Initialize agent with adaptive optimizations
544 | tool_logger.step(
545 | "agent_initialization",
546 | f"Initializing optimized research agent (timeout: {adaptive_timeout}s, sources: {optimized_sources})",
547 | )
548 | agent = get_research_agent(
549 | query=query,
550 | research_scope=research_scope,
551 | timeout_budget=adaptive_timeout,
552 | max_sources=optimized_sources,
553 | )
554 |
555 | # Set persona if provided
556 | if persona in ["conservative", "moderate", "aggressive", "day_trader"]:
557 | agent.persona = INVESTOR_PERSONAS.get(
558 | persona, INVESTOR_PERSONAS["moderate"]
559 | )
560 |
561 | # Step 4: Early validation of search provider configuration
562 | tool_logger.step(
563 | "provider_validation", "Validating search provider configuration"
564 | )
565 |
566 | # Check for API key before creating agent (faster failure)
567 | exa_available = bool(settings.research.exa_api_key)
568 |
569 | if not exa_available:
570 | return {
571 | "success": False,
572 | "error": "Research functionality unavailable - Exa search provider not configured",
573 | "details": {
574 | "required_configuration": "Exa search provider API key is required",
575 | "exa_api_key": "Missing (configure EXA_API_KEY environment variable)",
576 | "setup_instructions": "Get a free API key from: Exa (exa.ai)",
577 | },
578 | "query": query,
579 | "request_id": request_id,
580 | "timestamp": datetime.now().isoformat(),
581 | }
582 |
583 | # Log available provider
584 | tool_logger.step(
585 | "provider_available",
586 | "Exa search provider available",
587 | )
588 |
589 | session_id = f"enhanced_research_{datetime.now().timestamp()}"
590 | tool_logger.step(
591 | "source_optimization",
592 | f"Optimized sources: {max_sources} → {optimized_sources} for {research_scope} scope within {adaptive_timeout}s",
593 | )
594 | tool_logger.step(
595 | "research_execution",
596 | f"Starting progressive research with session {session_id[:12]} (timeout: {adaptive_timeout}s, sources: {optimized_sources})",
597 | )
598 |
599 | # Execute with direct timeout enforcement for reliable operation
600 | result = await _execute_research_with_direct_timeout(
601 | agent=agent,
602 | query=query,
603 | session_id=session_id,
604 | research_scope=research_scope,
605 | max_sources=optimized_sources, # Use optimized source count
606 | timeframe=timeframe,
607 | total_timeout=adaptive_timeout,
608 | tool_logger=tool_logger,
609 | )
610 |
611 | # Step 4: Process results
612 | tool_logger.step("result_processing", "Processing research results")
613 |
614 | # Handle timeout or error results
615 | if result.get("status") == "timeout":
616 | return {
617 | "success": False,
618 | "error": "Research operation timed out",
619 | "timeout_details": {
620 | "timeout_seconds": adaptive_timeout,
621 | "elapsed_time": result.get("elapsed_time", 0),
622 | "suggestions": result.get("actionable_insights", []),
623 | },
624 | "query": query,
625 | "request_id": request_id,
626 | "timestamp": datetime.now().isoformat(),
627 | }
628 |
629 | if result.get("status") == "error" or "error" in result:
630 | return {
631 | "success": False,
632 | "error": result.get("error", "Unknown research error"),
633 | "error_type": result.get("error_type", "UnknownError"),
634 | "query": query,
635 | "request_id": request_id,
636 | "timestamp": datetime.now().isoformat(),
637 | }
638 |
639 | # Step 5: Format response with timeout support
640 | tool_logger.step("response_formatting", "Formatting final response")
641 |
642 | # Check if this is a partial result or has warnings
643 | is_partial = result.get("status") == "partial_success"
644 | has_timeout_warning = result.get("timeout_warning", False)
645 |
646 | response = {
647 | "success": True,
648 | "query": query,
649 | "research_results": {
650 | "summary": result.get("content", "Research completed successfully"),
651 | "confidence_score": result.get("research_confidence", 0.0),
652 | "sources_analyzed": result.get("sources_found", 0),
653 | "key_insights": result.get("actionable_insights", [])[
654 | :5
655 | ], # Limit for size
656 | "sentiment": result.get("content_analysis", {}).get(
657 | "consensus_view", {}
658 | ),
659 | "key_themes": result.get("content_analysis", {}).get("key_themes", [])[
660 | :3
661 | ],
662 | },
663 | "research_metadata": {
664 | "persona": persona,
665 | "scope": research_scope,
666 | "timeframe": timeframe,
667 | "max_sources_requested": max_sources,
668 | "max_sources_optimized": optimized_sources,
669 | "sources_actually_used": result.get("sources_found", optimized_sources),
670 | "execution_mode": "progressive_timeout_protected",
671 | "is_partial_result": is_partial,
672 | "timeout_warning": has_timeout_warning,
673 | "elapsed_time": result.get("elapsed_time", 0),
674 | "completion_percentage": result.get(
675 | "completion_percentage", 100 if not is_partial else 60
676 | ),
677 | "optimization_features": [
678 | "adaptive_model_selection",
679 | "progressive_token_budgeting",
680 | "parallel_llm_processing",
681 | "intelligent_source_optimization",
682 | "timeout_monitoring",
683 | ],
684 | "parallel_processing": {
685 | "enabled": True,
686 | "max_concurrent_requests": min(4, optimized_sources // 2 + 1),
687 | "batch_processing": optimized_sources > 3,
688 | },
689 | },
690 | "request_id": request_id,
691 | "timestamp": datetime.now().isoformat(),
692 | }
693 |
694 | # Add warning message for partial results
695 | if is_partial:
696 | response["warning"] = {
697 | "type": "partial_result",
698 | "message": "Research was partially completed due to timeout constraints",
699 | "suggestions": [
700 | f"Try reducing research scope from '{research_scope}' to 'standard' or 'basic'",
701 | f"Reduce max_sources from {max_sources} to {min(15, optimized_sources)} or fewer",
702 | "Use more specific keywords to focus the search",
703 | f"Note: Sources were automatically optimized from {max_sources} to {optimized_sources} for better performance",
704 | ],
705 | }
706 | elif has_timeout_warning:
707 | response["warning"] = {
708 | "type": "timeout_warning",
709 | "message": "Research completed but took longer than expected",
710 | "suggestions": [
711 | "Consider reducing scope for faster results in the future"
712 | ],
713 | }
714 |
715 | tool_logger.complete(f"Research completed for query: {query[:50]}")
716 | return response
717 |
718 | except TimeoutError:
719 | # Calculate timeout for error reporting
720 | used_timeout = _get_timeout_for_research_scope(research_scope)
721 | tool_logger.error(
722 | "research_timeout",
723 | TimeoutError(f"Research operation timed out after {used_timeout}s"),
724 | )
725 | # Calculate optimized sources for error reporting
726 | timeout_optimized_sources = _optimize_sources_for_timeout(
727 | research_scope, max_sources, used_timeout
728 | )
729 |
730 | return {
731 | "success": False,
732 | "error": f"Research operation timed out after {used_timeout} seconds",
733 | "details": f"Consider using a more specific query, reducing the scope from '{research_scope}', or decreasing max_sources from {max_sources}",
734 | "suggestions": {
735 | "reduce_scope": "Try 'basic' or 'standard' instead of 'comprehensive'",
736 | "reduce_sources": f"Try max_sources={min(10, timeout_optimized_sources)} instead of {max_sources}",
737 | "narrow_query": "Use more specific keywords to focus the search",
738 | },
739 | "optimization_info": {
740 | "sources_requested": max_sources,
741 | "sources_auto_optimized": timeout_optimized_sources,
742 | "note": "Sources were automatically reduced for better performance, but timeout still occurred",
743 | },
744 | "query": query,
745 | "request_id": request_id,
746 | "timeout_seconds": used_timeout,
747 | "research_scope": research_scope,
748 | "timestamp": datetime.now().isoformat(),
749 | }
750 | except Exception as e:
751 | tool_logger.error(
752 | "research_error", e, f"Unexpected error in research: {str(e)}"
753 | )
754 | return {
755 | "success": False,
756 | "error": f"Research error: {str(e)}",
757 | "error_type": type(e).__name__,
758 | "query": query,
759 | "request_id": request_id,
760 | "timestamp": datetime.now().isoformat(),
761 | }
762 |
763 |
764 | async def company_comprehensive_research(
765 | symbol: str,
766 | include_competitive_analysis: bool = False, # Disabled by default for speed
767 | persona: str = "moderate",
768 | ) -> dict[str, Any]:
769 | """
770 | Enhanced company research with timeout protection and optimized scope.
771 |
772 | This tool provides reliable company analysis with:
773 | - Adaptive timeout protection
774 | - Streamlined analysis for faster execution
775 | - Step-by-step logging for debugging
776 | - Guaranteed responses to Claude Desktop
777 | - Focus on core financial metrics
778 |
779 | Args:
780 | symbol: Stock ticker symbol
781 | include_competitive_analysis: Include competitive analysis (disabled for speed)
782 | persona: Investor persona for analysis perspective
783 |
784 | Returns:
785 | Dictionary containing company research results or error information
786 | """
787 | tool_logger = get_tool_logger("company_comprehensive_research")
788 | request_id = str(uuid.uuid4())
789 |
790 | try:
791 | # Step 1: Initialize and validate
792 | tool_logger.step("initialization", f"Starting company research for {symbol}")
793 |
794 | # Create focused research query
795 | query = f"{symbol} stock financial analysis outlook 2025"
796 |
797 | # Execute streamlined research
798 | result = await comprehensive_research(
799 | query=query,
800 | persona=persona,
801 | research_scope="standard", # Focused scope
802 | max_sources=10, # Reduced sources for speed
803 | timeframe="1m",
804 | )
805 |
806 | # Step 2: Enhance with symbol-specific formatting
807 | tool_logger.step("formatting", "Formatting company-specific response")
808 |
809 | if not result.get("success", False):
810 | return {
811 | **result,
812 | "symbol": symbol,
813 | "analysis_type": "company_comprehensive",
814 | }
815 |
816 | # Reformat for company analysis
817 | company_response = {
818 | "success": True,
819 | "symbol": symbol,
820 | "company_analysis": {
821 | "investment_summary": result["research_results"].get("summary", ""),
822 | "confidence_score": result["research_results"].get(
823 | "confidence_score", 0.0
824 | ),
825 | "key_insights": result["research_results"].get("key_insights", []),
826 | "financial_sentiment": result["research_results"].get("sentiment", {}),
827 | "analysis_themes": result["research_results"].get("key_themes", []),
828 | "sources_analyzed": result["research_results"].get(
829 | "sources_analyzed", 0
830 | ),
831 | },
832 | "analysis_metadata": {
833 | **result["research_metadata"],
834 | "symbol": symbol,
835 | "competitive_analysis_included": include_competitive_analysis,
836 | "analysis_type": "company_comprehensive",
837 | },
838 | "request_id": request_id,
839 | "timestamp": datetime.now().isoformat(),
840 | }
841 |
842 | tool_logger.complete(f"Company analysis completed for {symbol}")
843 | return company_response
844 |
845 | except Exception as e:
846 | tool_logger.error(
847 | "company_research_error", e, f"Company research failed: {str(e)}"
848 | )
849 | return {
850 | "success": False,
851 | "error": f"Company research error: {str(e)}",
852 | "error_type": type(e).__name__,
853 | "symbol": symbol,
854 | "request_id": request_id,
855 | "timestamp": datetime.now().isoformat(),
856 | }
857 |
858 |
859 | async def analyze_market_sentiment(
860 | topic: str, timeframe: str = "1w", persona: str = "moderate"
861 | ) -> dict[str, Any]:
862 | """
863 | Enhanced market sentiment analysis with timeout protection.
864 |
865 | Provides fast, reliable sentiment analysis with:
866 | - Adaptive timeout protection
867 | - Focused sentiment extraction
868 | - Step-by-step logging
869 | - Guaranteed responses
870 |
871 | Args:
872 | topic: Topic for sentiment analysis
873 | timeframe: Time frame for analysis
874 | persona: Investor persona
875 |
876 | Returns:
877 | Dictionary containing sentiment analysis results
878 | """
879 | tool_logger = get_tool_logger("analyze_market_sentiment")
880 | request_id = str(uuid.uuid4())
881 |
882 | try:
883 | # Step 1: Create sentiment-focused query
884 | tool_logger.step("query_creation", f"Creating sentiment query for {topic}")
885 |
886 | sentiment_query = f"{topic} market sentiment analysis investor opinion"
887 |
888 | # Step 2: Execute focused research
889 | result = await comprehensive_research(
890 | query=sentiment_query,
891 | persona=persona,
892 | research_scope="basic", # Minimal scope for sentiment
893 | max_sources=8, # Reduced for speed
894 | timeframe=timeframe,
895 | )
896 |
897 | # Step 3: Format sentiment response
898 | tool_logger.step("sentiment_formatting", "Extracting sentiment data")
899 |
900 | if not result.get("success", False):
901 | return {
902 | **result,
903 | "topic": topic,
904 | "analysis_type": "market_sentiment",
905 | }
906 |
907 | sentiment_response = {
908 | "success": True,
909 | "topic": topic,
910 | "sentiment_analysis": {
911 | "overall_sentiment": result["research_results"].get("sentiment", {}),
912 | "sentiment_confidence": result["research_results"].get(
913 | "confidence_score", 0.0
914 | ),
915 | "key_themes": result["research_results"].get("key_themes", []),
916 | "market_insights": result["research_results"].get("key_insights", [])[
917 | :3
918 | ],
919 | "sources_analyzed": result["research_results"].get(
920 | "sources_analyzed", 0
921 | ),
922 | },
923 | "analysis_metadata": {
924 | **result["research_metadata"],
925 | "topic": topic,
926 | "analysis_type": "market_sentiment",
927 | },
928 | "request_id": request_id,
929 | "timestamp": datetime.now().isoformat(),
930 | }
931 |
932 | tool_logger.complete(f"Sentiment analysis completed for {topic}")
933 | return sentiment_response
934 |
935 | except Exception as e:
936 | tool_logger.error("sentiment_error", e, f"Sentiment analysis failed: {str(e)}")
937 | return {
938 | "success": False,
939 | "error": f"Sentiment analysis error: {str(e)}",
940 | "error_type": type(e).__name__,
941 | "topic": topic,
942 | "request_id": request_id,
943 | "timestamp": datetime.now().isoformat(),
944 | }
945 |
946 |
947 | def create_research_router(mcp: FastMCP | None = None) -> FastMCP:
948 | """Create and configure the research router."""
949 |
950 | if mcp is None:
951 | mcp = FastMCP("Deep Research Tools")
952 |
953 | @mcp.tool()
954 | async def research_comprehensive_research(
955 | query: str,
956 | persona: str | None = "moderate",
957 | research_scope: str | None = "standard",
958 | max_sources: int | None = 10,
959 | timeframe: str | None = "1m",
960 | ) -> dict[str, Any]:
961 | """
962 | Perform comprehensive research on any financial topic using web search and AI analysis.
963 |
964 | Enhanced features:
965 | - Generous timeout (basic: 120s, standard: 240s, comprehensive: 360s, exhaustive: 600s)
966 | - Intelligent source optimization
967 | - Parallel LLM processing
968 | - Progressive token budgeting
969 | - Partial results on timeout
970 |
971 | Args:
972 | query: Research query or topic
973 | persona: Investor persona (conservative, moderate, aggressive, day_trader)
974 | research_scope: Research scope (basic, standard, comprehensive, exhaustive)
975 | max_sources: Maximum sources to analyze (1-50)
976 | timeframe: Time frame for search (1d, 1w, 1m, 3m)
977 |
978 | Returns:
979 | Comprehensive research results with insights, sentiment, and recommendations
980 | """
981 | # CRITICAL DEBUG: Log immediately when tool is called
982 | logger.error(
983 | f"🚨 TOOL CALLED: research_comprehensive_research with query: {query[:50]}"
984 | )
985 |
986 | # Log tool invocation
987 | log_tool_invocation(
988 | "research_comprehensive_research",
989 | {
990 | "query": query[:100], # Truncate for logging
991 | "persona": persona,
992 | "research_scope": research_scope,
993 | "max_sources": max_sources,
994 | },
995 | )
996 |
997 | start_time = datetime.now()
998 |
999 | try:
1000 | # Execute enhanced research
1001 | result = await comprehensive_research(
1002 | query=query,
1003 | persona=persona or "moderate",
1004 | research_scope=research_scope or "standard",
1005 | max_sources=max_sources or 15,
1006 | timeframe=timeframe or "1m",
1007 | )
1008 |
1009 | # Calculate execution metrics
1010 | execution_time = (datetime.now() - start_time).total_seconds() * 1000
1011 |
1012 | # Log performance metrics
1013 | log_performance_metrics(
1014 | "research_comprehensive_research",
1015 | {
1016 | "execution_time_ms": execution_time,
1017 | "sources_analyzed": result.get("research_results", {}).get(
1018 | "sources_analyzed", 0
1019 | ),
1020 | "confidence_score": result.get("research_results", {}).get(
1021 | "confidence_score", 0.0
1022 | ),
1023 | "success": result.get("success", False),
1024 | },
1025 | )
1026 |
1027 | return result
1028 |
1029 | except Exception as e:
1030 | logger.error(
1031 | f"Research error: {str(e)}",
1032 | exc_info=True,
1033 | extra={"query": query[:100]},
1034 | )
1035 | return {
1036 | "success": False,
1037 | "error": f"Research failed: {str(e)}",
1038 | "error_type": type(e).__name__,
1039 | "query": query,
1040 | "timestamp": datetime.now().isoformat(),
1041 | }
1042 |
1043 | @mcp.tool()
1044 | async def research_company_comprehensive(
1045 | symbol: str,
1046 | include_competitive_analysis: bool = False,
1047 | persona: str | None = "moderate",
1048 | ) -> dict[str, Any]:
1049 | """
1050 | Perform comprehensive research on a specific company.
1051 |
1052 | Features:
1053 | - Financial metrics analysis
1054 | - Market sentiment assessment
1055 | - Competitive positioning
1056 | - Investment recommendations
1057 |
1058 | Args:
1059 | symbol: Stock ticker symbol
1060 | include_competitive_analysis: Include competitive analysis
1061 | persona: Investor persona for analysis perspective
1062 |
1063 | Returns:
1064 | Company-specific research with financial insights
1065 | """
1066 | return await company_comprehensive_research(
1067 | symbol=symbol,
1068 | include_competitive_analysis=include_competitive_analysis,
1069 | persona=persona or "moderate",
1070 | )
1071 |
1072 | @mcp.tool()
1073 | async def research_analyze_market_sentiment(
1074 | topic: str,
1075 | timeframe: str | None = "1w",
1076 | persona: str | None = "moderate",
1077 | ) -> dict[str, Any]:
1078 | """
1079 | Analyze market sentiment for a specific topic or sector.
1080 |
1081 | Features:
1082 | - Real-time sentiment extraction
1083 | - News and social media analysis
1084 | - Investor opinion aggregation
1085 | - Trend identification
1086 |
1087 | Args:
1088 | topic: Topic for sentiment analysis
1089 | timeframe: Time frame for analysis
1090 | persona: Investor persona
1091 |
1092 | Returns:
1093 | Sentiment analysis with market insights
1094 | """
1095 | return await analyze_market_sentiment(
1096 | topic=topic,
1097 | timeframe=timeframe or "1w",
1098 | persona=persona or "moderate",
1099 | )
1100 |
1101 | return mcp
1102 |
1103 |
1104 | # Create the router instance
1105 | research_router = create_research_router()
1106 |
```
--------------------------------------------------------------------------------
/maverick_mcp/api/routers/portfolio.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Portfolio analysis router for Maverick-MCP.
3 |
4 | This module contains all portfolio-related tools including:
5 | - Portfolio management (add, get, remove, clear positions)
6 | - Risk analysis and comparisons
7 | - Optimization functions
8 | """
9 |
10 | import logging
11 | from datetime import UTC, datetime, timedelta
12 | from decimal import Decimal
13 | from typing import Any
14 |
15 | import pandas as pd
16 | import pandas_ta as ta
17 | from fastmcp import FastMCP
18 | from sqlalchemy.orm import Session
19 |
20 | from maverick_mcp.data.models import PortfolioPosition, UserPortfolio, get_db
21 | from maverick_mcp.domain.portfolio import Portfolio
22 | from maverick_mcp.providers.stock_data import StockDataProvider
23 | from maverick_mcp.utils.stock_helpers import get_stock_dataframe
24 |
25 | logger = logging.getLogger(__name__)
26 |
27 | # Create the portfolio router
28 | portfolio_router: FastMCP = FastMCP("Portfolio_Analysis")
29 |
30 | # Initialize data provider
31 | stock_provider = StockDataProvider()
32 |
33 |
34 | def _normalize_ticker(ticker: str) -> str:
35 | """Normalize ticker symbol to uppercase and strip whitespace."""
36 | return ticker.strip().upper()
37 |
38 |
39 | def _validate_ticker(ticker: str) -> tuple[bool, str | None]:
40 | """
41 | Validate ticker symbol format.
42 |
43 | Returns:
44 | Tuple of (is_valid, error_message)
45 | """
46 | if not ticker or not ticker.strip():
47 | return False, "Ticker symbol cannot be empty"
48 |
49 | normalized = ticker.strip().upper()
50 |
51 | # Basic validation: 1-5 alphanumeric characters
52 | if not normalized.isalnum():
53 | return (
54 | False,
55 | f"Invalid ticker symbol '{ticker}': must contain only letters and numbers",
56 | )
57 |
58 | if len(normalized) > 10:
59 | return False, f"Invalid ticker symbol '{ticker}': too long (max 10 characters)"
60 |
61 | return True, None
62 |
63 |
64 | def risk_adjusted_analysis(
65 | ticker: str,
66 | risk_level: float | str | None = 50.0,
67 | user_id: str = "default",
68 | portfolio_name: str = "My Portfolio",
69 | ) -> dict[str, Any]:
70 | """
71 | Perform risk-adjusted stock analysis with position sizing.
72 |
73 | DISCLAIMER: This analysis is for educational purposes only and does not
74 | constitute investment advice. All investments carry risk of loss. Always
75 | consult with qualified financial professionals before making investment decisions.
76 |
77 | This tool analyzes a stock with risk parameters tailored to different investment
78 | styles. It provides:
79 | - Position sizing recommendations based on ATR
80 | - Stop loss suggestions
81 | - Entry points with scaling
82 | - Risk/reward ratio calculations
83 | - Confidence score based on technicals
84 |
85 | **Portfolio Integration:** If you already own this stock, the analysis includes:
86 | - Current position details (shares, cost basis, unrealized P&L)
87 | - Position sizing relative to existing holdings
88 | - Recommendations for averaging up/down
89 |
90 | The risk_level parameter (0-100) adjusts the analysis from conservative (low)
91 | to aggressive (high).
92 |
93 | Args:
94 | ticker: The ticker symbol to analyze
95 | risk_level: Risk tolerance from 0 (conservative) to 100 (aggressive)
96 | user_id: User identifier (defaults to "default")
97 | portfolio_name: Portfolio name (defaults to "My Portfolio")
98 |
99 | Returns:
100 | Dictionary containing risk-adjusted analysis results with optional position context
101 | """
102 | try:
103 | # Convert risk_level to float if it's a string
104 | if isinstance(risk_level, str):
105 | try:
106 | risk_level = float(risk_level)
107 | except ValueError:
108 | risk_level = 50.0
109 |
110 | # Use explicit date range to avoid weekend/holiday issues
111 | from datetime import UTC, datetime, timedelta
112 |
113 | end_date = (datetime.now(UTC) - timedelta(days=7)).strftime(
114 | "%Y-%m-%d"
115 | ) # Last week to be safe
116 | start_date = (datetime.now(UTC) - timedelta(days=365)).strftime(
117 | "%Y-%m-%d"
118 | ) # 1 year ago
119 | df = stock_provider.get_stock_data(
120 | ticker, start_date=start_date, end_date=end_date
121 | )
122 |
123 | # Validate dataframe has required columns (check for both upper and lower case)
124 | required_cols = ["high", "low", "close"]
125 | actual_cols_lower = [col.lower() for col in df.columns]
126 | if df.empty or not all(col in actual_cols_lower for col in required_cols):
127 | return {
128 | "error": f"Insufficient data for {ticker}",
129 | "details": "Unable to retrieve required price data (High, Low, Close) for analysis",
130 | "ticker": ticker,
131 | "required_data": ["High", "Low", "Close", "Volume"],
132 | "available_columns": list(df.columns),
133 | }
134 |
135 | df["atr"] = ta.atr(df["High"], df["Low"], df["Close"], length=20)
136 | atr = df["atr"].iloc[-1]
137 | current_price = df["Close"].iloc[-1]
138 | risk_factor = (risk_level or 50.0) / 100 # Convert to 0-1 scale
139 | account_size = 100000
140 | analysis = {
141 | "ticker": ticker,
142 | "current_price": round(current_price, 2),
143 | "atr": round(atr, 2),
144 | "risk_level": risk_level,
145 | "position_sizing": {
146 | "suggested_position_size": round(account_size * 0.01 * risk_factor, 2),
147 | "max_shares": int((account_size * 0.01 * risk_factor) / current_price),
148 | "position_value": round(account_size * 0.01 * risk_factor, 2),
149 | "percent_of_portfolio": round(1 * risk_factor, 2),
150 | },
151 | "risk_management": {
152 | "stop_loss": round(current_price - (atr * (2 - risk_factor)), 2),
153 | "stop_loss_percent": round(
154 | ((atr * (2 - risk_factor)) / current_price) * 100, 2
155 | ),
156 | "max_risk_amount": round(account_size * 0.01 * risk_factor, 2),
157 | },
158 | "entry_strategy": {
159 | "immediate_entry": round(current_price, 2),
160 | "scale_in_levels": [
161 | round(current_price, 2),
162 | round(current_price - (atr * 0.5), 2),
163 | round(current_price - atr, 2),
164 | ],
165 | },
166 | "targets": {
167 | "price_target": round(current_price + (atr * 3 * risk_factor), 2),
168 | "profit_potential": round(atr * 3 * risk_factor, 2),
169 | "risk_reward_ratio": round(3 * risk_factor, 2),
170 | },
171 | "analysis": {
172 | "confidence_score": round(70 * risk_factor, 2),
173 | "strategy_type": "aggressive"
174 | if (risk_level or 50.0) > 70
175 | else "moderate"
176 | if (risk_level or 50.0) > 30
177 | else "conservative",
178 | "time_horizon": "short-term"
179 | if (risk_level or 50.0) > 70
180 | else "medium-term"
181 | if (risk_level or 50.0) > 30
182 | else "long-term",
183 | },
184 | }
185 |
186 | # Check if user already owns this position
187 | db: Session = next(get_db())
188 | try:
189 | portfolio = (
190 | db.query(UserPortfolio)
191 | .filter(
192 | UserPortfolio.user_id == user_id,
193 | UserPortfolio.name == portfolio_name,
194 | )
195 | .first()
196 | )
197 |
198 | if portfolio:
199 | existing_position = next(
200 | (
201 | pos
202 | for pos in portfolio.positions
203 | if pos.ticker.upper() == ticker.upper()
204 | ),
205 | None,
206 | )
207 |
208 | if existing_position:
209 | # Calculate unrealized P&L
210 | unrealized_pnl = (
211 | current_price - float(existing_position.average_cost_basis)
212 | ) * float(existing_position.shares)
213 | unrealized_pnl_pct = (
214 | (current_price - float(existing_position.average_cost_basis))
215 | / float(existing_position.average_cost_basis)
216 | ) * 100
217 |
218 | analysis["existing_position"] = {
219 | "shares_owned": float(existing_position.shares),
220 | "average_cost_basis": float(
221 | existing_position.average_cost_basis
222 | ),
223 | "total_invested": float(existing_position.total_cost),
224 | "current_value": float(existing_position.shares)
225 | * current_price,
226 | "unrealized_pnl": round(unrealized_pnl, 2),
227 | "unrealized_pnl_pct": round(unrealized_pnl_pct, 2),
228 | "position_recommendation": "Consider averaging down"
229 | if current_price < float(existing_position.average_cost_basis)
230 | else "Consider taking partial profits"
231 | if unrealized_pnl_pct > 20
232 | else "Hold current position",
233 | }
234 | finally:
235 | db.close()
236 |
237 | return analysis
238 | except Exception as e:
239 | logger.error(f"Error performing risk analysis for {ticker}: {e}")
240 | return {"error": str(e)}
241 |
242 |
243 | def compare_tickers(
244 | tickers: list[str] | None = None,
245 | days: int = 90,
246 | user_id: str = "default",
247 | portfolio_name: str = "My Portfolio",
248 | ) -> dict[str, Any]:
249 | """
250 | Compare multiple tickers using technical and fundamental metrics.
251 |
252 | This tool provides side-by-side comparison of stocks including:
253 | - Price performance
254 | - Technical indicators (RSI, trend strength)
255 | - Volume characteristics
256 | - Momentum strength ratings
257 | - Risk metrics
258 |
259 | **Portfolio Integration:** If no tickers are provided, automatically compares
260 | all positions in your portfolio, making it easy to see which holdings are
261 | performing best.
262 |
263 | Args:
264 | tickers: List of ticker symbols to compare (minimum 2). If None, uses portfolio holdings.
265 | days: Number of days of historical data to analyze (default: 90)
266 | user_id: User identifier (defaults to "default")
267 | portfolio_name: Portfolio name (defaults to "My Portfolio")
268 |
269 | Returns:
270 | Dictionary containing comparison results with optional portfolio context
271 |
272 | Example:
273 | >>> compare_tickers() # Automatically compares all portfolio holdings
274 | >>> compare_tickers(["AAPL", "MSFT", "GOOGL"]) # Manual comparison
275 | """
276 | try:
277 | # Auto-fill tickers from portfolio if not provided
278 | if tickers is None or len(tickers) == 0:
279 | db: Session = next(get_db())
280 | try:
281 | # Get portfolio positions
282 | portfolio = (
283 | db.query(UserPortfolio)
284 | .filter(
285 | UserPortfolio.user_id == user_id,
286 | UserPortfolio.name == portfolio_name,
287 | )
288 | .first()
289 | )
290 |
291 | if not portfolio or len(portfolio.positions) < 2:
292 | return {
293 | "error": "No portfolio found or insufficient positions for comparison",
294 | "details": "Please provide at least 2 tickers manually or add more positions to your portfolio",
295 | "status": "error",
296 | }
297 |
298 | tickers = [pos.ticker for pos in portfolio.positions]
299 | portfolio_context = {
300 | "using_portfolio": True,
301 | "portfolio_name": portfolio_name,
302 | "position_count": len(tickers),
303 | }
304 | finally:
305 | db.close()
306 | else:
307 | portfolio_context = {"using_portfolio": False}
308 |
309 | if len(tickers) < 2:
310 | raise ValueError("At least two tickers are required for comparison")
311 |
312 | from maverick_mcp.core.technical_analysis import analyze_rsi, analyze_trend
313 |
314 | results = {}
315 | for ticker in tickers:
316 | df = get_stock_dataframe(ticker, days)
317 |
318 | # Basic analysis for comparison
319 | current_price = df["close"].iloc[-1]
320 | rsi = analyze_rsi(df)
321 | trend = analyze_trend(df)
322 |
323 | # Calculate performance metrics
324 | start_price = df["close"].iloc[0]
325 | price_change_pct = ((current_price - start_price) / start_price) * 100
326 |
327 | # Calculate volatility (standard deviation of returns)
328 | returns = df["close"].pct_change().dropna()
329 | volatility = returns.std() * (252**0.5) * 100 # Annualized
330 |
331 | # Calculate volume metrics
332 | volume_change_pct = 0.0
333 | if len(df) >= 22 and df["volume"].iloc[-22] > 0:
334 | volume_change_pct = float(
335 | (df["volume"].iloc[-1] / df["volume"].iloc[-22] - 1) * 100
336 | )
337 |
338 | avg_volume = df["volume"].mean()
339 |
340 | results[ticker] = {
341 | "current_price": float(current_price),
342 | "performance": {
343 | "price_change_pct": round(price_change_pct, 2),
344 | "period_high": float(df["high"].max()),
345 | "period_low": float(df["low"].min()),
346 | "volatility_annual": round(volatility, 2),
347 | },
348 | "technical": {
349 | "rsi": rsi["current"] if rsi and "current" in rsi else None,
350 | "rsi_signal": rsi["signal"]
351 | if rsi and "signal" in rsi
352 | else "unavailable",
353 | "trend_strength": trend,
354 | "trend_description": "Strong Uptrend"
355 | if trend >= 6
356 | else "Uptrend"
357 | if trend >= 4
358 | else "Neutral"
359 | if trend >= 3
360 | else "Downtrend",
361 | },
362 | "volume": {
363 | "current_volume": int(df["volume"].iloc[-1]),
364 | "avg_volume": int(avg_volume),
365 | "volume_change_pct": volume_change_pct,
366 | "volume_trend": "Increasing"
367 | if volume_change_pct > 20
368 | else "Decreasing"
369 | if volume_change_pct < -20
370 | else "Stable",
371 | },
372 | }
373 |
374 | # Add relative rankings
375 | tickers_list = list(results.keys())
376 |
377 | # Rank by performance
378 | def get_performance(ticker: str) -> float:
379 | ticker_result = results[ticker]
380 | assert isinstance(ticker_result, dict)
381 | perf_dict = ticker_result["performance"]
382 | assert isinstance(perf_dict, dict)
383 | return float(perf_dict["price_change_pct"])
384 |
385 | def get_trend(ticker: str) -> float:
386 | ticker_result = results[ticker]
387 | assert isinstance(ticker_result, dict)
388 | tech_dict = ticker_result["technical"]
389 | assert isinstance(tech_dict, dict)
390 | return float(tech_dict["trend_strength"])
391 |
392 | perf_sorted = sorted(tickers_list, key=get_performance, reverse=True)
393 | trend_sorted = sorted(tickers_list, key=get_trend, reverse=True)
394 |
395 | for i, ticker in enumerate(perf_sorted):
396 | results[ticker]["rankings"] = {
397 | "performance_rank": i + 1,
398 | "trend_rank": trend_sorted.index(ticker) + 1,
399 | }
400 |
401 | response = {
402 | "comparison": results,
403 | "period_days": days,
404 | "as_of": datetime.now(UTC).isoformat(),
405 | "best_performer": perf_sorted[0],
406 | "strongest_trend": trend_sorted[0],
407 | }
408 |
409 | # Add portfolio context if applicable
410 | if portfolio_context["using_portfolio"]:
411 | response["portfolio_context"] = portfolio_context
412 |
413 | return response
414 | except Exception as e:
415 | logger.error(f"Error comparing tickers {tickers}: {str(e)}")
416 | return {"error": str(e), "status": "error"}
417 |
418 |
419 | def portfolio_correlation_analysis(
420 | tickers: list[str] | None = None,
421 | days: int = 252,
422 | user_id: str = "default",
423 | portfolio_name: str = "My Portfolio",
424 | ) -> dict[str, Any]:
425 | """
426 | Analyze correlation between multiple securities.
427 |
428 | DISCLAIMER: This correlation analysis is for educational purposes only.
429 | Past correlations do not guarantee future relationships between securities.
430 | Always diversify appropriately and consult with financial professionals.
431 |
432 | This tool calculates the correlation matrix for a portfolio of stocks,
433 | helping to identify:
434 | - Highly correlated positions (diversification issues)
435 | - Negative correlations (natural hedges)
436 | - Overall portfolio correlation metrics
437 |
438 | **Portfolio Integration:** If no tickers are provided, automatically analyzes
439 | correlation between all positions in your portfolio, helping you understand
440 | diversification and identify concentration risk.
441 |
442 | Args:
443 | tickers: List of ticker symbols to analyze. If None, uses portfolio holdings.
444 | days: Number of days for correlation calculation (default: 252 for 1 year)
445 | user_id: User identifier (defaults to "default")
446 | portfolio_name: Portfolio name (defaults to "My Portfolio")
447 |
448 | Returns:
449 | Dictionary containing correlation analysis with optional portfolio context
450 |
451 | Example:
452 | >>> portfolio_correlation_analysis() # Automatically analyzes portfolio
453 | >>> portfolio_correlation_analysis(["AAPL", "MSFT", "GOOGL"]) # Manual analysis
454 | """
455 | try:
456 | # Auto-fill tickers from portfolio if not provided
457 | if tickers is None or len(tickers) == 0:
458 | db: Session = next(get_db())
459 | try:
460 | # Get portfolio positions
461 | portfolio = (
462 | db.query(UserPortfolio)
463 | .filter(
464 | UserPortfolio.user_id == user_id,
465 | UserPortfolio.name == portfolio_name,
466 | )
467 | .first()
468 | )
469 |
470 | if not portfolio or len(portfolio.positions) < 2:
471 | return {
472 | "error": "No portfolio found or insufficient positions for correlation analysis",
473 | "details": "Please provide at least 2 tickers manually or add more positions to your portfolio",
474 | "status": "error",
475 | }
476 |
477 | tickers = [pos.ticker for pos in portfolio.positions]
478 | portfolio_context = {
479 | "using_portfolio": True,
480 | "portfolio_name": portfolio_name,
481 | "position_count": len(tickers),
482 | }
483 | finally:
484 | db.close()
485 | else:
486 | portfolio_context = {"using_portfolio": False}
487 |
488 | if len(tickers) < 2:
489 | raise ValueError("At least two tickers required for correlation analysis")
490 |
491 | # Fetch data for all tickers
492 | end_date = datetime.now(UTC)
493 | start_date = end_date - timedelta(days=days)
494 |
495 | price_data = {}
496 | failed_tickers = []
497 | for ticker in tickers:
498 | try:
499 | df = stock_provider.get_stock_data(
500 | ticker,
501 | start_date.strftime("%Y-%m-%d"),
502 | end_date.strftime("%Y-%m-%d"),
503 | )
504 | if not df.empty:
505 | price_data[ticker] = df["close"]
506 | else:
507 | failed_tickers.append(ticker)
508 | except Exception as e:
509 | logger.warning(f"Failed to fetch data for {ticker}: {e}")
510 | failed_tickers.append(ticker)
511 |
512 | # Check if we have enough valid tickers
513 | if len(price_data) < 2:
514 | return {
515 | "error": f"Insufficient valid price data (need 2+ tickers, got {len(price_data)})",
516 | "details": f"Failed tickers: {', '.join(failed_tickers)}"
517 | if failed_tickers
518 | else "No tickers provided sufficient data",
519 | "status": "error",
520 | }
521 |
522 | # Create price DataFrame
523 | prices_df = pd.DataFrame(price_data)
524 |
525 | # Calculate returns
526 | returns_df = prices_df.pct_change().dropna()
527 |
528 | # Check for sufficient data points
529 | if len(returns_df) < 30:
530 | return {
531 | "error": "Insufficient data points for correlation analysis",
532 | "details": f"Need at least 30 data points, got {len(returns_df)}. Try increasing the days parameter.",
533 | "status": "error",
534 | }
535 |
536 | # Calculate correlation matrix
537 | correlation_matrix = returns_df.corr()
538 |
539 | # Check for NaN/Inf values
540 | if (
541 | correlation_matrix.isnull().any().any()
542 | or not correlation_matrix.applymap(lambda x: abs(x) <= 1.0).all().all()
543 | ):
544 | return {
545 | "error": "Invalid correlation values detected",
546 | "details": "Correlation matrix contains NaN or invalid values. This may indicate insufficient price variation.",
547 | "status": "error",
548 | }
549 |
550 | # Find highly correlated pairs
551 | high_correlation_pairs = []
552 | low_correlation_pairs = []
553 |
554 | for i in range(len(tickers)):
555 | for j in range(i + 1, len(tickers)):
556 | corr_val = correlation_matrix.iloc[i, j]
557 | corr = float(corr_val.item() if hasattr(corr_val, "item") else corr_val)
558 | pair = (tickers[i], tickers[j])
559 |
560 | if corr > 0.7:
561 | high_correlation_pairs.append(
562 | {
563 | "pair": pair,
564 | "correlation": round(corr, 3),
565 | "interpretation": "High positive correlation",
566 | }
567 | )
568 | elif corr < -0.3:
569 | low_correlation_pairs.append(
570 | {
571 | "pair": pair,
572 | "correlation": round(corr, 3),
573 | "interpretation": "Negative correlation (potential hedge)",
574 | }
575 | )
576 |
577 | # Calculate average portfolio correlation
578 | mask = correlation_matrix.values != 1 # Exclude diagonal
579 | avg_correlation = correlation_matrix.values[mask].mean()
580 |
581 | response = {
582 | "correlation_matrix": correlation_matrix.round(3).to_dict(),
583 | "average_portfolio_correlation": round(avg_correlation, 3),
584 | "high_correlation_pairs": high_correlation_pairs,
585 | "low_correlation_pairs": low_correlation_pairs,
586 | "diversification_score": round((1 - avg_correlation) * 100, 1),
587 | "recommendation": "Well diversified"
588 | if avg_correlation < 0.3
589 | else "Moderately diversified"
590 | if avg_correlation < 0.5
591 | else "Consider adding uncorrelated assets",
592 | "period_days": days,
593 | "data_points": len(returns_df),
594 | }
595 |
596 | # Add portfolio context if applicable
597 | if portfolio_context["using_portfolio"]:
598 | response["portfolio_context"] = portfolio_context
599 |
600 | return response
601 |
602 | except Exception as e:
603 | logger.error(f"Error in correlation analysis: {str(e)}")
604 | return {"error": str(e), "status": "error"}
605 |
606 |
607 | # ============================================================================
608 | # Portfolio Management Tools
609 | # ============================================================================
610 |
611 |
612 | def add_portfolio_position(
613 | ticker: str,
614 | shares: float,
615 | purchase_price: float,
616 | purchase_date: str | None = None,
617 | notes: str | None = None,
618 | user_id: str = "default",
619 | portfolio_name: str = "My Portfolio",
620 | ) -> dict[str, Any]:
621 | """
622 | Add a stock position to your portfolio.
623 |
624 | This tool adds a new position or increases an existing position in your portfolio.
625 | If the ticker already exists, it will average the cost basis automatically.
626 |
627 | Args:
628 | ticker: Stock ticker symbol (e.g., "AAPL", "MSFT")
629 | shares: Number of shares (supports fractional shares)
630 | purchase_price: Price per share at purchase
631 | purchase_date: Purchase date in YYYY-MM-DD format (defaults to today)
632 | notes: Optional notes about this position
633 | user_id: User identifier (defaults to "default")
634 | portfolio_name: Portfolio name (defaults to "My Portfolio")
635 |
636 | Returns:
637 | Dictionary containing the updated position information
638 |
639 | Example:
640 | >>> add_portfolio_position("AAPL", 10, 150.50, "2024-01-15", "Long-term hold")
641 | """
642 | try:
643 | # Validate and normalize ticker
644 | is_valid, error_msg = _validate_ticker(ticker)
645 | if not is_valid:
646 | return {"error": error_msg, "status": "error"}
647 |
648 | ticker = _normalize_ticker(ticker)
649 |
650 | # Validate shares
651 | if shares <= 0:
652 | return {"error": "Shares must be greater than zero", "status": "error"}
653 | if shares > 1_000_000_000: # Sanity check
654 | return {
655 | "error": "Shares value too large (max 1 billion shares)",
656 | "status": "error",
657 | }
658 |
659 | # Validate purchase price
660 | if purchase_price <= 0:
661 | return {
662 | "error": "Purchase price must be greater than zero",
663 | "status": "error",
664 | }
665 | if purchase_price > 1_000_000: # Sanity check
666 | return {
667 | "error": "Purchase price too large (max $1M per share)",
668 | "status": "error",
669 | }
670 |
671 | # Parse purchase date
672 | if purchase_date:
673 | try:
674 | parsed_date = datetime.fromisoformat(
675 | purchase_date.replace("Z", "+00:00")
676 | )
677 | if parsed_date.tzinfo is None:
678 | parsed_date = parsed_date.replace(tzinfo=UTC)
679 | except ValueError:
680 | return {
681 | "error": "Invalid date format. Use YYYY-MM-DD",
682 | "status": "error",
683 | }
684 | else:
685 | parsed_date = datetime.now(UTC)
686 |
687 | db: Session = next(get_db())
688 | try:
689 | # Get or create portfolio
690 | portfolio_db = (
691 | db.query(UserPortfolio)
692 | .filter_by(user_id=user_id, name=portfolio_name)
693 | .first()
694 | )
695 |
696 | if not portfolio_db:
697 | portfolio_db = UserPortfolio(user_id=user_id, name=portfolio_name)
698 | db.add(portfolio_db)
699 | db.flush()
700 |
701 | # Get existing position if any
702 | existing_position = (
703 | db.query(PortfolioPosition)
704 | .filter_by(portfolio_id=portfolio_db.id, ticker=ticker.upper())
705 | .first()
706 | )
707 |
708 | total_cost = Decimal(str(shares)) * Decimal(str(purchase_price))
709 |
710 | if existing_position:
711 | # Update existing position (average cost basis)
712 | old_total = (
713 | existing_position.shares * existing_position.average_cost_basis
714 | )
715 | new_total = old_total + total_cost
716 | new_shares = existing_position.shares + Decimal(str(shares))
717 | new_avg_cost = new_total / new_shares
718 |
719 | existing_position.shares = new_shares
720 | existing_position.average_cost_basis = new_avg_cost
721 | existing_position.total_cost = new_total
722 | existing_position.purchase_date = parsed_date
723 | if notes:
724 | existing_position.notes = notes
725 |
726 | position_result = existing_position
727 | else:
728 | # Create new position
729 | position_result = PortfolioPosition(
730 | portfolio_id=portfolio_db.id,
731 | ticker=ticker.upper(),
732 | shares=Decimal(str(shares)),
733 | average_cost_basis=Decimal(str(purchase_price)),
734 | total_cost=total_cost,
735 | purchase_date=parsed_date,
736 | notes=notes,
737 | )
738 | db.add(position_result)
739 |
740 | db.commit()
741 |
742 | return {
743 | "status": "success",
744 | "message": f"Added {shares} shares of {ticker.upper()}",
745 | "position": {
746 | "ticker": position_result.ticker,
747 | "shares": float(position_result.shares),
748 | "average_cost_basis": float(position_result.average_cost_basis),
749 | "total_cost": float(position_result.total_cost),
750 | "purchase_date": position_result.purchase_date.isoformat(),
751 | "notes": position_result.notes,
752 | },
753 | "portfolio": {
754 | "name": portfolio_db.name,
755 | "user_id": portfolio_db.user_id,
756 | },
757 | }
758 |
759 | finally:
760 | db.close()
761 |
762 | except Exception as e:
763 | logger.error(f"Error adding position {ticker}: {str(e)}")
764 | return {"error": str(e), "status": "error"}
765 |
766 |
767 | def get_my_portfolio(
768 | user_id: str = "default",
769 | portfolio_name: str = "My Portfolio",
770 | include_current_prices: bool = True,
771 | ) -> dict[str, Any]:
772 | """
773 | Get your complete portfolio with all positions and performance metrics.
774 |
775 | This tool retrieves your entire portfolio including:
776 | - All stock positions with cost basis
777 | - Current market values (if prices available)
778 | - Profit/loss for each position
779 | - Portfolio-wide performance metrics
780 |
781 | Args:
782 | user_id: User identifier (defaults to "default")
783 | portfolio_name: Portfolio name (defaults to "My Portfolio")
784 | include_current_prices: Whether to fetch live prices for P&L (default: True)
785 |
786 | Returns:
787 | Dictionary containing complete portfolio information with performance metrics
788 |
789 | Example:
790 | >>> get_my_portfolio()
791 | """
792 | try:
793 | db: Session = next(get_db())
794 | try:
795 | # Get portfolio
796 | portfolio_db = (
797 | db.query(UserPortfolio)
798 | .filter_by(user_id=user_id, name=portfolio_name)
799 | .first()
800 | )
801 |
802 | if not portfolio_db:
803 | return {
804 | "status": "empty",
805 | "message": f"No portfolio found for user '{user_id}' with name '{portfolio_name}'",
806 | "positions": [],
807 | "total_invested": 0.0,
808 | }
809 |
810 | # Get all positions
811 | positions = (
812 | db.query(PortfolioPosition)
813 | .filter_by(portfolio_id=portfolio_db.id)
814 | .all()
815 | )
816 |
817 | if not positions:
818 | return {
819 | "status": "empty",
820 | "message": "Portfolio is empty",
821 | "portfolio": {
822 | "name": portfolio_db.name,
823 | "user_id": portfolio_db.user_id,
824 | },
825 | "positions": [],
826 | "total_invested": 0.0,
827 | }
828 |
829 | # Convert to domain model for calculations
830 | portfolio = Portfolio(
831 | portfolio_id=str(portfolio_db.id),
832 | user_id=portfolio_db.user_id,
833 | name=portfolio_db.name,
834 | )
835 | for pos_db in positions:
836 | portfolio.add_position(
837 | pos_db.ticker,
838 | pos_db.shares,
839 | pos_db.average_cost_basis,
840 | pos_db.purchase_date,
841 | )
842 |
843 | # Fetch current prices if requested
844 | current_prices = {}
845 | if include_current_prices:
846 | for pos in positions:
847 | try:
848 | df = stock_provider.get_stock_data(
849 | pos.ticker,
850 | start_date=(datetime.now(UTC) - timedelta(days=7)).strftime(
851 | "%Y-%m-%d"
852 | ),
853 | end_date=datetime.now(UTC).strftime("%Y-%m-%d"),
854 | )
855 | if not df.empty:
856 | current_prices[pos.ticker] = Decimal(
857 | str(df["Close"].iloc[-1])
858 | )
859 | except Exception as e:
860 | logger.warning(
861 | f"Could not fetch price for {pos.ticker}: {str(e)}"
862 | )
863 |
864 | # Calculate metrics
865 | metrics = portfolio.calculate_portfolio_metrics(current_prices)
866 |
867 | # Build response
868 | positions_list = []
869 | for pos_db in positions:
870 | position_dict = {
871 | "ticker": pos_db.ticker,
872 | "shares": float(pos_db.shares),
873 | "average_cost_basis": float(pos_db.average_cost_basis),
874 | "total_cost": float(pos_db.total_cost),
875 | "purchase_date": pos_db.purchase_date.isoformat(),
876 | "notes": pos_db.notes,
877 | }
878 |
879 | # Add current price and P&L if available
880 | if pos_db.ticker in current_prices:
881 | decimal_current_price = current_prices[pos_db.ticker]
882 | current_price = float(decimal_current_price)
883 | current_value = (
884 | pos_db.shares * decimal_current_price
885 | ).quantize(Decimal("0.01"))
886 | unrealized_gain_loss = (
887 | current_value - pos_db.total_cost
888 | ).quantize(Decimal("0.01"))
889 |
890 | position_dict["current_price"] = current_price
891 | position_dict["current_value"] = float(current_value)
892 | position_dict["unrealized_gain_loss"] = float(
893 | unrealized_gain_loss
894 | )
895 | position_dict["unrealized_gain_loss_percent"] = (
896 | position_dict["unrealized_gain_loss"] / float(pos_db.total_cost)
897 | ) * 100
898 |
899 | positions_list.append(position_dict)
900 |
901 | return {
902 | "status": "success",
903 | "portfolio": {
904 | "name": portfolio_db.name,
905 | "user_id": portfolio_db.user_id,
906 | "created_at": portfolio_db.created_at.isoformat(),
907 | },
908 | "positions": positions_list,
909 | "metrics": {
910 | "total_invested": metrics["total_invested"],
911 | "total_current_value": metrics["total_current_value"],
912 | "total_unrealized_gain_loss": metrics["total_unrealized_gain_loss"],
913 | "total_return_percent": metrics["total_return_percent"],
914 | "number_of_positions": len(positions_list),
915 | },
916 | "as_of": datetime.now(UTC).isoformat(),
917 | }
918 |
919 | finally:
920 | db.close()
921 |
922 | except Exception as e:
923 | logger.error(f"Error getting portfolio: {str(e)}")
924 | return {"error": str(e), "status": "error"}
925 |
926 |
927 | def remove_portfolio_position(
928 | ticker: str,
929 | shares: float | None = None,
930 | user_id: str = "default",
931 | portfolio_name: str = "My Portfolio",
932 | ) -> dict[str, Any]:
933 | """
934 | Remove shares from a position in your portfolio.
935 |
936 | This tool removes some or all shares of a stock from your portfolio.
937 | If no share count is specified, the entire position is removed.
938 |
939 | Args:
940 | ticker: Stock ticker symbol
941 | shares: Number of shares to remove (None = remove entire position)
942 | user_id: User identifier (defaults to "default")
943 | portfolio_name: Portfolio name (defaults to "My Portfolio")
944 |
945 | Returns:
946 | Dictionary containing the updated or removed position
947 |
948 | Example:
949 | >>> remove_portfolio_position("AAPL", 5) # Remove 5 shares
950 | >>> remove_portfolio_position("MSFT") # Remove entire position
951 | """
952 | try:
953 | # Validate and normalize ticker
954 | is_valid, error_msg = _validate_ticker(ticker)
955 | if not is_valid:
956 | return {"error": error_msg, "status": "error"}
957 |
958 | ticker = _normalize_ticker(ticker)
959 |
960 | # Validate shares if provided
961 | if shares is not None and shares <= 0:
962 | return {
963 | "error": "Shares to remove must be greater than zero",
964 | "status": "error",
965 | }
966 |
967 | db: Session = next(get_db())
968 | if shares is not None and shares <= 0:
969 | return {"error": "Shares must be greater than zero", "status": "error"}
970 |
971 | db: Session = next(get_db())
972 | try:
973 | # Get portfolio
974 | portfolio_db = (
975 | db.query(UserPortfolio)
976 | .filter_by(user_id=user_id, name=portfolio_name)
977 | .first()
978 | )
979 |
980 | if not portfolio_db:
981 | return {
982 | "error": f"Portfolio '{portfolio_name}' not found for user '{user_id}'",
983 | "status": "error",
984 | }
985 |
986 | # Get position
987 | position_db = (
988 | db.query(PortfolioPosition)
989 | .filter_by(portfolio_id=portfolio_db.id, ticker=ticker.upper())
990 | .first()
991 | )
992 |
993 | if not position_db:
994 | return {
995 | "error": f"Position {ticker.upper()} not found in portfolio",
996 | "status": "error",
997 | }
998 |
999 | # Remove entire position or partial shares
1000 | if shares is None or shares >= float(position_db.shares):
1001 | # Remove entire position
1002 | removed_shares = float(position_db.shares)
1003 | db.delete(position_db)
1004 | db.commit()
1005 |
1006 | return {
1007 | "status": "success",
1008 | "message": f"Removed entire position of {removed_shares} shares of {ticker.upper()}",
1009 | "removed_shares": removed_shares,
1010 | "position_fully_closed": True,
1011 | }
1012 | else:
1013 | # Remove partial shares
1014 | new_shares = position_db.shares - Decimal(str(shares))
1015 | new_total_cost = new_shares * position_db.average_cost_basis
1016 |
1017 | position_db.shares = new_shares
1018 | position_db.total_cost = new_total_cost
1019 | db.commit()
1020 |
1021 | return {
1022 | "status": "success",
1023 | "message": f"Removed {shares} shares of {ticker.upper()}",
1024 | "removed_shares": shares,
1025 | "position_fully_closed": False,
1026 | "remaining_position": {
1027 | "ticker": position_db.ticker,
1028 | "shares": float(position_db.shares),
1029 | "average_cost_basis": float(position_db.average_cost_basis),
1030 | "total_cost": float(position_db.total_cost),
1031 | },
1032 | }
1033 |
1034 | finally:
1035 | db.close()
1036 |
1037 | except Exception as e:
1038 | logger.error(f"Error removing position {ticker}: {str(e)}")
1039 | return {"error": str(e), "status": "error"}
1040 |
1041 |
1042 | def clear_my_portfolio(
1043 | user_id: str = "default",
1044 | portfolio_name: str = "My Portfolio",
1045 | confirm: bool = False,
1046 | ) -> dict[str, Any]:
1047 | """
1048 | Clear all positions from your portfolio.
1049 |
1050 | CAUTION: This removes all positions from the specified portfolio.
1051 | This action cannot be undone.
1052 |
1053 | Args:
1054 | user_id: User identifier (defaults to "default")
1055 | portfolio_name: Portfolio name (defaults to "My Portfolio")
1056 | confirm: Must be True to confirm deletion (safety check)
1057 |
1058 | Returns:
1059 | Dictionary containing confirmation of cleared positions
1060 |
1061 | Example:
1062 | >>> clear_my_portfolio(confirm=True)
1063 | """
1064 | try:
1065 | if not confirm:
1066 | return {
1067 | "error": "Must set confirm=True to clear portfolio",
1068 | "status": "error",
1069 | "message": "This is a safety check to prevent accidental deletion",
1070 | }
1071 |
1072 | db: Session = next(get_db())
1073 | try:
1074 | # Get portfolio
1075 | portfolio_db = (
1076 | db.query(UserPortfolio)
1077 | .filter_by(user_id=user_id, name=portfolio_name)
1078 | .first()
1079 | )
1080 |
1081 | if not portfolio_db:
1082 | return {
1083 | "error": f"Portfolio '{portfolio_name}' not found for user '{user_id}'",
1084 | "status": "error",
1085 | }
1086 |
1087 | # Count positions before deletion
1088 | positions_count = (
1089 | db.query(PortfolioPosition)
1090 | .filter_by(portfolio_id=portfolio_db.id)
1091 | .count()
1092 | )
1093 |
1094 | if positions_count == 0:
1095 | return {
1096 | "status": "success",
1097 | "message": "Portfolio was already empty",
1098 | "positions_cleared": 0,
1099 | }
1100 |
1101 | # Delete all positions
1102 | db.query(PortfolioPosition).filter_by(portfolio_id=portfolio_db.id).delete()
1103 | db.commit()
1104 |
1105 | return {
1106 | "status": "success",
1107 | "message": f"Cleared all positions from portfolio '{portfolio_name}'",
1108 | "positions_cleared": positions_count,
1109 | "portfolio": {
1110 | "name": portfolio_db.name,
1111 | "user_id": portfolio_db.user_id,
1112 | },
1113 | }
1114 |
1115 | finally:
1116 | db.close()
1117 |
1118 | except Exception as e:
1119 | logger.error(f"Error clearing portfolio: {str(e)}")
1120 | return {"error": str(e), "status": "error"}
1121 |
```