This is page 6 of 45. Use http://codebase.md/dicklesworthstone/llm_gateway_mcp_server?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .cursorignore
├── .env.example
├── .envrc
├── .gitignore
├── additional_features.md
├── check_api_keys.py
├── completion_support.py
├── comprehensive_test.py
├── docker-compose.yml
├── Dockerfile
├── empirically_measured_model_speeds.json
├── error_handling.py
├── example_structured_tool.py
├── examples
│ ├── __init__.py
│ ├── advanced_agent_flows_using_unified_memory_system_demo.py
│ ├── advanced_extraction_demo.py
│ ├── advanced_unified_memory_system_demo.py
│ ├── advanced_vector_search_demo.py
│ ├── analytics_reporting_demo.py
│ ├── audio_transcription_demo.py
│ ├── basic_completion_demo.py
│ ├── cache_demo.py
│ ├── claude_integration_demo.py
│ ├── compare_synthesize_demo.py
│ ├── cost_optimization.py
│ ├── data
│ │ ├── sample_event.txt
│ │ ├── Steve_Jobs_Introducing_The_iPhone_compressed.md
│ │ └── Steve_Jobs_Introducing_The_iPhone_compressed.mp3
│ ├── docstring_refiner_demo.py
│ ├── document_conversion_and_processing_demo.py
│ ├── entity_relation_graph_demo.py
│ ├── filesystem_operations_demo.py
│ ├── grok_integration_demo.py
│ ├── local_text_tools_demo.py
│ ├── marqo_fused_search_demo.py
│ ├── measure_model_speeds.py
│ ├── meta_api_demo.py
│ ├── multi_provider_demo.py
│ ├── ollama_integration_demo.py
│ ├── prompt_templates_demo.py
│ ├── python_sandbox_demo.py
│ ├── rag_example.py
│ ├── research_workflow_demo.py
│ ├── sample
│ │ ├── article.txt
│ │ ├── backprop_paper.pdf
│ │ ├── buffett.pdf
│ │ ├── contract_link.txt
│ │ ├── legal_contract.txt
│ │ ├── medical_case.txt
│ │ ├── northwind.db
│ │ ├── research_paper.txt
│ │ ├── sample_data.json
│ │ └── text_classification_samples
│ │ ├── email_classification.txt
│ │ ├── news_samples.txt
│ │ ├── product_reviews.txt
│ │ └── support_tickets.txt
│ ├── sample_docs
│ │ └── downloaded
│ │ └── attention_is_all_you_need.pdf
│ ├── sentiment_analysis_demo.py
│ ├── simple_completion_demo.py
│ ├── single_shot_synthesis_demo.py
│ ├── smart_browser_demo.py
│ ├── sql_database_demo.py
│ ├── sse_client_demo.py
│ ├── test_code_extraction.py
│ ├── test_content_detection.py
│ ├── test_ollama.py
│ ├── text_classification_demo.py
│ ├── text_redline_demo.py
│ ├── tool_composition_examples.py
│ ├── tournament_code_demo.py
│ ├── tournament_text_demo.py
│ ├── unified_memory_system_demo.py
│ ├── vector_search_demo.py
│ ├── web_automation_instruction_packs.py
│ └── workflow_delegation_demo.py
├── LICENSE
├── list_models.py
├── marqo_index_config.json.example
├── mcp_protocol_schema_2025-03-25_version.json
├── mcp_python_lib_docs.md
├── mcp_tool_context_estimator.py
├── model_preferences.py
├── pyproject.toml
├── quick_test.py
├── README.md
├── resource_annotations.py
├── run_all_demo_scripts_and_check_for_errors.py
├── storage
│ └── smart_browser_internal
│ ├── locator_cache.db
│ ├── readability.js
│ └── storage_state.enc
├── test_client.py
├── test_connection.py
├── TEST_README.md
├── test_sse_client.py
├── test_stdio_client.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── integration
│ │ ├── __init__.py
│ │ └── test_server.py
│ ├── manual
│ │ ├── test_extraction_advanced.py
│ │ └── test_extraction.py
│ └── unit
│ ├── __init__.py
│ ├── test_cache.py
│ ├── test_providers.py
│ └── test_tools.py
├── TODO.md
├── tool_annotations.py
├── tools_list.json
├── ultimate_mcp_banner.webp
├── ultimate_mcp_logo.webp
├── ultimate_mcp_server
│ ├── __init__.py
│ ├── __main__.py
│ ├── cli
│ │ ├── __init__.py
│ │ ├── __main__.py
│ │ ├── commands.py
│ │ ├── helpers.py
│ │ └── typer_cli.py
│ ├── clients
│ │ ├── __init__.py
│ │ ├── completion_client.py
│ │ └── rag_client.py
│ ├── config
│ │ └── examples
│ │ └── filesystem_config.yaml
│ ├── config.py
│ ├── constants.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── evaluation
│ │ │ ├── base.py
│ │ │ └── evaluators.py
│ │ ├── providers
│ │ │ ├── __init__.py
│ │ │ ├── anthropic.py
│ │ │ ├── base.py
│ │ │ ├── deepseek.py
│ │ │ ├── gemini.py
│ │ │ ├── grok.py
│ │ │ ├── ollama.py
│ │ │ ├── openai.py
│ │ │ └── openrouter.py
│ │ ├── server.py
│ │ ├── state_store.py
│ │ ├── tournaments
│ │ │ ├── manager.py
│ │ │ ├── tasks.py
│ │ │ └── utils.py
│ │ └── ums_api
│ │ ├── __init__.py
│ │ ├── ums_database.py
│ │ ├── ums_endpoints.py
│ │ ├── ums_models.py
│ │ └── ums_services.py
│ ├── exceptions.py
│ ├── graceful_shutdown.py
│ ├── services
│ │ ├── __init__.py
│ │ ├── analytics
│ │ │ ├── __init__.py
│ │ │ ├── metrics.py
│ │ │ └── reporting.py
│ │ ├── cache
│ │ │ ├── __init__.py
│ │ │ ├── cache_service.py
│ │ │ ├── persistence.py
│ │ │ ├── strategies.py
│ │ │ └── utils.py
│ │ ├── cache.py
│ │ ├── document.py
│ │ ├── knowledge_base
│ │ │ ├── __init__.py
│ │ │ ├── feedback.py
│ │ │ ├── manager.py
│ │ │ ├── rag_engine.py
│ │ │ ├── retriever.py
│ │ │ └── utils.py
│ │ ├── prompts
│ │ │ ├── __init__.py
│ │ │ ├── repository.py
│ │ │ └── templates.py
│ │ ├── prompts.py
│ │ └── vector
│ │ ├── __init__.py
│ │ ├── embeddings.py
│ │ └── vector_service.py
│ ├── tool_token_counter.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── audio_transcription.py
│ │ ├── base.py
│ │ ├── completion.py
│ │ ├── docstring_refiner.py
│ │ ├── document_conversion_and_processing.py
│ │ ├── enhanced-ums-lookbook.html
│ │ ├── entity_relation_graph.py
│ │ ├── excel_spreadsheet_automation.py
│ │ ├── extraction.py
│ │ ├── filesystem.py
│ │ ├── html_to_markdown.py
│ │ ├── local_text_tools.py
│ │ ├── marqo_fused_search.py
│ │ ├── meta_api_tool.py
│ │ ├── ocr_tools.py
│ │ ├── optimization.py
│ │ ├── provider.py
│ │ ├── pyodide_boot_template.html
│ │ ├── python_sandbox.py
│ │ ├── rag.py
│ │ ├── redline-compiled.css
│ │ ├── sentiment_analysis.py
│ │ ├── single_shot_synthesis.py
│ │ ├── smart_browser.py
│ │ ├── sql_databases.py
│ │ ├── text_classification.py
│ │ ├── text_redline_tools.py
│ │ ├── tournament.py
│ │ ├── ums_explorer.html
│ │ └── unified_memory_system.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── async_utils.py
│ │ ├── display.py
│ │ ├── logging
│ │ │ ├── __init__.py
│ │ │ ├── console.py
│ │ │ ├── emojis.py
│ │ │ ├── formatter.py
│ │ │ ├── logger.py
│ │ │ ├── panels.py
│ │ │ ├── progress.py
│ │ │ └── themes.py
│ │ ├── parse_yaml.py
│ │ ├── parsing.py
│ │ ├── security.py
│ │ └── text.py
│ └── working_memory_api.py
├── unified_memory_system_technical_analysis.md
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/examples/prompt_templates_demo.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python
2 | """Prompt templates and repository demonstration for Ultimate MCP Server."""
3 | import asyncio
4 | import sys
5 | import time
6 | from pathlib import Path
7 |
8 | # Add project root to path for imports when running as script
9 | sys.path.insert(0, str(Path(__file__).parent.parent))
10 |
11 | from rich import box
12 | from rich.markup import escape
13 | from rich.panel import Panel
14 | from rich.rule import Rule
15 | from rich.syntax import Syntax
16 | from rich.table import Table
17 |
18 | from ultimate_mcp_server.constants import Provider
19 | from ultimate_mcp_server.core.server import Gateway
20 | from ultimate_mcp_server.services.prompts import PromptTemplate, get_prompt_repository
21 | from ultimate_mcp_server.utils import get_logger
22 | from ultimate_mcp_server.utils.display import CostTracker, display_text_content_result
23 |
24 | # --- Add Rich Imports ---
25 | from ultimate_mcp_server.utils.logging.console import console
26 |
27 | # ----------------------
28 |
29 | # Initialize logger
30 | logger = get_logger("example.prompt_templates")
31 |
32 |
33 | async def demonstrate_prompt_templates():
34 | """Demonstrate prompt template creation and rendering."""
35 | # Use Rich Rule for title
36 | console.print(Rule("[bold blue]Prompt Template Demonstration[/bold blue]"))
37 | logger.info("Starting prompt template demonstration", emoji_key="start")
38 |
39 | # Simple prompt template
40 | template_text = """
41 | You are an expert in {{field}}.
42 | Please explain {{concept}} in simple terms that a {{audience}} could understand.
43 | """
44 |
45 | # Create a prompt template
46 | template = PromptTemplate(
47 | template=template_text,
48 | template_id="simple_explanation",
49 | description="A template for generating simple explanations of concepts"
50 | )
51 |
52 | logger.info(
53 | f"Created prompt template: {template.template_id}",
54 | emoji_key="template"
55 | )
56 |
57 | # Render the template with variables
58 | variables = {
59 | "field": "artificial intelligence",
60 | "concept": "neural networks",
61 | "audience": "high school student"
62 | }
63 |
64 | rendered_prompt = template.render(variables)
65 |
66 | logger.info(
67 | "Template rendered successfully",
68 | emoji_key="success",
69 | variables=list(variables.keys())
70 | )
71 |
72 | # Display rendered template using Rich
73 | console.print(Rule("[cyan]Simple Template Rendering[/cyan]"))
74 | console.print(Panel(
75 | Syntax(template.template, "jinja2", theme="default", line_numbers=False),
76 | title="[bold]Template Source[/bold]",
77 | border_style="dim blue",
78 | expand=False
79 | ))
80 | vars_table = Table(title="[bold]Variables[/bold]", box=box.MINIMAL, show_header=False)
81 | vars_table.add_column("Key", style="magenta")
82 | vars_table.add_column("Value", style="white")
83 | for key, value in variables.items():
84 | vars_table.add_row(escape(key), escape(value))
85 | console.print(vars_table)
86 | console.print(Panel(
87 | escape(rendered_prompt.strip()),
88 | title="[bold green]Rendered Prompt[/bold green]",
89 | border_style="green",
90 | expand=False
91 | ))
92 | console.print()
93 |
94 |
95 | # Create a more complex template with conditional blocks
96 | complex_template_text = """
97 | {% if system_message %}
98 | {{system_message}}
99 | {% else %}
100 | You are a helpful assistant that provides accurate information.
101 | {% endif %}
102 |
103 | {% if context %}
104 | Here is some context to help you answer:
105 | {{context}}
106 | {% endif %}
107 |
108 | USER: {{query}}
109 |
110 | Please respond with:
111 | {% for item in response_items %}
112 | - {{item}}
113 | {% endfor %}
114 | """
115 |
116 | complex_template_obj = PromptTemplate(
117 | template=complex_template_text, # Use the text variable
118 | template_id="complex_assistant",
119 | description="A complex assistant template with conditionals and loops",
120 | required_vars=["system_message", "query", "response_items", "context"]
121 | )
122 |
123 | # Complex variables
124 | complex_variables = {
125 | "system_message": "You are an expert in climate science who explains concepts clearly and objectively.",
126 | "query": "What are the main causes of climate change?",
127 | "context": """
128 | Recent data shows that global temperatures have risen by about 1.1°C since pre-industrial times.
129 | The IPCC Sixth Assessment Report (2021) states that human activities are unequivocally the main driver
130 | of climate change, primarily through greenhouse gas emissions. CO2 levels have increased by 48% since
131 | the industrial revolution, reaching levels not seen in at least 800,000 years.
132 | """,
133 | "response_items": [
134 | "A summary of the main causes based on scientific consensus",
135 | "The role of greenhouse gases (CO2, methane, etc.) in climate change",
136 | "Human activities that contribute most significantly to emissions",
137 | "Natural vs anthropogenic factors and their relative impact",
138 | "Regional variations in climate change impacts"
139 | ]
140 | }
141 |
142 | complex_rendered = complex_template_obj.render(complex_variables)
143 |
144 | logger.info(
145 | "Complex template rendered successfully",
146 | emoji_key="success",
147 | template_id=complex_template_obj.template_id
148 | )
149 |
150 | # Display complex template rendering using Rich
151 | console.print(Rule("[cyan]Complex Template Rendering[/cyan]"))
152 | console.print(Panel(
153 | Syntax(complex_template_obj.template, "jinja2", theme="default", line_numbers=False),
154 | title="[bold]Template Source[/bold]",
155 | border_style="dim blue",
156 | expand=False
157 | ))
158 | complex_vars_table = Table(title="[bold]Variables[/bold]", box=box.MINIMAL, show_header=False)
159 | complex_vars_table.add_column("Key", style="magenta")
160 | complex_vars_table.add_column("Value", style="white")
161 | for key, value in complex_variables.items():
162 | # Truncate long context for display
163 | display_value = escape(str(value))
164 | if key == 'context' and len(display_value) > 150:
165 | display_value = display_value[:150] + '...'
166 | elif isinstance(value, list):
167 | display_value = escape(str(value)[:100] + '...' if len(str(value)) > 100 else str(value)) # Truncate lists too
168 | complex_vars_table.add_row(escape(key), display_value)
169 | console.print(complex_vars_table)
170 | console.print(Panel(
171 | escape(complex_rendered.strip()),
172 | title="[bold green]Rendered Prompt[/bold green]",
173 | border_style="green",
174 | expand=False
175 | ))
176 | console.print()
177 |
178 | # Demonstrate rendering with missing variables (handled by Jinja's default behavior or errors)
179 | console.print(Rule("[cyan]Template with Missing Variables[/cyan]"))
180 | missing_variables = {
181 | "query": "How can individuals reduce their carbon footprint?",
182 | "response_items": [
183 | "Daily lifestyle changes with significant impact",
184 | "Transportation choices and alternatives",
185 | "Home energy consumption reduction strategies"
186 | ]
187 | # system_message and context are intentionally missing
188 | }
189 |
190 | try:
191 | missing_rendered = complex_template_obj.render(missing_variables)
192 | logger.info(
193 | "Template rendered with missing optional variables (using defaults)",
194 | emoji_key="info",
195 | missing=["system_message", "context"]
196 | )
197 | console.print(Panel(
198 | escape(missing_rendered.strip()),
199 | title="[bold yellow]Rendered with Defaults[/bold yellow]",
200 | border_style="yellow",
201 | expand=False
202 | ))
203 | except Exception as e: # Catch Jinja exceptions or others
204 | logger.warning(f"Could not render with missing variables: {str(e)}", emoji_key="warning")
205 | console.print(Panel(
206 | f"[red]Error rendering template:[/red]\n{escape(str(e))}",
207 | title="[bold red]Rendering Error[/bold red]",
208 | border_style="red",
209 | expand=False
210 | ))
211 | console.print()
212 |
213 | return template, complex_template_obj
214 |
215 |
216 | async def demonstrate_prompt_repository():
217 | """Demonstrate saving and retrieving templates from repository."""
218 | # Use Rich Rule
219 | console.print(Rule("[bold blue]Prompt Repository Demonstration[/bold blue]"))
220 | logger.info("Starting prompt repository demonstration", emoji_key="start")
221 |
222 | # Get repository
223 | repo = get_prompt_repository()
224 |
225 | # Check repository path
226 | logger.info(f"Prompt repository path: {repo.base_dir}", emoji_key="info")
227 |
228 | # List existing prompts (if any)
229 | prompts = await repo.list_prompts()
230 | if prompts:
231 | logger.info(f"Found {len(prompts)} existing prompts: {', '.join(prompts)}", emoji_key="info")
232 | else:
233 | logger.info("No existing prompts found in repository", emoji_key="info")
234 |
235 | # Create a new prompt template for saving
236 | translation_template = """
237 | Translate the following {{source_language}} text into {{target_language}}:
238 |
239 | TEXT: {{text}}
240 |
241 | The translation should be:
242 | - Accurate and faithful to the original
243 | - Natural in the target language
244 | - Preserve the tone and style of the original
245 |
246 | TRANSLATION:
247 | """
248 |
249 | template = PromptTemplate(
250 | template=translation_template,
251 | template_id="translation_prompt",
252 | description="A template for translation tasks",
253 | metadata={
254 | "author": "Ultimate MCP Server",
255 | "version": "1.0",
256 | "supported_languages": ["English", "Spanish", "French", "German", "Japanese"]
257 | }
258 | )
259 |
260 | # Save to repository
261 | template_dict = template.to_dict()
262 |
263 | logger.info(
264 | f"Saving template '{template.template_id}' to repository",
265 | emoji_key="save",
266 | metadata=template.metadata
267 | )
268 |
269 | save_result = await repo.save_prompt(template.template_id, template_dict)
270 |
271 | if save_result:
272 | logger.success(
273 | f"Template '{template.template_id}' saved successfully",
274 | emoji_key="success"
275 | )
276 | else:
277 | logger.error(
278 | f"Failed to save template '{template.template_id}'",
279 | emoji_key="error"
280 | )
281 | return
282 |
283 | # Retrieve the saved template
284 | logger.info(f"Retrieving template '{template.template_id}' from repository", emoji_key="loading")
285 |
286 | retrieved_dict = await repo.get_prompt(template.template_id)
287 |
288 | if retrieved_dict:
289 | # Convert back to PromptTemplate object
290 | retrieved_template = PromptTemplate.from_dict(retrieved_dict)
291 |
292 | logger.success(
293 | f"Retrieved template '{retrieved_template.template_id}' successfully",
294 | emoji_key="success",
295 | metadata=retrieved_template.metadata
296 | )
297 |
298 | # Display retrieved template details using Rich
299 | retrieved_table = Table(title=f"[bold]Retrieved Template: {escape(retrieved_template.template_id)}[/bold]", box=box.ROUNDED, show_header=False)
300 | retrieved_table.add_column("Attribute", style="cyan")
301 | retrieved_table.add_column("Value", style="white")
302 | retrieved_table.add_row("Description", escape(retrieved_template.description))
303 | retrieved_table.add_row("Metadata", escape(str(retrieved_template.metadata)))
304 | console.print(retrieved_table)
305 | console.print(Panel(
306 | Syntax(retrieved_template.template, "jinja2", theme="default", line_numbers=False),
307 | title="[bold]Template Source[/bold]",
308 | border_style="dim blue",
309 | expand=False
310 | ))
311 | console.print()
312 |
313 | else:
314 | logger.error(
315 | f"Failed to retrieve template '{template.template_id}'",
316 | emoji_key="error"
317 | )
318 |
319 | # List prompts again to confirm addition
320 | updated_prompts = await repo.list_prompts()
321 | logger.info(
322 | f"Repository now contains {len(updated_prompts)} prompts: {', '.join(updated_prompts)}",
323 | emoji_key="info"
324 | )
325 |
326 | # Comment out the deletion to keep the template for the LLM demo
327 | # Uncommenting the below would delete the template
328 | """
329 | delete_result = await repo.delete_prompt(template.template_id)
330 | if delete_result:
331 | logger.info(
332 | f"Deleted template '{template.template_id}' from repository",
333 | emoji_key="cleaning"
334 | )
335 | """
336 |
337 | return retrieved_template
338 |
339 |
340 | async def demonstrate_llm_with_templates(tracker: CostTracker):
341 | """Demonstrate using a template from the repository with an LLM."""
342 | # Use Rich Rule
343 | console.print(Rule("[bold blue]LLM with Template Demonstration[/bold blue]"))
344 | logger.info("Starting LLM with template demonstration", emoji_key="start")
345 |
346 | # Retrieve the translation template saved earlier
347 | repo = get_prompt_repository()
348 | template_id = "translation_prompt"
349 | template_dict = await repo.get_prompt(template_id)
350 |
351 | if not template_dict:
352 | console.print(f"Prompt '{template_id}' not found")
353 | logger.error(f"Template '{template_id}' not found. Skipping LLM demo.", emoji_key="error")
354 | return
355 |
356 | template = PromptTemplate.from_dict(template_dict)
357 | logger.info(f"Retrieved template '{template_id}' for LLM use", emoji_key="template")
358 |
359 | # Variables for translation
360 | translation_vars = {
361 | "source_language": "English",
362 | "target_language": "French",
363 | "text": "The quick brown fox jumps over the lazy dog."
364 | }
365 |
366 | # Render the prompt
367 | try:
368 | rendered_prompt = template.render(translation_vars)
369 | logger.info("Translation prompt rendered", emoji_key="success")
370 |
371 | # Display the rendered prompt for clarity
372 | console.print(Panel(
373 | escape(rendered_prompt.strip()),
374 | title="[bold]Rendered Translation Prompt[/bold]",
375 | border_style="blue",
376 | expand=False
377 | ))
378 |
379 | except Exception as e:
380 | logger.error(f"Error rendering translation prompt: {str(e)}", emoji_key="error", exc_info=True)
381 | return
382 |
383 | # Initialize gateway with providers
384 | gateway = Gateway("prompt-templates-demo", register_tools=False)
385 | logger.info("Initializing providers...", emoji_key="provider")
386 | await gateway._initialize_providers()
387 |
388 | # Providers to try in order of preference
389 | providers_to_try = [
390 | Provider.OPENAI.value,
391 | Provider.ANTHROPIC.value,
392 | Provider.GEMINI.value,
393 | Provider.DEEPSEEK.value
394 | ]
395 |
396 | # Find an available provider
397 | provider = None
398 | provider_name = None
399 |
400 | for p_name in providers_to_try:
401 | if p_name in gateway.providers:
402 | provider = gateway.providers[p_name]
403 | provider_name = p_name
404 | logger.info(f"Using provider {p_name}", emoji_key="provider")
405 | break
406 |
407 | try:
408 | model = provider.get_default_model()
409 | logger.info(f"Using provider {provider_name} with model {model}", emoji_key="provider")
410 |
411 | # Generate completion using the rendered prompt
412 | logger.info("Generating translation...", emoji_key="processing")
413 | start_time = time.time()
414 | result = await provider.generate_completion(
415 | prompt=rendered_prompt,
416 | model=model,
417 | temperature=0.5,
418 | max_tokens=150
419 | )
420 | processing_time = time.time() - start_time
421 |
422 | logger.success("Translation generated successfully!", emoji_key="success")
423 |
424 | # Use display.py function for better visualization
425 | display_text_content_result(
426 | f"Translation Result ({escape(provider_name)}/{escape(model)})",
427 | result,
428 | console_instance=console
429 | )
430 |
431 | # Track cost
432 | tracker.add_call(result)
433 |
434 | # Display additional stats with standard rich components
435 | stats_table = Table(title="Translation Stats", show_header=False, box=box.ROUNDED)
436 | stats_table.add_column("Metric", style="cyan")
437 | stats_table.add_column("Value", style="white")
438 | stats_table.add_row("Provider", provider_name)
439 | stats_table.add_row("Model", model)
440 | stats_table.add_row("Input Tokens", str(result.input_tokens))
441 | stats_table.add_row("Output Tokens", str(result.output_tokens))
442 | stats_table.add_row("Cost", f"${result.cost:.6f}")
443 | stats_table.add_row("Processing Time", f"{processing_time:.3f}s")
444 | console.print(stats_table)
445 |
446 | except Exception as e:
447 | logger.error(f"Error during LLM completion: {str(e)}", emoji_key="error", exc_info=True)
448 | # Fall back to mock response
449 | console.print(Panel(
450 | "[yellow]Failed to generate real translation. Here's a mock response:[/yellow]\n" +
451 | "Le renard brun rapide saute par-dessus le chien paresseux.",
452 | title="[bold yellow]Mock Translation (After Error)[/bold yellow]",
453 | border_style="yellow"
454 | ))
455 |
456 | # Display cost summary at the end of this demo section
457 | tracker.display_summary(console)
458 |
459 |
460 | async def main():
461 | """Run all demonstrations."""
462 | try:
463 | # Demonstrate template creation and rendering
464 | template1, template2 = await demonstrate_prompt_templates()
465 | console.print() # Add space
466 |
467 | # Demonstrate repository usage
468 | retrieved_template = await demonstrate_prompt_repository() # noqa: F841
469 | console.print()
470 |
471 | # Demonstrate using a template with LLM - no longer check for retrieved_template
472 | # as it should always be available since we commented out the deletion
473 | tracker = CostTracker() # Instantiate tracker here
474 | await demonstrate_llm_with_templates(tracker)
475 |
476 | except Exception as e:
477 | logger.critical(f"Demo failed: {str(e)}", emoji_key="critical", exc_info=True)
478 | return 1
479 |
480 | # Clean up after demo is complete - optionally delete the template
481 | try:
482 | # After demo is complete, we can clean up by deleting the template
483 | repo = get_prompt_repository()
484 | await repo.delete_prompt("translation_prompt")
485 | logger.info("Deleted demonstration template", emoji_key="cleaning")
486 | except Exception as e:
487 | logger.warning(f"Cleanup error: {str(e)}", emoji_key="warning")
488 |
489 | logger.success("Prompt Template Demo Finished Successfully!", emoji_key="complete")
490 | return 0
491 |
492 |
493 | if __name__ == "__main__":
494 | exit_code = asyncio.run(main())
495 | sys.exit(exit_code)
```
--------------------------------------------------------------------------------
/tests/unit/test_providers.py:
--------------------------------------------------------------------------------
```python
1 | """Tests for the provider implementations."""
2 | from typing import Any, Dict
3 |
4 | import pytest
5 | from pytest import MonkeyPatch
6 |
7 | from ultimate_mcp_server.constants import Provider
8 | from ultimate_mcp_server.core.providers.anthropic import AnthropicProvider
9 | from ultimate_mcp_server.core.providers.base import (
10 | BaseProvider,
11 | ModelResponse,
12 | get_provider,
13 | )
14 | from ultimate_mcp_server.core.providers.deepseek import DeepSeekProvider
15 | from ultimate_mcp_server.core.providers.gemini import GeminiProvider
16 | from ultimate_mcp_server.core.providers.openai import OpenAIProvider
17 | from ultimate_mcp_server.utils import get_logger
18 |
19 | logger = get_logger("test.providers")
20 |
21 | # Set the loop scope for all tests - function scope is recommended for isolated test execution
22 | pytestmark = pytest.mark.asyncio(loop_scope="function")
23 |
24 |
25 | class TestBaseProvider:
26 | """Tests for the base provider class."""
27 |
28 | def test_init(self):
29 | """Test provider initialization."""
30 | logger.info("Testing base provider initialization", emoji_key="test")
31 |
32 | class TestProvider(BaseProvider):
33 | provider_name = "test"
34 |
35 | async def initialize(self):
36 | return True
37 |
38 | async def generate_completion(self, prompt, **kwargs):
39 | return ModelResponse(
40 | text="Test response",
41 | model="test-model",
42 | provider=self.provider_name
43 | )
44 |
45 | async def generate_completion_stream(self, prompt, **kwargs):
46 | yield "Test", {}
47 |
48 | def get_default_model(self):
49 | return "test-model"
50 |
51 | provider = TestProvider(api_key="test-key", test_option="value")
52 |
53 | assert provider.api_key == "test-key"
54 | assert provider.options == {"test_option": "value"}
55 | assert provider.provider_name == "test"
56 |
57 | async def test_process_with_timer(self, mock_provider: BaseProvider):
58 | """Test the process_with_timer utility method."""
59 | logger.info("Testing process_with_timer", emoji_key="test")
60 |
61 | # Create a mock async function that returns a value
62 | async def mock_func(arg1, arg2=None):
63 | return {"result": arg1 + str(arg2 or "")}
64 |
65 | # Process with timer
66 | result, time_taken = await mock_provider.process_with_timer(
67 | mock_func, "test", arg2="arg"
68 | )
69 |
70 | assert result == {"result": "testarg"}
71 | assert isinstance(time_taken, float)
72 | assert time_taken >= 0 # Time should be non-negative
73 |
74 | def test_model_response(self):
75 | """Test the ModelResponse class."""
76 | logger.info("Testing ModelResponse", emoji_key="test")
77 |
78 | # Create a response with minimal info
79 | response = ModelResponse(
80 | text="Test response",
81 | model="test-model",
82 | provider="test"
83 | )
84 |
85 | assert response.text == "Test response"
86 | assert response.model == "test-model"
87 | assert response.provider == "test"
88 | assert response.input_tokens == 0
89 | assert response.output_tokens == 0
90 | assert response.total_tokens == 0
91 | assert response.cost == 0.0 # No tokens, no cost
92 |
93 | # Create a response with token usage
94 | response = ModelResponse(
95 | text="Test response with tokens",
96 | model="gpt-4o", # A model with known cost
97 | provider="openai",
98 | input_tokens=100,
99 | output_tokens=50
100 | )
101 |
102 | assert response.input_tokens == 100
103 | assert response.output_tokens == 50
104 | assert response.total_tokens == 150
105 | assert response.cost > 0.0 # Should have calculated a cost
106 |
107 | # Test dictionary conversion
108 | response_dict = response.to_dict()
109 | assert response_dict["text"] == "Test response with tokens"
110 | assert response_dict["model"] == "gpt-4o"
111 | assert response_dict["provider"] == "openai"
112 | assert response_dict["usage"]["input_tokens"] == 100
113 | assert response_dict["usage"]["output_tokens"] == 50
114 | assert response_dict["usage"]["total_tokens"] == 150
115 | assert "cost" in response_dict
116 |
117 | def test_get_provider_factory(self, mock_env_vars):
118 | """Test the get_provider factory function."""
119 | logger.info("Testing get_provider factory", emoji_key="test")
120 |
121 | # Test getting a provider by name
122 | openai_provider = get_provider(Provider.OPENAI.value)
123 | assert isinstance(openai_provider, OpenAIProvider)
124 | assert openai_provider.provider_name == Provider.OPENAI.value
125 |
126 | # Test with different provider
127 | anthropic_provider = get_provider(Provider.ANTHROPIC.value)
128 | assert isinstance(anthropic_provider, AnthropicProvider)
129 | assert anthropic_provider.provider_name == Provider.ANTHROPIC.value
130 |
131 | # Test with invalid provider
132 | with pytest.raises(ValueError):
133 | get_provider("invalid-provider")
134 |
135 | # Test with custom API key
136 | custom_key_provider = get_provider(Provider.OPENAI.value, api_key="custom-key")
137 | assert custom_key_provider.api_key == "custom-key"
138 |
139 |
140 | class TestOpenAIProvider:
141 | """Tests for the OpenAI provider."""
142 |
143 | @pytest.fixture
144 | def mock_openai_responses(self) -> Dict[str, Any]:
145 | """Mock responses for OpenAI API."""
146 | # Create proper class-based mocks with attributes instead of dictionaries
147 | class MockCompletion:
148 | def __init__(self):
149 | self.id = "mock-completion-id"
150 | self.choices = [MockChoice()]
151 | self.usage = MockUsage()
152 |
153 | class MockChoice:
154 | def __init__(self):
155 | self.message = MockMessage()
156 | self.finish_reason = "stop"
157 |
158 | class MockMessage:
159 | def __init__(self):
160 | self.content = "Mock OpenAI response"
161 |
162 | class MockUsage:
163 | def __init__(self):
164 | self.prompt_tokens = 10
165 | self.completion_tokens = 5
166 | self.total_tokens = 15
167 |
168 | class MockModelsResponse:
169 | def __init__(self):
170 | self.data = [
171 | type("MockModel", (), {"id": "gpt-4o", "owned_by": "openai"}),
172 | type("MockModel", (), {"id": "gpt-4.1-mini", "owned_by": "openai"}),
173 | type("MockModel", (), {"id": "gpt-4.1-mini", "owned_by": "openai"})
174 | ]
175 |
176 | return {
177 | "completion": MockCompletion(),
178 | "models": MockModelsResponse()
179 | }
180 |
181 | @pytest.fixture
182 | def mock_openai_provider(self, monkeypatch: MonkeyPatch, mock_openai_responses: Dict[str, Any]) -> OpenAIProvider:
183 | """Get a mock OpenAI provider with patched methods."""
184 | # Create the provider
185 | provider = OpenAIProvider(api_key="mock-openai-key")
186 |
187 | # Mock the AsyncOpenAI client methods
188 | class MockAsyncOpenAI:
189 | def __init__(self, **kwargs):
190 | self.kwargs = kwargs
191 | self.chat = MockChat()
192 | self.models = MockModels()
193 |
194 | class MockChat:
195 | def __init__(self):
196 | self.completions = MockCompletions()
197 |
198 | class MockCompletions:
199 | async def create(self, **kwargs):
200 | return mock_openai_responses["completion"]
201 |
202 | class MockModels:
203 | async def list(self):
204 | return mock_openai_responses["models"]
205 |
206 | # Patch the AsyncOpenAI client
207 | monkeypatch.setattr("openai.AsyncOpenAI", MockAsyncOpenAI)
208 |
209 | # Initialize the provider with the mock client
210 | provider.client = MockAsyncOpenAI(api_key="mock-openai-key")
211 |
212 | return provider
213 |
214 | async def test_initialization(self, mock_openai_provider: OpenAIProvider):
215 | """Test OpenAI provider initialization."""
216 | logger.info("Testing OpenAI provider initialization", emoji_key="test")
217 |
218 | # Initialize
219 | success = await mock_openai_provider.initialize()
220 | assert success
221 | assert mock_openai_provider.client is not None
222 |
223 | async def test_completion(self, mock_openai_provider: OpenAIProvider):
224 | """Test OpenAI completion generation."""
225 | logger.info("Testing OpenAI completion", emoji_key="test")
226 |
227 | # Generate completion
228 | result = await mock_openai_provider.generate_completion(
229 | prompt="Test prompt",
230 | model="gpt-4o",
231 | temperature=0.7
232 | )
233 |
234 | # Check result
235 | assert isinstance(result, ModelResponse)
236 | assert result.text == "Mock OpenAI response"
237 | assert result.model == "gpt-4o"
238 | assert result.provider == Provider.OPENAI.value
239 | assert result.input_tokens == 10
240 | assert result.output_tokens == 5
241 | assert result.total_tokens == 15
242 |
243 | async def test_list_models(self, mock_openai_provider: OpenAIProvider):
244 | """Test listing OpenAI models."""
245 | logger.info("Testing OpenAI list_models", emoji_key="test")
246 |
247 | # Initialize first
248 | await mock_openai_provider.initialize()
249 |
250 | # List models
251 | models = await mock_openai_provider.list_models()
252 |
253 | # Should return filtered list of models (chat-capable)
254 | assert isinstance(models, list)
255 | assert len(models) > 0
256 |
257 | # Check model format
258 | for model in models:
259 | assert "id" in model
260 | assert "provider" in model
261 | assert model["provider"] == Provider.OPENAI.value
262 |
263 | def test_default_model(self, mock_openai_provider: OpenAIProvider):
264 | """Test getting default model."""
265 | logger.info("Testing OpenAI default_model", emoji_key="test")
266 |
267 | # Should return a default model
268 | model = mock_openai_provider.get_default_model()
269 | assert model is not None
270 | assert isinstance(model, str)
271 |
272 |
273 | class TestAnthropicProvider:
274 | """Tests for the Anthropic provider."""
275 |
276 | @pytest.fixture
277 | def mock_anthropic_responses(self) -> Dict[str, Any]:
278 | """Mock responses for Anthropic API."""
279 | class MockMessage:
280 | def __init__(self):
281 | # Content should be an array of objects with text property
282 | self.content = [type("ContentBlock", (), {"text": "Mock Claude response"})]
283 | self.usage = type("Usage", (), {"input_tokens": 20, "output_tokens": 10})
284 |
285 | return {
286 | "message": MockMessage()
287 | }
288 |
289 | @pytest.fixture
290 | def mock_anthropic_provider(self, monkeypatch: MonkeyPatch, mock_anthropic_responses: Dict[str, Any]) -> AnthropicProvider:
291 | """Get a mock Anthropic provider with patched methods."""
292 | # Create the provider
293 | provider = AnthropicProvider(api_key="mock-anthropic-key")
294 |
295 | # Mock the AsyncAnthropic client methods
296 | class MockAsyncAnthropic:
297 | def __init__(self, **kwargs):
298 | self.kwargs = kwargs
299 | self.messages = MockMessages()
300 |
301 | class MockMessages:
302 | async def create(self, **kwargs):
303 | return mock_anthropic_responses["message"]
304 |
305 | async def stream(self, **kwargs):
306 | class MockStream:
307 | async def __aenter__(self):
308 | return self
309 |
310 | async def __aexit__(self, exc_type, exc_val, exc_tb):
311 | pass
312 |
313 | async def __aiter__(self):
314 | yield type("MockChunk", (), {
315 | "type": "content_block_delta",
316 | "delta": type("MockDelta", (), {"text": "Mock streaming content"})
317 | })
318 |
319 | async def get_final_message(self):
320 | return mock_anthropic_responses["message"]
321 |
322 | return MockStream()
323 |
324 | # Patch the AsyncAnthropic client
325 | monkeypatch.setattr("anthropic.AsyncAnthropic", MockAsyncAnthropic)
326 |
327 | # Initialize the provider with the mock client
328 | provider.client = MockAsyncAnthropic(api_key="mock-anthropic-key")
329 |
330 | return provider
331 |
332 | async def test_initialization(self, mock_anthropic_provider: AnthropicProvider):
333 | """Test Anthropic provider initialization."""
334 | logger.info("Testing Anthropic provider initialization", emoji_key="test")
335 |
336 | # Initialize
337 | success = await mock_anthropic_provider.initialize()
338 | assert success
339 | assert mock_anthropic_provider.client is not None
340 |
341 | async def test_completion(self, mock_anthropic_provider: AnthropicProvider):
342 | """Test Anthropic completion generation."""
343 | logger.info("Testing Anthropic completion", emoji_key="test")
344 |
345 | # Generate completion
346 | result = await mock_anthropic_provider.generate_completion(
347 | prompt="Test prompt",
348 | model="claude-3-sonnet-20240229",
349 | temperature=0.7
350 | )
351 |
352 | # Check result
353 | assert isinstance(result, ModelResponse)
354 | assert result.text == "Mock Claude response"
355 | assert result.model == "claude-3-sonnet-20240229"
356 | assert result.provider == Provider.ANTHROPIC.value
357 | assert result.input_tokens == 20
358 | assert result.output_tokens == 10
359 | assert result.total_tokens == 30
360 |
361 | async def test_list_models(self, mock_anthropic_provider: AnthropicProvider):
362 | """Test listing Anthropic models."""
363 | logger.info("Testing Anthropic list_models", emoji_key="test")
364 |
365 | # Initialize first
366 | await mock_anthropic_provider.initialize()
367 |
368 | # List models
369 | models = await mock_anthropic_provider.list_models()
370 |
371 | # Should return a list of models
372 | assert isinstance(models, list)
373 | assert len(models) > 0
374 |
375 | # Check model format
376 | for model in models:
377 | assert "id" in model
378 | assert "provider" in model
379 | assert model["provider"] == Provider.ANTHROPIC.value
380 |
381 | def test_default_model(self, mock_anthropic_provider: AnthropicProvider):
382 | """Test getting default model."""
383 | logger.info("Testing Anthropic default_model", emoji_key="test")
384 |
385 | # Should return a default model
386 | model = mock_anthropic_provider.get_default_model()
387 | assert model is not None
388 | assert isinstance(model, str)
389 |
390 |
391 | # Brief tests for the other providers to save space
392 | class TestOtherProviders:
393 | """Brief tests for other providers."""
394 |
395 | async def test_deepseek_provider(self, monkeypatch: MonkeyPatch):
396 | """Test DeepSeek provider."""
397 | logger.info("Testing DeepSeek provider", emoji_key="test")
398 |
399 | # Mock the API client
400 | monkeypatch.setattr("openai.AsyncOpenAI", lambda **kwargs: type("MockClient", (), {
401 | "chat": type("MockChat", (), {
402 | "completions": type("MockCompletions", (), {
403 | "create": lambda **kwargs: type("MockResponse", (), {
404 | "choices": [type("MockChoice", (), {
405 | "message": type("MockMessage", (), {"content": "Mock DeepSeek response"}),
406 | "finish_reason": "stop"
407 | })],
408 | "usage": type("MockUsage", (), {
409 | "prompt_tokens": 15,
410 | "completion_tokens": 8,
411 | "total_tokens": 23
412 | })
413 | })
414 | })
415 | })
416 | }))
417 |
418 | provider = DeepSeekProvider(api_key="mock-deepseek-key")
419 | assert provider.provider_name == Provider.DEEPSEEK.value
420 |
421 | # Should return a default model
422 | model = provider.get_default_model()
423 | assert model is not None
424 | assert isinstance(model, str)
425 |
426 | async def test_gemini_provider(self, monkeypatch: MonkeyPatch):
427 | """Test Gemini provider."""
428 | logger.info("Testing Gemini provider", emoji_key="test")
429 |
430 | # Create mock response
431 | mock_response = type("MockResponse", (), {
432 | "text": "Mock Gemini response",
433 | "candidates": [
434 | type("MockCandidate", (), {
435 | "content": {
436 | "parts": [{"text": "Mock Gemini response"}]
437 | }
438 | })
439 | ]
440 | })
441 |
442 | # Mock the Google Generative AI Client
443 | class MockClient:
444 | def __init__(self, **kwargs):
445 | self.kwargs = kwargs
446 | self.models = MockModels()
447 |
448 | class MockModels:
449 | def generate_content(self, **kwargs):
450 | return mock_response
451 |
452 | def list(self):
453 | return [
454 | {"name": "gemini-2.0-flash-lite"},
455 | {"name": "gemini-2.0-pro"}
456 | ]
457 |
458 | # Patch the genai Client
459 | monkeypatch.setattr("google.genai.Client", MockClient)
460 |
461 | # Create and test the provider
462 | provider = GeminiProvider(api_key="mock-gemini-key")
463 | assert provider.provider_name == Provider.GEMINI.value
464 |
465 | # Initialize with the mock client
466 | await provider.initialize()
467 |
468 | # Should return a default model
469 | model = provider.get_default_model()
470 | assert model is not None
471 | assert isinstance(model, str)
472 |
473 | # Test completion
474 | result = await provider.generate_completion(
475 | prompt="Test prompt",
476 | model="gemini-2.0-pro"
477 | )
478 |
479 | # Check result
480 | assert result.text is not None
481 | assert "Gemini" in result.text # Should contain "Mock Gemini response"
```
--------------------------------------------------------------------------------
/examples/rag_example.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python3
2 | """Example of using the RAG functionality with Ultimate MCP Server."""
3 | import asyncio
4 | import sys
5 | from pathlib import Path
6 |
7 | # Add parent directory to path to import ultimate_mcp_server
8 | sys.path.insert(0, str(Path(__file__).parent.parent))
9 |
10 | from rich.panel import Panel
11 | from rich.rule import Rule
12 | from rich.table import Table
13 |
14 | from ultimate_mcp_server.core.server import Gateway
15 | from ultimate_mcp_server.services.knowledge_base import (
16 | get_knowledge_base_manager,
17 | get_knowledge_base_retriever,
18 | )
19 | from ultimate_mcp_server.utils import get_logger
20 | from ultimate_mcp_server.utils.display import CostTracker
21 | from ultimate_mcp_server.utils.logging.console import console
22 |
23 | # Initialize logger
24 | logger = get_logger("rag_example")
25 |
26 | # Sample documents about different AI technologies
27 | AI_DOCUMENTS = [
28 | """Transformers are a type of neural network architecture introduced in the paper
29 | "Attention is All You Need" by Vaswani et al. in 2017. They use self-attention
30 | mechanisms to process sequential data, making them highly effective for natural
31 | language processing tasks. Unlike recurrent neural networks (RNNs), transformers
32 | process entire sequences in parallel, which allows for more efficient training.
33 | The original transformer architecture consists of an encoder and a decoder, each
34 | made up of multiple layers of self-attention and feed-forward neural networks.""",
35 |
36 | """Retrieval-Augmented Generation (RAG) is an AI framework that combines the
37 | strengths of retrieval-based and generation-based approaches. In RAG systems,
38 | a retrieval component first finds relevant information from a knowledge base,
39 | and then a generation component uses this information to produce more accurate,
40 | factual, and contextually relevant outputs. RAG helps to mitigate hallucination
41 | issues in large language models by grounding the generation in retrieved facts.""",
42 |
43 | """Reinforcement Learning from Human Feedback (RLHF) is a technique used to align
44 | language models with human preferences. The process typically involves three steps:
45 | First, a language model is pre-trained on a large corpus of text. Second, human
46 | evaluators rank different model outputs, creating a dataset of preferred responses.
47 | Third, this dataset is used to train a reward model, which is then used to fine-tune
48 | the language model using reinforcement learning techniques such as Proximal Policy
49 | Optimization (PPO).""",
50 |
51 | """Mixture of Experts (MoE) is an architecture where multiple specialized neural
52 | networks (experts) are trained to handle different types of inputs or tasks. A
53 | gating network determines which expert(s) should process each input. This approach
54 | allows for larger model capacity without a proportional increase in computational
55 | costs, as only a subset of the parameters is activated for any given input. MoE
56 | has been successfully applied in large language models like Google's Switch
57 | Transformer and Microsoft's Mixtral."""
58 | ]
59 |
60 | AI_METADATAS = [
61 | {"title": "Transformers", "source": "AI Handbook", "type": "architecture"},
62 | {"title": "Retrieval-Augmented Generation", "source": "AI Handbook", "type": "technique"},
63 | {"title": "RLHF", "source": "AI Handbook", "type": "technique"},
64 | {"title": "Mixture of Experts", "source": "AI Handbook", "type": "architecture"}
65 | ]
66 |
67 | EXAMPLE_QUERIES = [
68 | "How do transformers work?",
69 | "What is retrieval-augmented generation?",
70 | "Compare RLHF and MoE approaches."
71 | ]
72 |
73 | KB_NAME = "ai_technologies"
74 |
75 | async def run_rag_demo(tracker: CostTracker):
76 | """Run the complete RAG demonstration."""
77 | console.print("[bold blue]RAG Example with Ultimate MCP Server[/bold blue]")
78 | console.print("This example demonstrates the RAG functionality using direct knowledge base services.")
79 | console.print()
80 |
81 | # Initialize Gateway for proper provider and API key management
82 | gateway = Gateway("rag-example", register_tools=False)
83 | await gateway._initialize_providers()
84 |
85 | # Get knowledge base services directly
86 | kb_manager = get_knowledge_base_manager()
87 | kb_retriever = get_knowledge_base_retriever()
88 |
89 | # Clean up any existing knowledge base with the same name before starting
90 | console.print(Rule("[bold blue]Cleaning Up Previous Runs[/bold blue]"))
91 |
92 | # Force a clean start
93 | try:
94 | # Get direct reference to the vector service
95 | from ultimate_mcp_server.services.vector import get_vector_db_service
96 | vector_service = get_vector_db_service()
97 |
98 | # Try a more aggressive approach by resetting chromadb client directly
99 | if hasattr(vector_service, 'chroma_client') and vector_service.chroma_client:
100 | try:
101 | # First try standard deletion
102 | try:
103 | vector_service.chroma_client.delete_collection(KB_NAME)
104 | logger.info("Successfully deleted ChromaDB collection using client API")
105 | except Exception as e:
106 | logger.debug(f"Standard ChromaDB deletion failed: {str(e)}")
107 |
108 | # Wait longer to ensure deletion propagates
109 | await asyncio.sleep(1.0)
110 |
111 | # Force reset the ChromaDB client when all else fails
112 | if hasattr(vector_service.chroma_client, 'reset'):
113 | try:
114 | vector_service.chroma_client.reset()
115 | logger.info("Reset ChromaDB client to ensure clean start")
116 | await asyncio.sleep(0.5)
117 | except Exception as e:
118 | logger.warning(f"Failed to reset ChromaDB client: {str(e)}")
119 | except Exception as e:
120 | logger.warning(f"Error with ChromaDB client manipulation: {str(e)}")
121 |
122 | # Try to delete at the vector database level again
123 | try:
124 | await vector_service.delete_collection(KB_NAME)
125 | logger.info(f"Directly deleted vector collection '{KB_NAME}'")
126 | await asyncio.sleep(0.5)
127 | except Exception as e:
128 | logger.warning(f"Error directly deleting vector collection: {str(e)}")
129 |
130 | # Also try to delete at the knowledge base level
131 | try:
132 | kb_info = await kb_manager.get_knowledge_base(KB_NAME)
133 | if kb_info and kb_info.get("status") != "not_found":
134 | await kb_manager.delete_knowledge_base(name=KB_NAME)
135 | logger.info(f"Deleted existing knowledge base '{KB_NAME}'")
136 | await asyncio.sleep(0.5)
137 | except Exception as e:
138 | logger.warning(f"Error deleting knowledge base: {str(e)}")
139 |
140 | logger.info("Cleanup completed, proceeding with clean start")
141 | except Exception as e:
142 | logger.warning(f"Error during initial cleanup: {str(e)}")
143 |
144 | console.print()
145 |
146 | # Step 1: Create knowledge base
147 | console.print(Rule("[bold blue]Step 1: Creating Knowledge Base[/bold blue]"))
148 | try:
149 | await kb_manager.create_knowledge_base(
150 | name=KB_NAME,
151 | description="Information about various AI technologies",
152 | embedding_model="text-embedding-3-small",
153 | overwrite=True
154 | )
155 | logger.success(f"Knowledge base created: {KB_NAME}", emoji_key="success")
156 | except Exception as e:
157 | logger.error(f"Failed to create knowledge base: {str(e)}", emoji_key="error")
158 | return 1
159 |
160 | console.print()
161 |
162 | # Step 2: Add documents
163 | console.print(Rule("[bold blue]Step 2: Adding Documents[/bold blue]"))
164 | try:
165 | result = await kb_manager.add_documents(
166 | knowledge_base_name=KB_NAME,
167 | documents=AI_DOCUMENTS,
168 | metadatas=AI_METADATAS,
169 | embedding_model="text-embedding-3-small",
170 | chunk_size=1000,
171 | chunk_method="semantic"
172 | )
173 | added_count = result.get("added_count", 0)
174 | logger.success(f"Added {added_count} documents to knowledge base", emoji_key="success")
175 | except Exception as e:
176 | logger.error(f"Failed to add documents: {str(e)}", emoji_key="error")
177 | return 1
178 |
179 | console.print()
180 |
181 | # Step 3: List knowledge bases
182 | console.print(Rule("[bold blue]Step 3: Listing Knowledge Bases[/bold blue]"))
183 | try:
184 | knowledge_bases = await kb_manager.list_knowledge_bases()
185 |
186 | # Create a Rich table for display
187 | table = Table(title="Available Knowledge Bases", box=None)
188 | table.add_column("Name", style="cyan")
189 | table.add_column("Description", style="green")
190 | table.add_column("Document Count", style="magenta")
191 |
192 | # Handle various return types
193 | try:
194 | if knowledge_bases is None:
195 | table.add_row("No knowledge bases found", "", "")
196 | elif isinstance(knowledge_bases, dict):
197 | # Handle dictionary response
198 | kb_names = knowledge_bases.get("knowledge_bases", [])
199 | if isinstance(kb_names, list):
200 | for kb_item in kb_names:
201 | if isinstance(kb_item, dict):
202 | # Extract name and metadata from dictionary
203 | name = kb_item.get("name", "Unknown")
204 | metadata = kb_item.get("metadata", {})
205 | description = metadata.get("description", "No description") if isinstance(metadata, dict) else "No description"
206 | doc_count = metadata.get("doc_count", "Unknown") if isinstance(metadata, dict) else "Unknown"
207 | table.add_row(str(name), str(description), str(doc_count))
208 | else:
209 | table.add_row(str(kb_item), "No description available", "Unknown")
210 | else:
211 | table.add_row("Error parsing response", "", "")
212 | elif isinstance(knowledge_bases, list):
213 | # Handle list response
214 | for kb in knowledge_bases:
215 | if isinstance(kb, str):
216 | table.add_row(kb, "No description", "0")
217 | elif isinstance(kb, dict):
218 | name = kb.get("name", "Unknown")
219 | metadata = kb.get("metadata", {})
220 | description = metadata.get("description", "No description") if isinstance(metadata, dict) else "No description"
221 | doc_count = metadata.get("doc_count", "Unknown") if isinstance(metadata, dict) else "Unknown"
222 | table.add_row(str(name), str(description), str(doc_count))
223 | else:
224 | kb_name = str(getattr(kb, 'name', str(kb)))
225 | table.add_row(kb_name, "No description", "0")
226 | else:
227 | # Fallback for unexpected response type
228 | table.add_row(f"Unexpected response: {type(knowledge_bases)}", "", "")
229 |
230 | console.print(table)
231 | except Exception as e:
232 | logger.error(f"Error rendering knowledge bases table: {str(e)}", emoji_key="error")
233 | # Simple fallback display
234 | console.print(f"Knowledge bases available: {knowledge_bases}")
235 | except Exception as e:
236 | logger.error(f"Failed to list knowledge bases: {str(e)}", emoji_key="error")
237 |
238 | console.print()
239 |
240 | # Step 4: Retrieve context for first query
241 | console.print(Rule("[bold blue]Step 4: Retrieving Context[/bold blue]"))
242 |
243 | query = EXAMPLE_QUERIES[0]
244 | logger.info(f"Retrieving context for query: '{query}'", emoji_key="processing")
245 |
246 | # Default fallback document if retrieval fails
247 | retrieved_results = []
248 |
249 | try:
250 | try:
251 | results = await kb_retriever.retrieve(
252 | knowledge_base_name=KB_NAME,
253 | query=query,
254 | top_k=2,
255 | min_score=0.0, # Set min_score to 0 to see all results
256 | embedding_model="text-embedding-3-small" # Use the same embedding model as when adding documents
257 | )
258 | retrieved_results = results.get('results', [])
259 |
260 | # Debug raw results
261 | logger.debug(f"Raw retrieval results: {results}")
262 | except Exception as e:
263 | logger.error(f"Error retrieving from knowledge base: {str(e)}", emoji_key="error")
264 | # Fallback to using the documents directly
265 | retrieved_results = [
266 | {
267 | "document": AI_DOCUMENTS[0],
268 | "score": 0.95,
269 | "metadata": AI_METADATAS[0]
270 | }
271 | ]
272 |
273 | console.print(f"Retrieved {len(retrieved_results)} results for query: '{query}'")
274 |
275 | # Display results in panels
276 | if retrieved_results:
277 | for i, doc in enumerate(retrieved_results):
278 | try:
279 | score = doc.get('score', 0.0)
280 | document = doc.get('document', '')
281 | metadata = doc.get('metadata', {})
282 | source = metadata.get('title', 'Unknown') if isinstance(metadata, dict) else 'Unknown'
283 |
284 | console.print(Panel(
285 | f"[bold]Document {i+1}[/bold] (score: {score:.2f})\n" +
286 | f"[italic]{document[:150]}...[/italic]",
287 | title=f"Source: {source}",
288 | border_style="blue"
289 | ))
290 | except Exception as e:
291 | logger.error(f"Error displaying document {i}: {str(e)}", emoji_key="error")
292 | else:
293 | console.print(Panel(
294 | "[italic]No results found. Using sample document as fallback for demonstration.[/italic]",
295 | title="No Results",
296 | border_style="yellow"
297 | ))
298 | # Create a fallback document for the next step
299 | retrieved_results = [
300 | {
301 | "document": AI_DOCUMENTS[0],
302 | "score": 0.0,
303 | "metadata": AI_METADATAS[0]
304 | }
305 | ]
306 | except Exception as e:
307 | logger.error(f"Failed to process retrieval results: {str(e)}", emoji_key="error")
308 | # Ensure we have something to continue with
309 | retrieved_results = [
310 | {
311 | "document": AI_DOCUMENTS[0],
312 | "score": 0.0,
313 | "metadata": AI_METADATAS[0]
314 | }
315 | ]
316 |
317 | console.print()
318 |
319 | # Step 5: Generate completions using retrieved context for the first query
320 | console.print(Rule("[bold blue]Step 5: Generating Response with Retrieved Context[/bold blue]"))
321 | query = EXAMPLE_QUERIES[0]
322 | console.print(f"\n[bold]Query:[/bold] {query}")
323 |
324 | try:
325 | # Get the provider
326 | provider_key = "gemini"
327 | provider = gateway.providers.get(provider_key)
328 | if not provider:
329 | provider_key = "openai"
330 | provider = gateway.providers.get(provider_key) # Fallback
331 |
332 | if not provider:
333 | logger.error("No suitable provider found", emoji_key="error")
334 | return 1
335 |
336 | # Use a hardcoded model based on provider type
337 | if provider_key == "gemini":
338 | model = "gemini-2.0-flash-lite"
339 | elif provider_key == "openai":
340 | model = "gpt-4.1-mini"
341 | elif provider_key == "anthropic":
342 | model = "claude-3-haiku-latest"
343 | else:
344 | # Get first available model or fallback
345 | models = getattr(provider, 'available_models', [])
346 | model = models[0] if models else "unknown-model"
347 |
348 | # Prepare context from retrieved documents
349 | if retrieved_results:
350 | context = "\n\n".join([doc.get("document", "") for doc in retrieved_results if doc.get("document")])
351 | else:
352 | # Fallback to using the first document directly if no results
353 | context = AI_DOCUMENTS[0]
354 |
355 | # Build prompt with context
356 | prompt = f"""Answer the following question based on the provided context.
357 | If the context doesn't contain relevant information, say so.
358 |
359 | Context:
360 | {context}
361 |
362 | Question: {query}
363 |
364 | Answer:"""
365 |
366 | # Generate response
367 | response = await provider.generate_completion(
368 | prompt=prompt,
369 | model=model,
370 | temperature=0.3,
371 | max_tokens=300
372 | )
373 |
374 | # Display the answer
375 | console.print(Panel(
376 | response.text,
377 | title=f"Answer from {provider_key}/{model}",
378 | border_style="green"
379 | ))
380 |
381 | # Display usage stats
382 | metrics_table = Table(title="Performance Metrics", box=None)
383 | metrics_table.add_column("Metric", style="cyan")
384 | metrics_table.add_column("Value", style="white")
385 | metrics_table.add_row("Input Tokens", str(response.input_tokens))
386 | metrics_table.add_row("Output Tokens", str(response.output_tokens))
387 | metrics_table.add_row("Processing Time", f"{response.processing_time:.2f}s")
388 | metrics_table.add_row("Cost", f"${response.cost:.6f}")
389 |
390 | console.print(metrics_table)
391 |
392 | # Track the generation call
393 | tracker.add_call(response)
394 |
395 | except Exception as e:
396 | logger.error(f"Failed to generate response: {str(e)}", emoji_key="error")
397 |
398 | console.print()
399 |
400 | # Step 6: Clean up
401 | console.print(Rule("[bold blue]Step 6: Cleaning Up[/bold blue]"))
402 |
403 | # Display cost summary before final cleanup
404 | tracker.display_summary(console)
405 |
406 | try:
407 | await kb_manager.delete_knowledge_base(name=KB_NAME)
408 | logger.success(f"Knowledge base {KB_NAME} deleted successfully", emoji_key="success")
409 | except Exception as e:
410 | logger.error(f"Failed to delete knowledge base: {str(e)}", emoji_key="error")
411 | return 1
412 |
413 | return 0
414 |
415 | async def main():
416 | """Run the RAG example."""
417 | tracker = CostTracker() # Instantiate tracker
418 | try:
419 | await run_rag_demo(tracker) # Pass tracker
420 | except Exception as e:
421 | logger.critical(f"RAG demo failed unexpectedly: {e}", exc_info=True)
422 | return 1
423 | return 0
424 |
425 | if __name__ == "__main__":
426 | # Run the demonstration
427 | exit_code = asyncio.run(main())
428 | sys.exit(exit_code)
```
--------------------------------------------------------------------------------
/ultimate_mcp_server/core/tournaments/manager.py:
--------------------------------------------------------------------------------
```python
1 | # --- core/tournaments/manager.py (Updates) ---
2 | import asyncio
3 | import json
4 | import re
5 | from datetime import datetime, timezone
6 | from pathlib import Path
7 | from typing import Any, Dict, List, Optional, Tuple # Added Type
8 |
9 | from pydantic import ValidationError
10 |
11 | import ultimate_mcp_server.core.evaluation.evaluators # Ensures evaluators are registered # noqa: F401
12 | from ultimate_mcp_server.core.evaluation.base import EVALUATOR_REGISTRY, Evaluator
13 | from ultimate_mcp_server.core.models.tournament import (
14 | CreateTournamentInput,
15 | TournamentConfig, # ModelConfig is nested in TournamentConfig from CreateTournamentInput
16 | TournamentData,
17 | TournamentRoundResult,
18 | TournamentStatus,
19 | )
20 | from ultimate_mcp_server.core.models.tournament import (
21 | ModelConfig as CoreModelConfig, # Alias to avoid confusion
22 | )
23 | from ultimate_mcp_server.utils import get_logger
24 |
25 | logger = get_logger("ultimate_mcp_server.tournaments.manager")
26 |
27 | STORAGE_DIR = Path(__file__).resolve().parent.parent.parent.parent / "storage"
28 | TOURNAMENT_STORAGE_BASE = STORAGE_DIR / "tournaments"
29 |
30 | class TournamentManager:
31 | def __init__(self):
32 | self.tournaments: Dict[str, TournamentData] = {}
33 | # --- NEW: Store instantiated evaluators per tournament ---
34 | self.tournament_evaluators: Dict[str, List[Evaluator]] = {}
35 | TOURNAMENT_STORAGE_BASE.mkdir(parents=True, exist_ok=True)
36 | logger.info(f"Tournament storage initialized at: {TOURNAMENT_STORAGE_BASE}")
37 | self._load_all_tournaments()
38 |
39 | def _instantiate_evaluators(self, tournament_id: str, config: TournamentConfig) -> bool:
40 | """Instantiates and stores evaluators for a tournament."""
41 | self.tournament_evaluators[tournament_id] = []
42 | for eval_config in config.evaluators:
43 | evaluator_cls = EVALUATOR_REGISTRY.get(eval_config.type)
44 | if not evaluator_cls:
45 | logger.error(f"Unknown evaluator type '{eval_config.type}' for tournament {tournament_id}. Skipping.")
46 | # Optionally, fail tournament creation if a critical evaluator is missing
47 | continue
48 | try:
49 | self.tournament_evaluators[tournament_id].append(evaluator_cls(eval_config.params))
50 | logger.info(f"Instantiated evaluator '{eval_config.type}' (ID: {eval_config.evaluator_id}) for tournament {tournament_id}")
51 | except Exception as e:
52 | logger.error(f"Failed to instantiate evaluator '{eval_config.type}' (ID: {eval_config.evaluator_id}): {e}", exc_info=True)
53 | # Decide if this is a fatal error for the tournament
54 | return False # Example: Fail if any evaluator instantiation fails
55 | return True
56 |
57 | def get_evaluators_for_tournament(self, tournament_id: str) -> List[Evaluator]:
58 | """Returns the list of instantiated evaluators for a given tournament."""
59 | return self.tournament_evaluators.get(tournament_id, [])
60 |
61 | def create_tournament(self, input_data: CreateTournamentInput) -> Optional[TournamentData]:
62 | try:
63 | logger.debug(f"Creating tournament with name: {input_data.name}, {len(input_data.model_configs)} model configs")
64 |
65 | # Map input ModelConfig to core ModelConfig used in TournamentConfig
66 | core_model_configs = [
67 | CoreModelConfig(
68 | model_id=mc.model_id,
69 | diversity_count=mc.diversity_count,
70 | temperature=mc.temperature,
71 | max_tokens=mc.max_tokens,
72 | system_prompt=mc.system_prompt,
73 | seed=mc.seed
74 | ) for mc in input_data.model_configs
75 | ]
76 |
77 | tournament_cfg = TournamentConfig(
78 | name=input_data.name,
79 | prompt=input_data.prompt,
80 | models=core_model_configs, # Use the mapped core_model_configs
81 | rounds=input_data.rounds,
82 | tournament_type=input_data.tournament_type,
83 | extraction_model_id=input_data.extraction_model_id,
84 | evaluators=input_data.evaluators, # Pass evaluator configs
85 | max_retries_per_model_call=input_data.max_retries_per_model_call,
86 | retry_backoff_base_seconds=input_data.retry_backoff_base_seconds,
87 | max_concurrent_model_calls=input_data.max_concurrent_model_calls
88 | )
89 |
90 | tournament = TournamentData(
91 | name=input_data.name,
92 | config=tournament_cfg,
93 | current_round=-1, # Initialize current_round
94 | start_time=None, # Will be set when execution starts
95 | end_time=None
96 | )
97 |
98 | tournament.storage_path = str(self._get_storage_path(tournament.tournament_id, tournament.name)) # Pass name for better paths
99 |
100 | # --- NEW: Instantiate evaluators ---
101 | if not self._instantiate_evaluators(tournament.tournament_id, tournament.config):
102 | logger.error(f"Failed to instantiate one or more evaluators for tournament {tournament.name}. Creation aborted.")
103 | # Clean up if necessary, e.g., remove from self.tournament_evaluators
104 | if tournament.tournament_id in self.tournament_evaluators:
105 | del self.tournament_evaluators[tournament.tournament_id]
106 | return None # Or raise an error
107 |
108 | self.tournaments[tournament.tournament_id] = tournament
109 | self._save_tournament_state(tournament)
110 | logger.info(f"Tournament '{tournament.name}' (ID: {tournament.tournament_id}) created successfully.")
111 | return tournament
112 | except ValidationError as ve:
113 | logger.error(f"Tournament validation failed: {ve}")
114 | return None
115 | except Exception as e:
116 | logger.error(f"Unexpected error creating tournament: {e}", exc_info=True)
117 | return None
118 |
119 | def get_tournament(self, tournament_id: str, force_reload: bool = False) -> Optional[TournamentData]:
120 | logger.debug(f"Getting tournament {tournament_id} (force_reload={force_reload})")
121 | if not force_reload and tournament_id in self.tournaments:
122 | return self.tournaments[tournament_id]
123 |
124 | tournament = self._load_tournament_state(tournament_id)
125 | if tournament:
126 | # --- NEW: Ensure evaluators are loaded/re-instantiated if not present ---
127 | if tournament_id not in self.tournament_evaluators:
128 | logger.info(f"Evaluators for tournament {tournament_id} not in memory, re-instantiating from config.")
129 | if not self._instantiate_evaluators(tournament_id, tournament.config):
130 | logger.error(f"Failed to re-instantiate evaluators for loaded tournament {tournament_id}. Evaluation might fail.")
131 | self.tournaments[tournament_id] = tournament # Update cache
132 | return tournament
133 |
134 | def _save_tournament_state(self, tournament: TournamentData):
135 | if not tournament.storage_path:
136 | logger.error(f"Cannot save state for tournament {tournament.tournament_id}: storage_path not set.")
137 | return
138 |
139 | state_file = Path(tournament.storage_path) / "tournament_state.json"
140 | try:
141 | state_file.parent.mkdir(parents=True, exist_ok=True)
142 | # Pydantic's model_dump_json handles datetime to ISO string conversion
143 | json_data = tournament.model_dump_json(indent=2)
144 | with open(state_file, 'w', encoding='utf-8') as f:
145 | f.write(json_data)
146 | logger.debug(f"Saved state for tournament {tournament.tournament_id} to {state_file}")
147 | except IOError as e:
148 | logger.error(f"Failed to save state for tournament {tournament.tournament_id}: {e}")
149 | except Exception as e: # Catch other potential errors from model_dump_json
150 | logger.error(f"Error serializing tournament state for {tournament.tournament_id}: {e}", exc_info=True)
151 |
152 |
153 | def _load_tournament_state(self, tournament_id: str) -> Optional[TournamentData]:
154 | # Try finding by explicit ID first (common case for direct access)
155 | # The storage path might be complex now, so scan might be more reliable if ID is the only input
156 |
157 | # Robust scan: iterate through all subdirectories of TOURNAMENT_STORAGE_BASE
158 | if TOURNAMENT_STORAGE_BASE.exists():
159 | for potential_tournament_dir in TOURNAMENT_STORAGE_BASE.iterdir():
160 | if potential_tournament_dir.is_dir():
161 | state_file = potential_tournament_dir / "tournament_state.json"
162 | if state_file.exists():
163 | try:
164 | with open(state_file, 'r', encoding='utf-8') as f:
165 | data = json.load(f)
166 | if data.get("tournament_id") == tournament_id:
167 | # Use Pydantic for robust parsing and type conversion
168 | parsed_tournament = TournamentData.model_validate(data)
169 | logger.debug(f"Loaded state for tournament {tournament_id} from {state_file}")
170 | return parsed_tournament
171 | except (IOError, json.JSONDecodeError, ValidationError) as e:
172 | logger.warning(f"Failed to load or validate state from {state_file} for tournament ID {tournament_id}: {e}")
173 | # Don't return, continue scanning
174 | except Exception as e: # Catch any other unexpected error
175 | logger.error(f"Unexpected error loading state from {state_file}: {e}", exc_info=True)
176 |
177 | logger.debug(f"Tournament {tournament_id} not found in any storage location during scan.")
178 | return None
179 |
180 | def _load_all_tournaments(self):
181 | logger.info(f"Scanning {TOURNAMENT_STORAGE_BASE} for existing tournaments...")
182 | count = 0
183 | if not TOURNAMENT_STORAGE_BASE.exists():
184 | logger.warning("Tournament storage directory does not exist. No tournaments loaded.")
185 | return
186 |
187 | for item in TOURNAMENT_STORAGE_BASE.iterdir():
188 | if item.is_dir():
189 | # Attempt to load tournament_state.json from this directory
190 | state_file = item / "tournament_state.json"
191 | if state_file.exists():
192 | try:
193 | with open(state_file, 'r', encoding='utf-8') as f:
194 | data = json.load(f)
195 | tournament_id_from_file = data.get("tournament_id")
196 | if not tournament_id_from_file:
197 | logger.warning(f"Skipping directory {item.name}, tournament_state.json missing 'tournament_id'.")
198 | continue
199 |
200 | if tournament_id_from_file not in self.tournaments: # Avoid reloading if already cached by some other means
201 | # Use the get_tournament method which handles re-instantiating evaluators
202 | loaded_tournament = self.get_tournament(tournament_id_from_file, force_reload=True)
203 | if loaded_tournament:
204 | count += 1
205 | logger.debug(f"Loaded tournament '{loaded_tournament.name}' (ID: {loaded_tournament.tournament_id}) from {item.name}")
206 | else:
207 | logger.warning(f"Failed to fully load tournament from directory: {item.name} (ID in file: {tournament_id_from_file})")
208 | except (IOError, json.JSONDecodeError, ValidationError) as e:
209 | logger.warning(f"Error loading tournament from directory {item.name}: {e}")
210 | except Exception as e:
211 | logger.error(f"Unexpected error loading tournament from {item.name}: {e}", exc_info=True)
212 | logger.info(f"Finished scan. Loaded {count} existing tournaments into manager.")
213 |
214 | def start_tournament_execution(self, tournament_id: str) -> bool:
215 | logger.debug(f"Attempting to start tournament execution for {tournament_id}")
216 | tournament = self.get_tournament(tournament_id) # Ensures evaluators are loaded
217 | if not tournament:
218 | logger.error(f"Cannot start execution: Tournament {tournament_id} not found.")
219 | return False
220 |
221 | if tournament.status not in [TournamentStatus.PENDING, TournamentStatus.CREATED]:
222 | logger.warning(f"Tournament {tournament_id} is not in a runnable state ({tournament.status}). Cannot start.")
223 | return False
224 |
225 | tournament.status = TournamentStatus.RUNNING # Or QUEUED if worker mode is implemented
226 | tournament.start_time = datetime.now(timezone.utc)
227 | tournament.current_round = 0 # Explicitly set to 0 when starting
228 | # Ensure rounds_results is initialized if empty
229 | if not tournament.rounds_results:
230 | tournament.rounds_results = [
231 | TournamentRoundResult(round_num=i) for i in range(tournament.config.rounds)
232 | ]
233 |
234 | self._save_tournament_state(tournament)
235 | logger.info(f"Tournament {tournament_id} status set to {tournament.status}, ready for async execution.")
236 |
237 | try:
238 | from ultimate_mcp_server.core.tournaments.tasks import (
239 | run_tournament_async, # Local import
240 | )
241 | asyncio.create_task(run_tournament_async(tournament_id))
242 | logger.info(f"Asyncio task created for tournament {tournament_id}.")
243 | return True
244 | except Exception as e:
245 | logger.error(f"Error creating asyncio task for tournament {tournament_id}: {e}", exc_info=True)
246 | tournament.status = TournamentStatus.FAILED
247 | tournament.error_message = f"Failed during asyncio task creation: {str(e)}"
248 | tournament.end_time = datetime.now(timezone.utc)
249 | self._save_tournament_state(tournament)
250 | return False
251 |
252 | async def cancel_tournament(self, tournament_id: str) -> Tuple[bool, str, TournamentStatus]: # Return final status
253 | """Attempts to cancel a tournament. Returns success, message, and final status."""
254 | tournament = self.get_tournament(tournament_id, force_reload=True)
255 | if not tournament:
256 | logger.warning(f"Cannot cancel non-existent tournament {tournament_id}")
257 | # Use FAILED or a specific status for "not found" if added to enum,
258 | # or rely on the tool layer to raise 404. For manager, FAILED can represent this.
259 | return False, "Tournament not found.", TournamentStatus.FAILED
260 |
261 | current_status = tournament.status
262 | final_status = current_status # Default to current status if no change
263 | message = ""
264 |
265 | if current_status == TournamentStatus.RUNNING or current_status == TournamentStatus.QUEUED:
266 | logger.info(f"Attempting to cancel tournament {tournament_id} (status: {current_status})...")
267 | tournament.status = TournamentStatus.CANCELLED
268 | tournament.error_message = tournament.error_message or "Tournament cancelled by user request."
269 | tournament.end_time = datetime.now(timezone.utc)
270 | final_status = TournamentStatus.CANCELLED
271 | message = "Cancellation requested. Tournament status set to CANCELLED."
272 | self._save_tournament_state(tournament)
273 | logger.info(f"Tournament {tournament_id} status set to CANCELLED.")
274 | # The background task needs to observe this status.
275 | return True, message, final_status
276 | elif current_status in [TournamentStatus.COMPLETED, TournamentStatus.FAILED, TournamentStatus.CANCELLED]:
277 | message = f"Tournament {tournament_id} is already finished or cancelled (Status: {current_status})."
278 | logger.warning(message)
279 | return False, message, final_status
280 | elif current_status == TournamentStatus.PENDING or current_status == TournamentStatus.CREATED:
281 | tournament.status = TournamentStatus.CANCELLED
282 | tournament.error_message = "Tournament cancelled before starting."
283 | tournament.end_time = datetime.now(timezone.utc)
284 | final_status = TournamentStatus.CANCELLED
285 | message = "Pending/Created tournament cancelled successfully."
286 | self._save_tournament_state(tournament)
287 | logger.info(f"Pending/Created tournament {tournament_id} cancelled.")
288 | return True, message, final_status
289 | else:
290 | # Should not happen, but handle unknown state
291 | message = f"Tournament {tournament_id} is in an unexpected state ({current_status}). Cannot determine cancellation action."
292 | logger.error(message)
293 | return False, message, current_status
294 |
295 |
296 | def list_tournaments(self) -> List[Dict[str, Any]]:
297 | # Ensure cache is up-to-date if new tournaments might have been added externally (less likely with file storage)
298 | # self._load_all_tournaments() # Consider if this is too expensive for every list call
299 |
300 | basic_list = []
301 | for t_data in self.tournaments.values():
302 | basic_list.append({
303 | "tournament_id": t_data.tournament_id,
304 | "name": t_data.name,
305 | "tournament_type": t_data.config.tournament_type,
306 | "status": t_data.status,
307 | "current_round": t_data.current_round,
308 | "total_rounds": t_data.config.rounds,
309 | "created_at": t_data.created_at.isoformat() if t_data.created_at else None, # Ensure ISO format
310 | "updated_at": t_data.updated_at.isoformat() if t_data.updated_at else None,
311 | "start_time": t_data.start_time.isoformat() if t_data.start_time else None,
312 | "end_time": t_data.end_time.isoformat() if t_data.end_time else None,
313 | })
314 | basic_list.sort(key=lambda x: x['created_at'] or '', reverse=True) # Handle None created_at for sorting
315 | return basic_list
316 |
317 | def _get_storage_path(self, tournament_id: str, tournament_name: str) -> Path:
318 | timestamp_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
319 | # Sanitize tournament name for use in path
320 | safe_name = re.sub(r'[^\w\s-]', '', tournament_name).strip().replace(' ', '_')
321 | safe_name = re.sub(r'[-\s]+', '-', safe_name) # Replace multiple spaces/hyphens with single hyphen
322 | safe_name = safe_name[:50] # Limit length
323 |
324 | # Use first 8 chars of UUID for brevity if name is too generic or empty
325 | uuid_suffix = tournament_id.split('-')[0]
326 |
327 | folder_name = f"{timestamp_str}_{safe_name}_{uuid_suffix}" if safe_name else f"{timestamp_str}_{uuid_suffix}"
328 |
329 | path = TOURNAMENT_STORAGE_BASE / folder_name
330 | path.mkdir(parents=True, exist_ok=True) # Ensure directory is created
331 | return path
332 |
333 | tournament_manager = TournamentManager()
```
--------------------------------------------------------------------------------
/examples/audio_transcription_demo.py:
--------------------------------------------------------------------------------
```python
1 | """Demonstration script for audio transcription using faster-whisper.
2 |
3 | This version uses the faster-whisper library which offers better performance than whisper.cpp.
4 | """
5 |
6 | import asyncio
7 | import os
8 | import sys
9 | import time
10 | from pathlib import Path
11 | from typing import Any, Dict, List, Tuple
12 |
13 | from rich import box
14 | from rich.console import Console
15 | from rich.live import Live
16 | from rich.markup import escape
17 | from rich.panel import Panel
18 | from rich.progress import (
19 | BarColumn,
20 | Progress,
21 | SpinnerColumn,
22 | TextColumn,
23 | TimeElapsedColumn,
24 | TimeRemainingColumn,
25 | )
26 | from rich.rule import Rule
27 | from rich.table import Table
28 |
29 | # Add the project root to the Python path
30 | # This allows finding the ultimate package when running the script directly
31 | project_root = Path(__file__).resolve().parents[1]
32 | sys.path.insert(0, str(project_root))
33 |
34 | EXAMPLE_DIR = Path(__file__).parent
35 | DATA_DIR = EXAMPLE_DIR / "data"
36 | SAMPLE_AUDIO_PATH = str(DATA_DIR / "Steve_Jobs_Introducing_The_iPhone_compressed.mp3")
37 |
38 | from ultimate_mcp_server.utils import get_logger # noqa: E402
39 |
40 | # --- Configuration ---
41 | logger = get_logger("audio_demo")
42 |
43 | # Get the directory of the current script
44 | SCRIPT_DIR = Path(__file__).parent.resolve()
45 | DATA_DIR = SCRIPT_DIR / "data"
46 |
47 | # Define allowed audio extensions
48 | ALLOWED_EXTENSIONS = [".mp3", ".wav", ".flac", ".ogg", ".m4a"]
49 |
50 | # --- Helper Functions ---
51 | def find_audio_files(directory: Path) -> List[Path]:
52 | """Finds audio files with allowed extensions in the given directory."""
53 | return [p for p in directory.iterdir() if p.is_file() and p.suffix.lower() in ALLOWED_EXTENSIONS]
54 |
55 | def format_timestamp(seconds: float) -> str:
56 | """Format seconds into a timestamp string."""
57 | hours = int(seconds / 3600)
58 | minutes = int((seconds % 3600) / 60)
59 | secs = seconds % 60
60 | if hours > 0:
61 | return f"{hours:02d}:{minutes:02d}:{secs:05.2f}"
62 | else:
63 | return f"{minutes:02d}:{secs:05.2f}"
64 |
65 | def detect_device() -> Tuple[str, str, str]:
66 | """Detect if CUDA GPU is available and return appropriate device and compute_type."""
67 | try:
68 | # Import torch to check if CUDA is available
69 | import torch
70 | if torch.cuda.is_available():
71 | # Get GPU info for display
72 | gpu_name = torch.cuda.get_device_name(0)
73 | return "cuda", "float16", gpu_name
74 | else:
75 | return "cpu", "int8", None
76 | except ImportError:
77 | # If torch is not available, try to directly check for NVIDIA GPUs with ctranslate2
78 | try:
79 | import subprocess
80 | nvidia_smi_output = subprocess.check_output(["nvidia-smi", "-L"], text=True, stderr=subprocess.DEVNULL)
81 | if "GPU" in nvidia_smi_output:
82 | # Extract GPU name
83 | gpu_name = nvidia_smi_output.strip().split(':')[1].strip().split('(')[0].strip()
84 | return "cuda", "float16", gpu_name
85 | else:
86 | return "cpu", "int8", None
87 | except Exception:
88 | # If all else fails, default to CPU
89 | return "cpu", "int8", None
90 |
91 | def generate_markdown_transcript(transcript: Dict[str, Any], file_path: str) -> str:
92 | """Generate a markdown version of the transcript with metadata."""
93 | audio_filename = os.path.basename(file_path)
94 | metadata = transcript.get("metadata", {})
95 | segments = transcript.get("segments", [])
96 |
97 | markdown = [
98 | f"# Transcript: {audio_filename}",
99 | "",
100 | "## Metadata",
101 | f"- **Duration:** {format_timestamp(metadata.get('duration', 0))}",
102 | f"- **Language:** {metadata.get('language', 'unknown')} (confidence: {metadata.get('language_probability', 0):.2f})",
103 | f"- **Transcription Model:** {metadata.get('model', 'unknown')}",
104 | f"- **Device:** {metadata.get('device', 'unknown')}",
105 | f"- **Processing Time:** {transcript.get('processing_time', {}).get('total', 0):.2f} seconds",
106 | "",
107 | "## Full Transcript",
108 | "",
109 | transcript.get("enhanced_transcript", transcript.get("raw_transcript", "")),
110 | "",
111 | "## Segments",
112 | ""
113 | ]
114 |
115 | for segment in segments:
116 | start_time = format_timestamp(segment["start"])
117 | end_time = format_timestamp(segment["end"])
118 | markdown.append(f"**[{start_time} → {end_time}]** {segment['text']}")
119 | markdown.append("")
120 |
121 | return "\n".join(markdown)
122 |
123 | def save_markdown_transcript(transcript: Dict[str, Any], file_path: str) -> Tuple[str, str]:
124 | """Save the transcript as markdown and text files.
125 |
126 | Returns:
127 | Tuple containing paths to markdown and text files
128 | """
129 | audio_path = Path(file_path)
130 | markdown_path = audio_path.with_suffix(".md")
131 | txt_path = audio_path.with_suffix(".txt")
132 |
133 | # Generate and save markdown (enhanced transcript)
134 | markdown_content = generate_markdown_transcript(transcript, file_path)
135 | with open(markdown_path, "w", encoding="utf-8") as f:
136 | f.write(markdown_content)
137 |
138 | # Save raw transcript as plain text file
139 | with open(txt_path, "w", encoding="utf-8") as f:
140 | f.write(transcript.get("raw_transcript", ""))
141 |
142 | return str(markdown_path), str(txt_path)
143 |
144 | async def enhance_transcript_with_llm(raw_transcript: str, console: Console) -> str:
145 | """Enhance the transcript using an LLM to improve readability."""
146 | try:
147 | from ultimate_mcp_server.tools.completion import chat_completion
148 | except ImportError:
149 | console.print("[yellow]Ultimate MCP Server tools not available for enhancement. Using raw transcript.[/yellow]")
150 | return raw_transcript
151 |
152 | # Setup progress display
153 | with Progress(
154 | SpinnerColumn(),
155 | TextColumn("[bold green]Enhancing transcript with LLM[/bold green]"),
156 | BarColumn(),
157 | TextColumn("[cyan]{task.percentage:>3.0f}%"),
158 | TimeElapsedColumn(),
159 | console=console
160 | ) as progress:
161 | enhance_task = progress.add_task("Enhancing...", total=100)
162 |
163 | try:
164 | # Create the prompt for transcript enhancement
165 | system_prompt = """You are an expert transcription editor. Your task is to enhance the following raw transcript:
166 | 1. Fix any spelling or grammar errors
167 | 2. Add proper punctuation and capitalization
168 | 3. Format the text into logical paragraphs
169 | 4. Remove filler words and repeated phrases
170 | 5. Preserve the original meaning and all factual content
171 | 6. Format numbers, acronyms, and technical terms consistently
172 | 7. Keep the text faithful to the original but make it more readable"""
173 |
174 | user_prompt = f"Here is the raw transcript to enhance:\n\n{raw_transcript}\n\nPlease provide only the enhanced transcript without explanations."
175 |
176 | # Split the transcript into chunks if it's very long
177 | progress.update(enhance_task, completed=20)
178 |
179 | # Call the chat completion function
180 | result = await chat_completion(
181 | system_prompt=system_prompt,
182 | messages=[{"role": "user", "content": user_prompt}],
183 | model="gpt-4.1-mini",
184 | temperature=0.3,
185 | )
186 |
187 | progress.update(enhance_task, completed=90)
188 |
189 | enhanced_transcript = result.get("content", raw_transcript)
190 |
191 | progress.update(enhance_task, completed=100)
192 |
193 | return enhanced_transcript
194 |
195 | except Exception as e:
196 | console.print(f"[red]Error enhancing transcript: {e}[/red]")
197 | progress.update(enhance_task, completed=100)
198 | return raw_transcript
199 |
200 | async def transcribe_with_faster_whisper(file_path: str, console: Console) -> Dict[str, Any]:
201 | """Transcribe audio using faster-whisper library with real-time progress updates."""
202 | logger.info(f"Processing file: {file_path}")
203 |
204 | # Check if audio file exists
205 | if not os.path.exists(file_path):
206 | logger.error(f"Audio file not found at {file_path}")
207 | return {"success": False, "error": f"Audio file not found at {file_path}"}
208 |
209 | try:
210 | # Import faster-whisper - install if not present
211 | try:
212 | from faster_whisper import WhisperModel
213 | except ImportError:
214 | console.print("[yellow]faster-whisper not installed. Installing now...[/yellow]")
215 | import subprocess
216 | subprocess.check_call([sys.executable, "-m", "pip", "install", "faster-whisper"])
217 | from faster_whisper import WhisperModel
218 |
219 | # Start timing
220 | start_time = time.time()
221 |
222 | # Get audio duration for progress calculation
223 | audio_duration = 0
224 | with Progress(
225 | SpinnerColumn(),
226 | TextColumn("[bold blue]{task.description}"),
227 | console=console
228 | ) as progress:
229 | analysis_task = progress.add_task("Analyzing audio file...", total=None)
230 | try:
231 | import av
232 | with av.open(file_path) as container:
233 | # Get duration in seconds
234 | if container.duration is not None:
235 | audio_duration = container.duration / 1000000 # microseconds to seconds
236 | console.print(f"Audio duration: [cyan]{format_timestamp(audio_duration)}[/cyan] seconds")
237 | progress.update(analysis_task, completed=True)
238 | except Exception as e:
239 | console.print(f"[yellow]Could not determine audio duration: {e}[/yellow]")
240 |
241 | # Detect device (CPU or GPU)
242 | device, compute_type, gpu_name = detect_device()
243 |
244 | # Load the model with progress
245 | model_size = "large-v3"
246 | console.print(f"Loading Whisper model: [bold]{model_size}[/bold]")
247 |
248 | if device == "cuda" and gpu_name:
249 | console.print(f"Using device: [bold green]GPU ({gpu_name})[/bold green], compute_type: [bold cyan]{compute_type}[/bold cyan]")
250 | else:
251 | console.print(f"Using device: [bold yellow]CPU[/bold yellow], compute_type: [bold cyan]{compute_type}[/bold cyan]")
252 |
253 | with Progress(
254 | SpinnerColumn(),
255 | TextColumn("[bold blue]{task.description}"),
256 | BarColumn(),
257 | TextColumn("[bold cyan]{task.percentage:>3.0f}%"),
258 | console=console
259 | ) as progress:
260 | load_task = progress.add_task("Loading model...", total=100)
261 | model = WhisperModel(model_size, device=device, compute_type=compute_type, download_root="./models")
262 | progress.update(load_task, completed=100)
263 |
264 | # Setup progress display for transcription
265 | console.print("\n[bold green]Starting transcription...[/bold green]")
266 |
267 | # Create table for displaying transcribed segments in real time
268 | table = Table(title="Transcription Progress", expand=True, box=box.ROUNDED)
269 | table.add_column("Segment")
270 | table.add_column("Time", style="yellow")
271 | table.add_column("Text", style="white")
272 |
273 | # Progress bar for overall transcription
274 | progress = Progress(
275 | SpinnerColumn(),
276 | TextColumn("[bold blue]Transcribing..."),
277 | BarColumn(),
278 | TextColumn("[cyan]{task.percentage:>3.0f}%"),
279 | TimeElapsedColumn(),
280 | TimeRemainingColumn(),
281 | )
282 |
283 | # Add main progress task
284 | transcribe_task = progress.add_task("Transcription", total=100)
285 |
286 | # Combine table and progress bar
287 | transcription_display = Table.grid()
288 | transcription_display.add_row(table)
289 | transcription_display.add_row(progress)
290 |
291 | segments_list = []
292 | segment_idx = 0
293 |
294 | # Run the transcription with live updating display
295 | with Live(transcription_display, console=console, refresh_per_second=10) as live:
296 | # Run transcription
297 | segments, info = model.transcribe(
298 | file_path,
299 | beam_size=5,
300 | vad_filter=True,
301 | word_timestamps=True,
302 | language="en", # Specify language to avoid language detection phase
303 | )
304 |
305 | # Process segments as they become available
306 | for segment in segments:
307 | segments_list.append(segment)
308 |
309 | # Update progress bar based on timestamp
310 | if audio_duration > 0:
311 | current_progress = min(int((segment.end / audio_duration) * 100), 99)
312 | progress.update(transcribe_task, completed=current_progress)
313 |
314 | # Add segment to table
315 | timestamp = f"[{format_timestamp(segment.start)} → {format_timestamp(segment.end)}]"
316 | table.add_row(
317 | f"[cyan]#{segment_idx+1}[/cyan]",
318 | timestamp,
319 | segment.text
320 | )
321 |
322 | # Update the live display
323 | live.update(transcription_display)
324 | segment_idx += 1
325 |
326 | # Finish progress
327 | progress.update(transcribe_task, completed=100)
328 |
329 | # Build full transcript
330 | raw_transcript = " ".join([segment.text for segment in segments_list])
331 |
332 | # Convert segments to dictionary format
333 | segments_dict = []
334 | for segment in segments_list:
335 | segments_dict.append({
336 | "start": segment.start,
337 | "end": segment.end,
338 | "text": segment.text,
339 | "words": [{"word": word.word, "start": word.start, "end": word.end, "probability": word.probability}
340 | for word in (segment.words or [])]
341 | })
342 |
343 | # Enhance the transcript with LLM
344 | console.print("\n[bold green]Raw transcription complete. Now enhancing the transcript...[/bold green]")
345 | enhanced_transcript = await enhance_transcript_with_llm(raw_transcript, console)
346 |
347 | # Calculate processing time
348 | processing_time = time.time() - start_time
349 |
350 | # Create the result dictionary
351 | result = {
352 | "success": True,
353 | "raw_transcript": raw_transcript,
354 | "enhanced_transcript": enhanced_transcript,
355 | "segments": segments_dict,
356 | "metadata": {
357 | "language": info.language,
358 | "language_probability": info.language_probability,
359 | "model": model_size,
360 | "duration": audio_duration,
361 | "device": device
362 | },
363 | "processing_time": {
364 | "total": processing_time,
365 | "transcription": processing_time
366 | }
367 | }
368 |
369 | # Save the transcripts
370 | markdown_path, txt_path = save_markdown_transcript(result, file_path)
371 | console.print(f"\n[bold green]Saved enhanced transcript to:[/bold green] [cyan]{markdown_path}[/cyan]")
372 | console.print(f"[bold green]Saved raw transcript to:[/bold green] [cyan]{txt_path}[/cyan]")
373 |
374 | return result
375 |
376 | except Exception as e:
377 | import traceback
378 | logger.error(f"Transcription error: {e}")
379 | logger.error(traceback.format_exc())
380 | return {"success": False, "error": f"Transcription error: {e}"}
381 |
382 |
383 | async def main():
384 | """Runs the audio transcription demonstrations."""
385 |
386 | logger.info("Starting Audio Transcription Demo", emoji_key="audio")
387 |
388 | console = Console()
389 | console.print(Rule("[bold green]Audio Transcription Demo (faster-whisper)[/bold green]"))
390 |
391 | # --- Find Audio Files ---
392 | audio_files = find_audio_files(DATA_DIR)
393 | if not audio_files:
394 | console.print(f"[bold red]Error:[/bold red] No audio files found in {DATA_DIR}. Please place audio files (e.g., .mp3, .wav) there.")
395 | return
396 |
397 | console.print(f"Found {len(audio_files)} audio file(s) in {DATA_DIR}:")
398 | for f in audio_files:
399 | console.print(f"- [cyan]{f.name}[/cyan]")
400 | console.print()
401 |
402 | # --- Process Each File ---
403 | for file_path in audio_files:
404 | try:
405 | console.print(Panel(
406 | f"Processing file: [cyan]{escape(str(file_path))}[/cyan]",
407 | title="Audio Transcription",
408 | border_style="blue"
409 | ))
410 |
411 | # Call our faster-whisper transcription function
412 | result = await transcribe_with_faster_whisper(str(file_path), console)
413 |
414 | if result.get("success", False):
415 | console.print(f"[green]Transcription successful for {escape(str(file_path))}.[/green]")
416 |
417 | # Show comparison of raw vs enhanced transcript
418 | if "raw_transcript" in result and "enhanced_transcript" in result:
419 | comparison = Table(title="Transcript Comparison", expand=True, box=box.ROUNDED)
420 | comparison.add_column("Raw Transcript", style="yellow")
421 | comparison.add_column("Enhanced Transcript", style="green")
422 |
423 | # Limit to a preview of the first part
424 | raw_preview = result["raw_transcript"][:500] + ("..." if len(result["raw_transcript"]) > 500 else "")
425 | enhanced_preview = result["enhanced_transcript"][:500] + ("..." if len(result["enhanced_transcript"]) > 500 else "")
426 |
427 | comparison.add_row(raw_preview, enhanced_preview)
428 | console.print(comparison)
429 |
430 | # Display metadata if available
431 | if "metadata" in result and result["metadata"]:
432 | console.print("[bold]Metadata:[/bold]")
433 | for key, value in result["metadata"].items():
434 | console.print(f" - [cyan]{key}[/cyan]: {value}")
435 |
436 | # Display processing time
437 | if "processing_time" in result:
438 | console.print("[bold]Processing Times:[/bold]")
439 | for key, value in result["processing_time"].items():
440 | if isinstance(value, (int, float)):
441 | console.print(f" - [cyan]{key}[/cyan]: {value:.2f}s")
442 | else:
443 | console.print(f" - [cyan]{key}[/cyan]: {value}")
444 | else:
445 | console.print("[yellow]Warning:[/yellow] No transcript was returned.")
446 | else:
447 | console.print(f"[bold red]Transcription failed:[/bold red] {escape(result.get('error', 'Unknown error'))}")
448 |
449 | console.print() # Add a blank line between files
450 |
451 | except Exception as outer_e:
452 | import traceback
453 | console.print(f"[bold red]Unexpected error processing file {escape(str(file_path))}:[/bold red] {escape(str(outer_e))}")
454 | console.print("[bold red]Traceback:[/bold red]")
455 | console.print(escape(traceback.format_exc()))
456 | continue # Move to the next file
457 |
458 | logger.info("Audio Transcription Demo Finished", emoji_key="audio")
459 |
460 | if __name__ == "__main__":
461 | # Basic error handling for the async execution itself
462 | try:
463 | asyncio.run(main())
464 | except Exception as e:
465 | print(f"An error occurred running the demo: {e}")
466 | import traceback
467 | traceback.print_exc()
```
--------------------------------------------------------------------------------
/ultimate_mcp_server/utils/logging/panels.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Panel definitions for Ultimate MCP Server logging system.
3 |
4 | This module provides specialized panels for different types of output like
5 | headers, results, errors, warnings, etc.
6 | """
7 |
8 | from typing import Any, Dict, List, Optional, Union
9 |
10 | from rich.box import SIMPLE
11 | from rich.columns import Columns
12 | from rich.console import ConsoleRenderable
13 | from rich.panel import Panel
14 | from rich.syntax import Syntax
15 | from rich.table import Table
16 | from rich.text import Text
17 |
18 | from ultimate_mcp_server.utils.logging.console import console
19 | from ultimate_mcp_server.utils.logging.emojis import ERROR, INFO, SUCCESS, WARNING
20 |
21 |
22 | class HeaderPanel:
23 | """Panel for section headers."""
24 |
25 | def __init__(
26 | self,
27 | title: str,
28 | subtitle: Optional[str] = None,
29 | component: Optional[str] = None,
30 | style: str = "bright_blue",
31 | ):
32 | """Initialize a header panel.
33 |
34 | Args:
35 | title: Panel title
36 | subtitle: Optional subtitle
37 | component: Optional component name
38 | style: Panel style
39 | """
40 | self.title = title
41 | self.subtitle = subtitle
42 | self.component = component
43 | self.style = style
44 |
45 | def __rich__(self) -> ConsoleRenderable:
46 | """Render the panel."""
47 | # Create the title text
48 | title_text = Text()
49 | title_text.append("- ", style="bright_black")
50 | title_text.append(self.title, style="bold")
51 | title_text.append(" -", style="bright_black")
52 |
53 | # Create the content
54 | content = Text()
55 |
56 | if self.component:
57 | content.append(f"[{self.component}] ", style="component")
58 |
59 | if self.subtitle:
60 | content.append(self.subtitle)
61 |
62 | return Panel(
63 | content,
64 | title=title_text,
65 | title_align="center",
66 | border_style=self.style,
67 | expand=True,
68 | padding=(1, 2),
69 | )
70 |
71 | class ResultPanel:
72 | """Panel for displaying operation results."""
73 |
74 | def __init__(
75 | self,
76 | title: str,
77 | results: Union[List[Dict[str, Any]], Dict[str, Any]],
78 | status: str = "success",
79 | component: Optional[str] = None,
80 | show_count: bool = True,
81 | compact: bool = False,
82 | ):
83 | """Initialize a result panel.
84 |
85 | Args:
86 | title: Panel title
87 | results: Results to display (list of dicts or single dict)
88 | status: Result status (success, warning, error)
89 | component: Optional component name
90 | show_count: Whether to show result count in title
91 | compact: Whether to use a compact display style
92 | """
93 | self.title = title
94 | self.results = results if isinstance(results, list) else [results]
95 | self.status = status.lower()
96 | self.component = component
97 | self.show_count = show_count
98 | self.compact = compact
99 |
100 | def __rich__(self) -> ConsoleRenderable:
101 | """Render the panel."""
102 | # Determine style and emoji based on status
103 | if self.status == "success":
104 | style = "result.success"
105 | emoji = SUCCESS
106 | elif self.status == "warning":
107 | style = "result.warning"
108 | emoji = WARNING
109 | elif self.status == "error":
110 | style = "result.error"
111 | emoji = ERROR
112 | else:
113 | style = "result.info"
114 | emoji = INFO
115 |
116 | # Create title
117 | title_text = Text()
118 | title_text.append(f"{emoji} ", style=style)
119 | title_text.append(self.title, style=f"bold {style}")
120 |
121 | if self.show_count and len(self.results) > 0:
122 | title_text.append(f" ({len(self.results)} items)", style="bright_black")
123 |
124 | # Create content
125 | if self.compact:
126 | # Compact mode - just show key/value list
127 | rows = []
128 | for item in self.results:
129 | for k, v in item.items():
130 | rows.append({
131 | "key": k,
132 | "value": self._format_value(v),
133 | })
134 |
135 | table = Table(box=None, expand=True, show_header=False)
136 | table.add_column("Key", style="data.key")
137 | table.add_column("Value", style="", overflow="fold")
138 |
139 | for row in rows:
140 | table.add_row(row["key"], row["value"])
141 |
142 | content = table
143 | else:
144 | # Full mode - create a table per result item
145 | tables = []
146 |
147 | for i, item in enumerate(self.results):
148 | if not item: # Skip empty items
149 | continue
150 |
151 | table = Table(
152 | box=SIMPLE,
153 | title=f"Item {i+1}" if len(self.results) > 1 else None,
154 | title_style="bright_black",
155 | expand=True,
156 | show_header=False,
157 | )
158 | table.add_column("Key", style="data.key")
159 | table.add_column("Value", style="", overflow="fold")
160 |
161 | for k, v in item.items():
162 | table.add_row(k, self._format_value(v))
163 |
164 | tables.append(table)
165 |
166 | content = Columns(tables) if len(tables) > 1 else tables[0] if tables else Text("No results")
167 |
168 | # Return the panel
169 | return Panel(
170 | content,
171 | title=title_text,
172 | border_style=style,
173 | expand=True,
174 | padding=(1, 1),
175 | )
176 |
177 | def _format_value(self, value: Any) -> str:
178 | """Format a value for display.
179 |
180 | Args:
181 | value: Value to format
182 |
183 | Returns:
184 | Formatted string
185 | """
186 | if value is None:
187 | return "[dim]None[/dim]"
188 | elif isinstance(value, bool):
189 | return str(value)
190 | elif isinstance(value, (int, float)):
191 | return str(value)
192 | elif isinstance(value, list):
193 | return ", ".join(self._format_value(v) for v in value[:5]) + \
194 | (f" ... (+{len(value) - 5} more)" if len(value) > 5 else "")
195 | elif isinstance(value, dict):
196 | if len(value) == 0:
197 | return "{}"
198 | else:
199 | return "{...}" # Just indicate there's content
200 | else:
201 | return str(value)
202 |
203 | class InfoPanel:
204 | """Panel for displaying information."""
205 |
206 | def __init__(
207 | self,
208 | title: str,
209 | content: Union[str, List[str], Dict[str, Any]],
210 | icon: Optional[str] = None,
211 | style: str = "info",
212 | ):
213 | """Initialize an information panel.
214 |
215 | Args:
216 | title: Panel title
217 | content: Content to display (string, list, or dict)
218 | icon: Emoji or icon character
219 | style: Style name to apply (from theme)
220 | """
221 | self.title = title
222 | self.content = content
223 | self.icon = icon or INFO
224 | self.style = style
225 |
226 | def __rich__(self) -> ConsoleRenderable:
227 | """Render the panel."""
228 | # Create title
229 | title_text = Text()
230 | title_text.append(f"{self.icon} ", style=self.style)
231 | title_text.append(self.title, style=f"bold {self.style}")
232 |
233 | # Format content based on type
234 | if isinstance(self.content, str):
235 | content = Text(self.content)
236 | elif isinstance(self.content, list):
237 | content = Text()
238 | for i, item in enumerate(self.content):
239 | if i > 0:
240 | content.append("\n")
241 | content.append(f"• {item}")
242 | elif isinstance(self.content, dict):
243 | # Create a table for dict content
244 | table = Table(box=None, expand=True, show_header=False)
245 | table.add_column("Key", style="data.key")
246 | table.add_column("Value", style="", overflow="fold")
247 |
248 | for k, v in self.content.items():
249 | table.add_row(k, str(v))
250 |
251 | content = table
252 | else:
253 | content = Text(str(self.content))
254 |
255 | # Return the panel
256 | return Panel(
257 | content,
258 | title=title_text,
259 | border_style=self.style,
260 | expand=True,
261 | padding=(1, 2),
262 | )
263 |
264 | class WarningPanel:
265 | """Panel for displaying warnings."""
266 |
267 | def __init__(
268 | self,
269 | title: Optional[str] = None,
270 | message: str = "",
271 | details: Optional[List[str]] = None,
272 | ):
273 | """Initialize a warning panel.
274 |
275 | Args:
276 | title: Optional panel title
277 | message: Main warning message
278 | details: Optional list of detail points
279 | """
280 | self.title = title or "Warning"
281 | self.message = message
282 | self.details = details or []
283 |
284 | def __rich__(self) -> ConsoleRenderable:
285 | """Render the panel."""
286 | # Create title
287 | title_text = Text()
288 | title_text.append(f"{WARNING} ", style="warning")
289 | title_text.append(self.title, style="bold warning")
290 |
291 | # Create content
292 | content = Text()
293 |
294 | # Add message
295 | if self.message:
296 | content.append(self.message)
297 |
298 | # Add details if any
299 | if self.details and len(self.details) > 0:
300 | if self.message:
301 | content.append("\n\n")
302 |
303 | content.append("Details:", style="bold")
304 | content.append("\n")
305 |
306 | for i, detail in enumerate(self.details):
307 | if i > 0:
308 | content.append("\n")
309 | content.append(f"• {detail}")
310 |
311 | # Return the panel
312 | return Panel(
313 | content,
314 | title=title_text,
315 | border_style="warning",
316 | expand=True,
317 | padding=(1, 2),
318 | )
319 |
320 | class ErrorPanel:
321 | """Panel for displaying errors."""
322 |
323 | def __init__(
324 | self,
325 | title: Optional[str] = None,
326 | message: str = "",
327 | details: Optional[str] = None,
328 | resolution_steps: Optional[List[str]] = None,
329 | error_code: Optional[str] = None,
330 | ):
331 | """Initialize an error panel.
332 |
333 | Args:
334 | title: Optional panel title
335 | message: Main error message
336 | details: Optional error details
337 | resolution_steps: Optional list of steps to resolve the error
338 | error_code: Optional error code for reference
339 | """
340 | self.title = title or "Error"
341 | self.message = message
342 | self.details = details
343 | self.resolution_steps = resolution_steps or []
344 | self.error_code = error_code
345 |
346 | def __rich__(self) -> ConsoleRenderable:
347 | """Render the panel."""
348 | # Create title
349 | title_text = Text()
350 | title_text.append(f"{ERROR} ", style="error")
351 | title_text.append(self.title, style="bold error")
352 |
353 | if self.error_code:
354 | title_text.append(f" [{self.error_code}]", style="bright_black")
355 |
356 | # Create content
357 | content = Text()
358 |
359 | # Add message
360 | if self.message:
361 | content.append(self.message, style="bold")
362 |
363 | # Add details if any
364 | if self.details:
365 | if self.message:
366 | content.append("\n\n")
367 |
368 | content.append(self.details)
369 |
370 | # Add resolution steps if any
371 | if self.resolution_steps and len(self.resolution_steps) > 0:
372 | if self.message or self.details:
373 | content.append("\n\n")
374 |
375 | content.append("Resolution steps:", style="bold")
376 | content.append("\n")
377 |
378 | for i, step in enumerate(self.resolution_steps):
379 | if i > 0:
380 | content.append("\n")
381 | content.append(f"{i+1}. {step}")
382 |
383 | # Return the panel
384 | return Panel(
385 | content,
386 | title=title_text,
387 | border_style="error",
388 | expand=True,
389 | padding=(1, 2),
390 | )
391 |
392 | class ToolOutputPanel:
393 | """Panel for displaying tool command output."""
394 |
395 | def __init__(
396 | self,
397 | tool: str,
398 | command: str,
399 | output: str,
400 | status: str = "success",
401 | duration: Optional[float] = None,
402 | ):
403 | """Initialize a tool output panel.
404 |
405 | Args:
406 | tool: Tool name (ripgrep, awk, jq, etc.)
407 | command: Command that was executed
408 | output: Command output text
409 | status: Execution status (success, error)
410 | duration: Optional execution duration in seconds
411 | """
412 | self.tool = tool
413 | self.command = command
414 | self.output = output
415 | self.status = status.lower()
416 | self.duration = duration
417 |
418 | def __rich__(self) -> ConsoleRenderable:
419 | """Render the panel."""
420 | # Determine style and emoji based on status
421 | if self.status == "success":
422 | style = "tool.success"
423 | emoji = SUCCESS
424 | else:
425 | style = "tool.error"
426 | emoji = ERROR
427 |
428 | # Create title
429 | title_text = Text()
430 | title_text.append(f"{emoji} ", style=style)
431 | title_text.append(f"{self.tool}", style=f"bold {style}")
432 |
433 | if self.duration is not None:
434 | title_text.append(f" ({self.duration:.2f}s)", style="tool.duration")
435 |
436 | # Create content
437 | content = Columns(
438 | [
439 | Panel(
440 | Text(self.command, style="tool.command"),
441 | title="Command",
442 | title_style="bright_black",
443 | border_style="tool.command",
444 | padding=(1, 1),
445 | ),
446 | Panel(
447 | Text(self.output, style="tool.output"),
448 | title="Output",
449 | title_style="bright_black",
450 | border_style="bright_black",
451 | padding=(1, 1),
452 | ),
453 | ],
454 | expand=True,
455 | padding=(0, 1),
456 | )
457 |
458 | # Return the panel
459 | return Panel(
460 | content,
461 | title=title_text,
462 | border_style=style,
463 | expand=True,
464 | padding=(1, 1),
465 | )
466 |
467 | class CodePanel:
468 | """Panel for displaying code with syntax highlighting."""
469 |
470 | def __init__(
471 | self,
472 | code: str,
473 | language: str = "python",
474 | title: Optional[str] = None,
475 | line_numbers: bool = True,
476 | highlight_lines: Optional[List[int]] = None,
477 | ):
478 | """Initialize a code panel.
479 |
480 | Args:
481 | code: The code to display
482 | language: Programming language for syntax highlighting
483 | title: Optional panel title
484 | line_numbers: Whether to show line numbers
485 | highlight_lines: List of line numbers to highlight
486 | """
487 | self.code = code
488 | self.language = language
489 | self.title = title
490 | self.line_numbers = line_numbers
491 | self.highlight_lines = highlight_lines
492 |
493 | def __rich__(self) -> ConsoleRenderable:
494 | """Render the panel."""
495 | # Create syntax highlighting component
496 | syntax = Syntax(
497 | self.code,
498 | self.language,
499 | theme="monokai",
500 | line_numbers=self.line_numbers,
501 | highlight_lines=self.highlight_lines,
502 | )
503 |
504 | # Create title
505 | if self.title:
506 | title_text = Text(self.title)
507 | else:
508 | title_text = Text()
509 | title_text.append(self.language.capitalize(), style="bright_blue bold")
510 | title_text.append(" Code", style="bright_black")
511 |
512 | # Return the panel
513 | return Panel(
514 | syntax,
515 | title=title_text,
516 | border_style="bright_blue",
517 | expand=True,
518 | padding=(0, 0),
519 | )
520 |
521 | # Helper functions for creating panels
522 |
523 | def display_header(
524 | title: str,
525 | subtitle: Optional[str] = None,
526 | component: Optional[str] = None,
527 | ) -> None:
528 | """Display a section header.
529 |
530 | Args:
531 | title: Section title
532 | subtitle: Optional subtitle
533 | component: Optional component name
534 | """
535 | panel = HeaderPanel(title, subtitle, component)
536 | console.print(panel)
537 |
538 | def display_results(
539 | title: str,
540 | results: Union[List[Dict[str, Any]], Dict[str, Any]],
541 | status: str = "success",
542 | component: Optional[str] = None,
543 | show_count: bool = True,
544 | compact: bool = False,
545 | ) -> None:
546 | """Display operation results.
547 |
548 | Args:
549 | title: Results title
550 | results: Results to display (list of dicts or single dict)
551 | status: Result status (success, warning, error)
552 | component: Optional component name
553 | show_count: Whether to show result count in title
554 | compact: Whether to use a compact display style
555 | """
556 | panel = ResultPanel(title, results, status, component, show_count, compact)
557 | console.print(panel)
558 |
559 | def display_info(
560 | title: str,
561 | content: Union[str, List[str], Dict[str, Any]],
562 | icon: Optional[str] = None,
563 | style: str = "info",
564 | ) -> None:
565 | """Display an information panel.
566 |
567 | Args:
568 | title: Panel title
569 | content: Content to display (string, list, or dict)
570 | icon: Emoji or icon character
571 | style: Style name to apply (from theme)
572 | """
573 | panel = InfoPanel(title, content, icon, style)
574 | console.print(panel)
575 |
576 | def display_warning(
577 | title: Optional[str] = None,
578 | message: str = "",
579 | details: Optional[List[str]] = None,
580 | ) -> None:
581 | """Display a warning panel.
582 |
583 | Args:
584 | title: Optional panel title
585 | message: Main warning message
586 | details: Optional list of detail points
587 | """
588 | panel = WarningPanel(title, message, details)
589 | console.print(panel)
590 |
591 | def display_error(
592 | title: Optional[str] = None,
593 | message: str = "",
594 | details: Optional[str] = None,
595 | resolution_steps: Optional[List[str]] = None,
596 | error_code: Optional[str] = None,
597 | ) -> None:
598 | """Display an error panel.
599 |
600 | Args:
601 | title: Optional panel title
602 | message: Main error message
603 | details: Optional error details
604 | resolution_steps: Optional list of steps to resolve the error
605 | error_code: Optional error code for reference
606 | """
607 | panel = ErrorPanel(title, message, details, resolution_steps, error_code)
608 | console.print(panel)
609 |
610 | def display_tool_output(
611 | tool: str,
612 | command: str,
613 | output: str,
614 | status: str = "success",
615 | duration: Optional[float] = None,
616 | ) -> None:
617 | """Display tool command output.
618 |
619 | Args:
620 | tool: Tool name (ripgrep, awk, jq, etc.)
621 | command: Command that was executed
622 | output: Command output text
623 | status: Execution status (success, error)
624 | duration: Optional execution duration in seconds
625 | """
626 | panel = ToolOutputPanel(tool, command, output, status, duration)
627 | console.print(panel)
628 |
629 | def display_code(
630 | code: str,
631 | language: str = "python",
632 | title: Optional[str] = None,
633 | line_numbers: bool = True,
634 | highlight_lines: Optional[List[int]] = None,
635 | ) -> None:
636 | """Display code with syntax highlighting.
637 |
638 | Args:
639 | code: The code to display
640 | language: Programming language for syntax highlighting
641 | title: Optional panel title
642 | line_numbers: Whether to show line numbers
643 | highlight_lines: List of line numbers to highlight
644 | """
645 | panel = CodePanel(code, language, title, line_numbers, highlight_lines)
646 | console.print(panel)
```
--------------------------------------------------------------------------------
/ultimate_mcp_server/core/evaluation/evaluators.py:
--------------------------------------------------------------------------------
```python
1 | # --- core/evaluation/evaluators.py ---
2 | import re
3 | from pathlib import Path
4 | from typing import Any, Dict, List, Literal, Optional
5 |
6 | from ultimate_mcp_server.core.evaluation.base import (
7 | EvaluationScore,
8 | Evaluator,
9 | register_evaluator,
10 | )
11 | from ultimate_mcp_server.core.models.tournament import ModelResponseData
12 | from ultimate_mcp_server.tools.completion import generate_completion
13 |
14 | # --- Import the sandbox execution tool ---
15 | from ultimate_mcp_server.tools.python_sandbox import (
16 | ProviderError,
17 | ToolError,
18 | ToolInputError,
19 | execute_python,
20 | )
21 | from ultimate_mcp_server.utils import get_logger
22 |
23 | logger = get_logger("ultimate_mcp_server.evaluation.evaluators")
24 |
25 |
26 | @register_evaluator
27 | class LLMGraderEvaluator(Evaluator):
28 | evaluator_type = "llm_grader"
29 |
30 | def __init__(self, config: Dict[str, Any]):
31 | super().__init__(config)
32 | self.grader_model_id = config.get("model_id", "anthropic/claude-3-5-haiku-20241022")
33 | self.rubric = config.get(
34 | "rubric",
35 | "Score the response on a scale of 0-100 for quality, relevance, and clarity. Explain your reasoning.",
36 | )
37 | self.score_extraction_regex_str = config.get(
38 | "score_extraction_regex", r"Score:\s*(\d{1,3})"
39 | )
40 | try:
41 | self.score_extraction_regex = re.compile(self.score_extraction_regex_str)
42 | except re.error as e:
43 | logger.error(
44 | f"Invalid regex for score_extraction_regex in LLMGrader: {self.score_extraction_regex_str}. Error: {e}"
45 | )
46 | self.score_extraction_regex = re.compile(r"Score:\s*(\d{1,3})")
47 |
48 | async def score(
49 | self,
50 | response_data: ModelResponseData,
51 | original_prompt: str,
52 | tournament_type: Literal["code", "text"],
53 | ) -> EvaluationScore:
54 | # ... (LLMGraderEvaluator code remains the same) ...
55 | content_to_grade = (
56 | response_data.extracted_code
57 | if tournament_type == "code" and response_data.extracted_code
58 | else response_data.response_text
59 | )
60 |
61 | if not content_to_grade:
62 | return EvaluationScore(score=0.0, details="No content to grade.")
63 |
64 | prompt = f"""Original Prompt:
65 | {original_prompt}
66 |
67 | Model Response to Evaluate:
68 | ---
69 | {content_to_grade}
70 | ---
71 |
72 | Rubric:
73 | {self.rubric}
74 |
75 | Please provide a score (0-100) and a brief justification. Format the score clearly, e.g., "Score: 90".
76 | """
77 | try:
78 | provider = self.grader_model_id.split("/")[0] if "/" in self.grader_model_id else None
79 |
80 | grader_response_dict = await generate_completion(
81 | prompt=prompt,
82 | model=self.grader_model_id,
83 | provider=provider,
84 | max_tokens=500,
85 | temperature=0.2,
86 | ) # Changed var name
87 |
88 | if not grader_response_dict.get("success"): # Use new var name
89 | return EvaluationScore(
90 | score=0.0, details=f"Grader LLM failed: {grader_response_dict.get('error')}"
91 | )
92 |
93 | grader_text = grader_response_dict.get("text", "") # Use new var name
94 |
95 | score_match = self.score_extraction_regex.search(grader_text)
96 | numerical_score = 0.0
97 | if score_match:
98 | try:
99 | numerical_score = float(score_match.group(1))
100 | if not (0 <= numerical_score <= 100):
101 | numerical_score = max(0.0, min(100.0, numerical_score))
102 | except ValueError:
103 | logger.warning(
104 | f"LLMGrader: Could not parse score from '{score_match.group(1)}'"
105 | )
106 | except IndexError:
107 | logger.warning(
108 | f"LLMGrader: Regex '{self.score_extraction_regex_str}' matched but had no capture group 1."
109 | )
110 | else:
111 | logger.warning(
112 | f"LLMGrader: Could not find score pattern in grader response: {grader_text[:200]}"
113 | )
114 |
115 | return EvaluationScore(
116 | score=numerical_score,
117 | details=grader_text,
118 | metrics={"grader_cost": grader_response_dict.get("cost", 0)}, # Use new var name
119 | )
120 |
121 | except Exception as e:
122 | logger.error(f"LLMGrader failed: {e}", exc_info=True)
123 | return EvaluationScore(score=0.0, details=f"Error during LLM grading: {str(e)}")
124 |
125 |
126 | @register_evaluator
127 | class UnitTestEvaluator(Evaluator):
128 | evaluator_type = "unit_test"
129 |
130 | def __init__(self, config: Dict[str, Any]):
131 | super().__init__(config)
132 | test_file_path_str = config.get("test_file_path")
133 | self.required_packages: List[str] = config.get("required_packages", []) # For sandbox
134 |
135 | if not test_file_path_str:
136 | logger.warning(
137 | "UnitTestEvaluator: 'test_file_path' not provided in config. This evaluator may not function."
138 | )
139 | self.test_file_path = Path()
140 | else:
141 | self.test_file_path = Path(test_file_path_str)
142 | self.timeout_seconds = config.get("timeout_seconds", 30) # Sandbox timeout is in ms
143 |
144 | async def score(
145 | self,
146 | response_data: ModelResponseData,
147 | original_prompt: str, # Unused but part of interface
148 | tournament_type: Literal["code", "text"],
149 | ) -> EvaluationScore:
150 | if tournament_type != "code" or not response_data.extracted_code:
151 | return EvaluationScore(
152 | score=0.0,
153 | details="Unit test evaluator only applicable to code tournaments with extracted code.",
154 | )
155 |
156 | if (
157 | not self.test_file_path
158 | or not self.test_file_path.exists()
159 | or not self.test_file_path.is_file()
160 | ):
161 | details = f"Test file not found, not configured, or not a file: {self.test_file_path}"
162 | if not self.test_file_path.name:
163 | details = "Test file path not configured for UnitTestEvaluator."
164 | logger.warning(f"UnitTestEvaluator: {details}")
165 | return EvaluationScore(score=0.0, details=details)
166 |
167 | try:
168 | # Read the user's test code from the host filesystem
169 | user_test_code = self.test_file_path.read_text(encoding="utf-8")
170 | except Exception as e:
171 | logger.error(f"UnitTestEvaluator: Failed to read test file {self.test_file_path}: {e}")
172 | return EvaluationScore(score=0.0, details=f"Failed to read test file: {e}")
173 |
174 | # Combine the generated code and the user's test code into a single script
175 | # to be run in the sandbox.
176 | # The generated code will be defined first, then the test code.
177 | # We assume the test code can import/use things defined in the generated code.
178 | # A common pattern is for generated code to be in a module `solution` or similar.
179 | # Here, we'll just put them in the same global scope for simplicity.
180 |
181 | # Let's make the generated code importable as 'generated_solution'
182 | # and the test code able to 'from generated_solution import *' or specific functions/classes.
183 | # This requires the generated code to be structured as a module.
184 | # For now, a simpler approach: just concatenate.
185 | # More robust: write generated_code to solution.py, test_code to test_solution.py,
186 | # then run test_solution.py which imports solution.py. This is harder without a true sandbox FS.
187 |
188 | # --- Simpler approach: Inject generated code directly, then test code ---
189 | # Test code should be written to assume the generated code's functions/classes
190 | # are available in the global scope or importable from a predefined module name.
191 | # For Pyodide, defining them globally is easiest.
192 |
193 | # The `unittest_runner_script` will execute the combined code.
194 | # It will define the generated code, then the test code, then run unittest.
195 |
196 | generated_code_to_run = response_data.extracted_code
197 |
198 | # This script will be executed by python_sandbox.py
199 | # It needs to define the generated functions/classes, then define and run tests.
200 | # stdout from this script will be parsed for results.
201 | unittest_runner_script = f"""
202 | # --- Generated Code from Model ---
203 | {generated_code_to_run}
204 | # --- End of Generated Code ---
205 |
206 | # --- User's Test Code ---
207 | {user_test_code}
208 | # --- End of User's Test Code ---
209 |
210 | # --- Unittest Execution ---
211 | import unittest
212 | import sys
213 | import io # To capture unittest output
214 |
215 | # Capture unittest's output to a string buffer instead of stderr
216 | # This makes parsing easier and cleaner from the sandbox output.
217 | suite = unittest.defaultTestLoader.loadTestsFromModule(sys.modules[__name__])
218 | output_buffer = io.StringIO()
219 | runner = unittest.TextTestRunner(stream=output_buffer, verbosity=2)
220 | result = runner.run(suite)
221 |
222 | # Print results in a parsable format to STDOUT
223 | # The python_sandbox tool will capture this stdout.
224 | print("UNIT_TEST_RESULTS_START") # Delimiter for easier parsing
225 | print(f"TestsRun:{{result.testsRun}}")
226 | print(f"Failures:{{len(result.failures)}}")
227 | print(f"Errors:{{len(result.errors)}}")
228 | pass_rate = (result.testsRun - len(result.failures) - len(result.errors)) / result.testsRun if result.testsRun > 0 else 0.0
229 | print(f"PassRate:{{pass_rate:.4f}}")
230 | print("UNIT_TEST_RESULTS_END")
231 |
232 | # Also print the full unittest output (which was captured in output_buffer)
233 | # This can go to stdout as well, or we can separate it.
234 | print("\\n--- Unittest Full Output ---")
235 | print(output_buffer.getvalue())
236 | """
237 | details_output = "Unit test execution details via Pyodide Sandbox:\n"
238 | pass_rate = 0.0
239 | tests_run = 0
240 | failures = 0
241 | errors = 0
242 | sandbox_stdout = ""
243 | sandbox_stderr = ""
244 |
245 | try:
246 | sandbox_result = await execute_python(
247 | code=unittest_runner_script,
248 | packages=self.required_packages, # Pass packages needed by generated code or tests
249 | # wheels=... # If wheels are needed
250 | allow_network=False, # Usually False for unit tests unless they test network code
251 | allow_fs=False, # Usually False unless tests interact with mcpfs
252 | timeout_ms=self.timeout_seconds * 1000,
253 | )
254 |
255 | if sandbox_result.get("success"):
256 | sandbox_stdout = sandbox_result.get("stdout", "")
257 | sandbox_stderr = sandbox_result.get("stderr", "") # Unittest output now in stdout
258 | details_output += f"Sandbox STDOUT:\n{sandbox_stdout}\n"
259 | if sandbox_stderr: # Still log stderr if sandbox itself had issues
260 | details_output += f"Sandbox STDERR:\n{sandbox_stderr}\n"
261 |
262 | # Parse metrics from sandbox_stdout
263 | # Use re.search with MULTILINE if parsing from a larger block
264 | run_match = re.search(r"TestsRun:(\d+)", sandbox_stdout)
265 | fail_match = re.search(r"Failures:(\d+)", sandbox_stdout)
266 | err_match = re.search(r"Errors:(\d+)", sandbox_stdout)
267 | rate_match = re.search(r"PassRate:([0-9.]+)", sandbox_stdout)
268 |
269 | if run_match:
270 | tests_run = int(run_match.group(1))
271 | if fail_match:
272 | failures = int(fail_match.group(1))
273 | if err_match:
274 | errors = int(err_match.group(1))
275 | if rate_match:
276 | pass_rate = float(rate_match.group(1))
277 | else:
278 | logger.warning(
279 | f"UnitTestEvaluator: Could not parse PassRate from sandbox stdout. Output: {sandbox_stdout[:500]}"
280 | )
281 | details_output += "Warning: Could not parse PassRate from output.\n"
282 | else: # Sandbox execution itself failed
283 | error_msg = sandbox_result.get("error_message", "Sandbox execution failed")
284 | error_details = sandbox_result.get("error_details", {})
285 | details_output += (
286 | f"Sandbox Execution Failed: {error_msg}\nDetails: {error_details}\n"
287 | )
288 | logger.error(
289 | f"UnitTestEvaluator: Sandbox execution failed: {error_msg} - {error_details}"
290 | )
291 | pass_rate = 0.0
292 |
293 | except (
294 | ProviderError,
295 | ToolError,
296 | ToolInputError,
297 | ) as e: # Catch errors from execute_python tool
298 | logger.error(f"UnitTestEvaluator: Error calling python_sandbox: {e}", exc_info=True)
299 | details_output += f"Error calling python_sandbox: {str(e)}\n"
300 | pass_rate = 0.0
301 | except Exception as e: # Catch any other unexpected errors
302 | logger.error(f"UnitTestEvaluator: Unexpected error: {e}", exc_info=True)
303 | details_output += f"Unexpected error during unit test evaluation: {str(e)}\n"
304 | pass_rate = 0.0
305 |
306 | return EvaluationScore(
307 | score=pass_rate * 100, # Score 0-100
308 | details=details_output,
309 | metrics={
310 | "tests_run": tests_run,
311 | "failures": failures,
312 | "errors": errors,
313 | "pass_rate": pass_rate,
314 | "sandbox_stdout_len": len(sandbox_stdout),
315 | "sandbox_stderr_len": len(sandbox_stderr),
316 | },
317 | )
318 |
319 |
320 | @register_evaluator
321 | class RegexMatchEvaluator(Evaluator):
322 | evaluator_type = "regex_match"
323 |
324 | def __init__(self, config: Dict[str, Any]):
325 | super().__init__(config)
326 | self.patterns_str: List[str] = config.get("patterns", [])
327 | if not self.patterns_str or not isinstance(self.patterns_str, list):
328 | logger.error("RegexMatchEvaluator: 'patterns' (list of strings) is required in config.")
329 | self.patterns_str = []
330 |
331 | self.target_field: Literal["response_text", "extracted_code"] = config.get(
332 | "target_field", "response_text"
333 | )
334 | self.match_mode: Literal["all_must_match", "any_can_match", "proportion_matched"] = (
335 | config.get("match_mode", "all_must_match")
336 | )
337 |
338 | flag_options_str: Optional[List[str]] = config.get("regex_flag_options")
339 | self.regex_flags: int = 0
340 | if flag_options_str:
341 | for flag_name in flag_options_str:
342 | if hasattr(re, flag_name.upper()):
343 | self.regex_flags |= getattr(re, flag_name.upper())
344 | else:
345 | logger.warning(
346 | f"RegexMatchEvaluator: Unknown regex flag '{flag_name}' specified."
347 | )
348 |
349 | self.compiled_patterns: List[re.Pattern] = []
350 | for i, p_str in enumerate(
351 | self.patterns_str
352 | ): # Use enumerate to get index for original string
353 | try:
354 | self.compiled_patterns.append(re.compile(p_str, self.regex_flags))
355 | except re.error as e:
356 | logger.error(
357 | f"RegexMatchEvaluator: Invalid regex pattern '{p_str}' (index {i}): {e}. Skipping this pattern."
358 | )
359 | # Add a placeholder or skip to keep lengths consistent if needed,
360 | # or ensure patterns_str is filtered alongside compiled_patterns.
361 | # For simplicity now, compiled_patterns might be shorter if errors occur.
362 |
363 | async def score(
364 | self,
365 | response_data: ModelResponseData,
366 | original_prompt: str,
367 | tournament_type: Literal["code", "text"],
368 | ) -> EvaluationScore:
369 | # Iterate using original patterns_str for error reporting if compiled_patterns is shorter
370 | num_configured_patterns = len(self.patterns_str)
371 |
372 | if not self.compiled_patterns and self.patterns_str: # Some patterns were invalid
373 | return EvaluationScore(
374 | score=0.0,
375 | details="No valid regex patterns could be compiled from configuration.",
376 | metrics={
377 | "patterns_configured": num_configured_patterns,
378 | "patterns_compiled": 0,
379 | "patterns_matched": 0,
380 | },
381 | )
382 | if not self.compiled_patterns and not self.patterns_str: # No patterns provided at all
383 | return EvaluationScore(
384 | score=0.0,
385 | details="No regex patterns configured for matching.",
386 | metrics={"patterns_configured": 0, "patterns_compiled": 0, "patterns_matched": 0},
387 | )
388 |
389 | content_to_check: Optional[str] = None
390 | if self.target_field == "extracted_code":
391 | content_to_check = response_data.extracted_code
392 | elif self.target_field == "response_text":
393 | content_to_check = response_data.response_text
394 | else:
395 | return EvaluationScore(
396 | score=0.0,
397 | details=f"Invalid target_field '{self.target_field}'.",
398 | metrics={"patterns_compiled": len(self.compiled_patterns), "patterns_matched": 0},
399 | )
400 |
401 | if content_to_check is None:
402 | return EvaluationScore(
403 | score=0.0,
404 | details=f"Target content field '{self.target_field}' is empty or None.",
405 | metrics={"patterns_compiled": len(self.compiled_patterns), "patterns_matched": 0},
406 | )
407 |
408 | num_matched = 0
409 | all_patterns_details: List[str] = []
410 |
411 | # Corrected loop over successfully compiled patterns
412 | for pattern_obj in self.compiled_patterns:
413 | if pattern_obj.search(content_to_check):
414 | num_matched += 1
415 | all_patterns_details.append(f"Pattern '{pattern_obj.pattern}': MATCHED")
416 | else:
417 | all_patterns_details.append(f"Pattern '{pattern_obj.pattern}': NOT MATCHED")
418 |
419 | final_score = 0.0
420 | num_effective_patterns = len(self.compiled_patterns) # Base score on only valid patterns
421 |
422 | if num_effective_patterns == 0 and num_configured_patterns > 0: # All patterns were invalid
423 | details_str = f"Target field: '{self.target_field}'. Mode: '{self.match_mode}'.\nAll {num_configured_patterns} configured regex patterns were invalid and could not be compiled."
424 | return EvaluationScore(
425 | score=0.0,
426 | details=details_str,
427 | metrics={
428 | "patterns_configured": num_configured_patterns,
429 | "patterns_compiled": 0,
430 | "patterns_matched": 0,
431 | },
432 | )
433 | elif num_effective_patterns == 0 and num_configured_patterns == 0: # No patterns configured
434 | details_str = f"Target field: '{self.target_field}'. Mode: '{self.match_mode}'.\nNo regex patterns configured."
435 | return EvaluationScore(
436 | score=0.0,
437 | details=details_str,
438 | metrics={"patterns_configured": 0, "patterns_compiled": 0, "patterns_matched": 0},
439 | )
440 |
441 | if self.match_mode == "all_must_match":
442 | final_score = 100.0 if num_matched == num_effective_patterns else 0.0
443 | elif self.match_mode == "any_can_match":
444 | final_score = 100.0 if num_matched > 0 else 0.0
445 | elif self.match_mode == "proportion_matched":
446 | final_score = (num_matched / num_effective_patterns) * 100.0
447 |
448 | details_str = f"Target field: '{self.target_field}'. Mode: '{self.match_mode}'.\n"
449 | details_str += f"Matched {num_matched} out of {num_effective_patterns} validly compiled patterns (from {num_configured_patterns} configured).\n"
450 | details_str += "\n".join(all_patterns_details)
451 |
452 | return EvaluationScore(
453 | score=final_score,
454 | details=details_str,
455 | metrics={
456 | "patterns_configured": num_configured_patterns,
457 | "patterns_compiled": num_effective_patterns,
458 | "patterns_matched": num_matched,
459 | "match_proportion_compiled": (num_matched / num_effective_patterns)
460 | if num_effective_patterns
461 | else 0.0,
462 | },
463 | )
464 |
```
--------------------------------------------------------------------------------
/ultimate_mcp_server/clients/completion_client.py:
--------------------------------------------------------------------------------
```python
1 | """High-level client for LLM completion operations."""
2 |
3 | import asyncio
4 | from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
5 |
6 | from ultimate_mcp_server.constants import Provider
7 | from ultimate_mcp_server.core.providers.base import BaseProvider, get_provider
8 | from ultimate_mcp_server.services.cache import get_cache_service
9 | from ultimate_mcp_server.utils import get_logger
10 |
11 | logger = get_logger("ultimate_mcp_server.clients.completion")
12 |
13 | class CompletionClient:
14 | """
15 | High-level client for LLM text generation operations with advanced features.
16 |
17 | The CompletionClient provides a unified interface for interacting with various LLM providers
18 | (OpenAI, Anthropic, etc.) through a simple, consistent API. It abstracts away the complexity
19 | of provider-specific implementations, offering a range of features that enhance reliability
20 | and performance.
21 |
22 | Key features:
23 | - Multi-provider support with unified interface
24 | - Automatic fallback between providers
25 | - Result caching for improved performance and reduced costs
26 | - Streaming support for real-time text generation
27 | - Provider initialization and error handling
28 | - Comprehensive error handling and logging
29 |
30 | Architecture:
31 | The client follows a layered architecture pattern:
32 | 1. High-level methods (generate_completion, generate_completion_stream) provide the main API
33 | 2. Provider abstraction layer manages provider-specific implementation details
34 | 3. Caching layer intercepts requests to reduce redundant API calls
35 | 4. Error handling layer provides graceful fallbacks and informative errors
36 |
37 | Performance Considerations:
38 | - Caching is enabled by default and can significantly reduce API costs and latency
39 | - For time-sensitive or unique responses, caching can be disabled per request
40 | - Streaming mode reduces time-to-first-token but cannot leverage caching
41 | - Provider fallback adds resilience but may increase latency if primary providers fail
42 |
43 | This client is designed for MCP tools that require text generation using LLMs,
44 | making interactions more robust by handling common issues like rate limits,
45 | timeouts, and provider-specific errors.
46 |
47 | Example:
48 | ```python
49 | # Create client with default settings
50 | client = CompletionClient()
51 |
52 | # Generate text non-streaming with specific provider and model
53 | result = await client.generate_completion(
54 | prompt="Explain quantum computing",
55 | provider="anthropic",
56 | model="claude-3-5-haiku-20241022",
57 | temperature=0.5,
58 | max_tokens=1000
59 | )
60 | print(f"Generated by {result.model} in {result.processing_time:.2f}s")
61 | print(result.text)
62 |
63 | # Generate text with streaming for real-time output
64 | async for chunk, metadata in client.generate_completion_stream(
65 | prompt="Write a short story about robots",
66 | temperature=0.8
67 | ):
68 | print(chunk, end="")
69 | if metadata.get("done", False):
70 | print("\nGeneration complete!")
71 |
72 | # Use provider fallback for high availability
73 | try:
74 | result = await client.try_providers(
75 | prompt="Summarize this article",
76 | providers=["openai", "anthropic", "gemini"],
77 | models=["gpt-4", "claude-instant-1", "gemini-pro"],
78 | temperature=0.3
79 | )
80 | except Exception as e:
81 | print(f"All providers failed: {e}")
82 | ```
83 | """
84 |
85 | def __init__(self, default_provider: str = Provider.OPENAI.value, use_cache_by_default: bool = True):
86 | """Initialize the completion client.
87 |
88 | Args:
89 | default_provider: Default provider to use for completions
90 | use_cache_by_default: Whether to use cache by default
91 | """
92 | self.default_provider = default_provider
93 | self.cache_service = get_cache_service()
94 | self.use_cache_by_default = use_cache_by_default
95 |
96 | async def initialize_provider(self, provider_name: str, api_key: Optional[str] = None) -> BaseProvider:
97 | """
98 | Initialize and return a provider instance ready for LLM interactions.
99 |
100 | This method handles the creation and initialization of a specific LLM provider,
101 | ensuring it's properly configured and ready to generate completions. It abstracts
102 | the details of provider initialization, including async initialization methods
103 | that some providers might require.
104 |
105 | The method performs several steps:
106 | 1. Retrieves the provider implementation based on the provider name
107 | 2. Applies the API key if provided (otherwise uses environment configuration)
108 | 3. Runs any provider-specific async initialization if required
109 | 4. Returns the ready-to-use provider instance
110 |
111 | Provider initialization follows these architecture principles:
112 | - Late binding: Providers are initialized on-demand, not at client creation
113 | - Dependency injection: API keys can be injected at runtime rather than relying only on environment
114 | - Fail-fast: Validation occurs during initialization rather than at generation time
115 | - Extensibility: New providers can be added without changing client code
116 |
117 | Common provider names include:
118 | - "openai": OpenAI API (GPT models)
119 | - "anthropic": Anthropic API (Claude models)
120 | - "google": Google AI/Vertex API (Gemini models)
121 | - "mistral": Mistral AI API (Mistral, Mixtral models)
122 | - "ollama": Local Ollama server for various open-source models
123 |
124 | Error handling:
125 | - Invalid provider names are caught and reported immediately
126 | - Authentication issues (e.g., invalid API keys) are detected during initialization
127 | - Provider-specific initialization failures are propagated with detailed error messages
128 |
129 | Args:
130 | provider_name: Identifier for the desired provider (e.g., "openai", "anthropic")
131 | api_key: Optional API key to use instead of environment-configured keys
132 |
133 | Returns:
134 | A fully initialized BaseProvider instance ready to generate completions
135 |
136 | Raises:
137 | ValueError: If the provider name is invalid or not supported
138 | Exception: If initialization fails (e.g., invalid API key, network issues)
139 |
140 | Note:
141 | This method is typically called internally by other client methods,
142 | but can be used directly when you need a specific provider instance
143 | for specialized operations not covered by the main client methods.
144 |
145 | Example:
146 | ```python
147 | # Get a specific provider instance for custom operations
148 | openai_provider = await client.initialize_provider("openai")
149 |
150 | # Custom operation using provider-specific features
151 | response = await openai_provider.some_specialized_method(...)
152 | ```
153 | """
154 | try:
155 | provider = await get_provider(provider_name, api_key=api_key)
156 | # Ensure the provider is initialized (some might need async init)
157 | if hasattr(provider, 'initialize') and asyncio.iscoroutinefunction(provider.initialize):
158 | await provider.initialize()
159 | return provider
160 | except Exception as e:
161 | logger.error(f"Failed to initialize provider {provider_name}: {e}", emoji_key="error")
162 | raise
163 |
164 | async def generate_completion(
165 | self,
166 | prompt: str,
167 | provider: Optional[str] = None,
168 | model: Optional[str] = None,
169 | temperature: float = 0.7,
170 | max_tokens: Optional[int] = None,
171 | use_cache: bool = True,
172 | cache_ttl: int = 3600,
173 | **kwargs
174 | ):
175 | """
176 | Generate text completion from an LLM with optional caching.
177 |
178 | This method provides a unified interface for generating text completions from
179 | any supported LLM provider. It includes intelligent caching to avoid redundant
180 | API calls for identical inputs, reducing costs and latency.
181 |
182 | The caching system:
183 | - Creates a unique key based on the prompt, provider, model, and parameters
184 | - Checks for cached results before making API calls
185 | - Stores successful responses with a configurable TTL
186 | - Can be disabled per-request with the use_cache parameter
187 |
188 | Args:
189 | prompt: The text prompt to send to the LLM
190 | provider: The LLM provider to use (e.g., "openai", "anthropic", "google")
191 | If None, uses the client's default_provider
192 | model: Specific model to use (e.g., "gpt-4", "claude-instant-1")
193 | If None, uses the provider's default model
194 | temperature: Sampling temperature for controlling randomness (0.0-1.0)
195 | Lower values are more deterministic, higher values more creative
196 | max_tokens: Maximum number of tokens to generate
197 | If None, uses provider-specific defaults
198 | use_cache: Whether to use the caching system (default: True)
199 | cache_ttl: Time-to-live for cache entries in seconds (default: 1 hour)
200 | **kwargs: Additional provider-specific parameters
201 | (e.g., top_p, frequency_penalty, presence_penalty)
202 |
203 | Returns:
204 | CompletionResult object with attributes:
205 | - text: The generated completion text
206 | - provider: The provider that generated the text
207 | - model: The model used
208 | - processing_time: Time taken to generate the completion (in seconds)
209 | - tokens: Token usage information (if available)
210 | - error: Error information (if an error occurred but was handled)
211 |
212 | Raises:
213 | ValueError: For invalid parameters
214 | Exception: For provider errors or other issues during generation
215 |
216 | Example:
217 | ```python
218 | result = await client.generate_completion(
219 | prompt="Write a poem about artificial intelligence",
220 | temperature=0.8,
221 | max_tokens=1000
222 | )
223 | print(f"Generated by {result.model} in {result.processing_time:.2f}s")
224 | print(result.text)
225 | ```
226 | """
227 | provider_name = provider or self.default_provider
228 |
229 | # Check cache if enabled
230 | if use_cache and self.cache_service.enabled:
231 | # Create a robust cache key
232 | provider_instance = await self.initialize_provider(provider_name)
233 | model_id = model or provider_instance.get_default_model()
234 | # Include relevant parameters in the cache key
235 | params_hash = hash((prompt, temperature, max_tokens, str(kwargs)))
236 | cache_key = f"completion:{provider_name}:{model_id}:{params_hash}"
237 |
238 | cached_result = await self.cache_service.get(cache_key)
239 | if cached_result is not None:
240 | logger.success("Cache hit! Using cached result", emoji_key="cache")
241 | # Set a nominal processing time for cached results
242 | cached_result.processing_time = 0.001
243 | return cached_result
244 |
245 | # Cache miss or cache disabled
246 | if use_cache and self.cache_service.enabled:
247 | logger.info("Cache miss. Generating new completion...", emoji_key="processing")
248 | else:
249 | logger.info("Generating completion...", emoji_key="processing")
250 |
251 | # Initialize provider and generate completion
252 | try:
253 | provider_instance = await self.initialize_provider(provider_name)
254 | model_id = model or provider_instance.get_default_model()
255 |
256 | result = await provider_instance.generate_completion(
257 | prompt=prompt,
258 | model=model_id,
259 | temperature=temperature,
260 | max_tokens=max_tokens,
261 | **kwargs
262 | )
263 |
264 | # Save to cache if enabled
265 | if use_cache and self.cache_service.enabled:
266 | await self.cache_service.set(
267 | key=cache_key,
268 | value=result,
269 | ttl=cache_ttl
270 | )
271 | logger.info(f"Result saved to cache (key: ...{cache_key[-10:]})", emoji_key="cache")
272 |
273 | return result
274 |
275 | except Exception as e:
276 | logger.error(f"Error generating completion: {str(e)}", emoji_key="error")
277 | raise
278 |
279 | async def generate_completion_stream(
280 | self,
281 | prompt: str,
282 | provider: Optional[str] = None,
283 | model: Optional[str] = None,
284 | temperature: float = 0.7,
285 | max_tokens: Optional[int] = None,
286 | **kwargs
287 | ) -> AsyncGenerator[Tuple[str, Dict[str, Any]], None]:
288 | """
289 | Generate a streaming text completion with real-time chunks.
290 |
291 | This method provides a streaming interface to LLM text generation, where
292 | text is returned incrementally as it's generated, rather than waiting for
293 | the entire response. This enables real-time UI updates, faster apparent
294 | response times, and the ability to process partial responses.
295 |
296 | Unlike the non-streaming version, this method:
297 | - Does not support caching (each streaming response is unique)
298 | - Returns an async generator that yields content incrementally
299 | - Provides metadata with each chunk for tracking generation progress
300 |
301 | Args:
302 | prompt: The text prompt to send to the LLM
303 | provider: The LLM provider to use (e.g., "openai", "anthropic")
304 | If None, uses the client's default_provider
305 | model: Specific model to use (e.g., "gpt-4", "claude-instant-1")
306 | If None, uses the provider's default model
307 | temperature: Sampling temperature for controlling randomness (0.0-1.0)
308 | Lower values are more deterministic, higher values more creative
309 | max_tokens: Maximum number of tokens to generate
310 | If None, uses provider-specific defaults
311 | **kwargs: Additional provider-specific parameters
312 |
313 | Yields:
314 | Tuples of (chunk_text, metadata), where:
315 | - chunk_text: A string containing the next piece of generated text
316 | - metadata: A dictionary with information about the generation process:
317 | - done: Boolean indicating if this is the final chunk
318 | - chunk_index: Index of the current chunk (0-based)
319 | - token_count: Number of tokens in this chunk (if available)
320 | - total_tokens: Running total of tokens generated so far (if available)
321 |
322 | Raises:
323 | ValueError: For invalid parameters
324 | Exception: For provider errors or other issues during streaming
325 |
326 | Example:
327 | ```python
328 | # Display text as it's generated
329 | async for chunk, metadata in client.generate_completion_stream(
330 | prompt="Explain the theory of relativity",
331 | temperature=0.3
332 | ):
333 | print(chunk, end="")
334 | if metadata.get("done", False):
335 | print("\nGeneration complete!")
336 | ```
337 |
338 | Note:
339 | Not all providers support streaming completions. Check the provider
340 | documentation for compatibility.
341 | """
342 | provider_name = provider or self.default_provider
343 |
344 | logger.info("Generating streaming completion...", emoji_key="processing")
345 |
346 | # Initialize provider and generate streaming completion
347 | try:
348 | provider_instance = await self.initialize_provider(provider_name)
349 | model_id = model or provider_instance.get_default_model()
350 |
351 | stream = provider_instance.generate_completion_stream(
352 | prompt=prompt,
353 | model=model_id,
354 | temperature=temperature,
355 | max_tokens=max_tokens,
356 | **kwargs
357 | )
358 |
359 | async for chunk, metadata in stream:
360 | yield chunk, metadata
361 |
362 | except Exception as e:
363 | logger.error(f"Error generating streaming completion: {str(e)}", emoji_key="error")
364 | raise
365 |
366 | async def try_providers(
367 | self,
368 | prompt: str,
369 | providers: List[str],
370 | models: Optional[List[str]] = None,
371 | **kwargs
372 | ):
373 | """
374 | Try multiple providers in sequence until one succeeds.
375 |
376 | This method implements an automatic fallback mechanism that attempts to generate
377 | a completion using a list of providers in order, continuing to the next provider
378 | if the current one fails. This provides resilience against provider downtime,
379 | rate limits, or other temporary failures.
380 |
381 | The method tries each provider exactly once in the order they're specified, with
382 | an optional corresponding model for each. This is useful for scenarios where you
383 | need high availability or want to implement prioritized provider selection.
384 |
385 | Args:
386 | prompt: The text prompt to send to the LLM
387 | providers: An ordered list of provider names to try (e.g., ["openai", "anthropic", "google"])
388 | Providers are tried in the specified order until one succeeds
389 | models: Optional list of models to use with each provider
390 | If provided, must be the same length as providers
391 | If None, each provider's default model is used
392 | **kwargs: Additional parameters passed to generate_completion
393 | Applies to all provider attempts
394 |
395 | Returns:
396 | CompletionResult from the first successful provider,
397 | with the same structure as generate_completion results
398 |
399 | Raises:
400 | ValueError: If no providers are specified or if models list length doesn't match providers
401 | Exception: If all specified providers fail, with details of the last error
402 |
403 | Example:
404 | ```python
405 | # Try OpenAI first, fall back to Anthropic, then Google
406 | result = await client.try_providers(
407 | prompt="Write a sonnet about programming",
408 | providers=["openai", "anthropic", "google"],
409 | models=["gpt-4", "claude-2", "gemini-pro"],
410 | temperature=0.7,
411 | max_tokens=800
412 | )
413 | print(f"Successfully used {result.provider} with model {result.model}")
414 | print(result.text)
415 | ```
416 |
417 | Note:
418 | Each provider attempt is logged, making it easy to track which providers
419 | succeeded or failed during the fallback sequence.
420 | """
421 | if not providers:
422 | raise ValueError("No providers specified")
423 |
424 | models = models or [None] * len(providers)
425 | if len(models) != len(providers):
426 | raise ValueError("If models are specified, there must be one for each provider")
427 |
428 | last_error = None
429 |
430 | for i, provider_name in enumerate(providers):
431 | try:
432 | logger.info(f"Trying provider: {provider_name}", emoji_key="provider")
433 | result = await self.generate_completion(
434 | prompt=prompt,
435 | provider=provider_name,
436 | model=models[i],
437 | **kwargs
438 | )
439 | return result
440 | except Exception as e:
441 | logger.warning(f"Provider {provider_name} failed: {str(e)}", emoji_key="warning")
442 | last_error = e
443 |
444 | # If we get here, all providers failed
445 | raise Exception(f"All providers failed. Last error: {str(last_error)}")
```
--------------------------------------------------------------------------------
/ultimate_mcp_server/utils/logging/progress.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Progress tracking and visualization for Gateway.
3 |
4 | This module provides enhanced progress tracking capabilities with Rich,
5 | supporting nested tasks, task groups, and dynamic progress updates.
6 | """
7 | import time
8 | import uuid
9 | from contextlib import contextmanager
10 | from dataclasses import dataclass, field
11 | from typing import Any, Dict, Generator, Iterable, List, Optional, TypeVar
12 |
13 | from rich.box import ROUNDED
14 | from rich.console import Console, ConsoleRenderable, Group
15 | from rich.live import Live
16 | from rich.progress import (
17 | BarColumn,
18 | SpinnerColumn,
19 | TaskID, # Import TaskID type hint
20 | TextColumn,
21 | TimeElapsedColumn,
22 | TimeRemainingColumn,
23 | )
24 | from rich.progress import Progress as RichProgress # Renamed to avoid clash
25 | from rich.table import Table
26 |
27 | from .console import console as default_console # Use the shared console instance
28 |
29 | # Use relative imports
30 |
31 | # TypeVar for generic progress tracking over iterables
32 | T = TypeVar("T")
33 |
34 | @dataclass
35 | class TaskInfo:
36 | """Information about a single task being tracked."""
37 | description: str
38 | total: float
39 | completed: float = 0.0
40 | status: str = "running" # running, success, error, skipped
41 | start_time: float = field(default_factory=time.time)
42 | end_time: Optional[float] = None
43 | parent_id: Optional[str] = None
44 | rich_task_id: Optional[TaskID] = None # ID from Rich Progress
45 | meta: Dict[str, Any] = field(default_factory=dict)
46 |
47 | @property
48 | def elapsed(self) -> float:
49 | """Calculate elapsed time."""
50 | end = self.end_time or time.time()
51 | return end - self.start_time
52 |
53 | @property
54 | def is_complete(self) -> bool:
55 | """Check if the task is in a terminal state."""
56 | return self.status in ("success", "error", "skipped")
57 |
58 | class GatewayProgress:
59 | """Manages multiple progress tasks with Rich integration and context.
60 |
61 | Allows for nested tasks and displays an overall summary.
62 | Uses a single Rich Progress instance managed internally.
63 | """
64 |
65 | def __init__(
66 | self,
67 | console: Optional[Console] = None,
68 | transient: bool = False, # Keep visible after completion?
69 | auto_refresh: bool = True,
70 | expand: bool = True, # Expand progress bars to full width?
71 | show_summary: bool = True,
72 | summary_refresh_rate: float = 1.0 # How often to refresh summary
73 | ):
74 | """Initialize the progress manager.
75 |
76 | Args:
77 | console: Rich Console instance (defaults to shared console)
78 | transient: Hide progress bars upon completion
79 | auto_refresh: Automatically refresh the display
80 | expand: Expand bars to console width
81 | show_summary: Display the summary panel below progress bars
82 | summary_refresh_rate: Rate limit for summary updates (seconds)
83 | """
84 | self.console = console or default_console
85 | self._rich_progress = self._create_progress(transient, auto_refresh, expand)
86 | self._live: Optional[Live] = None
87 | self._tasks: Dict[str, TaskInfo] = {}
88 | self._task_stack: List[str] = [] # For context managers
89 | self.show_summary = show_summary
90 | self._summary_renderable = self._render_summary() # Initial summary
91 | self._last_summary_update = 0.0
92 | self.summary_refresh_rate = summary_refresh_rate
93 |
94 | def _create_progress(self, transient: bool, auto_refresh: bool, expand: bool) -> RichProgress:
95 | """Create the underlying Rich Progress instance."""
96 | return RichProgress(
97 | SpinnerColumn(),
98 | TextColumn("[progress.description]{task.description}"),
99 | BarColumn(bar_width=None if expand else 40),
100 | "[progress.percentage]{task.percentage:>3.1f}%",
101 | TimeElapsedColumn(),
102 | TimeRemainingColumn(),
103 | console=self.console,
104 | transient=transient,
105 | auto_refresh=auto_refresh,
106 | expand=expand,
107 | # disable=True # Useful for debugging
108 | )
109 |
110 | def _render_summary(self) -> Group:
111 | """Render the overall progress summary table."""
112 | if not self.show_summary or not self._tasks:
113 | return Group() # Empty group if no summary needed or no tasks yet
114 |
115 | completed_count = sum(1 for t in self._tasks.values() if t.is_complete)
116 | running_count = len(self._tasks) - completed_count
117 | success_count = sum(1 for t in self._tasks.values() if t.status == 'success')
118 | error_count = sum(1 for t in self._tasks.values() if t.status == 'error')
119 | skipped_count = sum(1 for t in self._tasks.values() if t.status == 'skipped')
120 |
121 | total_elapsed = time.time() - min(t.start_time for t in self._tasks.values()) if self._tasks else 0
122 |
123 | # Calculate overall percentage (weighted average might be better?)
124 | overall_total = sum(t.total for t in self._tasks.values())
125 | overall_completed = sum(t.completed for t in self._tasks.values())
126 | overall_perc = (overall_completed / overall_total * 100) if overall_total > 0 else 100.0
127 |
128 | summary_table = Table(box=ROUNDED, show_header=False, padding=(0, 1), expand=True)
129 | summary_table.add_column("Metric", style="dim", width=15)
130 | summary_table.add_column("Value", style="bold")
131 |
132 | summary_table.add_row("Overall Prog.", f"{overall_perc:.1f}%")
133 | summary_table.add_row("Total Tasks", str(len(self._tasks)))
134 | summary_table.add_row(" Running", str(running_count))
135 | summary_table.add_row(" Completed", str(completed_count))
136 | if success_count > 0:
137 | summary_table.add_row(" Success", f"[success]{success_count}[/]")
138 | if error_count > 0:
139 | summary_table.add_row(" Errors", f"[error]{error_count}[/]")
140 | if skipped_count > 0:
141 | summary_table.add_row(" Skipped", f"[warning]{skipped_count}[/]")
142 | summary_table.add_row("Elapsed Time", f"{total_elapsed:.2f}s")
143 |
144 | return Group(summary_table)
145 |
146 | def _get_renderable(self) -> ConsoleRenderable:
147 | """Get the combined renderable for the Live display."""
148 | # Throttle summary updates
149 | now = time.time()
150 | if self.show_summary and (now - self._last_summary_update > self.summary_refresh_rate):
151 | self._summary_renderable = self._render_summary()
152 | self._last_summary_update = now
153 |
154 | if self.show_summary:
155 | return Group(self._rich_progress, self._summary_renderable)
156 | else:
157 | return self._rich_progress
158 |
159 | def add_task(
160 | self,
161 | description: str,
162 | name: Optional[str] = None,
163 | total: float = 100.0,
164 | parent: Optional[str] = None, # Name of parent task
165 | visible: bool = True,
166 | start: bool = True, # Start the Rich task immediately
167 | **meta: Any # Additional metadata
168 | ) -> str:
169 | """Add a new task to track.
170 |
171 | Args:
172 | description: Text description of the task.
173 | name: Unique name/ID for this task (auto-generated if None).
174 | total: Total steps/units for completion.
175 | parent: Name of the parent task for nesting (visual indent).
176 | visible: Whether the task is initially visible.
177 | start: Start the task in the Rich progress bar immediately.
178 | **meta: Arbitrary metadata associated with the task.
179 |
180 | Returns:
181 | The unique name/ID of the added task.
182 | """
183 | if name is None:
184 | name = str(uuid.uuid4()) # Generate unique ID if not provided
185 |
186 | if name in self._tasks:
187 | raise ValueError(f"Task with name '{name}' already exists.")
188 |
189 | parent_rich_id = None
190 | if parent:
191 | if parent not in self._tasks:
192 | raise ValueError(f"Parent task '{parent}' not found.")
193 | parent_task_info = self._tasks[parent]
194 | if parent_task_info.rich_task_id is not None:
195 | parent_rich_id = parent_task_info.rich_task_id
196 | # Quick hack for indentation - needs better Rich integration? Rich doesn't directly support tree view in Progress
197 | # description = f" {description}"
198 |
199 | task_info = TaskInfo(
200 | description=description,
201 | total=total,
202 | parent_id=parent,
203 | meta=meta,
204 | )
205 |
206 | # Add to Rich Progress if active
207 | rich_task_id = None
208 | if self._live and self._rich_progress:
209 | rich_task_id = self._rich_progress.add_task(
210 | description,
211 | total=total,
212 | start=start,
213 | visible=visible,
214 | parent=parent_rich_id # Rich uses TaskID for parent
215 | )
216 | task_info.rich_task_id = rich_task_id
217 |
218 | self._tasks[name] = task_info
219 | return name
220 |
221 | def update_task(
222 | self,
223 | name: str,
224 | description: Optional[str] = None,
225 | advance: Optional[float] = None,
226 | completed: Optional[float] = None,
227 | total: Optional[float] = None,
228 | visible: Optional[bool] = None,
229 | status: Optional[str] = None, # running, success, error, skipped
230 | **meta: Any
231 | ) -> None:
232 | """Update an existing task.
233 |
234 | Args:
235 | name: The unique name/ID of the task to update.
236 | description: New description text.
237 | advance: Amount to advance the completion progress.
238 | completed: Set completion to a specific value.
239 | total: Set a new total value.
240 | visible: Change task visibility.
241 | status: Update the task status (affects summary).
242 | **meta: Update or add metadata.
243 | """
244 | if name not in self._tasks:
245 | # Optionally log a warning or error
246 | # default_console.print(f"[warning]Attempted to update non-existent task: {name}[/]")
247 | return
248 |
249 | task_info = self._tasks[name]
250 | update_kwargs = {}
251 |
252 | if description is not None:
253 | task_info.description = description
254 | update_kwargs['description'] = description
255 |
256 | if total is not None:
257 | task_info.total = float(total)
258 | update_kwargs['total'] = task_info.total
259 |
260 | # Update completed status
261 | if completed is not None:
262 | task_info.completed = max(0.0, min(float(completed), task_info.total))
263 | update_kwargs['completed'] = task_info.completed
264 | elif advance is not None:
265 | task_info.completed = max(0.0, min(task_info.completed + float(advance), task_info.total))
266 | update_kwargs['completed'] = task_info.completed
267 |
268 | if visible is not None:
269 | update_kwargs['visible'] = visible
270 |
271 | if meta:
272 | task_info.meta.update(meta)
273 |
274 | # Update status (after completion update)
275 | if status is not None:
276 | task_info.status = status
277 | if task_info.is_complete and task_info.end_time is None:
278 | task_info.end_time = time.time()
279 | # Ensure Rich task is marked as complete
280 | if 'completed' not in update_kwargs:
281 | update_kwargs['completed'] = task_info.total
282 |
283 | # Update Rich progress bar if active
284 | if task_info.rich_task_id is not None and self._live and self._rich_progress:
285 | self._rich_progress.update(task_info.rich_task_id, **update_kwargs)
286 |
287 | def complete_task(self, name: str, status: str = "success") -> None:
288 | """Mark a task as complete with a final status.
289 |
290 | Args:
291 | name: The unique name/ID of the task.
292 | status: Final status ('success', 'error', 'skipped').
293 | """
294 | if name not in self._tasks:
295 | return # Or raise error/log warning
296 |
297 | task_info = self._tasks[name]
298 | self.update_task(
299 | name,
300 | completed=task_info.total, # Ensure it reaches 100%
301 | status=status
302 | )
303 |
304 | def start(self) -> "GatewayProgress":
305 | """Start the Rich Live display."""
306 | if self._live is None:
307 | # Add any tasks that were created before start()
308 | for _name, task_info in self._tasks.items():
309 | if task_info.rich_task_id is None:
310 | parent_rich_id = None
311 | if task_info.parent_id and task_info.parent_id in self._tasks:
312 | parent_rich_id = self._tasks[task_info.parent_id].rich_task_id
313 |
314 | task_info.rich_task_id = self._rich_progress.add_task(
315 | task_info.description,
316 | total=task_info.total,
317 | completed=task_info.completed,
318 | start=True, # Assume tasks added before start should be started
319 | visible=True, # Assume visible
320 | parent=parent_rich_id
321 | )
322 |
323 | self._live = Live(self._get_renderable(), console=self.console, refresh_per_second=10, vertical_overflow="visible")
324 | self._live.start(refresh=True)
325 | return self
326 |
327 | def stop(self) -> None:
328 | """Stop the Rich Live display."""
329 | if self._live is not None:
330 | # Ensure all running tasks in Rich are marked complete before stopping Live
331 | # to avoid them getting stuck visually
332 | if self._rich_progress:
333 | for task in self._rich_progress.tasks:
334 | if not task.finished:
335 | self._rich_progress.update(task.id, completed=task.total)
336 |
337 | self._live.stop()
338 | self._live = None
339 | # Optional: Clear the Rich Progress tasks?
340 | # self._rich_progress = self._create_progress(...) # Recreate if needed
341 |
342 | def update(self) -> None:
343 | """Force a refresh of the Live display (if active)."""
344 | if self._live:
345 | self._live.update(self._get_renderable(), refresh=True)
346 |
347 | def reset(self) -> None:
348 | """Reset the progress tracker, clearing all tasks."""
349 | self.stop() # Stop live display
350 | self._tasks.clear()
351 | self._task_stack.clear()
352 | # Recreate Rich progress to clear its tasks
353 | self._rich_progress = self._create_progress(
354 | self._rich_progress.transient,
355 | self._rich_progress.auto_refresh,
356 | True # Assuming expand is derived from console width anyway
357 | )
358 | self._summary_renderable = self._render_summary()
359 | self._last_summary_update = 0.0
360 |
361 | @contextmanager
362 | def task(
363 | self,
364 | description: str,
365 | name: Optional[str] = None,
366 | total: float = 100.0,
367 | parent: Optional[str] = None,
368 | autostart: bool = True, # Start Live display if not already started?
369 | **meta: Any
370 | ) -> Generator["GatewayProgress", None, None]: # Yields self for updates
371 | """Context manager for a single task.
372 |
373 | Args:
374 | description: Description of the task.
375 | name: Optional unique name/ID (auto-generated if None).
376 | total: Total steps/units for the task.
377 | parent: Optional parent task name.
378 | autostart: Start the overall progress display if not running.
379 | **meta: Additional metadata for the task.
380 |
381 | Yields:
382 | The GatewayProgress instance itself, allowing updates via `update_task`.
383 | """
384 | if autostart and self._live is None:
385 | self.start()
386 |
387 | task_name = self.add_task(description, name, total, parent, **meta)
388 | self._task_stack.append(task_name)
389 |
390 | try:
391 | yield self # Yield self to allow calling update_task(task_name, ...)
392 | except Exception:
393 | # Mark task as errored on exception
394 | self.complete_task(task_name, status="error")
395 | raise # Re-raise the exception
396 | else:
397 | # Mark task as successful if no exception
398 | # Check if it was already completed with a different status
399 | if task_name in self._tasks and not self._tasks[task_name].is_complete:
400 | self.complete_task(task_name, status="success")
401 | finally:
402 | # Pop task from stack
403 | if self._task_stack and self._task_stack[-1] == task_name:
404 | self._task_stack.pop()
405 | # No automatic stop here - allow multiple context managers
406 | # self.stop()
407 |
408 | def track(
409 | self,
410 | iterable: Iterable[T],
411 | description: str,
412 | name: Optional[str] = None,
413 | total: Optional[float] = None,
414 | parent: Optional[str] = None,
415 | autostart: bool = True,
416 | **meta: Any
417 | ) -> Iterable[T]:
418 | """Track progress over an iterable.
419 |
420 | Args:
421 | iterable: The iterable to track progress over.
422 | description: Description of the task.
423 | name: Optional unique name/ID (auto-generated if None).
424 | total: Total number of items (estimated if None).
425 | parent: Optional parent task name.
426 | autostart: Start the overall progress display if not running.
427 | **meta: Additional metadata for the task.
428 |
429 | Returns:
430 | The iterable, yielding items while updating progress.
431 | """
432 | if total is None:
433 | try:
434 | total = float(len(iterable)) # type: ignore
435 | except (TypeError, AttributeError):
436 | total = 100.0 # Default if length cannot be determined
437 |
438 | if autostart and self._live is None:
439 | self.start()
440 |
441 | task_name = self.add_task(description, name, total, parent, **meta)
442 |
443 | try:
444 | for item in iterable:
445 | yield item
446 | self.update_task(task_name, advance=1)
447 | except Exception:
448 | self.complete_task(task_name, status="error")
449 | raise
450 | else:
451 | # Check if it was already completed with a different status
452 | if task_name in self._tasks and not self._tasks[task_name].is_complete:
453 | self.complete_task(task_name, status="success")
454 | # No automatic stop
455 | # finally:
456 | # self.stop()
457 |
458 | def __enter__(self) -> "GatewayProgress":
459 | """Enter context manager, starts the display."""
460 | return self.start()
461 |
462 | def __exit__(self, exc_type, exc_val, exc_tb) -> None:
463 | """Exit context manager, stops the display."""
464 | self.stop()
465 |
466 | # --- Global Convenience Functions (using a default progress instance) ---
467 | # Note: Managing a truly global progress instance can be tricky.
468 | # It might be better to explicitly create and manage GatewayProgress instances.
469 | _global_progress: Optional[GatewayProgress] = None
470 |
471 | def get_global_progress() -> GatewayProgress:
472 | """Get or create the default global progress manager."""
473 | global _global_progress
474 | if _global_progress is None:
475 | _global_progress = GatewayProgress()
476 | return _global_progress
477 |
478 | def track(
479 | iterable: Iterable[T],
480 | description: str,
481 | name: Optional[str] = None,
482 | total: Optional[float] = None,
483 | parent: Optional[str] = None,
484 | ) -> Iterable[T]:
485 | """Track progress over an iterable using the global progress manager."""
486 | prog = get_global_progress()
487 | # Ensure global progress is started if used this way
488 | if prog._live is None:
489 | prog.start()
490 | return prog.track(iterable, description, name, total, parent, autostart=False)
491 |
492 | @contextmanager
493 | def task(
494 | description: str,
495 | name: Optional[str] = None,
496 | total: float = 100.0,
497 | parent: Optional[str] = None,
498 | ) -> Generator["GatewayProgress", None, None]:
499 | """Context manager for a single task using the global progress manager."""
500 | prog = get_global_progress()
501 | # Ensure global progress is started if used this way
502 | if prog._live is None:
503 | prog.start()
504 | with prog.task(description, name, total, parent, autostart=False) as task_context:
505 | yield task_context # Yields the progress manager itself
```