This is page 19 of 45. Use http://codebase.md/dicklesworthstone/llm_gateway_mcp_server?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .cursorignore
├── .env.example
├── .envrc
├── .gitignore
├── additional_features.md
├── check_api_keys.py
├── completion_support.py
├── comprehensive_test.py
├── docker-compose.yml
├── Dockerfile
├── empirically_measured_model_speeds.json
├── error_handling.py
├── example_structured_tool.py
├── examples
│ ├── __init__.py
│ ├── advanced_agent_flows_using_unified_memory_system_demo.py
│ ├── advanced_extraction_demo.py
│ ├── advanced_unified_memory_system_demo.py
│ ├── advanced_vector_search_demo.py
│ ├── analytics_reporting_demo.py
│ ├── audio_transcription_demo.py
│ ├── basic_completion_demo.py
│ ├── cache_demo.py
│ ├── claude_integration_demo.py
│ ├── compare_synthesize_demo.py
│ ├── cost_optimization.py
│ ├── data
│ │ ├── sample_event.txt
│ │ ├── Steve_Jobs_Introducing_The_iPhone_compressed.md
│ │ └── Steve_Jobs_Introducing_The_iPhone_compressed.mp3
│ ├── docstring_refiner_demo.py
│ ├── document_conversion_and_processing_demo.py
│ ├── entity_relation_graph_demo.py
│ ├── filesystem_operations_demo.py
│ ├── grok_integration_demo.py
│ ├── local_text_tools_demo.py
│ ├── marqo_fused_search_demo.py
│ ├── measure_model_speeds.py
│ ├── meta_api_demo.py
│ ├── multi_provider_demo.py
│ ├── ollama_integration_demo.py
│ ├── prompt_templates_demo.py
│ ├── python_sandbox_demo.py
│ ├── rag_example.py
│ ├── research_workflow_demo.py
│ ├── sample
│ │ ├── article.txt
│ │ ├── backprop_paper.pdf
│ │ ├── buffett.pdf
│ │ ├── contract_link.txt
│ │ ├── legal_contract.txt
│ │ ├── medical_case.txt
│ │ ├── northwind.db
│ │ ├── research_paper.txt
│ │ ├── sample_data.json
│ │ └── text_classification_samples
│ │ ├── email_classification.txt
│ │ ├── news_samples.txt
│ │ ├── product_reviews.txt
│ │ └── support_tickets.txt
│ ├── sample_docs
│ │ └── downloaded
│ │ └── attention_is_all_you_need.pdf
│ ├── sentiment_analysis_demo.py
│ ├── simple_completion_demo.py
│ ├── single_shot_synthesis_demo.py
│ ├── smart_browser_demo.py
│ ├── sql_database_demo.py
│ ├── sse_client_demo.py
│ ├── test_code_extraction.py
│ ├── test_content_detection.py
│ ├── test_ollama.py
│ ├── text_classification_demo.py
│ ├── text_redline_demo.py
│ ├── tool_composition_examples.py
│ ├── tournament_code_demo.py
│ ├── tournament_text_demo.py
│ ├── unified_memory_system_demo.py
│ ├── vector_search_demo.py
│ ├── web_automation_instruction_packs.py
│ └── workflow_delegation_demo.py
├── LICENSE
├── list_models.py
├── marqo_index_config.json.example
├── mcp_protocol_schema_2025-03-25_version.json
├── mcp_python_lib_docs.md
├── mcp_tool_context_estimator.py
├── model_preferences.py
├── pyproject.toml
├── quick_test.py
├── README.md
├── resource_annotations.py
├── run_all_demo_scripts_and_check_for_errors.py
├── storage
│ └── smart_browser_internal
│ ├── locator_cache.db
│ ├── readability.js
│ └── storage_state.enc
├── test_client.py
├── test_connection.py
├── TEST_README.md
├── test_sse_client.py
├── test_stdio_client.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── integration
│ │ ├── __init__.py
│ │ └── test_server.py
│ ├── manual
│ │ ├── test_extraction_advanced.py
│ │ └── test_extraction.py
│ └── unit
│ ├── __init__.py
│ ├── test_cache.py
│ ├── test_providers.py
│ └── test_tools.py
├── TODO.md
├── tool_annotations.py
├── tools_list.json
├── ultimate_mcp_banner.webp
├── ultimate_mcp_logo.webp
├── ultimate_mcp_server
│ ├── __init__.py
│ ├── __main__.py
│ ├── cli
│ │ ├── __init__.py
│ │ ├── __main__.py
│ │ ├── commands.py
│ │ ├── helpers.py
│ │ └── typer_cli.py
│ ├── clients
│ │ ├── __init__.py
│ │ ├── completion_client.py
│ │ └── rag_client.py
│ ├── config
│ │ └── examples
│ │ └── filesystem_config.yaml
│ ├── config.py
│ ├── constants.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── evaluation
│ │ │ ├── base.py
│ │ │ └── evaluators.py
│ │ ├── providers
│ │ │ ├── __init__.py
│ │ │ ├── anthropic.py
│ │ │ ├── base.py
│ │ │ ├── deepseek.py
│ │ │ ├── gemini.py
│ │ │ ├── grok.py
│ │ │ ├── ollama.py
│ │ │ ├── openai.py
│ │ │ └── openrouter.py
│ │ ├── server.py
│ │ ├── state_store.py
│ │ ├── tournaments
│ │ │ ├── manager.py
│ │ │ ├── tasks.py
│ │ │ └── utils.py
│ │ └── ums_api
│ │ ├── __init__.py
│ │ ├── ums_database.py
│ │ ├── ums_endpoints.py
│ │ ├── ums_models.py
│ │ └── ums_services.py
│ ├── exceptions.py
│ ├── graceful_shutdown.py
│ ├── services
│ │ ├── __init__.py
│ │ ├── analytics
│ │ │ ├── __init__.py
│ │ │ ├── metrics.py
│ │ │ └── reporting.py
│ │ ├── cache
│ │ │ ├── __init__.py
│ │ │ ├── cache_service.py
│ │ │ ├── persistence.py
│ │ │ ├── strategies.py
│ │ │ └── utils.py
│ │ ├── cache.py
│ │ ├── document.py
│ │ ├── knowledge_base
│ │ │ ├── __init__.py
│ │ │ ├── feedback.py
│ │ │ ├── manager.py
│ │ │ ├── rag_engine.py
│ │ │ ├── retriever.py
│ │ │ └── utils.py
│ │ ├── prompts
│ │ │ ├── __init__.py
│ │ │ ├── repository.py
│ │ │ └── templates.py
│ │ ├── prompts.py
│ │ └── vector
│ │ ├── __init__.py
│ │ ├── embeddings.py
│ │ └── vector_service.py
│ ├── tool_token_counter.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── audio_transcription.py
│ │ ├── base.py
│ │ ├── completion.py
│ │ ├── docstring_refiner.py
│ │ ├── document_conversion_and_processing.py
│ │ ├── enhanced-ums-lookbook.html
│ │ ├── entity_relation_graph.py
│ │ ├── excel_spreadsheet_automation.py
│ │ ├── extraction.py
│ │ ├── filesystem.py
│ │ ├── html_to_markdown.py
│ │ ├── local_text_tools.py
│ │ ├── marqo_fused_search.py
│ │ ├── meta_api_tool.py
│ │ ├── ocr_tools.py
│ │ ├── optimization.py
│ │ ├── provider.py
│ │ ├── pyodide_boot_template.html
│ │ ├── python_sandbox.py
│ │ ├── rag.py
│ │ ├── redline-compiled.css
│ │ ├── sentiment_analysis.py
│ │ ├── single_shot_synthesis.py
│ │ ├── smart_browser.py
│ │ ├── sql_databases.py
│ │ ├── text_classification.py
│ │ ├── text_redline_tools.py
│ │ ├── tournament.py
│ │ ├── ums_explorer.html
│ │ └── unified_memory_system.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── async_utils.py
│ │ ├── display.py
│ │ ├── logging
│ │ │ ├── __init__.py
│ │ │ ├── console.py
│ │ │ ├── emojis.py
│ │ │ ├── formatter.py
│ │ │ ├── logger.py
│ │ │ ├── panels.py
│ │ │ ├── progress.py
│ │ │ └── themes.py
│ │ ├── parse_yaml.py
│ │ ├── parsing.py
│ │ ├── security.py
│ │ └── text.py
│ └── working_memory_api.py
├── unified_memory_system_technical_analysis.md
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/ultimate_mcp_server/tools/optimization.py:
--------------------------------------------------------------------------------
```python
1 | """Tools for LLM cost estimation, model comparison, recommendation, and workflow execution.
2 |
3 | Provides utilities to help manage LLM usage costs and select appropriate models.
4 | """
5 | import asyncio
6 | import json
7 | import os
8 | import time
9 | import traceback
10 | from typing import Any, Dict, List, Optional, Set
11 |
12 | import networkx as nx
13 |
14 | from ultimate_mcp_server.constants import COST_PER_MILLION_TOKENS
15 | from ultimate_mcp_server.exceptions import ToolError, ToolInputError
16 | from ultimate_mcp_server.tools.base import with_error_handling, with_tool_metrics
17 | from ultimate_mcp_server.tools.completion import chat_completion
18 | from ultimate_mcp_server.tools.document_conversion_and_processing import (
19 | chunk_document,
20 | summarize_document,
21 | )
22 | from ultimate_mcp_server.tools.extraction import extract_json
23 | from ultimate_mcp_server.tools.rag import (
24 | add_documents,
25 | create_knowledge_base,
26 | generate_with_rag,
27 | retrieve_context,
28 | )
29 | from ultimate_mcp_server.tools.text_classification import text_classification
30 | from ultimate_mcp_server.utils import get_logger
31 | from ultimate_mcp_server.utils.text import count_tokens
32 |
33 | logger = get_logger("ultimate_mcp_server.tools.optimization")
34 |
35 | # --- Constants for Speed Score Mapping ---
36 | # Define bins for mapping tokens/second to a 1-5 score (lower is faster)
37 | # Adjust these thresholds based on observed performance and desired sensitivity
38 | SPEED_SCORE_BINS = [
39 | (200, 1), # > 200 tokens/s -> Score 1 (Fastest)
40 | (100, 2), # 100-200 tokens/s -> Score 2
41 | (50, 3), # 50-100 tokens/s -> Score 3
42 | (20, 4), # 20-50 tokens/s -> Score 4
43 | (0, 5), # 0-20 tokens/s -> Score 5 (Slowest)
44 | ]
45 | DEFAULT_SPEED_SCORE = 3 # Fallback score if measurement is missing/invalid or hardcoded value is missing
46 |
47 | def _map_tok_per_sec_to_score(tokens_per_sec: float) -> int:
48 | """Maps measured tokens/second to a 1-5 speed score (lower is faster)."""
49 | if tokens_per_sec is None or not isinstance(tokens_per_sec, (int, float)) or tokens_per_sec < 0:
50 | return DEFAULT_SPEED_SCORE # Return default for invalid input
51 | for threshold, score in SPEED_SCORE_BINS:
52 | if tokens_per_sec >= threshold:
53 | return score
54 | return SPEED_SCORE_BINS[-1][1] # Should hit the 0 threshold if positive
55 |
56 | @with_tool_metrics
57 | @with_error_handling
58 | async def estimate_cost(
59 | prompt: str,
60 | model: str, # Can be full 'provider/model_name' or just 'model_name' if unique
61 | max_tokens: Optional[int] = None,
62 | include_output: bool = True
63 | ) -> Dict[str, Any]:
64 | """Estimates the monetary cost of an LLM request without executing it.
65 |
66 | Calculates cost based on input prompt tokens and estimated/specified output tokens
67 | using predefined cost rates for the specified model.
68 |
69 | Args:
70 | prompt: The text prompt that would be sent to the model.
71 | model: The model identifier (e.g., "openai/gpt-4.1-mini", "gpt-4.1-mini",
72 | "anthropic/claude-3-5-haiku-20241022", "claude-3-5-haiku-20241022").
73 | Cost data must be available for the resolved model name in `COST_PER_MILLION_TOKENS`.
74 | max_tokens: (Optional) The maximum number of tokens expected in the output. If None,
75 | output tokens are estimated as roughly half the input prompt tokens.
76 | include_output: (Optional) If False, calculates cost based only on input tokens, ignoring
77 | `max_tokens` or output estimation. Defaults to True.
78 |
79 | Returns:
80 | A dictionary containing the cost estimate and token breakdown:
81 | {
82 | "cost": 0.000150, # Total estimated cost in USD
83 | "breakdown": {
84 | "input_cost": 0.000100,
85 | "output_cost": 0.000050
86 | },
87 | "tokens": {
88 | "input": 200, # Tokens counted from the prompt
89 | "output": 100, # Estimated or provided max_tokens
90 | "total": 300
91 | },
92 | "rate": { # Cost per million tokens for this model
93 | "input": 0.50,
94 | "output": 1.50
95 | },
96 | "model": "gpt-4.1-mini", # Returns the original model string passed as input
97 | "resolved_model_key": "gpt-4.1-mini", # The key used for cost lookup
98 | "is_estimate": true
99 | }
100 |
101 | Raises:
102 | ToolInputError: If prompt or model format is invalid.
103 | ToolError: If the specified `model` cannot be resolved to cost data.
104 | ValueError: If token counting fails for the given model and prompt.
105 | """
106 | # Input validation
107 | if not prompt or not isinstance(prompt, str):
108 | raise ToolInputError("Prompt must be a non-empty string.")
109 | if not model or not isinstance(model, str):
110 | raise ToolInputError("Model must be a non-empty string.")
111 |
112 | # Flexible Cost Data Lookup
113 | cost_data = COST_PER_MILLION_TOKENS.get(model)
114 | resolved_model_key = model # Assume direct match first
115 | model_name_only = model # Use input model for token counting initially
116 |
117 | if not cost_data and '/' in model:
118 | # If direct lookup fails and it looks like a prefixed name, try stripping prefix
119 | potential_short_key = model.split('/')[-1]
120 | cost_data = COST_PER_MILLION_TOKENS.get(potential_short_key)
121 | if cost_data:
122 | resolved_model_key = potential_short_key
123 | model_name_only = potential_short_key # Use short name for token count
124 | # If short key also fails, cost_data remains None
125 |
126 | if not cost_data:
127 | error_message = f"Unknown model or cost data unavailable for: {model}"
128 | raise ToolError(error_message, error_code="MODEL_NOT_FOUND", details={"model": model})
129 |
130 | # Token Counting (use model_name_only derived from successful cost key)
131 | try:
132 | input_tokens = count_tokens(prompt, model=model_name_only)
133 | except ValueError as e:
134 | # Log warning with the original model input for clarity
135 | logger.warning(f"Could not count tokens for model '{model}' (using '{model_name_only}' for tiktoken): {e}. Using rough estimate.")
136 | input_tokens = len(prompt) // 4 # Fallback estimate
137 |
138 | # Estimate output tokens if needed
139 | estimated_output_tokens = 0
140 | if include_output:
141 | if max_tokens is not None:
142 | estimated_output_tokens = max_tokens
143 | else:
144 | estimated_output_tokens = input_tokens // 2
145 | logger.debug(f"max_tokens not provided, estimating output tokens as {estimated_output_tokens}")
146 | else:
147 | estimated_output_tokens = 0
148 |
149 | # Calculate costs
150 | input_cost = (input_tokens / 1_000_000) * cost_data["input"]
151 | output_cost = (estimated_output_tokens / 1_000_000) * cost_data["output"]
152 | total_cost = input_cost + output_cost
153 |
154 | logger.info(f"Estimated cost for model '{model}' (using key '{resolved_model_key}'): ${total_cost:.6f} (In: {input_tokens} tokens, Out: {estimated_output_tokens} tokens)")
155 | return {
156 | "cost": total_cost,
157 | "breakdown": {
158 | "input_cost": input_cost,
159 | "output_cost": output_cost
160 | },
161 | "tokens": {
162 | "input": input_tokens,
163 | "output": estimated_output_tokens,
164 | "total": input_tokens + estimated_output_tokens
165 | },
166 | "rate": {
167 | "input": cost_data["input"],
168 | "output": cost_data["output"]
169 | },
170 | "model": model, # Return original input model string
171 | "resolved_model_key": resolved_model_key, # Key used for cost lookup
172 | "is_estimate": True
173 | }
174 |
175 | @with_tool_metrics
176 | @with_error_handling
177 | async def compare_models(
178 | prompt: str,
179 | models: List[str], # List of model IDs (can be short or full names)
180 | max_tokens: Optional[int] = None,
181 | include_output: bool = True
182 | ) -> Dict[str, Any]:
183 | """Compares the estimated cost of running a prompt across multiple specified models.
184 |
185 | Uses the `estimate_cost` tool for each model in the list concurrently.
186 |
187 | Args:
188 | prompt: The text prompt to use for cost comparison.
189 | models: A list of model identifiers (e.g., ["openai/gpt-4.1-mini", "gpt-4.1-mini", "claude-3-5-haiku-20241022"]).
190 | `estimate_cost` will handle resolving these to cost data.
191 | max_tokens: (Optional) Maximum output tokens to assume for cost estimation across all models.
192 | If None, output is estimated individually per model based on input.
193 | include_output: (Optional) Whether to include estimated output costs in the comparison. Defaults to True.
194 |
195 | Returns:
196 | A dictionary containing the cost comparison results:
197 | {
198 | "models": {
199 | "openai/gpt-4.1-mini": { # Uses the input model name as key
200 | "cost": 0.000150,
201 | "tokens": { "input": 200, "output": 100, "total": 300 }
202 | },
203 | "claude-3-5-haiku-20241022": {
204 | "cost": 0.000087,
205 | "tokens": { "input": 200, "output": 100, "total": 300 }
206 | },
207 | "some-unknown-model": { # Example of an error during estimation
208 | "error": "Unknown model or cost data unavailable for: some-unknown-model"
209 | }
210 | },
211 | "ranking": [ # List of input model names ordered by cost (cheapest first), errors excluded
212 | "claude-3-5-haiku-20241022",
213 | "openai/gpt-4.1-mini"
214 | ],
215 | "cheapest": "claude-3-5-haiku-20241022", # Input model name with the lowest cost
216 | "most_expensive": "openai/gpt-4.1-mini", # Input model name with the highest cost
217 | "prompt_length_chars": 512,
218 | "max_tokens_assumed": 100
219 | }
220 |
221 | Raises:
222 | ToolInputError: If the `models` list is empty.
223 | """
224 | if not models or not isinstance(models, list):
225 | raise ToolInputError("'models' must be a non-empty list of model identifiers.")
226 | # Removed the check for '/' in model names - estimate_cost will handle resolution
227 |
228 | results = {}
229 | estimated_output_for_summary = None
230 |
231 | async def get_estimate(model_input_name): # Use a distinct variable name
232 | nonlocal estimated_output_for_summary
233 | try:
234 | estimate = await estimate_cost(
235 | prompt=prompt,
236 | model=model_input_name, # Pass the potentially short/full name
237 | max_tokens=max_tokens,
238 | include_output=include_output
239 | )
240 | # Use the original input name as the key in results
241 | results[model_input_name] = {
242 | "cost": estimate["cost"],
243 | "tokens": estimate["tokens"],
244 | }
245 | if estimated_output_for_summary is None:
246 | estimated_output_for_summary = estimate["tokens"]["output"]
247 | except ToolError as e:
248 | logger.warning(f"Could not estimate cost for model {model_input_name}: {e.detail}")
249 | results[model_input_name] = {"error": e.detail} # Store error under original name
250 | except Exception as e:
251 | logger.error(f"Unexpected error estimating cost for model {model_input_name}: {e}", exc_info=True)
252 | results[model_input_name] = {"error": f"Unexpected error: {str(e)}"}
253 |
254 | await asyncio.gather(*(get_estimate(model_name) for model_name in models))
255 |
256 | successful_estimates = {m: r for m, r in results.items() if "error" not in r}
257 | sorted_models = sorted(successful_estimates.items(), key=lambda item: item[1]["cost"])
258 |
259 | output_tokens_summary = estimated_output_for_summary if max_tokens is None else max_tokens
260 | if not include_output:
261 | output_tokens_summary = 0
262 |
263 | cheapest_model = sorted_models[0][0] if sorted_models else None
264 | most_expensive_model = sorted_models[-1][0] if sorted_models else None
265 | logger.info(f"Compared models: {list(results.keys())}. Cheapest: {cheapest_model or 'N/A'}")
266 |
267 | return {
268 | "models": results,
269 | "ranking": [m for m, _ in sorted_models], # Ranking uses original input names
270 | "cheapest": cheapest_model,
271 | "most_expensive": most_expensive_model,
272 | "prompt_length_chars": len(prompt),
273 | "max_tokens_assumed": output_tokens_summary,
274 | }
275 |
276 | @with_tool_metrics
277 | @with_error_handling
278 | async def recommend_model(
279 | task_type: str,
280 | expected_input_length: int, # In characters
281 | expected_output_length: Optional[int] = None, # In characters
282 | required_capabilities: Optional[List[str]] = None,
283 | max_cost: Optional[float] = None,
284 | priority: str = "balanced" # Options: "cost", "quality", "speed", "balanced"
285 | ) -> Dict[str, Any]:
286 | """Recommends suitable LLM models based on task requirements and optimization priority.
287 |
288 | Evaluates known models against criteria like task type suitability (inferred),
289 | estimated cost (based on expected lengths), required capabilities,
290 | measured speed (tokens/sec if available), and quality metrics.
291 |
292 | Args:
293 | task_type: A description of the task (e.g., "summarization", "code generation", "entity extraction",
294 | "customer support chat", "complex reasoning question"). Used loosely for capability checks.
295 | expected_input_length: Estimated length of the input text in characters.
296 | expected_output_length: (Optional) Estimated length of the output text in characters.
297 | If None, it's roughly estimated based on input length.
298 | required_capabilities: (Optional) A list of specific capabilities the model MUST possess.
299 | Current known capabilities include: "reasoning", "coding", "knowledge",
300 | "instruction-following", "math". Check model metadata for supported values.
301 | Example: ["coding", "instruction-following"]
302 | max_cost: (Optional) The maximum acceptable estimated cost (in USD) for a single run
303 | with the expected input/output lengths. Models exceeding this are excluded.
304 | priority: (Optional) The primary factor for ranking suitable models.
305 | Options:
306 | - "cost": Prioritize the cheapest models.
307 | - "quality": Prioritize models with the highest quality score.
308 | - "speed": Prioritize models with the highest measured speed (tokens/sec).
309 | - "balanced": (Default) Attempt to find a good mix of cost, quality, and speed.
310 |
311 | Returns:
312 | A dictionary containing model recommendations:
313 | {
314 | "recommendations": [
315 | {
316 | "model": "anthropic/claude-3-5-haiku-20241022",
317 | "estimated_cost": 0.000087,
318 | "quality_score": 7,
319 | "measured_speed_tps": 50.63, # Tokens per second
320 | "capabilities": ["knowledge", "instruction-following"],
321 | "reason": "Good balance of cost and speed, meets requirements."
322 | },
323 | {
324 | "model": "openai/gpt-4.1-mini",
325 | "estimated_cost": 0.000150,
326 | "quality_score": 7,
327 | "measured_speed_tps": 112.06,
328 | "capabilities": ["reasoning", "coding", ...],
329 | "reason": "Higher cost, but good quality/speed."
330 | }
331 | # ... other suitable models
332 | ],
333 | "parameters": { # Input parameters for context
334 | "task_type": "summarization",
335 | "expected_input_length": 2000,
336 | "expected_output_length": 500,
337 | "required_capabilities": [],
338 | "max_cost": 0.001,
339 | "priority": "balanced"
340 | },
341 | "excluded_models": { # Models evaluated but excluded, with reasons
342 | "anthropic/claude-3-opus-20240229": "Exceeds max cost ($0.0015 > $0.001)",
343 | "some-other-model": "Missing required capabilities: ['coding']"
344 | }
345 | }
346 |
347 | Raises:
348 | ToolInputError: If priority is invalid or lengths are non-positive.
349 | """
350 | if expected_input_length <= 0:
351 | raise ToolInputError("expected_input_length must be positive.")
352 | if expected_output_length is not None and expected_output_length <= 0:
353 | raise ToolInputError("expected_output_length must be positive if provided.")
354 | if priority not in ["cost", "quality", "speed", "balanced"]:
355 | raise ToolInputError(f"Invalid priority: '{priority}'. Must be cost, quality, speed, or balanced.")
356 |
357 | # --- Load Measured Speed Data ---
358 | measured_speeds: Dict[str, Any] = {}
359 | measured_speeds_file = "empirically_measured_model_speeds.json"
360 | project_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
361 | filepath = os.path.join(project_root, measured_speeds_file)
362 | if os.path.exists(filepath):
363 | try:
364 | with open(filepath, 'r') as f:
365 | measured_speeds = json.load(f)
366 | logger.info(f"Successfully loaded measured speed data from {filepath}")
367 | except (FileNotFoundError, json.JSONDecodeError, IOError) as e:
368 | logger.warning(f"Could not load or parse measured speed data from {filepath}: {e}. Speed data will be 0.", exc_info=True)
369 | measured_speeds = {}
370 | else:
371 | logger.info(f"Measured speed file not found at {filepath}. Speed data will be 0.")
372 | # --- End Load Measured Speed Data ---
373 |
374 | # --- Model Metadata (Updated based on provided images) ---
375 | model_capabilities = {
376 | # OpenAI models
377 | "openai/gpt-4o": ["reasoning", "coding", "knowledge", "instruction-following", "math", "multimodal"], # Assuming multimodal based on general knowledge
378 | "openai/gpt-4o-mini": ["reasoning", "knowledge", "instruction-following"],
379 | "openai/gpt-4.1": ["reasoning", "coding", "knowledge", "instruction-following", "math"],
380 | "openai/gpt-4.1-mini": ["reasoning", "coding", "knowledge", "instruction-following"],
381 | "openai/gpt-4.1-nano": ["reasoning", "knowledge", "instruction-following"], # Added reasoning
382 | "openai/o1-preview": ["reasoning", "coding", "knowledge", "instruction-following", "math"],
383 | "openai/o1": ["reasoning", "coding", "knowledge", "instruction-following", "math"], # Keep guess
384 | "openai/o3-mini": ["reasoning", "knowledge", "instruction-following"],
385 |
386 | # Anthropic models
387 | "anthropic/claude-3-opus-20240229": ["reasoning", "coding", "knowledge", "instruction-following", "math", "multimodal"],
388 | "anthropic/claude-3-sonnet-20240229": ["reasoning", "coding", "knowledge", "instruction-following", "math", "multimodal"], # Previous Sonnet version
389 | "anthropic/claude-3-5-haiku-20241022": ["knowledge", "instruction-following", "multimodal"], # Based on 3.5 Haiku column
390 | "anthropic/claude-3-5-sonnet-20241022": ["reasoning", "coding", "knowledge", "instruction-following", "math", "multimodal"], # Based on 3.5 Sonnet column
391 | "anthropic/claude-3-7-sonnet-20250219": ["reasoning", "coding", "knowledge", "instruction-following", "math", "multimodal"], # Based on 3.7 Sonnet column
392 |
393 | # DeepSeek models
394 | "deepseek/deepseek-chat": ["coding", "knowledge", "instruction-following"],
395 | "deepseek/deepseek-reasoner": ["reasoning", "math", "instruction-following"],
396 |
397 | # Gemini models
398 | "gemini/gemini-2.0-flash-lite": ["knowledge", "instruction-following"],
399 | "gemini/gemini-2.0-flash": ["knowledge", "instruction-following", "multimodal"],
400 | "gemini/gemini-2.0-flash-thinking-exp-01-21": ["reasoning", "coding", "knowledge", "instruction-following", "multimodal"],
401 | "gemini/gemini-2.5-pro-preview-03-25": ["reasoning", "coding", "knowledge", "instruction-following", "math", "multimodal"], # Map from gemini-2.5-pro-preview-03-25
402 |
403 | # Grok models (Estimates)
404 | "grok/grok-3-latest": ["reasoning", "knowledge", "instruction-following", "math"],
405 | "grok/grok-3-fast-latest": ["reasoning", "knowledge", "instruction-following"],
406 | "grok/grok-3-mini-latest": ["knowledge", "instruction-following"],
407 | "grok/grok-3-mini-fast-latest": ["knowledge", "instruction-following"],
408 |
409 | # OpenRouter models
410 | # Note: Capabilities depend heavily on the underlying model proxied by OpenRouter.
411 | # This is a generic entry for the one model listed in constants.py.
412 | "openrouter/mistralai/mistral-nemo": ["knowledge", "instruction-following", "coding"] # Estimate based on Mistral family
413 | }
414 |
415 | model_speed_fallback = {}
416 |
417 | model_quality = {
418 | "openai/gpt-4o": 8, # Updated
419 | "openai/gpt-4.1-mini": 7,
420 | "openai/gpt-4o-mini": 6,
421 | "openai/gpt-4.1": 8,
422 | "openai/gpt-4.1-nano": 5,
423 | "openai/o1-preview": 10,
424 | "openai/o3-mini": 7,
425 |
426 | "anthropic/claude-3-opus-20240229": 10,
427 | "anthropic/claude-3-sonnet-20240229": 8,
428 | "anthropic/claude-3-5-haiku-20241022": 7,
429 | "anthropic/claude-3-5-sonnet-20241022": 9,
430 | "anthropic/claude-3-7-sonnet-20250219": 10,
431 |
432 | "deepseek/deepseek-chat": 7,
433 | "deepseek/deepseek-reasoner": 8,
434 |
435 | "gemini/gemini-2.0-flash-lite": 5,
436 | "gemini/gemini-2.0-flash": 6,
437 | "gemini/gemini-2.0-flash-thinking-exp-01-21": 6,
438 | "gemini/gemini-2.5-pro-preview-03-25": 9,
439 |
440 | # Grok models (Estimates: 1-10 scale)
441 | "grok/grok-3-latest": 9,
442 | "grok/grok-3-fast-latest": 8,
443 | "grok/grok-3-mini-latest": 6,
444 | "grok/grok-3-mini-fast-latest": 6,
445 |
446 | # OpenRouter models (Estimates: 1-10 scale)
447 | "openrouter/mistralai/mistral-nemo": 7 # Estimate based on Mistral family
448 | }
449 | # --- End Model Metadata ---
450 |
451 | # --- Pre-calculate model metadata lookups ---
452 | # Combine all known prefixed model names from metadata sources
453 | all_prefixed_metadata_keys = set(model_capabilities.keys()) | set(model_speed_fallback.keys()) | set(model_quality.keys())
454 |
455 | # Create a map from short names (e.g., "gpt-4.1-mini") to prefixed names (e.g., "openai/gpt-4.1-mini")
456 | # Handle potential ambiguities (same short name from different providers)
457 | short_to_prefixed_map: Dict[str, Optional[str]] = {}
458 | ambiguous_short_names = set()
459 |
460 | for key in all_prefixed_metadata_keys:
461 | if '/' in key:
462 | short_name = key.split('/')[-1]
463 | if short_name in short_to_prefixed_map:
464 | # Ambiguity detected
465 | if short_name not in ambiguous_short_names:
466 | logger.warning(f"Ambiguous short model name '{short_name}' found. Maps to '{short_to_prefixed_map[short_name]}' and '{key}'. Will require full name for this model.")
467 | short_to_prefixed_map[short_name] = None # Mark as ambiguous
468 | ambiguous_short_names.add(short_name)
469 | elif short_name not in ambiguous_short_names:
470 | short_to_prefixed_map[short_name] = key # Store unique mapping
471 |
472 | # Helper function to find the prefixed name for a cost key (using pre-calculated map)
473 | _prefixed_name_cache = {}
474 | def _get_prefixed_name_for_cost_key(cost_key: str) -> Optional[str]:
475 | if cost_key in _prefixed_name_cache:
476 | return _prefixed_name_cache[cost_key]
477 |
478 | # If the key is already prefixed, use it directly
479 | if '/' in cost_key:
480 | if cost_key in all_prefixed_metadata_keys:
481 | _prefixed_name_cache[cost_key] = cost_key
482 | return cost_key
483 | else:
484 | # Even if prefixed, if it's not in our known metadata, treat as unknown for consistency
485 | logger.warning(f"Prefixed cost key '{cost_key}' not found in any known metadata (capabilities, quality, speed).")
486 | _prefixed_name_cache[cost_key] = None
487 | return None
488 |
489 | # Look up the short name in the pre-calculated map
490 | prefixed_name = short_to_prefixed_map.get(cost_key)
491 |
492 | if prefixed_name is not None: # Found unique mapping
493 | _prefixed_name_cache[cost_key] = prefixed_name
494 | return prefixed_name
495 | elif cost_key in ambiguous_short_names: # Known ambiguous name
496 | logger.warning(f"Cannot resolve ambiguous short name '{cost_key}'. Please use the full 'provider/model_name' identifier.")
497 | _prefixed_name_cache[cost_key] = None
498 | return None
499 | else: # Short name not found in any metadata
500 | logger.warning(f"Short name cost key '{cost_key}' not found in any known model metadata. Cannot determine provider/full name.")
501 | _prefixed_name_cache[cost_key] = None
502 | return None
503 | # --- End Pre-calculation ---
504 |
505 | # Use a simple placeholder text based on length for cost estimation
506 | sample_text = "a" * expected_input_length
507 | required_capabilities = required_capabilities or []
508 |
509 | # Rough estimate for output length if not provided
510 | if expected_output_length is None:
511 | # Adjust this heuristic as needed (e.g., summarization shortens, generation might lengthen)
512 | estimated_output_length_chars = expected_input_length // 4
513 | else:
514 | estimated_output_length_chars = expected_output_length
515 | # Estimate max_tokens based on character length (very rough)
516 | estimated_max_tokens = estimated_output_length_chars // 3
517 |
518 | candidate_models_data = []
519 | excluded_models_reasons = {}
520 | all_cost_keys = list(COST_PER_MILLION_TOKENS.keys())
521 |
522 | async def evaluate_model(cost_key: str):
523 | # 1. Find prefixed name
524 | prefixed_model_name = _get_prefixed_name_for_cost_key(cost_key)
525 | if not prefixed_model_name:
526 | excluded_models_reasons[cost_key] = "Could not reliably determine provider/full name for metadata lookup."
527 | return
528 |
529 | # 2. Check capabilities
530 | capabilities = model_capabilities.get(prefixed_model_name, [])
531 | missing_caps = [cap for cap in required_capabilities if cap not in capabilities]
532 | if missing_caps:
533 | excluded_models_reasons[prefixed_model_name] = f"Missing required capabilities: {missing_caps}"
534 | return
535 |
536 | # 3. Estimate cost
537 | try:
538 | cost_estimate = await estimate_cost(
539 | prompt=sample_text,
540 | model=cost_key, # Use the key from COST_PER_MILLION_TOKENS
541 | max_tokens=estimated_max_tokens,
542 | include_output=True
543 | )
544 | estimated_cost_value = cost_estimate["cost"]
545 | except ToolError as e:
546 | excluded_models_reasons[prefixed_model_name] = f"Cost estimation failed: {e.detail}"
547 | return
548 | except Exception as e:
549 | logger.error(f"Unexpected error estimating cost for {cost_key} (prefixed: {prefixed_model_name}) in recommendation: {e}", exc_info=True)
550 | excluded_models_reasons[prefixed_model_name] = f"Cost estimation failed unexpectedly: {str(e)}"
551 | return
552 |
553 | # 4. Check max cost constraint
554 | if max_cost is not None and estimated_cost_value > max_cost:
555 | excluded_models_reasons[prefixed_model_name] = f"Exceeds max cost (${estimated_cost_value:.6f} > ${max_cost:.6f})"
556 | return
557 |
558 | # --- 5. Get Measured Speed (Tokens/Second) ---
559 | measured_tps = 0.0 # Default to 0.0 if no data
560 | speed_source = "unavailable"
561 |
562 | measured_data = measured_speeds.get(prefixed_model_name) or measured_speeds.get(cost_key)
563 |
564 | if measured_data and isinstance(measured_data, dict) and "error" not in measured_data:
565 | tokens_per_sec = measured_data.get("output_tokens_per_second")
566 | if tokens_per_sec is not None and isinstance(tokens_per_sec, (int, float)) and tokens_per_sec >= 0:
567 | measured_tps = float(tokens_per_sec)
568 | speed_source = f"measured ({measured_tps:.1f} t/s)"
569 | else:
570 | speed_source = "no t/s in measurement"
571 | elif measured_data and "error" in measured_data:
572 | speed_source = "measurement error"
573 |
574 | logger.debug(f"Speed for {prefixed_model_name}: {measured_tps:.1f} t/s (Source: {speed_source})")
575 | # --- End Get Measured Speed ---
576 |
577 | # 6. Gather data for scoring
578 | candidate_models_data.append({
579 | "model": prefixed_model_name,
580 | "cost_key": cost_key,
581 | "cost": estimated_cost_value,
582 | "quality": model_quality.get(prefixed_model_name, 5),
583 | "measured_speed_tps": measured_tps, # Store raw TPS
584 | "capabilities": capabilities,
585 | "speed_source": speed_source # Store source for potential debugging/output
586 | })
587 |
588 | # Evaluate all models
589 | await asyncio.gather(*(evaluate_model(key) for key in all_cost_keys))
590 |
591 | # --- Scoring Logic (Updated for raw TPS) ---
592 | def calculate_score(model_data, min_cost, cost_range, min_tps, tps_range):
593 | cost = model_data['cost']
594 | quality = model_data['quality']
595 | measured_tps = model_data['measured_speed_tps']
596 |
597 | # Normalize cost (1 is cheapest, 0 is most expensive)
598 | norm_cost_score = 1.0 - ((cost - min_cost) / cost_range) if cost_range > 0 else 1.0
599 |
600 | # Normalize quality (scale 1-10)
601 | norm_quality_score = quality / 10.0
602 |
603 | # Normalize speed (measured TPS - higher is better)
604 | # (1 is fastest, 0 is slowest/0)
605 | norm_speed_score_tps = (measured_tps - min_tps) / tps_range if tps_range > 0 else 0.0
606 |
607 | # Calculate final score based on priority
608 | if priority == "cost":
609 | # Lower weight for speed if using TPS, as cost is main driver
610 | score = norm_cost_score * 0.7 + norm_quality_score * 0.2 + norm_speed_score_tps * 0.1
611 | elif priority == "quality":
612 | score = norm_cost_score * 0.15 + norm_quality_score * 0.7 + norm_speed_score_tps * 0.15
613 | elif priority == "speed":
614 | score = norm_cost_score * 0.1 + norm_quality_score * 0.2 + norm_speed_score_tps * 0.7
615 | else: # balanced
616 | score = norm_cost_score * 0.34 + norm_quality_score * 0.33 + norm_speed_score_tps * 0.33
617 |
618 | return score
619 | # --- End Scoring Logic ---
620 |
621 | # Calculate scores for all candidates
622 | if not candidate_models_data:
623 | logger.warning("No candidate models found after filtering.")
624 | else:
625 | # Get min/max for normalization *before* scoring loop
626 | all_costs = [m['cost'] for m in candidate_models_data if m['cost'] > 0]
627 | min_cost = min(all_costs) if all_costs else 0.000001
628 | max_cost_found = max(all_costs) if all_costs else 0.000001
629 | cost_range = max_cost_found - min_cost
630 |
631 | all_tps = [m['measured_speed_tps'] for m in candidate_models_data]
632 | min_tps = min(all_tps) if all_tps else 0.0
633 | max_tps_found = max(all_tps) if all_tps else 0.0
634 | tps_range = max_tps_found - min_tps
635 |
636 | for model_data in candidate_models_data:
637 | # Pass normalization ranges to scoring function
638 | model_data['score'] = calculate_score(model_data, min_cost, cost_range, min_tps, tps_range)
639 |
640 | # Sort candidates by score (highest first)
641 | sorted_candidates = sorted(candidate_models_data, key=lambda x: x.get('score', 0), reverse=True)
642 |
643 | # Format recommendations
644 | recommendations_list = []
645 | if candidate_models_data:
646 | # Get min/max across candidates *after* filtering
647 | min_candidate_cost = min(m['cost'] for m in candidate_models_data)
648 | max_candidate_quality = max(m['quality'] for m in candidate_models_data)
649 | max_candidate_tps = max(m['measured_speed_tps'] for m in candidate_models_data)
650 |
651 | for cand in sorted_candidates:
652 | reason = f"High overall score ({cand['score']:.2f}) according to '{priority}' priority."
653 | # Adjust reason phrasing for TPS
654 | if priority == 'cost' and cand['cost'] <= min_candidate_cost:
655 | reason = f"Lowest estimated cost (${cand['cost']:.6f}) and meets requirements."
656 | elif priority == 'quality' and cand['quality'] >= max_candidate_quality:
657 | reason = f"Highest quality score ({cand['quality']}/10) and meets requirements."
658 | elif priority == 'speed' and cand['measured_speed_tps'] >= max_candidate_tps:
659 | reason = f"Fastest measured speed ({cand['measured_speed_tps']:.1f} t/s) and meets requirements."
660 |
661 | recommendations_list.append({
662 | "model": cand['model'],
663 | "estimated_cost": cand['cost'],
664 | "quality_score": cand['quality'],
665 | "measured_speed_tps": cand['measured_speed_tps'], # Add raw TPS
666 | "capabilities": cand['capabilities'],
667 | "reason": reason
668 | })
669 |
670 | logger.info(f"Recommended models (priority: {priority}): {[r['model'] for r in recommendations_list]}")
671 | return {
672 | "recommendations": recommendations_list,
673 | "parameters": { # Include input parameters for context
674 | "task_type": task_type,
675 | "expected_input_length": expected_input_length,
676 | "expected_output_length": estimated_output_length_chars,
677 | "required_capabilities": required_capabilities,
678 | "max_cost": max_cost,
679 | "priority": priority
680 | },
681 | "excluded_models": excluded_models_reasons
682 | }
683 |
684 | @with_tool_metrics
685 | @with_error_handling
686 | async def execute_optimized_workflow(
687 | documents: Optional[List[str]] = None, # Make documents optional, workflow might not need them
688 | workflow: List[Dict[str, Any]] = None, # Require workflow definition
689 | max_concurrency: int = 5
690 | ) -> Dict[str, Any]:
691 | """Executes a predefined workflow consisting of multiple tool calls.
692 |
693 | Processes a list of documents (optional) through a sequence of stages defined in the workflow.
694 | Handles dependencies between stages (output of one stage as input to another) and allows
695 | for concurrent execution of independent stages or document processing within stages.
696 |
697 | Args:
698 | documents: (Optional) A list of input document strings. Required if the workflow references
699 | 'documents' as input for any stage.
700 | workflow: A list of dictionaries, where each dictionary defines a stage (a tool call).
701 | Required keys per stage:
702 | - `stage_id`: A unique identifier for this stage (e.g., "summarize_chunks").
703 | - `tool_name`: The name of the tool function to call (e.g., "summarize_document").
704 | - `params`: A dictionary of parameters to pass to the tool function.
705 | Parameter values can be literal values (strings, numbers, lists) or references
706 | to outputs from previous stages using the format `"${stage_id}.output_key"`
707 | (e.g., `{"text": "${chunk_stage}.chunks"}`).
708 | Special inputs: `"${documents}"` refers to the input `documents` list.
709 | Optional keys per stage:
710 | - `depends_on`: A list of `stage_id`s that must complete before this stage starts.
711 | - `iterate_on`: The key from a previous stage's output list over which this stage
712 | should iterate (e.g., `"${chunk_stage}.chunks"`). The tool will be
713 | called once for each item in the list.
714 | - `optimization_hints`: (Future use) Hints for model selection or cost saving for this stage.
715 | max_concurrency: (Optional) The maximum number of concurrent tasks (tool calls) to run.
716 | Defaults to 5.
717 |
718 | Returns:
719 | A dictionary containing the results of all successful workflow stages:
720 | {
721 | "success": true,
722 | "results": {
723 | "chunk_stage": { "output": { "chunks": ["chunk1...", "chunk2..."] } },
724 | "summarize_chunks": { # Example of an iterated stage
725 | "output": [
726 | { "summary": "Summary of chunk 1..." },
727 | { "summary": "Summary of chunk 2..." }
728 | ]
729 | },
730 | "final_summary": { "output": { "summary": "Overall summary..." } }
731 | },
732 | "status": "Workflow completed successfully.",
733 | "total_processing_time": 15.8
734 | }
735 | or an error dictionary if the workflow fails:
736 | {
737 | "success": false,
738 | "results": { ... }, # Results up to the point of failure
739 | "status": "Workflow failed at stage 'stage_id'.",
740 | "error": "Error details from the failed stage...",
741 | "total_processing_time": 8.2
742 | }
743 |
744 | Raises:
745 | ToolInputError: If the workflow definition is invalid (missing keys, bad references,
746 | circular dependencies - basic checks).
747 | ToolError: If a tool call within the workflow fails.
748 | Exception: For unexpected errors during workflow execution.
749 | """
750 | start_time = time.time()
751 | if not workflow or not isinstance(workflow, list):
752 | raise ToolInputError("'workflow' must be a non-empty list of stage dictionaries.")
753 |
754 | # --- Tool Mapping --- (Dynamically import or map tool names to functions)
755 | # Ensure all tools listed in workflows are mapped here correctly.
756 |
757 | try:
758 | api_meta_tool = None # Placeholder - this needs to be the actual instance
759 |
760 | if api_meta_tool: # Only add if instance is available
761 | meta_api_tools = {
762 | "register_api": api_meta_tool.register_api,
763 | "list_registered_apis": api_meta_tool.list_registered_apis,
764 | "get_api_details": api_meta_tool.get_api_details,
765 | "unregister_api": api_meta_tool.unregister_api,
766 | "call_dynamic_tool": api_meta_tool.call_dynamic_tool,
767 | "refresh_api": api_meta_tool.refresh_api,
768 | "get_tool_details": api_meta_tool.get_tool_details,
769 | "list_available_tools": api_meta_tool.list_available_tools,
770 | }
771 | else:
772 | logger.warning("APIMetaTool instance not available in execute_optimized_workflow. Meta API tools will not be callable in workflows.")
773 | meta_api_tools = {}
774 | except ImportError:
775 | logger.warning("APIMetaTool not found (meta_api_tool.py). Meta API tools cannot be used in workflows.")
776 | meta_api_tools = {}
777 |
778 | # Import extract_entity_graph lazily to avoid circular imports
779 | try:
780 | from ultimate_mcp_server.tools.entity_relation_graph import extract_entity_graph
781 | except ImportError:
782 | logger.warning("entity_relation_graph module not found. extract_entity_graph will not be available in workflows.")
783 | extract_entity_graph = None
784 |
785 | tool_functions = {
786 | # Core Gateway Tools
787 | "estimate_cost": estimate_cost,
788 | "compare_models": compare_models,
789 | "recommend_model": recommend_model,
790 | "chat_completion": chat_completion,
791 | "chunk_document": chunk_document,
792 | "summarize_document": summarize_document,
793 | "extract_json": extract_json,
794 | # Add extract_entity_graph conditionally
795 | **({"extract_entity_graph": extract_entity_graph} if extract_entity_graph else {}),
796 | # RAG Tools
797 | "create_knowledge_base": create_knowledge_base,
798 | "add_documents": add_documents,
799 | "retrieve_context": retrieve_context,
800 | "generate_with_rag": generate_with_rag,
801 | # Classification tools
802 | "text_classification": text_classification,
803 |
804 | # Merge Meta API tools
805 | **meta_api_tools,
806 |
807 | # Add other tools as needed...
808 | }
809 |
810 | # --- Advanced Workflow Validation Using NetworkX ---
811 | # Build directed graph from workflow
812 | dag = nx.DiGraph()
813 |
814 | # Add all stages as nodes
815 | for i, stage in enumerate(workflow):
816 | # Validate required keys
817 | if not all(k in stage for k in ["stage_id", "tool_name", "params"]):
818 | raise ToolInputError(f"Workflow stage {i} missing required keys (stage_id, tool_name, params).")
819 |
820 | stage_id = stage["stage_id"]
821 |
822 | # Validate params is a dictionary
823 | if not isinstance(stage["params"], dict):
824 | raise ToolInputError(f"Stage '{stage_id}' params must be a dictionary.")
825 |
826 | # Check for duplicate stage IDs
827 | if stage_id in dag:
828 | raise ToolInputError(f"Duplicate stage_id found: '{stage_id}'.")
829 |
830 | # Validate tool exists
831 | tool_name = stage["tool_name"]
832 | if tool_name not in tool_functions:
833 | raise ToolInputError(f"Unknown tool '{tool_name}' specified in stage '{stage_id}'.")
834 |
835 | # Validate depends_on is a list
836 | depends_on = stage.get("depends_on", [])
837 | if not isinstance(depends_on, list):
838 | raise ToolInputError(f"Stage '{stage_id}' depends_on must be a list.")
839 |
840 | # Add node with full stage data
841 | dag.add_node(stage_id, stage=stage)
842 |
843 | # Add dependency edges
844 | for stage in workflow:
845 | stage_id = stage["stage_id"]
846 | depends_on = stage.get("depends_on", [])
847 |
848 | for dep_id in depends_on:
849 | if dep_id not in dag:
850 | raise ToolInputError(f"Stage '{stage_id}' depends on non-existent stage '{dep_id}'.")
851 | dag.add_edge(dep_id, stage_id)
852 |
853 | # Detect circular dependencies
854 | try:
855 | cycles = list(nx.simple_cycles(dag))
856 | if cycles:
857 | cycle_str = " -> ".join(cycles[0]) + " -> " + cycles[0][0]
858 | raise ToolInputError(f"Circular dependency detected in workflow: {cycle_str}")
859 | except nx.NetworkXNoCycle:
860 | # No cycles found, this is good
861 | pass
862 |
863 | # Dictionary to store results of each stage
864 | stage_results: Dict[str, Any] = {}
865 | # Set to keep track of completed stages
866 | completed_stages: Set[str] = set()
867 | # Dictionary to hold active tasks
868 | active_tasks: Dict[str, asyncio.Task] = {} # noqa: F841
869 | # Semaphore to control concurrency
870 | concurrency_semaphore = asyncio.Semaphore(max_concurrency)
871 |
872 | # --- Workflow Execution Logic with NetworkX ---
873 | async def execute_stage(stage_id: str) -> None:
874 | """Execute a single workflow stage."""
875 | async with concurrency_semaphore:
876 | # Get stage definition
877 | stage = dag.nodes[stage_id]["stage"]
878 | tool_name = stage["tool_name"]
879 | params = stage["params"]
880 | iterate_on_ref = stage.get("iterate_on")
881 |
882 | logger.info(f"Starting workflow stage '{stage_id}' (Tool: {tool_name})")
883 |
884 | tool_func = tool_functions[tool_name]
885 |
886 | try:
887 | # Resolve parameters and handle iteration
888 | resolved_params, is_iteration, iteration_list = _resolve_params(
889 | stage_id, params, iterate_on_ref, stage_results, documents
890 | )
891 |
892 | # Execute tool function(s)
893 | if is_iteration:
894 | # Handle iteration case
895 | iteration_tasks = []
896 |
897 | for i, item in enumerate(iteration_list):
898 | # Create a new semaphore release for each iteration to allow other stages to run
899 | # while keeping track of total concurrency
900 | async def run_iteration(item_idx, item_value):
901 | async with concurrency_semaphore:
902 | iter_params = _inject_iteration_item(resolved_params, item_value)
903 | try:
904 | result = await tool_func(**iter_params)
905 | return result
906 | except Exception as e:
907 | # Capture exception details for individual iteration
908 | error_msg = f"Iteration {item_idx} failed: {type(e).__name__}: {str(e)}"
909 | logger.error(error_msg, exc_info=True)
910 | raise # Re-raise to be caught by gather
911 |
912 | task = asyncio.create_task(run_iteration(i, item))
913 | iteration_tasks.append(task)
914 |
915 | # Gather all iteration results (may raise if any iteration fails)
916 | results = await asyncio.gather(*iteration_tasks)
917 | stage_results[stage_id] = {"output": results}
918 | else:
919 | # Single execution case
920 | result = await tool_func(**resolved_params)
921 | stage_results[stage_id] = {"output": result}
922 |
923 | # Mark stage as completed
924 | completed_stages.add(stage_id)
925 | logger.info(f"Workflow stage '{stage_id}' completed successfully")
926 |
927 | except Exception as e:
928 | error_msg = f"Workflow failed at stage '{stage_id}'. Error: {type(e).__name__}: {str(e)}"
929 | logger.error(error_msg, exc_info=True)
930 | stage_results[stage_id] = {
931 | "error": error_msg,
932 | "traceback": traceback.format_exc()
933 | }
934 | # Re-raise to signal failure to main execution loop
935 | raise
936 |
937 | async def execute_dag() -> Dict[str, Any]:
938 | """Execute the entire workflow DAG with proper dependency handling."""
939 | try:
940 | # Start with a topological sort to get execution order respecting dependencies
941 | try:
942 | execution_order = list(nx.topological_sort(dag))
943 | logger.debug(f"Workflow execution order (respecting dependencies): {execution_order}")
944 | except nx.NetworkXUnfeasible as e:
945 | # Should never happen as we already checked for cycles
946 | raise ToolInputError("Workflow contains circular dependencies that were not detected earlier.") from e
947 |
948 | # Process stages in waves of parallelizable tasks
949 | while len(completed_stages) < len(dag):
950 | # Find stages ready to execute (all dependencies satisfied)
951 | ready_stages = [
952 | stage_id for stage_id in execution_order
953 | if (stage_id not in completed_stages and
954 | all(pred in completed_stages for pred in dag.predecessors(stage_id)))
955 | ]
956 |
957 | if not ready_stages:
958 | if len(completed_stages) < len(dag):
959 | # This should never happen with a valid DAG that was topologically sorted
960 | unfinished = set(execution_order) - completed_stages
961 | logger.error(f"Workflow execution stalled. Unfinished stages: {unfinished}")
962 | raise ToolError("Workflow execution stalled due to unresolvable dependencies.")
963 | break
964 |
965 | # Launch tasks for all ready stages
966 | tasks = [execute_stage(stage_id) for stage_id in ready_stages]
967 |
968 | # Wait for all tasks to complete or for the first error
969 | try:
970 | await asyncio.gather(*tasks)
971 | except Exception as e:
972 | # Any stage failure will be caught here
973 | # The specific error details are already in stage_results
974 | logger.error(f"Workflow wave execution failed: {str(e)}")
975 |
976 | # Find the first failed stage for error reporting
977 | failed_stage = next(
978 | (s for s in ready_stages if s in stage_results and "error" in stage_results[s]),
979 | ready_stages[0] # Fallback if we can't identify the specific failed stage
980 | )
981 |
982 | error_info = stage_results.get(failed_stage, {}).get("error", f"Unknown error in stage '{failed_stage}'")
983 |
984 | return {
985 | "success": False,
986 | "results": stage_results,
987 | "status": f"Workflow failed at stage '{failed_stage}'.",
988 | "error": error_info,
989 | "total_processing_time": time.time() - start_time
990 | }
991 |
992 | # If we reach here, all stages in this wave completed successfully
993 |
994 | # All stages completed successfully
995 | return {
996 | "success": True,
997 | "results": stage_results,
998 | "status": "Workflow completed successfully.",
999 | "total_processing_time": time.time() - start_time
1000 | }
1001 |
1002 | except Exception as e:
1003 | # Catch any unexpected errors in the main execution loop
1004 | error_msg = f"Unexpected error in workflow execution: {type(e).__name__}: {str(e)}"
1005 | logger.error(error_msg, exc_info=True)
1006 | return {
1007 | "success": False,
1008 | "results": stage_results,
1009 | "status": "Workflow failed with an unexpected error.",
1010 | "error": error_msg,
1011 | "total_processing_time": time.time() - start_time
1012 | }
1013 |
1014 | # Execute the workflow DAG
1015 | result = await execute_dag()
1016 |
1017 | total_time = time.time() - start_time
1018 | if result["success"]:
1019 | logger.info(f"Workflow completed successfully in {total_time:.2f}s")
1020 | else:
1021 | logger.error(f"Workflow failed after {total_time:.2f}s: {result.get('error', 'Unknown error')}")
1022 |
1023 | return result
1024 |
1025 | # --- Helper functions for workflow execution ---
1026 | # These need careful implementation for robustness
1027 |
1028 | def _resolve_params(stage_id: str, params: Dict, iterate_on_ref: Optional[str], stage_results: Dict, documents: Optional[List[str]]) -> tuple[Dict, bool, Optional[List]]:
1029 | """Resolves parameter values, handling references and iteration.
1030 | Returns resolved_params, is_iteration, iteration_list.
1031 | Raises ValueError on resolution errors.
1032 | """
1033 | resolved = {}
1034 | is_iteration = False
1035 | iteration_list = None
1036 | iteration_param_name = None
1037 |
1038 | # Check for iteration first
1039 | if iterate_on_ref:
1040 | if not iterate_on_ref.startswith("${") or not iterate_on_ref.endswith("}"):
1041 | raise ValueError(f"Invalid iterate_on reference format: '{iterate_on_ref}'")
1042 | ref_key = iterate_on_ref[2:-1]
1043 |
1044 | if ref_key == "documents":
1045 | if documents is None:
1046 | raise ValueError(f"Stage '{stage_id}' iterates on documents, but no documents were provided.")
1047 | iteration_list = documents
1048 | else:
1049 | dep_stage_id, output_key = _parse_ref(ref_key)
1050 | if dep_stage_id not in stage_results or "output" not in stage_results[dep_stage_id]:
1051 | raise ValueError(f"Dependency '{dep_stage_id}' for iteration not found or failed.")
1052 | dep_output = stage_results[dep_stage_id]["output"]
1053 | if not isinstance(dep_output, dict) or output_key not in dep_output:
1054 | raise ValueError(f"Output key '{output_key}' not found in dependency '{dep_stage_id}' for iteration.")
1055 | iteration_list = dep_output[output_key]
1056 | if not isinstance(iteration_list, list):
1057 | raise ValueError(f"Iteration target '{ref_key}' is not a list.")
1058 |
1059 | is_iteration = True
1060 | # We still resolve other params, the iteration item is injected later
1061 | logger.debug(f"Stage '{stage_id}' will iterate over {len(iteration_list)} items from '{iterate_on_ref}'")
1062 |
1063 | # Resolve individual parameters
1064 | for key, value in params.items():
1065 | if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
1066 | ref_key = value[2:-1]
1067 | if ref_key == "documents":
1068 | if documents is None:
1069 | raise ValueError(f"Parameter '{key}' references documents, but no documents provided.")
1070 | resolved[key] = documents
1071 | else:
1072 | dep_stage_id, output_key = _parse_ref(ref_key)
1073 | if dep_stage_id not in stage_results or "output" not in stage_results[dep_stage_id]:
1074 | raise ValueError(f"Dependency '{dep_stage_id}' for parameter '{key}' not found or failed.")
1075 | dep_output = stage_results[dep_stage_id]["output"]
1076 | # Handle potential nested keys in output_key later if needed
1077 | if not isinstance(dep_output, dict) or output_key not in dep_output:
1078 | raise ValueError(f"Output key '{output_key}' not found in dependency '{dep_stage_id}' for parameter '{key}'. Available keys: {list(dep_output.keys()) if isinstance(dep_output, dict) else 'N/A'}")
1079 | resolved[key] = dep_output[output_key]
1080 | # If this resolved param is the one we iterate on, store its name
1081 | if is_iteration and iterate_on_ref == value:
1082 | iteration_param_name = key
1083 | else:
1084 | resolved[key] = value # Literal value
1085 |
1086 | # Validation: If iterating, one parameter must match the iterate_on reference
1087 | if is_iteration and iteration_param_name is None:
1088 | # This means iterate_on pointed to something not used directly as a param value
1089 | # We need a convention here, e.g., assume the tool takes a list or find the param name
1090 | # For now, let's assume the tool expects the *list* if iterate_on isn't directly a param value.
1091 | # This might need refinement based on tool behavior. A clearer workflow definition could help.
1092 | # Alternative: Raise error if iterate_on target isn't explicitly mapped to a param.
1093 | # logger.warning(f"Iteration target '{iterate_on_ref}' not directly mapped to a parameter in stage '{stage_id}'. Tool must handle list input.")
1094 | # Let's require the iteration target to be mapped for clarity:
1095 | raise ValueError(f"Iteration target '{iterate_on_ref}' must correspond to a parameter value in stage '{stage_id}'.")
1096 |
1097 | # Remove the iteration parameter itself from the base resolved params if iterating
1098 | # It will be injected per-item later
1099 | if is_iteration and iteration_param_name in resolved:
1100 | del resolved[iteration_param_name]
1101 | resolved["_iteration_param_name"] = iteration_param_name # Store the name for injection
1102 |
1103 | return resolved, is_iteration, iteration_list
1104 |
1105 | def _parse_ref(ref_key: str) -> tuple[str, str]:
1106 | """Parses a reference like 'stage_id.output_key'"""
1107 | parts = ref_key.split('.', 1)
1108 | if len(parts) != 2:
1109 | raise ValueError(f"Invalid reference format: '{ref_key}'. Expected 'stage_id.output_key'.")
1110 | return parts[0], parts[1]
1111 |
1112 | def _inject_iteration_item(base_params: Dict, item: Any) -> Dict:
1113 | """Injects the current iteration item into the parameter dict."""
1114 | injected_params = base_params.copy()
1115 | iter_param_name = injected_params.pop("_iteration_param_name", None)
1116 | if iter_param_name:
1117 | injected_params[iter_param_name] = item
1118 | else:
1119 | # This case should be prevented by validation in _resolve_params
1120 | logger.error("Cannot inject iteration item: Iteration parameter name not found in resolved params.")
1121 | # Handle error appropriately, maybe raise
1122 | return injected_params
1123 |
1124 | async def _gather_iteration_results(stage_id: str, tasks: List[asyncio.Task]) -> List[Any]:
1125 | """Gathers results from iteration sub-tasks. Raises exception if any sub-task failed."""
1126 | results = []
1127 | try:
1128 | raw_results = await asyncio.gather(*tasks)
1129 | # Assume each task returns the direct output dictionary
1130 | results = list(raw_results) # gather preserves order
1131 | logger.debug(f"Iteration stage '{stage_id}' completed with {len(results)} results.")
1132 | return results
1133 | except Exception:
1134 | # If any sub-task failed, gather will raise the first exception
1135 | logger.error(f"Iteration stage '{stage_id}' failed: One or more sub-tasks raised an error.", exc_info=True)
1136 | # Cancel any remaining tasks in this iteration group if needed (gather might do this)
1137 | for task in tasks:
1138 | if not task.done():
1139 | task.cancel()
1140 | raise # Re-raise the exception to fail the main workflow stage
```
--------------------------------------------------------------------------------
/examples/research_workflow_demo.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python
2 | """
3 | Advanced Research Assistant Workflow Demo
4 |
5 | This script demonstrates a realistic research workflow using the DAG-based
6 | workflow execution system. It processes research documents through multiple
7 | analysis stages and produces visualizations of the results.
8 | """
9 | import asyncio
10 | import os
11 | import sys
12 | from collections import namedtuple # Import namedtuple
13 |
14 | # Add the project root to path so we can import ultimate_mcp_server
15 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
16 |
17 | from rich.console import Console
18 | from rich.layout import Layout
19 | from rich.panel import Panel
20 | from rich.progress import Progress, SpinnerColumn, TextColumn
21 | from rich.rule import Rule
22 | from rich.syntax import Syntax
23 | from rich.table import Table
24 | from rich.tree import Tree
25 |
26 | from ultimate_mcp_server.constants import Provider
27 | from ultimate_mcp_server.tools.optimization import execute_optimized_workflow
28 | from ultimate_mcp_server.utils import get_logger # Import get_logger
29 | from ultimate_mcp_server.utils.display import CostTracker # Import CostTracker
30 |
31 | # Initialize rich console
32 | console = Console()
33 |
34 | # Initialize logger here so it's available in main()
35 | logger = get_logger("example.research_workflow")
36 |
37 | # Create a simple structure for cost tracking from dict (tokens might be missing)
38 | TrackableResult = namedtuple("TrackableResult", ["cost", "input_tokens", "output_tokens", "provider", "model", "processing_time"])
39 |
40 | # Sample research documents
41 | SAMPLE_DOCS = [
42 | """
43 | # The Impact of Climate Change on Coastal Communities: A Multi-Regional Analysis
44 |
45 | ## Abstract
46 | This comprehensive study examines the cascading effects of climate change on 570+ coastal cities globally, with projections extending to 2050. Using data from the IPCC AR6 report and economic models from the World Bank (2021), we identify adaptation costs exceeding $71 billion annually. The research incorporates satellite data from NASA's GRACE mission and economic vulnerability indices developed by Stern et al. (2019) to assess regional disparities.
47 |
48 | ## Vulnerable Regions and Economic Impact Assessment
49 |
50 | ### 1. Southeast Asia
51 | The Mekong Delta region, home to 17 million people, faces submersion threats to 38% of its landmass by 2050. Ho Chi Minh City has invested $1.42 billion in flood prevention infrastructure, while Bangkok's $2.3 billion flood management system remains partially implemented. The Asian Development Bank (ADB) estimates adaptation costs will reach $5.7 billion annually for Vietnam alone.
52 |
53 | ### 2. Pacific Islands
54 | Kiribati, Tuvalu, and the Marshall Islands face existential threats, with projected displacement of 25-35% of their populations by 2050 according to UN estimates. Australia's "Pacific Resilience Fund" ($2.1 billion) supports adaptation, but President Maamau of Kiribati has criticized its scope as "drastically insufficient." The 2022 Wellington Accords established migration pathways for climate refugees, though implementation remains fragmented.
55 |
56 | ### 3. North American Coastal Zones
57 | Miami-Dade County's $6 billion "Rising Above" initiative represents the largest municipal climate adaptation budget in North America. The U.S. Army Corps of Engineers projects that without intervention, coastal erosion will affect 31% of Florida's beaches by 2040. Economic models by Greenstone and Carleton (2020) indicate property devaluation between $15-27 billion in Florida alone.
58 |
59 | ## Adaptation Strategies and Cost-Benefit Analysis
60 |
61 | ### Infrastructure Hardening
62 | The Netherlands' Room for the River program ($2.6 billion) has demonstrated 300% ROI through prevented flood damage. Conversely, New Orleans' post-Katrina $14.5 billion levee system upgrades show more modest returns (130% ROI) due to maintenance requirements and subsidence issues highlighted by Professor Sarah Jenkins (MIT).
63 |
64 | ### Managed Retreat
65 | Indonesia's capital relocation from Jakarta to Borneo (est. cost $34 billion) represents the largest planned managed retreat globally. Smaller programs in Alaska (Newtok and Shishmaref villages) provide case studies with per-capita costs exceeding $380,000. Dr. Robert Chen's longitudinal studies show significant social cohesion challenges, with 47% of relocated communities reporting decreased quality-of-life metrics despite improved safety.
66 |
67 | ### Ecosystem-Based Approaches
68 | Vietnam's mangrove restoration initiative ($220 million) reduces storm surge impacts by 20-50% and provides $8-$20 million in annual aquaculture benefits. The Nature Conservancy's coral reef insurance programs in Mexico demonstrate innovative financing mechanisms while providing co-benefits for local tourism economies valued at $320 million annually.
69 |
70 | ## Cross-Disciplinary Implications
71 |
72 | Climate migration pathways identified by the UNHCR will increase urban population pressures in receiving cities, particularly in Manila, Dhaka, and Lagos. Healthcare systems in coastal regions report increasing cases of waterborne diseases (62% increase since 2010) and mental health challenges associated with displacement anxiety as documented by the WHO Southeast Asia regional office.
73 |
74 | ## References
75 |
76 | 1. IPCC (2021). AR6 Climate Change 2021: Impacts, Adaptation and Vulnerability
77 | 2. Stern, N., et al. (2019). Economic vulnerability indices for coastal communities
78 | 3. Asian Development Bank. (2022). Southeast Asia Climate Adaptation Report
79 | 4. Greenstone, M., & Carleton, T. (2020). Coastal property value projections 2020-2050
80 | 5. Jenkins, S. (2022). Engineering limitations in climate adaptation infrastructure
81 | 6. Chen, R. (2021). Social dimensions of community relocation programs
82 | 7. World Health Organization. (2021). Climate change health vulnerability assessments
83 | """,
84 |
85 | """
86 | # Renewable Energy Transition: Economic Implications and Policy Frameworks
87 |
88 | ## Executive Summary
89 | This multi-phase analysis examines the economic transformation accompanying the global renewable energy transition, projecting the creation of 42.3 million new jobs by 2050 while identifying significant regional disparities and transition barriers. Drawing on data from 157 countries, this research provides comprehensive policy recommendations for equitable implementation paths.
90 |
91 | ## Methodological Framework
92 |
93 | Our modeling utilizes a modified integrated assessment model combining economic inputs from the International Energy Agency (IEA), IRENA's Renewable Jobs Tracker database, and McKinsey's Global Energy Perspective 2022. Labor market projections incorporate automation factors derived from Oxford Economics' Workforce Displacement Index, providing more nuanced job creation estimates than previous studies by Zhang et al. (2019).
94 |
95 | ## Employment Transformation Analysis by Sector
96 |
97 | ### Solar Photovoltaic Industry
98 | Employment projections indicate 18.7 million jobs by 2045, concentrated in manufacturing (32%), installation (41%), and operations/maintenance (27%). Regional distribution analysis reveals concerning inequities, with China capturing 41% of manufacturing roles while sub-Saharan Africa secures only 2.3% despite having 16% of global solar potential. The Skill Transferability Index suggests 73% of displaced fossil fuel workers could transition to solar with targeted 6-month reskilling programs.
99 |
100 | ### Wind Energy Sector
101 | Offshore wind development led by Ørsted, Vestas, and General Electric is projected to grow at 24% CAGR through 2035, creating 6.8 million jobs. Supply chain bottlenecks in rare earth elements (particularly neodymium and dysprosium) represent critical vulnerabilities, with 83% of processing controlled by three Chinese companies. Professor Tanaka's analysis suggests price volatilities of 120-350% are possible under geopolitical tensions.
102 |
103 | ### Energy Storage Revolution
104 | Recent lithium-ferro-phosphate (LFP) battery innovations by CATL have reduced implementation costs by 27% while boosting cycle life by 4,000 cycles. Grid-scale storage installations are projected to grow from 17GW (2022) to 220GW by 2035, employing 5.3 million in manufacturing and installation. The MIT Battery Initiative under Dr. Viswanathan has demonstrated promising alternative chemistries using earth-abundant materials that could further accelerate adoption if commercialized by 2025.
105 |
106 | ### Hydrogen Economy Emergence
107 | Green hydrogen production costs have declined from $5.70/kg in 2018 to $3.80/kg in 2023, with projected cost parity with natural gas achievable by 2028 according to BloombergNEF. The European Hydrogen Backbone initiative, requiring €43 billion in infrastructure investment, could generate 3.8 million jobs while reducing EU natural gas imports by 30%. Significant technological challenges remain in storage density and transport infrastructure, as highlighted in critical analyses by Professors Wilson and Leibreich.
108 |
109 | ## Transition Barriers and Regional Disparities
110 |
111 | ### Financial Constraints
112 | Developing economies face investment gaps of $730 billion annually according to the Climate Policy Initiative's 2022 report. The African Development Bank estimates that 72% of sub-Saharan African energy projects stall at the planning phase due to financing constraints despite IRRs exceeding 11.5%. Innovative financing mechanisms through the Global Climate Fund have mobilized only 23% of pledged capital as of Q1 2023.
113 |
114 | ### Policy Framework Effectiveness
115 |
116 | Cross-jurisdictional analysis of 87 renewable portfolio standards reveals three dominant policy approaches:
117 |
118 | 1. **Carbon Pricing Mechanisms**: The EU ETS carbon price of €85/ton has driven 16.5% emissions reduction in the power sector, while Canada's escalating carbon price schedule ($170/ton by 2030) provides investment certainty. Econometric modeling by Dr. Elizabeth Warren (LSE) indicates prices must reach €120/ton to fully internalize climate externalities.
119 |
120 | 2. **Direct Subsidies**: Germany's Energiewende subsidies (€238 billion cumulative) achieved 44% renewable penetration but at high consumer costs. Targeted manufacturing incentives under the U.S. Inflation Reduction Act demonstrate improved cost-efficiency with 3.2x private capital mobilization according to analysis by Resources for the Future.
121 |
122 | 3. **Phased Transition Approaches**: Denmark's offshore wind cluster development model produced the highest success metrics in our analysis, reducing LCOE by 67% while creating domestic supply chains capturing 82% of economic value. This approach has been partially replicated in Taiwan and Vietnam with similar success indicators.
123 |
124 | ## Visualized Outcomes Under Various Scenarios
125 |
126 | Under an accelerated transition (consistent with 1.5°C warming), global GDP would increase by 2.4% beyond baseline by 2050, while air pollution-related healthcare costs would decline by $780 billion annually. Conversely, our "delayed action" scenario projects stranded fossil assets exceeding $14 trillion, concentrated in 8 petrostate economies, potentially triggering financial contagion comparable to 2008.
127 |
128 | ## References
129 |
130 | 1. International Energy Agency. (2022). World Energy Outlook 2022
131 | 2. IRENA. (2023). Renewable Energy Jobs Annual Review
132 | 3. McKinsey & Company. (2022). Global Energy Perspective
133 | 4. Zhang, F., et al. (2019). Employment impacts of renewable expansion
134 | 5. Oxford Economics. (2021). Workforce Displacement Index
135 | 6. Tanaka, K. (2022). Critical material supply chains in energy transition
136 | 7. Viswanathan, V. (2023). Next-generation grid-scale storage technologies
137 | 8. BloombergNEF. (2023). Hydrogen Economy Outlook
138 | 9. Climate Policy Initiative. (2022). Global Landscape of Climate Finance
139 | 10. Warren, E. (2022). Carbon pricing efficiency and distributional impacts
140 | 11. Resources for the Future. (2023). IRA Impact Assessment
141 | """,
142 |
143 | """
144 | # Artificial Intelligence Applications in Healthcare Diagnostics: Implementation Challenges and Economic Analysis
145 |
146 | ## Abstract
147 | This comprehensive evaluation examines the integration of artificial intelligence into clinical diagnostic workflows, with particular focus on deep learning systems demonstrating 94.2% accuracy in early-stage cancer detection across 14 cancer types. The analysis spans technical validation, implementation barriers, regulatory frameworks, and economic implications based on data from 137 healthcare systems across 42 countries.
148 |
149 | ## Technological Capabilities Assessment
150 |
151 | ### Diagnostic Performance Metrics
152 |
153 | Google Health's melanoma detection algorithm demonstrated sensitivity of 95.3% and specificity of 92.7% in prospective trials, exceeding dermatologist accuracy by 18 percentage points with consistent performance across Fitzpatrick skin types I-VI. This represents significant improvement over earlier algorithms criticized for performance disparities across demographic groups as documented by Dr. Abigail Johnson in JAMA Dermatology (2021).
154 |
155 | The Mayo Clinic's AI-enhanced colonoscopy system increased adenoma detection rates from 30% to 47% in their 2022 clinical implementation study (n=3,812). This translates to approximately 68 prevented colorectal cancer cases per 1,000 screened patients according to the predictive model developed by Dr. Singh at Memorial Sloan Kettering.
156 |
157 | Stanford Medicine's deep learning algorithm for chest radiograph interpretation identified 14 pathological conditions with average AUC of 0.91, reducing false negative rates for subtle pneumothorax by 43% and pulmonary nodules by 29% in their multi-center validation study across five hospital systems with diverse patient populations.
158 |
159 | ### Architectural Innovations
160 |
161 | Recent advancements in foundation models have transformed medical AI capabilities:
162 |
163 | 1. **Multi-modal integration**: Microsoft/Nuance's DAX system combines speech recognition, natural language processing, and computer vision, enabling real-time clinical documentation with 96.4% accuracy while reducing physician documentation time by 78 minutes daily according to their 16-site implementation study published in Health Affairs.
164 |
165 | 2. **Explainable AI approaches**: PathAI's interpretable convolutional neural networks provide visualization of decision-making factors in histopathology, addressing the "black box" concern highlighted by regulatory agencies. Their GradCAM implementation allows pathologists to review the specific cellular features informing algorithmic conclusions, increasing adoption willingness by 67% in surveyed practitioners (n=245).
166 |
167 | 3. **Federated learning**: The MELLODDY consortium's federated approach enables algorithm training across 10 pharmaceutical companies' proprietary datasets without data sharing, demonstrating how privacy-preserving computation can accelerate biomarker discovery. This approach increased available training data by 720% while maintaining data sovereignty.
168 |
169 | ## Implementation Challenges
170 |
171 | ### Clinical Workflow Integration
172 |
173 | Field studies at Massachusetts General Hospital identified five critical integration failure points that reduce AI effectiveness by 30-70% compared to validation performance:
174 |
175 | 1. Alert fatigue – 52% of clinical recommendations were dismissed when AI systems generated more than 8 alerts per hour
176 | 2. Workflow disruption – Systems requiring more than 15 seconds of additional process time saw 68% lower adoption
177 | 3. Interface design issues – Poorly designed UI elements reduced effective utilization by 47%
178 | 4. Confirmation bias – Clinicians were 3.4× more likely to accept AI suggestions matching their preliminary conclusion
179 | 5. Trust calibration – 64% of clinicians struggled to appropriately weight algorithmic recommendations against their clinical judgment
180 |
181 | The Cleveland Clinic's "AI Integration Framework" addresses these challenges through graduated autonomy, contextual presentation, and embedded calibration metrics, increasing sustained adoption rates to 84% compared to the industry average of 31%.
182 |
183 | ### Data Infrastructure Requirements
184 |
185 | Analysis of implementation failures reveals data architecture as the primary barrier in 68% of stalled healthcare AI initiatives. Specific challenges include:
186 |
187 | - Legacy system integration – 73% of U.S. hospitals utilize EHR systems with insufficient API capabilities for real-time AI integration
188 | - Data standardization – Only 12% of clinical data meets FHIR standards without requiring significant transformation
189 | - Computational infrastructure – 57% of healthcare systems lack edge computing capabilities necessary for low-latency applications
190 |
191 | Kaiser Permanente's successful enterprise-wide implementation demonstrates a viable pathway through their "data fabric" architecture connecting 39 hospitals while maintaining HIPAA compliance. Their staged implementation required $43 million in infrastructure investment but delivered $126 million in annual efficiency gains by year three.
192 |
193 | ### Training Requirements for Medical Personnel
194 |
195 | Harvard Medical School's "Technology Integration in Medicine" study identified critical competency gaps among practitioners:
196 |
197 | - Only 17% of physicians could correctly interpret AI-generated confidence intervals
198 | - 73% overestimated algorithm capabilities in transfer scenarios
199 | - 81% lacked understanding of common algorithmic biases
200 |
201 | The American Medical Association's AI curriculum module has demonstrated 82% improvement in AI literacy metrics but has reached only a fraction of practitioners. Training economics present a significant barrier, with health systems reporting that comprehensive AI education requires 18-24 hours per clinician at an average opportunity cost of $5,800.
202 |
203 | ## Economic and Policy Dimensions
204 |
205 | ### Cost-Benefit Model
206 |
207 | Our economic modeling based on Medicare claims data projects net healthcare savings of $36.7 billion annually when AI diagnostic systems reach 65% market penetration. These savings derive from:
208 |
209 | - Earlier cancer detection: $14.3 billion through stage migration
210 | - Reduced diagnostic errors: $9.8 billion in avoided misdiagnosis costs
211 | - Workflow efficiency: $6.2 billion in provider time optimization
212 | - Avoided unnecessary procedures: $6.4 billion by reducing false positives
213 |
214 | Implementation costs average $175,000-$390,000 per facility with 3.1-year average payback periods. Rural and critical access hospitals face disproportionately longer ROI timelines (5.7 years), exacerbating healthcare disparities.
215 |
216 | ### Regulatory Framework Analysis
217 |
218 | Comparative analysis of regulatory approaches across jurisdictions reveals critical inconsistencies:
219 |
220 | | Jurisdiction | Approval Pathway | Post-Market Requirements | Algorithm Update Handling |
221 | |--------------|------------------|--------------------------|---------------------------|
222 | | FDA (US) | 510(k)/De Novo | Limited continuous monitoring | Predetermined change protocol |
223 | | EMA (EU) | MDR risk-based | PMCF with periodic reporting | Significant modification framework |
224 | | PMDA (Japan) | SAKIGAKE pathway | Mandatory registry participation | Version control system |
225 | | NMPA (China) | Special approval | Real-world data collection | Annual recertification |
226 |
227 | The European Medical Device Regulation's requirement for "human oversight of automated systems" creates implementation ambiguities interpreted differently across member states. The FDA's proposed "Predetermined Change Control Plan" offers the most promising framework for AI's iterative improvement nature but remains in draft status.
228 |
229 | ## Conclusions and Future Directions
230 |
231 | AI diagnosis systems demonstrate significant technical capabilities but face complex implementation barriers that transcend technological challenges. Our analysis suggests a "sociotechnical systems approach" is essential, recognizing that successful implementation depends equally on technical performance, clinical workflow integration, organizational change management, and policy frameworks.
232 |
233 | The Cleveland Clinic-Mayo Clinic consortium's phased implementation approach, beginning with augmentative rather than autonomous functionality, provides a template for successful adoption. Their experience indicates that progressive automation on a 3-5 year timeline produces superior outcomes compared to transformative implementation approaches.
234 |
235 | ## References
236 |
237 | 1. Johnson, A. (2021). Demographic performance disparities in dermatological AI. JAMA Dermatology, 157(2)
238 | 2. Mayo Clinic. (2022). AI-enhanced colonoscopy outcomes study. Journal of Gastrointestinal Endoscopy, 95(3)
239 | 3. Singh, K. (2021). Predictive modeling of prevented colorectal cancer cases. NEJM, 384
240 | 4. Stanford Medicine. (2022). Multi-center validation of deep learning for radiograph interpretation. Radiology, 302(1)
241 | 5. Nuance Communications. (2023). DAX system implementation outcomes. Health Affairs, 42(1)
242 | 6. PathAI. (2022). Pathologist adoption of explainable AI systems. Modern Pathology, 35
243 | 7. MELLODDY Consortium. (2022). Federated learning for pharmaceutical research. Nature Machine Intelligence, 4
244 | 8. Massachusetts General Hospital. (2021). Clinical workflow integration failure points for AI. JAMIA, 28(9)
245 | 9. Cleveland Clinic. (2023). AI Integration Framework outcomes. Healthcare Innovation, 11(2)
246 | 10. American Medical Association. (2022). Physician AI literacy assessment. Journal of Medical Education, 97(6)
247 | 11. Centers for Medicare & Medicaid Services. (2023). Healthcare AI economic impact analysis
248 | 12. FDA. (2023). Proposed framework for AI/ML-based SaMD. Regulatory Science Forum
249 | """,
250 |
251 | """
252 | # Quantum Computing Applications in Pharmaceutical Discovery: Capabilities, Limitations, and Industry Transformation
253 |
254 | ## Executive Summary
255 |
256 | This analysis evaluates the integration of quantum computing technologies into pharmaceutical R&D workflows, examining current capabilities, near-term applications, and long-term industry transformation potential. Based on benchmarking across 17 pharmaceutical companies and 8 quantum technology providers, we provide a comprehensive assessment of this emerging computational paradigm and its implications for drug discovery economics.
257 |
258 | ## Current Quantum Computing Capabilities
259 |
260 | ### Hardware Platforms Assessment
261 |
262 | **Superconducting quantum processors** (IBM, Google, Rigetti) currently provide the most mature platform with IBM's 433-qubit Osprey system demonstrating quantum volume of 128 and error rates approaching 10^-3 per gate operation. While impressive relative to 2018 benchmarks, these systems remain limited by coherence times (averaging 114 microseconds) and require operating temperatures near absolute zero, creating significant infrastructure requirements.
263 |
264 | **Trapped-ion quantum computers** (IonQ, Quantinuum) offer superior coherence times exceeding 10 seconds and all-to-all connectivity but operate at slower gate speeds. IonQ's 32-qubit system achieved algorithmic qubits (#AQ) of 20, setting a record for effective computational capability when error mitigation is considered. Quantinuum's H-Series demonstrated the first logical qubit with real-time quantum error correction, a significant milestone towards fault-tolerant quantum computing.
265 |
266 | **Photonic quantum systems** (Xanadu, PsiQuantum) represent an alternative approach with potentially simpler scaling requirements. Xanadu's Borealis processor demonstrated quantum advantage for specific sampling problems but lacks the gate-based universality required for most pharmaceutical applications. PsiQuantum's fault-tolerant silicon photonic approach continues rapid development with semiconductor manufacturing partner GlobalFoundries but remains pre-commercial.
267 |
268 | **Neutral atom platforms** (QuEra, Pasqal) entered commercial accessibility in 2023, offering unprecedented qubit counts (QuEra: 256 atoms) with programmable geometries particularly suited for quantum simulation of molecular systems. Recent demonstrations of 3D atom arrangements provide promising avenues for simulating protein-ligand interactions.
269 |
270 | ### Quantum Algorithm Development
271 |
272 | Pharmaceutical applications currently focus on three quantum algorithm classes:
273 |
274 | 1. **Variational Quantum Eigensolver (VQE)** algorithms have progressed significantly for molecular ground state energy calculations, with Riverlane's enhanced VQE implementations demonstrating accuracy within 1.5 kcal/mol for molecules up to 20 atoms on IBM's 127-qubit processors. Merck's collaboration with Zapata Computing improved convergence rates by 300% through adaptive ansatz methods.
275 |
276 | 2. **Quantum Machine Learning (QML)** approaches for binding affinity prediction have shown mixed results. Pfizer's implementation of quantum convolutional neural networks (QCNN) demonstrated a 22% improvement in binding affinity predictions for their kinase inhibitor library, while AstraZeneca's quantum support vector machine approach showed no significant advantage over classical methods for their dataset.
277 |
278 | 3. **Quantum Annealing** for conformational search remains dominated by D-Wave's 5,000+ qubit systems, with Boehringer Ingelheim reporting successful applications in peptide folding predictions. However, comparisons with enhanced classical methods (particularly those using modern GPUs) show quantum advantage remains elusive for most production cases.
279 |
280 | ## Pharmaceutical Applications Landscape
281 |
282 | ### Virtual Screening Transformation
283 |
284 | GSK's quantum computing team achieved a significant milestone in 2022 through quantum-classical hybrid algorithms that accelerated screening of 10^7 compounds against novel SARS-CoV-2 targets. Their approach used classical computers for initial filtering followed by quantum evaluation of 10^4 promising candidates, identifying 12 compounds with nanomolar binding affinities subsequently confirmed by experimental assays. While impressive, the computational requirements exceeded $1.2M and required specialized expertise from partners at Quantinuum.
285 |
286 | ### Molecular Property Prediction
287 |
288 | Roche's collaboration with Cambridge Quantum Computing (now Quantinuum) demonstrated quantum advantage for dipole moment calculations in drug-like molecules, achieving accuracy improvements of 16% compared to density functional theory methods while potentially offering asymptotic speedup as molecule size increases. Their hybrid quantum-classical approach requires significantly fewer qubits than full quantum simulation, making it commercially relevant within the NISQ (Noisy Intermediate-Scale Quantum) era of hardware.
289 |
290 | ### Retrosynthesis Planning
291 |
292 | Quantum approaches to synthetic route planning remain largely theoretical with limited experimental validation. MIT-Takeda research demonstrated proof-of-concept for mapping retrosynthesis to quantum walks on Johnson graphs, with preliminary results showing promise for identifying non-obvious synthetic pathways. Commercial application appears distant (5-8 years) given current hardware limitations.
293 |
294 | ## Economic Implications Analysis
295 |
296 | Our economic model quantifies four significant impacts on pharmaceutical R&D:
297 |
298 | 1. **Preclinical timeline compression**: Currently estimated at 2-5% (0.5-1.3 months) but projected to reach 15-30% by 2030 as quantum hardware capabilities expand, potentially reducing time-to-market by up to 9 months for novel compounds
299 |
300 | 2. **Candidate quality improvements**: Quantum-enhanced binding affinity and ADMET property predictions demonstrate 7-18% higher success rates in early clinical phases across our analysis of 87 compounds that utilized quantum computational methods in preclinical development
301 |
302 | 3. **Novel mechanism identification**: Quantum simulation of previously intractable biological targets (particularly intrinsically disordered proteins and complex protein-protein interactions) could expand the druggable proteome by an estimated 8-14% according to our analysis of protein data bank targets
303 |
304 | 4. **R&D productivity impacts**: A 10% improvement in candidate quality translates to approximately $310M in reduced clinical development costs per approved drug by reducing late-stage failures
305 |
306 | ## Investment and Adoption Patterns
307 |
308 | Pharmaceutical quantum computing investment has accelerated dramatically, with cumulative industry investment growing from $18M (2018) to $597M (2023). Investment strategies fall into three categories:
309 |
310 | 1. **Direct infrastructure investment** (Roche, Merck): Building internal quantum teams and securing dedicated quantum hardware access
311 |
312 | 2. **Collaborative research partnerships** (GSK, Biogen, Novartis): Forming multi-year academic and commercial partnerships focused on specific computational challenges
313 |
314 | 3. **Quantum-as-a-service utilization** (Majority approach): Accessing quantum capabilities through cloud providers with limited internal expertise development
315 |
316 | Our analysis of 23 pharmaceutical companies indicates:
317 | - 19% have established dedicated quantum computing teams
318 | - 43% have active research collaborations with quantum providers
319 | - 78% report evaluating quantum capabilities for specific workflows
320 | - 100% express concerns about quantum talent acquisition challenges
321 |
322 | ## Future Outlook and Strategic Recommendations
323 |
324 | The pharmaceutical quantum computing landscape will evolve through three distinct phases:
325 |
326 | **Near-term (1-3 years)**: Hybrid quantum-classical algorithms will demonstrate incremental value in specific niches, particularly molecular property calculations and conformational analysis of small to medium-sized molecules. Successful organizations will combine quantum capabilities with enhanced classical methods rather than seeking immediate quantum advantage.
327 |
328 | **Mid-term (3-7 years)**: Error-corrected logical qubits will enable more robust quantum chemistry applications with demonstrable advantage for drug discovery workflows. Companies with established quantum capabilities will gain first-mover advantages in applying these technologies to proprietary chemical matter.
329 |
330 | **Long-term (7+ years)**: Fault-tolerant quantum computers with thousands of logical qubits could transform pharmaceutical R&D by enabling full quantum mechanical simulation of protein-drug interactions and previously intractable biological systems. This capability could fundamentally alter drug discovery economics by dramatically reducing empirical screening requirements.
331 |
332 | ## References
333 |
334 | 1. IBM Quantum. (2023). Osprey processor technical specifications and benchmarking
335 | 2. IonQ. (2023). Algorithmic qubit benchmarking methodology and results
336 | 3. Quantinuum. (2022). H-Series logical qubit demonstration
337 | 4. Xanadu. (2022). Borealis quantum advantage results. Nature Physics, 18
338 | 5. QuEra. (2023). Neutral atom quantum processor capabilities. Science, 377
339 | 6. Riverlane & Merck. (2022). Enhanced VQE implementations for molecular ground state calculations
340 | 7. Pfizer Quantum Team. (2023). QCNN for binding affinity prediction. Journal of Chemical Information and Modeling
341 | 8. AstraZeneca. (2022). Comparative analysis of quantum and classical ML methods
342 | 9. Boehringer Ingelheim. (2023). Quantum annealing for peptide conformational search
343 | 10. GSK Quantum Computing Team. (2022). Quantum-classical hybrid screening against SARS-CoV-2
344 | 11. Roche & Cambridge Quantum Computing. (2023). Quantum advantage for dipole moment calculations
345 | 12. MIT-Takeda Quantum Research. (2022). Mapping retrosynthesis to quantum walks
346 | 13. PhRMA Quantum Computing Working Group. (2023). Pharmaceutical R&D impact analysis
347 | """,
348 |
349 | """
350 | # Neuroplasticity in Cognitive Rehabilitation: Mechanisms, Interventions, and Clinical Applications
351 |
352 | ## Abstract
353 |
354 | This multidisciplinary review synthesizes current understanding of neuroplasticity mechanisms underlying cognitive rehabilitation, evaluating intervention efficacies across five domains of cognitive function following acquired brain injury. Integrating data from 142 clinical studies with advanced neuroimaging findings, we present evidence-based recommendations for clinical practice and identify promising emerging approaches.
355 |
356 | ## Neurobiological Foundations of Rehabilitation-Induced Plasticity
357 |
358 | ### Cellular and Molecular Mechanisms
359 |
360 | Recent advances in understanding activity-dependent plasticity have revolutionized rehabilitation approaches. The pioneering work of Dr. Alvarez-Buylla at UCSF has demonstrated that even the adult human brain maintains neurogenic capabilities in the hippocampus and subventricular zone, with newly generated neurons integrating into existing neural circuits following injury. Transcriptomic studies by Zhang et al. (2021) identified 37 genes significantly upregulated during rehabilitation-induced recovery, with brain-derived neurotrophic factor (BDNF) and insulin-like growth factor-1 (IGF-1) showing particularly strong associations with positive outcomes.
361 |
362 | Post-injury plasticity occurs through multiple complementary mechanisms:
363 |
364 | 1. **Synaptic remodeling**: Two-photon microscopy studies in animal models reveal extensive dendritic spine turnover within peri-lesional cortex during the first 3-8 weeks post-injury. The pioneering work of Professor Li-Huei Tsai demonstrates that enriched rehabilitation environments increase spine formation rates by 47-68% compared to standard housing conditions.
365 |
366 | 2. **Network reorganization**: Professor Nicholas Schiff's research at Weill Cornell demonstrates that dormant neural pathways can be functionally recruited following injury through targeted stimulation. Their multimodal imaging studies identified specific thalamocortical circuits that, when engaged through non-invasive stimulation, facilitated motor recovery in 72% of chronic stroke patients previously classified as "plateaued."
367 |
368 | 3. **Myelination dynamics**: Recent discoveries by Dr. Fields at NIH demonstrate activity-dependent myelination as a previously unrecognized form of neuroplasticity. Diffusion tensor imaging studies by Wang et al. (2022) show significant increases in white matter integrity following intensive cognitive training, correlating with functional improvements (r=0.62, p<0.001).
369 |
370 | ### Neuroimaging Correlates of Successful Rehabilitation
371 |
372 | Longitudinal multimodal neuroimaging studies have identified several biomarkers of successful cognitive rehabilitation:
373 |
374 | - **Functional connectivity reorganization**: Using resting-state fMRI, Northoff's laboratory documented that successful attention training in 67 TBI patients correlated with increased connectivity between the dorsolateral prefrontal cortex and posterior parietal regions (change in z-score: 0.43 ± 0.12), while unsuccessful cases showed no significant connectivity changes.
375 |
376 | - **Cortical thickness preservation**: Dr. Gabrieli's team at MIT found that cognitive rehabilitation initiated within 30 days of injury preserved cortical thickness in vulnerable regions, with each week of delay associated with 0.8% additional atrophy in domain-relevant cortical regions.
377 |
378 | - **Default mode network modulation**: Advanced network analyses by Dr. Marcus Raichle demonstrate that cognitive rehabilitation success correlates with restoration of appropriate task-related deactivation of the default mode network, suggesting intervention effectiveness can be monitored through this biomarker.
379 |
380 | ## Evidence-Based Intervention Analysis
381 |
382 | ### Attention and Executive Function Rehabilitation
383 |
384 | Our meta-analysis of 42 randomized controlled trials evaluating attention training programs reveals three intervention approaches with significant effect sizes:
385 |
386 | 1. **Adaptive computerized training** (Hedges' g = 0.68, 95% CI: 0.54-0.82): Programs like Attention Process Training showed transfer to untrained measures when training adapts in real-time to performance. The NYU-Columbia adaptive attention protocol demonstrated maintenance of gains at 18-month follow-up (retention rate: 83%).
387 |
388 | 2. **Metacognitive strategy training** (Hedges' g = 0.57, 95% CI: 0.41-0.73): The Toronto Hospital's Strategic Training for Executive Control program resulted in significant improvements on ecological measures of executive function. Moderator analyses indicate effectiveness increases when combined with daily strategy implementation exercises (interaction effect: p=0.002).
389 |
390 | 3. **Neurostimulation-enhanced approaches**: Combined tDCS-cognitive training protocols developed at Harvard demonstrate 37% greater improvement compared to cognitive training alone. Targeting the right inferior frontal gyrus with 2mA anodal stimulation during inhibitory control training shows particular promise for impulsivity reduction (Cohen's d = 0.74).
391 |
392 | ### Memory Rehabilitation Approaches
393 |
394 | Memory intervention effectiveness varies substantially by memory system affected and etiology:
395 |
396 | - **Episodic memory**: For medial temporal lobe damage, compensatory approaches using spaced retrieval and errorless learning demonstrate the strongest evidence. Dr. Schacter's laboratory protocol combining elaborative encoding with distributed practice shows a remarkable 247% improvement in functional memory measures compared to intensive rehearsal techniques.
397 |
398 | - **Prospective memory**: Implementation intention protocols developed by Professor Gollwitzer show transfer to daily functioning with large effect sizes (d = 0.92) when combined with environmental restructuring. Smartphone-based reminder systems increased medication adherence by 43% in our 12-month community implementation study.
399 |
400 | - **Working memory**: Recent controversy surrounding n-back training was addressed in Professor Klingberg's definitive multi-site study demonstrating domain-specific transfer effects. Their adaptive protocol produced sustainable working memory improvements (40% above baseline at 6-month follow-up) when training exceeded 20 hours and incorporated gradually increasing interference control demands.
401 |
402 | ## Clinical Application Framework
403 |
404 | ### Precision Rehabilitation Medicine Approach
405 |
406 | Our analysis indicates rehabilitation effectiveness increases substantially when protocols are tailored using a precision medicine framework:
407 |
408 | 1. **Comprehensive neurocognitive phenotyping**: The McGill Cognitive Rehabilitation Battery enables identification of specific processing deficits, allowing intervention targeting. Machine learning analysis of 1,247 patient profiles identified 11 distinct neurocognitive phenotypes that respond differentially to specific interventions.
409 |
410 | 2. **Biomarker-guided protocol selection**: EEG connectivity measures predicted response to attention training with 76% accuracy in our validation cohort, potentially reducing non-response rates. Professor Knight's laboratory demonstrated that P300 latency specifically predicts processing speed training response (AUC = 0.81).
411 |
412 | 3. **Adaptive progression algorithms**: Real-time difficulty adjustment based on multiple performance parameters rather than accuracy alone increased transfer effects by 34% compared to standard adaptive approaches. The computational model developed by Stanford's Poldrack laboratory dynamically optimizes challenge levels to maintain engagement while maximizing error-based learning.
413 |
414 | ### Implementation Science Considerations
415 |
416 | Our implementation analysis across 24 rehabilitation facilities identified critical factors for successful cognitive rehabilitation programs:
417 |
418 | - **Rehabilitation intensity and timing**: Early intervention (< 6 weeks post-injury) with high intensity (minimum 15 hours/week of direct treatment) demonstrated superior outcomes (NNT = 3.2 for clinically significant improvement).
419 |
420 | - **Therapist expertise effects**: Specialized certification in cognitive rehabilitation was associated with 28% larger treatment effects compared to general rehabilitation credentials.
421 |
422 | - **Technology augmentation**: Hybrid models combining therapist-directed sessions with home-based digital practice demonstrated optimal cost-effectiveness (ICER = $12,430/QALY) while addressing access barriers.
423 |
424 | ## Future Directions and Emerging Approaches
425 |
426 | Several innovative approaches show promise for enhancing neuroplasticity during cognitive rehabilitation:
427 |
428 | 1. **Closed-loop neurostimulation**: Dr. Suthana's team at UCLA demonstrated that theta-burst stimulation delivered precisely during specific phases of hippocampal activity enhanced associative memory formation by 37% in patients with mild cognitive impairment.
429 |
430 | 2. **Pharmacologically augmented rehabilitation**: The RESTORE trial combining daily atomoxetine with executive function training demonstrated synergistic effects (interaction p<0.001) compared to either intervention alone. Professor Feeney's research suggests a critical 30-minute window where noradrenergic enhancement specifically promotes task-relevant plasticity.
431 |
432 | 3. **Virtual reality cognitive training**: Immersive VR protocols developed at ETH Zurich demonstrated transfer to real-world functioning by simulating ecologically relevant scenarios with graduated difficulty. Their randomized trial showed 3.2× greater functional improvement compared to matched non-immersive training.
433 |
434 | 4. **Sleep optimization protocols**: The Northwestern sleep-enhanced memory consolidation protocol increased rehabilitation effectiveness by 41% by delivering targeted memory reactivation during slow-wave sleep, suggesting rehabilitation schedules should specifically incorporate sleep architecture considerations.
435 |
436 | ## Conclusion
437 |
438 | Cognitive rehabilitation effectiveness has improved substantially through integration of neuroplasticity principles, advanced technology, and precision intervention approaches. Optimal outcomes occur when interventions target specific neurocognitive mechanisms with sufficient intensity and are tailored to individual patient profiles. Emerging approaches leveraging closed-loop neurotechnology and multimodal enhancement strategies represent promising directions for further advancing rehabilitation outcomes.
439 |
440 | ## References
441 |
442 | 1. Alvarez-Buylla, A., & Lim, D. A. (2022). Neurogenesis in the adult human brain following injury
443 | 2. Zhang, Y., et al. (2021). Transcriptomic analysis of rehabilitation-responsive genes
444 | 3. Tsai, L. H., et al. (2023). Environmental enrichment effects on dendritic spine dynamics
445 | 4. Schiff, N. D. (2022). Recruitment of dormant neural pathways following brain injury
446 | 5. Fields, R. D. (2021). Activity-dependent myelination as a form of neuroplasticity
447 | 6. Wang, X., et al. (2022). White matter integrity changes following cognitive training
448 | 7. Northoff, G., et al. (2023). Functional connectivity reorganization during attention training
449 | 8. Gabrieli, J. D., et al. (2021). Relationship between intervention timing and cortical preservation
450 | 9. Raichle, M. E. (2022). Default mode network dynamics as a biomarker of rehabilitation efficacy
451 | 10. NYU-Columbia Collaborative. (2023). Adaptive attention protocol long-term outcomes
452 | 11. Schacter, D. L., et al. (2021). Elaborative encoding with distributed practice for episodic memory
453 | 12. Gollwitzer, P. M., & Oettingen, G. (2022). Implementation intentions for prospective memory
454 | 13. Klingberg, T., et al. (2023). Multi-site study of adaptive working memory training
455 | 14. Poldrack, R. A., et al. (2022). Computational models for optimizing learning parameters
456 | 15. Suthana, N., et al. (2023). Phase-specific closed-loop stimulation for memory enhancement
457 | 16. Feeney, D. M., & Sutton, R. L. (2022). Pharmacological enhancement of rehabilitation
458 | 17. ETH Zurich Rehabilitation Engineering Group. (2023). Virtual reality cognitive training
459 | 18. Northwestern Memory & Cognition Laboratory. (2022). Sleep-enhanced memory consolidation
460 | """
461 | ]
462 |
463 | async def display_workflow_diagram(workflow):
464 | """Display a visual representation of the workflow DAG."""
465 | console.print("\n[bold cyan]Workflow Execution Plan[/bold cyan]")
466 |
467 | # Create a tree representation of the workflow
468 | tree = Tree("[bold yellow]Research Analysis Workflow[/bold yellow]")
469 |
470 | # Track dependencies for visualization
471 | dependencies = {}
472 | for stage in workflow:
473 | stage_id = stage["stage_id"]
474 | deps = stage.get("depends_on", [])
475 | for dep in deps:
476 | if dep not in dependencies:
477 | dependencies[dep] = []
478 | dependencies[dep].append(stage_id)
479 |
480 | # Add stages without dependencies first (roots)
481 | root_stages = [s for s in workflow if not s.get("depends_on")]
482 | stage_map = {s["stage_id"]: s for s in workflow}
483 |
484 | def add_stage_to_tree(parent_tree, stage_id):
485 | stage = stage_map[stage_id]
486 | tool = stage["tool_name"]
487 | node_text = f"[bold green]{stage_id}[/bold green] ([cyan]{tool}[/cyan])"
488 |
489 | if "iterate_on" in stage:
490 | node_text += " [italic](iterative)[/italic]"
491 |
492 | stage_node = parent_tree.add(node_text)
493 |
494 | # Add children (stages that depend on this one)
495 | children = dependencies.get(stage_id, [])
496 | for child in children:
497 | add_stage_to_tree(stage_node, child)
498 |
499 | # Build the tree
500 | for root in root_stages:
501 | add_stage_to_tree(tree, root["stage_id"])
502 |
503 | # Print the tree
504 | console.print(tree)
505 |
506 | # Display additional workflow statistics
507 | table = Table(title="Workflow Statistics")
508 | table.add_column("Metric", style="cyan")
509 | table.add_column("Value", style="green")
510 |
511 | table.add_row("Total Stages", str(len(workflow)))
512 | table.add_row("Parallel Stages", str(len(root_stages)))
513 | table.add_row("Iterative Stages", str(sum(1 for s in workflow if "iterate_on" in s)))
514 |
515 | console.print(table)
516 |
517 | async def display_execution_progress(workflow_future):
518 | """Display a live progress indicator while the workflow executes."""
519 | with Progress(
520 | SpinnerColumn(),
521 | TextColumn("[bold blue]{task.description}"),
522 | console=console
523 | ) as progress:
524 | task = progress.add_task("[yellow]Executing workflow...", total=None)
525 | result = await workflow_future
526 | progress.update(task, completed=True, description="[green]Workflow completed!")
527 | return result
528 |
529 | async def visualize_results(results):
530 | """Create visualizations of the workflow results."""
531 | console.print("\n[bold magenta]Research Analysis Results[/bold magenta]")
532 |
533 | # Set up layout
534 | layout = Layout()
535 | layout.split_column(
536 | Layout(name="header"),
537 | Layout(name="statistics"),
538 | Layout(name="summaries"),
539 | Layout(name="extracted_entities"),
540 | )
541 |
542 | # Header
543 | layout["header"].update(Panel(
544 | "[bold]Advanced Research Assistant Results[/bold]",
545 | style="blue"
546 | ))
547 |
548 | # Statistics
549 | stats_table = Table(title="Document Processing Statistics")
550 | stats_table.add_column("Document", style="cyan")
551 | stats_table.add_column("Word Count", style="green")
552 | stats_table.add_column("Entity Count", style="yellow")
553 |
554 | try:
555 | chunking_result = results["results"]["chunking_stage"]["output"]
556 | entity_results = results["results"]["entity_extraction_stage"]["output"]
557 |
558 | for i, doc_stats in enumerate(chunking_result.get("document_stats", [])):
559 | entity_count = len(entity_results[i].get("entities", []))
560 | stats_table.add_row(
561 | f"Document {i+1}",
562 | str(doc_stats.get("word_count", "N/A")),
563 | str(entity_count)
564 | )
565 | except (KeyError, IndexError) as e:
566 | console.print(f"[red]Error displaying statistics: {e}[/red]")
567 |
568 | layout["statistics"].update(stats_table)
569 |
570 | # Summaries
571 | summary_panels = []
572 | try:
573 | summaries = results["results"]["summary_stage"]["output"]
574 | for i, summary in enumerate(summaries):
575 | summary_panels.append(Panel(
576 | summary.get("summary", "No summary available"),
577 | title=f"Document {i+1} Summary",
578 | border_style="green"
579 | ))
580 | except (KeyError, IndexError) as e:
581 | summary_panels.append(Panel(
582 | f"Error retrieving summaries: {e}",
583 | title="Summary Error",
584 | border_style="red"
585 | ))
586 |
587 | layout["summaries"].update(summary_panels)
588 |
589 | # Extracted entities
590 | try:
591 | final_analysis = results["results"]["final_analysis_stage"]["output"]
592 | json_str = Syntax(
593 | str(final_analysis.get("analysis", "No analysis available")),
594 | "json",
595 | theme="monokai",
596 | line_numbers=True
597 | )
598 | layout["extracted_entities"].update(Panel(
599 | json_str,
600 | title="Final Analysis",
601 | border_style="magenta"
602 | ))
603 | except (KeyError, IndexError) as e:
604 | layout["extracted_entities"].update(Panel(
605 | f"Error retrieving final analysis: {e}",
606 | title="Analysis Error",
607 | border_style="red"
608 | ))
609 |
610 | # Print layout
611 | console.print(layout)
612 |
613 | # Display execution time
614 | console.print(
615 | f"\n[bold green]Total workflow execution time:[/bold green] "
616 | f"{results.get('total_processing_time', 0):.2f} seconds"
617 | )
618 |
619 | def create_research_workflow():
620 | """Define a complex research workflow with multiple parallel and sequential stages."""
621 | workflow = [
622 | # Initial document processing stages (run in parallel for all documents)
623 | {
624 | "stage_id": "chunking_stage",
625 | "tool_name": "chunk_document",
626 | "params": {
627 | "text": "${documents}",
628 | "chunk_size": 1000,
629 | "get_stats": True
630 | }
631 | },
632 |
633 | # Entity extraction runs in parallel with summarization
634 | {
635 | "stage_id": "entity_extraction_stage",
636 | "tool_name": "extract_entity_graph",
637 | "params": {
638 | "text": "${documents}",
639 | "entity_types": ["organization", "person", "concept", "location", "technology"],
640 | "include_relations": True,
641 | "confidence_threshold": 0.7
642 | }
643 | },
644 |
645 | # Summarization stage (iterate over each document)
646 | {
647 | "stage_id": "summary_stage",
648 | "tool_name": "summarize_document",
649 | "params": {
650 | "text": "${documents}",
651 | "max_length": 150,
652 | "focus_on": "key findings and implications"
653 | }
654 | },
655 |
656 | # Classification of document topics
657 | {
658 | "stage_id": "classification_stage",
659 | "tool_name": "text_classification",
660 | "depends_on": ["chunking_stage"],
661 | "params": {
662 | "text": "${chunking_stage.document_text}",
663 | "categories": [
664 | "Climate & Environment",
665 | "Technology",
666 | "Healthcare",
667 | "Economy",
668 | "Social Policy",
669 | "Scientific Research"
670 | ],
671 | "provider": Provider.OPENAI.value,
672 | "multi_label": True,
673 | "confidence_threshold": 0.6
674 | }
675 | },
676 |
677 | # Generate structured insights from entity analysis
678 | {
679 | "stage_id": "entity_insights_stage",
680 | "tool_name": "extract_json",
681 | "depends_on": ["entity_extraction_stage"],
682 | "params": {
683 | "text": "${entity_extraction_stage.text_output}",
684 | "schema": {
685 | "key_entities": "array",
686 | "primary_relationships": "array",
687 | "research_domains": "array"
688 | },
689 | "include_reasoning": True
690 | }
691 | },
692 |
693 | # Cost-optimized final analysis
694 | {
695 | "stage_id": "model_selection_stage",
696 | "tool_name": "recommend_model",
697 | "depends_on": ["summary_stage", "classification_stage", "entity_insights_stage"],
698 | "params": {
699 | "task_type": "complex analysis and synthesis",
700 | "expected_input_length": 3000,
701 | "expected_output_length": 1000,
702 | "required_capabilities": ["reasoning", "knowledge"],
703 | "priority": "balanced"
704 | }
705 | },
706 |
707 | # Final analysis and synthesis
708 | {
709 | "stage_id": "final_analysis_stage",
710 | "tool_name": "chat_completion",
711 | "depends_on": ["model_selection_stage", "summary_stage", "classification_stage", "entity_insights_stage"],
712 | "params": {
713 | "messages": [
714 | {
715 | "role": "system",
716 | "content": "You are a research assistant synthesizing information from multiple documents."
717 | },
718 | {
719 | "role": "user",
720 | "content": "Analyze the following research summaries, classifications, and entity insights. Provide a comprehensive analysis that identifies cross-document patterns, contradictions, and key insights. Format the response as structured JSON.\n\nSummaries: ${summary_stage.summary}\n\nClassifications: ${classification_stage.classifications}\n\nEntity Insights: ${entity_insights_stage.content}"
721 | }
722 | ],
723 | "model": "${model_selection_stage.recommendations[0].model}",
724 | "response_format": {"type": "json_object"}
725 | }
726 | }
727 | ]
728 |
729 | return workflow
730 |
731 | async def main():
732 | """Run the complete research assistant workflow demo."""
733 | console.print(Rule("[bold magenta]Advanced Research Workflow Demo[/bold magenta]"))
734 | tracker = CostTracker() # Instantiate tracker
735 |
736 | try:
737 | # Display header
738 | console.print(Panel.fit(
739 | "[bold cyan]Advanced Research Assistant Workflow Demo[/bold cyan]\n"
740 | "Powered by NetworkX DAG-based Workflow Engine",
741 | title="Ultimate MCP Server",
742 | border_style="green"
743 | ))
744 |
745 | # Create the workflow definition
746 | workflow = create_research_workflow()
747 |
748 | # Visualize the workflow before execution
749 | await display_workflow_diagram(workflow)
750 |
751 | # Prompt user to continue
752 | console.print("\n[yellow]Press Enter to execute the workflow...[/yellow]", end="")
753 | input()
754 |
755 | # Execute workflow with progress display
756 | workflow_future = execute_optimized_workflow(
757 | documents=SAMPLE_DOCS,
758 | workflow=workflow,
759 | max_concurrency=3
760 | )
761 |
762 | results = await display_execution_progress(workflow_future)
763 |
764 | # Track cost if possible
765 | if results and isinstance(results, dict) and "cost" in results:
766 | try:
767 | total_cost = results.get("cost", {}).get("total_cost", 0.0)
768 | processing_time = results.get("total_processing_time", 0.0)
769 | # Provider/Model is ambiguous here, use a placeholder
770 | trackable = TrackableResult(
771 | cost=total_cost,
772 | input_tokens=0, # Not aggregated
773 | output_tokens=0, # Not aggregated
774 | provider="workflow",
775 | model="research_workflow",
776 | processing_time=processing_time
777 | )
778 | tracker.add_call(trackable)
779 | except Exception as track_err:
780 | logger.warning(f"Could not track workflow cost: {track_err}", exc_info=False)
781 |
782 | if results:
783 | console.print(Rule("[bold green]Workflow Execution Completed[/bold green]"))
784 | await visualize_results(results.get("outputs", {}))
785 | else:
786 | console.print("[bold red]Workflow execution failed or timed out.[/bold red]")
787 |
788 | except Exception as e:
789 | console.print(f"[bold red]An unexpected error occurred:[/bold red] {e}")
790 |
791 | # Display cost summary
792 | tracker.display_summary(console)
793 |
794 | if __name__ == "__main__":
795 | asyncio.run(main())
```
--------------------------------------------------------------------------------
/ultimate_mcp_server/services/vector/vector_service.py:
--------------------------------------------------------------------------------
```python
1 | """Vector database service for semantic search."""
2 | import asyncio
3 | import json
4 | import time
5 | import uuid
6 | from pathlib import Path
7 | from typing import Any, Dict, List, Optional, Union
8 |
9 | import numpy as np
10 | from sklearn.metrics.pairwise import cosine_similarity
11 |
12 | from ultimate_mcp_server.services.vector.embeddings import get_embedding_service
13 | from ultimate_mcp_server.utils import get_logger
14 |
15 | logger = get_logger(__name__)
16 |
17 | # Try to import chromadb
18 | try:
19 | import chromadb
20 | from chromadb.config import Settings as ChromaSettings
21 | CHROMADB_AVAILABLE = True
22 | logger.info("ChromaDB imported successfully", extra={"emoji_key": "success"})
23 | except ImportError as e:
24 | logger.warning(f"ChromaDB not available: {str(e)}", extra={"emoji_key": "warning"})
25 | CHROMADB_AVAILABLE = False
26 |
27 | # Try to import hnswlib, but don't fail if not available
28 | try:
29 | import hnswlib
30 | HNSWLIB_AVAILABLE = True
31 | HNSW_INDEX = hnswlib.Index
32 | except ImportError:
33 | HNSWLIB_AVAILABLE = False
34 | HNSW_INDEX = None
35 |
36 |
37 | class VectorCollection:
38 | """A collection of vectors with metadata."""
39 |
40 | def __init__(
41 | self,
42 | name: str,
43 | dimension: int = 1536,
44 | similarity_metric: str = "cosine",
45 | metadata: Optional[Dict[str, Any]] = None
46 | ):
47 | """Initialize a vector collection.
48 |
49 | Args:
50 | name: Collection name
51 | dimension: Vector dimension
52 | similarity_metric: Similarity metric (cosine, dot, or euclidean)
53 | metadata: Optional metadata for the collection
54 | """
55 | self.name = name
56 | self.dimension = dimension
57 | self.similarity_metric = similarity_metric
58 | self.metadata = metadata or {}
59 |
60 | # Initialize storage
61 | self.vectors = []
62 | self.ids = []
63 | self.metadatas = []
64 |
65 | # Create embedding service
66 | self.embedding_service = get_embedding_service()
67 |
68 | # Initialize search index
69 | self._init_search_index()
70 |
71 | logger.info(
72 | f"Vector collection '{name}' created ({dimension} dimensions)",
73 | extra={"emoji_key": "vector"}
74 | )
75 |
76 | def _init_search_index(self):
77 | """Initialize search index based on available libraries."""
78 | self.index_type = "numpy" # Fallback
79 | self.index = None
80 |
81 | # Try to use HNSW for fast search if available
82 | if HNSWLIB_AVAILABLE:
83 | try:
84 | self.index = HNSW_INDEX(space=self._get_hnswlib_space(), dim=self.dimension)
85 | self.index.init_index(max_elements=1000, ef_construction=200, M=16)
86 | self.index.set_ef(50) # Search accuracy parameter
87 | self.index_type = "hnswlib"
88 | logger.debug(
89 | f"Using HNSW index for collection '{self.name}'",
90 | emoji_key="vector"
91 | )
92 | except Exception as e:
93 | logger.warning(
94 | f"Failed to initialize HNSW index: {str(e)}. Falling back to numpy.",
95 | emoji_key="warning"
96 | )
97 | self.index = None
98 |
99 | def _get_hnswlib_space(self) -> str:
100 | """Get HNSW space based on similarity metric.
101 |
102 | Returns:
103 | HNSW space name
104 | """
105 | if self.similarity_metric == "cosine":
106 | return "cosine"
107 | elif self.similarity_metric == "dot":
108 | return "ip" # Inner product
109 | elif self.similarity_metric == "euclidean":
110 | return "l2"
111 | else:
112 | return "cosine" # Default
113 |
114 | def add(
115 | self,
116 | vectors: Union[List[List[float]], np.ndarray],
117 | ids: Optional[List[str]] = None,
118 | metadatas: Optional[List[Dict[str, Any]]] = None
119 | ) -> List[str]:
120 | """Add vectors to the collection.
121 |
122 | Args:
123 | vectors: Vectors to add
124 | ids: Optional IDs for the vectors (generated if not provided)
125 | metadatas: Optional metadata for each vector
126 |
127 | Returns:
128 | List of vector IDs
129 | """
130 | # Ensure vectors is a numpy array
131 | if not isinstance(vectors, np.ndarray):
132 | vectors = np.array(vectors, dtype=np.float32)
133 |
134 | # Generate IDs if not provided
135 | if ids is None:
136 | ids = [str(uuid.uuid4()) for _ in range(len(vectors))]
137 |
138 | # Ensure metadatas is a list of dicts
139 | if metadatas is None:
140 | metadatas = [{} for _ in range(len(vectors))]
141 |
142 | # Add to storage
143 | for _i, (vector, id, metadata) in enumerate(zip(vectors, ids, metadatas, strict=False)):
144 | self.vectors.append(vector)
145 | self.ids.append(id)
146 | self.metadatas.append(metadata)
147 |
148 | # Update index if using HNSW
149 | if self.index_type == "hnswlib" and self.index is not None:
150 | try:
151 | # Resize index if needed
152 | if len(self.vectors) > self.index.get_max_elements():
153 | new_size = max(1000, len(self.vectors) * 2)
154 | self.index.resize_index(new_size)
155 |
156 | # Add vectors to index
157 | start_idx = len(self.vectors) - len(vectors)
158 | for i, vector in enumerate(vectors):
159 | self.index.add_items(vector, start_idx + i)
160 | except Exception as e:
161 | logger.error(
162 | f"Failed to update HNSW index: {str(e)}",
163 | emoji_key="error"
164 | )
165 | # Rebuild index
166 | self._rebuild_index()
167 |
168 | logger.debug(
169 | f"Added {len(vectors)} vectors to collection '{self.name}'",
170 | emoji_key="vector"
171 | )
172 |
173 | return ids
174 |
175 | def _rebuild_index(self):
176 | """Rebuild the search index from scratch."""
177 | if not HNSWLIB_AVAILABLE or not self.vectors:
178 | return
179 |
180 | try:
181 | # Re-initialize index
182 | self.index = HNSW_INDEX(space=self._get_hnswlib_space(), dim=self.dimension)
183 | self.index.init_index(max_elements=max(1000, len(self.vectors) * 2), ef_construction=200, M=16)
184 | self.index.set_ef(50)
185 |
186 | # Add all vectors
187 | vectors_array = np.array(self.vectors, dtype=np.float32)
188 | self.index.add_items(vectors_array, np.arange(len(self.vectors)))
189 |
190 | logger.info(
191 | f"Rebuilt HNSW index for collection '{self.name}'",
192 | emoji_key="vector"
193 | )
194 | except Exception as e:
195 | logger.error(
196 | f"Failed to rebuild HNSW index: {str(e)}",
197 | emoji_key="error"
198 | )
199 | self.index = None
200 | self.index_type = "numpy"
201 |
202 | def search(
203 | self,
204 | query_vector: Union[List[float], np.ndarray],
205 | top_k: int = 5,
206 | filter: Optional[Dict[str, Any]] = None,
207 | similarity_threshold: float = 0.0
208 | ) -> List[Dict[str, Any]]:
209 | """Search for similar vectors.
210 |
211 | Args:
212 | query_vector: Query vector
213 | top_k: Number of results to return
214 | filter: Optional metadata filter
215 | similarity_threshold: Minimum similarity score (0.0 to 1.0)
216 |
217 | Returns:
218 | List of results with scores and metadata
219 | """
220 | # Ensure query_vector is a numpy array
221 | if not isinstance(query_vector, np.ndarray):
222 | query_vector = np.array(query_vector, dtype=np.float32)
223 |
224 | # Log some diagnostic information
225 | logger.debug(f"Collection '{self.name}' contains {len(self.vectors)} vectors")
226 | logger.debug(f"Searching for top {top_k} matches with filter: {filter} and threshold: {similarity_threshold}")
227 |
228 | # Filter vectors based on metadata if needed
229 | if filter:
230 | filtered_indices = self._apply_filter(filter)
231 | if not filtered_indices:
232 | logger.debug(f"No vectors match the filter criteria: {filter}")
233 | return []
234 | logger.debug(f"Filter reduced search space to {len(filtered_indices)} vectors")
235 | else:
236 | filtered_indices = list(range(len(self.vectors)))
237 | logger.debug(f"No filter applied, searching all {len(filtered_indices)} vectors")
238 |
239 | # If no vectors to search, return empty results
240 | if not filtered_indices:
241 | logger.debug("No vectors to search, returning empty results")
242 | return []
243 |
244 | # Perform search based on index type
245 | if self.index_type == "hnswlib" and self.index is not None and not filter:
246 | # Use HNSW for fast search (only if no filter)
247 | try:
248 | start_time = time.time()
249 | labels, distances = self.index.knn_query(query_vector, k=min(top_k, len(self.vectors)))
250 | search_time = time.time() - start_time
251 |
252 | # Convert distances to similarities based on metric
253 | if self.similarity_metric == "cosine" or self.similarity_metric == "dot":
254 | similarities = 1.0 - distances[0] # Convert distance to similarity
255 | else:
256 | similarities = 1.0 / (1.0 + distances[0]) # Convert distance to similarity
257 |
258 | # Format results
259 | results = []
260 | for _i, (label, similarity) in enumerate(zip(labels[0], similarities, strict=False)):
261 | # Apply similarity threshold
262 | if similarity < similarity_threshold:
263 | continue
264 |
265 | results.append({
266 | "id": self.ids[label],
267 | "similarity": float(similarity),
268 | "metadata": self.metadatas[label],
269 | "vector": self.vectors[label].tolist(),
270 | })
271 |
272 | logger.debug(
273 | f"HNSW search completed in {search_time:.6f}s, found {len(results)} results"
274 | )
275 |
276 | for i, result in enumerate(results):
277 | logger.debug(f"Result {i+1}: id={result['id']}, similarity={result['similarity']:.4f}, metadata={result['metadata']}")
278 |
279 | return results
280 | except Exception as e:
281 | logger.error(
282 | f"HNSW search failed: {str(e)}. Falling back to numpy.",
283 | emoji_key="error"
284 | )
285 | # Fall back to numpy search
286 |
287 | # Numpy-based search (slower but always works)
288 | start_time = time.time()
289 |
290 | # Calculate similarities
291 | results = []
292 | for idx in filtered_indices:
293 | vector = self.vectors[idx]
294 |
295 | # Calculate similarity based on metric
296 | if self.similarity_metric == "cosine":
297 | similarity = cosine_similarity(query_vector, vector)
298 | elif self.similarity_metric == "dot":
299 | similarity = np.dot(query_vector, vector)
300 | elif self.similarity_metric == "euclidean":
301 | similarity = 1.0 / (1.0 + np.linalg.norm(query_vector - vector))
302 | else:
303 | similarity = cosine_similarity(query_vector, vector)
304 |
305 | # Apply similarity threshold
306 | if similarity < similarity_threshold:
307 | continue
308 |
309 | results.append({
310 | "id": self.ids[idx],
311 | "similarity": float(similarity),
312 | "metadata": self.metadatas[idx],
313 | "vector": vector.tolist(),
314 | })
315 |
316 | # Sort by similarity (descending)
317 | results.sort(key=lambda x: x["similarity"], reverse=True)
318 |
319 | # Limit to top_k
320 | results = results[:top_k]
321 |
322 | search_time = time.time() - start_time
323 | logger.debug(
324 | f"Numpy search completed in {search_time:.6f}s, found {len(results)} results"
325 | )
326 |
327 | for i, result in enumerate(results):
328 | logger.debug(f"Result {i+1}: id={result['id']}, similarity={result['similarity']:.4f}, metadata={result['metadata']}")
329 |
330 | return results
331 |
332 | def _apply_filter(self, filter: Dict[str, Any]) -> List[int]:
333 | """Apply metadata filter to get matching indices.
334 |
335 | Args:
336 | filter: Metadata filter
337 |
338 | Returns:
339 | List of matching indices
340 | """
341 | filtered_indices = []
342 | for i, metadata in enumerate(self.metadatas):
343 | # Simple equality filter for now
344 | match = True
345 | for k, v in filter.items():
346 | if k not in metadata or metadata[k] != v:
347 | match = False
348 | break
349 | if match:
350 | filtered_indices.append(i)
351 | return filtered_indices
352 |
353 | async def search_by_text(
354 | self,
355 | query_text: str,
356 | top_k: int = 5,
357 | filter: Optional[Dict[str, Any]] = None,
358 | model: Optional[str] = None,
359 | similarity_threshold: float = 0.0
360 | ) -> List[Dict[str, Any]]:
361 | """Search by text query.
362 |
363 | Args:
364 | query_text: Text query
365 | top_k: Number of results to return
366 | filter: Optional metadata filter
367 | model: Embedding model name
368 | similarity_threshold: Minimum similarity score (0.0 to 1.0)
369 |
370 | Returns:
371 | List of results with scores and metadata
372 | """
373 | # Get query embedding - call create_embeddings with a list and get the first result
374 | query_embeddings = await self.embedding_service.create_embeddings(
375 | texts=[query_text], # Pass text as a list
376 | # model=model # create_embeddings uses the model set during service init
377 | )
378 | if not query_embeddings: # Handle potential empty result
379 | logger.error(f"Failed to generate embedding for query: {query_text}")
380 | return []
381 |
382 | query_embedding = query_embeddings[0] # Get the first (only) embedding
383 |
384 | # Search with the embedding
385 | return self.search(query_embedding, top_k, filter, similarity_threshold)
386 |
387 | def delete(
388 | self,
389 | ids: Optional[List[str]] = None,
390 | filter: Optional[Dict[str, Any]] = None
391 | ) -> int:
392 | """Delete vectors from the collection.
393 |
394 | Args:
395 | ids: IDs of vectors to delete
396 | filter: Metadata filter for vectors to delete
397 |
398 | Returns:
399 | Number of vectors deleted
400 | """
401 | if ids is None and filter is None:
402 | return 0
403 |
404 | # Get indices to delete
405 | indices_to_delete = set()
406 |
407 | # Add indices by ID
408 | if ids:
409 | for i, id in enumerate(self.ids):
410 | if id in ids:
411 | indices_to_delete.add(i)
412 |
413 | # Add indices by filter
414 | if filter:
415 | filtered_indices = self._apply_filter(filter)
416 | indices_to_delete.update(filtered_indices)
417 |
418 | # Delete vectors (in reverse order to avoid index issues)
419 | indices_to_delete = sorted(indices_to_delete, reverse=True)
420 | for idx in indices_to_delete:
421 | del self.vectors[idx]
422 | del self.ids[idx]
423 | del self.metadatas[idx]
424 |
425 | # Rebuild index if using HNSW
426 | if self.index_type == "hnswlib" and self.index is not None:
427 | self._rebuild_index()
428 |
429 | logger.info(
430 | f"Deleted {len(indices_to_delete)} vectors from collection '{self.name}'",
431 | emoji_key="vector"
432 | )
433 |
434 | return len(indices_to_delete)
435 |
436 | def save(self, directory: Union[str, Path]) -> bool:
437 | """Save collection to disk.
438 |
439 | Args:
440 | directory: Directory to save to
441 |
442 | Returns:
443 | True if successful
444 | """
445 | directory = Path(directory)
446 | directory.mkdir(parents=True, exist_ok=True)
447 |
448 | try:
449 | # Save vectors
450 | vectors_array = np.array(self.vectors, dtype=np.float32)
451 | np.save(str(directory / "vectors.npy"), vectors_array)
452 |
453 | # Save IDs and metadata
454 | with open(directory / "data.json", "w") as f:
455 | json.dump({
456 | "name": self.name,
457 | "dimension": self.dimension,
458 | "similarity_metric": self.similarity_metric,
459 | "metadata": self.metadata,
460 | "ids": self.ids,
461 | "metadatas": self.metadatas,
462 | }, f)
463 |
464 | logger.info(
465 | f"Saved collection '{self.name}' to {directory}",
466 | emoji_key="vector"
467 | )
468 | return True
469 | except Exception as e:
470 | logger.error(
471 | f"Failed to save collection: {str(e)}",
472 | emoji_key="error"
473 | )
474 | return False
475 |
476 | @classmethod
477 | def load(cls, directory: Union[str, Path]) -> "VectorCollection":
478 | """Load collection from disk.
479 |
480 | Args:
481 | directory: Directory to load from
482 |
483 | Returns:
484 | Loaded collection
485 |
486 | Raises:
487 | FileNotFoundError: If collection files not found
488 | ValueError: If collection data is invalid
489 | """
490 | directory = Path(directory)
491 |
492 | # Check if files exist
493 | vectors_file = directory / "vectors.npy"
494 | data_file = directory / "data.json"
495 |
496 | if not vectors_file.exists() or not data_file.exists():
497 | raise FileNotFoundError(f"Collection files not found in {directory}")
498 |
499 | try:
500 | # Load vectors
501 | vectors_array = np.load(str(vectors_file))
502 | vectors = [vectors_array[i] for i in range(len(vectors_array))]
503 |
504 | # Load data
505 | with open(data_file, "r") as f:
506 | data = json.load(f)
507 |
508 | # Create collection
509 | collection = cls(
510 | name=data["name"],
511 | dimension=data["dimension"],
512 | similarity_metric=data["similarity_metric"],
513 | metadata=data["metadata"]
514 | )
515 |
516 | # Set data
517 | collection.ids = data["ids"]
518 | collection.metadatas = data["metadatas"]
519 | collection.vectors = vectors
520 |
521 | # Rebuild index
522 | collection._rebuild_index()
523 |
524 | logger.info(
525 | f"Loaded collection '{collection.name}' from {directory} ({len(vectors)} vectors)",
526 | emoji_key="vector"
527 | )
528 |
529 | return collection
530 | except Exception as e:
531 | logger.error(
532 | f"Failed to load collection: {str(e)}",
533 | emoji_key="error"
534 | )
535 | raise ValueError(f"Failed to load collection: {str(e)}") from e
536 |
537 | def get_stats(self) -> Dict[str, Any]:
538 | """Get collection statistics.
539 |
540 | Returns:
541 | Dictionary of statistics
542 | """
543 | return {
544 | "name": self.name,
545 | "dimension": self.dimension,
546 | "similarity_metric": self.similarity_metric,
547 | "vectors_count": len(self.vectors),
548 | "index_type": self.index_type,
549 | "metadata": self.metadata,
550 | }
551 |
552 | def clear(self) -> None:
553 | """Clear all vectors from the collection."""
554 | self.vectors = []
555 | self.ids = []
556 | self.metadatas = []
557 |
558 | # Reset index
559 | self._init_search_index()
560 |
561 | logger.info(
562 | f"Cleared collection '{self.name}'",
563 | emoji_key="vector"
564 | )
565 |
566 | async def query(
567 | self,
568 | query_texts: List[str],
569 | n_results: int = 10,
570 | where: Optional[Dict[str, Any]] = None,
571 | where_document: Optional[Dict[str, Any]] = None,
572 | include: Optional[List[str]] = None
573 | ) -> Dict[str, List[Any]]:
574 | """Query the collection with text queries (compatibility with ChromaDB).
575 |
576 | Args:
577 | query_texts: List of query texts
578 | n_results: Number of results to return
579 | where: Optional metadata filter
580 | where_document: Optional document content filter
581 | include: Optional list of fields to include
582 |
583 | Returns:
584 | Dictionary with results in ChromaDB format
585 | """
586 | logger.debug(f"DEBUG VectorCollection.query: query_texts={query_texts}, n_results={n_results}")
587 | logger.debug(f"DEBUG VectorCollection.query: where={where}, where_document={where_document}")
588 | logger.debug(f"DEBUG VectorCollection.query: include={include}")
589 | logger.debug(f"DEBUG VectorCollection.query: Collection has {len(self.vectors)} vectors and {len(self.ids)} IDs")
590 |
591 | # Initialize results
592 | results = {
593 | "ids": [],
594 | "documents": [],
595 | "metadatas": [],
596 | "distances": [],
597 | "embeddings": []
598 | }
599 |
600 | # Process each query
601 | for query_text in query_texts:
602 | # Get embedding using the async embedding service (which uses its configured model)
603 | logger.debug(f"DEBUG VectorCollection.query: Getting embedding for '{query_text}' using service model: {self.embedding_service.model_name}")
604 | try:
605 | query_embeddings_list = await self.embedding_service.create_embeddings([query_text])
606 | if not query_embeddings_list or not query_embeddings_list[0]:
607 | logger.error(f"Failed to generate embedding for query: '{query_text[:50]}...'")
608 | # Add empty results for this query and continue
609 | results["ids"].append([])
610 | results["documents"].append([])
611 | results["metadatas"].append([])
612 | results["distances"].append([])
613 | if "embeddings" in (include or []):
614 | results["embeddings"].append([])
615 | continue # Skip to next query_text
616 | query_embedding = np.array(query_embeddings_list[0], dtype=np.float32)
617 | if query_embedding.size == 0:
618 | logger.warning(f"Generated query embedding is empty for: '{query_text[:50]}...'. Skipping search for this query.")
619 | # Add empty results for this query and continue
620 | results["ids"].append([])
621 | results["documents"].append([])
622 | results["metadatas"].append([])
623 | results["distances"].append([])
624 | if "embeddings" in (include or []):
625 | results["embeddings"].append([])
626 | continue # Skip to next query_text
627 |
628 | except Exception as embed_err:
629 | logger.error(f"Error generating embedding for query '{query_text[:50]}...': {embed_err}", exc_info=True)
630 | # Add empty results for this query and continue
631 | results["ids"].append([])
632 | results["documents"].append([])
633 | results["metadatas"].append([])
634 | results["distances"].append([])
635 | if "embeddings" in (include or []):
636 | results["embeddings"].append([])
637 | continue # Skip to next query_text
638 |
639 | logger.debug(f"DEBUG VectorCollection.query: Embedding shape: {query_embedding.shape}")
640 |
641 | # Search with the embedding
642 | logger.debug(f"Searching for query text: '{query_text}' in collection '{self.name}'")
643 | search_results = self.search(
644 | query_vector=query_embedding, # Use the generated embedding
645 | top_k=n_results,
646 | filter=where,
647 | similarity_threshold=0.0 # Set to 0 to get all results for debugging
648 | )
649 |
650 | logger.debug(f"DEBUG VectorCollection.query: Found {len(search_results)} raw search results")
651 |
652 | # Format results in ChromaDB format
653 | ids = []
654 | documents = []
655 | metadatas = []
656 | distances = []
657 | embeddings = []
658 |
659 | for i, item in enumerate(search_results):
660 | ids.append(item["id"])
661 |
662 | # Extract document from metadata (keep existing robust logic)
663 | metadata = item.get("metadata", {})
664 | doc = ""
665 | if "text" in metadata:
666 | doc = metadata["text"]
667 | elif "document" in metadata:
668 | doc = metadata["document"]
669 | elif "content" in metadata:
670 | doc = metadata["content"]
671 | if not doc and isinstance(metadata, str):
672 | doc = metadata
673 |
674 | # Apply document content filter if specified
675 | if where_document and where_document.get("$contains"):
676 | filter_text = where_document["$contains"]
677 | if filter_text not in doc:
678 | logger.debug(f"DEBUG VectorCollection.query: Skipping doc {i} - doesn't contain filter text")
679 | continue
680 |
681 | logger.debug(f"Result {i+1}: id={item['id']}, similarity={item.get('similarity', 0.0):.4f}, doc_length={len(doc)}")
682 |
683 | documents.append(doc)
684 | metadatas.append(metadata)
685 | distance = 1.0 - item.get("similarity", 0.0)
686 | distances.append(distance)
687 | if "embeddings" in (include or []):
688 | embeddings.append(item.get("vector", []))
689 |
690 | # Add results for the current query_text
691 | results["ids"].append(ids)
692 | results["documents"].append(documents)
693 | results["metadatas"].append(metadatas)
694 | results["distances"].append(distances)
695 | if "embeddings" in (include or []):
696 | results["embeddings"].append(embeddings)
697 |
698 | logger.debug(f"DEBUG VectorCollection.query: Final formatted results for this query - {len(documents)} documents")
699 |
700 | return results
701 |
702 |
703 | class VectorDatabaseService:
704 | """Vector database service for semantic search."""
705 |
706 | _instance = None
707 |
708 | def __new__(cls, *args, **kwargs):
709 | """Create a singleton instance."""
710 | if cls._instance is None:
711 | cls._instance = super(VectorDatabaseService, cls).__new__(cls)
712 | cls._instance._initialized = False
713 | return cls._instance
714 |
715 | def __init__(
716 | self,
717 | base_dir: Optional[Union[str, Path]] = None,
718 | use_chromadb: Optional[bool] = None
719 | ):
720 | """Initialize the vector database service.
721 |
722 | Args:
723 | base_dir: Base directory for storage
724 | use_chromadb: Whether to use ChromaDB (if available)
725 | """
726 | # Only initialize once for singleton
727 | if self._initialized:
728 | return
729 |
730 | # Set base directory
731 | if base_dir:
732 | self.base_dir = Path(base_dir)
733 | else:
734 | self.base_dir = Path.home() / ".ultimate" / "vector_db"
735 |
736 | # Create base directory if it doesn't exist
737 | self.base_dir.mkdir(parents=True, exist_ok=True)
738 |
739 | # Check if ChromaDB should be used
740 | self.use_chromadb = use_chromadb if use_chromadb is not None else CHROMADB_AVAILABLE
741 |
742 | # Initialize ChromaDB client if used
743 | self.chroma_client = None
744 | if self.use_chromadb and CHROMADB_AVAILABLE:
745 | try:
746 | # Create ChromaDB directory if it doesn't exist
747 | chroma_dir = self.base_dir / "chromadb"
748 | chroma_dir.mkdir(parents=True, exist_ok=True)
749 |
750 | self.chroma_client = chromadb.PersistentClient(
751 | path=str(chroma_dir),
752 | settings=ChromaSettings(
753 | anonymized_telemetry=False,
754 | allow_reset=True
755 | )
756 | )
757 |
758 | # Test if it works properly
759 | test_collections = self.chroma_client.list_collections()
760 | logger.debug(f"ChromaDB initialized with {len(test_collections)} existing collections")
761 |
762 | logger.info(
763 | "Using ChromaDB for vector storage",
764 | emoji_key="vector"
765 | )
766 | except Exception as e:
767 | logger.error(
768 | f"Failed to initialize ChromaDB: {str(e)}. Vector operations will not work properly.",
769 | emoji_key="error"
770 | )
771 | # We'll raise an error rather than falling back to local storage
772 | # as that creates inconsistency
773 | self.use_chromadb = False
774 | self.chroma_client = None
775 |
776 | # Re-raise if ChromaDB was explicitly requested
777 | if use_chromadb:
778 | raise ValueError(f"ChromaDB initialization failed: {str(e)}") from e
779 | else:
780 | if use_chromadb and not CHROMADB_AVAILABLE:
781 | logger.error(
782 | "ChromaDB was explicitly requested but is not available. Please install it with: pip install chromadb",
783 | emoji_key="error"
784 | )
785 | raise ImportError("ChromaDB was requested but is not installed")
786 |
787 | self.use_chromadb = False
788 |
789 | # Collections
790 | self.collections = {}
791 |
792 | # Get embedding service
793 | self.embedding_service = get_embedding_service()
794 |
795 | self._initialized = True
796 |
797 | logger.info(
798 | f"Vector database service initialized (base_dir: {self.base_dir}, use_chromadb: {self.use_chromadb})",
799 | emoji_key="vector"
800 | )
801 |
802 | async def _reset_chroma_client(self) -> bool:
803 | """Reset or recreate the ChromaDB client.
804 |
805 | Returns:
806 | True if successful
807 | """
808 | if not CHROMADB_AVAILABLE or not self.use_chromadb:
809 | return False
810 |
811 | try:
812 | # First try using the reset API if available
813 | if self.chroma_client and hasattr(self.chroma_client, 'reset'):
814 | try:
815 | self.chroma_client.reset()
816 | logger.debug("Reset ChromaDB client successfully")
817 | return True
818 | except Exception as e:
819 | logger.debug(f"Failed to reset ChromaDB client using reset(): {str(e)}")
820 |
821 | # If that fails, recreate the client
822 | chroma_dir = self.base_dir / "chromadb"
823 | chroma_dir.mkdir(parents=True, exist_ok=True)
824 |
825 | self.chroma_client = chromadb.PersistentClient(
826 | path=str(chroma_dir),
827 | settings=ChromaSettings(
828 | anonymized_telemetry=False,
829 | allow_reset=True
830 | )
831 | )
832 |
833 | logger.debug("Successfully recreated ChromaDB client")
834 | return True
835 | except Exception as e:
836 | logger.error(
837 | f"Failed to reset or recreate ChromaDB client: {str(e)}",
838 | emoji_key="error"
839 | )
840 | return False
841 |
842 | async def create_collection(
843 | self,
844 | name: str,
845 | dimension: int = 1536,
846 | similarity_metric: str = "cosine",
847 | metadata: Optional[Dict[str, Any]] = None,
848 | overwrite: bool = False
849 | ) -> Union[VectorCollection, Any]:
850 | """Create a new collection.
851 |
852 | Args:
853 | name: Collection name
854 | dimension: Vector dimension
855 | similarity_metric: Similarity metric (cosine, dot, or euclidean)
856 | metadata: Optional metadata for the collection
857 | overwrite: Whether to overwrite existing collection
858 |
859 | Returns:
860 | Created collection
861 |
862 | Raises:
863 | ValueError: If collection already exists and overwrite is False
864 | """
865 | # Check if collection already exists in memory
866 | if name in self.collections and not overwrite:
867 | raise ValueError(f"Collection '{name}' already exists")
868 |
869 | # For consistency, if overwrite is True, explicitly delete any existing collection
870 | if overwrite:
871 | try:
872 | # Delete from memory collections
873 | if name in self.collections:
874 | del self.collections[name]
875 |
876 | # Try to delete from ChromaDB
877 | await self.delete_collection(name)
878 | logger.debug(f"Deleted existing collection '{name}' for overwrite")
879 |
880 | # # If using ChromaDB and overwrite is True, also try to reset the client
881 | # if self.use_chromadb and self.chroma_client:
882 | # await self._reset_chroma_client()
883 | # logger.debug("Reset ChromaDB client before creating new collection")
884 |
885 | # Force a delay to ensure deletions complete
886 | await asyncio.sleep(1.5)
887 |
888 | except Exception as e:
889 | logger.debug(f"Error during collection cleanup for overwrite: {str(e)}")
890 |
891 | # Create collection based on storage type
892 | if self.use_chromadb and self.chroma_client is not None:
893 | # Use ChromaDB
894 | # Sanitize metadata for ChromaDB (no None values)
895 | sanitized_metadata = {}
896 | if metadata:
897 | for k, v in metadata.items():
898 | if v is not None and not isinstance(v, (str, int, float, bool)):
899 | sanitized_metadata[k] = str(v) # Convert to string
900 | elif v is not None:
901 | sanitized_metadata[k] = v # Keep as is if it's a valid type
902 |
903 | # Force a delay to ensure previous deletions have completed
904 | await asyncio.sleep(0.1)
905 |
906 | # Create collection
907 | try:
908 | collection = self.chroma_client.create_collection(
909 | name=name,
910 | metadata=sanitized_metadata or {"description": "Vector collection"}
911 | )
912 |
913 | logger.info(
914 | f"Created ChromaDB collection '{name}'",
915 | emoji_key="vector"
916 | )
917 |
918 | self.collections[name] = collection
919 | return collection
920 | except Exception as e:
921 | # Instead of falling back to local storage, raise the error
922 | logger.error(
923 | f"Failed to create ChromaDB collection: {str(e)}",
924 | emoji_key="error"
925 | )
926 | raise ValueError(f"Failed to create ChromaDB collection: {str(e)}") from e
927 | else:
928 | # Use local storage
929 | collection = VectorCollection(
930 | name=name,
931 | dimension=dimension,
932 | similarity_metric=similarity_metric,
933 | metadata=metadata
934 | )
935 |
936 | self.collections[name] = collection
937 | return collection
938 |
939 | async def get_collection(self, name: str) -> Optional[Union[VectorCollection, Any]]:
940 | """Get a collection by name.
941 |
942 | Args:
943 | name: Collection name
944 |
945 | Returns:
946 | Collection or None if not found
947 | """
948 | # Check if collection is already loaded
949 | if name in self.collections:
950 | return self.collections[name]
951 |
952 | # Try to load from disk
953 | if self.use_chromadb and self.chroma_client is not None:
954 | # Check if ChromaDB collection exists
955 | try:
956 | # In ChromaDB v0.6.0+, list_collections() returns names not objects
957 | existing_collections = self.chroma_client.list_collections()
958 | existing_collection_names = []
959 |
960 | # Handle both chromadb v0.6.0+ and older versions
961 | if existing_collections and not isinstance(existing_collections[0], str):
962 | # v0.6.0+ returns collection objects
963 | for collection in existing_collections:
964 | # Access name attribute or use object itself if it's a string
965 | if hasattr(collection, 'name'):
966 | existing_collection_names.append(collection.name)
967 | else:
968 | existing_collection_names.append(str(collection))
969 | else:
970 | # Older versions return string names directly
971 | existing_collection_names = existing_collections
972 |
973 | if name in existing_collection_names:
974 | collection = self.chroma_client.get_collection(name)
975 | self.collections[name] = collection
976 | return collection
977 | except Exception as e:
978 | logger.error(
979 | f"Failed to get ChromaDB collection: {str(e)}",
980 | emoji_key="error"
981 | )
982 |
983 | # Try to load local collection
984 | collection_dir = self.base_dir / "collections" / name
985 | if collection_dir.exists():
986 | try:
987 | collection = VectorCollection.load(collection_dir)
988 | self.collections[name] = collection
989 | return collection
990 | except Exception as e:
991 | logger.error(
992 | f"Failed to load collection '{name}': {str(e)}",
993 | emoji_key="error"
994 | )
995 |
996 | return None
997 |
998 | async def list_collections(self) -> List[str]:
999 | """List all collection names.
1000 |
1001 | Returns:
1002 | List of collection names
1003 | """
1004 | collection_names = set(self.collections.keys())
1005 |
1006 | # Add collections from ChromaDB
1007 | if self.use_chromadb and self.chroma_client is not None:
1008 | try:
1009 | # Handle both chromadb v0.6.0+ and older versions
1010 | chroma_collections = self.chroma_client.list_collections()
1011 |
1012 | # Check if we received a list of collection objects or just names
1013 | if chroma_collections and not isinstance(chroma_collections[0], str):
1014 | # v0.6.0+ returns collection objects
1015 | for collection in chroma_collections:
1016 | # Access name attribute or use object itself if it's a string
1017 | if hasattr(collection, 'name'):
1018 | collection_names.add(collection.name)
1019 | else:
1020 | collection_names.add(str(collection))
1021 | else:
1022 | # Older versions return string names directly
1023 | for collection in chroma_collections:
1024 | collection_names.add(collection)
1025 | except Exception as e:
1026 | logger.error(
1027 | f"Failed to list ChromaDB collections: {str(e)}",
1028 | emoji_key="error"
1029 | )
1030 |
1031 | # Add collections from disk
1032 | collections_dir = self.base_dir / "collections"
1033 | if collections_dir.exists():
1034 | for path in collections_dir.iterdir():
1035 | if path.is_dir() and (path / "data.json").exists():
1036 | collection_names.add(path.name)
1037 |
1038 | return list(collection_names)
1039 |
1040 | async def delete_collection(self, name: str) -> bool:
1041 | """Delete a collection.
1042 |
1043 | Args:
1044 | name: Collection name
1045 |
1046 | Returns:
1047 | True if successful
1048 | """
1049 | # Remove from loaded collections
1050 | if name in self.collections:
1051 | del self.collections[name]
1052 |
1053 | success = True
1054 |
1055 | # Delete from ChromaDB
1056 | if self.use_chromadb and self.chroma_client is not None:
1057 | try:
1058 | # Check if collection exists in ChromaDB first
1059 | exists_in_chromadb = False
1060 | try:
1061 | collections = self.chroma_client.list_collections()
1062 | # Handle different versions of ChromaDB API
1063 | if collections and hasattr(collections[0], 'name'):
1064 | collection_names = [c.name for c in collections]
1065 | else:
1066 | collection_names = collections
1067 |
1068 | exists_in_chromadb = name in collection_names
1069 | except Exception as e:
1070 | logger.debug(f"Error checking ChromaDB collections: {str(e)}")
1071 |
1072 | # Only try to delete if it exists
1073 | if exists_in_chromadb:
1074 | self.chroma_client.delete_collection(name)
1075 | logger.debug(f"Deleted ChromaDB collection '{name}'")
1076 | except Exception as e:
1077 | logger.warning(
1078 | f"Failed to delete ChromaDB collection: {str(e)}",
1079 | emoji_key="warning"
1080 | )
1081 | success = False
1082 |
1083 | # Delete from disk
1084 | collection_dir = self.base_dir / "collections" / name
1085 | if collection_dir.exists():
1086 | try:
1087 | import shutil
1088 | shutil.rmtree(collection_dir)
1089 | logger.debug(f"Deleted collection directory: {collection_dir}")
1090 | except Exception as e:
1091 | logger.error(
1092 | f"Failed to delete collection directory: {str(e)}",
1093 | emoji_key="error"
1094 | )
1095 | return False
1096 |
1097 | logger.info(
1098 | f"Deleted collection '{name}'",
1099 | emoji_key="vector"
1100 | )
1101 |
1102 | return success
1103 |
1104 | async def add_texts(
1105 | self,
1106 | collection_name: str,
1107 | texts: List[str],
1108 | metadatas: Optional[List[Dict[str, Any]]] = None,
1109 | ids: Optional[List[str]] = None,
1110 | embedding_model: Optional[str] = None,
1111 | batch_size: int = 100
1112 | ) -> List[str]:
1113 | """Add texts to a collection.
1114 |
1115 | Args:
1116 | collection_name: Collection name
1117 | texts: Texts to add
1118 | metadatas: Optional metadata for each text
1119 | ids: Optional IDs for the texts
1120 | embedding_model: Embedding model name (NOTE: Model is set during EmbeddingService init)
1121 | batch_size: Maximum batch size for embedding generation
1122 |
1123 | Returns:
1124 | List of document IDs
1125 |
1126 | Raises:
1127 | ValueError: If collection not found
1128 | """
1129 | # Get or create collection
1130 | collection = await self.get_collection(collection_name)
1131 | if collection is None:
1132 | collection = await self.create_collection(collection_name)
1133 |
1134 | # Generate embeddings
1135 | logger.debug(f"Generating embeddings for {len(texts)} texts using model: {self.embedding_service.model_name}")
1136 | embeddings = []
1137 | for i in range(0, len(texts), batch_size):
1138 | batch_texts = texts[i:i + batch_size]
1139 | batch_embeddings = await self.embedding_service.create_embeddings(
1140 | texts=batch_texts,
1141 | )
1142 | embeddings.extend(batch_embeddings)
1143 | if len(texts) > batch_size: # Add delay if batching
1144 | await asyncio.sleep(0.1) # Small delay between batches
1145 |
1146 | logger.debug(f"Generated {len(embeddings)} embeddings")
1147 |
1148 | # Add to collection
1149 | if self.use_chromadb and isinstance(collection, chromadb.Collection):
1150 | # ChromaDB collection
1151 | try:
1152 | # Generate IDs if not provided
1153 | if ids is None:
1154 | ids = [str(uuid.uuid4()) for _ in range(len(texts))]
1155 |
1156 | # Ensure metadatas is provided
1157 | if metadatas is None:
1158 | metadatas = [{} for _ in range(len(texts))]
1159 |
1160 | # Add to ChromaDB collection
1161 | collection.add(
1162 | embeddings=embeddings,
1163 | documents=texts,
1164 | metadatas=metadatas,
1165 | ids=ids
1166 | )
1167 |
1168 | logger.info(
1169 | f"Added {len(texts)} documents to ChromaDB collection '{collection_name}'",
1170 | emoji_key="vector"
1171 | )
1172 |
1173 | return ids
1174 | except Exception as e:
1175 | logger.error(
1176 | f"Failed to add documents to ChromaDB collection: {str(e)}",
1177 | emoji_key="error"
1178 | )
1179 | raise
1180 | else:
1181 | # Local collection
1182 | # For local collection, store text in metadata
1183 | combined_metadata = []
1184 | for _i, (text, meta) in enumerate(zip(texts, metadatas or [{} for _ in range(len(texts))], strict=False)):
1185 | # Create metadata with text as main content
1186 | combined_meta = {"text": text}
1187 | # Add any other metadata
1188 | if meta:
1189 | combined_meta.update(meta)
1190 | combined_metadata.append(combined_meta)
1191 |
1192 | logger.debug(f"Adding vectors to local collection with metadata: {combined_metadata[0] if combined_metadata else None}")
1193 |
1194 | result_ids = collection.add(
1195 | vectors=embeddings,
1196 | ids=ids,
1197 | metadatas=combined_metadata
1198 | )
1199 |
1200 | logger.debug(f"Added {len(result_ids)} vectors to local collection '{collection_name}'")
1201 |
1202 | return result_ids
1203 |
1204 | async def search_by_text(
1205 | self,
1206 | collection_name: str,
1207 | query_text: str,
1208 | top_k: int = 5,
1209 | filter: Optional[Dict[str, Any]] = None,
1210 | embedding_model: Optional[str] = None,
1211 | include_vectors: bool = False,
1212 | similarity_threshold: float = 0.0
1213 | ) -> List[Dict[str, Any]]:
1214 | """Search a collection by text query.
1215 |
1216 | Args:
1217 | collection_name: Collection name
1218 | query_text: Text query
1219 | top_k: Number of results to return
1220 | filter: Optional metadata filter
1221 | embedding_model: Embedding model name
1222 | include_vectors: Whether to include vectors in results
1223 | similarity_threshold: Minimum similarity score (0.0 to 1.0)
1224 |
1225 | Returns:
1226 | List of search results
1227 |
1228 | Raises:
1229 | ValueError: If collection not found
1230 | """
1231 | # Get collection
1232 | collection = await self.get_collection(collection_name)
1233 | if collection is None:
1234 | raise ValueError(f"Collection '{collection_name}' not found")
1235 |
1236 | # Search collection
1237 | if self.use_chromadb and isinstance(collection, chromadb.Collection):
1238 | # ChromaDB collection
1239 | try:
1240 | # Convert filter to ChromaDB format if provided
1241 | chroma_filter = self._convert_to_chroma_filter(filter) if filter else None
1242 |
1243 | # Prepare include parameters for ChromaDB
1244 | include_params = ["documents", "metadatas", "distances"]
1245 | if include_vectors:
1246 | include_params.append("embeddings")
1247 |
1248 | # Get embedding directly using our service
1249 | query_embeddings = await self.embedding_service.create_embeddings(
1250 | texts=[query_text],
1251 | # model=embedding_model # Model is defined in the service instance
1252 | )
1253 | if not query_embeddings:
1254 | logger.error(f"Failed to generate embedding for query: {query_text}")
1255 | return []
1256 | query_embedding = query_embeddings[0]
1257 |
1258 | logger.debug(f"Using explicitly generated embedding with model {self.embedding_service.model_name}")
1259 |
1260 | # Search ChromaDB collection with our embedding
1261 | results = collection.query(
1262 | query_embeddings=[query_embedding], # Use our embedding directly, not ChromaDB's
1263 | n_results=top_k,
1264 | where=chroma_filter,
1265 | where_document=None,
1266 | include=include_params
1267 | )
1268 |
1269 | # Format results and apply similarity threshold
1270 | formatted_results = []
1271 | for i in range(len(results["ids"][0])):
1272 | similarity = 1.0 - float(results["distances"][0][i]) # Convert distance to similarity
1273 |
1274 | # Skip results below threshold
1275 | if similarity < similarity_threshold:
1276 | continue
1277 |
1278 | result = {
1279 | "id": results["ids"][0][i],
1280 | "text": results["documents"][0][i],
1281 | "metadata": results["metadatas"][0][i],
1282 | "similarity": similarity,
1283 | }
1284 |
1285 | if include_vectors and "embeddings" in results:
1286 | result["vector"] = results["embeddings"][0][i]
1287 |
1288 | formatted_results.append(result)
1289 |
1290 | return formatted_results
1291 | except Exception as e:
1292 | logger.error(
1293 | f"Failed to search ChromaDB collection: {str(e)}",
1294 | emoji_key="error"
1295 | )
1296 | raise
1297 | else:
1298 | # Local collection
1299 | results = await collection.search_by_text(
1300 | query_text=query_text,
1301 | top_k=top_k,
1302 | filter=filter,
1303 | # model=embedding_model, # Pass model used by the collection's service instance
1304 | similarity_threshold=similarity_threshold
1305 | )
1306 |
1307 | # Format results
1308 | formatted_results = []
1309 | for result in results:
1310 | formatted_result = {
1311 | "id": result["id"],
1312 | "text": result["metadata"].get("text", ""),
1313 | "metadata": {k: v for k, v in result["metadata"].items() if k != "text"},
1314 | "similarity": result["similarity"],
1315 | }
1316 |
1317 | if include_vectors:
1318 | formatted_result["vector"] = result["vector"]
1319 |
1320 | formatted_results.append(formatted_result)
1321 |
1322 | return formatted_results
1323 |
1324 | def _convert_to_chroma_filter(self, filter: Dict[str, Any]) -> Dict[str, Any]:
1325 | """Convert filter to ChromaDB format.
1326 |
1327 | Args:
1328 | filter: Filter dictionary
1329 |
1330 | Returns:
1331 | ChromaDB-compatible filter
1332 | """
1333 | # Simple equality filter for now
1334 | return filter
1335 |
1336 | def save_all_collections(self) -> int:
1337 | """Save all local collections to disk.
1338 |
1339 | Returns:
1340 | Number of collections saved
1341 | """
1342 | saved_count = 0
1343 | collections_dir = self.base_dir / "collections"
1344 | collections_dir.mkdir(parents=True, exist_ok=True)
1345 |
1346 | for name, collection in self.collections.items():
1347 | if not self.use_chromadb or not isinstance(collection, chromadb.Collection):
1348 | # Only save local collections
1349 | collection_dir = collections_dir / name
1350 | if collection.save(collection_dir):
1351 | saved_count += 1
1352 |
1353 | logger.info(
1354 | f"Saved {saved_count} collections to disk",
1355 | emoji_key="vector"
1356 | )
1357 |
1358 | return saved_count
1359 |
1360 | async def get_stats(self) -> Dict[str, Any]:
1361 | """Get statistics about collections.
1362 |
1363 | Returns:
1364 | Dictionary of statistics
1365 | """
1366 | collection_names = await self.list_collections()
1367 | collection_stats = {}
1368 |
1369 | for name in collection_names:
1370 | collection = await self.get_collection(name)
1371 | if collection:
1372 | if isinstance(collection, VectorCollection):
1373 | collection_stats[name] = collection.get_stats()
1374 | else:
1375 | # ChromaDB collection
1376 | try:
1377 | count = collection.count()
1378 | collection_stats[name] = {
1379 | "count": count,
1380 | "type": "chromadb"
1381 | }
1382 | except Exception as e:
1383 | logger.error(
1384 | f"Error getting stats for ChromaDB collection '{name}': {str(e)}",
1385 | emoji_key="error"
1386 | )
1387 | collection_stats[name] = {
1388 | "count": 0,
1389 | "type": "chromadb",
1390 | "error": str(e)
1391 | }
1392 |
1393 | stats = {
1394 | "collections": len(collection_names),
1395 | "collection_stats": collection_stats
1396 | }
1397 |
1398 | return stats
1399 |
1400 | async def get_collection_metadata(self, name: str) -> Dict[str, Any]:
1401 | """Get collection metadata.
1402 |
1403 | Args:
1404 | name: Collection name
1405 |
1406 | Returns:
1407 | Collection metadata
1408 |
1409 | Raises:
1410 | ValueError: If collection not found
1411 | """
1412 | # Get collection
1413 | collection = await self.get_collection(name)
1414 | if collection is None:
1415 | raise ValueError(f"Collection '{name}' not found")
1416 |
1417 | # Get metadata
1418 | try:
1419 | if self.use_chromadb and hasattr(collection, "get_metadata"):
1420 | # ChromaDB collection
1421 | return collection.get_metadata() or {}
1422 | elif hasattr(collection, "metadata"):
1423 | # Local collection
1424 | return collection.metadata or {}
1425 | except Exception as e:
1426 | logger.error(
1427 | f"Failed to get collection metadata: {str(e)}",
1428 | emoji_key="error"
1429 | )
1430 |
1431 | return {}
1432 |
1433 | async def update_collection_metadata(self, name: str, metadata: Dict[str, Any]) -> bool:
1434 | """Update collection metadata.
1435 |
1436 | Args:
1437 | name: Collection name
1438 | metadata: New metadata
1439 |
1440 | Returns:
1441 | True if successful
1442 |
1443 | Raises:
1444 | ValueError: If collection not found
1445 | """
1446 | # Get collection
1447 | collection = await self.get_collection(name)
1448 | if collection is None:
1449 | raise ValueError(f"Collection '{name}' not found")
1450 |
1451 | # Update metadata
1452 | try:
1453 | if self.use_chromadb and hasattr(collection, "update_metadata"):
1454 | # ChromaDB collection - needs validation
1455 | validated_metadata = {}
1456 | for k, v in metadata.items():
1457 | # ChromaDB accepts only str, int, float, bool
1458 | if isinstance(v, (str, int, float, bool)):
1459 | validated_metadata[k] = v
1460 | elif v is None:
1461 | # Skip None values
1462 | logger.debug(f"Skipping None value for metadata key '{k}'")
1463 | continue
1464 | else:
1465 | # Convert other types to string
1466 | validated_metadata[k] = str(v)
1467 |
1468 | # Debug log the validated metadata
1469 | logger.debug(f"Updating ChromaDB collection metadata with: {validated_metadata}")
1470 |
1471 | collection.update_metadata(validated_metadata)
1472 | elif hasattr(collection, "metadata"):
1473 | # Local collection
1474 | collection.metadata.update(metadata)
1475 |
1476 | logger.info(
1477 | f"Updated metadata for collection '{name}'",
1478 | emoji_key="vector"
1479 | )
1480 | return True
1481 | except Exception as e:
1482 | logger.error(
1483 | f"Failed to update collection metadata: {str(e)}",
1484 | emoji_key="error"
1485 | )
1486 | # Don't re-raise, just return false
1487 | return False
1488 |
1489 |
1490 | # Singleton instance getter
1491 | def get_vector_db_service(
1492 | base_dir: Optional[Union[str, Path]] = None,
1493 | use_chromadb: Optional[bool] = None
1494 | ) -> VectorDatabaseService:
1495 | """Get the vector database service singleton instance.
1496 |
1497 | Args:
1498 | base_dir: Base directory for storage
1499 | use_chromadb: Whether to use ChromaDB (if available)
1500 |
1501 | Returns:
1502 | VectorDatabaseService singleton instance
1503 | """
1504 | return VectorDatabaseService(base_dir, use_chromadb)
```