This is page 23 of 45. Use http://codebase.md/dicklesworthstone/llm_gateway_mcp_server?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .cursorignore
├── .env.example
├── .envrc
├── .gitignore
├── additional_features.md
├── check_api_keys.py
├── completion_support.py
├── comprehensive_test.py
├── docker-compose.yml
├── Dockerfile
├── empirically_measured_model_speeds.json
├── error_handling.py
├── example_structured_tool.py
├── examples
│ ├── __init__.py
│ ├── advanced_agent_flows_using_unified_memory_system_demo.py
│ ├── advanced_extraction_demo.py
│ ├── advanced_unified_memory_system_demo.py
│ ├── advanced_vector_search_demo.py
│ ├── analytics_reporting_demo.py
│ ├── audio_transcription_demo.py
│ ├── basic_completion_demo.py
│ ├── cache_demo.py
│ ├── claude_integration_demo.py
│ ├── compare_synthesize_demo.py
│ ├── cost_optimization.py
│ ├── data
│ │ ├── sample_event.txt
│ │ ├── Steve_Jobs_Introducing_The_iPhone_compressed.md
│ │ └── Steve_Jobs_Introducing_The_iPhone_compressed.mp3
│ ├── docstring_refiner_demo.py
│ ├── document_conversion_and_processing_demo.py
│ ├── entity_relation_graph_demo.py
│ ├── filesystem_operations_demo.py
│ ├── grok_integration_demo.py
│ ├── local_text_tools_demo.py
│ ├── marqo_fused_search_demo.py
│ ├── measure_model_speeds.py
│ ├── meta_api_demo.py
│ ├── multi_provider_demo.py
│ ├── ollama_integration_demo.py
│ ├── prompt_templates_demo.py
│ ├── python_sandbox_demo.py
│ ├── rag_example.py
│ ├── research_workflow_demo.py
│ ├── sample
│ │ ├── article.txt
│ │ ├── backprop_paper.pdf
│ │ ├── buffett.pdf
│ │ ├── contract_link.txt
│ │ ├── legal_contract.txt
│ │ ├── medical_case.txt
│ │ ├── northwind.db
│ │ ├── research_paper.txt
│ │ ├── sample_data.json
│ │ └── text_classification_samples
│ │ ├── email_classification.txt
│ │ ├── news_samples.txt
│ │ ├── product_reviews.txt
│ │ └── support_tickets.txt
│ ├── sample_docs
│ │ └── downloaded
│ │ └── attention_is_all_you_need.pdf
│ ├── sentiment_analysis_demo.py
│ ├── simple_completion_demo.py
│ ├── single_shot_synthesis_demo.py
│ ├── smart_browser_demo.py
│ ├── sql_database_demo.py
│ ├── sse_client_demo.py
│ ├── test_code_extraction.py
│ ├── test_content_detection.py
│ ├── test_ollama.py
│ ├── text_classification_demo.py
│ ├── text_redline_demo.py
│ ├── tool_composition_examples.py
│ ├── tournament_code_demo.py
│ ├── tournament_text_demo.py
│ ├── unified_memory_system_demo.py
│ ├── vector_search_demo.py
│ ├── web_automation_instruction_packs.py
│ └── workflow_delegation_demo.py
├── LICENSE
├── list_models.py
├── marqo_index_config.json.example
├── mcp_protocol_schema_2025-03-25_version.json
├── mcp_python_lib_docs.md
├── mcp_tool_context_estimator.py
├── model_preferences.py
├── pyproject.toml
├── quick_test.py
├── README.md
├── resource_annotations.py
├── run_all_demo_scripts_and_check_for_errors.py
├── storage
│ └── smart_browser_internal
│ ├── locator_cache.db
│ ├── readability.js
│ └── storage_state.enc
├── test_client.py
├── test_connection.py
├── TEST_README.md
├── test_sse_client.py
├── test_stdio_client.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── integration
│ │ ├── __init__.py
│ │ └── test_server.py
│ ├── manual
│ │ ├── test_extraction_advanced.py
│ │ └── test_extraction.py
│ └── unit
│ ├── __init__.py
│ ├── test_cache.py
│ ├── test_providers.py
│ └── test_tools.py
├── TODO.md
├── tool_annotations.py
├── tools_list.json
├── ultimate_mcp_banner.webp
├── ultimate_mcp_logo.webp
├── ultimate_mcp_server
│ ├── __init__.py
│ ├── __main__.py
│ ├── cli
│ │ ├── __init__.py
│ │ ├── __main__.py
│ │ ├── commands.py
│ │ ├── helpers.py
│ │ └── typer_cli.py
│ ├── clients
│ │ ├── __init__.py
│ │ ├── completion_client.py
│ │ └── rag_client.py
│ ├── config
│ │ └── examples
│ │ └── filesystem_config.yaml
│ ├── config.py
│ ├── constants.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── evaluation
│ │ │ ├── base.py
│ │ │ └── evaluators.py
│ │ ├── providers
│ │ │ ├── __init__.py
│ │ │ ├── anthropic.py
│ │ │ ├── base.py
│ │ │ ├── deepseek.py
│ │ │ ├── gemini.py
│ │ │ ├── grok.py
│ │ │ ├── ollama.py
│ │ │ ├── openai.py
│ │ │ └── openrouter.py
│ │ ├── server.py
│ │ ├── state_store.py
│ │ ├── tournaments
│ │ │ ├── manager.py
│ │ │ ├── tasks.py
│ │ │ └── utils.py
│ │ └── ums_api
│ │ ├── __init__.py
│ │ ├── ums_database.py
│ │ ├── ums_endpoints.py
│ │ ├── ums_models.py
│ │ └── ums_services.py
│ ├── exceptions.py
│ ├── graceful_shutdown.py
│ ├── services
│ │ ├── __init__.py
│ │ ├── analytics
│ │ │ ├── __init__.py
│ │ │ ├── metrics.py
│ │ │ └── reporting.py
│ │ ├── cache
│ │ │ ├── __init__.py
│ │ │ ├── cache_service.py
│ │ │ ├── persistence.py
│ │ │ ├── strategies.py
│ │ │ └── utils.py
│ │ ├── cache.py
│ │ ├── document.py
│ │ ├── knowledge_base
│ │ │ ├── __init__.py
│ │ │ ├── feedback.py
│ │ │ ├── manager.py
│ │ │ ├── rag_engine.py
│ │ │ ├── retriever.py
│ │ │ └── utils.py
│ │ ├── prompts
│ │ │ ├── __init__.py
│ │ │ ├── repository.py
│ │ │ └── templates.py
│ │ ├── prompts.py
│ │ └── vector
│ │ ├── __init__.py
│ │ ├── embeddings.py
│ │ └── vector_service.py
│ ├── tool_token_counter.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── audio_transcription.py
│ │ ├── base.py
│ │ ├── completion.py
│ │ ├── docstring_refiner.py
│ │ ├── document_conversion_and_processing.py
│ │ ├── enhanced-ums-lookbook.html
│ │ ├── entity_relation_graph.py
│ │ ├── excel_spreadsheet_automation.py
│ │ ├── extraction.py
│ │ ├── filesystem.py
│ │ ├── html_to_markdown.py
│ │ ├── local_text_tools.py
│ │ ├── marqo_fused_search.py
│ │ ├── meta_api_tool.py
│ │ ├── ocr_tools.py
│ │ ├── optimization.py
│ │ ├── provider.py
│ │ ├── pyodide_boot_template.html
│ │ ├── python_sandbox.py
│ │ ├── rag.py
│ │ ├── redline-compiled.css
│ │ ├── sentiment_analysis.py
│ │ ├── single_shot_synthesis.py
│ │ ├── smart_browser.py
│ │ ├── sql_databases.py
│ │ ├── text_classification.py
│ │ ├── text_redline_tools.py
│ │ ├── tournament.py
│ │ ├── ums_explorer.html
│ │ └── unified_memory_system.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── async_utils.py
│ │ ├── display.py
│ │ ├── logging
│ │ │ ├── __init__.py
│ │ │ ├── console.py
│ │ │ ├── emojis.py
│ │ │ ├── formatter.py
│ │ │ ├── logger.py
│ │ │ ├── panels.py
│ │ │ ├── progress.py
│ │ │ └── themes.py
│ │ ├── parse_yaml.py
│ │ ├── parsing.py
│ │ ├── security.py
│ │ └── text.py
│ └── working_memory_api.py
├── unified_memory_system_technical_analysis.md
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/examples/local_text_tools_demo.py:
--------------------------------------------------------------------------------
```python
1 | # local_text_tools_demo.py
2 | """
3 | Comprehensive demonstration script for the local_text_tools functions in Ultimate MCP Server.
4 |
5 | This script showcases the usage of local command-line text processing utilities
6 | (ripgrep, awk, sed, jq) through the secure, standalone functions provided by
7 | ultimate_mcp_server.tools.local_text_tools.
8 | It includes basic examples, advanced command-line techniques, security failure demos,
9 | streaming examples, and interactive workflows demonstrating LLM-driven tool usage
10 | on sample documents.
11 |
12 | It uses sample files from the 'sample/' directory relative to this script.
13 | NOTE: The LLM interactive demos require a configured LLM provider (e.g., OpenAI API key).
14 |
15 | -------------------------------------------------------------------------------------
16 | IMPORTANT: ABOUT ERROR INDICATORS AND "FAILURES" IN THIS DEMO
17 | -------------------------------------------------------------------------------------
18 |
19 | Many demonstrations in this script INTENTIONALLY trigger security features and error
20 | handling. These appear as red ❌ boxes but are actually showing CORRECT BEHAVIOR.
21 |
22 | Examples of intentional security demonstrations include:
23 | - Invalid regex patterns (to show proper error reporting)
24 | - AWK/SED script syntax errors (to show validation)
25 | - Path traversal attempts (to demonstrate workspace confinement)
26 | - Usage of forbidden flags like 'sed -i' (showing security limits)
27 | - Redirection attempts (demonstrating shell character blocking)
28 | - Command substitution (showing protection against command injection)
29 |
30 | When you see "SECURITY CHECK PASSED" or "INTENTIONAL DEMONSTRATION" in the description,
31 | this indicates a feature working correctly, not a bug in the tools.
32 | -------------------------------------------------------------------------------------
33 | """
34 |
35 | # --- Standard Library Imports ---
36 | import asyncio
37 | import inspect # top-level import is fine too
38 | import json
39 | import os
40 | import re
41 | import shlex
42 | import shutil
43 | import sys
44 | import time
45 | from enum import Enum # Import Enum from the enum module, not typing
46 | from pathlib import Path
47 | from typing import Any, AsyncIterator, Callable, Coroutine, Dict, List, Optional
48 |
49 | # --- Configuration & Path Setup ---
50 | # Add project root to path for imports when running as script
51 | try:
52 | SCRIPT_DIR = Path(__file__).resolve().parent
53 | PROJECT_ROOT = SCRIPT_DIR
54 | # Try to find project root marker ('ultimate_mcp_server' dir or 'pyproject.toml')
55 | while (
56 | not (
57 | (PROJECT_ROOT / "ultimate_mcp_server").is_dir()
58 | or (PROJECT_ROOT / "pyproject.toml").is_file()
59 | )
60 | and PROJECT_ROOT.parent != PROJECT_ROOT
61 | ):
62 | PROJECT_ROOT = PROJECT_ROOT.parent
63 |
64 | # If marker found and path not added, add it
65 | if (PROJECT_ROOT / "ultimate_mcp_server").is_dir() or (
66 | PROJECT_ROOT / "pyproject.toml"
67 | ).is_file():
68 | if str(PROJECT_ROOT) not in sys.path:
69 | sys.path.insert(0, str(PROJECT_ROOT))
70 | # Fallback if no clear marker found upwards
71 | elif SCRIPT_DIR.parent != PROJECT_ROOT and (SCRIPT_DIR.parent / "ultimate_mcp_server").is_dir():
72 | PROJECT_ROOT = SCRIPT_DIR.parent
73 | print(f"Warning: Assuming project root is {PROJECT_ROOT}", file=sys.stderr)
74 | if str(PROJECT_ROOT) not in sys.path:
75 | sys.path.insert(0, str(PROJECT_ROOT))
76 | # Final fallback: add script dir itself
77 | elif str(SCRIPT_DIR) not in sys.path:
78 | sys.path.insert(0, str(SCRIPT_DIR))
79 | print(
80 | f"Warning: Could not reliably determine project root. Added script directory {SCRIPT_DIR} to path as fallback.",
81 | file=sys.stderr,
82 | )
83 | else:
84 | # If already in path, assume it's okay
85 | pass
86 |
87 | # Set MCP_TEXT_WORKSPACE environment variable to PROJECT_ROOT before importing local_text_tools
88 | os.environ["MCP_TEXT_WORKSPACE"] = str(PROJECT_ROOT)
89 | print(f"Set MCP_TEXT_WORKSPACE to: {os.environ['MCP_TEXT_WORKSPACE']}", file=sys.stderr)
90 |
91 | except Exception as e:
92 | print(f"Error setting up sys.path: {e}", file=sys.stderr)
93 | sys.exit(1)
94 |
95 |
96 | # --- Third-Party Imports ---
97 | try:
98 | from rich.console import Console
99 | from rich.markup import escape
100 | from rich.panel import Panel
101 | from rich.pretty import pretty_repr
102 | from rich.rule import Rule
103 | from rich.syntax import Syntax
104 | from rich.traceback import install as install_rich_traceback
105 | except ImportError:
106 | print("Error: 'rich' library not found. Please install it: pip install rich", file=sys.stderr)
107 | sys.exit(1)
108 |
109 |
110 | # --- Project-Specific Imports ---
111 | # Import necessary tools and components
112 | try:
113 | # Import specific functions and types
114 | from ultimate_mcp_server.config import get_config # To check LLM provider config
115 | from ultimate_mcp_server.constants import Provider # For LLM demo
116 |
117 | # Import specific exceptions
118 | from ultimate_mcp_server.exceptions import ToolExecutionError, ToolInputError
119 | from ultimate_mcp_server.tools.completion import chat_completion
120 | from ultimate_mcp_server.tools.local_text_tools import (
121 | ToolErrorCode,
122 | ToolResult,
123 | get_workspace_dir, # Function to get configured workspace
124 | run_awk,
125 | run_awk_stream,
126 | run_jq,
127 | run_jq_stream,
128 | run_ripgrep,
129 | run_ripgrep_stream,
130 | run_sed,
131 | run_sed_stream,
132 | )
133 | from ultimate_mcp_server.utils import get_logger
134 | except ImportError as import_err:
135 | print(f"Error: Failed to import necessary MCP Server components: {import_err}", file=sys.stderr)
136 | print(
137 | "Please ensure the script is run from within the correct environment, the package is installed (`pip install -e .`), and project structure is correct.",
138 | file=sys.stderr,
139 | )
140 | sys.exit(1)
141 |
142 | # --- Initialization ---
143 | console = Console()
144 | logger = get_logger("demo.local_text_tools")
145 | install_rich_traceback(show_locals=False, width=console.width)
146 |
147 | # Define path to sample files relative to this script's location
148 | SAMPLE_DIR = SCRIPT_DIR / "sample"
149 | if not SAMPLE_DIR.is_dir():
150 | print(
151 | f"Error: Sample directory not found at expected location: {SCRIPT_DIR}/sample",
152 | file=sys.stderr,
153 | )
154 | # Try locating it relative to Project Root as fallback
155 | ALT_SAMPLE_DIR = PROJECT_ROOT / "examples" / "local_text_tools_demo" / "sample"
156 | if ALT_SAMPLE_DIR.is_dir():
157 | print(f"Found sample directory at alternate location: {ALT_SAMPLE_DIR}", file=sys.stderr)
158 | SAMPLE_DIR = ALT_SAMPLE_DIR
159 | else:
160 | print(
161 | f"Please ensure the 'sample' directory exists within {SCRIPT_DIR} or {ALT_SAMPLE_DIR}.",
162 | file=sys.stderr,
163 | )
164 | sys.exit(1)
165 |
166 | # Store both absolute and relative paths for the samples
167 | SAMPLE_DIR_ABS = SAMPLE_DIR
168 | CLASSIFICATION_SAMPLES_DIR_ABS = SAMPLE_DIR / "text_classification_samples"
169 |
170 | # Create relative paths for use with the tools - relative to PROJECT_ROOT
171 | SAMPLE_DIR_REL = SAMPLE_DIR.relative_to(PROJECT_ROOT)
172 | CLASSIFICATION_SAMPLES_DIR_REL = CLASSIFICATION_SAMPLES_DIR_ABS.relative_to(PROJECT_ROOT)
173 |
174 | # Use relative paths for the tools
175 | CONTRACT_FILE_PATH = str(SAMPLE_DIR_REL / "legal_contract.txt") # Relative path
176 | ARTICLE_FILE_PATH = str(SAMPLE_DIR_REL / "article.txt")
177 | EMAIL_FILE_PATH = str(CLASSIFICATION_SAMPLES_DIR_REL / "email_classification.txt")
178 | SCHEDULE_FILE_PATH = str(SAMPLE_DIR_REL / "SCHEDULE_1.2") # Added for awk demo
179 | JSON_SAMPLE_PATH = str(SAMPLE_DIR_REL / "sample_data.json") # Added for jq file demo
180 |
181 | # But for file operations (checking existence, etc.), use absolute paths
182 | CONTRACT_FILE_PATH_ABS = str(SAMPLE_DIR_ABS / "legal_contract.txt")
183 | ARTICLE_FILE_PATH_ABS = str(SAMPLE_DIR_ABS / "article.txt")
184 | EMAIL_FILE_PATH_ABS = str(CLASSIFICATION_SAMPLES_DIR_ABS / "email_classification.txt")
185 | SCHEDULE_FILE_PATH_ABS = str(SAMPLE_DIR_ABS / "SCHEDULE_1.2")
186 | JSON_SAMPLE_PATH_ABS = str(SAMPLE_DIR_ABS / "sample_data.json")
187 |
188 | # Create sample JSON file if it doesn't exist
189 | if not Path(JSON_SAMPLE_PATH_ABS).exists():
190 | sample_json_content = """
191 | [
192 | {"user": "Alice", "dept": "Sales", "region": "North", "value": 100, "tags": ["active", "pipeline"]},
193 | {"user": "Bob", "dept": "IT", "region": "South", "value": 150, "tags": ["active", "support"]},
194 | {"user": "Charlie", "dept": "Sales", "region": "North", "value": 120, "tags": ["inactive", "pipeline"]},
195 | {"user": "David", "dept": "IT", "region": "West", "value": 200, "tags": ["active", "admin"]}
196 | ]
197 | """
198 | try:
199 | # Make sure the directory exists
200 | Path(JSON_SAMPLE_PATH_ABS).parent.mkdir(parents=True, exist_ok=True)
201 | with open(JSON_SAMPLE_PATH_ABS, "w") as f:
202 | f.write(sample_json_content)
203 | logger.info(f"Created sample JSON file: {JSON_SAMPLE_PATH_ABS}")
204 | except OSError as e:
205 | logger.error(f"Failed to create sample JSON file {JSON_SAMPLE_PATH_ABS}: {e}")
206 | # Continue without it, jq file demos will fail gracefully
207 |
208 | MAX_LLM_ITERATIONS = 5 # Limit for the interactive demo
209 |
210 | # --- Helper Functions ---
211 |
212 | ToolFunction = Callable[..., Coroutine[Any, Any, ToolResult]]
213 | StreamFunction = Callable[..., Coroutine[Any, Any, AsyncIterator[str]]]
214 |
215 |
216 | async def safe_tool_call(
217 | tool_func: ToolFunction,
218 | args: Dict[str, Any],
219 | description: str,
220 | display_input: bool = True,
221 | display_output: bool = True,
222 | ) -> ToolResult:
223 | """Helper to call a tool function, catch errors, and display results."""
224 | tool_func_name = getattr(tool_func, "__name__", "unknown_tool")
225 |
226 | if display_output:
227 | console.print(Rule(f"[bold blue]{escape(description)}[/bold blue]", style="blue"))
228 |
229 | if not callable(tool_func):
230 | console.print(
231 | f"[bold red]Error:[/bold red] Tool function '{tool_func_name}' is not callable."
232 | )
233 | return ToolResult(success=False, error=f"Function '{tool_func_name}' not callable.")
234 |
235 | if display_input and display_output:
236 | console.print(f"[dim]Calling [bold cyan]{tool_func_name}[/] with args:[/]")
237 | try:
238 | args_to_print = args.copy()
239 | # Truncate long input_data for display
240 | if "input_data" in args_to_print and isinstance(args_to_print["input_data"], str):
241 | if len(args_to_print["input_data"]) > 200:
242 | args_to_print["input_data"] = args_to_print["input_data"][:200] + "[...]"
243 | args_repr = pretty_repr(args_to_print, max_length=120, max_string=200)
244 | console.print(args_repr)
245 | except Exception:
246 | console.print("(Could not represent args)")
247 |
248 | start_time = time.monotonic()
249 | result: ToolResult = ToolResult(
250 | success=False, error="Execution did not complete.", exit_code=None
251 | ) # Default error
252 |
253 | try:
254 | result = await tool_func(**args) # Direct function call
255 | processing_time = time.monotonic() - start_time
256 | logger.debug(f"Tool '{tool_func_name}' execution time: {processing_time:.4f}s")
257 |
258 | if display_output:
259 | success = result.get("success", False)
260 | is_dry_run = result.get("dry_run_cmdline") is not None
261 |
262 | panel_title = f"[bold {'green' if success else 'red'}]Result: {tool_func_name} {'✅' if success else '❌'}{' (Dry Run)' if is_dry_run else ''}[/]"
263 | panel_border = "green" if success else "red"
264 |
265 | # Format output for display
266 | output_display = ""
267 | exit_code = result.get("exit_code", "N/A")
268 | output_display += f"[bold]Exit Code:[/bold] {exit_code}\n"
269 | duration = result.get("duration", 0.0)
270 | output_display += f"[bold]Duration:[/bold] {duration:.3f}s\n"
271 | cached = result.get("cached_result", False)
272 | output_display += f"[bold]Cached:[/bold] {'Yes' if cached else 'No'}\n"
273 |
274 | if is_dry_run:
275 | cmdline = result.get("dry_run_cmdline", [])
276 | output_display += f"\n[bold yellow]Dry Run Command:[/]\n{shlex.join(cmdline)}\n"
277 | elif success:
278 | stdout_str = result.get("stdout", "")
279 | stderr_str = result.get("stderr", "")
280 | stdout_trunc = result.get("stdout_truncated", False)
281 | stderr_trunc = result.get("stderr_truncated", False)
282 |
283 | if stdout_str:
284 | output_display += f"\n[bold green]STDOUT ({len(stdout_str)} chars{', TRUNCATED' if stdout_trunc else ''}):[/]\n"
285 | # Try syntax highlighting if stdout looks like JSON
286 | stdout_str.strip().startswith(
287 | ("{", "[")
288 | ) and stdout_str.strip().endswith(("}", "]"))
289 | # Limit length for display
290 | display_stdout = stdout_str[:3000] + ("..." if len(stdout_str) > 3000 else "")
291 | # Just add the plain output text instead of the Syntax object
292 | output_display += display_stdout
293 | else:
294 | output_display += "[dim]STDOUT: (empty)[/]"
295 |
296 | if stderr_str:
297 | header = f"[bold yellow]STDERR ({len(stderr_str)} chars{', TRUNCATED' if stderr_trunc else ''}):[/]"
298 | output_display += f"\n\n{header}"
299 | # Apply syntax highlighting for stderr too if it looks structured
300 | is_stderr_json_like = stderr_str.strip().startswith(
301 | ("{", "[")
302 | ) and stderr_str.strip().endswith(("}", "]"))
303 | if is_stderr_json_like:
304 | stderr_display = stderr_str[:1000] + ("..." if len(stderr_str) > 1000 else "")
305 | Syntax(
306 | stderr_display,
307 | "json",
308 | theme="monokai",
309 | line_numbers=False,
310 | word_wrap=True,
311 | )
312 | # We'll print this directly later
313 | else:
314 | output_display += "\n" + escape(
315 | stderr_str[:1000] + ("..." if len(stderr_str) > 1000 else "")
316 | )
317 | else:
318 | output_display += "\n\n[dim]STDERR: (empty)[/]"
319 |
320 | # Create panel with the text content
321 | console.print(
322 | Panel(output_display, title=panel_title, border_style=panel_border, expand=False)
323 | )
324 |
325 | except (ToolInputError, ToolExecutionError) as e: # Catch specific tool errors
326 | processing_time = time.monotonic() - start_time
327 | logger.error(f"Tool '{tool_func_name}' failed: {e}", exc_info=False)
328 | if display_output:
329 | error_title = f"[bold red]Error: {tool_func_name} Failed ❌[/]"
330 | error_code_val = getattr(e, "error_code", None)
331 | # Handle both enum and string error codes
332 | error_code_str = ""
333 | if error_code_val:
334 | if hasattr(error_code_val, "value"):
335 | error_code_str = f" ({error_code_val.value})"
336 | else:
337 | error_code_str = f" ({error_code_val})"
338 | error_content = f"[bold red]{type(e).__name__}{error_code_str}:[/] {escape(str(e))}"
339 | if hasattr(e, "details") and e.details:
340 | try:
341 | details_repr = pretty_repr(e.details)
342 | except Exception:
343 | details_repr = str(e.details)
344 | error_content += f"\n\n[yellow]Details:[/]\n{escape(details_repr)}"
345 | console.print(Panel(error_content, title=error_title, border_style="red", expand=False))
346 | # Ensure result dict structure on error
347 | result = ToolResult(
348 | success=False,
349 | error=str(e),
350 | error_code=getattr(e, "error_code", ToolErrorCode.UNEXPECTED_FAILURE),
351 | details=getattr(e, "details", {}),
352 | stdout=None,
353 | stderr=None,
354 | exit_code=None,
355 | duration=processing_time,
356 | )
357 | except Exception as e:
358 | processing_time = time.monotonic() - start_time
359 | logger.critical(f"Unexpected error calling '{tool_func_name}': {e}", exc_info=True)
360 | if display_output:
361 | console.print(f"\n[bold red]CRITICAL UNEXPECTED ERROR in {tool_func_name}:[/bold red]")
362 | console.print_exception(show_locals=False)
363 | result = ToolResult(
364 | success=False,
365 | error=f"Unexpected: {str(e)}",
366 | error_code=ToolErrorCode.UNEXPECTED_FAILURE,
367 | stdout=None,
368 | stderr=None,
369 | exit_code=None,
370 | duration=processing_time,
371 | )
372 | finally:
373 | if display_output:
374 | console.print() # Add spacing
375 |
376 | # Ensure result is always a ToolResult-like dictionary before returning
377 | if not isinstance(result, dict):
378 | logger.error(
379 | f"Tool {tool_func_name} returned non-dict type {type(result)}. Returning error dict."
380 | )
381 | result = ToolResult(
382 | success=False,
383 | error=f"Tool returned unexpected type: {type(result).__name__}",
384 | error_code=ToolErrorCode.UNEXPECTED_FAILURE,
385 | )
386 |
387 | # Ensure basic keys exist even if tool failed unexpectedly before returning dict
388 | result.setdefault("success", False)
389 | result.setdefault("cached_result", False)
390 |
391 | return result
392 |
393 |
394 | async def safe_tool_stream_call(
395 | stream_func: StreamFunction,
396 | args: Dict[str, Any],
397 | description: str,
398 | ) -> bool:
399 | """
400 | Call a run_*_stream wrapper, printing the stream as it arrives.
401 | Works whether the wrapper returns the iterator directly or returns it
402 | inside a coroutine (the current behaviour when decorators are applied).
403 | """
404 | tool_name = getattr(stream_func, "__name__", "unknown_stream_tool")
405 | console.print(
406 | Rule(f"[bold magenta]Streaming Demo: {escape(description)}[/bold magenta]",
407 | style="magenta")
408 | )
409 | console.print(f"[dim]Calling [bold cyan]{tool_name}[/] with args:[/]")
410 | console.print(pretty_repr(args, max_length=120, max_string=200))
411 |
412 | # ─── call the wrapper ────────────────────────────────────────────────────────
413 | stream_obj = stream_func(**args) # do *not* await yet
414 | if inspect.iscoroutine(stream_obj): # decorator returned coroutine
415 | stream_obj = await stream_obj # now we have AsyncIterator
416 |
417 | if not hasattr(stream_obj, "__aiter__"):
418 | console.print(
419 | Panel(f"[red]Fatal: {tool_name} did not return an async iterator.[/red]",
420 | border_style="red")
421 | )
422 | return False
423 |
424 | # ─── consume the stream ─────────────────────────────────────────────────────
425 | start = time.monotonic()
426 | line_count, buffered = 0, ""
427 | console.print("[yellow]--- Streaming Output Start ---[/]")
428 |
429 | try:
430 | async for line in stream_obj: # type: ignore[arg-type]
431 | line_count += 1
432 | buffered += line
433 | if len(buffered) > 2000 or "\n" in buffered:
434 | console.out(escape(buffered), end="")
435 | buffered = ""
436 | if buffered:
437 | console.out(escape(buffered), end="")
438 |
439 | status = "[green]Complete"
440 | ok = True
441 | except Exception:
442 | console.print_exception()
443 | status = "[red]Failed"
444 | ok = False
445 |
446 | console.print(
447 | f"\n[yellow]--- Streaming {status} ({line_count} lines in "
448 | f"{time.monotonic() - start:.3f}s) ---[/]\n"
449 | )
450 | return ok
451 |
452 |
453 | # --- Demo Functions ---
454 |
455 |
456 | async def demonstrate_ripgrep_basic():
457 | """Demonstrate basic usage of the run_ripgrep tool."""
458 | console.print(Rule("[bold green]1. Ripgrep (rg) Basic Examples[/bold green]", style="green"))
459 |
460 | classification_samples_str = str(CLASSIFICATION_SAMPLES_DIR_REL)
461 | article_file_quoted = shlex.quote(ARTICLE_FILE_PATH)
462 | class_dir_quoted = shlex.quote(classification_samples_str)
463 |
464 | # 1a: Basic search in a file
465 | await safe_tool_call(
466 | run_ripgrep,
467 | {
468 | "args_str": f"--threads=4 'Microsoft' {article_file_quoted}",
469 | "input_file": True, # Indicate args_str contains the file target
470 | },
471 | "Search for 'Microsoft' in article.txt (with thread limiting)",
472 | )
473 |
474 | # 1b: Case-insensitive search with context
475 | await safe_tool_call(
476 | run_ripgrep,
477 | {
478 | "args_str": f"--threads=4 -i --context 2 'anthropic' {article_file_quoted}",
479 | "input_file": True,
480 | },
481 | "Case-insensitive search for 'anthropic' with context (-i -C 2, limited threads)",
482 | )
483 |
484 | # 1c: Search for lines NOT containing a pattern
485 | await safe_tool_call(
486 | run_ripgrep,
487 | {
488 | "args_str": f"--threads=4 --invert-match 'AI' {article_file_quoted}",
489 | "input_file": True,
490 | },
491 | "Find lines NOT containing 'AI' in article.txt (-v, limited threads)",
492 | )
493 |
494 | # 1d: Count matches per file in a directory
495 | await safe_tool_call(
496 | run_ripgrep,
497 | {
498 | "args_str": f"--threads=4 --count-matches 'Subject:' {class_dir_quoted}",
499 | "input_dir": True, # Indicate args_str contains the dir target
500 | },
501 | "Count lines with 'Subject:' in classification samples dir (-c, limited threads)",
502 | )
503 |
504 | # 1e: Search within input_data
505 | sample_data = "Line one\nLine two with pattern\nLine three\nAnother pattern line"
506 | await safe_tool_call(
507 | run_ripgrep,
508 | {"args_str": "--threads=4 'pattern'", "input_data": sample_data},
509 | "Search for 'pattern' within input_data string (limited threads)",
510 | )
511 |
512 | # 1f: JSON output
513 | await safe_tool_call(
514 | run_ripgrep,
515 | {
516 | "args_str": f"--threads=4 --json 'acquisition' {article_file_quoted}",
517 | "input_file": True,
518 | },
519 | "Search for 'acquisition' with JSON output (--json, limited threads)",
520 | )
521 |
522 | # 1g: Error case - Invalid Regex Pattern (example)
523 | await safe_tool_call(
524 | run_ripgrep,
525 | {"args_str": f"--threads=4 '[' {article_file_quoted}", "input_file": True},
526 | "Search with potentially invalid regex pattern '[' (INTENTIONAL DEMONSTRATION: regex validation)",
527 | )
528 |
529 |
530 | async def demonstrate_ripgrep_advanced():
531 | """Demonstrate advanced usage of the run_ripgrep tool."""
532 | console.print(
533 | Rule("[bold green]1b. Ripgrep (rg) Advanced Examples[/bold green]", style="green")
534 | )
535 |
536 | contract_file_quoted = shlex.quote(CONTRACT_FILE_PATH)
537 | class_dir_quoted = shlex.quote(str(CLASSIFICATION_SAMPLES_DIR_REL))
538 |
539 | # Adv 1a: Multiline search (simple example)
540 | await safe_tool_call(
541 | run_ripgrep,
542 | # Search for "ARTICLE I" followed by "Consideration" within 10 lines, case sensitive
543 | {
544 | "args_str": f"--threads=4 --multiline --multiline-dotall --context 1 'ARTICLE I.*?Consideration' {contract_file_quoted}",
545 | "input_file": True,
546 | },
547 | "Multiline search for 'ARTICLE I' then 'Consideration' within context (-U -C 1, limited threads)",
548 | )
549 |
550 | # Adv 1b: Search specific file types and replace output
551 | await safe_tool_call(
552 | run_ripgrep,
553 | # Search for 'Agreement' in .txt files, replace matching text with '***CONTRACT***'
554 | {
555 | "args_str": f"--threads=4 --replace '***CONTRACT***' 'Agreement' {contract_file_quoted}",
556 | "input_file": True,
557 | },
558 | "Search for 'Agreement' in contract file and replace in output (--replace, limited threads)",
559 | )
560 |
561 | # Adv 1c: Using Globs to include/exclude
562 | # Search for 'email' in classification samples, but exclude the news samples file
563 | exclude_pattern = shlex.quote(os.path.basename(CLASSIFICATION_SAMPLES_DIR_REL / "news_samples.txt"))
564 | await safe_tool_call(
565 | run_ripgrep,
566 | {
567 | "args_str": f"--threads=4 -i 'email' -g '!{exclude_pattern}' {class_dir_quoted}",
568 | "input_dir": True,
569 | },
570 | f"Search for 'email' in classification dir, excluding '{exclude_pattern}' (-g, limited threads)",
571 | )
572 |
573 | # Adv 1d: Print only matching part with line numbers and context
574 | await safe_tool_call(
575 | run_ripgrep,
576 | # Extract dates like YYYY-MM-DD
577 | {
578 | "args_str": f"--threads=4 --only-matching --line-number --context 1 '[0-9]{{4}}-[0-9]{{2}}-[0-9]{{2}}' {contract_file_quoted}",
579 | "input_file": True,
580 | },
581 | "Extract date patterns (YYYY-MM-DD) with line numbers and context (-o -n -C 1, limited threads)",
582 | )
583 |
584 | # Adv 1e: Follow symlinks (if applicable and symlinks were created in setup)
585 | # This depends on your setup having symlinks pointing into allowed directories
586 | # Example assumes a symlink named 'contract_link.txt' points to legal_contract.txt
587 | link_path = SAMPLE_DIR_ABS / "contract_link.txt" # Absolute path for creation
588 | target_path = SAMPLE_DIR_ABS / "legal_contract.txt" # Absolute path for file operations
589 | # Create link for demo if target exists
590 | if target_path.exists() and not link_path.exists():
591 | try:
592 | os.symlink(target_path.name, link_path) # Relative link
593 | logger.info("Created symlink 'contract_link.txt' for demo.")
594 | except OSError as e:
595 | logger.warning(f"Could not create symlink for demo: {e}")
596 |
597 | # Use relative path for the tool
598 | link_path_rel = link_path.relative_to(PROJECT_ROOT) if link_path.exists() else "nonexistent_link.txt"
599 | link_path_quoted = shlex.quote(str(link_path_rel))
600 | await safe_tool_call(
601 | run_ripgrep,
602 | {"args_str": f"--threads=4 --follow 'Acquirer' {link_path_quoted}", "input_file": True},
603 | "Search for 'Acquirer' following symlinks (--follow, limited threads) (requires symlink setup)",
604 | )
605 |
606 |
607 | async def demonstrate_awk_basic():
608 | """Demonstrate basic usage of the run_awk tool."""
609 | console.print(Rule("[bold green]2. AWK Basic Examples[/bold green]", style="green"))
610 |
611 | email_file_quoted = shlex.quote(EMAIL_FILE_PATH)
612 |
613 | # 2a: Print specific fields (e.g., Subject lines)
614 | await safe_tool_call(
615 | run_awk,
616 | # FS = ':' is the field separator, print second field ($2) if first field is 'Subject'
617 | {
618 | "args_str": f"-F ':' '/^Subject:/ {{ print $2 }}' {email_file_quoted}",
619 | "input_file": True,
620 | },
621 | "Extract Subject lines from email sample using AWK (-F ':')",
622 | )
623 |
624 | # 2b: Count lines containing a specific word using AWK logic
625 | await safe_tool_call(
626 | run_awk,
627 | # Increment count if line contains 'account', print total at the end
628 | {
629 | "args_str": f"'/account/ {{ count++ }} END {{ print \"Lines containing account:\", count }}' {email_file_quoted}",
630 | "input_file": True,
631 | },
632 | "Count lines containing 'account' in email sample using AWK",
633 | )
634 |
635 | # 2c: Process input_data - print first word of each line
636 | awk_input_data = "Apple Banana Cherry\nDog Elephant Fox\nOne Two Three"
637 | await safe_tool_call(
638 | run_awk,
639 | {"args_str": "'{ print $1 }'", "input_data": awk_input_data},
640 | "Print first word of each line from input_data using AWK",
641 | )
642 |
643 | # 2d: Error case - Syntax error in AWK script
644 | await safe_tool_call(
645 | run_awk,
646 | {"args_str": "'{ print $1 '", "input_data": awk_input_data}, # Missing closing brace
647 | "Run AWK with a syntax error in the script (INTENTIONAL DEMONSTRATION: script validation)",
648 | )
649 |
650 |
651 | async def demonstrate_awk_advanced():
652 | """Demonstrate advanced usage of the run_awk tool."""
653 | console.print(Rule("[bold green]2b. AWK Advanced Examples[/bold green]", style="green"))
654 |
655 | contract_file_quoted = shlex.quote(CONTRACT_FILE_PATH)
656 | schedule_file_quoted = shlex.quote(SCHEDULE_FILE_PATH)
657 |
658 | # Adv 2a: Calculate sum based on a field (extracting amounts from contract)
659 | await safe_tool_call(
660 | run_awk,
661 | # Find lines with '$', extract the number after '$', sum them
662 | {
663 | "args_str": f"'/[$]/ {{ gsub(/[,USD$]/, \"\"); for(i=1;i<=NF;i++) if ($i ~ /^[0-9.]+$/) sum+=$i }} END {{ printf \"Total Value Mentioned: $%.2f\\n\", sum }}' {contract_file_quoted}",
664 | "input_file": True,
665 | },
666 | "Sum numeric values following '$' in contract using AWK"
667 | )
668 |
669 | # Adv 2b: Using BEGIN block and variables to extract definitions
670 | await safe_tool_call(
671 | run_awk,
672 | # Find lines defining terms like ("Acquirer"), print term and line number
673 | {
674 | "args_str": f"'/^\\s*[A-Z][[:alpha:] ]+\\s+\\(.*\"[[:alpha:]].*\"\\)/ {{ if(match($0, /\\(\"([^\"]+)\"\\)/, arr)) {{ term=arr[1]; print \"Term Defined: \", term, \"(Line: \" NR \")\" }} }}' {contract_file_quoted}",
675 | "input_file": True,
676 | },
677 | 'Extract defined terms (e.g., ("Acquirer")) using AWK and NR',
678 | )
679 |
680 | # Adv 2c: Change output field separator and process specific sections
681 | await safe_tool_call(
682 | run_awk,
683 | # In ARTICLE I, print section number and title, comma separated
684 | {
685 | "args_str": f"'BEGIN {{ OFS=\",\"; print \"Section,Title\" }} /^## ARTICLE I/,/^## ARTICLE II/ {{ if (/^[0-9]\\.[0-9]+\\s/) {{ title=$0; sub(/^[0-9.]+s*/, \"\", title); print $1, title }} }}' {contract_file_quoted}",
686 | "input_file": True,
687 | },
688 | "Extract section titles from ARTICLE I, CSV formatted (OFS)",
689 | )
690 |
691 | # Adv 2d: Associative arrays to count stockholder types from SCHEDULE_1.2 file
692 | if Path(SCHEDULE_FILE_PATH_ABS).exists():
693 | await safe_tool_call(
694 | run_awk,
695 | # Count occurrences based on text before '(' or '%'
696 | {
697 | "args_str": f"-F'|' '/^\\| / && NF>2 {{ gsub(/^ +| +$/, \"\", $2); types[$2]++ }} END {{ print \"Stockholder Counts:\"; for (t in types) print t \":\", types[t] }}' {schedule_file_quoted}",
698 | "input_file": True,
699 | },
700 | "Use associative array in AWK to count stockholder types in Schedule 1.2",
701 | )
702 | else:
703 | logger.warning(f"Skipping AWK advanced demo 2d, file not found: {SCHEDULE_FILE_PATH_ABS}")
704 |
705 |
706 | async def demonstrate_sed_basic():
707 | """Demonstrate basic usage of the run_sed tool."""
708 | console.print(Rule("[bold green]3. SED Basic Examples[/bold green]", style="green"))
709 |
710 | article_file_quoted = shlex.quote(ARTICLE_FILE_PATH)
711 |
712 | # 3a: Simple substitution
713 | await safe_tool_call(
714 | run_sed,
715 | {
716 | "args_str": f"'s/Microsoft/MegaCorp/g' {article_file_quoted}",
717 | "input_file": True,
718 | },
719 | "Replace 'Microsoft' with 'MegaCorp' in article.txt (global)",
720 | )
721 |
722 | # 3b: Delete lines containing a pattern
723 | await safe_tool_call(
724 | run_sed,
725 | {
726 | "args_str": f"'/Anthropic/d' {article_file_quoted}",
727 | "input_file": True,
728 | },
729 | "Delete lines containing 'Anthropic' from article.txt",
730 | )
731 |
732 | # 3c: Print only lines containing a specific pattern (-n + p)
733 | await safe_tool_call(
734 | run_sed,
735 | {
736 | "args_str": f"-n '/acquisition/p' {article_file_quoted}",
737 | "input_file": True,
738 | },
739 | "Print only lines containing 'acquisition' from article.txt",
740 | )
741 |
742 | # 3d: Process input_data - change 'line' to 'row'
743 | sed_input_data = "This is line one.\nThis is line two.\nAnother line."
744 | await safe_tool_call(
745 | run_sed,
746 | {"args_str": "'s/line/row/g'", "input_data": sed_input_data},
747 | "Replace 'line' with 'row' in input_data string",
748 | )
749 |
750 | # 3e: Demonstrate blocked in-place edit attempt (security feature)
751 | await safe_tool_call(
752 | run_sed,
753 | {
754 | "args_str": f"-i 's/AI/ArtificialIntelligence/g' {article_file_quoted}",
755 | "input_file": True,
756 | },
757 | "Attempt in-place edit with sed -i (SECURITY CHECK PASSED: forbidden flag blocked)",
758 | )
759 |
760 | # 3f: Error case - Unterminated substitute command
761 | await safe_tool_call(
762 | run_sed,
763 | {
764 | "args_str": "'s/AI/ArtificialIntelligence",
765 | "input_data": sed_input_data,
766 | }, # Missing closing quote and delimiter
767 | "Run SED with an unterminated 's' command (INTENTIONAL DEMONSTRATION: script validation)",
768 | )
769 |
770 |
771 | async def demonstrate_sed_advanced():
772 | """Demonstrate advanced usage of the run_sed tool."""
773 | console.print(Rule("[bold green]3b. SED Advanced Examples[/bold green]", style="green"))
774 |
775 | contract_file_quoted = shlex.quote(CONTRACT_FILE_PATH)
776 |
777 | # Adv 3a: Multiple commands with -e
778 | await safe_tool_call(
779 | run_sed,
780 | # Command 1: Change 'Agreement' to 'CONTRACT'. Command 2: Delete lines with 'Exhibit'.
781 | {
782 | "args_str": f"-e 's/Agreement/CONTRACT/g' -e '/Exhibit/d' {contract_file_quoted}",
783 | "input_file": True,
784 | },
785 | "Use multiple SED commands (-e) for substitution and deletion",
786 | )
787 |
788 | # Adv 3b: Using address ranges (print ARTICLE III content)
789 | await safe_tool_call(
790 | run_sed,
791 | {
792 | "args_str": f"-n '/^## ARTICLE III/,/^## ARTICLE IV/p' {contract_file_quoted}",
793 | "input_file": True,
794 | },
795 | "Print content between '## ARTICLE III' and '## ARTICLE IV' using SED addresses",
796 | )
797 |
798 | # Adv 3c: Substitute only the first occurrence on a line
799 | await safe_tool_call(
800 | run_sed,
801 | # Change only the first 'Company' to 'Firm' on each line
802 | {
803 | "args_str": f"'s/Company/Firm/' {contract_file_quoted}",
804 | "input_file": True,
805 | },
806 | "Substitute only the first occurrence of 'Company' per line",
807 | )
808 |
809 | # Adv 3d: Using capture groups to reformat dates (MM/DD/YYYY -> YYYY-MM-DD)
810 | # Note: This regex is basic, might not handle all date formats in the text perfectly
811 | await safe_tool_call(
812 | run_sed,
813 | # Capture month, day, year and rearrange
814 | {
815 | "args_str": rf"-E 's|([0-9]{{1,2}})/([0-9]{{1,2}})/([0-9]{{4}})|\3-\1-\2|g' {contract_file_quoted}",
816 | "input_file": True,
817 | },
818 | "Rearrange date format (MM/DD/YYYY -> YYYY-MM-DD) using SED capture groups",
819 | )
820 |
821 | # Adv 3e: Insert text before lines matching a pattern
822 | await safe_tool_call(
823 | run_sed,
824 | # Insert 'IMPORTANT: ' before lines starting with '## ARTICLE'
825 | {
826 | "args_str": f"'/^## ARTICLE/i IMPORTANT: ' {contract_file_quoted}",
827 | "input_file": True,
828 | },
829 | "Insert text before lines matching a pattern using SED 'i' command",
830 | )
831 |
832 |
833 | async def demonstrate_jq_basic():
834 | """Demonstrate basic usage of the run_jq tool."""
835 | console.print(Rule("[bold green]4. JQ Basic Examples[/bold green]", style="green"))
836 |
837 | # Using input_data for most basic examples
838 | jq_input_data = """
839 | {
840 | "id": "wf-123",
841 | "title": "Data Processing",
842 | "steps": [
843 | {"name": "load", "status": "completed", "duration": 5.2},
844 | {"name": "transform", "status": "running", "duration": null, "details": {"type": "pivot"}},
845 | {"name": "analyze", "status": "pending", "duration": null}
846 | ],
847 | "metadata": {
848 | "user": "admin",
849 | "priority": "high"
850 | }
851 | }
852 | """
853 |
854 | # 4a: Select a top-level field
855 | await safe_tool_call(
856 | run_jq,
857 | {"args_str": "'.title'", "input_data": jq_input_data},
858 | "Select the '.title' field using JQ",
859 | )
860 |
861 | # 4b: Select a nested field
862 | await safe_tool_call(
863 | run_jq,
864 | {"args_str": "'.metadata.priority'", "input_data": jq_input_data},
865 | "Select the nested '.metadata.priority' field using JQ",
866 | )
867 |
868 | # 4c: Select names from the steps array
869 | await safe_tool_call(
870 | run_jq,
871 | {"args_str": "'.steps[].name'", "input_data": jq_input_data},
872 | "Select all step names from the '.steps' array using JQ",
873 | )
874 |
875 | # 4d: Filter steps by status
876 | await safe_tool_call(
877 | run_jq,
878 | {"args_str": "'.steps[] | select(.status == \"completed\")'", "input_data": jq_input_data},
879 | "Filter steps where status is 'completed' using JQ",
880 | )
881 |
882 | # 4e: Create a new object structure
883 | await safe_tool_call(
884 | run_jq,
885 | # Create a new object with workflow id and number of steps
886 | {
887 | "args_str": "'{ workflow: .id, step_count: (.steps | length) }'",
888 | "input_data": jq_input_data,
889 | },
890 | "Create a new object structure using JQ '{ workflow: .id, step_count: .steps | length }'",
891 | )
892 |
893 | # 4f: Error case - Invalid JQ filter syntax
894 | await safe_tool_call(
895 | run_jq,
896 | {
897 | "args_str": "'.steps[] | select(.status =)'",
898 | "input_data": jq_input_data,
899 | }, # Incomplete select
900 | "Run JQ with invalid filter syntax (INTENTIONAL DEMONSTRATION: script validation)",
901 | )
902 |
903 | # 4g: Error case - Process non-JSON input (Input Validation)
904 | await safe_tool_call(
905 | run_jq,
906 | {"args_str": "'.'", "input_data": "This is not JSON."},
907 | "Run JQ on non-JSON input data (INTENTIONAL DEMONSTRATION: input validation)",
908 | )
909 |
910 | # 4h: Using a JSON file as input
911 | if Path(JSON_SAMPLE_PATH_ABS).exists():
912 | json_file_quoted = shlex.quote(JSON_SAMPLE_PATH)
913 | await safe_tool_call(
914 | run_jq,
915 | {
916 | "args_str": f"'.[] | select(.dept == \"IT\").user' {json_file_quoted}",
917 | "input_file": True,
918 | },
919 | "Select 'user' from IT department in sample_data.json",
920 | )
921 | else:
922 | logger.warning(f"Skipping JQ basic demo 4h, file not found: {JSON_SAMPLE_PATH_ABS}")
923 |
924 |
925 | async def demonstrate_jq_advanced():
926 | """Demonstrate advanced usage of the run_jq tool."""
927 | console.print(Rule("[bold green]4b. JQ Advanced Examples[/bold green]", style="green"))
928 |
929 | # Using file input for advanced examples
930 | if not Path(JSON_SAMPLE_PATH_ABS).exists():
931 | logger.warning(f"Skipping JQ advanced demos, file not found: {JSON_SAMPLE_PATH_ABS}")
932 | return
933 |
934 | json_file_quoted = shlex.quote(JSON_SAMPLE_PATH)
935 |
936 | # Adv 4a: Map and filter combined (select users with 'active' tag)
937 | await safe_tool_call(
938 | run_jq,
939 | {
940 | "args_str": f"'.[] | select(.tags | contains([\"active\"])) | .user' {json_file_quoted}",
941 | "input_file": True,
942 | },
943 | "JQ: Select users with the 'active' tag using 'contains' from file",
944 | )
945 |
946 | # Adv 4b: Group by department and calculate average value
947 | # Note: jq 'group_by' produces nested arrays, requires map to process
948 | await safe_tool_call(
949 | run_jq,
950 | {
951 | "args_str": f"'group_by(.dept) | map({{department: .[0].dept, avg_value: (map(.value) | add / length)}})' {json_file_quoted}",
952 | "input_file": True,
953 | },
954 | "JQ: Group by 'dept' and calculate average 'value' from file",
955 | )
956 |
957 | # Adv 4c: Using variables and checking multiple conditions
958 | await safe_tool_call(
959 | run_jq,
960 | # Find IT users from South or West with value > 120
961 | {
962 | "args_str": f'\'map(select(.dept == "IT" and (.region == "South" or .region == "West") and .value > 120))\' {json_file_quoted}',
963 | "input_file": True,
964 | },
965 | "JQ: Complex select with multiple AND/OR conditions from file",
966 | )
967 |
968 | # Adv 4d: Raw output (-r) to get just text values
969 | await safe_tool_call(
970 | run_jq,
971 | # Output user names directly without JSON quotes
972 | {"args_str": f"-r '.[] | .user' {json_file_quoted}", "input_file": True},
973 | "JQ: Get raw string output using -r flag from file",
974 | )
975 |
976 |
977 | async def demonstrate_security_features():
978 | """Demonstrate argument validation and security features."""
979 | console.print(Rule("[bold red]5. Security Feature Demonstrations[/bold red]", style="red"))
980 |
981 | target_file_quoted = shlex.quote(ARTICLE_FILE_PATH)
982 | workspace = get_workspace_dir() # Get the actual workspace for context # noqa: F841
983 |
984 | # Sec 1: Forbidden flag (-i for sed) - Already in sed_basic, ensure it's shown clearly
985 | console.print("[dim]--- Test: Forbidden Flag ---[/]")
986 | await safe_tool_call(
987 | run_sed,
988 | {
989 | "args_str": f"-i 's/AI/ArtificialIntelligence/g' {target_file_quoted}",
990 | "input_file": True,
991 | },
992 | "Attempt in-place edit with sed -i (SECURITY CHECK PASSED: forbidden flag blocked)",
993 | )
994 |
995 | # Sec 2: Forbidden characters (e.g., > for redirection)
996 | console.print("[dim]--- Test: Forbidden Characters ---[/]")
997 | await safe_tool_call(
998 | run_awk,
999 | {"args_str": "'{ print $1 > \"output.txt\" }'", "input_data": "hello world"},
1000 | "Attempt redirection with awk '>' (SECURITY CHECK PASSED: forbidden operation blocked)",
1001 | )
1002 |
1003 | # Sec 3: Command substitution attempt
1004 | console.print("[dim]--- Test: Command Substitution ---[/]")
1005 | await safe_tool_call(
1006 | run_ripgrep,
1007 | {
1008 | "args_str": f"--threads=4 'pattern' `echo {target_file_quoted}`",
1009 | "input_file": True,
1010 | "input_dir": False,
1011 | }, # Input from args only
1012 | "Attempt command substitution with backticks `` (SECURITY CHECK PASSED: command injection blocked)",
1013 | )
1014 | await safe_tool_call(
1015 | run_ripgrep,
1016 | {
1017 | "args_str": f"--threads=4 'pattern' $(basename {target_file_quoted})",
1018 | "input_file": True,
1019 | "input_dir": False,
1020 | },
1021 | "Attempt command substitution with $() (SECURITY CHECK PASSED: command injection blocked)",
1022 | )
1023 |
1024 | # Sec 4: Path Traversal
1025 | console.print("[dim]--- Test: Path Traversal ---[/]")
1026 | # Choose a target likely outside the workspace
1027 | traversal_path = (
1028 | "../../etc/passwd"
1029 | if sys.platform != "win32"
1030 | else "..\\..\\Windows\\System32\\drivers\\etc\\hosts"
1031 | )
1032 | traversal_path_quoted = shlex.quote(traversal_path)
1033 | await safe_tool_call(
1034 | run_ripgrep,
1035 | {"args_str": f"--threads=4 'root' {traversal_path_quoted}", "input_file": True},
1036 | f"Attempt path traversal '{traversal_path}' (SECURITY CHECK PASSED: path traversal blocked)",
1037 | )
1038 |
1039 | # Sec 5: Absolute Path
1040 | console.print("[dim]--- Test: Absolute Path ---[/]")
1041 | # Use a known absolute path
1042 | abs_path = str(
1043 | Path(target_file_quoted).resolve()
1044 | ) # Should be inside workspace IF demo runs from there, but treat as example
1045 | abs_path_quoted = shlex.quote(abs_path) # noqa: F841
1046 | # Let's try a known outside-workspace path if possible
1047 | abs_outside_path = "/tmp/testfile" if sys.platform != "win32" else "C:\\Windows\\notepad.exe"
1048 | abs_outside_path_quoted = shlex.quote(abs_outside_path)
1049 |
1050 | await safe_tool_call(
1051 | run_ripgrep,
1052 | {"args_str": f"--threads=4 'test' {abs_outside_path_quoted}", "input_file": True},
1053 | f"Attempt absolute path '{abs_outside_path}' (SECURITY CHECK PASSED: absolute path blocked)",
1054 | )
1055 |
1056 | # Sec 6: Dry Run
1057 | console.print("[dim]--- Test: Dry Run ---[/]")
1058 | await safe_tool_call(
1059 | run_ripgrep,
1060 | {
1061 | "args_str": f"--json -i 'pattern' {target_file_quoted}",
1062 | "input_file": True,
1063 | "dry_run": True,
1064 | },
1065 | "Demonstrate dry run (--json -i 'pattern' <file>)",
1066 | )
1067 |
1068 |
1069 | async def demonstrate_streaming():
1070 | """Demonstrate the streaming capabilities."""
1071 | console.print(Rule("[bold magenta]6. Streaming Examples[/bold magenta]", style="magenta"))
1072 |
1073 | # Use a file likely to produce multiple lines of output
1074 | target_file_quoted = shlex.quote(CONTRACT_FILE_PATH)
1075 |
1076 | # Stream 1: Ripgrep stream for a common word
1077 | await safe_tool_stream_call(
1078 | run_ripgrep_stream,
1079 | {"args_str": f"--threads=4 -i 'Agreement' {target_file_quoted}", "input_file": True},
1080 | "Stream search results for 'Agreement' in contract (with thread limiting)",
1081 | )
1082 |
1083 | # Stream 2: Sed stream to replace and print
1084 | await safe_tool_stream_call(
1085 | run_sed_stream,
1086 | {"args_str": f"'s/Section/Clause/g' {target_file_quoted}", "input_file": True},
1087 | "Stream sed output replacing 'Section' with 'Clause'",
1088 | )
1089 |
1090 | # Stream 3: Awk stream to print fields
1091 | await safe_tool_stream_call(
1092 | run_awk_stream,
1093 | {
1094 | "args_str": f"'/^##/ {{print \"Found Section: \", $0}}' {target_file_quoted}",
1095 | "input_file": True,
1096 | },
1097 | "Stream awk output printing lines starting with '##'",
1098 | )
1099 |
1100 | # Stream 4: JQ stream on input data
1101 | jq_stream_input = """
1102 | {"id": 1, "value": "alpha"}
1103 | {"id": 2, "value": "beta"}
1104 | {"id": 3, "value": "gamma"}
1105 | {"id": 4, "value": "delta"}
1106 | """
1107 | await safe_tool_stream_call(
1108 | run_jq_stream,
1109 | {"args_str": "'.value'", "input_data": jq_stream_input},
1110 | "Stream jq extracting '.value' from multiple JSON objects",
1111 | )
1112 |
1113 |
1114 | # --- LLM Interactive Workflow Section ---
1115 | # NOTE: run_llm_interactive_workflow helper remains largely the same,
1116 | # but system prompts are updated below.
1117 |
1118 |
1119 | async def run_llm_interactive_workflow(
1120 | goal: str,
1121 | system_prompt: str,
1122 | target_file: Optional[str] = None,
1123 | initial_input_data: Optional[str] = None,
1124 | ):
1125 | """Runs an interactive workflow driven by an LLM using the text tool functions."""
1126 | # --- LLM Config Check ---
1127 | llm_provider_name = None
1128 | llm_model_name = None
1129 | try:
1130 | config = get_config()
1131 | # Use configured default provider or fallback
1132 | llm_provider_name = config.default_provider or Provider.OPENAI.value
1133 | provider_config = getattr(config.providers, llm_provider_name, None)
1134 | if not provider_config or not provider_config.api_key:
1135 | console.print(
1136 | f"[bold yellow]Warning:[/bold yellow] LLM provider '{llm_provider_name}' API key not configured."
1137 | )
1138 | console.print("Skipping this LLM interactive workflow demo.")
1139 | return False # Indicate skip
1140 | llm_model_name = provider_config.default_model # Use provider's default (can be None)
1141 | if not llm_model_name:
1142 | # Try a known default if provider default is missing
1143 | if llm_provider_name == Provider.OPENAI.value:
1144 | llm_model_name = "gpt-3.5-turbo"
1145 | elif llm_provider_name == Provider.ANTHROPIC.value:
1146 | llm_model_name = "claude-3-5-haiku-20241022" # Use a valid model without comments
1147 | # Add other provider fallbacks if needed
1148 | else:
1149 | llm_model_name = "default" # Placeholder if truly unknown
1150 |
1151 | if llm_model_name != "default":
1152 | logger.info(
1153 | f"No default model for provider '{llm_provider_name}', using fallback: {llm_model_name}"
1154 | )
1155 | else:
1156 | console.print(
1157 | f"[bold yellow]Warning:[/bold yellow] Could not determine default model for provider '{llm_provider_name}'. LLM calls might fail."
1158 | )
1159 |
1160 | except Exception as e:
1161 | console.print(f"[bold red]Error checking LLM configuration:[/bold red] {e}")
1162 | console.print("Skipping this LLM interactive workflow demo.")
1163 | return False # Indicate skip
1164 |
1165 | # --- Workflow Setup ---
1166 | console.print(
1167 | Panel(f"[bold]Goal:[/bold]\n{escape(goal)}", title="LLM Task", border_style="blue")
1168 | )
1169 | messages = [{"role": "system", "content": system_prompt}]
1170 | # Add initial content if provided
1171 | if target_file:
1172 | messages.append(
1173 | {"role": "user", "content": f"The primary target file for operations is: {target_file}"}
1174 | )
1175 | elif initial_input_data:
1176 | messages.append(
1177 | {
1178 | "role": "user",
1179 | "content": f"The input data to process is:\n```\n{initial_input_data[:1000]}\n```",
1180 | }
1181 | )
1182 |
1183 | # --- Helper to call LLM ---
1184 | async def run_llm_step(history: List[Dict]) -> Optional[Dict]:
1185 | # (This helper remains largely the same as before, relying on imported chat_completion)
1186 | try:
1187 | llm_response = await chat_completion(
1188 | provider=llm_provider_name, # type: ignore
1189 | model=llm_model_name,
1190 | messages=history,
1191 | temperature=0.1,
1192 | max_tokens=600, # Increased slightly for potentially complex plans
1193 | additional_params={"json_mode": True} # Pass json_mode through additional_params instead
1194 | )
1195 | if not llm_response.get("success"):
1196 | error_detail = llm_response.get("error", "Unknown error")
1197 | console.print(f"[bold red]LLM call failed:[/bold red] {error_detail}")
1198 | # Provide feedback to LLM about the failure
1199 | history.append(
1200 | {
1201 | "role": "assistant",
1202 | "content": json.dumps(
1203 | {
1204 | "tool": "error",
1205 | "args": {"reason": f"LLM API call failed: {error_detail}"},
1206 | }
1207 | ),
1208 | }
1209 | )
1210 | history.append(
1211 | {
1212 | "role": "user",
1213 | "content": "Your previous response resulted in an API error. Please check your request and try again, ensuring valid JSON output.",
1214 | }
1215 | )
1216 | # Try one more time after feedback
1217 | llm_response = await chat_completion(
1218 | provider=llm_provider_name, # type: ignore
1219 | model=llm_model_name,
1220 | messages=history,
1221 | temperature=0.15, # Slightly higher temp for retry
1222 | max_tokens=600,
1223 | additional_params={"json_mode": True} # Pass json_mode through additional_params here too
1224 | )
1225 | if not llm_response.get("success"):
1226 | console.print(
1227 | f"[bold red]LLM call failed on retry:[/bold red] {llm_response.get('error')}"
1228 | )
1229 | return None # Give up after retry
1230 |
1231 | llm_content = llm_response.get("message", {}).get("content", "").strip()
1232 |
1233 | # Attempt to parse the JSON directly
1234 | try:
1235 | # Handle potential ```json blocks if provider doesn't strip them in JSON mode
1236 | if llm_content.startswith("```json"):
1237 | llm_content = re.sub(r"^```json\s*|\s*```$", "", llm_content, flags=re.DOTALL)
1238 |
1239 | parsed_action = json.loads(llm_content)
1240 | if (
1241 | isinstance(parsed_action, dict)
1242 | and "tool" in parsed_action
1243 | and "args" in parsed_action
1244 | ):
1245 | # Basic validation of args structure
1246 | if not isinstance(parsed_action["args"], dict):
1247 | raise ValueError("LLM 'args' field is not a dictionary.")
1248 | return parsed_action
1249 | else:
1250 | console.print(
1251 | "[bold yellow]Warning:[/bold yellow] LLM response is valid JSON but lacks 'tool' or 'args'. Raw:\n",
1252 | llm_content,
1253 | )
1254 | return {
1255 | "tool": "error",
1256 | "args": {
1257 | "reason": "LLM response structure invalid (expected top-level 'tool' and 'args' keys in JSON)."
1258 | },
1259 | }
1260 | except (json.JSONDecodeError, ValueError) as json_err:
1261 | console.print(
1262 | f"[bold red]Error:[/bold red] LLM response was not valid JSON ({json_err}). Raw response:\n",
1263 | llm_content,
1264 | )
1265 | # Try to find tool name even in broken JSON for feedback
1266 | tool_match = re.search(r'"tool":\s*"(\w+)"', llm_content)
1267 | reason = f"LLM response was not valid JSON ({json_err})."
1268 | if tool_match:
1269 | reason += f" It mentioned tool '{tool_match.group(1)}'."
1270 | return {"tool": "error", "args": {"reason": reason}}
1271 | except Exception as e:
1272 | console.print(f"[bold red]Error during LLM interaction:[/bold red] {e}")
1273 | logger.error("LLM interaction error", exc_info=True)
1274 | return None
1275 |
1276 | # Map tool names from LLM response to actual functions
1277 | TOOL_FUNCTIONS = {
1278 | "run_ripgrep": run_ripgrep,
1279 | "run_awk": run_awk,
1280 | "run_sed": run_sed,
1281 | "run_jq": run_jq,
1282 | # Add streaming if needed, but LLM needs careful prompting for stream handling
1283 | # "run_ripgrep_stream": run_ripgrep_stream,
1284 | }
1285 |
1286 | # --- Iteration Loop ---
1287 | for i in range(MAX_LLM_ITERATIONS):
1288 | console.print(Rule(f"[bold]LLM Iteration {i + 1}/{MAX_LLM_ITERATIONS}[/bold]"))
1289 |
1290 | llm_action = await run_llm_step(messages)
1291 | if not llm_action:
1292 | console.print("[bold red]Failed to get valid action from LLM. Stopping.[/bold red]")
1293 | break
1294 |
1295 | # Append LLM's raw action choice to history BEFORE execution
1296 | messages.append({"role": "assistant", "content": json.dumps(llm_action)})
1297 |
1298 | tool_name = llm_action.get("tool")
1299 | tool_args = llm_action.get("args", {}) # Should be a dict if validation passed
1300 |
1301 | console.print(f"[magenta]LLM Planned Action:[/magenta] Tool = {tool_name}")
1302 | console.print(f"[magenta]LLM Args:[/magenta] {pretty_repr(tool_args)}")
1303 |
1304 | if tool_name == "finish":
1305 | console.print(Rule("[bold green]LLM Finished[/bold green]", style="green"))
1306 | console.print("[bold green]Final Answer:[/bold green]")
1307 | final_answer = tool_args.get("final_answer", "No final answer provided.")
1308 | # Display potential JSON nicely
1309 | try:
1310 | # Attempt to parse if it looks like JSON, otherwise print escaped string
1311 | if isinstance(final_answer, str) and final_answer.strip().startswith(("{", "[")):
1312 | parsed_answer = json.loads(final_answer)
1313 | console.print(
1314 | Syntax(json.dumps(parsed_answer, indent=2), "json", theme="monokai")
1315 | )
1316 | else:
1317 | console.print(escape(str(final_answer))) # Ensure it's a string
1318 | except json.JSONDecodeError:
1319 | console.print(escape(str(final_answer))) # Print escaped string on parse fail
1320 | break
1321 | if tool_name == "error":
1322 | console.print(Rule("[bold red]LLM Reported Error[/bold red]", style="red"))
1323 | console.print(
1324 | f"[bold red]Reason:[/bold red] {escape(tool_args.get('reason', 'No reason provided.'))}"
1325 | )
1326 | # Don't break immediately, let LLM try again based on this error feedback
1327 | messages.append(
1328 | {
1329 | "role": "user",
1330 | "content": f"Your previous step resulted in an error state: {tool_args.get('reason')}. Please analyze the issue and plan the next step or finish.",
1331 | }
1332 | )
1333 | continue # Allow LLM to react to its own error report
1334 |
1335 | tool_func_to_call = TOOL_FUNCTIONS.get(tool_name)
1336 |
1337 | if not tool_func_to_call:
1338 | error_msg = f"LLM requested invalid or unsupported tool: '{tool_name}'. Allowed: {list(TOOL_FUNCTIONS.keys())}"
1339 | console.print(f"[bold red]Error:[/bold red] {error_msg}")
1340 | messages.append(
1341 | {
1342 | "role": "user",
1343 | "content": f"Execution Error: {error_msg}. Please choose a valid tool from the allowed list.",
1344 | }
1345 | )
1346 | continue
1347 |
1348 | # Basic validation of common args
1349 | if "args_str" not in tool_args or not isinstance(tool_args["args_str"], str):
1350 | error_msg = f"LLM tool call for '{tool_name}' is missing 'args_str' string argument."
1351 | console.print(f"[bold red]Error:[/bold red] {error_msg}")
1352 | messages.append({"role": "user", "content": f"Input Error: {error_msg}"})
1353 | continue
1354 |
1355 | # Inject target file/data if not explicitly set by LLM but context suggests it
1356 | # Less critical now LLM is prompted to include path in args_str and set flags
1357 | if (
1358 | "input_file" not in tool_args
1359 | and "input_dir" not in tool_args
1360 | and "input_data" not in tool_args
1361 | ):
1362 | # Simple heuristic: if target_file seems to be in args_str, set input_file=True
1363 | if target_file and shlex.quote(target_file) in tool_args.get("args_str", ""):
1364 | tool_args["input_file"] = True
1365 | logger.debug(f"Injecting input_file=True based on args_str content: {target_file}")
1366 | # Maybe inject input_data if available and no file/dir flags? Risky.
1367 | # Let's rely on the LLM providing the flags or safe_tool_call catching errors.
1368 |
1369 | # Execute tool using the safe helper
1370 | execution_result = await safe_tool_call(
1371 | tool_func_to_call,
1372 | tool_args, # Pass the dict received from LLM
1373 | f"Executing LLM Request: {tool_name}",
1374 | display_input=False, # Already printed LLM args
1375 | display_output=False, # Summarize below for LLM context
1376 | )
1377 |
1378 | # Prepare result summary for LLM (Truncate long outputs)
1379 | result_summary_for_llm = ""
1380 | if isinstance(execution_result, dict):
1381 | success = execution_result.get("success", False)
1382 | stdout_preview = (execution_result.get("stdout", "") or "")[:1500] # Limit length
1383 | stderr_preview = (execution_result.get("stderr", "") or "")[:500]
1384 | stdout_trunc = execution_result.get("stdout_truncated", False)
1385 | stderr_trunc = execution_result.get("stderr_truncated", False)
1386 | exit_code = execution_result.get("exit_code")
1387 | error_msg = execution_result.get("error")
1388 | error_code = execution_result.get("error_code")
1389 |
1390 | result_summary_for_llm = f"Tool Execution Result ({tool_name}):\n"
1391 | result_summary_for_llm += f"Success: {success}\n"
1392 | result_summary_for_llm += f"Exit Code: {exit_code}\n"
1393 | if error_msg:
1394 | result_summary_for_llm += f"Error: {error_msg}\n"
1395 | if error_code:
1396 | if isinstance(error_code, Enum):
1397 | error_code_repr = error_code.value
1398 | else:
1399 | error_code_repr = str(error_code)
1400 | result_summary_for_llm += f"Error Code: {error_code_repr}\n"
1401 |
1402 | stdout_info = f"STDOUT ({len(stdout_preview)} chars preview{' - TRUNCATED' if stdout_trunc else ''}):"
1403 | result_summary_for_llm += f"{stdout_info}\n```\n{stdout_preview}\n```\n"
1404 |
1405 | if stderr_preview:
1406 | stderr_info = f"STDERR ({len(stderr_preview)} chars preview{' - TRUNCATED' if stderr_trunc else ''}):"
1407 | result_summary_for_llm += f"{stderr_info}\n```\n{stderr_preview}\n```\n"
1408 | else:
1409 | result_summary_for_llm += "STDERR: (empty)\n"
1410 | else: # Should not happen if safe_tool_call works
1411 | result_summary_for_llm = (
1412 | f"Tool Execution Error: Unexpected result format: {type(execution_result)}"
1413 | )
1414 |
1415 | console.print(
1416 | "[cyan]Execution Result Summary (for LLM):[/]", escape(result_summary_for_llm)
1417 | )
1418 | # Append the outcome back to the message history for the LLM's next turn
1419 | messages.append({"role": "user", "content": result_summary_for_llm})
1420 |
1421 | if i == MAX_LLM_ITERATIONS - 1:
1422 | console.print(Rule("[bold yellow]Max Iterations Reached[/bold yellow]", style="yellow"))
1423 | console.print("Stopping LLM workflow.")
1424 | break
1425 |
1426 | return True # Indicate demo ran (or attempted to run)
1427 |
1428 |
1429 | async def demonstrate_llm_workflow_extract_contacts():
1430 | """LLM Workflow: Extract email addresses and phone numbers from legal_contract.txt."""
1431 | console.print(
1432 | Rule("[bold cyan]7. LLM Workflow: Extract Contacts from Contract[/bold cyan]", style="cyan")
1433 | )
1434 | goal = f"Extract all unique email addresses and phone numbers (in standard format like XXX-XXX-XXXX or (XXX) XXX-XXXX) from the file: {CONTRACT_FILE_PATH}. Present the results clearly as two distinct lists (emails, phone numbers) in your final answer JSON."
1435 | # Updated system prompt for standalone functions
1436 | system_prompt = rf"""
1437 | You are an expert AI assistant tasked with extracting information from text using command-line tools accessed via functions.
1438 | Your goal is: {goal}
1439 | The primary target file is: {CONTRACT_FILE_PATH}
1440 |
1441 | You have access to the following functions:
1442 | - `run_ripgrep(args_str: str, input_file: bool = False, input_data: Optional[str] = None, ...)`: For regex searching.
1443 | - `run_awk(args_str: str, input_file: bool = False, input_data: Optional[str] = None, ...)`: For text processing.
1444 | - `run_sed(args_str: str, input_file: bool = False, input_data: Optional[str] = None, ...)`: For text transformation.
1445 |
1446 | To operate on the target file, you MUST:
1447 | 1. Include the correctly quoted file path in the `args_str`. Use '{shlex.quote(CONTRACT_FILE_PATH)}'.
1448 | 2. Set `input_file=True` in the arguments dictionary.
1449 |
1450 | Example `run_ripgrep` call structure for a file:
1451 | {{
1452 | "tool": "run_ripgrep",
1453 | "args": {{
1454 | "args_str": "-oN 'pattern' {shlex.quote(CONTRACT_FILE_PATH)}",
1455 | "input_file": true
1456 | }}
1457 | }}
1458 |
1459 | Example `run_awk` call structure for stdin:
1460 | {{
1461 | "tool": "run_awk",
1462 | "args": {{
1463 | "args_str": "'{{print $1}}'",
1464 | "input_data": "some input data here"
1465 | }}
1466 | }}
1467 |
1468 | Plan your steps carefully:
1469 | 1. Use `run_ripgrep` with appropriate regex patterns to find emails and phone numbers. Use flags like `-o` (only matching), `-N` (no line numbers), `--no-filename`.
1470 | 2. You might need separate `run_ripgrep` calls for emails and phone numbers.
1471 | 3. Consider using `run_awk` or `run_sed` on the output of `run_ripgrep` (passed via `input_data`) to normalize or unique sort the results, OR present the unique lists in your final answer. A simple approach is often best.
1472 | 4. When finished, respond with `tool: "finish"` and provide the final answer in the specified format within `args: {{"final_answer": ...}}`.
1473 |
1474 | Respond ONLY with a valid JSON object representing the next single action (tool and args) or the final answer. Do not add explanations outside the JSON.
1475 | """
1476 | await run_llm_interactive_workflow(goal, system_prompt, target_file=CONTRACT_FILE_PATH)
1477 |
1478 |
1479 | async def demonstrate_llm_workflow_financial_terms():
1480 | """LLM Workflow: Extract key financial figures from legal_contract.txt."""
1481 | console.print(
1482 | Rule(
1483 | "[bold cyan]8. LLM Workflow: Extract Financial Terms from Contract[/bold cyan]",
1484 | style="cyan",
1485 | )
1486 | )
1487 | goal = f"Extract the exact 'Transaction Value', 'Cash Consideration', and 'Stock Consideration' figures (including USD amounts) mentioned in ARTICLE I of the file: {CONTRACT_FILE_PATH}. Also find the 'Escrow Amount' percentage and the Escrow Agent's name. Structure the final answer as a JSON object."
1488 | # Updated system prompt
1489 | system_prompt = rf"""
1490 | You are an AI assistant specialized in analyzing legal documents using command-line tools accessed via functions.
1491 | Your goal is: {goal}
1492 | The target file is: {CONTRACT_FILE_PATH}
1493 |
1494 | Available functions: `run_ripgrep`, `run_awk`, `run_sed`.
1495 | Remember to include the quoted file path '{shlex.quote(CONTRACT_FILE_PATH)}' in `args_str` and set `input_file=True` when operating on the file.
1496 |
1497 | Plan your steps:
1498 | 1. Use `run_ripgrep` to find relevant lines in ARTICLE I (e.g., search for 'Consideration', '$', 'USD', 'Escrow'). Use context flags like `-A`, `-C` to get surrounding lines if needed.
1499 | 2. Use `run_ripgrep` again or `run_sed`/`run_awk` on the previous output (passed via `input_data`) or the original file to isolate the exact monetary figures (e.g., '$XXX,XXX,XXX USD') and the Escrow Agent name. Regex like `\$\d{{1,3}}(,\d{{3}})*(\.\d+)?\s*USD` might be useful. Be specific with your patterns.
1500 | 3. Combine the extracted information into a JSON object for the `final_answer`.
1501 |
1502 | Respond ONLY with a valid JSON object for the next action or the final answer (`tool: "finish"`).
1503 | """
1504 | await run_llm_interactive_workflow(goal, system_prompt, target_file=CONTRACT_FILE_PATH)
1505 |
1506 |
1507 | async def demonstrate_llm_workflow_defined_terms():
1508 | """LLM Workflow: Extract defined terms like ("Acquirer") from legal_contract.txt."""
1509 | console.print(
1510 | Rule(
1511 | "[bold cyan]9. LLM Workflow: Extract Defined Terms from Contract[/bold cyan]",
1512 | style="cyan",
1513 | )
1514 | )
1515 | goal = f'Find all defined terms enclosed in parentheses and quotes, like ("Acquirer"), in the file: {CONTRACT_FILE_PATH}. List the unique terms found in the final answer.'
1516 | # Updated system prompt
1517 | system_prompt = rf"""
1518 | You are an AI assistant skilled at extracting specific patterns from text using command-line tools accessed via functions.
1519 | Your goal is: {goal}
1520 | The target file is: {CONTRACT_FILE_PATH}
1521 |
1522 | Available functions: `run_ripgrep`, `run_awk`, `run_sed`.
1523 | Remember to include the quoted file path '{shlex.quote(CONTRACT_FILE_PATH)}' in `args_str` and set `input_file=True` when operating on the file.
1524 |
1525 | Plan your steps:
1526 | 1. Use `run_ripgrep` with a regular expression to capture text inside `("...")`. The pattern should capture the content within the quotes. Use the `-o` flag for only matching parts, `-N` for no line numbers, `--no-filename`. Example regex: `\(\"([A-Za-z ]+)\"\)` (you might need to adjust escaping for rg's syntax within `args_str`).
1527 | 2. Process the output to get unique terms. You could pipe the output of ripgrep into awk/sed using `input_data`, e.g., `run_awk` with `'!seen[$0]++'` to get unique lines, or just list unique terms in the final answer.
1528 | 3. Respond ONLY with the JSON for the next action or the final answer (`tool: "finish"`).
1529 | """
1530 | await run_llm_interactive_workflow(goal, system_prompt, target_file=CONTRACT_FILE_PATH)
1531 |
1532 |
1533 | # --- Main Execution ---
1534 |
1535 |
1536 | async def main():
1537 | """Run all LocalTextTools demonstrations."""
1538 | console.print(
1539 | Rule(
1540 | "[bold magenta]Local Text Tools Demo (Standalone Functions)[/bold magenta]",
1541 | style="white",
1542 | )
1543 | )
1544 |
1545 | # Check command availability (uses the new _COMMAND_METADATA if accessible, otherwise shutil.which)
1546 | console.print("Checking availability of required command-line tools...")
1547 | available_tools: Dict[str, bool] = {}
1548 | missing_tools: List[str] = []
1549 | commands_to_check = ["rg", "awk", "sed", "jq"] # Commands used in demo
1550 | try:
1551 | # Try accessing the (internal) metadata if possible for accurate check
1552 | from ultimate_mcp_server.tools.local_text_tools import _COMMAND_METADATA
1553 |
1554 | for cmd, meta in _COMMAND_METADATA.items():
1555 | if cmd in commands_to_check:
1556 | if meta.path and meta.path.exists():
1557 | available_tools[cmd] = True
1558 | console.print(f"[green]✓ {cmd} configured at: {meta.path}[/green]")
1559 | else:
1560 | available_tools[cmd] = False
1561 | missing_tools.append(cmd)
1562 | status = "Not Found" if not meta.path else "Path Not Found"
1563 | console.print(f"[bold red]✗ {cmd} {status}[/bold red]")
1564 | # Check any commands not in metadata via simple which
1565 | for cmd in commands_to_check:
1566 | if cmd not in available_tools:
1567 | if shutil.which(cmd):
1568 | available_tools[cmd] = True
1569 | console.print(f"[green]✓ {cmd} found via shutil.which[/green]")
1570 | else:
1571 | available_tools[cmd] = False
1572 | missing_tools.append(cmd)
1573 | console.print(f"[bold red]✗ {cmd} NOT FOUND[/bold red]")
1574 |
1575 | except ImportError:
1576 | # Fallback to simple check if internal metadata not accessible
1577 | logger.warning("Could not access internal _COMMAND_METADATA, using shutil.which fallback.")
1578 | for cmd in commands_to_check:
1579 | if shutil.which(cmd):
1580 | available_tools[cmd] = True
1581 | console.print(f"[green]✓ {cmd} found via shutil.which[/green]")
1582 | else:
1583 | available_tools[cmd] = False
1584 | missing_tools.append(cmd)
1585 | console.print(f"[bold red]✗ {cmd} NOT FOUND[/bold red]")
1586 |
1587 | if missing_tools:
1588 | console.print(
1589 | f"\n[bold yellow]Warning:[/bold yellow] The following tools seem missing or not configured: {', '.join(missing_tools)}"
1590 | )
1591 | console.print("Demonstrations requiring these tools will likely fail.")
1592 | console.print("Please install them and ensure they are in your system's PATH.")
1593 | console.print("-" * 30)
1594 |
1595 | # No instantiation needed for standalone functions
1596 |
1597 | # --- Basic Demos ---
1598 | if available_tools.get("rg"):
1599 | await demonstrate_ripgrep_basic()
1600 | if available_tools.get("awk"):
1601 | await demonstrate_awk_basic()
1602 | if available_tools.get("sed"):
1603 | await demonstrate_sed_basic()
1604 | if available_tools.get("jq"):
1605 | await demonstrate_jq_basic()
1606 |
1607 | # --- Advanced Demos ---
1608 | if available_tools.get("rg"):
1609 | await demonstrate_ripgrep_advanced()
1610 | if available_tools.get("awk"):
1611 | await demonstrate_awk_advanced()
1612 | if available_tools.get("sed"):
1613 | await demonstrate_sed_advanced()
1614 | if available_tools.get("jq"):
1615 | await demonstrate_jq_advanced()
1616 |
1617 | # --- Security Demos ---
1618 | # These demos don't strictly require the tool to *succeed*, just to be called
1619 | # Run them even if some tools might be missing, to show validation layer
1620 | await demonstrate_security_features()
1621 |
1622 | # --- Streaming Demos ---
1623 | if all(available_tools.get(cmd) for cmd in ["rg", "awk", "sed", "jq"]):
1624 | await demonstrate_streaming()
1625 | else:
1626 | console.print(
1627 | Rule(
1628 | "[yellow]Skipping Streaming Demos (One or more tools missing)[/yellow]",
1629 | style="yellow",
1630 | )
1631 | )
1632 |
1633 | # --- LLM Workflow Demos ---
1634 | llm_available = False
1635 | try:
1636 | config = get_config()
1637 | provider_key = config.default_provider or Provider.OPENAI.value # Check default or fallback
1638 | if (
1639 | config.providers
1640 | and getattr(config.providers, provider_key, None)
1641 | and getattr(config.providers, provider_key).api_key
1642 | ):
1643 | llm_available = True
1644 | else:
1645 | logger.warning(f"LLM provider '{provider_key}' API key not configured.")
1646 | except Exception as e:
1647 | logger.warning(f"Could not verify LLM provider configuration: {e}")
1648 |
1649 | if llm_available and all(
1650 | available_tools.get(cmd) for cmd in ["rg", "awk", "sed"]
1651 | ): # Check tools needed by LLM demos
1652 | llm_demo_ran = await demonstrate_llm_workflow_extract_contacts()
1653 | if llm_demo_ran:
1654 | await demonstrate_llm_workflow_financial_terms()
1655 | if llm_demo_ran:
1656 | await demonstrate_llm_workflow_defined_terms()
1657 | else:
1658 | reason = (
1659 | "LLM Provider Not Configured/Available"
1660 | if not llm_available
1661 | else "One or more required tools (rg, awk, sed) missing"
1662 | )
1663 | console.print(
1664 | Rule(f"[yellow]Skipping LLM Workflow Demos ({reason})[/yellow]", style="yellow")
1665 | )
1666 |
1667 | console.print(Rule("[bold green]Local Text Tools Demo Complete[/bold green]", style="green"))
1668 | return 0
1669 |
1670 |
1671 | if __name__ == "__main__":
1672 | # Run the demo
1673 | try:
1674 | exit_code = asyncio.run(main())
1675 | sys.exit(exit_code)
1676 | except KeyboardInterrupt:
1677 | console.print("\n[bold yellow]Demo interrupted by user.[/bold yellow]")
1678 | sys.exit(1)
1679 |
```
--------------------------------------------------------------------------------
/ultimate_mcp_server/core/providers/ollama.py:
--------------------------------------------------------------------------------
```python
1 | """Ollama provider implementation for the Ultimate MCP Server.
2 |
3 | This module implements the Ollama provider, enabling interaction with locally running
4 | Ollama models through a standard interface. Ollama is an open-source framework for
5 | running LLMs locally with minimal setup.
6 |
7 | The implementation supports:
8 | - Text completion (generate) and chat completations
9 | - Streaming responses
10 | - Model listing and information retrieval
11 | - Embeddings generation
12 | - Cost tracking (estimated since Ollama is free to use locally)
13 |
14 | Ollama must be installed and running locally (by default on localhost:11434)
15 | for this provider to work properly.
16 | """
17 |
18 | import asyncio
19 | import json
20 | import re
21 | import time
22 | from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
23 |
24 | import aiohttp
25 | import httpx
26 | from pydantic import BaseModel
27 |
28 | from ultimate_mcp_server.config import get_config
29 | from ultimate_mcp_server.constants import COST_PER_MILLION_TOKENS, Provider
30 | from ultimate_mcp_server.core.providers.base import (
31 | BaseProvider,
32 | ModelResponse,
33 | )
34 | from ultimate_mcp_server.exceptions import ProviderError
35 | from ultimate_mcp_server.utils import get_logger
36 |
37 | logger = get_logger("ultimate_mcp_server.providers.ollama")
38 |
39 |
40 | # Define the Model class locally since it's not available in base.py
41 | class Model(dict):
42 | """Model information returned by providers."""
43 |
44 | def __init__(self, id: str, name: str, description: str, provider: str, **kwargs):
45 | """Initialize a model info dictionary.
46 |
47 | Args:
48 | id: Model identifier (e.g., "llama3.2")
49 | name: Human-readable model name
50 | description: Longer description of the model
51 | provider: Provider name
52 | **kwargs: Additional model metadata
53 | """
54 | super().__init__(id=id, name=name, description=description, provider=provider, **kwargs)
55 |
56 |
57 | # Define ProviderFeatures locally since it's not available in base.py
58 | class ProviderFeatures:
59 | """Features supported by a provider."""
60 |
61 | def __init__(
62 | self,
63 | supports_chat_completions: bool = False,
64 | supports_streaming: bool = False,
65 | supports_function_calling: bool = False,
66 | supports_multiple_functions: bool = False,
67 | supports_embeddings: bool = False,
68 | supports_json_mode: bool = False,
69 | max_retries: int = 3,
70 | ):
71 | """Initialize provider features.
72 |
73 | Args:
74 | supports_chat_completions: Whether the provider supports chat completions
75 | supports_streaming: Whether the provider supports streaming responses
76 | supports_function_calling: Whether the provider supports function calling
77 | supports_multiple_functions: Whether the provider supports multiple functions
78 | supports_embeddings: Whether the provider supports embeddings
79 | supports_json_mode: Whether the provider supports JSON mode
80 | max_retries: Maximum number of retries for failed requests
81 | """
82 | self.supports_chat_completions = supports_chat_completions
83 | self.supports_streaming = supports_streaming
84 | self.supports_function_calling = supports_function_calling
85 | self.supports_multiple_functions = supports_multiple_functions
86 | self.supports_embeddings = supports_embeddings
87 | self.supports_json_mode = supports_json_mode
88 | self.max_retries = max_retries
89 |
90 |
91 | # Define ProviderStatus locally since it's not available in base.py
92 | class ProviderStatus:
93 | """Status information for a provider."""
94 |
95 | def __init__(
96 | self,
97 | name: str,
98 | enabled: bool = False,
99 | available: bool = False,
100 | api_key_configured: bool = False,
101 | features: Optional[ProviderFeatures] = None,
102 | default_model: Optional[str] = None,
103 | ):
104 | """Initialize provider status.
105 |
106 | Args:
107 | name: Provider name
108 | enabled: Whether the provider is enabled
109 | available: Whether the provider is available
110 | api_key_configured: Whether an API key is configured
111 | features: Provider features
112 | default_model: Default model for the provider
113 | """
114 | self.name = name
115 | self.enabled = enabled
116 | self.available = available
117 | self.api_key_configured = api_key_configured
118 | self.features = features
119 | self.default_model = default_model
120 |
121 |
122 | class OllamaConfig(BaseModel):
123 | """Configuration for the Ollama provider."""
124 |
125 | # API endpoint (default is localhost:11434)
126 | api_url: str = "http://127.0.0.1:11434"
127 |
128 | # Default model to use if none specified
129 | default_model: str = "llama3.2"
130 |
131 | # Timeout settings
132 | request_timeout: int = 300
133 |
134 | # Whether this provider is enabled
135 | enabled: bool = True
136 |
137 |
138 | class OllamaProvider(BaseProvider):
139 | """
140 | Provider implementation for Ollama.
141 |
142 | Ollama allows running open-source language models locally with minimal setup.
143 | This provider implementation connects to a locally running Ollama instance and
144 | provides a standard interface for generating completions and embeddings.
145 |
146 | Unlike cloud providers, Ollama runs models locally, so:
147 | - No API key is required
148 | - Costs are estimated (since running locally is free)
149 | - Model availability depends on what models have been downloaded locally
150 |
151 | The Ollama provider supports both chat completions and text completions,
152 | as well as streaming responses. It requires that the Ollama service is
153 | running and accessible at the configured endpoint.
154 | """
155 |
156 | provider_name = Provider.OLLAMA
157 |
158 | def __init__(self, api_key: Optional[str] = None, **kwargs):
159 | """Initialize the Ollama provider.
160 |
161 | Args:
162 | api_key: Not used by Ollama, included for API compatibility with other providers
163 | **kwargs: Additional provider-specific options
164 | """
165 | # Skip API key, it's not used by Ollama but we accept it for compatibility
166 | super().__init__()
167 | self.logger = get_logger(f"provider.{Provider.OLLAMA}")
168 | self.logger.info("Initializing Ollama provider...")
169 | self.config = self._load_config()
170 | self.logger.info(
171 | f"Loaded config: API URL={self.config.api_url}, default_model={self.config.default_model}, enabled={self.config.enabled}"
172 | )
173 |
174 | # Initialize session to None, we'll create it when needed
175 | self._session = None
176 |
177 | self.client_session_params = {
178 | "timeout": aiohttp.ClientTimeout(total=self.config.request_timeout)
179 | }
180 |
181 | # Unlike other providers, Ollama doesn't require an API key
182 | # But we'll still set this flag to True for consistency
183 | self._api_key_configured = True
184 | self._initialized = False
185 |
186 | # Set feature flags
187 | self.features = ProviderFeatures(
188 | supports_chat_completions=True,
189 | supports_streaming=True,
190 | supports_function_calling=False, # Ollama doesn't support function calling natively
191 | supports_multiple_functions=False,
192 | supports_embeddings=True,
193 | supports_json_mode=True, # Now supported via prompt engineering and format parameter
194 | max_retries=3,
195 | )
196 |
197 | # Set default costs for Ollama models (very low estimated costs)
198 | # Since Ollama runs locally, the actual cost is hardware usage/electricity
199 | # We'll use very low values for tracking purposes
200 | self._default_token_cost = {
201 | "input": 0.0001, # $0.0001 per 1M tokens (effectively free)
202 | "output": 0.0001, # $0.0001 per 1M tokens (effectively free)
203 | }
204 | self.logger.info("Ollama provider initialization completed")
205 |
206 | @property
207 | async def session(self) -> aiohttp.ClientSession:
208 | """Get the current session or create a new one if needed."""
209 | if self._session is None or self._session.closed:
210 | self._session = aiohttp.ClientSession(**self.client_session_params)
211 | return self._session
212 |
213 | async def __aenter__(self):
214 | """Enter async context, initializing the provider."""
215 | await self.initialize()
216 | return self
217 |
218 | async def __aexit__(self, exc_type, exc_val, exc_tb):
219 | """Exit async context, ensuring proper shutdown."""
220 | await self.shutdown()
221 |
222 | async def initialize(self) -> bool:
223 | """Initialize the provider, creating a new HTTP session.
224 |
225 | This method handles the initialization of the connection to Ollama.
226 | If Ollama isn't available (not installed or not running),
227 | it will gracefully report the issue without spamming errors.
228 |
229 | Returns:
230 | bool: True if initialization was successful, False otherwise
231 | """
232 | try:
233 | # Create a temporary session with a short timeout for the initial check
234 | async with aiohttp.ClientSession(
235 | timeout=aiohttp.ClientTimeout(total=5.0)
236 | ) as check_session:
237 | # Try to connect to Ollama and check if it's running
238 | self.logger.info(
239 | f"Attempting to connect to Ollama at {self.config.api_url}/api/tags",
240 | emoji_key="provider",
241 | )
242 |
243 | # First try the configured URL
244 | try:
245 | async with check_session.get(
246 | f"{self.config.api_url}/api/tags", timeout=5.0
247 | ) as response:
248 | if response.status == 200:
249 | # Ollama is running, we'll create the main session when needed later
250 | self.logger.info(
251 | "Ollama service is available and running", emoji_key="provider"
252 | )
253 | self._initialized = True
254 | return True
255 | else:
256 | self.logger.warning(
257 | f"Ollama service responded with status {response.status}. "
258 | "The service might be misconfigured.",
259 | emoji_key="warning",
260 | )
261 | except aiohttp.ClientConnectionError:
262 | # Try alternate localhost format (127.0.0.1 instead of localhost or vice versa)
263 | alternate_url = (
264 | self.config.api_url.replace("localhost", "127.0.0.1")
265 | if "localhost" in self.config.api_url
266 | else self.config.api_url.replace("127.0.0.1", "localhost")
267 | )
268 | self.logger.info(
269 | f"Connection failed, trying alternate URL: {alternate_url}",
270 | emoji_key="provider",
271 | )
272 |
273 | try:
274 | async with check_session.get(
275 | f"{alternate_url}/api/tags", timeout=5.0
276 | ) as response:
277 | if response.status == 200:
278 | # Update the config to use the working URL
279 | self.logger.info(
280 | f"Connected successfully using alternate URL: {alternate_url}",
281 | emoji_key="provider",
282 | )
283 | self.config.api_url = alternate_url
284 | self._initialized = True
285 | return True
286 | else:
287 | self.logger.warning(
288 | f"Ollama service at alternate URL responded with status {response.status}. "
289 | "The service might be misconfigured.",
290 | emoji_key="warning",
291 | )
292 | except (aiohttp.ClientError, asyncio.TimeoutError) as e:
293 | self.logger.warning(
294 | f"Could not connect to alternate URL: {str(e)}. "
295 | "Make sure Ollama is installed and running: https://ollama.com/download",
296 | emoji_key="warning",
297 | )
298 | except aiohttp.ClientError as e:
299 | # Other client errors
300 | self.logger.warning(
301 | f"Could not connect to Ollama service: {str(e)}. "
302 | "Make sure Ollama is installed and running: https://ollama.com/download",
303 | emoji_key="warning",
304 | )
305 | except asyncio.TimeoutError:
306 | # Timeout indicates Ollama is likely not responding
307 | self.logger.warning(
308 | "Connection to Ollama service timed out. "
309 | "Make sure Ollama is installed and running: https://ollama.com/download",
310 | emoji_key="warning",
311 | )
312 |
313 | # If we got here, Ollama is not available
314 | self._initialized = False
315 | return False
316 |
317 | except Exception as e:
318 | # Catch any other exceptions to avoid spamming errors
319 | self.logger.error(
320 | f"Unexpected error initializing Ollama provider: {str(e)}", emoji_key="error"
321 | )
322 | self._initialized = False
323 | return False
324 |
325 | async def shutdown(self) -> None:
326 | """Shutdown the provider, closing the HTTP session."""
327 | try:
328 | if self._session and not self._session.closed:
329 | await self._session.close()
330 | self._session = None
331 | except Exception as e:
332 | self.logger.warning(
333 | f"Error closing Ollama session during shutdown: {str(e)}", emoji_key="warning"
334 | )
335 | finally:
336 | self._initialized = False
337 |
338 | def _load_config(self) -> OllamaConfig:
339 | """Load Ollama configuration from app configuration."""
340 | try:
341 | self.logger.info("Loading Ollama config from app configuration")
342 | config = get_config()
343 | # Print entire config for debugging
344 | self.logger.debug(f"Full config: {config}")
345 |
346 | if not hasattr(config, "providers"):
347 | self.logger.warning("Config doesn't have 'providers' attribute")
348 | return OllamaConfig()
349 |
350 | if not hasattr(config.providers, Provider.OLLAMA):
351 | self.logger.warning(f"Config doesn't have '{Provider.OLLAMA}' provider configured")
352 | return OllamaConfig()
353 |
354 | provider_config = getattr(config.providers, Provider.OLLAMA, {})
355 | self.logger.info(f"Found provider config: {provider_config}")
356 |
357 | if hasattr(provider_config, "dict"):
358 | self.logger.info("Provider config has 'dict' method, using it")
359 | return OllamaConfig(**provider_config.dict())
360 | else:
361 | self.logger.warning(
362 | "Provider config doesn't have 'dict' method, attempting direct conversion"
363 | )
364 | # Try to convert to dict directly
365 | config_dict = {}
366 |
367 | # Define mapping from ProviderConfig field names to OllamaConfig field names
368 | field_mapping = {
369 | "base_url": "api_url", # ProviderConfig -> OllamaConfig
370 | "default_model": "default_model",
371 | "timeout": "request_timeout",
372 | "enabled": "enabled",
373 | }
374 |
375 | # Map fields from provider_config to OllamaConfig's expected field names
376 | for provider_key, ollama_key in field_mapping.items():
377 | if hasattr(provider_config, provider_key):
378 | config_dict[ollama_key] = getattr(provider_config, provider_key)
379 | self.logger.info(
380 | f"Mapped {provider_key} to {ollama_key}: {getattr(provider_config, provider_key)}"
381 | )
382 |
383 | self.logger.info(f"Created config dict: {config_dict}")
384 | return OllamaConfig(**config_dict)
385 | except Exception as e:
386 | self.logger.error(f"Error loading Ollama config: {e}", exc_info=True)
387 | return OllamaConfig()
388 |
389 | def get_default_model(self) -> str:
390 | """Get the default model for this provider."""
391 | return self.config.default_model
392 |
393 | def get_status(self) -> ProviderStatus:
394 | """Get the current status of this provider."""
395 | return ProviderStatus(
396 | name=self.provider_name,
397 | enabled=self.config.enabled,
398 | available=self._initialized,
399 | api_key_configured=self._api_key_configured,
400 | features=self.features,
401 | default_model=self.get_default_model(),
402 | )
403 |
404 | async def check_api_key(self) -> bool:
405 | """
406 | Check if the Ollama service is accessible.
407 |
408 | Since Ollama doesn't use API keys, this just checks if the service is running.
409 | This check is designed to fail gracefully if Ollama is not installed or running,
410 | without causing cascading errors in the system.
411 |
412 | Returns:
413 | bool: True if Ollama service is running and accessible, False otherwise
414 | """
415 | if not self._initialized:
416 | try:
417 | # Attempt to initialize with a short timeout
418 | return await self.initialize()
419 | except Exception as e:
420 | self.logger.warning(
421 | f"Failed to initialize Ollama during service check: {str(e)}",
422 | emoji_key="warning",
423 | )
424 | return False
425 |
426 | try:
427 | # Use a dedicated session with short timeout for health check
428 | async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=3.0)) as session:
429 | try:
430 | async with session.get(f"{self.config.api_url}/api/tags") as response:
431 | return response.status == 200
432 | except (aiohttp.ClientConnectionError, asyncio.TimeoutError, Exception) as e:
433 | self.logger.warning(
434 | f"Ollama service check failed: {str(e)}", emoji_key="warning"
435 | )
436 | return False
437 | except Exception as e:
438 | self.logger.warning(
439 | f"Failed to create session for Ollama check: {str(e)}", emoji_key="warning"
440 | )
441 | return False
442 |
443 | def _build_api_url(self, endpoint: str) -> str:
444 | """Build the full API URL for a given endpoint."""
445 | return f"{self.config.api_url}/api/{endpoint}"
446 |
447 | def _estimate_token_cost(self, model: str, input_tokens: int, output_tokens: int) -> float:
448 | """
449 | Estimate the cost of a completion based on token counts.
450 |
451 | Since Ollama runs locally, the costs are just estimates and very low.
452 | """
453 | # Try to get model-specific costs if available
454 | model_costs = COST_PER_MILLION_TOKENS.get(model, self._default_token_cost)
455 |
456 | # Calculate costs
457 | input_cost = (input_tokens / 1_000_000) * model_costs.get(
458 | "input", self._default_token_cost["input"]
459 | )
460 | output_cost = (output_tokens / 1_000_000) * model_costs.get(
461 | "output", self._default_token_cost["output"]
462 | )
463 |
464 | return input_cost + output_cost
465 |
466 | async def list_models(self) -> List[Model]:
467 | """
468 | List all available models from Ollama.
469 |
470 | This method attempts to list all locally available Ollama models.
471 | If Ollama is not available or cannot be reached, it will return
472 | an empty list instead of raising an exception.
473 |
474 | Returns:
475 | List of available Ollama models, or empty list if Ollama is not available
476 | """
477 | if not self._initialized:
478 | try:
479 | initialized = await self.initialize()
480 | if not initialized:
481 | self.logger.warning(
482 | "Cannot list Ollama models because the service is not available",
483 | emoji_key="warning",
484 | )
485 | return []
486 | except Exception:
487 | self.logger.warning(
488 | "Cannot list Ollama models because initialization failed", emoji_key="warning"
489 | )
490 | return []
491 |
492 | try:
493 | # Create a dedicated session for this operation to avoid shared session issues
494 | async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=10.0)) as session:
495 | return await self._fetch_models(session)
496 | except Exception as e:
497 | self.logger.warning(
498 | f"Error listing Ollama models: {str(e)}. The service may not be running.",
499 | emoji_key="warning",
500 | )
501 | return []
502 |
503 | async def _fetch_models(self, session: aiohttp.ClientSession) -> List[Model]:
504 | """Fetch models using the provided session."""
505 | try:
506 | async with session.get(self._build_api_url("tags")) as response:
507 | if response.status != 200:
508 | self.logger.warning(f"Failed to list Ollama models: {response.status}")
509 | return []
510 |
511 | data = await response.json()
512 | models = []
513 |
514 | # Process the response
515 | for model_info in data.get("models", []):
516 | model_id = model_info.get("name", "")
517 |
518 | # Extract additional info if available
519 | description = f"Ollama model: {model_id}"
520 | model_size = model_info.get("size", 0)
521 | size_gb = None
522 |
523 | if model_size:
524 | # Convert to GB for readability if size is provided in bytes
525 | size_gb = model_size / (1024 * 1024 * 1024)
526 | description += f" ({size_gb:.2f} GB)"
527 |
528 | models.append(
529 | Model(
530 | id=model_id,
531 | name=model_id,
532 | description=description,
533 | provider=self.provider_name,
534 | size=f"{size_gb:.2f} GB" if size_gb else "Unknown",
535 | )
536 | )
537 |
538 | return models
539 | except aiohttp.ClientConnectionError:
540 | self.logger.warning(
541 | "Connection refused while listing Ollama models", emoji_key="warning"
542 | )
543 | return []
544 | except asyncio.TimeoutError:
545 | self.logger.warning("Timeout while listing Ollama models", emoji_key="warning")
546 | return []
547 | except Exception as e:
548 | self.logger.warning(f"Error fetching Ollama models: {str(e)}", emoji_key="warning")
549 | return []
550 |
551 | async def generate_completion(
552 | self,
553 | prompt: Optional[str] = None,
554 | messages: Optional[List[Dict[str, Any]]] = None,
555 | model: Optional[str] = None,
556 | max_tokens: Optional[int] = None,
557 | temperature: float = 0.7,
558 | stop: Optional[List[str]] = None,
559 | top_p: Optional[float] = None,
560 | top_k: Optional[int] = None,
561 | frequency_penalty: Optional[float] = None,
562 | presence_penalty: Optional[float] = None,
563 | mirostat: Optional[int] = None,
564 | mirostat_tau: Optional[float] = None,
565 | mirostat_eta: Optional[float] = None,
566 | json_mode: bool = False,
567 | **kwargs
568 | ) -> ModelResponse:
569 | """Generate a completion from Ollama.
570 |
571 | Args:
572 | prompt: Text prompt to send to Ollama (optional if messages provided)
573 | messages: List of message dictionaries (optional if prompt provided)
574 | model: Ollama model name (e.g., "llama2:13b")
575 | max_tokens: Maximum tokens to generate
576 | temperature: Controls randomness (0.0-1.0)
577 | stop: List of strings that stop generation when encountered
578 | top_p: Nucleus sampling parameter
579 | top_k: Top-k sampling parameter
580 | frequency_penalty: Frequency penalty parameter
581 | presence_penalty: Presence penalty parameter
582 | mirostat: Mirostat sampling algorithm (0, 1, or 2)
583 | mirostat_tau: Target entropy for mirostat
584 | mirostat_eta: Learning rate for mirostat
585 | json_mode: Request JSON-formatted response
586 | **kwargs: Additional parameters
587 |
588 | Returns:
589 | ModelResponse object with completion result
590 | """
591 | if not self.config.api_url:
592 | raise ValueError("Ollama API URL not configured")
593 |
594 | # Verify we have either prompt or messages
595 | if prompt is None and not messages:
596 | raise ValueError("Either prompt or messages must be provided to generate a completion")
597 |
598 | # If model is None, use configured default
599 | model = model or self.get_default_model()
600 |
601 | # Only strip provider prefix if it's our provider name, keep organization prefixes
602 | if "/" in model and model.startswith(f"{self.provider_name}/"):
603 | model = model.split("/", 1)[1]
604 |
605 | # If JSON mode is enabled, use the streaming implementation internally
606 | # since Ollama's non-streaming JSON mode is inconsistent
607 | if json_mode:
608 | self.logger.debug("JSON mode requested, using streaming implementation internally for reliability")
609 | return await self._generate_completion_via_streaming(
610 | prompt=prompt,
611 | messages=messages,
612 | model=model,
613 | max_tokens=max_tokens,
614 | temperature=temperature,
615 | stop=stop,
616 | top_p=top_p,
617 | top_k=top_k,
618 | frequency_penalty=frequency_penalty,
619 | presence_penalty=presence_penalty,
620 | mirostat=mirostat,
621 | mirostat_tau=mirostat_tau,
622 | mirostat_eta=mirostat_eta,
623 | json_mode=True, # Ensure json_mode is passed through
624 | **kwargs
625 | )
626 |
627 | # Log request start
628 | self.logger.info(
629 | f"Generating Ollama completion (generate) with model {model}",
630 | emoji_key=self.provider_name
631 | )
632 |
633 | # Convert messages to prompt if messages provided
634 | using_messages = False
635 | if messages and not prompt:
636 | using_messages = True
637 | # Convert messages to Ollama's chat format
638 | chat_params = {"messages": []}
639 |
640 | # Process messages into Ollama format
641 | for msg in messages:
642 | role = msg.get("role", "").lower()
643 | content = msg.get("content", "")
644 |
645 | # Map roles to Ollama's expected format
646 | if role == "system":
647 | ollama_role = "system"
648 | elif role == "user":
649 | ollama_role = "user"
650 | elif role == "assistant":
651 | ollama_role = "assistant"
652 | else:
653 | # Default unknown roles to user
654 | self.logger.warning(f"Unknown message role '{role}', treating as 'user'")
655 | ollama_role = "user"
656 |
657 | chat_params["messages"].append({
658 | "role": ollama_role,
659 | "content": content
660 | })
661 |
662 | # Add model and parameters to chat_params
663 | chat_params["model"] = model
664 |
665 | # Add optional parameters if provided
666 | if temperature is not None and temperature != 0.7:
667 | chat_params["options"] = chat_params.get("options", {})
668 | chat_params["options"]["temperature"] = temperature
669 |
670 | if max_tokens is not None:
671 | chat_params["options"] = chat_params.get("options", {})
672 | chat_params["options"]["num_predict"] = max_tokens
673 |
674 | if stop:
675 | chat_params["options"] = chat_params.get("options", {})
676 | chat_params["options"]["stop"] = stop
677 |
678 | # Add other parameters if provided
679 | for param_name, param_value in [
680 | ("top_p", top_p),
681 | ("top_k", top_k),
682 | ("frequency_penalty", frequency_penalty),
683 | ("presence_penalty", presence_penalty),
684 | ("mirostat", mirostat),
685 | ("mirostat_tau", mirostat_tau),
686 | ("mirostat_eta", mirostat_eta)
687 | ]:
688 | if param_value is not None:
689 | chat_params["options"] = chat_params.get("options", {})
690 | chat_params["options"][param_name] = param_value
691 |
692 | # Add json_mode if requested (as format option)
693 | if json_mode:
694 | chat_params["options"] = chat_params.get("options", {})
695 | chat_params["options"]["format"] = "json"
696 |
697 | # For Ollama non-streaming completions, we need to force the system message
698 | # because the format param alone isn't reliable
699 | kwargs["add_json_instructions"] = True
700 |
701 | # Only add system message instruction as a fallback if explicitly requested
702 | add_json_instructions = kwargs.pop("add_json_instructions", False)
703 |
704 | # Add system message for json_mode only if requested
705 | if add_json_instructions:
706 | has_system = any(msg.get("role", "").lower() == "system" for msg in messages)
707 | if not has_system:
708 | # Add JSON instruction as a system message
709 | chat_params["messages"].insert(0, {
710 | "role": "system",
711 | "content": "You must respond with valid JSON. Format your entire response as a JSON object with properly quoted keys and values."
712 | })
713 | self.logger.debug("Added JSON system instructions for chat_params")
714 |
715 | # Add any additional kwargs as options
716 | if kwargs:
717 | chat_params["options"] = chat_params.get("options", {})
718 | chat_params["options"].update(kwargs)
719 |
720 | # Use chat endpoint
721 | api_endpoint = self._build_api_url("chat")
722 | response_type = "chat"
723 | else:
724 | # Using generate endpoint with prompt
725 | # Prepare generate parameters
726 | generate_params = {
727 | "model": model,
728 | "prompt": prompt
729 | }
730 |
731 | # Add optional parameters if provided
732 | if temperature is not None and temperature != 0.7:
733 | generate_params["options"] = generate_params.get("options", {})
734 | generate_params["options"]["temperature"] = temperature
735 |
736 | if max_tokens is not None:
737 | generate_params["options"] = generate_params.get("options", {})
738 | generate_params["options"]["num_predict"] = max_tokens
739 |
740 | if stop:
741 | generate_params["options"] = generate_params.get("options", {})
742 | generate_params["options"]["stop"] = stop
743 |
744 | # Add other parameters if provided
745 | for param_name, param_value in [
746 | ("top_p", top_p),
747 | ("top_k", top_k),
748 | ("frequency_penalty", frequency_penalty),
749 | ("presence_penalty", presence_penalty),
750 | ("mirostat", mirostat),
751 | ("mirostat_tau", mirostat_tau),
752 | ("mirostat_eta", mirostat_eta)
753 | ]:
754 | if param_value is not None:
755 | generate_params["options"] = generate_params.get("options", {})
756 | generate_params["options"][param_name] = param_value
757 |
758 | # Add json_mode if requested (as format option)
759 | if json_mode:
760 | generate_params["options"] = generate_params.get("options", {})
761 | generate_params["options"]["format"] = "json"
762 |
763 | # For Ollama non-streaming completions, we need to force the JSON instructions
764 | # because the format param alone isn't reliable
765 | kwargs["add_json_instructions"] = True
766 |
767 | # Only enhance prompt with JSON instructions if explicitly requested
768 | add_json_instructions = kwargs.pop("add_json_instructions", False)
769 | if add_json_instructions:
770 | # Enhance prompt with JSON instructions for better compliance
771 | generate_params["prompt"] = f"Please respond with valid JSON only. {prompt}\nEnsure your entire response is a valid, parseable JSON object with properly quoted keys and values."
772 | self.logger.debug("Enhanced prompt with JSON instructions for generate_params")
773 |
774 | # Add any additional kwargs as options
775 | if kwargs:
776 | generate_params["options"] = generate_params.get("options", {})
777 | generate_params["options"].update(kwargs)
778 |
779 | # Use generate endpoint
780 | api_endpoint = self._build_api_url("generate")
781 | response_type = "generate" # noqa: F841
782 |
783 | # Start timer for tracking
784 | start_time = time.time()
785 |
786 | try:
787 | # Make HTTP request to Ollama
788 | async with httpx.AsyncClient(timeout=self.config.request_timeout) as client:
789 | if using_messages:
790 | # Using chat endpoint
791 | response = await client.post(api_endpoint, json=chat_params)
792 | else:
793 | # Using generate endpoint
794 | response = await client.post(api_endpoint, json=generate_params)
795 |
796 | # Check for HTTP errors
797 | response.raise_for_status()
798 |
799 | # Parse response - handle multi-line JSON data which can happen with json_mode
800 | try:
801 | # First try regular JSON parsing
802 | result = response.json()
803 | except json.JSONDecodeError as e:
804 | # If that fails, try parsing line by line and concatenate responses
805 | self.logger.debug("Response contains multiple JSON objects, parsing line by line")
806 | content = response.text
807 | lines = content.strip().split('\n')
808 |
809 | # If we have multiple JSON objects
810 | if len(lines) > 1:
811 | # For multiple objects, take the last one which should have the final response
812 | # This happens in some Ollama versions when using format=json
813 | try:
814 | result = json.loads(lines[-1]) # Use the last line, which typically has the complete response
815 |
816 | # Verify result has response/message field, if not try the first line
817 | if using_messages and "message" not in result:
818 | result = json.loads(lines[0])
819 | elif not using_messages and "response" not in result:
820 | result = json.loads(lines[0])
821 |
822 | except json.JSONDecodeError as e:
823 | raise RuntimeError(f"Failed to parse Ollama JSON response: {str(e)}. Response: {content[:200]}...") from e
824 | else:
825 | # If we only have one line but still got a JSON error
826 | raise RuntimeError(f"Invalid JSON in Ollama response: {content[:200]}...") from e
827 |
828 | # Calculate processing time
829 | processing_time = time.time() - start_time
830 |
831 | # Extract response text based on endpoint
832 | if using_messages:
833 | # Extract from chat endpoint
834 | completion_text = result.get("message", {}).get("content", "")
835 | else:
836 | # Extract from generate endpoint
837 | completion_text = result.get("response", "")
838 |
839 | # Log the raw response for debugging
840 | self.logger.debug(f"Raw Ollama response: {result}")
841 | self.logger.debug(f"Extracted completion text: {completion_text[:500]}...")
842 |
843 | # For JSON mode, ensure the completion text is properly formatted JSON
844 | if json_mode and completion_text:
845 | # Always use add_json_instructions for this model since it seems to need it
846 | if "gemma" in model.lower():
847 | # Force adding instructions for gemma models specifically
848 | kwargs["add_json_instructions"] = True
849 |
850 | try:
851 | # First try to extract JSON using our comprehensive method
852 | extracted_json = self._extract_json_from_text(completion_text)
853 | self.logger.debug(f"Extracted JSON: {extracted_json[:500]}...")
854 |
855 | # If we found valid JSON, parse and format it
856 | json_data = json.loads(extracted_json)
857 |
858 | # If successful, format it nicely with indentation
859 | if isinstance(json_data, (dict, list)):
860 | completion_text = json.dumps(json_data, indent=2)
861 | self.logger.debug("Successfully parsed and formatted JSON response")
862 | else:
863 | self.logger.warning(f"JSON response is not a dict or list: {type(json_data)}")
864 | except (json.JSONDecodeError, TypeError) as e:
865 | self.logger.warning(f"Failed to extract valid JSON from response: {str(e)[:100]}...")
866 |
867 | # Calculate token usage
868 | prompt_tokens = result.get("prompt_eval_count", 0)
869 | completion_tokens = result.get("eval_count", 0)
870 |
871 | # Format the standardized response
872 | model_response = ModelResponse(
873 | text=completion_text,
874 | model=f"{self.provider_name}/{model}",
875 | provider=self.provider_name,
876 | input_tokens=prompt_tokens,
877 | output_tokens=completion_tokens,
878 | processing_time=processing_time,
879 | raw_response=result
880 | )
881 |
882 | # Add message field for chat_completion compatibility
883 | model_response.message = {"role": "assistant", "content": completion_text}
884 |
885 | # Ensure there's always a value returned for JSON mode to prevent empty displays
886 | if json_mode and (not completion_text or not completion_text.strip()):
887 | # If we got an empty response, create a default one
888 | default_json = {
889 | "response": "No content was returned by the model",
890 | "error": "Empty response with json_mode enabled"
891 | }
892 | completion_text = json.dumps(default_json, indent=2)
893 | model_response.text = completion_text
894 | model_response.message["content"] = completion_text
895 | self.logger.warning("Empty response with JSON mode, returning default JSON structure")
896 |
897 | # Log success
898 | self.logger.success(
899 | f"Ollama completion successful with model {model}",
900 | emoji_key="completion_success",
901 | tokens={"input": prompt_tokens, "output": completion_tokens},
902 | time=processing_time,
903 | model=model
904 | )
905 |
906 | return model_response
907 |
908 | except httpx.HTTPStatusError as http_err:
909 | # Handle HTTP errors
910 | processing_time = time.time() - start_time
911 | try:
912 | error_json = http_err.response.json()
913 | error_msg = error_json.get("error", str(http_err))
914 | except (json.JSONDecodeError, KeyError):
915 | error_msg = f"HTTP error: {http_err.response.status_code} - {http_err.response.text}"
916 |
917 | self.logger.error(
918 | f"Ollama API error: {error_msg}",
919 | emoji_key="error",
920 | status_code=http_err.response.status_code,
921 | model=model
922 | )
923 |
924 | raise ConnectionError(f"Ollama API error: {error_msg}") from http_err
925 |
926 | except httpx.RequestError as req_err:
927 | # Handle request errors (e.g., connection issues)
928 | processing_time = time.time() - start_time
929 | error_msg = f"Request error: {str(req_err)}"
930 |
931 | self.logger.error(
932 | f"Ollama request error: {error_msg}",
933 | emoji_key="error",
934 | model=model
935 | )
936 |
937 | raise ConnectionError(f"Ollama request error: {error_msg}") from req_err
938 |
939 | except Exception as e:
940 | # Handle other unexpected errors
941 | processing_time = time.time() - start_time
942 |
943 | self.logger.error(
944 | f"Unexpected error calling Ollama: {str(e)}",
945 | emoji_key="error",
946 | model=model,
947 | exc_info=True
948 | )
949 |
950 | raise RuntimeError(f"Unexpected error calling Ollama: {str(e)}") from e
951 |
952 | async def generate_completion_stream(
953 | self,
954 | prompt: Optional[str] = None,
955 | messages: Optional[List[Dict[str, Any]]] = None,
956 | model: Optional[str] = None,
957 | temperature: float = 0.7,
958 | max_tokens: Optional[int] = None,
959 | stop: Optional[List[str]] = None,
960 | system: Optional[str] = None,
961 | **kwargs: Any,
962 | ) -> AsyncGenerator[Tuple[str, Dict[str, Any]], None]:
963 | # This is the main try block for the whole function - needs exception handling
964 | try:
965 | # Verify we have either prompt or messages
966 | if prompt is None and not messages:
967 | raise ValueError("Either prompt or messages must be provided to generate a streaming completion")
968 |
969 | # Check if provider is initialized before attempting to generate
970 | if not self._initialized:
971 | try:
972 | initialized = await self.initialize()
973 | if not initialized:
974 | # Yield an error message and immediately terminate
975 | error_metadata = {
976 | "model": f"{self.provider_name}/{model or self.get_default_model()}",
977 | "provider": self.provider_name,
978 | "error": "Ollama service is not available. Make sure Ollama is installed and running: https://ollama.com/download",
979 | "finish_reason": "error",
980 | "input_tokens": 0,
981 | "output_tokens": 0,
982 | "total_tokens": 0,
983 | "processing_time": 0.0,
984 | }
985 | yield "", error_metadata
986 | return
987 | except Exception as e:
988 | # Yield an error message and immediately terminate
989 | error_metadata = {
990 | "model": f"{self.provider_name}/{model or self.get_default_model()}",
991 | "provider": self.provider_name,
992 | "error": f"Failed to initialize Ollama provider: {str(e)}. Make sure Ollama is installed and running: https://ollama.com/download",
993 | "finish_reason": "error",
994 | "input_tokens": 0,
995 | "output_tokens": 0,
996 | "total_tokens": 0,
997 | "processing_time": 0.0,
998 | }
999 | yield "", error_metadata
1000 | return
1001 |
1002 | # Use default model if none specified
1003 | model_id = model or self.get_default_model()
1004 |
1005 | # Only remove our provider prefix if present, keep organization prefixes
1006 | if "/" in model_id and model_id.startswith(f"{self.provider_name}/"):
1007 | model_id = model_id.split("/", 1)[1]
1008 |
1009 | # Check for json_mode flag and remove it from kwargs
1010 | json_mode = kwargs.pop("json_mode", False)
1011 | format_param = None
1012 |
1013 | if json_mode:
1014 | # Ollama supports structured output via 'format' parameter at the ROOT level
1015 | # This can be either "json" for basic JSON mode or a JSON schema for structured output
1016 | format_param = "json" # Use simple "json" string for basic JSON mode
1017 | self.logger.debug("Setting format='json' for Ollama streaming")
1018 |
1019 | # Note: Format parameter may be less reliable with streaming
1020 | # due to how content is chunked, but Ollama should handle this.
1021 |
1022 | # Flag to track if we're using messages format
1023 | using_messages = False
1024 |
1025 | # Prepare the payload based on input type (messages or prompt)
1026 | if messages:
1027 | using_messages = True # noqa: F841
1028 | # Convert messages to Ollama's expected format
1029 | ollama_messages = []
1030 |
1031 | # Process messages
1032 | for msg in messages:
1033 | role = msg.get("role", "").lower()
1034 | content = msg.get("content", "")
1035 |
1036 | # Map roles to Ollama's expected format
1037 | if role == "system":
1038 | ollama_role = "system"
1039 | elif role == "user":
1040 | ollama_role = "user"
1041 | elif role == "assistant":
1042 | ollama_role = "assistant"
1043 | else:
1044 | # Default unknown roles to user
1045 | self.logger.warning(f"Unknown message role '{role}', treating as 'user'")
1046 | ollama_role = "user"
1047 |
1048 | ollama_messages.append({
1049 | "role": ollama_role,
1050 | "content": content
1051 | })
1052 |
1053 | # Build chat payload
1054 | payload = {
1055 | "model": model_id,
1056 | "messages": ollama_messages,
1057 | "stream": True,
1058 | "options": { # Ollama options go inside an 'options' dict
1059 | "temperature": temperature,
1060 | },
1061 | }
1062 |
1063 | # Use chat endpoint
1064 | api_endpoint = "chat"
1065 |
1066 | elif system is not None or model_id.startswith(
1067 | ("llama", "gpt", "claude", "phi", "mistral")
1068 | ):
1069 | # Use chat endpoint with system message (if provided) and prompt
1070 | messages = []
1071 | if system:
1072 | messages.append({"role": "system", "content": system})
1073 | messages.append({"role": "user", "content": prompt})
1074 |
1075 | payload = {
1076 | "model": model_id,
1077 | "messages": messages,
1078 | "stream": True,
1079 | "options": { # Ollama options go inside an 'options' dict
1080 | "temperature": temperature,
1081 | },
1082 | }
1083 |
1084 | # Use chat endpoint
1085 | api_endpoint = "chat"
1086 |
1087 | else:
1088 | # Use generate endpoint with prompt
1089 | payload = {
1090 | "model": model_id,
1091 | "prompt": prompt,
1092 | "stream": True,
1093 | "options": { # Ollama options go inside an 'options' dict
1094 | "temperature": temperature,
1095 | },
1096 | }
1097 |
1098 | # Use generate endpoint
1099 | api_endpoint = "generate"
1100 |
1101 | # Add common optional parameters
1102 | if max_tokens:
1103 | payload["options"]["num_predict"] = max_tokens
1104 | if stop:
1105 | payload["options"]["stop"] = stop
1106 |
1107 | # Add format parameter at the root level if JSON mode is enabled
1108 | if format_param:
1109 | payload["format"] = format_param
1110 |
1111 | # Add any additional supported parameters from kwargs into options
1112 | for key, value in kwargs.items():
1113 | if key in ["seed", "top_k", "top_p", "num_ctx"]:
1114 | payload["options"][key] = value
1115 |
1116 | # Log request including JSON mode status
1117 | content_length = 0
1118 | if messages:
1119 | content_length = sum(len(m.get("content", "")) for m in messages)
1120 | elif prompt:
1121 | content_length = len(prompt)
1122 |
1123 | self.logger.info(
1124 | f"Generating Ollama streaming completion ({api_endpoint}) with model {model_id}",
1125 | emoji_key=self.provider_name,
1126 | prompt_length=content_length,
1127 | json_mode_requested=json_mode,
1128 | )
1129 |
1130 | start_time = time.time()
1131 | input_tokens = 0
1132 | output_tokens = 0
1133 | finish_reason = None
1134 | final_error = None
1135 |
1136 | async with aiohttp.ClientSession(**self.client_session_params) as streaming_session:
1137 | async with streaming_session.post(
1138 | self._build_api_url(api_endpoint), json=payload
1139 | ) as response:
1140 | if response.status != 200:
1141 | error_text = await response.text()
1142 | final_error = (
1143 | f"Ollama streaming API error: {response.status} - {error_text}"
1144 | )
1145 | # Yield error and stop
1146 | yield (
1147 | "",
1148 | {
1149 | "error": final_error,
1150 | "finished": True,
1151 | "provider": self.provider_name,
1152 | "model": model_id,
1153 | },
1154 | )
1155 | return
1156 |
1157 | buffer = ""
1158 | chunk_index = 0
1159 | async for line in response.content:
1160 | if not line.strip():
1161 | continue
1162 | buffer += line.decode("utf-8")
1163 |
1164 | # Process complete JSON objects in the buffer
1165 | while "\n" in buffer:
1166 | json_str, buffer = buffer.split("\n", 1)
1167 | if not json_str.strip():
1168 | continue
1169 |
1170 | try:
1171 | data = json.loads(json_str)
1172 | chunk_index += 1
1173 |
1174 | # Extract content based on endpoint
1175 | if api_endpoint == "chat":
1176 | text_chunk = data.get("message", {}).get("content", "")
1177 | else: # generate endpoint
1178 | text_chunk = data.get("response", "")
1179 |
1180 | # Check if this is the final summary chunk
1181 | if data.get("done", False):
1182 | input_tokens = data.get("prompt_eval_count", input_tokens)
1183 | output_tokens = data.get("eval_count", output_tokens)
1184 | finish_reason = data.get(
1185 | "done_reason", "stop"
1186 | ) # Get finish reason if available
1187 | # Yield the final text chunk if any, then break to yield summary
1188 | if text_chunk:
1189 | metadata = {
1190 | "provider": self.provider_name,
1191 | "model": model_id,
1192 | "chunk_index": chunk_index,
1193 | "finished": False,
1194 | }
1195 | yield text_chunk, metadata
1196 | break # Exit inner loop after processing final chunk
1197 |
1198 | # Yield regular chunk
1199 | if text_chunk:
1200 | metadata = {
1201 | "provider": self.provider_name,
1202 | "model": model_id,
1203 | "chunk_index": chunk_index,
1204 | "finished": False,
1205 | }
1206 | yield text_chunk, metadata
1207 |
1208 | except json.JSONDecodeError:
1209 | self.logger.warning(
1210 | f"Could not decode JSON line: {json_str[:100]}..."
1211 | )
1212 | # Continue, maybe it's part of a larger object split across lines
1213 | except Exception as parse_error:
1214 | self.logger.warning(f"Error processing stream chunk: {parse_error}")
1215 | final_error = f"Error processing stream: {parse_error}"
1216 | break # Stop processing on unexpected error
1217 |
1218 | if final_error:
1219 | break # Exit outer loop if error occurred
1220 |
1221 | # --- Final Chunk ---
1222 | processing_time = time.time() - start_time
1223 | total_tokens = input_tokens + output_tokens
1224 | cost = self._estimate_token_cost(model_id, input_tokens, output_tokens)
1225 |
1226 | final_metadata = {
1227 | "model": f"{self.provider_name}/{model_id}",
1228 | "provider": self.provider_name,
1229 | "finished": True,
1230 | "finish_reason": finish_reason,
1231 | "input_tokens": input_tokens,
1232 | "output_tokens": output_tokens,
1233 | "total_tokens": total_tokens,
1234 | "cost": cost,
1235 | "processing_time": processing_time,
1236 | "error": final_error,
1237 | }
1238 | yield "", final_metadata # Yield empty chunk with final stats
1239 |
1240 | except aiohttp.ClientConnectionError as e:
1241 | # Yield connection error
1242 | yield (
1243 | "",
1244 | {
1245 | "error": f"Connection to Ollama failed: {str(e)}",
1246 | "finished": True,
1247 | "provider": self.provider_name,
1248 | "model": model_id,
1249 | },
1250 | )
1251 | except asyncio.TimeoutError:
1252 | # Yield timeout error
1253 | yield (
1254 | "",
1255 | {
1256 | "error": "Connection to Ollama timed out",
1257 | "finished": True,
1258 | "provider": self.provider_name,
1259 | "model": model_id,
1260 | },
1261 | )
1262 | except Exception as e:
1263 | # Yield generic error
1264 | if isinstance(e, ProviderError):
1265 | raise
1266 | yield (
1267 | "",
1268 | {
1269 | "error": f"Error generating streaming completion: {str(e)}",
1270 | "finished": True,
1271 | "provider": self.provider_name,
1272 | "model": model_id,
1273 | },
1274 | )
1275 |
1276 | async def create_embeddings(
1277 | self,
1278 | texts: List[str],
1279 | model: Optional[str] = None,
1280 | **kwargs: Any,
1281 | ) -> ModelResponse:
1282 | """
1283 | Generate embeddings for a list of texts using the Ollama API.
1284 |
1285 | Args:
1286 | texts: List of texts to generate embeddings for.
1287 | model: The model ID to use (defaults to provider's default).
1288 | **kwargs: Additional parameters to pass to the API.
1289 |
1290 | Returns:
1291 | An ModelResponse object with the embeddings and metadata.
1292 | If Ollama is not available, returns an error in the metadata.
1293 | """
1294 | # Check if provider is initialized before attempting to generate
1295 | if not self._initialized:
1296 | try:
1297 | initialized = await self.initialize()
1298 | if not initialized:
1299 | # Return a clear error without raising an exception
1300 | return ModelResponse(
1301 | text="",
1302 | model=f"{self.provider_name}/{model or self.get_default_model()}",
1303 | provider=self.provider_name,
1304 | input_tokens=0,
1305 | output_tokens=0,
1306 | total_tokens=0,
1307 | processing_time=0.0,
1308 | metadata={
1309 | "error": "Ollama service is not available. Make sure Ollama is installed and running: https://ollama.com/download",
1310 | "embeddings": [],
1311 | },
1312 | )
1313 | except Exception as e:
1314 | # Return a clear error without raising an exception
1315 | return ModelResponse(
1316 | text="",
1317 | model=f"{self.provider_name}/{model or self.get_default_model()}",
1318 | provider=self.provider_name,
1319 | input_tokens=0,
1320 | output_tokens=0,
1321 | total_tokens=0,
1322 | processing_time=0.0,
1323 | metadata={
1324 | "error": f"Failed to initialize Ollama provider: {str(e)}. Make sure Ollama is installed and running: https://ollama.com/download",
1325 | "embeddings": [],
1326 | },
1327 | )
1328 |
1329 | # Use default model if none specified
1330 | model_id = model or self.get_default_model()
1331 |
1332 | # Only remove our provider prefix if present, keep organization prefixes
1333 | if "/" in model_id and model_id.startswith(f"{self.provider_name}/"):
1334 | model_id = model_id.split("/", 1)[1]
1335 |
1336 | # Get total number of tokens in all texts
1337 | # This is an estimation since Ollama doesn't provide token counts for embeddings
1338 | total_tokens = sum(len(text.split()) for text in texts)
1339 |
1340 | # Prepare the result
1341 | result_embeddings = []
1342 | errors = []
1343 | all_dimensions = None
1344 |
1345 | try:
1346 | start_time = time.time()
1347 |
1348 | # Create a dedicated session for this embeddings request
1349 | async with aiohttp.ClientSession(**self.client_session_params) as session:
1350 | # Process each text individually (Ollama supports batching but we'll use same pattern as other providers)
1351 | for text in texts:
1352 | payload = {
1353 | "model": model_id,
1354 | "prompt": text,
1355 | }
1356 |
1357 | # Add any additional parameters
1358 | for key, value in kwargs.items():
1359 | if key not in payload and value is not None:
1360 | payload[key] = value
1361 |
1362 | try:
1363 | async with session.post(
1364 | self._build_api_url("embeddings"), json=payload, timeout=30.0
1365 | ) as response:
1366 | if response.status != 200:
1367 | error_text = await response.text()
1368 | errors.append(f"Ollama API error: {response.status} - {error_text}")
1369 | # Continue with the next text
1370 | continue
1371 |
1372 | data = await response.json()
1373 |
1374 | # Extract embeddings
1375 | embedding = data.get("embedding", [])
1376 |
1377 | if not embedding:
1378 | errors.append(f"No embedding returned for text: {text[:50]}...")
1379 | continue
1380 |
1381 | # Store the embedding
1382 | result_embeddings.append(embedding)
1383 |
1384 | # Check dimensions for consistency
1385 | dimensions = len(embedding)
1386 | if all_dimensions is None:
1387 | all_dimensions = dimensions
1388 | elif dimensions != all_dimensions:
1389 | errors.append(
1390 | f"Inconsistent embedding dimensions: got {dimensions}, expected {all_dimensions}"
1391 | )
1392 | except aiohttp.ClientConnectionError as e:
1393 | errors.append(
1394 | f"Connection to Ollama failed: {str(e)}. Make sure Ollama is running and accessible."
1395 | )
1396 | break
1397 | except asyncio.TimeoutError:
1398 | errors.append(
1399 | "Connection to Ollama timed out. Check if the service is overloaded."
1400 | )
1401 | break
1402 | except Exception as e:
1403 | errors.append(f"Error generating embedding: {str(e)}")
1404 | continue
1405 |
1406 | # Calculate processing time
1407 | processing_time = time.time() - start_time
1408 |
1409 | # Calculate cost (estimated)
1410 | estimated_cost = (total_tokens / 1_000_000) * 0.0001 # Very low cost estimation
1411 |
1412 | # Create response model with embeddings in metadata
1413 | return ModelResponse(
1414 | text="", # Embeddings don't have text content
1415 | model=f"{self.provider_name}/{model_id}",
1416 | provider=self.provider_name,
1417 | input_tokens=total_tokens, # Use total tokens as input tokens for embeddings
1418 | output_tokens=0, # No output tokens for embeddings
1419 | total_tokens=total_tokens,
1420 | processing_time=processing_time,
1421 | metadata={
1422 | "embeddings": result_embeddings,
1423 | "dimensions": all_dimensions or 0,
1424 | "errors": errors if errors else None,
1425 | "cost": estimated_cost,
1426 | },
1427 | )
1428 |
1429 | except aiohttp.ClientConnectionError as e:
1430 | # Return a clear error without raising an exception
1431 | return ModelResponse(
1432 | text="",
1433 | model=f"{self.provider_name}/{model_id}",
1434 | provider=self.provider_name,
1435 | input_tokens=0,
1436 | output_tokens=0,
1437 | total_tokens=0,
1438 | processing_time=0.0,
1439 | metadata={
1440 | "error": f"Connection to Ollama failed: {str(e)}. Make sure Ollama is running and accessible.",
1441 | "embeddings": [],
1442 | "cost": 0.0,
1443 | },
1444 | )
1445 | except Exception as e:
1446 | # Return a clear error without raising an exception
1447 | if isinstance(e, ProviderError):
1448 | raise
1449 | return ModelResponse(
1450 | text="",
1451 | model=f"{self.provider_name}/{model_id}",
1452 | provider=self.provider_name,
1453 | input_tokens=0,
1454 | output_tokens=0,
1455 | total_tokens=0,
1456 | processing_time=0.0,
1457 | metadata={
1458 | "error": f"Error generating embeddings: {str(e)}",
1459 | "embeddings": result_embeddings,
1460 | "cost": 0.0,
1461 | },
1462 | )
1463 |
1464 | def _extract_json_from_text(self, text: str) -> str:
1465 | """Extract JSON content from text that might include markdown code blocks or explanatory text.
1466 |
1467 | Args:
1468 | text: The raw text response that might contain JSON
1469 |
1470 | Returns:
1471 | Cleaned JSON content
1472 | """
1473 |
1474 | # First check if the text is already valid JSON
1475 | try:
1476 | json.loads(text)
1477 | return text # Already valid JSON
1478 | except json.JSONDecodeError:
1479 | pass # Continue with extraction
1480 |
1481 | # Extract JSON from code blocks - common pattern
1482 | code_block_match = re.search(r'```(?:json)?\s*([\s\S]*?)```', text)
1483 | if code_block_match:
1484 | code_content = code_block_match.group(1).strip()
1485 | try:
1486 | json.loads(code_content)
1487 | return code_content
1488 | except json.JSONDecodeError:
1489 | # Try to fix common JSON syntax issues like trailing commas
1490 | fixed_content = re.sub(r',\s*([}\]])', r'\1', code_content)
1491 | try:
1492 | json.loads(fixed_content)
1493 | return fixed_content
1494 | except json.JSONDecodeError:
1495 | pass # Continue with other extraction methods
1496 |
1497 | # Look for JSON array or object patterns in the content
1498 | # Find the first [ or { and the matching closing ] or }
1499 | stripped = text.strip()
1500 |
1501 | # Try to extract array
1502 | if '[' in stripped and ']' in stripped:
1503 | start = stripped.find('[')
1504 | # Find the matching closing bracket
1505 | end = -1
1506 | depth = 0
1507 | for i in range(start, len(stripped)):
1508 | if stripped[i] == '[':
1509 | depth += 1
1510 | elif stripped[i] == ']':
1511 | depth -= 1
1512 | if depth == 0:
1513 | end = i + 1
1514 | break
1515 |
1516 | if end > start:
1517 | array_content = stripped[start:end]
1518 | try:
1519 | json.loads(array_content)
1520 | return array_content
1521 | except json.JSONDecodeError:
1522 | pass # Try other methods
1523 |
1524 | # Try to extract object
1525 | if '{' in stripped and '}' in stripped:
1526 | start = stripped.find('{')
1527 | # Find the matching closing bracket
1528 | end = -1
1529 | depth = 0
1530 | for i in range(start, len(stripped)):
1531 | if stripped[i] == '{':
1532 | depth += 1
1533 | elif stripped[i] == '}':
1534 | depth -= 1
1535 | if depth == 0:
1536 | end = i + 1
1537 | break
1538 |
1539 | if end > start:
1540 | object_content = stripped[start:end]
1541 | try:
1542 | json.loads(object_content)
1543 | return object_content
1544 | except json.JSONDecodeError:
1545 | pass # Try other methods
1546 |
1547 | # If all else fails, return the original text
1548 | return text
1549 |
1550 | async def _generate_completion_via_streaming(
1551 | self,
1552 | prompt: Optional[str] = None,
1553 | messages: Optional[List[Dict[str, Any]]] = None,
1554 | model: Optional[str] = None,
1555 | max_tokens: Optional[int] = None,
1556 | temperature: float = 0.7,
1557 | stop: Optional[List[str]] = None,
1558 | top_p: Optional[float] = None,
1559 | top_k: Optional[int] = None,
1560 | frequency_penalty: Optional[float] = None,
1561 | presence_penalty: Optional[float] = None,
1562 | mirostat: Optional[int] = None,
1563 | mirostat_tau: Optional[float] = None,
1564 | mirostat_eta: Optional[float] = None,
1565 | system: Optional[str] = None,
1566 | json_mode: bool = False, # Add json_mode parameter to pass it through to streaming method
1567 | **kwargs: Any,
1568 | ) -> ModelResponse:
1569 | """Generate a completion via streaming and collect the results.
1570 |
1571 | This is a workaround for Ollama's inconsistent behavior with JSON mode
1572 | in non-streaming completions. It uses the streaming API which works reliably
1573 | with JSON mode, and collects all chunks into a single result.
1574 |
1575 | Args:
1576 | Same as generate_completion and generate_completion_stream
1577 |
1578 | Returns:
1579 | ModelResponse: The complete response
1580 | """
1581 | self.logger.debug("Using streaming method internally to handle JSON mode reliably")
1582 |
1583 | # Start the streaming generator
1584 | stream_gen = self.generate_completion_stream(
1585 | prompt=prompt,
1586 | messages=messages,
1587 | model=model,
1588 | max_tokens=max_tokens,
1589 | temperature=temperature,
1590 | stop=stop,
1591 | top_p=top_p,
1592 | top_k=top_k,
1593 | frequency_penalty=frequency_penalty,
1594 | presence_penalty=presence_penalty,
1595 | mirostat=mirostat,
1596 | mirostat_tau=mirostat_tau,
1597 | mirostat_eta=mirostat_eta,
1598 | system=system,
1599 | json_mode=json_mode,
1600 | **kwargs
1601 | )
1602 |
1603 | # Collect all text chunks
1604 | combined_text = ""
1605 | metadata = {}
1606 | input_tokens = 0
1607 | output_tokens = 0
1608 | processing_time = 0
1609 |
1610 | try:
1611 | async for chunk, chunk_metadata in stream_gen:
1612 | if chunk_metadata.get("error"):
1613 | # If there's an error, raise it
1614 | raise RuntimeError(chunk_metadata["error"])
1615 |
1616 | # Add current chunk to result
1617 | combined_text += chunk
1618 |
1619 | # If this is the final chunk with stats, save the metadata
1620 | if chunk_metadata.get("finished", False):
1621 | metadata = chunk_metadata
1622 | input_tokens = chunk_metadata.get("input_tokens", 0)
1623 | output_tokens = chunk_metadata.get("output_tokens", 0)
1624 | processing_time = chunk_metadata.get("processing_time", 0)
1625 | except Exception as e:
1626 | # If streaming fails, re-raise the exception
1627 | raise RuntimeError(f"Error in streaming completion: {str(e)}") from e
1628 |
1629 | # Create a ModelResponse with the combined text
1630 | result = ModelResponse(
1631 | text=combined_text,
1632 | model=metadata.get("model", f"{self.provider_name}/{model or self.get_default_model()}"),
1633 | provider=self.provider_name,
1634 | input_tokens=input_tokens,
1635 | output_tokens=output_tokens,
1636 | processing_time=processing_time,
1637 | raw_response={"streaming_source": True, "metadata": metadata}
1638 | )
1639 |
1640 | # Add message field for chat_completion compatibility
1641 | result.message = {"role": "assistant", "content": combined_text}
1642 |
1643 | return result
1644 |
```