This is page 3 of 45. Use http://codebase.md/dicklesworthstone/llm_gateway_mcp_server?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .cursorignore
├── .env.example
├── .envrc
├── .gitignore
├── additional_features.md
├── check_api_keys.py
├── completion_support.py
├── comprehensive_test.py
├── docker-compose.yml
├── Dockerfile
├── empirically_measured_model_speeds.json
├── error_handling.py
├── example_structured_tool.py
├── examples
│ ├── __init__.py
│ ├── advanced_agent_flows_using_unified_memory_system_demo.py
│ ├── advanced_extraction_demo.py
│ ├── advanced_unified_memory_system_demo.py
│ ├── advanced_vector_search_demo.py
│ ├── analytics_reporting_demo.py
│ ├── audio_transcription_demo.py
│ ├── basic_completion_demo.py
│ ├── cache_demo.py
│ ├── claude_integration_demo.py
│ ├── compare_synthesize_demo.py
│ ├── cost_optimization.py
│ ├── data
│ │ ├── sample_event.txt
│ │ ├── Steve_Jobs_Introducing_The_iPhone_compressed.md
│ │ └── Steve_Jobs_Introducing_The_iPhone_compressed.mp3
│ ├── docstring_refiner_demo.py
│ ├── document_conversion_and_processing_demo.py
│ ├── entity_relation_graph_demo.py
│ ├── filesystem_operations_demo.py
│ ├── grok_integration_demo.py
│ ├── local_text_tools_demo.py
│ ├── marqo_fused_search_demo.py
│ ├── measure_model_speeds.py
│ ├── meta_api_demo.py
│ ├── multi_provider_demo.py
│ ├── ollama_integration_demo.py
│ ├── prompt_templates_demo.py
│ ├── python_sandbox_demo.py
│ ├── rag_example.py
│ ├── research_workflow_demo.py
│ ├── sample
│ │ ├── article.txt
│ │ ├── backprop_paper.pdf
│ │ ├── buffett.pdf
│ │ ├── contract_link.txt
│ │ ├── legal_contract.txt
│ │ ├── medical_case.txt
│ │ ├── northwind.db
│ │ ├── research_paper.txt
│ │ ├── sample_data.json
│ │ └── text_classification_samples
│ │ ├── email_classification.txt
│ │ ├── news_samples.txt
│ │ ├── product_reviews.txt
│ │ └── support_tickets.txt
│ ├── sample_docs
│ │ └── downloaded
│ │ └── attention_is_all_you_need.pdf
│ ├── sentiment_analysis_demo.py
│ ├── simple_completion_demo.py
│ ├── single_shot_synthesis_demo.py
│ ├── smart_browser_demo.py
│ ├── sql_database_demo.py
│ ├── sse_client_demo.py
│ ├── test_code_extraction.py
│ ├── test_content_detection.py
│ ├── test_ollama.py
│ ├── text_classification_demo.py
│ ├── text_redline_demo.py
│ ├── tool_composition_examples.py
│ ├── tournament_code_demo.py
│ ├── tournament_text_demo.py
│ ├── unified_memory_system_demo.py
│ ├── vector_search_demo.py
│ ├── web_automation_instruction_packs.py
│ └── workflow_delegation_demo.py
├── LICENSE
├── list_models.py
├── marqo_index_config.json.example
├── mcp_protocol_schema_2025-03-25_version.json
├── mcp_python_lib_docs.md
├── mcp_tool_context_estimator.py
├── model_preferences.py
├── pyproject.toml
├── quick_test.py
├── README.md
├── resource_annotations.py
├── run_all_demo_scripts_and_check_for_errors.py
├── storage
│ └── smart_browser_internal
│ ├── locator_cache.db
│ ├── readability.js
│ └── storage_state.enc
├── test_client.py
├── test_connection.py
├── TEST_README.md
├── test_sse_client.py
├── test_stdio_client.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── integration
│ │ ├── __init__.py
│ │ └── test_server.py
│ ├── manual
│ │ ├── test_extraction_advanced.py
│ │ └── test_extraction.py
│ └── unit
│ ├── __init__.py
│ ├── test_cache.py
│ ├── test_providers.py
│ └── test_tools.py
├── TODO.md
├── tool_annotations.py
├── tools_list.json
├── ultimate_mcp_banner.webp
├── ultimate_mcp_logo.webp
├── ultimate_mcp_server
│ ├── __init__.py
│ ├── __main__.py
│ ├── cli
│ │ ├── __init__.py
│ │ ├── __main__.py
│ │ ├── commands.py
│ │ ├── helpers.py
│ │ └── typer_cli.py
│ ├── clients
│ │ ├── __init__.py
│ │ ├── completion_client.py
│ │ └── rag_client.py
│ ├── config
│ │ └── examples
│ │ └── filesystem_config.yaml
│ ├── config.py
│ ├── constants.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── evaluation
│ │ │ ├── base.py
│ │ │ └── evaluators.py
│ │ ├── providers
│ │ │ ├── __init__.py
│ │ │ ├── anthropic.py
│ │ │ ├── base.py
│ │ │ ├── deepseek.py
│ │ │ ├── gemini.py
│ │ │ ├── grok.py
│ │ │ ├── ollama.py
│ │ │ ├── openai.py
│ │ │ └── openrouter.py
│ │ ├── server.py
│ │ ├── state_store.py
│ │ ├── tournaments
│ │ │ ├── manager.py
│ │ │ ├── tasks.py
│ │ │ └── utils.py
│ │ └── ums_api
│ │ ├── __init__.py
│ │ ├── ums_database.py
│ │ ├── ums_endpoints.py
│ │ ├── ums_models.py
│ │ └── ums_services.py
│ ├── exceptions.py
│ ├── graceful_shutdown.py
│ ├── services
│ │ ├── __init__.py
│ │ ├── analytics
│ │ │ ├── __init__.py
│ │ │ ├── metrics.py
│ │ │ └── reporting.py
│ │ ├── cache
│ │ │ ├── __init__.py
│ │ │ ├── cache_service.py
│ │ │ ├── persistence.py
│ │ │ ├── strategies.py
│ │ │ └── utils.py
│ │ ├── cache.py
│ │ ├── document.py
│ │ ├── knowledge_base
│ │ │ ├── __init__.py
│ │ │ ├── feedback.py
│ │ │ ├── manager.py
│ │ │ ├── rag_engine.py
│ │ │ ├── retriever.py
│ │ │ └── utils.py
│ │ ├── prompts
│ │ │ ├── __init__.py
│ │ │ ├── repository.py
│ │ │ └── templates.py
│ │ ├── prompts.py
│ │ └── vector
│ │ ├── __init__.py
│ │ ├── embeddings.py
│ │ └── vector_service.py
│ ├── tool_token_counter.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── audio_transcription.py
│ │ ├── base.py
│ │ ├── completion.py
│ │ ├── docstring_refiner.py
│ │ ├── document_conversion_and_processing.py
│ │ ├── enhanced-ums-lookbook.html
│ │ ├── entity_relation_graph.py
│ │ ├── excel_spreadsheet_automation.py
│ │ ├── extraction.py
│ │ ├── filesystem.py
│ │ ├── html_to_markdown.py
│ │ ├── local_text_tools.py
│ │ ├── marqo_fused_search.py
│ │ ├── meta_api_tool.py
│ │ ├── ocr_tools.py
│ │ ├── optimization.py
│ │ ├── provider.py
│ │ ├── pyodide_boot_template.html
│ │ ├── python_sandbox.py
│ │ ├── rag.py
│ │ ├── redline-compiled.css
│ │ ├── sentiment_analysis.py
│ │ ├── single_shot_synthesis.py
│ │ ├── smart_browser.py
│ │ ├── sql_databases.py
│ │ ├── text_classification.py
│ │ ├── text_redline_tools.py
│ │ ├── tournament.py
│ │ ├── ums_explorer.html
│ │ └── unified_memory_system.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── async_utils.py
│ │ ├── display.py
│ │ ├── logging
│ │ │ ├── __init__.py
│ │ │ ├── console.py
│ │ │ ├── emojis.py
│ │ │ ├── formatter.py
│ │ │ ├── logger.py
│ │ │ ├── panels.py
│ │ │ ├── progress.py
│ │ │ └── themes.py
│ │ ├── parse_yaml.py
│ │ ├── parsing.py
│ │ ├── security.py
│ │ └── text.py
│ └── working_memory_api.py
├── unified_memory_system_technical_analysis.md
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/examples/multi_provider_demo.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python
2 | """Multi-provider completion demo using Ultimate MCP Server."""
3 | import asyncio
4 | import sys
5 | from pathlib import Path
6 |
7 | # Add project root to path for imports when running as script
8 | sys.path.insert(0, str(Path(__file__).parent.parent))
9 |
10 | # Third-party imports
11 | # These imports need to be below sys.path modification, which is why they have noqa comments
12 | from rich import box # noqa: E402
13 | from rich.markup import escape # noqa: E402
14 | from rich.panel import Panel # noqa: E402
15 | from rich.rule import Rule # noqa: E402
16 | from rich.table import Table # noqa: E402
17 |
18 | # Project imports
19 | from ultimate_mcp_server.constants import Provider # noqa: E402
20 | from ultimate_mcp_server.core.server import Gateway # noqa: E402
21 | from ultimate_mcp_server.utils import get_logger # noqa: E402
22 | from ultimate_mcp_server.utils.display import CostTracker # Import CostTracker
23 | from ultimate_mcp_server.utils.logging.console import console # noqa: E402
24 |
25 | # Initialize logger and console
26 | logger = get_logger("example.multi_provider")
27 |
28 | async def run_provider_comparison(tracker: CostTracker):
29 | """Run a comparison of completions across multiple providers using Rich."""
30 | console.print(Rule("[bold blue]Multi-Provider Completion Comparison[/bold blue]"))
31 | logger.info("Starting multi-provider comparison demo", emoji_key="start")
32 |
33 | # Create Gateway instance - this handles provider initialization
34 | gateway = Gateway("multi-provider-demo", register_tools=False)
35 |
36 | # Initialize providers
37 | logger.info("Initializing providers...", emoji_key="provider")
38 | await gateway._initialize_providers()
39 |
40 | prompt = "Explain the advantages of quantum computing in 3-4 sentences."
41 | console.print(f"[cyan]Prompt:[/cyan] {escape(prompt)}")
42 |
43 | # Use model names directly if providers are inferred or handled by get_provider
44 | configs = [
45 | {"provider": Provider.OPENAI.value, "model": "gpt-4.1-mini"},
46 | {"provider": Provider.ANTHROPIC.value, "model": "claude-3-5-haiku-20241022"},
47 | {"provider": Provider.GEMINI.value, "model": "gemini-2.0-flash-lite"},
48 | {"provider": Provider.DEEPSEEK.value, "model": "deepseek-chat"},
49 | {"provider": Provider.GROK.value, "model": "grok-3-mini-latest"},
50 | {"provider": Provider.OPENROUTER.value, "model": "mistralai/mistral-nemo"},
51 | {"provider": Provider.OLLAMA.value, "model": "llama3.2"}
52 | ]
53 |
54 | results_data = []
55 |
56 | for config in configs:
57 | provider_name = config["provider"]
58 | model_name = config["model"]
59 |
60 | provider = gateway.providers.get(provider_name)
61 | if not provider:
62 | logger.warning(f"Provider {provider_name} not available or initialized, skipping.", emoji_key="warning")
63 | continue
64 |
65 | try:
66 | logger.info(f"Generating completion with {provider_name}/{model_name}...", emoji_key="processing")
67 | result = await provider.generate_completion(
68 | prompt=prompt,
69 | model=model_name,
70 | temperature=0.7,
71 | max_tokens=150
72 | )
73 |
74 | # Track the cost
75 | tracker.add_call(result)
76 |
77 | results_data.append({
78 | "provider": provider_name,
79 | "model": model_name,
80 | "text": result.text,
81 | "input_tokens": result.input_tokens,
82 | "output_tokens": result.output_tokens,
83 | "cost": result.cost,
84 | "processing_time": result.processing_time
85 | })
86 | logger.success(f"Completion from {provider_name}/{model_name} successful.", emoji_key="success")
87 |
88 | except Exception as e:
89 | logger.error(f"Error with {provider_name}/{model_name}: {e}", emoji_key="error", exc_info=True)
90 | # Optionally store error result
91 | results_data.append({
92 | "provider": provider_name,
93 | "model": model_name,
94 | "text": f"[red]Error: {escape(str(e))}[/red]",
95 | "cost": 0.0, "processing_time": 0.0, "input_tokens": 0, "output_tokens": 0
96 | })
97 |
98 | # Print comparison results using Rich Panels
99 | console.print(Rule("[bold green]Comparison Results[/bold green]"))
100 | for result in results_data:
101 | stats_line = (
102 | f"Cost: [green]${result['cost']:.6f}[/green] | "
103 | f"Time: [yellow]{result['processing_time']:.2f}s[/yellow] | "
104 | f"Tokens: [cyan]{result['input_tokens']} in, {result['output_tokens']} out[/cyan]"
105 | )
106 | console.print(Panel(
107 | escape(result['text'].strip()),
108 | title=f"[bold magenta]{escape(result['provider'])} / {escape(result['model'])}[/bold magenta]",
109 | subtitle=stats_line,
110 | border_style="blue",
111 | expand=False
112 | ))
113 |
114 | # Filter out error results before calculating summary stats
115 | valid_results = [r for r in results_data if "Error" not in r["text"]]
116 |
117 | if valid_results:
118 | summary_table = Table(title="Comparison Summary", box=box.ROUNDED, show_header=False)
119 | summary_table.add_column("Metric", style="cyan")
120 | summary_table.add_column("Value", style="white")
121 |
122 | try:
123 | fastest = min(valid_results, key=lambda r: r['processing_time'])
124 | summary_table.add_row("⚡ Fastest", f"{escape(fastest['provider'])}/{escape(fastest['model'])} ({fastest['processing_time']:.2f}s)")
125 | except ValueError:
126 | pass # Handle empty list
127 |
128 | try:
129 | cheapest = min(valid_results, key=lambda r: r['cost'])
130 | summary_table.add_row("💰 Cheapest", f"{escape(cheapest['provider'])}/{escape(cheapest['model'])} (${cheapest['cost']:.6f})")
131 | except ValueError:
132 | pass
133 |
134 | try:
135 | most_tokens = max(valid_results, key=lambda r: r['output_tokens'])
136 | summary_table.add_row("📄 Most Tokens", f"{escape(most_tokens['provider'])}/{escape(most_tokens['model'])} ({most_tokens['output_tokens']} tokens)")
137 | except ValueError:
138 | pass
139 |
140 | if summary_table.row_count > 0:
141 | console.print(summary_table)
142 | else:
143 | console.print("[yellow]No valid results to generate summary.[/yellow]")
144 |
145 | # Display final summary
146 | tracker.display_summary(console) # Display summary at the end
147 |
148 | console.print() # Final spacing
149 | return 0
150 |
151 | async def main():
152 | """Run the demo."""
153 | tracker = CostTracker() # Instantiate tracker
154 | try:
155 | return await run_provider_comparison(tracker) # Pass tracker
156 | except Exception as e:
157 | logger.critical(f"Demo failed: {str(e)}", emoji_key="critical")
158 | return 1
159 |
160 | if __name__ == "__main__":
161 | # Run the demo
162 | exit_code = asyncio.run(main())
163 | sys.exit(exit_code)
```
--------------------------------------------------------------------------------
/tool_annotations.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Tool annotations for MCP servers.
3 |
4 | This module provides a standardized way to annotate tools with hints that help LLMs
5 | understand when and how to use them effectively.
6 | """
7 | from typing import List, Optional
8 |
9 |
10 | class ToolAnnotations:
11 | """
12 | Tool annotations providing hints to LLMs about tool behavior and usage patterns.
13 |
14 | ToolAnnotations supply metadata that helps LLMs make informed decisions about:
15 | - WHEN to use a particular tool (appropriate contexts and priority)
16 | - HOW to use the tool correctly (through examples and behavior hints)
17 | - WHAT the potential consequences of using the tool might be (read-only vs. destructive)
18 | - WHO should use the tool (assistant, user, or both via audience hints)
19 |
20 | These annotations serve as a bridge between tool developers and LLMs, providing
21 | crucial context beyond just function signatures and descriptions. For example, the
22 | annotations can indicate that a file deletion tool is destructive and should be used
23 | with caution, or that a search tool is safe to retry multiple times.
24 |
25 | The system supports four key behavioral hints:
26 | - read_only_hint: Tool doesn't modify state (safe for exploratory use)
27 | - destructive_hint: Tool may perform irreversible changes (use with caution)
28 | - idempotent_hint: Repeated calls with same arguments produce same results
29 | - open_world_hint: Tool interacts with external systems beyond the LLM's knowledge
30 |
31 | Additional metadata includes:
32 | - audience: Who can/should use this tool
33 | - priority: How important/commonly used this tool is
34 | - title: Human-readable name for the tool
35 | - examples: Sample inputs and expected outputs
36 |
37 | Usage example:
38 | ```python
39 | # For a document deletion tool
40 | delete_doc_annotations = ToolAnnotations(
41 | read_only_hint=False, # Modifies state
42 | destructive_hint=True, # Deletion is destructive
43 | idempotent_hint=True, # Deleting twice has same effect as once
44 | open_world_hint=True, # Changes external file system
45 | audience=["assistant"], # Only assistant should use it
46 | priority=0.3, # Lower priority (use cautiously)
47 | title="Delete Document",
48 | examples=[{
49 | "input": {"document_id": "doc-123"},
50 | "output": {"success": True, "message": "Document deleted"}
51 | }]
52 | )
53 | ```
54 |
55 | Note: All hints are advisory only - they don't enforce behavior but help LLMs
56 | make better decisions about tool usage.
57 | """
58 |
59 | def __init__(
60 | self,
61 | read_only_hint: bool = False,
62 | destructive_hint: bool = True,
63 | idempotent_hint: bool = False,
64 | open_world_hint: bool = True,
65 | audience: List[str] = None,
66 | priority: float = 0.5,
67 | title: Optional[str] = None,
68 | examples: List[dict] = None,
69 | ):
70 | """
71 | Initialize tool annotations.
72 |
73 | Args:
74 | read_only_hint: If True, indicates this tool does not modify its environment.
75 | Tools with read_only_hint=True are safe to call for exploration without
76 | side effects. Examples: search tools, data retrieval, information queries.
77 | Default: False
78 |
79 | destructive_hint: If True, the tool may perform destructive updates that
80 | can't easily be reversed or undone. Only meaningful when read_only_hint
81 | is False. Examples: deletion operations, irreversible state changes, payments.
82 | Default: True
83 |
84 | idempotent_hint: If True, calling the tool repeatedly with the same arguments
85 | will have no additional effect beyond the first call. Useful for retry logic.
86 | Only meaningful when read_only_hint is False. Examples: setting a value,
87 | deleting an item (calling it twice doesn't delete it twice).
88 | Default: False
89 |
90 | open_world_hint: If True, this tool may interact with systems or information
91 | outside the LLM's knowledge context (external APIs, file systems, etc.).
92 | If False, the tool operates in a closed domain the LLM can fully model.
93 | Default: True
94 |
95 | audience: Who is the intended user of this tool, as a list of roles:
96 | - "assistant": The AI assistant can use this tool
97 | - "user": The human user can use this tool
98 | Default: ["assistant"]
99 |
100 | priority: How important this tool is, from 0.0 (lowest) to 1.0 (highest).
101 | Higher priority tools should be considered first when multiple tools
102 | might accomplish a similar task. Default: 0.5 (medium priority)
103 |
104 | title: Human-readable title for the tool. If not provided, the tool's
105 | function name is typically used instead.
106 |
107 | examples: List of usage examples, each containing 'input' and 'output' keys.
108 | These help the LLM understand expected patterns of use and responses.
109 | """
110 | self.read_only_hint = read_only_hint
111 | self.destructive_hint = destructive_hint
112 | self.idempotent_hint = idempotent_hint
113 | self.open_world_hint = open_world_hint
114 | self.audience = audience or ["assistant"]
115 | self.priority = max(0.0, min(1.0, priority)) # Clamp between 0 and 1
116 | self.title = title
117 | self.examples = examples or []
118 |
119 | def to_dict(self) -> dict:
120 | """Convert annotations to dictionary for MCP protocol."""
121 | return {
122 | "readOnlyHint": self.read_only_hint,
123 | "destructiveHint": self.destructive_hint,
124 | "idempotentHint": self.idempotent_hint,
125 | "openWorldHint": self.open_world_hint,
126 | "title": self.title,
127 | "audience": self.audience,
128 | "priority": self.priority,
129 | "examples": self.examples
130 | }
131 |
132 | # Pre-defined annotation templates for common tool types
133 |
134 | # A tool that only reads/queries data without modifying any state
135 | READONLY_TOOL = ToolAnnotations(
136 | read_only_hint=True,
137 | destructive_hint=False,
138 | idempotent_hint=True,
139 | open_world_hint=False,
140 | priority=0.8,
141 | title="Read-Only Tool"
142 | )
143 |
144 | # A tool that queries external systems or APIs for information
145 | QUERY_TOOL = ToolAnnotations(
146 | read_only_hint=True,
147 | destructive_hint=False,
148 | idempotent_hint=True,
149 | open_world_hint=True,
150 | priority=0.7,
151 | title="Query Tool"
152 | )
153 |
154 | # A tool that performs potentially irreversible changes to state
155 | # The LLM should use these with caution, especially without confirmation
156 | DESTRUCTIVE_TOOL = ToolAnnotations(
157 | read_only_hint=False,
158 | destructive_hint=True,
159 | idempotent_hint=False,
160 | open_world_hint=True,
161 | priority=0.3,
162 | title="Destructive Tool"
163 | )
164 |
165 | # A tool that modifies state but can be safely called multiple times
166 | # with the same arguments (e.g., setting a value, creating if not exists)
167 | IDEMPOTENT_UPDATE_TOOL = ToolAnnotations(
168 | read_only_hint=False,
169 | destructive_hint=False,
170 | idempotent_hint=True,
171 | open_world_hint=False,
172 | priority=0.5,
173 | title="Idempotent Update Tool"
174 | )
```
--------------------------------------------------------------------------------
/examples/test_content_detection.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python
2 | """
3 | Test script to demonstrate enhanced content type detection with Magika integration
4 | in the DocumentProcessingTool.
5 | """
6 |
7 | import asyncio
8 | import sys
9 | from pathlib import Path
10 |
11 | # Add project root to path for imports when running as script
12 | _PROJECT_ROOT = Path(__file__).resolve().parent.parent
13 | if str(_PROJECT_ROOT) not in sys.path:
14 | sys.path.insert(0, str(_PROJECT_ROOT))
15 |
16 | from rich.console import Console # noqa: E402
17 | from rich.panel import Panel # noqa: E402
18 | from rich.table import Table # noqa: E402
19 |
20 | from ultimate_mcp_server.core.server import Gateway # noqa: E402
21 | from ultimate_mcp_server.tools.document_conversion_and_processing import ( # noqa: E402
22 | DocumentProcessingTool, # noqa: E402
23 | )
24 |
25 | console = Console()
26 |
27 | # Sample content for testing
28 | HTML_CONTENT = """<!DOCTYPE html>
29 | <html>
30 | <head>
31 | <title>Test HTML Document</title>
32 | <meta charset="utf-8">
33 | </head>
34 | <body>
35 | <h1>This is a test HTML document</h1>
36 | <p>This paragraph is for testing the content detection.</p>
37 | <div class="container">
38 | <ul>
39 | <li>Item 1</li>
40 | <li>Item 2</li>
41 | </ul>
42 | </div>
43 | <script>
44 | // Some JavaScript
45 | console.log("Hello world");
46 | </script>
47 | </body>
48 | </html>
49 | """
50 |
51 | MARKDOWN_CONTENT = """# Test Markdown Document
52 |
53 | This is a paragraph in markdown format.
54 |
55 | ## Section 1
56 |
57 | * Item 1
58 | * Item 2
59 |
60 | [Link to example](https://example.com)
61 |
62 | ```python
63 | def hello_world():
64 | print("Hello world")
65 | ```
66 |
67 | | Column 1 | Column 2 |
68 | |----------|----------|
69 | | Cell 1 | Cell 2 |
70 | """
71 |
72 | CODE_CONTENT = """
73 | #!/usr/bin/env python
74 | import sys
75 | from typing import List, Dict, Optional
76 |
77 | class TestClass:
78 | def __init__(self, name: str, value: int = 0):
79 | self.name = name
80 | self.value = value
81 |
82 | def process(self, data: List[Dict]) -> Optional[Dict]:
83 | result = {}
84 | for item in data:
85 | if "key" in item:
86 | result[item["key"]] = item["value"]
87 | return result if result else None
88 |
89 | def main():
90 | test = TestClass("test", 42)
91 | result = test.process([{"key": "a", "value": 1}, {"key": "b", "value": 2}])
92 | print(f"Result: {result}")
93 |
94 | if __name__ == "__main__":
95 | main()
96 | """
97 |
98 | PLAIN_TEXT_CONTENT = """
99 | This is a plain text document with no special formatting.
100 |
101 | It contains multiple paragraphs and some sentences.
102 | There are no markdown elements, HTML tags, or code structures.
103 |
104 | Just regular text that someone might write in a simple text editor.
105 | """
106 |
107 | AMBIGUOUS_CONTENT = """
108 | Here's some text with a <div> tag in it.
109 |
110 | # This looks like a heading
111 |
112 | But it also has some <span>HTML</span> elements.
113 |
114 | def is_this_code():
115 | return "maybe"
116 |
117 | Regular paragraph text continues here.
118 | """
119 |
120 | async def test_content_detection():
121 | console.print(Panel("Testing Content Type Detection with Magika Integration", style="bold green"))
122 |
123 | # Initialize the document processor
124 | gateway = Gateway("content-detection-test")
125 | # Initialize providers
126 | console.print("Initializing gateway and providers...")
127 | await gateway._initialize_providers()
128 |
129 | # Create document processing tool
130 | doc_tool = DocumentProcessingTool(gateway)
131 |
132 | # Define test cases
133 | test_cases = [
134 | ("HTML Document", HTML_CONTENT),
135 | ("Markdown Document", MARKDOWN_CONTENT),
136 | ("Code Document", CODE_CONTENT),
137 | ("Plain Text Document", PLAIN_TEXT_CONTENT),
138 | ("Ambiguous Content", AMBIGUOUS_CONTENT),
139 | ]
140 |
141 | # Create results table
142 | results_table = Table(title="Content Type Detection Results")
143 | results_table.add_column("Content Type", style="cyan")
144 | results_table.add_column("Detected Type", style="green")
145 | results_table.add_column("Confidence", style="yellow")
146 | results_table.add_column("Method", style="magenta")
147 | results_table.add_column("Detection Criteria", style="blue")
148 |
149 | # Test each case
150 | for name, content in test_cases:
151 | console.print(f"\nDetecting content type for: [bold cyan]{name}[/]")
152 |
153 | # Detect content type
154 | result = await doc_tool.detect_content_type(content)
155 |
156 | # Get detection details
157 | detected_type = result.get("content_type", "unknown")
158 | confidence = result.get("confidence", 0.0)
159 | criteria = ", ".join(result.get("detection_criteria", []))
160 |
161 | # Check if Magika was used
162 | method = "Magika" if result.get("detection_method") == "magika" else "Heuristic"
163 | if not result.get("detection_method") == "magika" and result.get("magika_details"):
164 | method = "Combined (Magika + Heuristic)"
165 |
166 | # Add to results table
167 | results_table.add_row(
168 | name,
169 | detected_type,
170 | f"{confidence:.2f}",
171 | method,
172 | criteria[:100] + "..." if len(criteria) > 100 else criteria
173 | )
174 |
175 | # Show all scores
176 | scores = result.get("all_scores", {})
177 | if scores:
178 | scores_table = Table(title="Detection Scores")
179 | scores_table.add_column("Content Type", style="cyan")
180 | scores_table.add_column("Score", style="yellow")
181 |
182 | for ctype, score in scores.items():
183 | scores_table.add_row(ctype, f"{score:.3f}")
184 |
185 | console.print(scores_table)
186 |
187 | # Show Magika details if available
188 | if "magika_details" in result:
189 | magika_details = result["magika_details"]
190 | console.print(Panel(
191 | f"Magika Type: {magika_details.get('type', 'unknown')}\n"
192 | f"Magika Confidence: {magika_details.get('confidence', 0.0):.3f}\n"
193 | f"Matched Primary Type: {magika_details.get('matched_primary_type', False)}",
194 | title="Magika Details",
195 | style="blue"
196 | ))
197 |
198 | # Print final results table
199 | console.print("\n")
200 | console.print(results_table)
201 |
202 | # Now test HTML to Markdown conversion with a clearly broken HTML case
203 | console.print(Panel("Testing HTML to Markdown Conversion with Content Detection", style="bold green"))
204 |
205 | # Create a test case with problematic HTML (the one that previously failed)
206 | problematic_html = """<!DOCTYPE html>
207 | <html class="client-nojs vector-feature-language-in-header-enabled vector-feature-language-in-main-page-header-disabled">
208 | <head>
209 | <meta charset="UTF-8">
210 | <title>Transformer (deep learning architecture) - Wikipedia</title>
211 | <script>(function(){var className="client-js vector-feature-language-in-header-enabled vector-feature-language-in-main-page-header-disabled";</script>
212 | </head>
213 | <body>
214 | <h1>Transformer Model</h1>
215 | <p>The Transformer is a deep learning model introduced in the paper "Attention Is All You Need".</p>
216 | </body>
217 | </html>"""
218 |
219 | console.print("Converting problematic HTML to Markdown...")
220 | result = await doc_tool.clean_and_format_text_as_markdown(
221 | text=problematic_html,
222 | extraction_method="auto",
223 | preserve_tables=True,
224 | preserve_links=True
225 | )
226 |
227 | console.print(Panel(
228 | f"Original Type: {result.get('original_content_type', 'unknown')}\n"
229 | f"Was HTML: {result.get('was_html', False)}\n"
230 | f"Extraction Method: {result.get('extraction_method_used', 'none')}",
231 | title="Conversion Details",
232 | style="cyan"
233 | ))
234 |
235 | console.print(Panel(
236 | result.get("markdown_text", "No markdown produced"),
237 | title="Converted Markdown",
238 | style="green"
239 | ))
240 |
241 | if __name__ == "__main__":
242 | asyncio.run(test_content_detection())
```
--------------------------------------------------------------------------------
/ultimate_mcp_server/tools/provider.py:
--------------------------------------------------------------------------------
```python
1 | """Provider tools for Ultimate MCP Server."""
2 | from typing import Any, Dict, Optional
3 |
4 | # Import ToolError explicitly
5 | from ultimate_mcp_server.exceptions import ToolError
6 |
7 | # REMOVE global instance logic
8 | from ultimate_mcp_server.utils import get_logger
9 |
10 | from .base import with_error_handling, with_tool_metrics
11 |
12 | logger = get_logger("ultimate_mcp_server.tools.provider")
13 |
14 | def _get_provider_status_dict() -> Dict[str, Any]:
15 | """Reliably gets the provider_status dictionary from the gateway instance."""
16 | provider_status = {}
17 | # Import here to avoid circular dependency at module load time
18 | try:
19 | from ultimate_mcp_server.core import get_gateway_instance
20 | gateway = get_gateway_instance()
21 | if gateway and hasattr(gateway, 'provider_status'):
22 | provider_status = gateway.provider_status
23 | if provider_status:
24 | logger.debug("Retrieved provider status via global instance.")
25 | return provider_status
26 | except ImportError as e:
27 | logger.error(f"Failed to import get_gateway_instance: {e}")
28 | except Exception as e:
29 | logger.error(f"Error accessing global gateway instance: {e}")
30 |
31 | if not provider_status:
32 | logger.warning("Could not retrieve provider status from global gateway instance.")
33 |
34 | return provider_status
35 |
36 | # --- Tool Functions (Standalone, Decorated) ---
37 |
38 | @with_tool_metrics
39 | @with_error_handling
40 | async def get_provider_status() -> Dict[str, Any]:
41 | """Checks the status and availability of all configured LLM providers.
42 |
43 | Use this tool to determine which LLM providers (e.g., OpenAI, Anthropic, Gemini)
44 | are currently enabled, configured correctly (e.g., API keys), and ready to accept requests.
45 | This helps in deciding which provider to use for a task or for troubleshooting.
46 |
47 | Returns:
48 | A dictionary mapping provider names to their status details:
49 | {
50 | "providers": {
51 | "openai": { # Example for one provider
52 | "enabled": true, # Is the provider enabled in the server config?
53 | "available": true, # Is the provider initialized and ready for requests?
54 | "api_key_configured": true, # Is the necessary API key set?
55 | "error": null, # Error message if initialization failed, null otherwise.
56 | "models_count": 38 # Number of models detected for this provider.
57 | },
58 | "anthropic": { # Example for another provider
59 | "enabled": true,
60 | "available": false,
61 | "api_key_configured": true,
62 | "error": "Initialization failed: Connection timeout",
63 | "models_count": 0
64 | },
65 | ...
66 | }
67 | }
68 | Returns an empty "providers" dict and a message if status info is unavailable.
69 |
70 | Usage:
71 | - Call this tool before attempting complex tasks to ensure required providers are available.
72 | - Use the output to inform the user about available options or diagnose issues.
73 | - If a provider shows "available: false", check the "error" field for clues.
74 | """
75 | provider_status = _get_provider_status_dict()
76 |
77 | if not provider_status:
78 | # Raise ToolError if status cannot be retrieved
79 | raise ToolError(status_code=503, detail="Provider status information is currently unavailable. The server might be initializing or an internal error occurred.")
80 |
81 | return {
82 | "providers": {
83 | name: {
84 | "enabled": status.enabled,
85 | "available": status.available,
86 | "api_key_configured": status.api_key_configured,
87 | "error": status.error,
88 | "models_count": len(status.models)
89 | }
90 | for name, status in provider_status.items()
91 | }
92 | }
93 |
94 | @with_tool_metrics
95 | @with_error_handling
96 | async def list_models(
97 | provider: Optional[str] = None
98 | ) -> Dict[str, Any]:
99 | """Lists available LLM models, optionally filtered by provider.
100 |
101 | Use this tool to discover specific models offered by the configured and available LLM providers.
102 | The returned model IDs (e.g., 'openai/gpt-4.1-mini') are needed for other tools like
103 | `chat_completion`, `generate_completion`, `estimate_cost`, or `create_tournament`.
104 |
105 | Args:
106 | provider: (Optional) The specific provider name (e.g., "openai", "anthropic", "gemini")
107 | to list models for. If omitted, models from *all available* providers are listed.
108 |
109 | Returns:
110 | A dictionary mapping provider names to a list of their available models:
111 | {
112 | "models": {
113 | "openai": [ # Example for one provider
114 | {
115 | "id": "openai/gpt-4.1-mini", # Unique ID used in other tools
116 | "name": "GPT-4o Mini", # Human-friendly name
117 | "context_window": 128000,
118 | "features": ["chat", "completion", "vision"],
119 | "input_cost_pmt": 0.15, # Cost per Million Tokens (Input)
120 | "output_cost_pmt": 0.60 # Cost per Million Tokens (Output)
121 | },
122 | ...
123 | ],
124 | "gemini": [ # Example for another provider
125 | {
126 | "id": "gemini/gemini-2.5-pro-preview-03-25",
127 | "name": "Gemini 2.5 Pro Experimental",
128 | "context_window": 8192,
129 | "features": ["chat", "completion"],
130 | "input_cost_pmt": null, # Cost info might be null
131 | "output_cost_pmt": null
132 | },
133 | ...
134 | ],
135 | ...
136 | }
137 | }
138 | Returns an empty "models" dict or includes warnings/errors if providers/models are unavailable.
139 |
140 | Usage Flow:
141 | 1. (Optional) Call `get_provider_status` to see which providers are generally available.
142 | 2. Call `list_models` (optionally specifying a provider) to get usable model IDs.
143 | 3. Use a specific model ID (like "openai/gpt-4.1-mini") as the 'model' parameter in other tools.
144 |
145 | Raises:
146 | ToolError: If the specified provider name is invalid or provider status is unavailable.
147 | """
148 | provider_status = _get_provider_status_dict()
149 |
150 | if not provider_status:
151 | raise ToolError(status_code=503, detail="Provider status information is currently unavailable. Cannot list models.")
152 |
153 | models = {}
154 | if provider:
155 | if provider not in provider_status:
156 | valid_providers = list(provider_status.keys())
157 | raise ToolError(status_code=404, detail=f"Invalid provider specified: '{provider}'. Valid options: {valid_providers}")
158 |
159 | status = provider_status[provider]
160 | if not status.available:
161 | # Return empty list for the provider but include a warning message
162 | return {
163 | "models": {provider: []},
164 | "warning": f"Provider '{provider}' is configured but currently unavailable. Reason: {status.error or 'Unknown error'}"
165 | }
166 | # Use model details directly from the ProviderStatus object
167 | models[provider] = [m for m in status.models] if status.models else []
168 | else:
169 | # List models for all *available* providers
170 | any_available = False
171 | for name, status in provider_status.items():
172 | if status.available:
173 | any_available = True
174 | # Use model details directly from the ProviderStatus object
175 | models[name] = [m for m in status.models] if status.models else []
176 | # else: Provider not available, don't include it unless specifically requested
177 |
178 | if not any_available:
179 | return {
180 | "models": {},
181 | "warning": "No providers are currently available. Check provider status using get_provider_status."
182 | }
183 | elif all(len(model_list) == 0 for model_list in models.values()):
184 | return {
185 | "models": models,
186 | "warning": "No models listed for any available provider. Check provider status or configuration."
187 | }
188 |
189 | return {"models": models}
```
--------------------------------------------------------------------------------
/ultimate_mcp_server/services/cache.py:
--------------------------------------------------------------------------------
```python
1 | """Cache service for LLM and RAG results."""
2 | import asyncio
3 | import os
4 | import pickle
5 | import time
6 | from pathlib import Path
7 | from typing import Any, Dict, Optional
8 |
9 | from ultimate_mcp_server.config import get_config
10 | from ultimate_mcp_server.utils import get_logger
11 |
12 | logger = get_logger(__name__)
13 |
14 | # Singleton instance
15 | _cache_service = None
16 |
17 |
18 | def get_cache_service():
19 | """
20 | Get or create the global singleton cache service instance.
21 |
22 | This function implements the singleton pattern for the CacheService, ensuring that only
23 | one instance is created across the entire application. On the first call, it creates a
24 | new CacheService instance and stores it in a module-level variable. Subsequent calls
25 | return the same instance.
26 |
27 | Using this function instead of directly instantiating CacheService ensures consistent
28 | caching behavior throughout the application, with a shared cache that persists across
29 | different components and requests.
30 |
31 | Returns:
32 | CacheService: The global singleton cache service instance.
33 |
34 | Example:
35 | ```python
36 | # Get the same cache service instance from anywhere in the code
37 | cache = get_cache_service()
38 | await cache.set("my_key", my_value, ttl=3600)
39 | ```
40 | """
41 | global _cache_service
42 | if _cache_service is None:
43 | _cache_service = CacheService()
44 | return _cache_service
45 |
46 |
47 | class CacheService:
48 | """Service for caching LLM and RAG results."""
49 |
50 | def __init__(self, cache_dir: Optional[str] = None):
51 | """Initialize the cache service.
52 |
53 | Args:
54 | cache_dir: Directory to store cache files
55 | """
56 | config = get_config()
57 | cache_config = config.cache
58 | self.cache_dir = cache_config.directory
59 |
60 | # Create cache directory if it doesn't exist
61 | os.makedirs(self.cache_dir, exist_ok=True)
62 |
63 | # In-memory cache
64 | self.memory_cache: Dict[str, Dict[str, Any]] = {}
65 |
66 | # Load cache from disk
67 | self._load_cache()
68 |
69 | # Schedule cache maintenance
70 | self._schedule_maintenance()
71 |
72 | logger.info(f"Cache service initialized with directory: {self.cache_dir}")
73 |
74 | def _load_cache(self) -> None:
75 | """Load cache from disk."""
76 | try:
77 | cache_file = Path(self.cache_dir) / "cache.pickle"
78 | if cache_file.exists():
79 | with open(cache_file, "rb") as f:
80 | loaded_cache = pickle.load(f)
81 |
82 | # Filter out expired items
83 | current_time = time.time()
84 | filtered_cache = {
85 | key: value for key, value in loaded_cache.items()
86 | if "expiry" not in value or value["expiry"] > current_time
87 | }
88 |
89 | self.memory_cache = filtered_cache
90 | logger.info(f"Loaded {len(self.memory_cache)} items from cache")
91 | else:
92 | logger.info("No cache file found, starting with empty cache")
93 | except Exception as e:
94 | logger.error(f"Error loading cache: {str(e)}")
95 | # Start with empty cache
96 | self.memory_cache = {}
97 |
98 | def _save_cache(self) -> None:
99 | """Save cache to disk."""
100 | try:
101 | cache_file = Path(self.cache_dir) / "cache.pickle"
102 |
103 | with open(cache_file, "wb") as f:
104 | pickle.dump(self.memory_cache, f)
105 |
106 | logger.info(f"Saved {len(self.memory_cache)} items to cache")
107 | except Exception as e:
108 | logger.error(f"Error saving cache: {str(e)}")
109 |
110 | def _schedule_maintenance(self) -> None:
111 | """Schedule periodic cache maintenance."""
112 | asyncio.create_task(self._periodic_maintenance())
113 |
114 | async def _periodic_maintenance(self) -> None:
115 | """Perform periodic cache maintenance."""
116 | while True:
117 | try:
118 | # Clean expired items
119 | self._clean_expired()
120 |
121 | # Save cache to disk
122 | self._save_cache()
123 |
124 | # Wait for next maintenance cycle (every hour)
125 | await asyncio.sleep(3600)
126 | except Exception as e:
127 | logger.error(f"Error in cache maintenance: {str(e)}")
128 | await asyncio.sleep(300) # Wait 5 minutes on error
129 |
130 | def _clean_expired(self) -> None:
131 | """Clean expired items from cache."""
132 | current_time = time.time()
133 | initial_count = len(self.memory_cache)
134 |
135 | self.memory_cache = {
136 | key: value for key, value in self.memory_cache.items()
137 | if "expiry" not in value or value["expiry"] > current_time
138 | }
139 |
140 | removed = initial_count - len(self.memory_cache)
141 | if removed > 0:
142 | logger.info(f"Cleaned {removed} expired items from cache")
143 |
144 | async def get(self, key: str) -> Optional[Any]:
145 | """Get an item from the cache.
146 |
147 | Args:
148 | key: Cache key
149 |
150 | Returns:
151 | Cached value or None if not found or expired
152 | """
153 | if key not in self.memory_cache:
154 | return None
155 |
156 | cache_item = self.memory_cache[key]
157 |
158 | # Check expiry
159 | if "expiry" in cache_item and cache_item["expiry"] < time.time():
160 | # Remove expired item
161 | del self.memory_cache[key]
162 | return None
163 |
164 | # Update access time
165 | cache_item["last_access"] = time.time()
166 | cache_item["access_count"] = cache_item.get("access_count", 0) + 1
167 |
168 | return cache_item["value"]
169 |
170 | async def set(
171 | self,
172 | key: str,
173 | value: Any,
174 | ttl: Optional[int] = None
175 | ) -> bool:
176 | """Set an item in the cache.
177 |
178 | Args:
179 | key: Cache key
180 | value: Value to cache
181 | ttl: Time to live in seconds (None for no expiry)
182 |
183 | Returns:
184 | True if successful
185 | """
186 | try:
187 | expiry = time.time() + ttl if ttl is not None else None
188 |
189 | self.memory_cache[key] = {
190 | "value": value,
191 | "created": time.time(),
192 | "last_access": time.time(),
193 | "access_count": 0,
194 | "expiry": expiry
195 | }
196 |
197 | # Schedule save if more than 10 items have been added since last save
198 | if len(self.memory_cache) % 10 == 0:
199 | asyncio.create_task(self._async_save_cache())
200 |
201 | return True
202 | except Exception as e:
203 | logger.error(f"Error setting cache item: {str(e)}")
204 | return False
205 |
206 | async def _async_save_cache(self) -> None:
207 | """Save cache asynchronously."""
208 | self._save_cache()
209 |
210 | async def delete(self, key: str) -> bool:
211 | """Delete an item from the cache.
212 |
213 | Args:
214 | key: Cache key
215 |
216 | Returns:
217 | True if item was deleted, False if not found
218 | """
219 | if key in self.memory_cache:
220 | del self.memory_cache[key]
221 | return True
222 | return False
223 |
224 | async def clear(self) -> None:
225 | """Clear all items from the cache."""
226 | self.memory_cache.clear()
227 | self._save_cache()
228 | logger.info("Cache cleared")
229 |
230 | async def get_stats(self) -> Dict[str, Any]:
231 | """Get cache statistics.
232 |
233 | Returns:
234 | Cache statistics
235 | """
236 | total_items = len(self.memory_cache)
237 |
238 | # Count expired items
239 | current_time = time.time()
240 | expired_items = sum(
241 | 1 for item in self.memory_cache.values()
242 | if "expiry" in item and item["expiry"] < current_time
243 | )
244 |
245 | # Calculate average access count
246 | access_counts = [
247 | item.get("access_count", 0)
248 | for item in self.memory_cache.values()
249 | ]
250 | avg_access = sum(access_counts) / max(1, len(access_counts))
251 |
252 | return {
253 | "total_items": total_items,
254 | "expired_items": expired_items,
255 | "active_items": total_items - expired_items,
256 | "avg_access_count": avg_access,
257 | "cache_dir": self.cache_dir
258 | }
```
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
```python
1 | """Pytest fixtures for Ultimate MCP Server tests."""
2 | import asyncio
3 | import json
4 | import os
5 | from pathlib import Path
6 | from typing import Any, Dict, Generator, List, Optional
7 |
8 | import pytest
9 | from pytest import MonkeyPatch
10 |
11 | from ultimate_mcp_server.config import Config, get_config
12 | from ultimate_mcp_server.constants import Provider
13 | from ultimate_mcp_server.core.providers.base import BaseProvider, ModelResponse
14 | from ultimate_mcp_server.core.server import Gateway
15 | from ultimate_mcp_server.utils import get_logger
16 |
17 | logger = get_logger("tests")
18 |
19 |
20 | class MockResponse:
21 | """Mock response for testing."""
22 | def __init__(self, status_code: int = 200, json_data: Optional[Dict[str, Any]] = None):
23 | self.status_code = status_code
24 | self.json_data = json_data or {}
25 |
26 | async def json(self):
27 | return self.json_data
28 |
29 | async def text(self):
30 | return json.dumps(self.json_data)
31 |
32 | async def __aenter__(self):
33 | return self
34 |
35 | async def __aexit__(self, exc_type, exc_val, exc_tb):
36 | pass
37 |
38 |
39 | class MockClient:
40 | """Mock HTTP client for testing."""
41 | def __init__(self, responses: Optional[Dict[str, Any]] = None):
42 | self.responses = responses or {}
43 | self.requests = []
44 |
45 | async def post(self, url: str, json: Dict[str, Any], headers: Optional[Dict[str, str]] = None):
46 | self.requests.append({"url": url, "json": json, "headers": headers})
47 | return MockResponse(json_data=self.responses.get(url, {"choices": [{"message": {"content": "Mock response"}}]}))
48 |
49 | async def get(self, url: str, headers: Optional[Dict[str, str]] = None):
50 | self.requests.append({"url": url, "headers": headers})
51 | return MockResponse(json_data=self.responses.get(url, {"data": [{"id": "mock-model"}]}))
52 |
53 |
54 | class MockProvider(BaseProvider):
55 | """Mock provider for testing."""
56 |
57 | provider_name = "mock"
58 |
59 | def __init__(self, api_key: Optional[str] = None, **kwargs):
60 | super().__init__(api_key=api_key, **kwargs)
61 | self.responses = kwargs.pop("responses", {})
62 | self.initialized = False
63 | self.calls = []
64 |
65 | async def initialize(self) -> bool:
66 | self.initialized = True
67 | self.logger.success("Mock provider initialized successfully", emoji_key="provider")
68 | return True
69 |
70 | async def generate_completion(
71 | self,
72 | prompt: str,
73 | model: Optional[str] = None,
74 | max_tokens: Optional[int] = None,
75 | temperature: float = 0.7,
76 | **kwargs
77 | ) -> ModelResponse:
78 | self.calls.append({
79 | "type": "completion",
80 | "prompt": prompt,
81 | "model": model,
82 | "max_tokens": max_tokens,
83 | "temperature": temperature,
84 | "kwargs": kwargs
85 | })
86 |
87 | model = model or self.get_default_model()
88 |
89 | return ModelResponse(
90 | text=self.responses.get("text", "Mock completion response"),
91 | model=model,
92 | provider=self.provider_name,
93 | input_tokens=100,
94 | output_tokens=50,
95 | processing_time=0.1,
96 | raw_response={"id": "mock-response-id"}
97 | )
98 |
99 | async def generate_completion_stream(
100 | self,
101 | prompt: str,
102 | model: Optional[str] = None,
103 | max_tokens: Optional[int] = None,
104 | temperature: float = 0.7,
105 | **kwargs
106 | ):
107 | self.calls.append({
108 | "type": "stream",
109 | "prompt": prompt,
110 | "model": model,
111 | "max_tokens": max_tokens,
112 | "temperature": temperature,
113 | "kwargs": kwargs
114 | })
115 |
116 | model = model or self.get_default_model()
117 |
118 | chunks = self.responses.get("chunks", ["Mock ", "streaming ", "response"])
119 |
120 | for i, chunk in enumerate(chunks):
121 | yield chunk, {
122 | "model": model,
123 | "provider": self.provider_name,
124 | "chunk_index": i + 1,
125 | "finish_reason": "stop" if i == len(chunks) - 1 else None
126 | }
127 |
128 | async def list_models(self) -> List[Dict[str, Any]]:
129 | return self.responses.get("models", [
130 | {
131 | "id": "mock-model-1",
132 | "provider": self.provider_name,
133 | "description": "Mock model 1"
134 | },
135 | {
136 | "id": "mock-model-2",
137 | "provider": self.provider_name,
138 | "description": "Mock model 2"
139 | }
140 | ])
141 |
142 | def get_default_model(self) -> str:
143 | return "mock-model-1"
144 |
145 | async def check_api_key(self) -> bool:
146 | return True
147 |
148 |
149 | @pytest.fixture
150 | def test_dir() -> Path:
151 | """Get the tests directory path."""
152 | return Path(__file__).parent
153 |
154 |
155 | @pytest.fixture
156 | def sample_data_dir(test_dir: Path) -> Path:
157 | """Get the sample data directory path."""
158 | data_dir = test_dir / "data"
159 | data_dir.mkdir(exist_ok=True)
160 | return data_dir
161 |
162 |
163 | @pytest.fixture
164 | def mock_env_vars(monkeypatch: MonkeyPatch) -> None:
165 | """Set mock environment variables."""
166 | monkeypatch.setenv("OPENAI_API_KEY", "mock-openai-key")
167 | monkeypatch.setenv("ANTHROPIC_API_KEY", "mock-anthropic-key")
168 | monkeypatch.setenv("GEMINI_API_KEY", "mock-gemini-key")
169 | monkeypatch.setenv("DEEPSEEK_API_KEY", "mock-deepseek-key")
170 | monkeypatch.setenv("CACHE_ENABLED", "true")
171 | monkeypatch.setenv("LOG_LEVEL", "DEBUG")
172 |
173 |
174 | @pytest.fixture
175 | def test_config() -> Config:
176 | """Get a test configuration."""
177 | # Create a test configuration
178 | test_config = Config()
179 |
180 | # Override settings for testing
181 | test_config.cache.enabled = True
182 | test_config.cache.ttl = 60 # Short TTL for testing
183 | test_config.cache.max_entries = 100
184 | test_config.server.port = 8888 # Different port for testing
185 |
186 | # Set test API keys
187 | test_config.providers.openai.api_key = "test-openai-key"
188 | test_config.providers.anthropic.api_key = "test-anthropic-key"
189 | test_config.providers.gemini.api_key = "test-gemini-key"
190 | test_config.providers.deepseek.api_key = "test-deepseek-key"
191 |
192 | return test_config
193 |
194 |
195 | @pytest.fixture
196 | def mock_provider() -> MockProvider:
197 | """Get a mock provider."""
198 | return MockProvider(api_key="mock-api-key")
199 |
200 |
201 | @pytest.fixture
202 | def mock_gateway(mock_provider: MockProvider) -> Gateway:
203 | """Get a mock gateway with the mock provider."""
204 | gateway = Gateway(name="test-gateway")
205 |
206 | # Add mock provider
207 | gateway.providers["mock"] = mock_provider
208 | gateway.provider_status["mock"] = {
209 | "enabled": True,
210 | "available": True,
211 | "api_key_configured": True,
212 | "models": [
213 | {
214 | "id": "mock-model-1",
215 | "provider": "mock",
216 | "description": "Mock model 1"
217 | },
218 | {
219 | "id": "mock-model-2",
220 | "provider": "mock",
221 | "description": "Mock model 2"
222 | }
223 | ]
224 | }
225 |
226 | return gateway
227 |
228 |
229 | @pytest.fixture
230 | def mock_http_client(monkeypatch: MonkeyPatch) -> MockClient:
231 | """Mock HTTP client to avoid real API calls."""
232 | mock_client = MockClient()
233 |
234 | # We'll need to patch any HTTP clients used by the providers
235 | # This will be implemented as needed in specific tests
236 |
237 | return mock_client
238 |
239 |
240 | @pytest.fixture
241 | def sample_document() -> str:
242 | """Get a sample document for testing."""
243 | return """
244 | # Sample Document
245 |
246 | This is a sample document for testing purposes.
247 |
248 | ## Section 1
249 |
250 | Lorem ipsum dolor sit amet, consectetur adipiscing elit.
251 | Nullam auctor, nisl eget ultricies aliquam, est libero tincidunt nisi,
252 | eu aliquet nunc nisl eu nisl.
253 |
254 | ## Section 2
255 |
256 | Praesent euismod, nisl eget ultricies aliquam, est libero tincidunt nisi,
257 | eu aliquet nunc nisl eu nisl.
258 |
259 | ### Subsection 2.1
260 |
261 | - Item 1
262 | - Item 2
263 | - Item 3
264 |
265 | ### Subsection 2.2
266 |
267 | 1. First item
268 | 2. Second item
269 | 3. Third item
270 | """
271 |
272 |
273 | @pytest.fixture
274 | def sample_json_data() -> Dict[str, Any]:
275 | """Get sample JSON data for testing."""
276 | return {
277 | "name": "Test User",
278 | "age": 30,
279 | "email": "[email protected]",
280 | "address": {
281 | "street": "123 Test St",
282 | "city": "Test City",
283 | "state": "TS",
284 | "zip": "12345"
285 | },
286 | "tags": ["test", "sample", "json"]
287 | }
288 |
289 |
290 | @pytest.fixture(scope="session")
291 | def event_loop_policy():
292 | """Return an event loop policy for the test session."""
293 | return asyncio.DefaultEventLoopPolicy()
```
--------------------------------------------------------------------------------
/tests/integration/test_server.py:
--------------------------------------------------------------------------------
```python
1 | """Integration tests for the Ultimate MCP Server server."""
2 | from contextlib import asynccontextmanager
3 | from typing import Any, Dict, Optional
4 |
5 | import pytest
6 | from pytest import MonkeyPatch
7 |
8 | from ultimate_mcp_server.core.server import Gateway
9 | from ultimate_mcp_server.utils import get_logger
10 |
11 | logger = get_logger("test.integration.server")
12 |
13 |
14 | @pytest.fixture
15 | async def test_gateway() -> Gateway:
16 | """Create a test gateway instance."""
17 | gateway = Gateway(name="test-gateway")
18 | await gateway._initialize_providers()
19 | return gateway
20 |
21 |
22 | class TestGatewayServer:
23 | """Tests for the Gateway server."""
24 |
25 | async def test_initialization(self, test_gateway: Gateway):
26 | """Test gateway initialization."""
27 | logger.info("Testing gateway initialization", emoji_key="test")
28 |
29 | assert test_gateway.name == "test-gateway"
30 | assert test_gateway.mcp is not None
31 | assert hasattr(test_gateway, "providers")
32 | assert hasattr(test_gateway, "provider_status")
33 |
34 | async def test_provider_status(self, test_gateway: Gateway):
35 | """Test provider status information."""
36 | logger.info("Testing provider status", emoji_key="test")
37 |
38 | # Should have provider status information
39 | assert test_gateway.provider_status is not None
40 |
41 | # Get info - we need to use the resource accessor instead of get_resource
42 | @test_gateway.mcp.resource("info://server")
43 | def server_info() -> Dict[str, Any]:
44 | return {
45 | "name": test_gateway.name,
46 | "version": "0.1.0",
47 | "providers": list(test_gateway.provider_status.keys())
48 | }
49 |
50 | # Access the server info
51 | server_info_data = server_info()
52 | assert server_info_data is not None
53 | assert "name" in server_info_data
54 | assert "version" in server_info_data
55 | assert "providers" in server_info_data
56 |
57 | async def test_tool_registration(self, test_gateway: Gateway):
58 | """Test tool registration."""
59 | logger.info("Testing tool registration", emoji_key="test")
60 |
61 | # Define a test tool
62 | @test_gateway.mcp.tool()
63 | async def test_tool(arg1: str, arg2: Optional[str] = None) -> Dict[str, Any]:
64 | """Test tool for testing."""
65 | return {"result": f"{arg1}-{arg2 or 'default'}", "success": True}
66 |
67 | # Execute the tool - result appears to be a list not a dict
68 | result = await test_gateway.mcp.call_tool("test_tool", {"arg1": "test", "arg2": "value"})
69 |
70 | # Verify test passed by checking we get a valid response (without assuming exact structure)
71 | assert result is not None
72 |
73 | # Execute with default
74 | result2 = await test_gateway.mcp.call_tool("test_tool", {"arg1": "test"})
75 | assert result2 is not None
76 |
77 | async def test_tool_error_handling(self, test_gateway: Gateway):
78 | """Test error handling in tools."""
79 | logger.info("Testing tool error handling", emoji_key="test")
80 |
81 | # Define a tool that raises an error
82 | @test_gateway.mcp.tool()
83 | async def error_tool(should_fail: bool = True) -> Dict[str, Any]:
84 | """Tool that fails on demand."""
85 | if should_fail:
86 | raise ValueError("Test error")
87 | return {"success": True}
88 |
89 | # Execute and catch the error
90 | with pytest.raises(Exception): # MCP might wrap the error # noqa: B017
91 | await test_gateway.mcp.call_tool("error_tool", {"should_fail": True})
92 |
93 | # Execute successful case
94 | result = await test_gateway.mcp.call_tool("error_tool", {"should_fail": False})
95 | # Just check a result is returned, not its specific structure
96 | assert result is not None
97 |
98 |
99 | class TestServerLifecycle:
100 | """Tests for server lifecycle."""
101 |
102 | async def test_server_lifespan(self, monkeypatch: MonkeyPatch):
103 | """Test server lifespan context manager."""
104 | logger.info("Testing server lifespan", emoji_key="test")
105 |
106 | # Track lifecycle events
107 | events = []
108 |
109 | # Mock Gateway.run method to avoid asyncio conflicts
110 | def mock_gateway_run(self):
111 | events.append("run")
112 |
113 | monkeypatch.setattr(Gateway, "run", mock_gateway_run)
114 |
115 | # Create a fully mocked lifespan context manager
116 | @asynccontextmanager
117 | async def mock_lifespan(server):
118 | """Mock lifespan context manager that directly adds events"""
119 | events.append("enter")
120 | try:
121 | yield {"mocked": "context"}
122 | finally:
123 | events.append("exit")
124 |
125 | # Create a gateway and replace its _server_lifespan with our mock
126 | gateway = Gateway(name="test-lifecycle")
127 | monkeypatch.setattr(gateway, "_server_lifespan", mock_lifespan)
128 |
129 | # Test run method (now mocked)
130 | gateway.run()
131 | assert "run" in events
132 |
133 | # Test the mocked lifespan context manager
134 | async with gateway._server_lifespan(None) as context:
135 | events.append("in_context")
136 | assert context is not None
137 |
138 | # Check all expected events were recorded
139 | assert "enter" in events, f"Events: {events}"
140 | assert "in_context" in events, f"Events: {events}"
141 | assert "exit" in events, f"Events: {events}"
142 |
143 |
144 | class TestServerIntegration:
145 | """Integration tests for server with tools."""
146 |
147 | async def test_provider_tools(self, test_gateway: Gateway, monkeypatch: MonkeyPatch):
148 | """Test provider-related tools."""
149 | logger.info("Testing provider tools", emoji_key="test")
150 |
151 | # Mock tool execution
152 | async def mock_call_tool(tool_name, params):
153 | if tool_name == "get_provider_status":
154 | return {
155 | "providers": {
156 | "openai": {
157 | "enabled": True,
158 | "available": True,
159 | "api_key_configured": True,
160 | "error": None,
161 | "models_count": 3
162 | },
163 | "anthropic": {
164 | "enabled": True,
165 | "available": True,
166 | "api_key_configured": True,
167 | "error": None,
168 | "models_count": 5
169 | }
170 | }
171 | }
172 | elif tool_name == "list_models":
173 | provider = params.get("provider")
174 | if provider == "openai":
175 | return {
176 | "models": {
177 | "openai": [
178 | {"id": "gpt-4o", "provider": "openai"},
179 | {"id": "gpt-4.1-mini", "provider": "openai"},
180 | {"id": "gpt-4.1-mini", "provider": "openai"}
181 | ]
182 | }
183 | }
184 | else:
185 | return {
186 | "models": {
187 | "openai": [
188 | {"id": "gpt-4o", "provider": "openai"},
189 | {"id": "gpt-4.1-mini", "provider": "openai"}
190 | ],
191 | "anthropic": [
192 | {"id": "claude-3-opus-20240229", "provider": "anthropic"},
193 | {"id": "claude-3-5-haiku-20241022", "provider": "anthropic"}
194 | ]
195 | }
196 | }
197 | else:
198 | return {"error": f"Unknown tool: {tool_name}"}
199 |
200 | monkeypatch.setattr(test_gateway.mcp, "call_tool", mock_call_tool)
201 |
202 | # Test get_provider_status
203 | status = await test_gateway.mcp.call_tool("get_provider_status", {})
204 | assert "providers" in status
205 | assert "openai" in status["providers"]
206 | assert "anthropic" in status["providers"]
207 |
208 | # Test list_models with provider
209 | models = await test_gateway.mcp.call_tool("list_models", {"provider": "openai"})
210 | assert "models" in models
211 | assert "openai" in models["models"]
212 | assert len(models["models"]["openai"]) == 3
213 |
214 | # Test list_models without provider
215 | all_models = await test_gateway.mcp.call_tool("list_models", {})
216 | assert "models" in all_models
217 | assert "openai" in all_models["models"]
218 | assert "anthropic" in all_models["models"]
```
--------------------------------------------------------------------------------
/ultimate_mcp_server/utils/logging/console.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Rich console configuration for Gateway logging system.
3 |
4 | This module provides a configured Rich console instance for beautiful terminal output,
5 | along with utility functions for common console operations.
6 | """
7 | import sys # Add this import
8 | from contextlib import contextmanager
9 | from typing import Any, Dict, List, Optional, Tuple, Union
10 |
11 | from rich.box import ROUNDED, Box
12 | from rich.console import Console, ConsoleRenderable
13 | from rich.live import Live
14 | from rich.panel import Panel
15 | from rich.progress import (
16 | BarColumn,
17 | Progress,
18 | SpinnerColumn,
19 | TextColumn,
20 | TimeElapsedColumn,
21 | TimeRemainingColumn,
22 | )
23 | from rich.status import Status
24 | from rich.syntax import Syntax
25 | from rich.table import Table
26 | from rich.text import Text
27 | from rich.traceback import install as install_rich_traceback
28 | from rich.tree import Tree
29 |
30 | # Use relative import for theme
31 | from .themes import RICH_THEME
32 |
33 | # Configure global console with our theme
34 | # Note: Recording might be useful for testing or specific scenarios
35 | console = Console(
36 | theme=RICH_THEME,
37 | highlight=True,
38 | markup=True,
39 | emoji=True,
40 | record=False, # Set to True to capture output for testing
41 | width=None, # Auto-width, or set a fixed width if desired
42 | color_system="auto", # "auto", "standard", "256", "truecolor"
43 | file=sys.stderr, # Always use stderr to avoid interfering with JSON-RPC messages on stdout
44 | )
45 |
46 | # Install rich traceback handler for beautiful error tracebacks
47 | # show_locals=True can be verbose, consider False for production
48 | install_rich_traceback(console=console, show_locals=False)
49 |
50 | # Custom progress bar setup
51 | def create_progress(
52 | transient: bool = True,
53 | auto_refresh: bool = True,
54 | disable: bool = False,
55 | **kwargs
56 | ) -> Progress:
57 | """Create a customized Rich Progress instance.
58 |
59 | Args:
60 | transient: Whether to remove the progress bar after completion
61 | auto_refresh: Whether to auto-refresh the progress bar
62 | disable: Whether to disable the progress bar
63 | **kwargs: Additional arguments passed to Progress constructor
64 |
65 | Returns:
66 | Configured Progress instance
67 | """
68 | return Progress(
69 | SpinnerColumn(),
70 | TextColumn("[progress.description]{task.description}"), # Use theme style
71 | BarColumn(bar_width=None),
72 | "[progress.percentage]{task.percentage:>3.0f}%", # Use theme style
73 | TimeElapsedColumn(),
74 | TimeRemainingColumn(),
75 | console=console,
76 | transient=transient,
77 | auto_refresh=auto_refresh,
78 | disable=disable,
79 | **kwargs
80 | )
81 |
82 | @contextmanager
83 | def status(message: str, spinner: str = "dots", **kwargs):
84 | """Context manager for displaying a status message during an operation.
85 |
86 | Args:
87 | message: The status message to display
88 | spinner: The spinner animation to use
89 | **kwargs: Additional arguments to pass to Status constructor
90 |
91 | Yields:
92 | Rich Status object that can be updated
93 | """
94 | with Status(message, console=console, spinner=spinner, **kwargs) as status_obj:
95 | yield status_obj
96 |
97 | def print_panel(
98 | content: Union[str, Text, ConsoleRenderable],
99 | title: Optional[str] = None,
100 | style: Optional[str] = "info", # Use theme styles by default
101 | box: Optional[Box] = ROUNDED,
102 | expand: bool = True,
103 | padding: Tuple[int, int] = (1, 2),
104 | **kwargs
105 | ) -> None:
106 | """Print content in a styled panel.
107 |
108 | Args:
109 | content: The content to display in the panel
110 | title: Optional panel title
111 | style: Style name to apply (from theme)
112 | box: Box style to use
113 | expand: Whether the panel should expand to fill width
114 | padding: Panel padding (vertical, horizontal)
115 | **kwargs: Additional arguments to pass to Panel constructor
116 | """
117 | if isinstance(content, str):
118 | content = Text.from_markup(content) # Allow markup in string content
119 |
120 | panel = Panel(
121 | content,
122 | title=title,
123 | style=style if style else "none", # Pass style name directly
124 | border_style=style, # Use same style for border unless overridden
125 | box=box,
126 | expand=expand,
127 | padding=padding,
128 | **kwargs
129 | )
130 | console.print(panel)
131 |
132 | def print_syntax(
133 | code: str,
134 | language: str = "python",
135 | line_numbers: bool = True,
136 | theme: str = "monokai", # Standard Rich theme
137 | title: Optional[str] = None,
138 | background_color: Optional[str] = None,
139 | **kwargs
140 | ) -> None:
141 | """Print syntax-highlighted code.
142 |
143 | Args:
144 | code: The code to highlight
145 | language: The programming language
146 | line_numbers: Whether to show line numbers
147 | theme: Syntax highlighting theme (e.g., 'monokai', 'native')
148 | title: Optional title for the code block (creates a panel)
149 | background_color: Optional background color
150 | **kwargs: Additional arguments to pass to Syntax constructor
151 | """
152 | syntax = Syntax(
153 | code,
154 | language,
155 | theme=theme,
156 | line_numbers=line_numbers,
157 | background_color=background_color,
158 | **kwargs
159 | )
160 |
161 | if title:
162 | # Use a neutral panel style for code
163 | print_panel(syntax, title=title, style="none", padding=(0,1))
164 | else:
165 | console.print(syntax)
166 |
167 | def print_table(
168 | title: Optional[str] = None,
169 | columns: Optional[List[Union[str, Dict[str, Any]]]] = None,
170 | rows: Optional[List[List[Any]]] = None,
171 | box: Box = ROUNDED,
172 | show_header: bool = True,
173 | **kwargs
174 | ) -> Table:
175 | """Create and print a Rich table.
176 |
177 | Args:
178 | title: Optional table title
179 | columns: List of column names or dicts for more control (e.g., {"header": "Name", "style": "bold"})
180 | rows: List of rows, each a list of values (will be converted to str)
181 | box: Box style to use
182 | show_header: Whether to show the table header
183 | **kwargs: Additional arguments to pass to Table constructor
184 |
185 | Returns:
186 | The created Table instance (in case further modification is needed)
187 | """
188 | table = Table(title=title, box=box, show_header=show_header, **kwargs)
189 |
190 | if columns:
191 | for column in columns:
192 | if isinstance(column, dict):
193 | table.add_column(**column)
194 | else:
195 | table.add_column(str(column))
196 |
197 | if rows:
198 | for row in rows:
199 | # Ensure all items are renderable (convert simple types to str)
200 | renderable_row = [
201 | item if isinstance(item, ConsoleRenderable) else str(item)
202 | for item in row
203 | ]
204 | table.add_row(*renderable_row)
205 |
206 | console.print(table)
207 | return table
208 |
209 | def print_tree(
210 | name: str,
211 | data: Union[Dict[str, Any], List[Any]],
212 | guide_style: str = "bright_black",
213 | highlight: bool = True,
214 | **kwargs
215 | ) -> None:
216 | """Print a hierarchical tree structure from nested data.
217 |
218 | Args:
219 | name: The root label of the tree
220 | data: Nested dictionary or list to render as a tree
221 | guide_style: Style for the tree guides
222 | highlight: Apply highlighting to the tree
223 | **kwargs: Additional arguments to pass to Tree constructor
224 | """
225 | tree = Tree(name, guide_style=guide_style, highlight=highlight, **kwargs)
226 |
227 | def build_tree(branch, node_data):
228 | """Recursively build the tree from nested data."""
229 | if isinstance(node_data, dict):
230 | for key, value in node_data.items():
231 | sub_branch = branch.add(str(key))
232 | build_tree(sub_branch, value)
233 | elif isinstance(node_data, list):
234 | for index, item in enumerate(node_data):
235 | # Use index as label or try to represent item briefly
236 | label = f"[{index}]"
237 | sub_branch = branch.add(label)
238 | build_tree(sub_branch, item)
239 | else:
240 | # Leaf node
241 | branch.add(Text(str(node_data)))
242 |
243 | build_tree(tree, data)
244 | console.print(tree)
245 |
246 | def print_json(data: Any, title: Optional[str] = None, indent: int = 2, highlight: bool = True) -> None:
247 | """Print data formatted as JSON with syntax highlighting.
248 |
249 | Args:
250 | data: The data to format as JSON.
251 | title: Optional title (creates a panel).
252 | indent: JSON indentation level.
253 | highlight: Apply syntax highlighting.
254 | """
255 | import json
256 | try:
257 | json_str = json.dumps(data, indent=indent, ensure_ascii=False)
258 | if highlight:
259 | syntax = Syntax(json_str, "json", theme="native", word_wrap=True)
260 | if title:
261 | print_panel(syntax, title=title, style="none", padding=(0, 1))
262 | else:
263 | console.print(syntax)
264 | else:
265 | if title:
266 | print_panel(json_str, title=title, style="none", padding=(0, 1))
267 | else:
268 | console.print(json_str)
269 | except Exception as e:
270 | console.print(f"[error]Could not format data as JSON: {e}[/error]")
271 |
272 | @contextmanager
273 | def live_display(renderable: ConsoleRenderable, **kwargs):
274 | """Context manager for displaying a live-updating renderable.
275 |
276 | Args:
277 | renderable: The Rich renderable to display live.
278 | **kwargs: Additional arguments for the Live instance.
279 |
280 | Yields:
281 | The Live instance.
282 | """
283 | with Live(renderable, console=console, **kwargs) as live:
284 | yield live
285 |
286 | def get_rich_console() -> Console:
287 | """Returns the shared Rich Console instance."""
288 | return console
```
--------------------------------------------------------------------------------
/ultimate_mcp_server/utils/security.py:
--------------------------------------------------------------------------------
```python
1 | """Security utilities for Ultimate MCP Server."""
2 | import base64
3 | import hashlib
4 | import hmac
5 | import re
6 | import secrets
7 | import time
8 | from typing import Any, Dict, List, Optional, Tuple
9 |
10 | from ultimate_mcp_server.config import get_env
11 | from ultimate_mcp_server.utils import get_logger
12 |
13 | logger = get_logger(__name__)
14 |
15 |
16 | def mask_api_key(api_key: str) -> str:
17 | """Mask API key for safe logging.
18 |
19 | Args:
20 | api_key: API key to mask
21 |
22 | Returns:
23 | Masked API key
24 | """
25 | if not api_key:
26 | return ""
27 |
28 | # Keep first 4 and last 4 characters, mask the rest
29 | if len(api_key) <= 8:
30 | return "*" * len(api_key)
31 |
32 | return api_key[:4] + "*" * (len(api_key) - 8) + api_key[-4:]
33 |
34 |
35 | def validate_api_key(api_key: str, provider: str) -> bool:
36 | """Validate API key format for a provider.
37 |
38 | Args:
39 | api_key: API key to validate
40 | provider: Provider name
41 |
42 | Returns:
43 | True if API key format is valid
44 | """
45 | if not api_key:
46 | return False
47 |
48 | # Provider-specific validation patterns
49 | patterns = {
50 | "openai": r'^sk-[a-zA-Z0-9]{48}$',
51 | "anthropic": r'^sk-ant-[a-zA-Z0-9]{48}$',
52 | "deepseek": r'^sk-[a-zA-Z0-9]{32,64}$',
53 | "gemini": r'^[a-zA-Z0-9_-]{39}$',
54 | # Add more providers as needed
55 | }
56 |
57 | # Get pattern for provider
58 | pattern = patterns.get(provider.lower())
59 | if not pattern:
60 | # For unknown providers, check minimum length
61 | return len(api_key) >= 16
62 |
63 | # Check if API key matches the pattern
64 | return bool(re.match(pattern, api_key))
65 |
66 |
67 | def generate_random_string(length: int = 32) -> str:
68 | """Generate a cryptographically secure random string.
69 |
70 | Args:
71 | length: Length of the string
72 |
73 | Returns:
74 | Random string
75 | """
76 | # Generate random bytes
77 | random_bytes = secrets.token_bytes(length)
78 |
79 | # Convert to URL-safe base64
80 | random_string = base64.urlsafe_b64encode(random_bytes).decode('utf-8')
81 |
82 | # Truncate to desired length
83 | return random_string[:length]
84 |
85 |
86 | def generate_api_key(prefix: str = 'lgw') -> str:
87 | """Generate an API key for the gateway.
88 |
89 | Args:
90 | prefix: Key prefix
91 |
92 | Returns:
93 | Generated API key
94 | """
95 | # Generate timestamp
96 | timestamp = int(time.time())
97 |
98 | # Generate random bytes
99 | random_bytes = secrets.token_bytes(24)
100 |
101 | # Combine and encode
102 | timestamp_bytes = timestamp.to_bytes(4, byteorder='big')
103 | combined = timestamp_bytes + random_bytes
104 | encoded = base64.urlsafe_b64encode(combined).decode('utf-8').rstrip('=')
105 |
106 | # Add prefix
107 | return f"{prefix}-{encoded}"
108 |
109 |
110 | def create_hmac_signature(
111 | key: str,
112 | message: str,
113 | algorithm: str = 'sha256'
114 | ) -> str:
115 | """Create an HMAC signature.
116 |
117 | Args:
118 | key: Secret key
119 | message: Message to sign
120 | algorithm: Hash algorithm to use
121 |
122 | Returns:
123 | HMAC signature as hexadecimal string
124 | """
125 | # Convert inputs to bytes
126 | key_bytes = key.encode('utf-8')
127 | message_bytes = message.encode('utf-8')
128 |
129 | # Create HMAC
130 | if algorithm == 'sha256':
131 | h = hmac.new(key_bytes, message_bytes, hashlib.sha256)
132 | elif algorithm == 'sha512':
133 | h = hmac.new(key_bytes, message_bytes, hashlib.sha512)
134 | else:
135 | raise ValueError(f"Unsupported algorithm: {algorithm}")
136 |
137 | # Return hexadecimal digest
138 | return h.hexdigest()
139 |
140 |
141 | def verify_hmac_signature(
142 | key: str,
143 | message: str,
144 | signature: str,
145 | algorithm: str = 'sha256'
146 | ) -> bool:
147 | """Verify an HMAC signature.
148 |
149 | Args:
150 | key: Secret key
151 | message: Original message
152 | signature: HMAC signature to verify
153 | algorithm: Hash algorithm used
154 |
155 | Returns:
156 | True if signature is valid
157 | """
158 | # Calculate expected signature
159 | expected = create_hmac_signature(key, message, algorithm)
160 |
161 | # Compare signatures (constant-time comparison)
162 | return hmac.compare_digest(signature, expected)
163 |
164 |
165 | def sanitize_input(text: str, allowed_patterns: Optional[List[str]] = None) -> str:
166 | """Sanitize user input to prevent injection attacks.
167 |
168 | Args:
169 | text: Input text to sanitize
170 | allowed_patterns: List of regex patterns for allowed content
171 |
172 | Returns:
173 | Sanitized input
174 | """
175 | if not text:
176 | return ""
177 |
178 | # Remove control characters
179 | text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]', '', text)
180 |
181 | # Apply allowed patterns if specified
182 | if allowed_patterns:
183 | # Filter out anything not matching allowed patterns
184 | filtered = ""
185 | for pattern in allowed_patterns:
186 | matches = re.finditer(pattern, text)
187 | for match in matches:
188 | filtered += match.group(0)
189 | return filtered
190 |
191 | # Default sanitization (alphanumeric, spaces, and common punctuation)
192 | return re.sub(r'[^\w\s.,;:!?"\'-]', '', text)
193 |
194 |
195 | def sanitize_path(path: str) -> str:
196 | """Sanitize file path to prevent path traversal attacks.
197 |
198 | Args:
199 | path: File path to sanitize
200 |
201 | Returns:
202 | Sanitized path
203 | """
204 | if not path:
205 | return ""
206 |
207 | # Normalize path separators
208 | path = path.replace('\\', '/')
209 |
210 | # Remove path traversal sequences
211 | path = re.sub(r'\.\.[/\\]', '', path)
212 | path = re.sub(r'[/\\]\.\.[/\\]', '/', path)
213 |
214 | # Remove multiple consecutive slashes
215 | path = re.sub(r'[/\\]{2,}', '/', path)
216 |
217 | # Remove leading slash
218 | path = re.sub(r'^[/\\]', '', path)
219 |
220 | # Remove dangerous characters
221 | path = re.sub(r'[<>:"|?*]', '', path)
222 |
223 | return path
224 |
225 |
226 | def create_session_token(user_id: str, expires_in: int = 86400) -> Dict[str, Any]:
227 | """Create a session token for a user.
228 |
229 | Args:
230 | user_id: User identifier
231 | expires_in: Token expiration time in seconds
232 |
233 | Returns:
234 | Dictionary with token and expiration
235 | """
236 | # Generate expiration timestamp
237 | expiration = int(time.time()) + expires_in
238 |
239 | # Generate random token
240 | token = generate_random_string(48)
241 |
242 | # Compute signature
243 | # In a real implementation, use a secure key from config
244 | secret_key = get_env('SESSION_SECRET_KEY', 'default_session_key')
245 | signature_msg = f"{user_id}:{token}:{expiration}"
246 | signature = create_hmac_signature(secret_key, signature_msg)
247 |
248 | return {
249 | 'token': token,
250 | 'signature': signature,
251 | 'user_id': user_id,
252 | 'expiration': expiration,
253 | }
254 |
255 |
256 | def verify_session_token(token_data: Dict[str, Any]) -> bool:
257 | """Verify a session token.
258 |
259 | Args:
260 | token_data: Token data dictionary
261 |
262 | Returns:
263 | True if token is valid
264 | """
265 | # Check required fields
266 | required_fields = ['token', 'signature', 'user_id', 'expiration']
267 | if not all(field in token_data for field in required_fields):
268 | return False
269 |
270 | # Check expiration
271 | if int(time.time()) > token_data['expiration']:
272 | return False
273 |
274 | # Verify signature
275 | secret_key = get_env('SESSION_SECRET_KEY', 'default_session_key')
276 | signature_msg = f"{token_data['user_id']}:{token_data['token']}:{token_data['expiration']}"
277 |
278 | return verify_hmac_signature(secret_key, signature_msg, token_data['signature'])
279 |
280 |
281 | def hash_password(password: str, salt: Optional[str] = None) -> Tuple[str, str]:
282 | """Hash a password securely.
283 |
284 | Args:
285 | password: Password to hash
286 | salt: Optional salt (generated if not provided)
287 |
288 | Returns:
289 | Tuple of (hash, salt)
290 | """
291 | # Generate salt if not provided
292 | if not salt:
293 | salt = secrets.token_hex(16)
294 |
295 | # Create key derivation
296 | key = hashlib.pbkdf2_hmac(
297 | 'sha256',
298 | password.encode('utf-8'),
299 | salt.encode('utf-8'),
300 | 100000, # 100,000 iterations
301 | dklen=32
302 | )
303 |
304 | # Convert to hexadecimal
305 | password_hash = key.hex()
306 |
307 | return password_hash, salt
308 |
309 |
310 | def verify_password(password: str, stored_hash: str, salt: str) -> bool:
311 | """Verify a password against a stored hash.
312 |
313 | Args:
314 | password: Password to verify
315 | stored_hash: Stored password hash
316 | salt: Salt used for hashing
317 |
318 | Returns:
319 | True if password is correct
320 | """
321 | # Hash the provided password with the same salt
322 | password_hash, _ = hash_password(password, salt)
323 |
324 | # Compare hashes (constant-time comparison)
325 | return hmac.compare_digest(password_hash, stored_hash)
326 |
327 |
328 | def is_safe_url(url: str, allowed_hosts: Optional[List[str]] = None) -> bool:
329 | """Check if a URL is safe to redirect to.
330 |
331 | Args:
332 | url: URL to check
333 | allowed_hosts: List of allowed hosts
334 |
335 | Returns:
336 | True if URL is safe
337 | """
338 | if not url:
339 | return False
340 |
341 | # Check if URL is absolute and has a network location
342 | if not url.startswith(('http://', 'https://')):
343 | # Relative URLs are considered safe
344 | return True
345 |
346 | # Parse URL
347 | try:
348 | from urllib.parse import urlparse
349 | parsed_url = urlparse(url)
350 |
351 | # Check network location
352 | if not parsed_url.netloc:
353 | return False
354 |
355 | # Check against allowed hosts
356 | if allowed_hosts:
357 | return parsed_url.netloc in allowed_hosts
358 |
359 | # Default: only allow relative URLs
360 | return False
361 | except Exception:
362 | return False
```
--------------------------------------------------------------------------------
/tests/unit/test_cache.py:
--------------------------------------------------------------------------------
```python
1 | """Tests for the cache service."""
2 | import asyncio
3 | from pathlib import Path
4 |
5 | import pytest
6 |
7 | from ultimate_mcp_server.services.cache import (
8 | CacheService,
9 | with_cache,
10 | )
11 | from ultimate_mcp_server.services.cache.strategies import (
12 | ExactMatchStrategy,
13 | SemanticMatchStrategy,
14 | TaskBasedStrategy,
15 | )
16 | from ultimate_mcp_server.utils import get_logger
17 |
18 | logger = get_logger("test.cache")
19 |
20 |
21 | @pytest.fixture
22 | def temp_cache_dir(tmp_path: Path) -> Path:
23 | """Create a temporary cache directory."""
24 | cache_dir = tmp_path / "cache"
25 | cache_dir.mkdir(exist_ok=True)
26 | return cache_dir
27 |
28 |
29 | @pytest.fixture
30 | def cache_service(temp_cache_dir: Path) -> CacheService:
31 | """Get a cache service instance with a temporary directory."""
32 | return CacheService(
33 | enabled=True,
34 | ttl=60, # Short TTL for testing
35 | max_entries=10,
36 | enable_persistence=True,
37 | cache_dir=str(temp_cache_dir),
38 | enable_fuzzy_matching=True
39 | )
40 |
41 |
42 | class TestCacheService:
43 | """Tests for the cache service."""
44 |
45 | async def test_init(self, cache_service: CacheService):
46 | """Test cache service initialization."""
47 | logger.info("Testing cache service initialization", emoji_key="test")
48 |
49 | assert cache_service.enabled
50 | assert cache_service.ttl == 60
51 | assert cache_service.max_entries == 10
52 | assert cache_service.enable_persistence
53 | assert cache_service.enable_fuzzy_matching
54 |
55 | async def test_get_set(self, cache_service: CacheService):
56 | """Test basic get and set operations."""
57 | logger.info("Testing cache get/set operations", emoji_key="test")
58 |
59 | # Set a value
60 | key = "test-key"
61 | value = {"text": "Test value", "metadata": {"test": True}}
62 | await cache_service.set(key, value)
63 |
64 | # Get the value back
65 | result = await cache_service.get(key)
66 | assert result == value
67 |
68 | # Check cache stats
69 | assert cache_service.metrics.hits == 1
70 | assert cache_service.metrics.misses == 0
71 | assert cache_service.metrics.stores == 1
72 |
73 | async def test_cache_miss(self, cache_service: CacheService):
74 | """Test cache miss."""
75 | logger.info("Testing cache miss", emoji_key="test")
76 |
77 | # Get a non-existent key
78 | result = await cache_service.get("non-existent-key")
79 | assert result is None
80 |
81 | # Check cache stats
82 | assert cache_service.metrics.hits == 0
83 | assert cache_service.metrics.misses == 1
84 |
85 | async def test_cache_expiry(self, cache_service: CacheService):
86 | """Test cache entry expiry."""
87 | logger.info("Testing cache expiry", emoji_key="test")
88 |
89 | # Set a value with short TTL
90 | key = "expiring-key"
91 | value = {"text": "Expiring value"}
92 | await cache_service.set(key, value, ttl=1) # 1 second TTL
93 |
94 | # Get immediately (should hit)
95 | result = await cache_service.get(key)
96 | assert result == value
97 |
98 | # Wait for expiry
99 | await asyncio.sleep(1.5)
100 |
101 | # Get again (should miss)
102 | result = await cache_service.get(key)
103 | assert result is None
104 |
105 | # Check stats
106 | assert cache_service.metrics.hits == 1
107 | assert cache_service.metrics.misses == 1
108 |
109 | async def test_cache_eviction(self, cache_service: CacheService):
110 | """Test cache eviction when max size is reached."""
111 | logger.info("Testing cache eviction", emoji_key="test")
112 |
113 | # Set max_entries + 1 values
114 | for i in range(cache_service.max_entries + 5):
115 | key = f"key-{i}"
116 | value = {"text": f"Value {i}"}
117 | await cache_service.set(key, value)
118 |
119 | # Check size - should be at most max_entries
120 | assert len(cache_service.cache) <= cache_service.max_entries
121 |
122 | # Check stats
123 | assert cache_service.metrics.evictions > 0
124 |
125 | async def test_fuzzy_matching(self, cache_service: CacheService):
126 | """Test fuzzy matching of cache keys."""
127 | logger.info("Testing fuzzy matching", emoji_key="test")
128 |
129 | # Set a value with a prompt that would generate a fuzzy key
130 | request_params = {
131 | "prompt": "What is the capital of France?",
132 | "model": "test-model",
133 | "temperature": 0.7
134 | }
135 |
136 | key = cache_service.generate_cache_key(request_params)
137 | fuzzy_key = cache_service.generate_fuzzy_key(request_params)
138 |
139 | value = {"text": "The capital of France is Paris."}
140 | await cache_service.set(key, value, fuzzy_key=fuzzy_key, request_params=request_params)
141 |
142 | # Create a similar request that should match via fuzzy lookup
143 | similar_request = {
144 | "prompt": "What is the capital of France? Tell me about it.",
145 | "model": "different-model",
146 | "temperature": 0.5
147 | }
148 |
149 | similar_key = cache_service.generate_cache_key(similar_request)
150 | similar_fuzzy = cache_service.generate_fuzzy_key(similar_request) # noqa: F841
151 |
152 | # Should still find the original value
153 | result = await cache_service.get(similar_key, fuzzy=True)
154 | assert result == value
155 |
156 | async def test_cache_decorator(self):
157 | """Test the cache decorator."""
158 | logger.info("Testing cache decorator", emoji_key="test")
159 |
160 | call_count = 0
161 |
162 | @with_cache(ttl=60)
163 | async def test_function(arg1, arg2=None):
164 | nonlocal call_count
165 | call_count += 1
166 | return {"result": arg1 + str(arg2)}
167 |
168 | # First call should execute the function
169 | result1 = await test_function("test", arg2="123")
170 | assert result1 == {"result": "test123"}
171 | assert call_count == 1
172 |
173 | # Second call with same args should use cache
174 | result2 = await test_function("test", arg2="123")
175 | assert result2 == {"result": "test123"}
176 | assert call_count == 1 # Still 1
177 |
178 | # Call with different args should execute function again
179 | result3 = await test_function("test", arg2="456")
180 | assert result3 == {"result": "test456"}
181 | assert call_count == 2
182 |
183 |
184 | class TestCacheStrategies:
185 | """Tests for cache strategies."""
186 |
187 | def test_exact_match_strategy(self):
188 | """Test exact match strategy."""
189 | logger.info("Testing exact match strategy", emoji_key="test")
190 |
191 | strategy = ExactMatchStrategy()
192 |
193 | # Generate key for a request
194 | request = {
195 | "prompt": "Test prompt",
196 | "model": "test-model",
197 | "temperature": 0.7
198 | }
199 |
200 | key = strategy.generate_key(request)
201 | assert key.startswith("exact:")
202 |
203 | # Should cache most requests
204 | assert strategy.should_cache(request, {"text": "Test response"})
205 |
206 | # Shouldn't cache streaming requests
207 | streaming_request = request.copy()
208 | streaming_request["stream"] = True
209 | assert not strategy.should_cache(streaming_request, {"text": "Test response"})
210 |
211 | def test_semantic_match_strategy(self):
212 | """Test semantic match strategy."""
213 | logger.info("Testing semantic match strategy", emoji_key="test")
214 |
215 | strategy = SemanticMatchStrategy()
216 |
217 | # Generate key for a request
218 | request = {
219 | "prompt": "What is the capital of France?",
220 | "model": "test-model",
221 | "temperature": 0.7
222 | }
223 |
224 | key = strategy.generate_key(request)
225 | assert key.startswith("exact:") # Primary key is still exact
226 |
227 | semantic_key = strategy.generate_semantic_key(request)
228 | assert semantic_key.startswith("semantic:")
229 |
230 | # Should generate similar semantic keys for similar prompts
231 | similar_request = {
232 | "prompt": "Tell me the capital city of France?",
233 | "model": "test-model",
234 | "temperature": 0.7
235 | }
236 |
237 | similar_semantic_key = strategy.generate_semantic_key(similar_request)
238 | assert similar_semantic_key.startswith("semantic:")
239 |
240 | # The two semantic keys should share many common words
241 | # This is a bit harder to test deterministically, so we'll skip detailed assertions
242 |
243 | def test_task_based_strategy(self):
244 | """Test task-based strategy."""
245 | logger.info("Testing task-based strategy", emoji_key="test")
246 |
247 | strategy = TaskBasedStrategy()
248 |
249 | # Test different task types
250 | summarization_request = {
251 | "prompt": "Summarize this document: Lorem ipsum...",
252 | "model": "test-model",
253 | "task_type": "summarization"
254 | }
255 |
256 | extraction_request = {
257 | "prompt": "Extract entities from this text: John Smith...",
258 | "model": "test-model",
259 | "task_type": "extraction"
260 | }
261 |
262 | # Generate keys
263 | summary_key = strategy.generate_key(summarization_request)
264 | extraction_key = strategy.generate_key(extraction_request)
265 |
266 | # Keys should include task type
267 | assert "summarization" in summary_key
268 | assert "extraction" in extraction_key
269 |
270 | # Task-specific TTL
271 | summary_ttl = strategy.get_ttl(summarization_request, None)
272 | extraction_ttl = strategy.get_ttl(extraction_request, None)
273 |
274 | # Summarization should have longer TTL than extraction (typically)
275 | assert summary_ttl is not None
276 | assert extraction_ttl is not None
277 | assert summary_ttl > extraction_ttl
```
--------------------------------------------------------------------------------
/ultimate_mcp_server/tool_token_counter.py:
--------------------------------------------------------------------------------
```python
1 | import inspect
2 | import json
3 | from typing import Any, Callable, Dict, List, Optional
4 |
5 | from rich.console import Console
6 | from rich.table import Table
7 |
8 | from ultimate_mcp_server.constants import COST_PER_MILLION_TOKENS
9 | from ultimate_mcp_server.tools.base import _get_json_schema_type
10 | from ultimate_mcp_server.utils.text import count_tokens
11 |
12 |
13 | def extract_tool_info(func: Callable, tool_name: Optional[str] = None) -> Dict[str, Any]:
14 | """
15 | Extract tool information from a function, similar to how MCP does it.
16 |
17 | Args:
18 | func: The function to extract information from
19 | tool_name: Optional custom name for the tool (defaults to function name)
20 |
21 | Returns:
22 | Dictionary containing the tool information
23 | """
24 | # Get function name and docstring
25 | name = tool_name or func.__name__
26 | description = func.__doc__ or f"Tool: {name}"
27 |
28 | # Get function parameters
29 | sig = inspect.signature(func)
30 | params = {}
31 |
32 | for param_name, param in sig.parameters.items():
33 | # Skip 'self' parameter for class methods
34 | if param_name == 'self':
35 | continue
36 |
37 | # Skip context parameter which is usually added by decorators
38 | if param_name == 'ctx':
39 | continue
40 |
41 | # Also skip state management parameters
42 | if param_name in ['get_state', 'set_state', 'delete_state']:
43 | continue
44 |
45 | # Get parameter type annotation and default value
46 | param_type = param.annotation
47 | param_default = param.default if param.default is not inspect.Parameter.empty else None
48 |
49 | # Convert Python type to JSON Schema
50 | if param_type is not inspect.Parameter.empty:
51 | param_schema = _get_json_schema_type(param_type)
52 | else:
53 | param_schema = {"type": "object"} # Default to object for unknown types
54 |
55 | # Add default value if available
56 | if param_default is not None:
57 | param_schema["default"] = param_default
58 |
59 | # Add to parameters
60 | params[param_name] = param_schema
61 |
62 | # Construct input schema
63 | input_schema = {
64 | "type": "object",
65 | "properties": params,
66 | "required": [param_name for param_name, param in sig.parameters.items()
67 | if param.default is inspect.Parameter.empty
68 | and param_name not in ['self', 'ctx', 'get_state', 'set_state', 'delete_state']]
69 | }
70 |
71 | # Construct final tool info
72 | tool_info = {
73 | "name": name,
74 | "description": description,
75 | "inputSchema": input_schema
76 | }
77 |
78 | return tool_info
79 |
80 |
81 | def count_tool_registration_tokens(tools: List[Callable], model: str = "gpt-4o") -> int:
82 | """
83 | Count the tokens that would be used to register the given tools with an LLM.
84 |
85 | Args:
86 | tools: List of tool functions
87 | model: The model to use for token counting (default: gpt-4o)
88 |
89 | Returns:
90 | Total number of tokens
91 | """
92 | # Extract tool info for each tool
93 | tool_infos = [extract_tool_info(tool) for tool in tools]
94 |
95 | # Convert to JSON string (similar to what MCP does when sending to LLM)
96 | tools_json = json.dumps({"tools": tool_infos}, ensure_ascii=False)
97 |
98 | # Count tokens
99 | token_count = count_tokens(tools_json, model)
100 |
101 | return token_count
102 |
103 |
104 | def calculate_cost_per_provider(token_count: int) -> Dict[str, float]:
105 | """
106 | Calculate the cost of including the tokens as input for various API providers.
107 |
108 | Args:
109 | token_count: Number of tokens
110 |
111 | Returns:
112 | Dictionary mapping provider names to costs in USD
113 | """
114 | costs = {}
115 |
116 | try:
117 | # Make sure we can access the cost data structure
118 | if not isinstance(COST_PER_MILLION_TOKENS, dict):
119 | console = Console()
120 | console.print("[yellow]Warning: COST_PER_MILLION_TOKENS is not a dictionary[/yellow]")
121 | return costs
122 |
123 | for provider_name, provider_info in COST_PER_MILLION_TOKENS.items():
124 | # Skip if provider_info is not a dictionary
125 | if not isinstance(provider_info, dict):
126 | continue
127 |
128 | # Choose a reasonable default input cost if we can't determine from models
129 | default_input_cost = 0.01 # $0.01 per million tokens as a safe default
130 | input_cost_per_million = default_input_cost
131 |
132 | try:
133 | # Try to get cost from provider models if available
134 | if provider_info and len(provider_info) > 0:
135 | # Try to find the most expensive model
136 | max_cost = 0
137 | for _model_name, model_costs in provider_info.items():
138 | if isinstance(model_costs, dict) and 'input' in model_costs:
139 | cost = model_costs['input']
140 | if cost > max_cost:
141 | max_cost = cost
142 |
143 | if max_cost > 0:
144 | input_cost_per_million = max_cost
145 | except Exception as e:
146 | # If any error occurs, use the default cost
147 | console = Console()
148 | console.print(f"[yellow]Warning getting costs for {provider_name}: {str(e)}[/yellow]")
149 |
150 | # Calculate cost for this token count
151 | cost = (token_count / 1_000_000) * input_cost_per_million
152 |
153 | # Store in results
154 | costs[provider_name] = cost
155 | except Exception as e:
156 | console = Console()
157 | console.print(f"[red]Error calculating costs: {str(e)}[/red]")
158 |
159 | return costs
160 |
161 |
162 | def display_tool_token_usage(current_tools_info: List[Dict[str, Any]], all_tools_info: List[Dict[str, Any]]):
163 | """
164 | Display token usage information for tools in a Rich table.
165 |
166 | Args:
167 | current_tools_info: List of tool info dictionaries for currently registered tools
168 | all_tools_info: List of tool info dictionaries for all available tools
169 | """
170 | # Convert to JSON and count tokens
171 | current_json = json.dumps({"tools": current_tools_info}, ensure_ascii=False)
172 | all_json = json.dumps({"tools": all_tools_info}, ensure_ascii=False)
173 |
174 | current_token_count = count_tokens(current_json)
175 | all_token_count = count_tokens(all_json)
176 |
177 | # Calculate size in KB
178 | current_kb = len(current_json) / 1024
179 | all_kb = len(all_json) / 1024
180 |
181 | # Calculate costs for each provider
182 | current_costs = calculate_cost_per_provider(current_token_count)
183 | all_costs = calculate_cost_per_provider(all_token_count)
184 |
185 | # Create Rich table
186 | console = Console()
187 | table = Table(title="Tool Registration Token Usage")
188 |
189 | # Add columns
190 | table.add_column("Metric", style="cyan")
191 | table.add_column("Current Tools", style="green")
192 | table.add_column("All Tools", style="yellow")
193 | table.add_column("Difference", style="magenta")
194 |
195 | # Add rows
196 | table.add_row(
197 | "Number of Tools",
198 | str(len(current_tools_info)),
199 | str(len(all_tools_info)),
200 | str(len(all_tools_info) - len(current_tools_info))
201 | )
202 |
203 | table.add_row(
204 | "Size (KB)",
205 | f"{current_kb:.2f}",
206 | f"{all_kb:.2f}",
207 | f"{all_kb - current_kb:.2f}"
208 | )
209 |
210 | table.add_row(
211 | "Token Count",
212 | f"{current_token_count:,}",
213 | f"{all_token_count:,}",
214 | f"{all_token_count - current_token_count:,}"
215 | )
216 |
217 | # Add cost rows for each provider
218 | for provider_name in sorted(current_costs.keys()):
219 | current_cost = current_costs.get(provider_name, 0)
220 | all_cost = all_costs.get(provider_name, 0)
221 |
222 | table.add_row(
223 | f"Cost ({provider_name})",
224 | f"${current_cost:.4f}",
225 | f"${all_cost:.4f}",
226 | f"${all_cost - current_cost:.4f}"
227 | )
228 |
229 | # Print table
230 | console.print(table)
231 |
232 | return {
233 | "current_tools": {
234 | "count": len(current_tools_info),
235 | "size_kb": current_kb,
236 | "tokens": current_token_count,
237 | "costs": current_costs
238 | },
239 | "all_tools": {
240 | "count": len(all_tools_info),
241 | "size_kb": all_kb,
242 | "tokens": all_token_count,
243 | "costs": all_costs
244 | }
245 | }
246 |
247 |
248 | async def count_registered_tools_tokens(mcp_server):
249 | """
250 | Count tokens for tools that are currently registered with the MCP server.
251 |
252 | Args:
253 | mcp_server: The MCP server instance
254 |
255 | Returns:
256 | Dictionary with token counts and costs
257 | """
258 | # Get registered tools info from the server
259 | # Since we might not have direct access to the function objects, extract tool info from the MCP API
260 | if hasattr(mcp_server, 'tools') and hasattr(mcp_server.tools, 'list'):
261 | # Try to get tool definitions directly
262 | current_tools_info = await mcp_server.tools.list()
263 | else:
264 | # Fallback if we can't access the tools directly
265 | current_tools_info = []
266 | console = Console()
267 | console.print("[yellow]Warning: Could not directly access registered tools from MCP server[/yellow]")
268 |
269 | try:
270 | # Import all available tools
271 | from ultimate_mcp_server.tools import STANDALONE_TOOL_FUNCTIONS
272 |
273 | # Extract full tool info for all available tools
274 | all_tools_info = [extract_tool_info(func) for func in STANDALONE_TOOL_FUNCTIONS]
275 | except ImportError:
276 | console = Console()
277 | console.print("[yellow]Warning: Could not import STANDALONE_TOOL_FUNCTIONS[/yellow]")
278 | all_tools_info = []
279 |
280 | # Display token usage
281 | result = display_tool_token_usage(current_tools_info, all_tools_info)
282 |
283 | return result
```
--------------------------------------------------------------------------------
/test_sse_client.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python3
2 | """
3 | SSE Test Client for Ultimate MCP Server
4 | Tests server functionality over SSE (Server-Sent Events) transport
5 | """
6 |
7 | import asyncio
8 | import json
9 |
10 | from fastmcp import Client
11 |
12 |
13 | async def test_sse_server():
14 | """Test Ultimate MCP Server over SSE transport."""
15 | # SSE endpoint - note the /sse path for SSE transport
16 | server_url = "http://127.0.0.1:8013/sse"
17 |
18 | print("🔥 Ultimate MCP Server SSE Test Client")
19 | print("=" * 50)
20 | print(f"🔗 Connecting to Ultimate MCP Server SSE endpoint at {server_url}")
21 |
22 | try:
23 | async with Client(server_url) as client:
24 | print("✅ Successfully connected to SSE server")
25 |
26 | # Test 1: List available tools
27 | print("\n📋 Testing tool discovery...")
28 | tools = await client.list_tools()
29 | print(f"Found {len(tools)} tools via SSE transport:")
30 | for i, tool in enumerate(tools[:10]): # Show first 10
31 | print(f" {i+1:2d}. {tool.name}")
32 | if len(tools) > 10:
33 | print(f" ... and {len(tools) - 10} more tools")
34 |
35 | # Test 2: List available resources
36 | print("\n📚 Testing resource discovery...")
37 | resources = await client.list_resources()
38 | print(f"Found {len(resources)} resources:")
39 | for resource in resources:
40 | print(f" - {resource.uri}")
41 |
42 | # Test 3: Echo tool test
43 | print("\n🔊 Testing echo tool over SSE...")
44 | echo_result = await client.call_tool("echo", {"message": "Hello from SSE client!"})
45 | if echo_result:
46 | echo_data = json.loads(echo_result[0].text)
47 | print(f"✅ Echo response: {json.dumps(echo_data, indent=2)}")
48 |
49 | # Test 4: Provider status test
50 | print("\n🔌 Testing provider status over SSE...")
51 | try:
52 | provider_result = await client.call_tool("get_provider_status", {})
53 | if provider_result:
54 | provider_data = json.loads(provider_result[0].text)
55 | providers = provider_data.get("providers", {})
56 | print(f"✅ Found {len(providers)} providers via SSE:")
57 | for name, status in providers.items():
58 | available = "✅" if status.get("available") else "❌"
59 | model_count = len(status.get("models", []))
60 | print(f" {available} {name}: {model_count} models")
61 | except Exception as e:
62 | print(f"❌ Provider status failed: {e}")
63 |
64 | # Test 5: Resource reading test
65 | print("\n📖 Testing resource reading over SSE...")
66 | if resources:
67 | try:
68 | resource_uri = resources[0].uri
69 | resource_content = await client.read_resource(resource_uri)
70 | if resource_content:
71 | content = resource_content[0].text
72 | preview = content[:200] + "..." if len(content) > 200 else content
73 | print(f"✅ Resource {resource_uri} content preview:")
74 | print(f" {preview}")
75 | except Exception as e:
76 | print(f"❌ Resource reading failed: {e}")
77 |
78 | # Test 6: Simple completion test (if providers available)
79 | print("\n🤖 Testing completion over SSE...")
80 | try:
81 | completion_result = await client.call_tool(
82 | "generate_completion",
83 | {
84 | "prompt": "Say hello in exactly 3 words",
85 | "provider": "ollama",
86 | "model": "mix_77/gemma3-qat-tools:27b",
87 | "max_tokens": 10,
88 | },
89 | )
90 | if completion_result:
91 | result_data = json.loads(completion_result[0].text)
92 | print("✅ Completion via SSE:")
93 | print(f" Text: '{result_data.get('text', 'No text')}'")
94 | print(f" Model: {result_data.get('model', 'Unknown')}")
95 | print(f" Success: {result_data.get('success', False)}")
96 | print(f" Processing time: {result_data.get('processing_time', 0):.2f}s")
97 | except Exception as e:
98 | print(f"⚠️ Completion test failed (expected if no providers): {e}")
99 |
100 | # Test 7: Filesystem tool test
101 | print("\n📁 Testing filesystem tools over SSE...")
102 | try:
103 | dirs_result = await client.call_tool("list_allowed_directories", {})
104 | if dirs_result:
105 | dirs_data = json.loads(dirs_result[0].text)
106 | print(f"✅ Allowed directories via SSE: {dirs_data.get('count', 0)} directories")
107 | except Exception as e:
108 | print(f"❌ Filesystem test failed: {e}")
109 |
110 | # Test 8: Text processing tool test
111 | print("\n📝 Testing text processing over SSE...")
112 | try:
113 | ripgrep_result = await client.call_tool(
114 | "run_ripgrep",
115 | {
116 | "args_str": "'async' . -t py --max-count 5",
117 | "input_dir": "."
118 | }
119 | )
120 | if ripgrep_result:
121 | ripgrep_data = json.loads(ripgrep_result[0].text)
122 | if ripgrep_data.get("success"):
123 | lines = ripgrep_data.get("output", "").split('\n')
124 | line_count = len([l for l in lines if l.strip()]) # noqa: E741
125 | print(f"✅ Ripgrep via SSE found {line_count} matching lines")
126 | else:
127 | print("⚠️ Ripgrep completed but found no matches")
128 | except Exception as e:
129 | print(f"❌ Text processing test failed: {e}")
130 |
131 | print("\n🎉 SSE transport functionality test completed!")
132 | return True
133 |
134 | except Exception as e:
135 | print(f"❌ SSE connection failed: {e}")
136 | print("\nTroubleshooting:")
137 | print("1. Make sure the server is running in SSE mode:")
138 | print(" umcp run -t sse")
139 | print("2. Verify the server is accessible at http://127.0.0.1:8013")
140 | print("3. Check that the SSE endpoint is available at /sse")
141 | return False
142 |
143 |
144 | async def test_sse_interactive():
145 | """Interactive SSE testing mode."""
146 | server_url = "http://127.0.0.1:8013/sse"
147 |
148 | print("\n🎮 Entering SSE interactive mode...")
149 | print("Type 'list' to see available tools, 'quit' to exit")
150 |
151 | try:
152 | async with Client(server_url) as client:
153 | tools = await client.list_tools()
154 | resources = await client.list_resources()
155 |
156 | while True:
157 | try:
158 | command = input("\nSSE> ").strip()
159 |
160 | if command.lower() in ['quit', 'exit', 'q']:
161 | print("👋 Goodbye!")
162 | break
163 | elif command.lower() == 'list':
164 | print("Available tools:")
165 | for i, tool in enumerate(tools[:20]):
166 | print(f" {i+1:2d}. {tool.name}")
167 | if len(tools) > 20:
168 | print(f" ... and {len(tools) - 20} more")
169 | elif command.lower() == 'resources':
170 | print("Available resources:")
171 | for resource in resources:
172 | print(f" - {resource.uri}")
173 | elif command.startswith("tool "):
174 | # Call tool: tool <tool_name> <json_params>
175 | parts = command[5:].split(' ', 1)
176 | tool_name = parts[0]
177 | params = json.loads(parts[1]) if len(parts) > 1 else {}
178 |
179 | try:
180 | result = await client.call_tool(tool_name, params)
181 | if result:
182 | print(f"✅ Tool result: {result[0].text}")
183 | else:
184 | print("❌ No result returned")
185 | except Exception as e:
186 | print(f"❌ Tool call failed: {e}")
187 | elif command.startswith("read "):
188 | # Read resource: read <resource_uri>
189 | resource_uri = command[5:].strip()
190 | try:
191 | result = await client.read_resource(resource_uri)
192 | if result:
193 | content = result[0].text
194 | preview = content[:500] + "..." if len(content) > 500 else content
195 | print(f"✅ Resource content: {preview}")
196 | else:
197 | print("❌ No content returned")
198 | except Exception as e:
199 | print(f"❌ Resource read failed: {e}")
200 | else:
201 | print("Commands:")
202 | print(" list - List available tools")
203 | print(" resources - List available resources")
204 | print(" tool <name> <params> - Call a tool with JSON params")
205 | print(" read <uri> - Read a resource")
206 | print(" quit - Exit interactive mode")
207 |
208 | except KeyboardInterrupt:
209 | print("\n👋 Goodbye!")
210 | break
211 | except Exception as e:
212 | print(f"❌ Command error: {e}")
213 |
214 | except Exception as e:
215 | print(f"❌ SSE interactive mode failed: {e}")
216 |
217 |
218 | async def main():
219 | """Main test function."""
220 | # Run basic functionality test
221 | success = await test_sse_server()
222 |
223 | if success:
224 | # Ask if user wants interactive mode
225 | try:
226 | response = input("\nWould you like to enter SSE interactive mode? (y/n): ").strip().lower()
227 | if response in ['y', 'yes']:
228 | await test_sse_interactive()
229 | except KeyboardInterrupt:
230 | print("\n👋 Goodbye!")
231 |
232 |
233 | if __name__ == "__main__":
234 | asyncio.run(main())
```
--------------------------------------------------------------------------------
/ultimate_mcp_server/utils/logging/__init__.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Gateway Logging Package.
3 |
4 | This package provides enhanced logging capabilities with rich formatting,
5 | progress tracking, and console output for the Gateway system.
6 | """
7 |
8 | import logging
9 | import logging.handlers
10 | from typing import Any, Dict, List, Optional
11 |
12 | # Import Rich-based console
13 | # Adjusted imports to be relative within the new structure
14 | from .console import (
15 | console,
16 | create_progress,
17 | live_display,
18 | print_json,
19 | print_panel,
20 | print_syntax,
21 | print_table,
22 | print_tree,
23 | status,
24 | )
25 |
26 | # Import emojis
27 | from .emojis import (
28 | COMPLETED,
29 | CRITICAL,
30 | DEBUG,
31 | ERROR,
32 | FAILED,
33 | INFO,
34 | RUNNING,
35 | SUCCESS,
36 | WARNING,
37 | get_emoji,
38 | )
39 |
40 | # Import formatters and handlers
41 | from .formatter import (
42 | DetailedLogFormatter,
43 | GatewayLogRecord,
44 | RichLoggingHandler,
45 | SimpleLogFormatter,
46 | create_rich_console_handler, # Added missing import used in server.py LOGGING_CONFIG
47 | )
48 |
49 | # Import logger and related utilities
50 | from .logger import (
51 | Logger,
52 | critical,
53 | debug,
54 | error,
55 | info,
56 | section,
57 | success,
58 | warning,
59 | )
60 |
61 | # Import panels
62 | from .panels import (
63 | CodePanel,
64 | ErrorPanel,
65 | HeaderPanel,
66 | InfoPanel,
67 | ResultPanel,
68 | ToolOutputPanel,
69 | WarningPanel,
70 | display_code,
71 | display_error,
72 | display_header,
73 | display_info,
74 | display_results,
75 | display_tool_output,
76 | display_warning,
77 | )
78 |
79 | # Import progress tracking
80 | from .progress import (
81 | GatewayProgress,
82 | track,
83 | )
84 |
85 | # Create a global logger instance for importing
86 | logger = Logger("ultimate")
87 |
88 | # Removed configure_root_logger, initialize_logging, set_log_level functions
89 | # Logging is now configured via dictConfig in main.py (or server.py equivalent)
90 |
91 | def get_logger(name: str) -> Logger:
92 | """
93 | Get or create a specialized Logger instance for a specific component.
94 |
95 | This function provides access to the enhanced logging system of the Ultimate MCP Server,
96 | returning a Logger instance that includes rich formatting, emoji support, and other
97 | advanced features beyond Python's standard logging.
98 |
99 | The returned Logger is configured with the project's logging settings and integrates
100 | with the rich console output system. It provides methods like success() and section()
101 | in addition to standard logging methods.
102 |
103 | Args:
104 | name: The logger name, typically the module or component name.
105 | Can use dot notation for hierarchy (e.g., "module.submodule").
106 |
107 | Returns:
108 | An enhanced Logger instance with rich formatting and emoji support
109 |
110 | Example:
111 | ```python
112 | # In a module file
113 | from ultimate_mcp_server.utils.logging import get_logger
114 |
115 | # Create logger with the module name
116 | logger = get_logger(__name__)
117 |
118 | # Use the enhanced logging methods
119 | logger.info("Server starting") # Basic info log
120 | logger.success("Operation completed") # Success log (not in std logging)
121 | logger.warning("Resource low", resource="RAM") # With additional context
122 | logger.error("Failed to connect", emoji_key="network") # With custom emoji
123 | ```
124 | """
125 | # Use the new base name for sub-loggers if needed, or keep original logic
126 | # return Logger(f"ultimate_mcp_server.{name}") # Option 1: Prefix with base name
127 | return Logger(name) # Option 2: Keep original name logic
128 |
129 | def capture_logs(level: Optional[str] = None) -> "LogCapture":
130 | """
131 | Create a context manager to capture logs for testing or debugging.
132 |
133 | This function is a convenience wrapper around the LogCapture class, creating
134 | and returning a context manager that will capture logs at or above the specified
135 | level during its active scope.
136 |
137 | Use this function when you need to verify that certain log messages are emitted
138 | during tests, or when you want to collect logs for analysis without modifying
139 | the application's logging configuration.
140 |
141 | Args:
142 | level: Minimum log level to capture (e.g., "INFO", "WARNING", "ERROR").
143 | If None, all log levels are captured. Default: None
144 |
145 | Returns:
146 | A LogCapture context manager that will collect logs when active
147 |
148 | Example:
149 | ```python
150 | # Test that a function produces expected log messages
151 | def test_login_function():
152 | with capture_logs("WARNING") as logs:
153 | # Call function that should produce a warning log for invalid login
154 | result = login("invalid_user", "wrong_password")
155 |
156 | # Assert that the expected warning was logged
157 | assert logs.contains("Invalid login attempt")
158 | assert len(logs.get_logs()) == 1
159 | ```
160 | """
161 | return LogCapture(level)
162 |
163 | # Log capturing for testing
164 | class LogCapture:
165 | """
166 | Context manager for capturing and analyzing logs during execution.
167 |
168 | This class provides a way to intercept, store, and analyze logs emitted during
169 | a specific block of code execution. It's primarily useful for:
170 |
171 | - Testing: Verify that specific log messages were emitted during tests
172 | - Debugging: Collect logs for examination without changing logging configuration
173 | - Analysis: Gather statistics about logging patterns
174 |
175 | The LogCapture acts as a context manager, capturing logs only within its scope
176 | and providing methods to retrieve and analyze the captured logs after execution.
177 |
178 | Each captured log entry is stored as a dictionary with details including the
179 | message, level, timestamp, and source file/line information.
180 |
181 | Example usage:
182 | ```python
183 | # Capture all logs
184 | with LogCapture() as capture:
185 | # Code that generates logs
186 | perform_operation()
187 |
188 | # Check for specific log messages
189 | assert capture.contains("Database connected")
190 | assert not capture.contains("Error")
191 |
192 | # Get all captured logs
193 | all_logs = capture.get_logs()
194 |
195 | # Get only warning and error messages
196 | warnings = capture.get_logs(level="WARNING")
197 | ```
198 | """
199 |
200 | def __init__(self, level: Optional[str] = None):
201 | """Initialize the log capture.
202 |
203 | Args:
204 | level: Minimum log level to capture
205 | """
206 | self.level = level
207 | self.level_num = getattr(logging, self.level.upper(), 0) if self.level else 0
208 | self.logs: List[Dict[str, Any]] = []
209 | self.handler = self._create_handler()
210 |
211 | def _create_handler(self) -> logging.Handler:
212 | """Create a handler to capture logs.
213 |
214 | Returns:
215 | Log handler
216 | """
217 | class CaptureHandler(logging.Handler):
218 | def __init__(self, capture):
219 | super().__init__()
220 | self.capture = capture
221 |
222 | def emit(self, record):
223 | # Skip if record level is lower than minimum
224 | if record.levelno < self.capture.level_num:
225 | return
226 |
227 | # Add log record to captured logs
228 | self.capture.logs.append({
229 | "level": record.levelname,
230 | "message": record.getMessage(),
231 | "name": record.name,
232 | "time": record.created,
233 | "file": record.pathname,
234 | "line": record.lineno,
235 | })
236 |
237 | return CaptureHandler(self)
238 |
239 | def __enter__(self) -> "LogCapture":
240 | """Enter the context manager.
241 |
242 | Returns:
243 | Self
244 | """
245 | # Add handler to root logger
246 | # Use the project's logger name
247 | logging.getLogger("ultimate").addHandler(self.handler)
248 | # Consider adding to the absolute root logger as well if needed
249 | # logging.getLogger().addHandler(self.handler)
250 | return self
251 |
252 | def __exit__(self, exc_type, exc_val, exc_tb) -> None:
253 | """Exit the context manager.
254 |
255 | Args:
256 | exc_type: Exception type
257 | exc_val: Exception value
258 | exc_tb: Exception traceback
259 | """
260 | # Remove handler from root logger
261 | logging.getLogger("ultimate").removeHandler(self.handler)
262 | # logging.getLogger().removeHandler(self.handler)
263 |
264 | def get_logs(self, level: Optional[str] = None) -> List[Dict[str, Any]]:
265 | """Get captured logs, optionally filtered by level.
266 |
267 | Args:
268 | level: Filter logs by level
269 |
270 | Returns:
271 | List of log records
272 | """
273 | if not level:
274 | return self.logs
275 |
276 | level_num = getattr(logging, level.upper(), 0)
277 | return [log for log in self.logs if getattr(logging, log["level"], 0) >= level_num]
278 |
279 | def get_messages(self, level: Optional[str] = None) -> List[str]:
280 | """Get captured log messages, optionally filtered by level.
281 |
282 | Args:
283 | level: Filter logs by level
284 |
285 | Returns:
286 | List of log messages
287 | """
288 | return [log["message"] for log in self.get_logs(level)]
289 |
290 | def contains(self, text: str, level: Optional[str] = None) -> bool:
291 | """Check if any log message contains the given text.
292 |
293 | Args:
294 | text: Text to search for
295 | level: Optional level filter
296 |
297 | Returns:
298 | True if text is found in any message
299 | """
300 | return any(text in msg for msg in self.get_messages(level))
301 |
302 | __all__ = [
303 | # Console
304 | "console",
305 | "create_progress",
306 | "status",
307 | "print_panel",
308 | "print_syntax",
309 | "print_table",
310 | "print_tree",
311 | "print_json",
312 | "live_display",
313 |
314 | # Logger and utilities
315 | "logger",
316 | "Logger",
317 | "debug",
318 | "info",
319 | "success",
320 | "warning",
321 | "error",
322 | "critical",
323 | "section",
324 | "get_logger",
325 | "capture_logs",
326 | "LogCapture",
327 |
328 | # Emojis
329 | "get_emoji",
330 | "INFO",
331 | "DEBUG",
332 | "WARNING",
333 | "ERROR",
334 | "CRITICAL",
335 | "SUCCESS",
336 | "RUNNING",
337 | "COMPLETED",
338 | "FAILED",
339 |
340 | # Panels
341 | "HeaderPanel",
342 | "ResultPanel",
343 | "InfoPanel",
344 | "WarningPanel",
345 | "ErrorPanel",
346 | "ToolOutputPanel",
347 | "CodePanel",
348 | "display_header",
349 | "display_results",
350 | "display_info",
351 | "display_warning",
352 | "display_error",
353 | "display_tool_output",
354 | "display_code",
355 |
356 | # Progress tracking
357 | "GatewayProgress",
358 | "track",
359 |
360 | # Formatters and handlers
361 | "GatewayLogRecord",
362 | "SimpleLogFormatter",
363 | "DetailedLogFormatter",
364 | "RichLoggingHandler",
365 | "create_rich_console_handler",
366 | ]
```
--------------------------------------------------------------------------------
/ultimate_mcp_server/services/vector/embeddings.py:
--------------------------------------------------------------------------------
```python
1 | """Embedding generation service for vector operations."""
2 | import asyncio
3 | import hashlib
4 | import os
5 | from pathlib import Path
6 | from typing import List, Optional
7 |
8 | import numpy as np
9 | from openai import AsyncOpenAI
10 |
11 | from ultimate_mcp_server.config import get_config
12 | from ultimate_mcp_server.utils import get_logger
13 |
14 | logger = get_logger(__name__)
15 |
16 | # Global dictionary to store embedding instances (optional)
17 | embedding_instances = {}
18 |
19 |
20 | class EmbeddingCache:
21 | """Cache for embeddings to avoid repeated API calls."""
22 |
23 | def __init__(self, cache_dir: Optional[str] = None):
24 | """Initialize the embedding cache.
25 |
26 | Args:
27 | cache_dir: Directory to store cache files
28 | """
29 | if cache_dir:
30 | self.cache_dir = Path(cache_dir)
31 | else:
32 | self.cache_dir = Path.home() / ".ultimate" / "embeddings"
33 |
34 | # Create cache directory if it doesn't exist
35 | self.cache_dir.mkdir(parents=True, exist_ok=True)
36 |
37 | # In-memory cache
38 | self.cache = {}
39 |
40 | logger.info(
41 | f"Embeddings cache initialized (directory: {self.cache_dir})",
42 | emoji_key="cache"
43 | )
44 |
45 | def _get_cache_key(self, text: str, model: str) -> str:
46 | """Generate a cache key for text and model.
47 |
48 | Args:
49 | text: Text to embed
50 | model: Embedding model name
51 |
52 | Returns:
53 | Cache key
54 | """
55 | # Create a hash based on text and model
56 | text_hash = hashlib.md5(text.encode("utf-8")).hexdigest()
57 | return f"{model}_{text_hash}"
58 |
59 | def _get_cache_file_path(self, key: str) -> Path:
60 | """Get cache file path for a key.
61 |
62 | Args:
63 | key: Cache key
64 |
65 | Returns:
66 | Cache file path
67 | """
68 | return self.cache_dir / f"{key}.npy"
69 |
70 | def get(self, text: str, model: str) -> Optional[np.ndarray]:
71 | """Get embedding from cache.
72 |
73 | Args:
74 | text: Text to embed
75 | model: Embedding model name
76 |
77 | Returns:
78 | Cached embedding or None if not found
79 | """
80 | key = self._get_cache_key(text, model)
81 |
82 | # Check in-memory cache first
83 | if key in self.cache:
84 | return self.cache[key]
85 |
86 | # Check disk cache
87 | cache_file = self._get_cache_file_path(key)
88 | if cache_file.exists():
89 | try:
90 | embedding = np.load(str(cache_file))
91 | # Add to in-memory cache
92 | self.cache[key] = embedding
93 | return embedding
94 | except Exception as e:
95 | logger.error(
96 | f"Failed to load embedding from cache: {str(e)}",
97 | emoji_key="error"
98 | )
99 |
100 | return None
101 |
102 | def set(self, text: str, model: str, embedding: np.ndarray) -> None:
103 | """Set embedding in cache.
104 |
105 | Args:
106 | text: Text to embed
107 | model: Embedding model name
108 | embedding: Embedding vector
109 | """
110 | key = self._get_cache_key(text, model)
111 |
112 | # Add to in-memory cache
113 | self.cache[key] = embedding
114 |
115 | # Save to disk
116 | cache_file = self._get_cache_file_path(key)
117 | try:
118 | np.save(str(cache_file), embedding)
119 | except Exception as e:
120 | logger.error(
121 | f"Failed to save embedding to cache: {str(e)}",
122 | emoji_key="error"
123 | )
124 |
125 | def clear(self) -> None:
126 | """Clear the embedding cache."""
127 | # Clear in-memory cache
128 | self.cache.clear()
129 |
130 | # Clear disk cache
131 | for cache_file in self.cache_dir.glob("*.npy"):
132 | try:
133 | cache_file.unlink()
134 | except Exception as e:
135 | logger.error(
136 | f"Failed to delete cache file {cache_file}: {str(e)}",
137 | emoji_key="error"
138 | )
139 |
140 | logger.info(
141 | "Embeddings cache cleared",
142 | emoji_key="cache"
143 | )
144 |
145 |
146 | class EmbeddingService:
147 | """Generic service to create embeddings using different providers."""
148 | def __init__(self, provider_type: str = 'openai', model_name: str = 'text-embedding-3-small', api_key: Optional[str] = None, **kwargs):
149 | """Initialize the embedding service.
150 |
151 | Args:
152 | provider_type: The type of embedding provider (e.g., 'openai').
153 | model_name: The specific embedding model to use.
154 | api_key: Optional API key. If not provided, attempts to load from config.
155 | **kwargs: Additional provider-specific arguments.
156 | """
157 | self.provider_type = provider_type.lower()
158 | self.model_name = model_name
159 | self.client = None
160 | self.api_key = api_key
161 | self.kwargs = kwargs
162 |
163 | try:
164 | config = get_config()
165 | if self.provider_type == 'openai':
166 | provider_config = config.providers.openai
167 | # Use provided key first, then config key
168 | self.api_key = self.api_key or provider_config.api_key
169 | if not self.api_key:
170 | raise ValueError("OpenAI API key not provided or found in configuration.")
171 | # Pass base_url and organization from config if available
172 | openai_kwargs = {
173 | 'api_key': self.api_key,
174 | 'base_url': provider_config.base_url or self.kwargs.get('base_url'),
175 | 'organization': provider_config.organization or self.kwargs.get('organization'),
176 | 'timeout': provider_config.timeout or self.kwargs.get('timeout'),
177 | }
178 | # Filter out None values before passing to OpenAI client
179 | openai_kwargs = {k: v for k, v in openai_kwargs.items() if v is not None}
180 |
181 | # Always use AsyncOpenAI
182 | self.client = AsyncOpenAI(**openai_kwargs)
183 | logger.info(f"Initialized AsyncOpenAI embedding client for model: {self.model_name}")
184 |
185 | else:
186 | raise ValueError(f"Unsupported embedding provider type: {self.provider_type}")
187 |
188 | except Exception as e:
189 | logger.error(f"Failed to initialize embedding service for provider {self.provider_type}: {e}", exc_info=True)
190 | raise RuntimeError(f"Embedding service initialization failed: {e}") from e
191 |
192 |
193 | async def create_embeddings(self, texts: List[str]) -> List[List[float]]:
194 | """Create embeddings for a list of texts.
195 |
196 | Args:
197 | texts: A list of strings to embed.
198 |
199 | Returns:
200 | A list of embedding vectors (each a list of floats).
201 |
202 | Raises:
203 | ValueError: If the provider type is unsupported or embedding fails.
204 | RuntimeError: If the client is not initialized.
205 | """
206 | if self.client is None:
207 | raise RuntimeError("Embedding client is not initialized.")
208 |
209 | try:
210 | if self.provider_type == 'openai':
211 | response = await self.client.embeddings.create(
212 | input=texts,
213 | model=self.model_name
214 | )
215 | # Extract the embedding data
216 | embeddings = [item.embedding for item in response.data]
217 | logger.debug(f"Successfully created {len(embeddings)} embeddings using {self.model_name}.")
218 | return embeddings
219 |
220 | else:
221 | raise ValueError(f"Unsupported provider type: {self.provider_type}")
222 |
223 | except Exception as e:
224 | logger.error(f"Failed to create embeddings using {self.provider_type} model {self.model_name}: {e}", exc_info=True)
225 | # Re-raise the error or return an empty list/handle appropriately
226 | raise ValueError(f"Embedding creation failed: {e}") from e
227 |
228 |
229 | def get_embedding_service(provider_type: str = 'openai', model_name: str = 'text-embedding-3-small', **kwargs) -> EmbeddingService:
230 | """Factory function to get or create an EmbeddingService instance.
231 |
232 | Args:
233 | provider_type: The type of embedding provider.
234 | model_name: The specific embedding model.
235 | **kwargs: Additional arguments passed to the EmbeddingService constructor.
236 |
237 | Returns:
238 | An initialized EmbeddingService instance.
239 | """
240 | # Optional: Implement caching/singleton pattern for instances if desired
241 | instance_key = (provider_type, model_name)
242 | if instance_key in embedding_instances:
243 | # TODO: Check if kwargs match cached instance? For now, assume they do.
244 | logger.debug(f"Returning cached embedding service instance for {provider_type}/{model_name}")
245 | return embedding_instances[instance_key]
246 | else:
247 | logger.debug(f"Creating new embedding service instance for {provider_type}/{model_name}")
248 | instance = EmbeddingService(provider_type=provider_type, model_name=model_name, **kwargs)
249 | embedding_instances[instance_key] = instance
250 | return instance
251 |
252 |
253 | # Example usage (for testing)
254 | async def main():
255 | # setup_logging(log_level="DEBUG") # Removed as logging is configured centrally
256 | # Make sure OPENAI_API_KEY is set in your .env file or environment
257 | os.environ['GATEWAY_FORCE_CONFIG_RELOAD'] = 'true' # Ensure latest config
258 |
259 | try:
260 | # Get the default OpenAI service
261 | openai_service = get_embedding_service()
262 |
263 | texts_to_embed = [
264 | "The quick brown fox jumps over the lazy dog.",
265 | "Quantum computing leverages quantum mechanics.",
266 | "Paris is the capital of France."
267 | ]
268 |
269 | embeddings = await openai_service.create_embeddings(texts_to_embed)
270 | print(f"Generated {len(embeddings)} embeddings.")
271 | print(f"Dimension of first embedding: {len(embeddings[0])}")
272 | # print(f"First embedding (preview): {embeddings[0][:10]}...")
273 |
274 | # Example of specifying a different model (if available and configured)
275 | # try:
276 | # ada_service = get_embedding_service(model_name='text-embedding-ada-002')
277 | # ada_embeddings = await ada_service.create_embeddings(["Test with Ada model"])
278 | # print(\"\nSuccessfully used Ada model.\")
279 | # except Exception as e:
280 | # print(f\"\nCould not use Ada model (may need different API key/config): {e}\")
281 |
282 | except Exception as e:
283 | print(f"An error occurred during the example: {e}")
284 | finally:
285 | if 'GATEWAY_FORCE_CONFIG_RELOAD' in os.environ:
286 | del os.environ['GATEWAY_FORCE_CONFIG_RELOAD']
287 |
288 | if __name__ == "__main__":
289 | import asyncio
290 | asyncio.run(main())
```
--------------------------------------------------------------------------------
/examples/marqo_fused_search_demo.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python3
2 | """Demo script showcasing the marqo_fused_search tool."""
3 |
4 | import asyncio
5 | import json
6 | import os
7 | import sys
8 | import time # Add time import
9 | from datetime import datetime, timedelta
10 | from typing import Any, Dict, Optional
11 |
12 | # Add Rich imports
13 | from rich.console import Console
14 | from rich.markup import escape
15 | from rich.panel import Panel
16 | from rich.rule import Rule
17 | from rich.syntax import Syntax
18 |
19 | # Add the project root to the Python path
20 | project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
21 | sys.path.insert(0, project_root)
22 |
23 | from ultimate_mcp_server.tools.marqo_fused_search import DateRange, marqo_fused_search # noqa: E402
24 | from ultimate_mcp_server.utils.logging import logger # noqa: E402
25 |
26 | # Initialize Rich Console
27 | console = Console()
28 |
29 | # --- Configuration ---
30 | CONFIG_FILE_PATH = os.path.join(project_root, "marqo_index_config.json")
31 |
32 | def load_marqo_config() -> Dict[str, Any]:
33 | """Loads Marqo configuration from the JSON file."""
34 | try:
35 | with open(CONFIG_FILE_PATH, 'r') as f:
36 | config = json.load(f)
37 | logger.info(f"Loaded Marqo config from {CONFIG_FILE_PATH}")
38 | return config
39 | except FileNotFoundError:
40 | logger.error(f"Marqo config file not found at {CONFIG_FILE_PATH}. Cannot run dynamic examples.")
41 | return {}
42 | except json.JSONDecodeError as e:
43 | logger.error(f"Error decoding Marqo config file {CONFIG_FILE_PATH}: {e}")
44 | return {}
45 |
46 | def find_schema_field(schema: Dict[str, Any], required_properties: Dict[str, Any]) -> Optional[str]:
47 | """
48 | Finds the first field name in the schema that matches all required properties.
49 | Handles nested properties like 'type'.
50 | """
51 | if not schema or "fields" not in schema:
52 | return None
53 |
54 | for field_name, properties in schema["fields"].items():
55 | match = True
56 | for req_prop, req_value in required_properties.items():
57 | # Allow checking properties like 'type', 'filterable', 'sortable', 'role', 'searchable'
58 | if properties.get(req_prop) != req_value:
59 | match = False
60 | break
61 | if match:
62 | # Avoid returning internal fields like _id unless specifically requested
63 | if field_name == "_id" and required_properties.get("role") != "internal":
64 | continue
65 | return field_name
66 | return None
67 |
68 | # --- Helper Function ---
69 | async def run_search_example(example_name: str, **kwargs):
70 | """Runs a single search example and prints the results using Rich."""
71 | console.print(Rule(f"[bold cyan]{example_name}[/bold cyan]"))
72 |
73 | # Display parameters using a panel
74 | param_str_parts = []
75 | for key, value in kwargs.items():
76 | # Format DateRange nicely
77 | if isinstance(value, DateRange):
78 | start_str = value.start_date.strftime("%Y-%m-%d") if value.start_date else "N/A"
79 | end_str = value.end_date.strftime("%Y-%m-%d") if value.end_date else "N/A"
80 | param_str_parts.append(f" [green]{key}[/green]: Start=[yellow]{start_str}[/yellow], End=[yellow]{end_str}[/yellow]")
81 | else:
82 | param_str_parts.append(f" [green]{key}[/green]: [yellow]{escape(str(value))}[/yellow]")
83 | param_str = "\n".join(param_str_parts)
84 | console.print(Panel(param_str, title="Search Parameters", border_style="blue", expand=False))
85 |
86 | try:
87 | start_time = time.time() # Use time for accurate timing
88 | results = await marqo_fused_search(**kwargs)
89 | processing_time = time.time() - start_time
90 |
91 | logger.debug(f"Raw results for '{example_name}': {results}") # Keep debug log
92 |
93 | if results.get("success"):
94 | logger.success(f"Search successful for '{example_name}'! ({processing_time:.3f}s)", emoji_key="success")
95 |
96 | # Display results using Rich Syntax for JSON
97 | results_json = json.dumps(results, indent=2, default=str)
98 | syntax = Syntax(results_json, "json", theme="default", line_numbers=True)
99 | console.print(Panel(syntax, title="Marqo Search Results", border_style="green"))
100 |
101 | else:
102 | # Display error nicely if success is False but no exception was raised
103 | error_msg = results.get("error", "Unknown error")
104 | error_code = results.get("error_code", "UNKNOWN_CODE")
105 | logger.error(f"Search failed for '{example_name}': {error_code} - {error_msg}", emoji_key="error")
106 | console.print(Panel(f"[bold red]Error ({error_code}):[/bold red]\n{escape(error_msg)}", title="Search Failed", border_style="red"))
107 |
108 | except Exception as e:
109 | processing_time = time.time() - start_time
110 | logger.error(f"An exception occurred during '{example_name}' ({processing_time:.3f}s): {e}", emoji_key="critical", exc_info=True)
111 | # Display exception using Rich traceback
112 | console.print_exception(show_locals=False)
113 | console.print(Panel(f"[bold red]Exception:[/bold red]\n{escape(str(e))}", title="Execution Error", border_style="red"))
114 |
115 | console.print() # Add space after each example
116 |
117 |
118 | # --- Main Demo Function ---
119 | async def main():
120 | """Runs various demonstrations of the marqo_fused_search tool."""
121 |
122 | # Load Marqo configuration and schema
123 | marqo_config = load_marqo_config()
124 | if not marqo_config:
125 | logger.error("Exiting demo as Marqo config could not be loaded.")
126 | return
127 |
128 | schema = marqo_config.get("default_schema", {})
129 | tensor_field = schema.get("tensor_field")
130 | # content_field = schema.get("default_content_field", "content") # Not directly used in examples
131 | date_field = schema.get("default_date_field") # Used for date range
132 |
133 | # --- Find suitable fields dynamically ---
134 | # For filter examples (keyword preferred)
135 | filter_field = find_schema_field(schema, {"filterable": True, "type": "keyword"}) or \
136 | find_schema_field(schema, {"filterable": True}) # Fallback to any filterable
137 |
138 | # For lexical search (requires searchable='lexical')
139 | lexical_field_1 = find_schema_field(schema, {"searchable": "lexical"})
140 | lexical_field_2 = find_schema_field(schema, {"searchable": "lexical", "field_name_not": lexical_field_1}) or lexical_field_1 # Find a second one if possible
141 |
142 | # For hybrid search (need tensor + lexical)
143 | hybrid_tensor_field = tensor_field # Use the main tensor field
144 | hybrid_lexical_field_1 = lexical_field_1
145 | hybrid_lexical_field_2 = lexical_field_2
146 |
147 | # For explicit tensor search (need tensor field)
148 | explicit_tensor_field = tensor_field
149 |
150 | logger.info("Dynamically determined fields for examples:")
151 | logger.info(f" Filter Field: '{filter_field}'")
152 | logger.info(f" Lexical Fields: '{lexical_field_1}', '{lexical_field_2}'")
153 | logger.info(f" Tensor Field (for hybrid/explicit): '{hybrid_tensor_field}'")
154 | logger.info(f" Date Field (for range): '{date_field}'")
155 |
156 | # --- Run Examples ---
157 |
158 | # --- Example 1: Basic Semantic Search --- (No specific fields needed)
159 | await run_search_example(
160 | "Basic Semantic Search",
161 | query="impact of AI on software development"
162 | )
163 |
164 | # --- Example 2: Search with Metadata Filter ---
165 | if filter_field:
166 | # Use a plausible value; specific value might not exist in data
167 | example_filter_value = "10-K" if filter_field == "form_type" else "example_value"
168 | await run_search_example(
169 | "Search with Metadata Filter",
170 | query="latest advancements in renewable energy",
171 | filters={filter_field: example_filter_value}
172 | )
173 | else:
174 | logger.warning("Skipping Example 2: No suitable filterable field found in schema.")
175 |
176 | # --- Example 3: Search with Multiple Filter Values (OR condition) ---
177 | if filter_field:
178 | # Use plausible values
179 | example_filter_values = ["10-K", "10-Q"] if filter_field == "form_type" else ["value1", "value2"]
180 | await run_search_example(
181 | "Search with Multiple Filter Values (OR)",
182 | query="financial report analysis",
183 | filters={filter_field: example_filter_values}
184 | )
185 | else:
186 | logger.warning("Skipping Example 3: No suitable filterable field found in schema.")
187 |
188 | # --- Example 4: Search with Date Range ---
189 | if date_field and find_schema_field(schema, {"name": date_field, "type": "timestamp"}):
190 | start_date = datetime.now() - timedelta(days=900)
191 | end_date = datetime.now() - timedelta(days=30)
192 | await run_search_example(
193 | "Search with Date Range",
194 | query="market trends",
195 | date_range=DateRange(start_date=start_date, end_date=end_date)
196 | )
197 | else:
198 | logger.warning(f"Skipping Example 4: No sortable timestamp field named '{date_field}' (default_date_field) found in schema.")
199 |
200 | # --- Example 5: Pure Lexical Search --- (Relies on schema having lexical fields)
201 | # The tool will auto-detect lexical fields if not specified, but this tests the weight
202 | await run_search_example(
203 | "Pure Lexical Search",
204 | query="exact sciences", # Query likely to hit company name etc.
205 | semantic_weight=0.0
206 | )
207 |
208 | # --- Example 6: Hybrid Search with Custom Weight --- (Relies on schema having both)
209 | await run_search_example(
210 | "Hybrid Search with Custom Weight",
211 | query="balancing innovation and regulation",
212 | semantic_weight=0.5 # Equal weight
213 | )
214 |
215 | # --- Example 7: Pagination (Limit and Offset) --- (No specific fields needed)
216 | await run_search_example(
217 | "Pagination (Limit and Offset)",
218 | query="common programming paradigms",
219 | limit=10,
220 | offset=10
221 | )
222 |
223 | # --- Example 8: Explicit Searchable Attributes (Tensor Search) ---
224 | if explicit_tensor_field:
225 | await run_search_example(
226 | "Explicit Tensor Searchable Attributes",
227 | query="neural network architectures",
228 | searchable_attributes=[explicit_tensor_field],
229 | semantic_weight=1.0 # Ensure tensor search is used
230 | )
231 | else:
232 | logger.warning("Skipping Example 8: No tensor field found in schema.")
233 |
234 | # --- Example 9: Explicit Hybrid Search Attributes ---
235 | if hybrid_tensor_field and hybrid_lexical_field_1:
236 | lexical_fields = [hybrid_lexical_field_1]
237 | if hybrid_lexical_field_2 and hybrid_lexical_field_1 != hybrid_lexical_field_2:
238 | lexical_fields.append(hybrid_lexical_field_2)
239 | await run_search_example(
240 | "Explicit Hybrid Search Attributes",
241 | query="machine learning applications in healthcare",
242 | hybrid_search_attributes={
243 | "tensor": [hybrid_tensor_field],
244 | "lexical": lexical_fields
245 | },
246 | semantic_weight=0.6 # Specify hybrid search balance
247 | )
248 | else:
249 | logger.warning("Skipping Example 9: Need both tensor and lexical fields defined in schema.")
250 |
251 | # --- Example 12: Overriding Marqo URL and Index Name --- (Keep commented out)
252 | # ... rest of the code ...
253 | console.print(Rule("[bold magenta]Marqo Fused Search Demo Complete[/bold magenta]"))
254 |
255 |
256 | if __name__ == "__main__":
257 | console.print(Rule("[bold magenta]Starting Marqo Fused Search Demo[/bold magenta]"))
258 | # logger.info("Starting Marqo Fused Search Demo...") # Replaced by Rich rule
259 | asyncio.run(main())
260 | # logger.info("Marqo Fused Search Demo finished.") # Replaced by Rich rule
```
--------------------------------------------------------------------------------
/examples/claude_integration_demo.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python
2 | """Claude integration demonstration using Ultimate MCP Server."""
3 | import asyncio
4 | import sys
5 | import time
6 | from pathlib import Path
7 |
8 | # Add project root to path for imports when running as script
9 | sys.path.insert(0, str(Path(__file__).parent.parent))
10 |
11 | # Third-party imports
12 | # These imports need to be below sys.path modification, which is why they have noqa comments
13 | from rich import box # noqa: E402
14 | from rich.markup import escape # noqa: E402
15 | from rich.panel import Panel # noqa: E402
16 | from rich.rule import Rule # noqa: E402
17 | from rich.table import Table # noqa: E402
18 |
19 | # Project imports
20 | from ultimate_mcp_server.constants import Provider # noqa: E402
21 | from ultimate_mcp_server.core.server import Gateway # noqa: E402
22 | from ultimate_mcp_server.utils import get_logger # noqa: E402
23 | from ultimate_mcp_server.utils.display import CostTracker # Import CostTracker
24 | from ultimate_mcp_server.utils.logging.console import console # noqa: E402
25 |
26 | # Initialize logger
27 | logger = get_logger("example.claude_integration_demo")
28 |
29 |
30 | async def compare_claude_models(tracker: CostTracker):
31 | """Compare different Claude models."""
32 | console.print(Rule("[bold blue]Claude Model Comparison[/bold blue]"))
33 | logger.info("Starting Claude models comparison", emoji_key="start")
34 |
35 | # Create Gateway instance - this handles provider initialization
36 | gateway = Gateway("claude-demo", register_tools=False)
37 |
38 | # Initialize providers
39 | logger.info("Initializing providers...", emoji_key="provider")
40 | await gateway._initialize_providers()
41 |
42 | provider_name = Provider.ANTHROPIC.value
43 | try:
44 | # Get the provider from the gateway
45 | provider = gateway.providers.get(provider_name)
46 | if not provider:
47 | logger.error(f"Provider {provider_name} not available or initialized", emoji_key="error")
48 | return
49 |
50 | logger.info(f"Using provider: {provider_name}", emoji_key="provider")
51 |
52 | models = await provider.list_models()
53 | model_names = [m["id"] for m in models] # Extract names from model dictionaries
54 | console.print(f"Found {len(model_names)} Claude models: [cyan]{escape(str(model_names))}[/cyan]")
55 |
56 | # Select specific models to compare (Ensure these are valid and available)
57 | claude_models = [
58 | "anthropic/claude-3-7-sonnet-20250219",
59 | "anthropic/claude-3-5-haiku-20241022"
60 | ]
61 | # Filter based on available models
62 | models_to_compare = [m for m in claude_models if m in model_names]
63 | if not models_to_compare:
64 | logger.error("None of the selected models for comparison are available. Exiting comparison.", emoji_key="error")
65 | console.print("[red]Selected models not found in available list.[/red]")
66 | return
67 | console.print(f"Comparing models: [yellow]{escape(str(models_to_compare))}[/yellow]")
68 |
69 | prompt = """
70 | Explain the concept of quantum entanglement in a way that a high school student would understand.
71 | Keep your response brief and accessible.
72 | """
73 | console.print(f"[cyan]Using Prompt:[/cyan] {escape(prompt.strip())[:100]}...")
74 |
75 | results_data = []
76 |
77 | for model_name in models_to_compare:
78 | try:
79 | logger.info(f"Testing model: {model_name}", emoji_key="model")
80 | start_time = time.time()
81 | result = await provider.generate_completion(
82 | prompt=prompt,
83 | model=model_name,
84 | temperature=0.3,
85 | max_tokens=300
86 | )
87 | processing_time = time.time() - start_time
88 |
89 | # Track the cost
90 | tracker.add_call(result)
91 |
92 | results_data.append({
93 | "model": model_name,
94 | "text": result.text,
95 | "tokens": {
96 | "input": result.input_tokens,
97 | "output": result.output_tokens,
98 | "total": result.total_tokens
99 | },
100 | "cost": result.cost,
101 | "time": processing_time
102 | })
103 |
104 | logger.success(
105 | f"Completion for {model_name} successful",
106 | emoji_key="success",
107 | # Tokens/cost/time logged implicitly by storing in results_data
108 | )
109 |
110 | except Exception as e:
111 | logger.error(f"Error testing model {model_name}: {str(e)}", emoji_key="error", exc_info=True)
112 | # Optionally add an error entry to results_data if needed
113 |
114 | # Display comparison results using Rich
115 | if results_data:
116 | console.print(Rule("[bold green]Comparison Results[/bold green]"))
117 |
118 | for result_item in results_data:
119 | model = result_item["model"]
120 | time_s = result_item["time"]
121 | tokens = result_item.get("tokens", {}).get("total", 0)
122 | tokens_per_second = tokens / time_s if time_s > 0 else 0
123 | cost = result_item.get("cost", 0.0)
124 | text = result_item.get("text", "[red]Error generating response[/red]").strip()
125 |
126 | stats_line = (
127 | f"Time: [yellow]{time_s:.2f}s[/yellow] | "
128 | f"Tokens: [cyan]{tokens}[/cyan] | "
129 | f"Speed: [blue]{tokens_per_second:.1f} tok/s[/blue] | "
130 | f"Cost: [green]${cost:.6f}[/green]"
131 | )
132 |
133 | console.print(Panel(
134 | escape(text),
135 | title=f"[bold magenta]{escape(model)}[/bold magenta]",
136 | subtitle=stats_line,
137 | border_style="blue",
138 | expand=False
139 | ))
140 | console.print()
141 |
142 | except Exception as e:
143 | logger.error(f"Error in model comparison: {str(e)}", emoji_key="error", exc_info=True)
144 | # Optionally re-raise or handle differently
145 |
146 |
147 | async def demonstrate_system_prompt(tracker: CostTracker):
148 | """Demonstrate Claude with system prompts."""
149 | console.print(Rule("[bold blue]Claude System Prompt Demonstration[/bold blue]"))
150 | logger.info("Demonstrating Claude with system prompts", emoji_key="start")
151 |
152 | # Create Gateway instance - this handles provider initialization
153 | gateway = Gateway("claude-demo", register_tools=False)
154 |
155 | # Initialize providers
156 | logger.info("Initializing providers...", emoji_key="provider")
157 | await gateway._initialize_providers()
158 |
159 | provider_name = Provider.ANTHROPIC.value
160 | try:
161 | # Get the provider from the gateway
162 | provider = gateway.providers.get(provider_name)
163 | if not provider:
164 | logger.error(f"Provider {provider_name} not available or initialized", emoji_key="error")
165 | return
166 |
167 | # Use a fast Claude model (ensure it's available)
168 | model = "anthropic/claude-3-5-haiku-20241022"
169 | available_models = await provider.list_models()
170 | if model not in [m["id"] for m in available_models]:
171 | logger.warning(f"Model {model} not available, falling back to default.", emoji_key="warning")
172 | model = provider.get_default_model()
173 | if not model:
174 | logger.error("No suitable Claude model found for system prompt demo.", emoji_key="error")
175 | return
176 | logger.info(f"Using model: {model}", emoji_key="model")
177 |
178 | system_prompt = """
179 | You are a helpful assistant with expertise in physics.
180 | Keep all explanations accurate but very concise.
181 | Always provide real-world examples to illustrate concepts.
182 | """
183 | user_prompt = "Explain the concept of gravity."
184 |
185 | logger.info("Generating completion with system prompt", emoji_key="processing")
186 |
187 | result = await provider.generate_completion(
188 | prompt=user_prompt,
189 | model=model,
190 | temperature=0.7,
191 | system=system_prompt,
192 | max_tokens=1000 # Increased max_tokens
193 | )
194 |
195 | # Track the cost
196 | tracker.add_call(result)
197 |
198 | logger.success("Completion with system prompt successful", emoji_key="success")
199 |
200 | # Display result using Rich Panels
201 | console.print(Panel(
202 | escape(system_prompt.strip()),
203 | title="[bold cyan]System Prompt[/bold cyan]",
204 | border_style="dim cyan",
205 | expand=False
206 | ))
207 | console.print(Panel(
208 | escape(user_prompt.strip()),
209 | title="[bold yellow]User Prompt[/bold yellow]",
210 | border_style="dim yellow",
211 | expand=False
212 | ))
213 | console.print(Panel(
214 | escape(result.text.strip()),
215 | title="[bold green]Claude Response[/bold green]",
216 | border_style="green",
217 | expand=False
218 | ))
219 |
220 | # Display stats in a small table
221 | stats_table = Table(title="Execution Stats", show_header=False, box=box.MINIMAL, expand=False)
222 | stats_table.add_column("Metric", style="cyan")
223 | stats_table.add_column("Value", style="white")
224 | stats_table.add_row("Input Tokens", str(result.input_tokens))
225 | stats_table.add_row("Output Tokens", str(result.output_tokens))
226 | stats_table.add_row("Cost", f"${result.cost:.6f}")
227 | stats_table.add_row("Processing Time", f"{result.processing_time:.3f}s")
228 | console.print(stats_table)
229 | console.print()
230 |
231 | except Exception as e:
232 | logger.error(f"Error in system prompt demonstration: {str(e)}", emoji_key="error", exc_info=True)
233 | # Optionally re-raise or handle differently
234 |
235 |
236 | async def explore_claude_models():
237 | """Display available Claude models."""
238 | console.print(Rule("[bold cyan]Available Claude Models[/bold cyan]"))
239 |
240 | # Create Gateway instance - this handles provider initialization
241 | gateway = Gateway("claude-demo", register_tools=False)
242 |
243 | # Initialize providers
244 | logger.info("Initializing providers...", emoji_key="provider")
245 | await gateway._initialize_providers()
246 |
247 | # Get provider from the gateway
248 | provider = gateway.providers.get(Provider.ANTHROPIC.value)
249 | if not provider:
250 | logger.error(f"Provider {Provider.ANTHROPIC.value} not available or initialized", emoji_key="error")
251 | return
252 |
253 | # Get list of available models
254 | models = await provider.list_models()
255 | model_names = [m["id"] for m in models] # Extract names from model dictionaries
256 | console.print(f"Found {len(model_names)} Claude models: [cyan]{escape(str(model_names))}[/cyan]")
257 |
258 |
259 | async def main():
260 | """Run Claude integration examples."""
261 | tracker = CostTracker() # Instantiate tracker here
262 | try:
263 | # Run model comparison
264 | await compare_claude_models(tracker) # Pass tracker
265 |
266 | console.print() # Add space between sections
267 |
268 | # Run system prompt demonstration
269 | await demonstrate_system_prompt(tracker) # Pass tracker
270 |
271 | # Run explore Claude models
272 | await explore_claude_models()
273 |
274 | # Display final summary
275 | tracker.display_summary(console) # Display summary at the end
276 |
277 | except Exception as e:
278 | logger.critical(f"Example failed: {str(e)}", emoji_key="critical", exc_info=True)
279 | return 1
280 |
281 | logger.success("Claude Integration Demo Finished Successfully!", emoji_key="complete")
282 | return 0
283 |
284 |
285 | if __name__ == "__main__":
286 | exit_code = asyncio.run(main())
287 | sys.exit(exit_code)
```
--------------------------------------------------------------------------------
/comprehensive_test.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python3
2 | """
3 | Comprehensive test script for Ultimate MCP Server
4 | Tests specific tools and REST API endpoints
5 | """
6 |
7 | import asyncio
8 | import json
9 |
10 | import aiohttp
11 | from fastmcp import Client
12 |
13 |
14 | async def test_mcp_interface():
15 | """Test the MCP interface functionality."""
16 | server_url = "http://127.0.0.1:8013/mcp"
17 |
18 | print("🔧 Testing MCP Interface")
19 | print("=" * 40)
20 |
21 | try:
22 | async with Client(server_url) as client:
23 | print("✅ MCP client connected")
24 |
25 | # Test core tools
26 | tools_to_test = [
27 | ("echo", {"message": "Hello MCP!"}),
28 | ("get_provider_status", {}),
29 | ("list_models", {}),
30 | ]
31 |
32 | for tool_name, params in tools_to_test:
33 | try:
34 | result = await client.call_tool(tool_name, params)
35 | if result:
36 | print(f"✅ {tool_name}: OK")
37 | # Show sample of result for key tools
38 | if tool_name == "get_provider_status":
39 | data = json.loads(result[0].text)
40 | provider_count = len(data.get('providers', {}))
41 | print(f" → {provider_count} providers configured")
42 | elif tool_name == "list_models":
43 | data = json.loads(result[0].text)
44 | total_models = sum(len(models) for models in data.get('models', {}).values())
45 | print(f" → {total_models} total models available")
46 | else:
47 | print(f"❌ {tool_name}: No response")
48 | except Exception as e:
49 | print(f"❌ {tool_name}: {e}")
50 |
51 | # Test filesystem tools
52 | print("\n📁 Testing filesystem access...")
53 | try:
54 | dirs_result = await client.call_tool("list_allowed_directories", {})
55 | if dirs_result:
56 | print("✅ Filesystem access configured")
57 | except Exception as e:
58 | print(f"❌ Filesystem access: {e}")
59 |
60 | # Test Python execution
61 | print("\n🐍 Testing Python sandbox...")
62 | try:
63 | python_result = await client.call_tool("execute_python", {
64 | "code": "import sys; print(f'Python {sys.version_info.major}.{sys.version_info.minor}')"
65 | })
66 | if python_result:
67 | result_data = json.loads(python_result[0].text)
68 | if result_data.get('success'):
69 | print("✅ Python sandbox working")
70 | print(f" → {result_data.get('output', '').strip()}")
71 | else:
72 | print("❌ Python sandbox failed")
73 | except Exception as e:
74 | print(f"❌ Python sandbox: {e}")
75 |
76 | except Exception as e:
77 | print(f"❌ MCP interface failed: {e}")
78 |
79 |
80 | async def test_rest_api():
81 | """Test the REST API endpoints."""
82 | base_url = "http://127.0.0.1:8013"
83 |
84 | print("\n🌐 Testing REST API Endpoints")
85 | print("=" * 40)
86 |
87 | async with aiohttp.ClientSession() as session:
88 | # Test discovery endpoint
89 | try:
90 | async with session.get(f"{base_url}/") as response:
91 | if response.status == 200:
92 | data = await response.json()
93 | print(f"✅ Discovery endpoint: {data.get('type')}")
94 | print(f" → Transport: {data.get('transport')}")
95 | print(f" → Endpoint: {data.get('endpoint')}")
96 | else:
97 | print(f"❌ Discovery endpoint: HTTP {response.status}")
98 | except Exception as e:
99 | print(f"❌ Discovery endpoint: {e}")
100 |
101 | # Test health endpoint
102 | try:
103 | async with session.get(f"{base_url}/api/health") as response:
104 | if response.status == 200:
105 | data = await response.json()
106 | print(f"✅ Health endpoint: {data.get('status')}")
107 | else:
108 | print(f"❌ Health endpoint: HTTP {response.status}")
109 | except Exception as e:
110 | print(f"❌ Health endpoint: {e}")
111 |
112 | # Test OpenAPI docs
113 | try:
114 | async with session.get(f"{base_url}/api/docs") as response:
115 | if response.status == 200:
116 | print("✅ Swagger UI accessible")
117 | else:
118 | print(f"❌ Swagger UI: HTTP {response.status}")
119 | except Exception as e:
120 | print(f"❌ Swagger UI: {e}")
121 |
122 | # Test cognitive states endpoint
123 | try:
124 | async with session.get(f"{base_url}/api/cognitive-states") as response:
125 | if response.status == 200:
126 | data = await response.json()
127 | print(f"✅ Cognitive states: {data.get('total', 0)} states")
128 | else:
129 | print(f"❌ Cognitive states: HTTP {response.status}")
130 | except Exception as e:
131 | print(f"❌ Cognitive states: {e}")
132 |
133 | # Test performance overview
134 | try:
135 | async with session.get(f"{base_url}/api/performance/overview") as response:
136 | if response.status == 200:
137 | data = await response.json()
138 | overview = data.get('overview', {})
139 | print(f"✅ Performance overview: {overview.get('total_actions', 0)} actions")
140 | else:
141 | print(f"❌ Performance overview: HTTP {response.status}")
142 | except Exception as e:
143 | print(f"❌ Performance overview: {e}")
144 |
145 | # Test artifacts endpoint
146 | try:
147 | async with session.get(f"{base_url}/api/artifacts") as response:
148 | if response.status == 200:
149 | data = await response.json()
150 | print(f"✅ Artifacts: {data.get('total', 0)} artifacts")
151 | else:
152 | print(f"❌ Artifacts: HTTP {response.status}")
153 | except Exception as e:
154 | print(f"❌ Artifacts: {e}")
155 |
156 |
157 | async def test_tool_completions():
158 | """Test actual completions with available providers."""
159 | server_url = "http://127.0.0.1:8013/mcp"
160 |
161 | print("\n🤖 Testing LLM Completions")
162 | print("=" * 40)
163 |
164 | try:
165 | async with Client(server_url) as client:
166 | # Get available providers first
167 | provider_result = await client.call_tool("get_provider_status", {})
168 | provider_data = json.loads(provider_result[0].text)
169 |
170 | available_providers = []
171 | for name, status in provider_data.get('providers', {}).items():
172 | if status.get('available') and status.get('models'):
173 | available_providers.append((name, status['models'][0]))
174 |
175 | if not available_providers:
176 | print("❌ No providers available for testing")
177 | return
178 |
179 | # Test with first available provider
180 | provider_name, model_info = available_providers[0]
181 | model_id = model_info.get('id')
182 |
183 | print(f"🧪 Testing with {provider_name} / {model_id}")
184 |
185 | try:
186 | result = await client.call_tool("generate_completion", {
187 | "prompt": "Count from 1 to 5",
188 | "provider": provider_name,
189 | "model": model_id,
190 | "max_tokens": 50
191 | })
192 |
193 | if result:
194 | response_data = json.loads(result[0].text)
195 | if response_data.get('success', True):
196 | print("✅ Completion successful")
197 | print(f" → Response: {response_data.get('text', '')[:100]}...")
198 | if 'usage' in response_data:
199 | usage = response_data['usage']
200 | print(f" → Tokens: {usage.get('total_tokens', 'N/A')}")
201 | else:
202 | print(f"❌ Completion failed: {response_data.get('error')}")
203 | else:
204 | print("❌ No completion response")
205 |
206 | except Exception as e:
207 | print(f"❌ Completion error: {e}")
208 |
209 | except Exception as e:
210 | print(f"❌ Completion test failed: {e}")
211 |
212 |
213 | async def test_memory_system():
214 | """Test the memory and cognitive state system."""
215 | server_url = "http://127.0.0.1:8013/mcp"
216 |
217 | print("\n🧠 Testing Memory System")
218 | print("=" * 40)
219 |
220 | try:
221 | async with Client(server_url) as client:
222 | # Test memory storage
223 | try:
224 | memory_result = await client.call_tool("store_memory", {
225 | "memory_type": "test",
226 | "content": "This is a test memory for the test client",
227 | "importance": 7.5,
228 | "tags": ["test", "client"]
229 | })
230 |
231 | if memory_result:
232 | memory_data = json.loads(memory_result[0].text)
233 | if memory_data.get('success'):
234 | memory_id = memory_data.get('memory_id')
235 | print(f"✅ Memory stored: {memory_id}")
236 |
237 | # Test memory retrieval
238 | try:
239 | get_result = await client.call_tool("get_memory_by_id", {
240 | "memory_id": memory_id
241 | })
242 |
243 | if get_result:
244 | print("✅ Memory retrieved successfully")
245 | except Exception as e:
246 | print(f"❌ Memory retrieval: {e}")
247 |
248 | else:
249 | print(f"❌ Memory storage failed: {memory_data.get('error')}")
250 |
251 | except Exception as e:
252 | print(f"❌ Memory system: {e}")
253 |
254 | # Test cognitive state
255 | try:
256 | state_result = await client.call_tool("save_cognitive_state", {
257 | "state_type": "test_state",
258 | "description": "Test cognitive state from client",
259 | "data": {"test": True, "client": "test_client"}
260 | })
261 |
262 | if state_result:
263 | state_data = json.loads(state_result[0].text)
264 | if state_data.get('success'):
265 | print("✅ Cognitive state saved")
266 | else:
267 | print(f"❌ Cognitive state failed: {state_data.get('error')}")
268 |
269 | except Exception as e:
270 | print(f"❌ Cognitive state: {e}")
271 |
272 | except Exception as e:
273 | print(f"❌ Memory system test failed: {e}")
274 |
275 |
276 | async def main():
277 | """Run all comprehensive tests."""
278 | print("🚀 Ultimate MCP Server Comprehensive Test Suite")
279 | print("=" * 60)
280 |
281 | # Test MCP interface
282 | await test_mcp_interface()
283 |
284 | # Test REST API
285 | await test_rest_api()
286 |
287 | # Test completions
288 | await test_tool_completions()
289 |
290 | # Test memory system
291 | await test_memory_system()
292 |
293 | print("\n🎯 Comprehensive testing completed!")
294 | print("\nIf you see mostly ✅ symbols, your server is working correctly!")
295 | print("Any ❌ symbols indicate areas that may need attention.")
296 |
297 |
298 | if __name__ == "__main__":
299 | asyncio.run(main())
```