This is page 22 of 35. Use http://codebase.md/dicklesworthstone/llm_gateway_mcp_server?page={x} to view the full context.
# Directory Structure
```
├── .cursorignore
├── .env.example
├── .envrc
├── .gitignore
├── additional_features.md
├── check_api_keys.py
├── completion_support.py
├── comprehensive_test.py
├── docker-compose.yml
├── Dockerfile
├── empirically_measured_model_speeds.json
├── error_handling.py
├── example_structured_tool.py
├── examples
│ ├── __init__.py
│ ├── advanced_agent_flows_using_unified_memory_system_demo.py
│ ├── advanced_extraction_demo.py
│ ├── advanced_unified_memory_system_demo.py
│ ├── advanced_vector_search_demo.py
│ ├── analytics_reporting_demo.py
│ ├── audio_transcription_demo.py
│ ├── basic_completion_demo.py
│ ├── cache_demo.py
│ ├── claude_integration_demo.py
│ ├── compare_synthesize_demo.py
│ ├── cost_optimization.py
│ ├── data
│ │ ├── sample_event.txt
│ │ ├── Steve_Jobs_Introducing_The_iPhone_compressed.md
│ │ └── Steve_Jobs_Introducing_The_iPhone_compressed.mp3
│ ├── docstring_refiner_demo.py
│ ├── document_conversion_and_processing_demo.py
│ ├── entity_relation_graph_demo.py
│ ├── filesystem_operations_demo.py
│ ├── grok_integration_demo.py
│ ├── local_text_tools_demo.py
│ ├── marqo_fused_search_demo.py
│ ├── measure_model_speeds.py
│ ├── meta_api_demo.py
│ ├── multi_provider_demo.py
│ ├── ollama_integration_demo.py
│ ├── prompt_templates_demo.py
│ ├── python_sandbox_demo.py
│ ├── rag_example.py
│ ├── research_workflow_demo.py
│ ├── sample
│ │ ├── article.txt
│ │ ├── backprop_paper.pdf
│ │ ├── buffett.pdf
│ │ ├── contract_link.txt
│ │ ├── legal_contract.txt
│ │ ├── medical_case.txt
│ │ ├── northwind.db
│ │ ├── research_paper.txt
│ │ ├── sample_data.json
│ │ └── text_classification_samples
│ │ ├── email_classification.txt
│ │ ├── news_samples.txt
│ │ ├── product_reviews.txt
│ │ └── support_tickets.txt
│ ├── sample_docs
│ │ └── downloaded
│ │ └── attention_is_all_you_need.pdf
│ ├── sentiment_analysis_demo.py
│ ├── simple_completion_demo.py
│ ├── single_shot_synthesis_demo.py
│ ├── smart_browser_demo.py
│ ├── sql_database_demo.py
│ ├── sse_client_demo.py
│ ├── test_code_extraction.py
│ ├── test_content_detection.py
│ ├── test_ollama.py
│ ├── text_classification_demo.py
│ ├── text_redline_demo.py
│ ├── tool_composition_examples.py
│ ├── tournament_code_demo.py
│ ├── tournament_text_demo.py
│ ├── unified_memory_system_demo.py
│ ├── vector_search_demo.py
│ ├── web_automation_instruction_packs.py
│ └── workflow_delegation_demo.py
├── LICENSE
├── list_models.py
├── marqo_index_config.json.example
├── mcp_protocol_schema_2025-03-25_version.json
├── mcp_python_lib_docs.md
├── mcp_tool_context_estimator.py
├── model_preferences.py
├── pyproject.toml
├── quick_test.py
├── README.md
├── resource_annotations.py
├── run_all_demo_scripts_and_check_for_errors.py
├── storage
│ └── smart_browser_internal
│ ├── locator_cache.db
│ ├── readability.js
│ └── storage_state.enc
├── test_client.py
├── test_connection.py
├── TEST_README.md
├── test_sse_client.py
├── test_stdio_client.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── integration
│ │ ├── __init__.py
│ │ └── test_server.py
│ ├── manual
│ │ ├── test_extraction_advanced.py
│ │ └── test_extraction.py
│ └── unit
│ ├── __init__.py
│ ├── test_cache.py
│ ├── test_providers.py
│ └── test_tools.py
├── TODO.md
├── tool_annotations.py
├── tools_list.json
├── ultimate_mcp_banner.webp
├── ultimate_mcp_logo.webp
├── ultimate_mcp_server
│ ├── __init__.py
│ ├── __main__.py
│ ├── cli
│ │ ├── __init__.py
│ │ ├── __main__.py
│ │ ├── commands.py
│ │ ├── helpers.py
│ │ └── typer_cli.py
│ ├── clients
│ │ ├── __init__.py
│ │ ├── completion_client.py
│ │ └── rag_client.py
│ ├── config
│ │ └── examples
│ │ └── filesystem_config.yaml
│ ├── config.py
│ ├── constants.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── evaluation
│ │ │ ├── base.py
│ │ │ └── evaluators.py
│ │ ├── providers
│ │ │ ├── __init__.py
│ │ │ ├── anthropic.py
│ │ │ ├── base.py
│ │ │ ├── deepseek.py
│ │ │ ├── gemini.py
│ │ │ ├── grok.py
│ │ │ ├── ollama.py
│ │ │ ├── openai.py
│ │ │ └── openrouter.py
│ │ ├── server.py
│ │ ├── state_store.py
│ │ ├── tournaments
│ │ │ ├── manager.py
│ │ │ ├── tasks.py
│ │ │ └── utils.py
│ │ └── ums_api
│ │ ├── __init__.py
│ │ ├── ums_database.py
│ │ ├── ums_endpoints.py
│ │ ├── ums_models.py
│ │ └── ums_services.py
│ ├── exceptions.py
│ ├── graceful_shutdown.py
│ ├── services
│ │ ├── __init__.py
│ │ ├── analytics
│ │ │ ├── __init__.py
│ │ │ ├── metrics.py
│ │ │ └── reporting.py
│ │ ├── cache
│ │ │ ├── __init__.py
│ │ │ ├── cache_service.py
│ │ │ ├── persistence.py
│ │ │ ├── strategies.py
│ │ │ └── utils.py
│ │ ├── cache.py
│ │ ├── document.py
│ │ ├── knowledge_base
│ │ │ ├── __init__.py
│ │ │ ├── feedback.py
│ │ │ ├── manager.py
│ │ │ ├── rag_engine.py
│ │ │ ├── retriever.py
│ │ │ └── utils.py
│ │ ├── prompts
│ │ │ ├── __init__.py
│ │ │ ├── repository.py
│ │ │ └── templates.py
│ │ ├── prompts.py
│ │ └── vector
│ │ ├── __init__.py
│ │ ├── embeddings.py
│ │ └── vector_service.py
│ ├── tool_token_counter.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── audio_transcription.py
│ │ ├── base.py
│ │ ├── completion.py
│ │ ├── docstring_refiner.py
│ │ ├── document_conversion_and_processing.py
│ │ ├── enhanced-ums-lookbook.html
│ │ ├── entity_relation_graph.py
│ │ ├── excel_spreadsheet_automation.py
│ │ ├── extraction.py
│ │ ├── filesystem.py
│ │ ├── html_to_markdown.py
│ │ ├── local_text_tools.py
│ │ ├── marqo_fused_search.py
│ │ ├── meta_api_tool.py
│ │ ├── ocr_tools.py
│ │ ├── optimization.py
│ │ ├── provider.py
│ │ ├── pyodide_boot_template.html
│ │ ├── python_sandbox.py
│ │ ├── rag.py
│ │ ├── redline-compiled.css
│ │ ├── sentiment_analysis.py
│ │ ├── single_shot_synthesis.py
│ │ ├── smart_browser.py
│ │ ├── sql_databases.py
│ │ ├── text_classification.py
│ │ ├── text_redline_tools.py
│ │ ├── tournament.py
│ │ ├── ums_explorer.html
│ │ └── unified_memory_system.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── async_utils.py
│ │ ├── display.py
│ │ ├── logging
│ │ │ ├── __init__.py
│ │ │ ├── console.py
│ │ │ ├── emojis.py
│ │ │ ├── formatter.py
│ │ │ ├── logger.py
│ │ │ ├── panels.py
│ │ │ ├── progress.py
│ │ │ └── themes.py
│ │ ├── parse_yaml.py
│ │ ├── parsing.py
│ │ ├── security.py
│ │ └── text.py
│ └── working_memory_api.py
├── unified_memory_system_technical_analysis.md
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/ultimate_mcp_server/tools/python_sandbox.py:
--------------------------------------------------------------------------------
```python
# ultimate_mcp_server/tools/python_sandbox.py
"""Pyodide-backed sandbox tool for Ultimate MCP Server.
Provides a secure environment for executing Python code within a headless browser,
with stdout/stderr capture, package management, security controls, and optional REPL functionality.
Includes integrated offline asset caching for Pyodide.
"""
###############################################################################
# Standard library & typing
###############################################################################
import argparse
import asyncio
import atexit
import base64
import collections
import gzip
import hashlib
import json
import logging # Import logging for fallback
import mimetypes
import os
import pathlib
import time
import urllib.error
import urllib.parse
import urllib.request
import uuid
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Optional, OrderedDict
# ---------------------------------
###############################################################################
# Third‑party – runtime dependency only on Playwright
###############################################################################
try:
import playwright.async_api as pw
if TYPE_CHECKING:
# Import SPECIFIC types for type hints inside TYPE_CHECKING
from playwright.async_api import Browser, Page, Request, Route
PLAYWRIGHT_AVAILABLE = True
except ImportError:
pw = None
# Define placeholder types ONLY if playwright is unavailable,
# and inside TYPE_CHECKING if you still want hints to reference *something*
# Although the imports above should handle this for the type checker.
if TYPE_CHECKING:
Browser = Any
Page = Any
Route = Any
Request = Any
PLAYWRIGHT_AVAILABLE = False
from rich import box
from rich.markup import escape
from rich.panel import Panel
from rich.rule import Rule
from rich.syntax import Syntax
from rich.table import Table
###############################################################################
# Project specific imports
###############################################################################
# Assuming these are correctly located within your project structure
try:
from ultimate_mcp_server.constants import TaskType
from ultimate_mcp_server.exceptions import (
ProviderError,
ToolError,
ToolInputError,
)
from ultimate_mcp_server.tools.base import with_error_handling, with_tool_metrics
from ultimate_mcp_server.utils import get_logger
except ImportError as e:
# Provide a fallback or clearer error if these imports fail
print(f"WARNING: Failed to import Ultimate MCP Server components: {e}")
print("Running in standalone mode or environment misconfiguration.")
# Define dummy logger/decorators if running standalone for preloading
def get_logger(name):
_logger = logging.getLogger(name)
if not _logger.handlers: # Setup basic config only if no handlers exist
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
return _logger
def with_tool_metrics(func):
return func
def with_error_handling(func):
return func
# Define dummy exceptions
class ProviderError(Exception):
pass
class ToolError(Exception):
pass
class ToolInputError(Exception):
pass
class TaskType:
CODE_EXECUTION = "code_execution" # Dummy enum value
from ultimate_mcp_server.utils.logging.console import console
logger = get_logger("ultimate_mcp_server.tools.python_sandbox")
# Constant for posting messages back to the sandbox page context
JS_POST_MESSAGE = "(msg) => globalThis.postMessage(msg, '*')"
###############################################################################
# Constants & Caching Configuration
###############################################################################
COMMON_PACKAGES: list[str] = [
"numpy",
"pandas",
"matplotlib",
"scipy",
"networkx",
]
# Define JSON string *after* COMMON_PACKAGES is defined
COMMON_PACKAGES_JSON = json.dumps(COMMON_PACKAGES)
MAX_SANDBOXES = 6 # Max number of concurrent browser tabs/sandboxes
GLOBAL_CONCURRENCY = 8 # Max number of simultaneous code executions across all sandboxes
MEM_LIMIT_MB = 512 # Memory limit for the heap watchdog in the browser tab
# --- Pyodide Version and CDN ---
_PYODIDE_VERSION = "0.27.5" # <<< Ensure this matches the intended version
_CDN_BASE = f"https://cdn.jsdelivr.net/pyodide/v{_PYODIDE_VERSION}/full"
# Note: PYODIDE_CDN variable might not be strictly necessary if importing .mjs directly
PYODIDE_CDN = f"{_CDN_BASE}/pyodide.js"
# --- Define the packages to be loaded AT STARTUP ---
# These will be baked into the loadPyodide call via the template
CORE_PACKAGES_TO_LOAD_AT_STARTUP: list[str] = [
"numpy",
"pandas",
"matplotlib",
"scipy",
"networkx",
"micropip", # Good to include if you often load wheels later
]
# Generate the JSON string to be injected into the HTML template
CORE_PACKAGES_JSON_FOR_TEMPLATE = json.dumps(CORE_PACKAGES_TO_LOAD_AT_STARTUP)
# --- Asset Caching Configuration ---
_CACHE_DIR = (
pathlib.Path(os.getenv("XDG_CACHE_HOME", "~/.cache")).expanduser()
/ "ultimate_mcp_server"
/ "pyodide"
/ _PYODIDE_VERSION # Versioned cache directory
)
try:
_CACHE_DIR.mkdir(parents=True, exist_ok=True)
logger.info(f"Using Pyodide asset cache directory: {_CACHE_DIR}")
except OSError as e:
logger.error(
f"Failed to create Pyodide asset cache directory {_CACHE_DIR}: {e}. Caching might fail."
)
################################################################################
# Diagnostic logging helpers
################################################################################
# level 0 = quiet, 1 = basic req/resp, 2 = full body/hex dumps
_VERBOSE_SANDBOX_LOGGING = int(os.getenv("PYODIDE_SANDBOX_DEBUG", "0") or 0)
def _wire_page_logging(page: "Page", session_id: str) -> None: # type: ignore
"""
Mirrors everything interesting coming out of the browser tab back into our
Python logger. When PYODIDE_SANDBOX_DEBUG=2 we also dump request/response
headers and first 64 bytes of every body.
"""
# ───────── console / JS errors ───────────────────────────────────────────
def _log_console(msg):
try:
# Safely access properties, defaulting if necessary
lvl = msg.type if not callable(getattr(msg, "type", None)) else msg.type()
txt = msg.text if not callable(getattr(msg, "text", None)) else msg.text()
loc = msg.location if not callable(getattr(msg, "location", None)) else msg.location()
src = ""
if isinstance(loc, dict):
src = f"{loc.get('url', '')}:{loc.get('lineNumber', '?')}:{loc.get('columnNumber', '?')}"
elif loc:
src = str(loc)
line = f"SB[{session_id}] {src} ▶ {txt}" if src else f"SB[{session_id}] ▶ {txt}"
log_func = {
"error": logger.error,
"warning": logger.warning,
"warn": logger.warning,
"info": logger.info,
"log": logger.info,
"debug": logger.debug,
"trace": logger.debug,
}.get(str(lvl).lower(), logger.debug)
log_func(line)
except Exception as e:
logger.error(f"SB[{session_id}] Error in console message processing: {e}")
try:
page.on("console", _log_console)
page.on(
"pageerror",
lambda e: logger.error(f"SB[{session_id}] PageError ▶ {e.message}\n{e.stack}"),
)
page.on("crash", lambda: logger.critical(f"SB[{session_id}] **PAGE CRASHED**"))
except Exception as e:
logger.error(f"SB[{session_id}] Failed to attach basic page log listeners: {e}")
# ───────── high-level net trace ─────────────────────────────────────────
if _VERBOSE_SANDBOX_LOGGING > 0:
try:
page.on("request", lambda r: logger.debug(f"SB[{session_id}] → {r.method} {r.url}"))
page.on(
"requestfailed",
lambda r: logger.warning(f"SB[{session_id}] ✗ {r.method} {r.url} ▶ {r.failure}"),
)
async def _resp_logger(resp: "pw.Response"): # type: ignore # Use string literal hint
try:
status = resp.status
url = resp.url
if status == 200 and url.startswith("data:") and _VERBOSE_SANDBOX_LOGGING < 2:
return
# Use resp.all_headers() which returns a dict directly
hdrs = await resp.all_headers()
ce = hdrs.get("content-encoding", "")
ct = hdrs.get("content-type", "")
log_line = (
f"SB[{session_id}] ← {status} {url} (type='{ct}', enc='{ce or 'none'}')"
)
if _VERBOSE_SANDBOX_LOGGING > 1 or status >= 400: # Log body for errors too
try:
body = await resp.body()
sig = body[:64]
hexs = " ".join(f"{b:02x}" for b in sig)
log_line += f" (len={len(body)}, first-64: {hexs})"
except Exception as body_err:
# Handle cases where body might not be available (e.g., redirects)
log_line += f" (body unavailable: {body_err})"
logger.debug(log_line)
except Exception as e:
logger.warning(f"SB[{session_id}] Error in response logger: {e}")
page.on("response", lambda r: asyncio.create_task(_resp_logger(r)))
except Exception as e:
logger.error(f"SB[{session_id}] Failed to attach network trace log listeners: {e}")
###############################################################################
# Asset Caching Helper Functions (Integrated)
###############################################################################
def _local_path(remote_url: str) -> pathlib.Path:
"""Generates the local cache path for a given remote URL."""
try:
parsed_url = urllib.parse.urlparse(remote_url)
path_part = parsed_url.path if parsed_url.path else "/"
fname = pathlib.Path(path_part).name
if not fname or fname == "/":
fname = hashlib.md5(remote_url.encode()).hexdigest() + ".cache"
logger.debug(
f"No filename in path '{path_part}', using hash '{fname}' for {remote_url}"
)
except Exception as e:
logger.warning(f"Error parsing URL '{remote_url}' for filename: {e}. Falling back to hash.")
fname = hashlib.md5(remote_url.encode()).hexdigest() + ".cache"
return _CACHE_DIR / fname
def _fetch_asset_sync(remote_url: str, max_age_s: int = 7 * 24 * 3600) -> bytes:
"""
Synchronous version: Return requested asset from cache or download.
Used by Playwright interceptor and preloader.
"""
p = _local_path(remote_url)
use_cache = False
if p.exists():
try:
file_stat = p.stat()
file_age = time.time() - file_stat.st_mtime
if file_age < max_age_s:
if file_stat.st_size > 0:
logger.debug(
f"[Cache] HIT for {remote_url} (age: {file_age:.0f}s < {max_age_s}s)"
)
use_cache = True
else:
logger.warning(
f"[Cache] Hit for {remote_url}, but file is empty. Re-downloading."
)
else:
logger.info(
f"[Cache] STALE for {remote_url} (age: {file_age:.0f}s >= {max_age_s}s)"
)
except OSError as e:
logger.warning(f"[Cache] Error accessing cache file {p}: {e}. Will attempt download.")
if use_cache:
try:
return p.read_bytes()
except OSError as e:
logger.warning(f"[Cache] Error reading cached file {p}: {e}. Will attempt download.")
logger.info(f"[Cache] MISS or STALE/Error for {remote_url}. Downloading...")
downloaded_data = None
try:
req = urllib.request.Request(
remote_url,
headers={"User-Agent": "UltimateMCPServer-AssetCache/1.0", "Accept-Encoding": "gzip"},
)
with urllib.request.urlopen(req, timeout=30) as resp:
if resp.status != 200:
raise urllib.error.HTTPError(
remote_url, resp.status, resp.reason, resp.headers, None
)
downloaded_data = resp.read()
# Handle potential gzip encoding from server
if resp.headers.get("Content-Encoding") == "gzip":
try:
downloaded_data = gzip.decompress(downloaded_data)
logger.debug(f"[Cache] Decompressed gzip response for {remote_url}")
except gzip.BadGzipFile:
logger.warning(
f"[Cache] Received gzip header but invalid gzip data for {remote_url}. Using raw."
)
except Exception as gz_err:
logger.warning(
f"[Cache] Error decompressing gzip for {remote_url}: {gz_err}. Using raw."
)
logger.info(
f"[Cache] Downloaded {len(downloaded_data)} bytes from {remote_url} (status: {resp.status})"
)
except urllib.error.HTTPError as e:
logger.warning(f"[Cache] HTTP error downloading {remote_url}: {e.code} {e.reason}")
if p.exists():
try:
stale_stat = p.stat()
if stale_stat.st_size > 0:
logger.warning(
f"[Cache] Using STALE cache file {p} as fallback due to HTTP {e.code}."
)
return p.read_bytes()
except OSError as read_err:
logger.error(
f"[Cache] Failed reading fallback cache {p} after download error: {read_err}"
)
raise RuntimeError(
f"Cannot download {remote_url} (HTTP {e.code}) and no usable cache available"
) from e
except urllib.error.URLError as e:
logger.warning(f"[Cache] Network error downloading {remote_url}: {e.reason}")
if p.exists():
try:
stale_stat = p.stat()
if stale_stat.st_size > 0:
logger.warning(
f"[Cache] Using STALE cache file {p} as fallback due to network error."
)
return p.read_bytes()
except OSError as read_err:
logger.error(
f"[Cache] Failed reading fallback cache {p} after network error: {read_err}"
)
raise RuntimeError(
f"Cannot download {remote_url} ({e.reason}) and no usable cache available"
) from e
except Exception as e:
logger.error(f"[Cache] Unexpected error downloading {remote_url}: {e}", exc_info=True)
if p.exists():
try:
stale_stat = p.stat()
if stale_stat.st_size > 0:
logger.warning(
f"[Cache] Using STALE cache file {p} as fallback due to unexpected error."
)
return p.read_bytes()
except OSError as read_err:
logger.error(
f"[Cache] Failed reading fallback cache {p} after unexpected error: {read_err}"
)
raise RuntimeError(
f"Cannot download {remote_url} (unexpected error: {e}) and no usable cache available"
) from e
if downloaded_data is not None:
try:
tmp_suffix = f".tmp_{os.getpid()}_{uuid.uuid4().hex[:6]}"
tmp_path = p.with_suffix(p.suffix + tmp_suffix)
tmp_path.write_bytes(downloaded_data)
tmp_path.replace(p)
logger.info(f"[Cache] Saved {len(downloaded_data)} bytes for {remote_url} to {p}")
except OSError as e:
logger.error(f"[Cache] Failed write cache file {p}: {e}")
return downloaded_data
else:
raise RuntimeError(f"Download completed for {remote_url} but data is None (internal error)")
###############################################################################
# Browser / bookkeeping singletons
###############################################################################
_BROWSER: Optional["Browser"] = None # type: ignore # Use string literal hint
_PAGES: OrderedDict[str, "PyodideSandbox"] = collections.OrderedDict()
_GLOBAL_SEM: Optional[asyncio.Semaphore] = None
###############################################################################
# PyodideSandbox Class Definition
###############################################################################
@dataclass(slots=True)
class PyodideSandbox:
"""One Chromium tab with Pyodide runtime (optionally persistent)."""
page: "Page" # type: ignore # Use string literal hint
allow_network: bool = False
allow_fs: bool = False
ready_evt: asyncio.Event = field(default_factory=asyncio.Event)
created_at: float = field(default_factory=time.time)
last_used: float = field(default_factory=time.time)
_init_timeout: int = 90
_message_handlers: Dict[str, asyncio.Queue] = field(default_factory=dict)
_init_queue: asyncio.Queue = field(default_factory=asyncio.Queue)
async def init(self):
"""Load boot HTML, set up messaging (direct callback), and wait for ready signal."""
if not PLAYWRIGHT_AVAILABLE:
raise RuntimeError("Playwright is not installed. Cannot initialize sandbox.")
logger.info(f"Initializing PyodideSandbox instance (Page: {self.page.url})...")
init_start_time = time.monotonic()
# === 1. Network Interception Setup ===
logger.debug("Setting up network interception...")
try:
# Define the interception logic inline or call an external helper
cdn_base_lower = _CDN_BASE.lower()
async def _block(route: "Route", request: "Request"): # type: ignore
url = request.url
low = url.lower()
is_cdn = low.startswith(cdn_base_lower)
is_pypi = "pypi.org/simple" in low or "files.pythonhosted.org" in low
if is_cdn:
try:
# Assuming _fetch_asset_sync is correctly defined elsewhere
body = _fetch_asset_sync(url)
ctype = mimetypes.guess_type(url)[0] or "application/octet-stream"
headers = {
"Content-Type": ctype,
"Access-Control-Allow-Origin": "*",
"Cache-Control": "public, max-age=31536000",
}
# Simple check for gzip magic bytes; don't decompress here, let browser handle it
if body.startswith(b"\x1f\x8b"):
headers["Content-Encoding"] = "gzip"
if _VERBOSE_SANDBOX_LOGGING > 1:
logger.debug(
f"[Intercept] FULFILL CDN {url} (type={ctype}, enc={headers.get('Content-Encoding', 'none')}, len={len(body)})"
)
await route.fulfill(status=200, body=body, headers=headers)
return
except Exception as exc:
logger.error(
f"[Intercept] FAILED serving CDN {url} from cache/download: {exc}",
exc_info=_VERBOSE_SANDBOX_LOGGING > 1,
)
await route.abort(error_code="failed")
return
# Allow PyPI only if explicitly enabled
if self.allow_network and is_pypi:
if _VERBOSE_SANDBOX_LOGGING > 0:
logger.debug(f"[Intercept] ALLOW PyPI {url}")
try:
await route.continue_()
except Exception as cont_err:
logger.warning(
f"[Intercept] Error continuing PyPI request {url}: {cont_err}"
)
try:
await route.abort(error_code="failed")
except Exception:
pass
return
# Block other network requests by default
# Log less aggressively for common browser noise
if not any(low.endswith(ext) for ext in [".ico", ".png", ".woff", ".woff2"]):
if _VERBOSE_SANDBOX_LOGGING > 0:
logger.debug(f"[Intercept] BLOCK {url}")
try:
await route.abort(error_code="blockedbyclient")
except Exception:
pass # Ignore errors aborting (e.g., already handled)
await self.page.route("**/*", _block)
logger.info("Network interception active.")
except Exception as e:
logger.error(f"Failed to set up network interception: {e}", exc_info=True)
await self._try_close_page("Network Intercept Setup Error")
raise ToolError(f"Failed to configure sandbox network rules: {e}") from e
# === 2. Load Boot HTML ===
logger.debug("Loading boot HTML template...")
try:
template_path = pathlib.Path(__file__).parent / "pyodide_boot_template.html"
if not template_path.is_file():
raise FileNotFoundError(f"Boot template not found at {template_path}")
boot_html_template = template_path.read_text(encoding="utf-8")
# Replace placeholders, including the CORE packages JSON
processed_boot_html = (
boot_html_template.replace("__CDN_BASE__", _CDN_BASE)
.replace("__PYODIDE_VERSION__", _PYODIDE_VERSION)
# *** Use the new constant for core packages ***
.replace("__CORE_PACKAGES_JSON__", CORE_PACKAGES_JSON_FOR_TEMPLATE)
.replace("__MEM_LIMIT_MB__", str(MEM_LIMIT_MB)) # Keep MEM_LIMIT if using watchdog
)
# Check essential placeholders
essential_placeholders = ["__CDN_BASE__", "__PYODIDE_VERSION__"]
# Check optional placeholders based on template features
optional_placeholders = ["__CORE_PACKAGES_JSON__", "__MEM_LIMIT_MB__"]
missing_essential = [p for p in essential_placeholders if p in processed_boot_html]
missing_optional = [p for p in optional_placeholders if p in processed_boot_html]
if missing_essential:
logger.critical(
f"CRITICAL: Essential placeholders missing in boot HTML: {missing_essential}. Aborting."
)
raise ToolError(
f"Essential placeholders missing in boot template: {missing_essential}"
)
if missing_optional:
logger.warning(
f"Optional placeholders missing in boot HTML: {missing_optional}. Check template if features are expected."
)
await self.page.set_content(
processed_boot_html,
wait_until="domcontentloaded",
timeout=60000, # Slightly longer timeout for package loading
)
logger.info("Boot HTML loaded into page.")
except FileNotFoundError as e:
logger.error(f"Failed to load boot HTML template: {e}", exc_info=True)
await self._try_close_page("Boot HTML Template Not Found")
raise ToolError(f"Could not find sandbox boot HTML template: {e}") from e
except Exception as e:
logger.error(f"Failed loading boot HTML content: {e}", exc_info=True)
await self._try_close_page("Boot HTML Load Error")
raise ToolError(f"Could not load sandbox boot HTML content: {e}") from e
# === 3. Setup Communication Channels ===
# This involves two parts:
# a) Exposing a Python function for JS to send *replies* directly.
# b) Exposing a Python function for JS to send the initial *ready/error* signal.
# --- 3a. Setup for Execution Replies ---
logger.debug("Setting up direct reply mechanism (JS->Python)...")
try:
# This Python function will be called by JavaScript's `window._deliverReplyToHost(reply)`
async def _deliver_reply_to_host(payload: Any):
msg_id = None # Define outside try block
try:
if not isinstance(payload, dict):
if _VERBOSE_SANDBOX_LOGGING > 1:
logger.debug(f"Host received non-dict reply payload: {type(payload)}")
return
data = payload
msg_id = data.get("id")
if not msg_id:
logger.warning(f"Host received reply payload without an ID: {data}")
return
# Log received reply
if _VERBOSE_SANDBOX_LOGGING > 0:
log_detail = (
f"ok={data.get('ok')}"
if _VERBOSE_SANDBOX_LOGGING == 1
else json.dumps(data, default=str)
)
logger.debug(
f"Host received reply via exposed function (id: {msg_id}): {log_detail}"
)
# Route reply to the waiting asyncio Queue in _message_handlers
if msg_id in self._message_handlers:
await self._message_handlers[msg_id].put(data)
if _VERBOSE_SANDBOX_LOGGING > 0:
logger.debug(f"Reply payload for ID {msg_id} routed.")
elif _VERBOSE_SANDBOX_LOGGING > 0:
logger.debug(
f"Host received reply for unknown/stale execution ID: {msg_id}"
)
except Exception as e:
logger.error(
f"Error processing execution reply payload (id: {msg_id or 'unknown'}) from sandbox: {e}",
exc_info=True,
)
reply_handler_name = "_deliverReplyToHost" # Must match the name called in JS template
await self.page.expose_function(reply_handler_name, _deliver_reply_to_host)
logger.info(f"Python function '{reply_handler_name}' exposed for JS execution replies.")
except Exception as e:
logger.error(f"Failed to expose reply handler function: {e}", exc_info=True)
await self._try_close_page("Reply Handler Setup Error")
raise ToolError(f"Could not expose reply handler to sandbox: {e}") from e
# --- 3b. Setup for Initial Ready/Error Signal ---
# The JS template sends the initial 'pyodide_ready' or 'pyodide_init_error' via postMessage.
# We need a way to capture *only* that specific message and put it on _init_queue.
logger.debug("Setting up listener for initial ready/error signal (JS->Python)...")
try:
# This Python function will be called by the JS listener below
async def _handle_initial_message(payload: Any):
try:
if not isinstance(payload, dict):
return # Ignore non-dicts
msg_id = payload.get("id")
if msg_id == "pyodide_ready" or msg_id == "pyodide_init_error":
log_level = logger.info if payload.get("ready") else logger.error
log_level(
f"Received initial status message from sandbox via exposed function: {payload}"
)
await self._init_queue.put(payload) # Put it on the init queue
# Optionally remove the listener after receiving the first signal? Might be risky.
# Ignore other messages potentially caught by this listener
except Exception as e:
logger.error(
f"Error processing initial message from sandbox: {e}", exc_info=True
)
# Put an error onto the queue to unblock init waiter
await self._init_queue.put(
{
"id": "pyodide_init_error",
"ok": False,
"error": {
"type": "HostProcessingError",
"message": f"Error handling init message: {e}",
},
}
)
init_handler_name = "_handleInitialMessage"
await self.page.expose_function(init_handler_name, _handle_initial_message)
# Evaluate JavaScript to add a *specific* listener that calls the exposed init handler
await self.page.evaluate(f"""
console.log('[PyodideBoot] Adding specific listener for initial ready/error messages...');
// Ensure we don't add multiple listeners if init is somehow re-run
if (!window._initialMessageListenerAdded) {{
window.addEventListener('message', (event) => {{
const data = event.data;
// Check if the exposed function exists and if it's the specific message we want
if (typeof window.{init_handler_name} === 'function' &&
typeof data === 'object' && data !== null &&
(data.id === 'pyodide_ready' || data.id === 'pyodide_init_error'))
{{
// Forward only specific initial messages to the exposed Python function
console.log('[PyodideBoot] Forwarding initial message to host:', data.id);
window.{init_handler_name}(data);
}}
}});
window._initialMessageListenerAdded = true; // Flag to prevent multiple adds
console.log('[PyodideBoot] Initial message listener added.');
}} else {{
console.log('[PyodideBoot] Initial message listener already added.');
}}
""")
logger.info(
f"Python function '{init_handler_name}' exposed and JS listener added for initial signal."
)
except Exception as e:
logger.error(f"Failed to set up initial signal listener: {e}", exc_info=True)
await self._try_close_page("Initial Signal Listener Setup Error")
raise ToolError(f"Could not set up initial signal listener: {e}") from e
# === 4. Wait for Ready Signal ===
logger.info(f"Waiting for sandbox ready signal (timeout: {self._init_timeout}s)...")
try:
# Wait for a message to appear on the _init_queue
init_data = await asyncio.wait_for(self._init_queue.get(), timeout=self._init_timeout)
# Check the content of the message
if init_data.get("id") == "pyodide_init_error" or init_data.get("ok") is False:
error_details = init_data.get(
"error", {"message": "Unknown initialization error reported by sandbox."}
)
error_msg = error_details.get("message", "Unknown Error")
logger.error(f"Pyodide sandbox initialization failed inside browser: {error_msg}")
await self._try_close_page("Initialization Error Reported by JS")
raise ToolError(f"Pyodide sandbox initialization failed: {error_msg}")
if not init_data.get("ready"):
logger.error(f"Received unexpected init message without 'ready' flag: {init_data}")
await self._try_close_page("Unexpected Init Message from JS")
raise ToolError("Received unexpected initialization message from sandbox.")
# If we received the correct ready message
self.ready_evt.set() # Set the event flag
boot_ms_reported = init_data.get("boot_ms", "N/A")
init_duration = time.monotonic() - init_start_time
logger.info(
f"Pyodide sandbox ready signal received (reported boot: {boot_ms_reported}ms, total init wait: {init_duration:.2f}s)"
)
except asyncio.TimeoutError as e:
logger.error(f"Timeout ({self._init_timeout}s) waiting for Pyodide ready signal.")
await self._check_page_responsiveness(
"Timeout Waiting for Ready"
) # Check if page is stuck
await self._try_close_page("Timeout Waiting for Ready")
raise ToolError(
f"Sandbox failed to initialize within timeout ({self._init_timeout}s)."
) from e
except Exception as e:
# Catch other errors during the wait/processing phase
logger.error(f"Error during sandbox initialization wait: {e}", exc_info=True)
await self._try_close_page("Initialization Wait Error")
if isinstance(e, ToolError): # Don't wrap existing ToolErrors
raise e
raise ToolError(f"Unexpected error during sandbox initialization wait: {e}") from e
async def _check_page_responsiveness(self, context: str) -> bool: # Return boolean
"""Tries to evaluate a simple JS command to check if the page is alive."""
if self.page and not self.page.is_closed():
try:
await asyncio.wait_for(self.page.evaluate("1+1"), timeout=5.0)
logger.debug(f"Page responded after {context}.")
return True
except Exception as page_err:
logger.error(f"Page seems unresponsive after {context}: {page_err}")
return False
else:
logger.debug(
f"Page already closed or non-existent during responsiveness check ({context})."
)
return False # Not responsive if closed
async def _try_close_page(self, reason: str):
"""Attempts to close the sandbox page, logging errors."""
if self.page and not self.page.is_closed():
logger.info(f"Attempting to close sandbox page due to: {reason}")
try:
await self.page.close()
logger.info(f"Sandbox page closed successfully after {reason}.")
except Exception as close_err:
logger.warning(f"Error closing page after {reason}: {close_err}")
else:
logger.debug(
f"Page already closed or non-existent when trying to close due to: {reason}"
)
async def execute(
self,
code: str,
packages: list[str] | None,
wheels: list[str] | None,
timeout_ms: int,
repl_mode: bool = False,
) -> Dict[str, Any]:
"""Sends code to the sandbox for execution and returns the result."""
if not PLAYWRIGHT_AVAILABLE:
# This condition should ideally be checked before creating/getting the sandbox
# but is included here for robustness.
raise ToolError("Playwright is not installed.")
if not self.page or self.page.is_closed():
raise ToolError("Cannot execute code: Sandbox page is closed.")
if not self.ready_evt.is_set():
# Wait briefly for the ready event if it's not set yet, in case of race conditions
try:
await asyncio.wait_for(self.ready_evt.wait(), timeout=1.0)
except asyncio.TimeoutError as e:
raise ToolError(
"Cannot execute code: Sandbox is not ready (or timed out becoming ready)."
) from e
self.last_used = time.time()
global _GLOBAL_SEM
if _GLOBAL_SEM is None:
# Initialize if it hasn't been already (should be done in _get_sandbox, but safety check)
logger.warning("Global execution semaphore not initialized, initializing now.")
_GLOBAL_SEM = asyncio.Semaphore(GLOBAL_CONCURRENCY)
# Acquire the semaphore to limit concurrency across all sandboxes
async with _GLOBAL_SEM:
exec_id = f"exec-{uuid.uuid4().hex[:8]}"
response_queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue()
self._message_handlers[exec_id] = response_queue
try:
# Encode the user's Python code to Base64
code_b64 = base64.b64encode(code.encode("utf-8")).decode("ascii")
except Exception as enc_err:
# If encoding fails, it's an input error, no need to involve the sandbox
self._message_handlers.pop(exec_id, None) # Clean up handler
raise ToolInputError(f"Failed to encode code to base64: {enc_err}") from enc_err
# Prepare the message payload for the JavaScript side
payload = {
"type": "exec",
"id": exec_id,
"code_b64": code_b64,
"packages": packages or [],
"wheels": wheels or [],
"repl_mode": repl_mode,
}
data: dict[str, Any] = {} # Initialize response data dictionary
try:
logger.debug(
f"Sending execution request to sandbox (id: {exec_id}, repl={repl_mode})"
)
# Send the message to the sandbox page's window context
await self.page.evaluate("window.postMessage", payload)
logger.debug(
f"Waiting for execution result (id: {exec_id}, timeout: {timeout_ms}ms)..."
)
# Wait for the response message from the sandbox via the queue
data = await asyncio.wait_for(response_queue.get(), timeout=timeout_ms / 1000.0)
logger.debug(f"Received execution result (id: {exec_id}): ok={data.get('ok')}")
except asyncio.TimeoutError:
logger.warning(
f"Execution timed out waiting for response (id: {exec_id}, timeout: {timeout_ms}ms)"
)
# Check if the page is still responsive after the timeout
await self._check_page_responsiveness(f"Timeout id={exec_id}")
# Return a structured timeout error
return {
"ok": False,
"error": {
"type": "TimeoutError",
"message": f"Execution timed out after {timeout_ms}ms waiting for sandbox response.",
"traceback": None, # No Python traceback available in this case
},
"stdout": "", # Default values on timeout
"stderr": "",
"result": None,
"elapsed": 0, # No Python elapsed time available
"wall_ms": timeout_ms, # Wall time is the timeout duration
}
except Exception as e:
# Catch potential Playwright communication errors during evaluate/wait
logger.error(
f"Error communicating during execution (id: {exec_id}): {e}", exc_info=True
)
# Return a structured communication error
return {
"ok": False,
"error": {
"type": "CommunicationError",
"message": f"Error communicating with sandbox during execution: {e}",
"traceback": None, # Or potentially include JS stack if available from e
},
"stdout": "",
"stderr": "",
"result": None,
"elapsed": 0,
"wall_ms": 0,
}
finally:
# Always remove the message handler for this execution ID
self._message_handlers.pop(exec_id, None)
# --- Validate the structure of the received response ---
if not isinstance(data, dict) or "ok" not in data:
logger.error(
f"Received malformed response from sandbox (id: {exec_id}, structure invalid): {str(data)[:500]}"
)
# Return a structured error indicating the malformed response
return {
"ok": False,
"error": {
"type": "MalformedResponseError",
"message": "Received malformed or incomplete response from sandbox.",
"traceback": None,
# Safely include details for debugging, converting non-serializable types to string
"details": data
if isinstance(data, (dict, list, str, int, float, bool, type(None)))
else str(data),
},
# Provide default values for other fields
"stdout": "",
"stderr": "",
"result": None,
"elapsed": 0,
# Try to get wall_ms if data is a dict, otherwise default to 0
"wall_ms": data.get("wall_ms", 0) if isinstance(data, dict) else 0,
}
# --- Ensure essential fields exist with default values before returning ---
# This guarantees the caller receives a consistent structure even if the sandbox
# somehow missed fields (though the JS side now also sets defaults).
data.setdefault("stdout", "")
data.setdefault("stderr", "")
data.setdefault("result", None)
data.setdefault("elapsed", 0)
data.setdefault("wall_ms", 0)
# Ensure 'error' field is present if 'ok' is false
if not data.get("ok", False):
data.setdefault(
"error",
{
"type": "UnknownSandboxError",
"message": "Sandbox reported failure with no specific details.",
},
)
else:
# Ensure 'error' is None if 'ok' is true
data["error"] = None
# Return the validated and defaulted data dictionary
return data
async def reset_repl_state(self) -> Dict[str, Any]:
"""Sends a reset request to the REPL sandbox."""
if not PLAYWRIGHT_AVAILABLE:
return {
"ok": False,
"error": {"type": "SetupError", "message": "Playwright not installed."},
}
if not self.page or self.page.is_closed():
return {"ok": False, "error": {"type": "StateError", "message": "REPL page is closed."}}
if not self.ready_evt.is_set():
return {
"ok": False,
"error": {"type": "StateError", "message": "REPL sandbox is not ready."},
}
reset_id = f"reset-{uuid.uuid4().hex[:8]}"
response_queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue()
self._message_handlers[reset_id] = response_queue
try:
message_payload = {"type": "reset", "id": reset_id}
logger.debug(f"Sending REPL reset message (id: {reset_id})")
await self.page.evaluate("window.postMessage", message_payload)
logger.debug(f"Waiting for REPL reset confirmation (id: {reset_id}, timeout: 5s)...")
data = await asyncio.wait_for(response_queue.get(), timeout=5.0)
logger.debug(f"Received REPL reset confirmation: {data}")
return data
except asyncio.TimeoutError:
logger.warning(f"Timeout waiting for REPL reset confirmation (id: {reset_id})")
return {
"ok": False,
"error": {
"type": "TimeoutError",
"message": "Timeout waiting for reset confirmation.",
},
}
except Exception as e:
logger.error(f"Error during REPL reset call (id: {reset_id}): {e}", exc_info=True)
return {
"ok": False,
"error": {
"type": "CommunicationError",
"message": f"Error during reset operation: {e}",
},
}
finally:
self._message_handlers.pop(reset_id, None)
async def _inject_mcpfs_stub(self) -> None:
"""Creates a minimal stub module `mcpfs` inside the Pyodide interpreter."""
if not PLAYWRIGHT_AVAILABLE:
logger.warning("Playwright not available, cannot inject mcpfs stub.")
return
# This stub code is executed within Pyodide, it doesn't need COMMON_PACKAGES_JSON from host Python
stub_code = r"""
import sys
import types
import asyncio
import json
from js import globalThis
# Simple check if stub already exists
if "mcpfs" in sys.modules:
print("mcpfs module stub already exists.")
else:
print("Initializing mcpfs module stub...")
_mcpfs_msg_id_counter = 0
_mcpfs_pending_futures = {}
async def _mcpfs_roundtrip(op: str, *args):
'''Sends an operation to the host and waits for the response.'''
nonlocal _mcpfs_msg_id_counter
_mcpfs_msg_id_counter += 1
current_id = f"mcpfs-{_mcpfs_msg_id_counter}"
loop = asyncio.get_running_loop()
fut = loop.create_future()
_mcpfs_pending_futures[current_id] = fut
payload = {"type": "mcpfs", "id": current_id, "op": op, "args": args}
globalThis.postMessage(payload)
try:
response = await asyncio.wait_for(fut, timeout=15.0)
except asyncio.TimeoutError:
raise RuntimeError(f"Timeout waiting for host mcpfs op '{op}' (id: {current_id})")
finally:
_mcpfs_pending_futures.pop(current_id, None)
if response is None: raise RuntimeError(f"Null response from host for mcpfs op '{op}' (id: {current_id})")
if "error" in response:
err_details = response.get('details', '')
raise RuntimeError(f"Host error for mcpfs op '{op}': {response['error']} {err_details}")
return response.get("result")
def _mcpfs_message_callback(event):
'''Callback attached to Pyodide's message listener to resolve futures.'''
data = event.data
if isinstance(data, dict) and data.get("type") == "mcpfs_response":
msg_id = data.get("id")
fut = _mcpfs_pending_futures.get(msg_id)
if fut and not fut.done(): fut.set_result(data)
globalThis.addEventListener("message", _mcpfs_message_callback)
mcpfs_module = types.ModuleType("mcpfs")
async def read_text_async(p): return await _mcpfs_roundtrip("read", p)
async def write_text_async(p, t): return await _mcpfs_roundtrip("write", p, t)
async def listdir_async(p): return await _mcpfs_roundtrip("list", p)
mcpfs_module.read_text_async = read_text_async
mcpfs_module.write_text_async = write_text_async
mcpfs_module.listdir_async = listdir_async
mcpfs_module.read_text = read_text_async
mcpfs_module.write_text = write_text_async
mcpfs_module.listdir = listdir_async
sys.modules["mcpfs"] = mcpfs_module
print("mcpfs Python module stub initialized successfully.")
# --- End of MCPFS Stub Logic ---
"""
if not self.page or self.page.is_closed():
logger.error("Cannot inject mcpfs stub: Sandbox page is closed.")
return
try:
logger.debug("Injecting mcpfs stub into Pyodide environment...")
await self.page.evaluate(
f"""(async () => {{
try {{
if (typeof self.pyodide === 'undefined' || !self.pyodide.runPythonAsync) {{
console.error('Pyodide instance not ready for mcpfs stub injection.'); return;
}}
await self.pyodide.runPythonAsync(`{stub_code}`);
console.log("mcpfs Python stub injection script executed.");
}} catch (err) {{
console.error("Error running mcpfs stub injection Python code:", err);
globalThis.postMessage({{ type: 'error', id:'mcpfs_stub_inject_fail', error: {{ type: 'InjectionError', message: 'Failed to inject mcpfs stub: ' + err.toString() }} }}, "*");
}}
}})();"""
)
logger.info("mcpfs stub injection command sent to sandbox.")
except Exception as e:
logger.error(f"Failed to evaluate mcpfs stub injection script: {e}", exc_info=True)
# Don't raise ToolError here, log it. FS might not be critical.
# --- End of PyodideSandbox Class ---
###############################################################################
# Browser / sandbox lifecycle helpers – with LRU eviction
###############################################################################
async def _get_browser() -> "Browser": # type: ignore # Use string literal hint
"""Initializes and returns the shared Playwright browser instance."""
global _BROWSER
if not PLAYWRIGHT_AVAILABLE:
raise RuntimeError("Playwright is not installed.")
browser_connected = False
if _BROWSER is not None:
try:
browser_connected = _BROWSER.is_connected()
except Exception as check_err:
logger.warning(
f"Error checking browser connection status: {check_err}. Assuming disconnected."
)
browser_connected = False
_BROWSER = None
if _BROWSER is None or not browser_connected:
logger.info("Launching headless Chromium for Pyodide sandbox...")
try:
playwright = await pw.async_playwright().start()
launch_options = {
"headless": True,
"args": [
"--no-sandbox",
"--disable-gpu",
"--disable-dev-shm-usage",
"--disable-features=Translate",
"--disable-extensions",
"--disable-component-extensions-with-background-pages",
"--disable-background-networking",
"--disable-sync",
"--metrics-recording-only",
"--disable-default-apps",
"--mute-audio",
"--no-first-run",
"--safebrowsing-disable-auto-update",
"--disable-popup-blocking",
"--disable-setuid-sandbox",
"--disable-web-security",
"--allow-file-access-from-files",
"--allow-universal-access-from-file-urls",
"--disable-permissions-api",
],
"timeout": 90000,
}
_BROWSER = await playwright.chromium.launch(**launch_options)
def _sync_cleanup():
global _BROWSER
if _BROWSER and _BROWSER.is_connected():
logger.info("Closing Playwright browser via atexit handler...")
try:
loop = asyncio.get_event_loop_policy().get_event_loop()
if loop.is_running():
future = asyncio.run_coroutine_threadsafe(_BROWSER.close(), loop)
future.result(timeout=15)
else:
loop.run_until_complete(_BROWSER.close())
logger.info("Playwright browser closed successfully via atexit.")
_BROWSER = None
except Exception as e:
logger.error(
f"Error during atexit Playwright browser cleanup: {e}", exc_info=True
)
atexit.register(_sync_cleanup)
logger.info("Headless Chromium launched successfully and atexit cleanup registered.")
except Exception as e:
logger.error(f"Failed to launch Playwright browser: {e}", exc_info=True)
_BROWSER = None
raise ProviderError(f"Failed to launch browser for sandbox: {e}") from e
if not _BROWSER:
raise ProviderError("Browser instance is None after launch attempt.")
return _BROWSER
async def _get_sandbox(session_id: str, **kwargs) -> PyodideSandbox:
"""Retrieves or creates a PyodideSandbox instance, managing LRU cache."""
global _GLOBAL_SEM, _PAGES
if not PLAYWRIGHT_AVAILABLE:
# Check upfront if Playwright is available
raise RuntimeError("Playwright is not installed. Cannot create sandboxes.")
# Initialize the global semaphore if this is the first call
if _GLOBAL_SEM is None:
_GLOBAL_SEM = asyncio.Semaphore(GLOBAL_CONCURRENCY)
logger.debug(
f"Initialized global execution semaphore with concurrency {GLOBAL_CONCURRENCY}"
)
# Check if a sandbox for this session ID already exists in our cache
sb = _PAGES.get(session_id)
if sb is not None:
page_valid = False
if sb.page:
# Verify the associated Playwright Page object is still open
try:
page_valid = not sb.page.is_closed()
except Exception as page_check_err:
# Handle potential errors during the check (e.g., context destroyed)
logger.warning(
f"Error checking page status for {session_id}: {page_check_err}. Assuming invalid."
)
page_valid = False
if page_valid:
# If the page is valid, reuse the existing sandbox
logger.debug(f"Reusing existing sandbox session: {session_id}")
# Move the accessed sandbox to the end of the OrderedDict (marks it as recently used)
_PAGES.move_to_end(session_id)
sb.last_used = time.time() # Update last used timestamp
return sb
else:
# If the page is closed or invalid, remove the entry from the cache
logger.warning(f"Removing closed/invalid sandbox session from cache: {session_id}")
_PAGES.pop(session_id, None)
# Attempt to gracefully close the page object if it exists
if sb.page:
await sb._try_close_page("Invalid page found in cache")
# If no valid sandbox found, create a new one, potentially evicting the LRU
while len(_PAGES) >= MAX_SANDBOXES:
# Remove the least recently used sandbox (first item in OrderedDict)
try:
victim_id, victim_sb = _PAGES.popitem(last=False)
logger.info(
f"Sandbox cache full ({len(_PAGES) + 1}/{MAX_SANDBOXES}). Evicting LRU session: {victim_id} "
f"(created: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(victim_sb.created_at))}, "
f"last used: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(victim_sb.last_used))})"
)
# Attempt to close the evicted sandbox's page
await victim_sb._try_close_page(f"LRU eviction (victim: {victim_id})")
except KeyError:
# Should not happen if len(_PAGES) >= MAX_SANDBOXES, but handle defensively
logger.warning("LRU eviction attempted but cache was empty.")
break # Avoid infinite loop if something is wrong
logger.info(f"Creating new sandbox session: {session_id}")
browser = await _get_browser() # Get or initialize the shared browser instance
page: Optional["Page"] = None # type: ignore # Initialize page variable
try:
# Create a new browser page (tab)
page = await browser.new_page()
# Set up logging and error handlers for this specific page
_wire_page_logging(page, session_id)
logger.debug(f"New browser page created for session {session_id}")
# Create the PyodideSandbox object instance
sb = PyodideSandbox(
page=page, **kwargs
) # Pass through any extra kwargs (like allow_network)
# Initialize the sandbox (loads HTML, waits for ready signal)
await sb.init()
# Add the newly created and initialized sandbox to the cache
_PAGES[session_id] = sb
logger.info(f"New sandbox session {session_id} created and initialized successfully.")
return sb
except Exception as e:
# Handle errors during page creation or sandbox initialization
logger.error(f"Failed to create or initialize new sandbox {session_id}: {e}", exc_info=True)
# If the page was created but initialization failed, try to close it
if page and not page.is_closed():
# Use a temporary sandbox object just to call the closing method
await PyodideSandbox(page=page)._try_close_page(
f"Failed sandbox creation/init ({session_id})"
)
# Re-raise the exception, preserving type if it's a known error type
if isinstance(e, (ToolError, ProviderError)):
raise e
# Wrap unexpected errors in ProviderError for consistent error handling upstream
raise ProviderError(f"Failed to create sandbox {session_id}: {e}") from e
async def _close_all_sandboxes():
"""Gracefully close all active sandbox pages and the browser."""
global _BROWSER, _PAGES
logger.info("Closing all active Pyodide sandboxes...")
page_close_tasks = []
sandboxes_to_close = list(_PAGES.values())
_PAGES.clear()
for sb in sandboxes_to_close:
close_task = asyncio.create_task(sb._try_close_page("Global shutdown"))
page_close_tasks.append(close_task)
if page_close_tasks:
gathered_results = await asyncio.gather(*page_close_tasks, return_exceptions=True)
closed_count = sum(1 for result in gathered_results if not isinstance(result, Exception))
errors = [result for result in gathered_results if isinstance(result, Exception)]
logger.info(
f"Attempted to close {len(page_close_tasks)} sandbox pages. Success: {closed_count}."
)
if errors:
logger.warning(f"{len(errors)} errors during page close: {errors}")
browser_needs_closing = False
if _BROWSER:
try:
browser_needs_closing = _BROWSER.is_connected()
except Exception as browser_check_err:
logger.warning(
f"Error checking browser connection during close: {browser_check_err}. Assuming needs closing."
)
browser_needs_closing = True
if browser_needs_closing:
logger.info("Closing Playwright browser instance...")
try:
await _BROWSER.close()
logger.info("Playwright browser closed successfully.")
except Exception as e:
logger.error(f"Error closing Playwright browser: {e}")
_BROWSER = None
def display_sandbox_result(
title: str, result: Optional[Dict[str, Any]], code_str: Optional[str] = None
) -> None:
"""Display sandbox execution result with enhanced formatting."""
console.print(Rule(f"[bold cyan]{escape(title)}[/bold cyan]"))
if code_str:
console.print(
Panel(
Syntax(
code_str.strip(), "python", theme="monokai", line_numbers=True, word_wrap=True
),
title="Executed Code",
border_style="blue",
padding=(1, 2),
)
)
if result is None:
console.print(
Panel(
"[bold yellow]No result object returned from tool call.[/]",
title="Warning",
border_style="yellow",
)
)
console.print()
return
# Check for errors based on the result structure from safe_tool_call
if not result.get("success", False) and "error" in result:
error_msg = result.get("error", "Unknown error")
error_type = result.get("error_type", "UnknownError")
error_code = result.get("error_code", "UNKNOWN")
details = result.get("details", {})
error_renderable = f"[bold red]:x: Operation Failed ({escape(error_type)} / {escape(error_code)}):[/]\n{escape(error_msg)}"
if details:
try:
details_str = escape(str(details)) # Basic string representation
error_renderable += f"\n\n[bold]Details:[/]\n{details_str}"
except Exception:
error_renderable += "\n\n[bold]Details:[/]\n(Could not display details)"
console.print(
Panel(error_renderable, title="Error", border_style="red", padding=(1, 2), expand=False)
)
console.print()
return
# --- Display Success Case ---
actual_result = result.get(
"result", {}
) # Get the nested result dict from execute_python/repl_python
# Create output panel for stdout/stderr
output_parts = []
if stdout := actual_result.get("stdout", ""):
output_parts.append(f"[bold green]STDOUT:[/]\n{escape(stdout)}")
if stderr := actual_result.get("stderr", ""):
if output_parts:
output_parts.append("\n" + ("-" * 20) + "\n") # Separator
output_parts.append(f"[bold red]STDERR:[/]\n{escape(stderr)}")
if output_parts:
console.print(
Panel(
"\n".join(output_parts),
title="Output (stdout/stderr)",
border_style="yellow",
padding=(1, 2),
)
)
else:
console.print("[dim]No stdout or stderr captured.[/dim]")
# Display result value if present and not None
result_value = actual_result.get(
"result"
) # This is the value assigned to 'result' in the executed code
if result_value is not None:
try:
# Attempt to pretty-print common types
if isinstance(result_value, (dict, list)):
result_str = str(result_value) # Keep it simple for now
else:
result_str = str(result_value)
# Limit length for display
max_len = 500
display_str = result_str[:max_len] + ("..." if len(result_str) > max_len else "")
console.print(
Panel(
Syntax(
display_str, "python", theme="monokai", line_numbers=False, word_wrap=True
),
title="Result Variable ('result')",
border_style="green",
padding=(1, 2),
)
)
except Exception as e:
console.print(
Panel(
f"[yellow]Could not format result value: {e}[/]",
title="Result Variable ('result')",
border_style="yellow",
)
)
console.print(f"Raw Result Type: {type(result_value)}")
try:
console.print(f"Raw Result Repr: {escape(repr(result_value)[:500])}...")
except Exception:
pass
# Display execution stats
stats_table = Table(
title="Execution Statistics",
box=box.ROUNDED,
show_header=False,
padding=(0, 1),
border_style="dim",
)
stats_table.add_column("Metric", style="cyan", justify="right")
stats_table.add_column("Value", style="white")
if "elapsed_py_ms" in actual_result:
stats_table.add_row("Python Execution Time", f"{actual_result['elapsed_py_ms']:.2f} ms")
if "elapsed_wall_ms" in actual_result:
stats_table.add_row("Sandbox Wall Clock Time", f"{actual_result['elapsed_wall_ms']:.2f} ms")
if "total_duration_ms" in result: # From safe_tool_call wrapper
stats_table.add_row("Total Tool Call Time", f"{result['total_duration_ms']:.2f} ms")
if "session_id" in actual_result:
stats_table.add_row("Session ID", actual_result["session_id"])
if "handle" in actual_result:
stats_table.add_row("REPL Handle", actual_result["handle"])
if stats_table.row_count > 0:
console.print(stats_table)
console.print() # Add spacing
###############################################################################
# mcpfs bridge – listens for postMessage & proxies to secure FS tool
###############################################################################
async def _listen_for_mcpfs_calls(page: "Page"): # type: ignore # Use string literal hint
"""Sets up listener for 'mcpfs' messages from the sandbox page."""
if not PLAYWRIGHT_AVAILABLE:
logger.warning("Playwright not available, cannot listen for mcpfs calls.")
return
async def _handle_mcpfs_message(payload: Any):
"""Processes 'mcpfs' request from Pyodide and sends 'mcpfs_response' back."""
data = payload
is_mcpfs_message = isinstance(data, dict) and data.get("type") == "mcpfs"
if not is_mcpfs_message:
return
call_id = data.get("id")
op = data.get("op")
args = data.get("args", [])
if not call_id or not op:
logger.warning(
f"MCPFS Bridge: Received invalid mcpfs message (missing id or op): {data}"
)
return
response_payload: dict[str, Any] = {"type": "mcpfs_response", "id": call_id}
try:
try:
from ultimate_mcp_server.tools import filesystem as fs
except ImportError as e:
logger.error("MCPFS Bridge: Failed to import 'filesystem' tool.", exc_info=True)
raise ToolError("Filesystem tool backend not available.") from e
if _VERBOSE_SANDBOX_LOGGING > 1:
logger.debug(f"MCPFS Bridge: Received op='{op}', args={args}, id={call_id}")
if op == "read":
if len(args) != 1 or not isinstance(args[0], str):
raise ValueError("read requires 1 string arg (path)")
res = await fs.read_file(path=args[0])
if res.get("success") and isinstance(res.get("content"), list) and res["content"]:
file_content = res["content"][0].get("text")
if file_content is None:
raise ToolError("Read succeeded but missing 'text' key.")
response_payload["result"] = file_content
else:
raise ToolError(res.get("error", "Read failed"), details=res.get("details"))
elif op == "write":
if len(args) != 2 or not isinstance(args[0], str) or not isinstance(args[1], str):
raise ValueError("write requires 2 string args (path, content)")
res = await fs.write_file(path=args[0], content=args[1])
if res.get("success"):
response_payload["result"] = True
else:
raise ToolError(res.get("error", "Write failed"), details=res.get("details"))
elif op == "list":
if len(args) != 1 or not isinstance(args[0], str):
raise ValueError("list requires 1 string arg (path)")
res = await fs.list_directory(path=args[0])
if res.get("success"):
response_payload["result"] = res.get("entries", [])
else:
raise ToolError(res.get("error", "List failed"), details=res.get("details"))
else:
raise ValueError(f"Unsupported mcpfs operation: '{op}'")
except (ToolError, ToolInputError, ProviderError, ValueError) as tool_exc:
error_message = f"{type(tool_exc).__name__}: {tool_exc}"
logger.warning(
f"MCPFS Bridge Error processing op='{op}' (id={call_id}): {error_message}"
)
response_payload["error"] = error_message
if hasattr(tool_exc, "details") and tool_exc.details:
try:
response_payload["details"] = json.loads(
json.dumps(tool_exc.details, default=str)
)
except Exception:
response_payload["details"] = {"error": "Serialization failed"}
except Exception as exc:
error_message = f"Unexpected Host Error: {exc}"
logger.error(
f"Unexpected MCPFS Bridge Error (op='{op}', id={call_id}): {error_message}",
exc_info=True,
)
response_payload["error"] = error_message
try:
response_successful = "error" not in response_payload
if _VERBOSE_SANDBOX_LOGGING > 1:
logger.debug(
f"MCPFS Bridge: Sending response (op='{op}', id={call_id}, success={response_successful})"
)
await page.evaluate(JS_POST_MESSAGE, response_payload)
except Exception as post_err:
logger.warning(
f"Failed to send mcpfs response back to sandbox (id: {call_id}, op: '{op}'): {post_err}"
)
handler_func_name = "_handleMcpFsMessageFromHost"
try:
await page.expose_function(handler_func_name, _handle_mcpfs_message)
await page.evaluate(f"""
if (!window._mcpfsListenerAttached) {{
console.log('Setting up MCPFS message listener in browser context...');
window.addEventListener('message', (event) => {{
if (event.data && event.data.type === 'mcpfs' && typeof window.{handler_func_name} === 'function') {{
window.{handler_func_name}(event.data);
}}
}});
window._mcpfsListenerAttached = true;
console.log('MCPFS message listener attached.');
}}
""")
logger.info("MCPFS listener bridge established successfully.")
except Exception as e:
logger.error(f"Failed to set up MCPFS listener bridge: {e}", exc_info=True)
raise ToolError(f"Filesystem bridge listener setup failed: {e}") from e
def _format_sandbox_error(error_payload: Optional[Dict[str, Any]]) -> str:
if not error_payload or not isinstance(error_payload, dict):
return "Unknown sandbox execution error."
err_type = error_payload.get("type", "UnknownError")
err_msg = error_payload.get("message", "No details provided.")
# Optionally include traceback snippet if needed, but keep main message clean
tb = error_payload.get("traceback")
if tb:
err_msg += f"\nTraceback (see logs/details):\n{str(tb)[:200]}..."
return f"{err_type} - {err_msg}"
###############################################################################
# Standalone Tool Functions (execute_python, repl_python)
###############################################################################
@with_tool_metrics
@with_error_handling
async def execute_python(
code: str,
packages: Optional[List[str]] = None,
wheels: Optional[List[str]] = None,
allow_network: bool = False,
allow_fs: bool = False,
session_id: Optional[str] = None,
timeout_ms: int = 15_000,
ctx: Optional[Dict[str, Any]] = None, # Context often used by decorators
) -> Dict[str, Any]:
"""
Runs Python code in a one-shot Pyodide sandbox.
Args:
code: The Python code string to execute.
packages: A list of Pyodide packages to ensure are loaded. Do not include stdlib modules.
wheels: A list of Python wheel URLs to install via micropip.
allow_network: If True, allows network access (e.g., for micropip to PyPI).
allow_fs: If True, enables the mcpfs filesystem bridge (requires host setup).
session_id: Optional ID to reuse or create a specific sandbox session. If None, a new ID is generated.
timeout_ms: Timeout for waiting for the sandbox execution result (in milliseconds).
ctx: Optional context dictionary, often passed by framework/decorators.
Returns:
A dictionary containing execution results:
{
'success': bool,
'stdout': str,
'stderr': str,
'result': Any, # Value of the 'result' variable in the Python code, if set
'elapsed_py_ms': int, # Time spent executing Python code (reported by sandbox)
'elapsed_wall_ms': int, # Total wall clock time from JS perspective (reported by sandbox)
'session_id': str,
'error_message': Optional[str], # Formatted error if success is False
'error_details': Optional[Dict], # Original error dict from sandbox if success is False
}
Raises:
ProviderError: If the sandbox environment (Playwright/browser) cannot be set up.
ToolInputError: If input arguments are invalid.
ToolError: If sandbox execution fails (contains formatted message and details).
"""
if not PLAYWRIGHT_AVAILABLE:
raise ProviderError("Playwright dependency is missing for Python Sandbox.")
if not isinstance(code, str) or not code:
raise ToolInputError(
"Input 'code' must be a non-empty string.", param="code", value=repr(code)
)
if not isinstance(timeout_ms, int) or timeout_ms <= 0:
raise ToolInputError(
"Input 'timeout_ms' must be a positive integer.", param="timeout_ms", value=timeout_ms
)
# Basic type checks for lists/bools - could add more specific validation
if packages is not None and not isinstance(packages, list):
raise ToolInputError(
"Input 'packages' must be a list or None.", param="packages", value=packages
)
if wheels is not None and not isinstance(wheels, list):
raise ToolInputError("Input 'wheels' must be a list or None.", param="wheels", value=wheels)
if not isinstance(allow_network, bool):
raise ToolInputError(
"Input 'allow_network' must be a boolean.", param="allow_network", value=allow_network
)
if not isinstance(allow_fs, bool):
raise ToolInputError(
"Input 'allow_fs' must be a boolean.", param="allow_fs", value=allow_fs
)
if session_id is not None and not isinstance(session_id, str):
raise ToolInputError(
"Input 'session_id' must be a string or None.", param="session_id", value=session_id
)
# Normalize package/wheel lists
# IMPORTANT: Filter out common stdlib modules that shouldn't be passed
stdlib_modules_to_filter = {
"math",
"sys",
"os",
"json",
"io",
"contextlib",
"time",
"base64",
"traceback",
"collections",
"re",
"datetime",
}
packages_normalized = [pkg for pkg in (packages or []) if pkg not in stdlib_modules_to_filter]
wheels_normalized = wheels or []
# Generate a session ID if one wasn't provided
current_session_id = session_id or f"exec-{uuid.uuid4().hex[:12]}" # Add prefix for clarity
# Get or create the sandbox instance
try:
# Assuming _get_sandbox is defined elsewhere and returns PyodideSandbox instance
sb = await _get_sandbox(current_session_id, allow_network=allow_network, allow_fs=allow_fs)
except Exception as e:
# Catch potential errors during sandbox acquisition/initialization
if isinstance(e, (ToolError, ProviderError)):
raise e # Re-raise known error types
# Wrap unexpected errors
raise ProviderError(
f"Failed to get or initialize sandbox '{current_session_id}': {e}",
tool_name="python_sandbox",
cause=e,
) from e
t0 = time.perf_counter() # Start timer just before execute call
data: Dict[str, Any] = {} # Initialize data dict
# Execute the code within the sandbox
try:
# Call the execute method on the sandbox object
# Pass repl_mode=False for one-shot execution
data = await sb.execute(
code, packages_normalized, wheels_normalized, timeout_ms, repl_mode=False
)
except Exception as e:
# Catch potential host-side errors during the .execute() call itself
# (e.g., Playwright communication errors not caught internally by execute)
wall_ms = int((time.perf_counter() - t0) * 1000)
logger.error(
f"Unexpected host error calling sandbox execute for {current_session_id}: {e}",
exc_info=True,
)
raise ToolError(
f"Unexpected host error during sandbox execution call: {e}",
error_code="HostExecutionError",
details={"session_id": current_session_id, "elapsed_wall_ms": wall_ms, "cause": str(e)},
) from e
# Process the results received from the sandbox
wall_ms_host = int((time.perf_counter() - t0) * 1000) # Wall time measured by host
is_success = data.get("ok", False)
error_info = data.get(
"error"
) # This is the structured {type, message, traceback} dict from sandbox
js_wall_ms = int(data.get("wall_ms", 0)) # Wall time reported by JS sandbox handler
# Format error message IF execution failed inside the sandbox
error_message_for_caller = None
error_code_for_caller = "UnknownSandboxError" # Default error code
if not is_success:
error_message_for_caller = _format_sandbox_error(
error_info
) # Use helper to get "Type - Message" string
if isinstance(error_info, dict):
error_code_for_caller = error_info.get(
"type", "UnknownSandboxError"
) # Get specific code
# Prepare structured logging details
log_details = {
"session_id": current_session_id,
"elapsed_wall_ms_host": wall_ms_host,
"elapsed_wall_ms_js": js_wall_ms, # Log both wall times for comparison
"elapsed_py_ms": int(data.get("elapsed", 0)),
"packages_requested": packages or [], # Log original requested packages
"packages_loaded": packages_normalized, # Log packages actually sent to loadPackage
"wheels_count": len(wheels_normalized),
"stdout_len": len(data.get("stdout", "")),
"stderr_len": len(data.get("stderr", "")),
"result_type": type(data.get("result")).__name__,
"success": is_success,
"repl_mode": False,
}
# Log and return/raise based on success
if is_success:
logger.success(
f"Python code executed successfully (session: {current_session_id})",
TaskType.CODE_EXECUTION, # Assumes TaskType is defined/imported
**log_details,
)
# Return success dictionary matching specified structure
return {
"success": True,
"stdout": data.get("stdout", ""),
"stderr": data.get("stderr", ""),
"result": data.get("result"), # Can be None
"elapsed_py_ms": int(data.get("elapsed", 0)),
"elapsed_wall_ms": js_wall_ms or wall_ms_host, # Prefer JS wall time
"session_id": current_session_id,
"error_message": None, # Explicitly None on success
"error_details": None, # Explicitly None on success
}
else:
# Log the failure with details
logger.error(
f"Python code execution failed (session: {current_session_id}): {error_message_for_caller}",
TaskType.CODE_EXECUTION, # Assumes TaskType is defined/imported
**log_details,
error_details=error_info, # Log the original structured error details
)
# Raise a ToolError containing the formatted message and original details
raise ToolError(
f"Python execution failed: {error_message_for_caller}", # User-friendly message
error_code=error_code_for_caller, # Specific error code from sandbox
details=error_info, # Original structured error from sandbox
)
@with_tool_metrics
@with_error_handling
async def repl_python(
code: str,
packages: Optional[List[str]] = None,
wheels: Optional[List[str]] = None,
allow_network: bool = False,
allow_fs: bool = False,
handle: Optional[str] = None, # Session handle for persistence
timeout_ms: int = 15_000,
reset: bool = False, # Flag to reset the REPL state before execution
ctx: Optional[Dict[str, Any]] = None, # Context often used by decorators
) -> Dict[str, Any]:
"""
Runs Python code in a persistent REPL-like sandbox environment.
Args:
code: The Python code string to execute in the session. Can be empty if only resetting.
packages: Additional Pyodide packages to ensure are loaded for this specific call.
wheels: Additional Python wheel URLs to install for this specific call.
allow_network: If True, allows network access for the sandbox session.
allow_fs: If True, enables the mcpfs filesystem bridge for the session.
handle: A specific session ID to use. If None, a new session is created.
Use the returned handle for subsequent calls to maintain state.
timeout_ms: Timeout for waiting for this specific execution call.
reset: If True, clears the REPL session's state (_MCP_REPL_NS) before executing code.
ctx: Optional context dictionary.
Returns:
A dictionary containing execution results for *this call*:
{
'success': bool, # Success of *this specific code execution* (or reset)
'stdout': str,
'stderr': str,
'result': Any, # Value of 'result' variable from this execution, if set
'elapsed_py_ms': int,
'elapsed_wall_ms': int,
'handle': str, # The session handle (same as input or newly generated)
'error_message': Optional[str], # Formatted error if success is False
'error_details': Optional[Dict], # Original error dict if success is False
'reset_status': Optional[Dict], # Included only if reset=True, contains reset ack
}
Raises:
ProviderError: If the sandbox environment cannot be set up.
ToolInputError: If input arguments are invalid.
ToolError: If a non-recoverable error occurs during execution (contains details).
Note: Standard Python errors within the code are returned in the 'error' fields,
not typically raised as ToolError unless they prevent result processing.
"""
if not PLAYWRIGHT_AVAILABLE:
raise ProviderError("Playwright dependency is missing for Python Sandbox.")
# Code can be empty if reset is True
if not isinstance(code, str):
raise ToolInputError("Input 'code' must be a string.", param="code", value=repr(code))
if not code and not reset:
raise ToolInputError(
"Input 'code' cannot be empty unless 'reset' is True.", param="code", value=repr(code)
)
if not isinstance(timeout_ms, int) or timeout_ms <= 0:
raise ToolInputError(
"Input 'timeout_ms' must be a positive integer.", param="timeout_ms", value=timeout_ms
)
# Basic type checks - can be expanded
if packages is not None and not isinstance(packages, list):
raise ToolInputError(
"Input 'packages' must be a list or None.", param="packages", value=packages
)
if wheels is not None and not isinstance(wheels, list):
raise ToolInputError("Input 'wheels' must be a list or None.", param="wheels", value=wheels)
if not isinstance(allow_network, bool):
raise ToolInputError(
"Input 'allow_network' must be a boolean.", param="allow_network", value=allow_network
)
if not isinstance(allow_fs, bool):
raise ToolInputError(
"Input 'allow_fs' must be a boolean.", param="allow_fs", value=allow_fs
)
if handle is not None and not isinstance(handle, str):
raise ToolInputError(
"Input 'handle' must be a string or None.", param="handle", value=handle
)
if not isinstance(reset, bool):
raise ToolInputError("Input 'reset' must be a boolean.", param="reset", value=reset)
# IMPORTANT: Filter out common stdlib modules that shouldn't be passed
stdlib_modules_to_filter = {
"math",
"sys",
"os",
"json",
"io",
"contextlib",
"time",
"base64",
"traceback",
"collections",
"re",
"datetime",
}
packages_normalized = [pkg for pkg in (packages or []) if pkg not in stdlib_modules_to_filter]
wheels_normalized = wheels or []
# Use provided handle or generate a new persistent one
session_id = handle or f"repl-{uuid.uuid4().hex[:12]}"
# Get or create the sandbox instance (will reuse if handle exists and page is open)
try:
# Pass allow_network/allow_fs, they are session-level properties
sb = await _get_sandbox(session_id, allow_network=allow_network, allow_fs=allow_fs)
except Exception as e:
if isinstance(e, (ToolError, ProviderError)):
raise e
raise ProviderError(
f"Failed to get or initialize REPL sandbox '{session_id}': {e}",
tool_name="python_sandbox",
cause=e,
) from e
t0 = time.perf_counter() # Start timer before potential reset/execute
reset_ack_data: Optional[Dict] = None # To store the ack from the reset call
# --- Handle Reset Request ---
if reset:
logger.info(f"Resetting REPL state for session: {session_id}")
try:
# Assuming PyodideSandbox has a method like this that uses the direct callback
reset_ack_data = await sb.reset_repl_state() # This should wait for the JS ack
if not reset_ack_data or not reset_ack_data.get("ok"):
# Log warning but don't necessarily fail the whole call yet
error_msg = (
_format_sandbox_error(reset_ack_data.get("error"))
if reset_ack_data
else "No confirmation received"
)
logger.warning(
f"REPL state reset failed or unconfirmed for session {session_id}: {error_msg}"
)
# Optionally add this warning to the final result?
except Exception as e:
# Handle errors during the reset call itself
logger.warning(
f"Error during REPL reset call for session {session_id}: {e}", exc_info=True
)
# Store this error to potentially include in final result if no code is run
reset_ack_data = {"ok": False, "error": {"type": "ResetHostError", "message": str(e)}}
# If ONLY resetting (no code provided), return immediately after reset attempt
if not code:
host_wall_ms = int((time.perf_counter() - t0) * 1000)
final_result = {
"success": reset_ack_data.get("ok", False)
if reset_ack_data
else False, # Reflect reset success
"stdout": "",
"stderr": "",
"result": None,
"elapsed_py_ms": 0,
"elapsed_wall_ms": host_wall_ms, # Only host time available
"handle": session_id,
"error_message": None
if (reset_ack_data and reset_ack_data.get("ok"))
else _format_sandbox_error(reset_ack_data.get("error") if reset_ack_data else None),
"error_details": reset_ack_data.get("error")
if (reset_ack_data and not reset_ack_data.get("ok"))
else None,
"reset_status": reset_ack_data, # Always include reset ack if reset was true
}
return final_result
# --- Execute Code (if provided) ---
data: Dict[str, Any] = {} # Initialize data dict for execution results
execution_successful_this_call = True # Assume success unless execution fails
if code:
try:
# Call the execute method, ensuring repl_mode=True is passed
data = await sb.execute(
code, packages_normalized, wheels_normalized, timeout_ms, repl_mode=True
)
execution_successful_this_call = data.get(
"ok", False
) # Get success status from execution result
except Exception as e:
# Catch host-side errors during the execute call
execution_successful_this_call = False
wall_ms_host_error = int((time.perf_counter() - t0) * 1000)
logger.error(
f"Unexpected host error calling REPL sandbox execute for {session_id}: {e}",
exc_info=True,
)
# Create a failure structure similar to what execute returns
data = {
"ok": False,
"error": {
"type": "HostExecutionError",
"message": f"Host error during REPL exec: {e}",
},
"wall_ms": wall_ms_host_error, # Use host time
"elapsed": 0,
"stdout": "",
"stderr": "",
"result": None,
}
# --- Format results and potential errors ---
wall_ms_host_final = int((time.perf_counter() - t0) * 1000)
js_wall_ms = int(data.get("wall_ms", 0)) # Wall time reported by JS sandbox handler
py_elapsed_ms = int(data.get("elapsed", 0))
stdout_content = data.get("stdout", "")
stderr_content = data.get("stderr", "")
result_val = data.get("result")
error_info = data.get("error") # Original error dict from sandbox execution
error_message_for_caller = None
error_code_for_caller = "UnknownError"
if not execution_successful_this_call:
error_message_for_caller = _format_sandbox_error(error_info)
if isinstance(error_info, dict):
error_code_for_caller = error_info.get("type", "UnknownSandboxError") # noqa: F841
# --- Logging ---
action_desc = "executed" if code else "accessed (no code run)"
action_desc += " with reset" if reset else ""
log_details = {
"session_id": session_id,
"action": action_desc,
"reset_requested": reset,
"reset_successful": reset_ack_data.get("ok") if reset_ack_data else None,
"elapsed_wall_ms_host": wall_ms_host_final,
"elapsed_wall_ms_js": js_wall_ms,
"elapsed_py_ms": py_elapsed_ms,
"packages_requested": packages or [],
"packages_loaded": packages_normalized,
"wheels_count": len(wheels_normalized),
"stdout_len": len(stdout_content),
"stderr_len": len(stderr_content),
"result_type": type(result_val).__name__,
"success_this_call": execution_successful_this_call,
"repl_mode": True,
}
log_level = logger.success if execution_successful_this_call else logger.warning
log_level(
f"Python code {action_desc} in REPL sandbox (session: {session_id})",
TaskType.CODE_EXECUTION, # Assumes TaskType is defined/imported
**log_details,
error_details=error_info if not execution_successful_this_call else None,
)
# --- Construct final return dictionary ---
final_result = {
"success": execution_successful_this_call, # Reflect success of *this* call
"stdout": stdout_content,
"stderr": stderr_content,
"result": result_val,
"elapsed_py_ms": py_elapsed_ms,
"elapsed_wall_ms": js_wall_ms or wall_ms_host_final, # Prefer JS wall time
"handle": session_id, # Always return the handle
"error_message": error_message_for_caller, # Formatted string or None
"error_details": error_info
if not execution_successful_this_call
else None, # Original dict or None
}
# Include reset status if reset was requested
if reset:
final_result["reset_status"] = reset_ack_data
# Do NOT raise ToolError for standard Python errors caught inside sandbox,
# return them in the dictionary structure instead.
# Only raise ToolError for host-level/unrecoverable issues earlier.
return final_result
###############################################################################
# Optional: Asset Preloader Function (Integrated)
###############################################################################
def _get_pyodide_asset_list_from_manifest(manifest_url: str) -> List[str]:
"""
Generates a list of essential Pyodide assets to preload based on version.
For v0.27.5, uses a hardcoded list as repodata.json isn't typically used
for core file listing in the same way older versions might have.
"""
global _PYODIDE_VERSION # Access the global version
logger.info(f"[Preload] Generating asset list for Pyodide v{_PYODIDE_VERSION}.")
# Version-specific logic (can be expanded for other versions)
if _PYODIDE_VERSION.startswith("0.27."):
# Hardcoded list for v0.27.x - VERIFY these against the actual CDN structure for 0.27.5!
# These are the most common core files needed for initialization.
core_files = {
# --- Core Runtime ---
"pyodide.js", # Main JS loader (UMD potentially)
"pyodide.mjs", # Main JS loader (ESM)
"pyodide.asm.js", # Wasm loader fallback/glue
"pyodide.asm.wasm", # The main WebAssembly module
# --- Standard Library ---
"python_stdlib.zip", # Packed standard library
# --- Metadata/Lock Files ---
"pyodide-lock.json", # Package lock file (crucial for loadPackage)
# --- Potential Depencencies (Less common to preload, but check CDN) ---
# "distutils.tar",
# "pyodide_py.tar",
}
logger.info(f"[Preload] Using hardcoded core asset list for v{_PYODIDE_VERSION}.")
# Log the URL that was passed but ignored for clarity
if (
manifest_url != f"{_CDN_BASE}/repodata.json"
): # Only log if it differs from default assumption
logger.warning(f"[Preload] Ignoring provided manifest_url: {manifest_url}")
else:
# Placeholder for potentially different logic for other versions
# (e.g., actually fetching and parsing repodata.json if needed)
logger.warning(
f"[Preload] No specific asset list logic for Pyodide v{_PYODIDE_VERSION}. Using empty list."
)
core_files = set()
# If you needed to fetch/parse repodata.json for other versions:
# try:
# logger.info(f"[Preload] Fetching manifest from {manifest_url}")
# # ... (fetch manifest_url using _fetch_asset_sync or urllib) ...
# # ... (parse JSON) ...
# # ... (extract file names based on manifest structure) ...
# except Exception as e:
# logger.error(f"[Preload] Failed to fetch or parse manifest {manifest_url}: {e}")
# core_files = set() # Fallback to empty on error
if not core_files:
logger.warning("[Preload] The generated core asset list is empty!")
else:
logger.debug(f"[Preload] Identified {len(core_files)} essential core files to fetch.")
# Common packages are loaded on demand *within* the sandbox, not typically preloaded here.
# Explicitly state this.
logger.info(
"[Preload] Common packages (like numpy, pandas) are NOT included in this core preload list. "
"They will be fetched on demand by the sandbox if needed and cached separately "
"when `loadPackage` is called."
)
# Return the sorted list of unique filenames
return sorted(list(core_files))
def preload_pyodide_assets(force_download: bool = False):
"""Downloads Pyodide assets to the local cache directory."""
print("-" * 60)
print("Starting Pyodide Asset Preloader")
print(f"Target Pyodide Version: {_PYODIDE_VERSION}")
print(f"CDN Base URL: {_CDN_BASE}")
print(f"Cache Directory: {_CACHE_DIR}")
print(f"Force Re-download: {force_download}")
print("-" * 60)
try:
_CACHE_DIR.mkdir(parents=True, exist_ok=True)
except OSError as e:
print(
f"ERROR: Failed to ensure cache directory exists at {_CACHE_DIR}: {e}\nPreloading cannot proceed."
)
return
manifest_url = f"{_CDN_BASE}/repodata.json" # Dummy URL for v0.27.5 preloader logic
asset_files = _get_pyodide_asset_list_from_manifest(manifest_url)
if not asset_files:
print("ERROR: No asset files were identified.\nPreloading cannot proceed.")
return
print(f"\nAttempting to cache/verify {len(asset_files)} assets...")
cached_count = 0
verified_count = 0
error_count = 0
total_bytes_downloaded = 0
total_bytes_verified = 0
max_age = 0 if force_download else (10 * 365 * 24 * 3600)
num_files = len(asset_files)
width = len(str(num_files))
for i, filename in enumerate(asset_files):
if not filename:
logger.warning(f"[Preload] Skipping empty filename at index {i}.")
continue
file_url = f"{_CDN_BASE}/{filename}"
progress = f"[{i + 1:>{width}}/{num_files}]"
local_file_path = _local_path(file_url)
file_exists = local_file_path.exists()
is_stale = False
action = "Fetching"
if file_exists:
try:
file_stat = local_file_path.stat()
if file_stat.st_size == 0:
logger.warning(
f"[Preload] Cached file {local_file_path} is empty. Will re-fetch."
)
file_exists = False
else:
file_age = time.time() - file_stat.st_mtime
if file_age >= max_age:
is_stale = True
else:
action = "Verifying" if not force_download else "Re-fetching (forced)"
except OSError as stat_err:
logger.warning(
f"[Preload] Error checking status of {local_file_path}: {stat_err}. Will re-fetch."
)
file_exists = False
action = "Fetching (stat failed)"
if file_exists and is_stale and not force_download:
action = "Re-fetching (stale)"
display_name = filename if len(filename) <= 60 else filename[:57] + "..."
print(f"{progress} {action:<25} {display_name:<60} ... ", end="", flush=True)
try:
data = _fetch_asset_sync(file_url, max_age_s=max_age)
file_size_kb = len(data) // 1024
if action == "Verifying":
verified_count += 1
total_bytes_verified += len(data)
print(f"OK (cached, {file_size_kb:>5} KB)")
else:
cached_count += 1
total_bytes_downloaded += len(data)
status = "OK" if action.startswith("Fetch") else "OK (updated)"
print(f"{status} ({file_size_kb:>5} KB)")
except Exception as e:
print(f"ERROR: {e}")
logger.error(f"[Preload] Failed to fetch/cache {file_url}: {e}", exc_info=False)
error_count += 1
print("\n" + "-" * 60)
print("Preload Summary")
print("-" * 60)
print(f"Assets already cached & verified: {verified_count}")
print(f"Assets newly downloaded/updated: {cached_count}")
print(f"Total assets processed: {verified_count + cached_count}")
print(f"Errors encountered: {error_count}")
print("-" * 60)
print(f"Size of verified assets: {total_bytes_verified / (1024 * 1024):,.1f} MB")
print(f"Size of downloaded assets: {total_bytes_downloaded / (1024 * 1024):,.1f} MB")
print(
f"Total cache size (approx): {(total_bytes_verified + total_bytes_downloaded) / (1024 * 1024):,.1f} MB"
)
print("-" * 60)
if error_count == 0:
print("Preloading completed successfully. Assets should be cached for offline use.")
else:
print(
f"WARNING: {error_count} assets failed to download. Offline functionality may be incomplete."
)
print("-" * 60)
###############################################################################
# Main execution block for preloading (if script is run directly)
###############################################################################
if __name__ == "__main__":
# Setup logging if run as main script
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
parser = argparse.ArgumentParser(
description="Utility for the Python Sandbox module. Includes Pyodide asset preloader.",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=""" Examples:\n Cache Pyodide assets (download if missing/stale):\n python %(prog)s --preload\n\n Force re-download of all assets, ignoring cache:\n python %(prog)s --preload --force """,
)
parser.add_argument(
"--preload",
action="store_true",
help="Run the Pyodide asset preloader to cache files required for offline operation.",
)
parser.add_argument(
"--force",
"-f",
action="store_true",
help="Force re-download of all assets during preload, ignoring existing cache validity.",
)
args = parser.parse_args()
if args.preload:
preload_pyodide_assets(force_download=args.force)
else:
print(
"This script contains the PythonSandbox tool implementation.\nUse the --preload argument to cache Pyodide assets for offline use.\nExample: python path/to/python_sandbox.py --preload"
)
```
--------------------------------------------------------------------------------
/ultimate_mcp_server/tools/audio_transcription.py:
--------------------------------------------------------------------------------
```python
"""Advanced audio transcription and enhancement tools for Ultimate MCP Server.
This module provides tools for high-quality audio transcription, pre-processing,
and intelligent transcript enhancement with advanced features like speaker
diarization, custom vocabulary support, and semantic structuring.
"""
import asyncio
import concurrent.futures
import datetime
import json
import os
import re
import shutil
import subprocess
import tempfile
import time
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional
import aiofiles
import httpx
from docx import Document
from pydantic import BaseModel, Field
from pydantic.functional_validators import field_validator
from ultimate_mcp_server.constants import Provider, TaskType
from ultimate_mcp_server.core.providers.base import get_provider
from ultimate_mcp_server.exceptions import (
ProviderError,
ResourceError,
ToolError,
ToolInputError,
)
from ultimate_mcp_server.services.cache import with_cache
from ultimate_mcp_server.tools.base import with_error_handling, with_retry, with_tool_metrics
from ultimate_mcp_server.tools.completion import chat_completion, generate_completion
from ultimate_mcp_server.utils import get_logger
from ultimate_mcp_server.utils.text import count_tokens
logger = get_logger("ultimate_mcp_server.tools.audio")
# --- Constants and Enums ---
class AudioEnhancementProfile(str, Enum):
"""Predefined audio enhancement profiles for different recording types."""
CONFERENCE_CALL = "conference_call" # Optimized for online meetings
INTERVIEW = "interview" # Optimized for interview recordings
LECTURE = "lecture" # Optimized for lectures/presentations
NOISY = "noisy" # Heavy noise reduction for noisy environments
PHONE_CALL = "phone_call" # Optimized for telephone audio
VOICEMAIL = "voicemail" # Optimized for voicemail recordings
CUSTOM = "custom" # User-defined settings
# Expected model sizes in bytes, with some tolerance (minimum size)
WHISPER_MODEL_SIZES = {
"large-v3": 2900000000, # ~2.9GB
"large-v3-turbo": 1500000000, # ~1.5GB
}
class TranscriptionQuality(str, Enum):
"""Quality settings for transcription, balancing speed vs accuracy."""
DRAFT = "draft" # Fastest, less accurate
STANDARD = "standard" # Balanced speed/accuracy
ENHANCED = "enhanced" # More accurate, slower
MAXIMUM = "maximum" # Most accurate, slowest
class EnhancementStyle(str, Enum):
"""Transcript enhancement styles for different use cases."""
RAW = "raw" # No enhancement, just cleaned
READABLE = "readable" # Basic readability improvements
POLISHED = "polished" # Well-formatted with proper punctuation
VERBATIM = "verbatim" # Preserve all speech patterns, hesitations
STRUCTURED = "structured" # Add semantic structure (paragraphs, sections)
class OutputFormat(str, Enum):
"""Available output formats for transcripts."""
JSON = "json" # Full JSON with all metadata
TEXT = "text" # Plain text
SRT = "srt" # SubRip subtitle format
VTT = "vtt" # WebVTT subtitle format
DOCX = "docx" # Microsoft Word format
MARKDOWN = "markdown" # Markdown with formatting
# --- Schema Validation Models ---
class AudioEnhancementParams(BaseModel):
"""Parameters for audio enhancement."""
profile: AudioEnhancementProfile = Field(
default=AudioEnhancementProfile.CONFERENCE_CALL,
description="Predefined audio enhancement profile"
)
volume: float = Field(
default=1.5,
ge=0.1,
le=5.0,
description="Volume adjustment factor"
)
noise_reduction: int = Field(
default=10,
ge=0,
le=30,
description="Noise reduction strength (0-30)"
)
highpass: int = Field(
default=200,
ge=50,
le=500,
description="Highpass filter frequency in Hz"
)
lowpass: int = Field(
default=3000,
ge=1000,
le=20000,
description="Lowpass filter frequency in Hz"
)
normalize: bool = Field(
default=True,
description="Apply dynamic audio normalization"
)
compression: bool = Field(
default=True,
description="Apply dynamic range compression"
)
dereverberation: bool = Field(
default=False,
description="Apply dereverberation filter"
)
custom_filters: Optional[str] = Field(
default=None,
description="Additional custom FFmpeg filters"
)
output_channels: int = Field(
default=2,
ge=1,
le=2,
description="Number of output channels (1=mono, 2=stereo)"
)
output_sample_rate: int = Field(
default=16000,
ge=8000,
le=48000,
description="Output sample rate in Hz"
)
@field_validator('custom_filters')
def validate_custom_filters(cls, v):
"""Validate that custom filters don't contain dangerous commands."""
if v:
# Check for shell escape attempts
dangerous_patterns = [';', '&&', '||', '`', '$', '\\', '>', '<', '|', '*', '?', '~', '#']
for pattern in dangerous_patterns:
if pattern in v:
raise ValueError(f"Custom filter contains disallowed character: {pattern}")
return v
class WhisperParams(BaseModel):
"""Parameters for Whisper transcription."""
model: str = Field(
default="large-v3-turbo",
description="Whisper model name"
)
language: Optional[str] = Field(
default=None,
description="Language code (auto-detect if None)"
)
quality: TranscriptionQuality = Field(
default=TranscriptionQuality.STANDARD,
description="Transcription quality level"
)
beam_size: int = Field(
default=5,
ge=1,
le=10,
description="Beam size for beam search"
)
processors: int = Field(
default=2,
ge=1,
le=8,
description="Number of processors to use"
)
word_timestamps: bool = Field(
default=True,
description="Generate timestamps for each word"
)
translate: bool = Field(
default=False,
description="Translate non-English to English"
)
diarize: bool = Field(
default=False,
description="Attempt speaker diarization"
)
highlight_words: bool = Field(
default=False,
description="Highlight words with lower confidence"
)
max_context: int = Field(
default=-1,
ge=-1,
description="Maximum number of text tokens to consider from previous history"
)
custom_vocab: Optional[List[str]] = Field(
default=None,
description="Custom vocabulary terms to improve recognition"
)
class TranscriptEnhancementParams(BaseModel):
"""Parameters for transcript enhancement."""
style: EnhancementStyle = Field(
default=EnhancementStyle.READABLE,
description="Enhancement style"
)
provider: str = Field(
default=Provider.ANTHROPIC.value,
description="LLM provider for enhancement"
)
model: Optional[str] = Field(
default=None,
description="Specific model to use (provider default if None)"
)
identify_speakers: bool = Field(
default=False,
description="Attempt to identify and label speakers"
)
add_paragraphs: bool = Field(
default=True,
description="Add paragraph breaks at natural points"
)
fix_spelling: bool = Field(
default=True,
description="Fix spelling errors"
)
fix_grammar: bool = Field(
default=True,
description="Fix basic grammatical errors"
)
sections: bool = Field(
default=False,
description="Add section headings based on topic changes"
)
max_chunk_size: int = Field(
default=6500,
ge=1000,
le=100000,
description="Maximum chunk size in characters"
)
format_numbers: bool = Field(
default=True,
description="Format numbers consistently (e.g., '25' instead of 'twenty-five')"
)
custom_instructions: Optional[str] = Field(
default=None,
description="Additional custom instructions for enhancement"
)
class TranscriptionOptions(BaseModel):
"""Complete options for audio transcription."""
enhance_audio: bool = Field(
default=True,
description="Whether to preprocess audio with FFmpeg"
)
enhance_transcript: bool = Field(
default=True,
description="Whether to enhance the transcript with LLM"
)
parallel_processing: bool = Field(
default=True,
description="Process chunks in parallel when possible"
)
max_workers: int = Field(
default=4,
ge=1,
description="Maximum number of parallel workers"
)
output_formats: List[OutputFormat] = Field(
default=[OutputFormat.JSON, OutputFormat.TEXT],
description="Output formats to generate"
)
save_enhanced_audio: bool = Field(
default=False,
description="Save the enhanced audio file"
)
keep_artifacts: bool = Field(
default=False,
description="Keep temporary files and artifacts"
)
audio_params: AudioEnhancementParams = Field(
default_factory=AudioEnhancementParams,
description="Audio enhancement parameters"
)
whisper_params: WhisperParams = Field(
default_factory=WhisperParams,
description="Whisper transcription parameters"
)
enhancement_params: TranscriptEnhancementParams = Field(
default_factory=TranscriptEnhancementParams,
description="Transcript enhancement parameters"
)
language_detection: bool = Field(
default=False, # Disable language detection by default
description="Automatically detect language before transcription"
)
class Segment(BaseModel):
"""A segment of transcript with timing information."""
start: float = Field(..., description="Start time in seconds")
end: float = Field(..., description="End time in seconds")
text: str = Field(..., description="Segment text")
speaker: Optional[str] = Field(None, description="Speaker identifier")
words: Optional[List[Dict[str, Any]]] = Field(None, description="Word-level data")
confidence: Optional[float] = Field(None, description="Confidence score")
class AudioInfo(BaseModel):
"""Audio file information."""
duration: float = Field(..., description="Duration in seconds")
channels: int = Field(..., description="Number of audio channels")
sample_rate: int = Field(..., description="Sample rate in Hz")
format: str = Field(..., description="Audio format")
codec: Optional[str] = Field(None, description="Audio codec")
bit_depth: Optional[int] = Field(None, description="Bit depth")
bitrate: Optional[int] = Field(None, description="Bitrate in bits/second")
size_bytes: Optional[int] = Field(None, description="File size in bytes")
# --- Data Classes ---
@dataclass
class ProcessingContext:
"""Context for the transcription process."""
file_path: str
temp_dir: str
original_filename: str
base_filename: str
options: TranscriptionOptions
enhanced_audio_path: Optional[str] = None
processing_times: Dict[str, float] = None
language_code: Optional[str] = None
def __post_init__(self):
if self.processing_times is None:
self.processing_times = {}
# --- Tool Functions ---
@with_cache(ttl=24 * 60 * 60) # Cache results for 24 hours
@with_tool_metrics
@with_retry(max_retries=1, retry_delay=1.0)
@with_error_handling
async def transcribe_audio(
file_path: str,
options: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Transcribes an audio file to text with advanced preprocessing and enhancement.
This tool performs a multi-stage process:
1. Analyzes the audio file to determine optimal processing parameters
2. Enhances audio quality with adaptive filtering and preprocessing
3. Performs high-quality transcription with customizable settings
4. Intelligently enhances and structures the transcript for readability
5. Optionally identifies speakers and adds semantic structure
Args:
file_path: Path to the input audio file (.mp3, .m4a, .wav, etc.)
options: Optional dictionary with transcription options including:
- enhance_audio: Whether to preprocess audio (default True)
- enhance_transcript: Whether to enhance transcript (default True)
- parallel_processing: Process chunks in parallel (default True)
- output_formats: List of output formats (default ["json", "text"])
- audio_params: Audio enhancement parameters
- whisper_params: Whisper transcription parameters
- enhancement_params: Transcript enhancement parameters
Returns:
A dictionary containing:
{
"raw_transcript": "Original unmodified transcript from Whisper",
"enhanced_transcript": "LLM-enhanced transcript with improved formatting",
"segments": [
{
"start": 0.0,
"end": 10.5,
"text": "Segment text content",
"speaker": "Speaker 1", # If speaker diarization is enabled
"words": [...] # Word-level data if available
},
...
],
"metadata": {
"language": "en",
"duration": 120.5,
"title": "Automatically detected title",
"topics": ["Topic 1", "Topic 2"] # If topic extraction is enabled
},
"audio_info": {
"duration": 120.5,
"channels": 2,
"sample_rate": 44100,
"format": "wav",
"codec": "pcm_s16le",
"bit_depth": 16,
"bitrate": 1411000,
"size_bytes": 31000000
},
"processing_time": {
"audio_analysis": 0.5,
"audio_enhancement": 5.2,
"language_detection": 1.1,
"transcription": 45.3,
"transcript_enhancement": 10.2,
"total": 62.3
},
"artifacts": {
"enhanced_audio": "/path/to/enhanced.wav", # If save_enhanced_audio is True
"output_files": {
"json": "/path/to/transcript.json",
"text": "/path/to/transcript.txt",
"srt": "/path/to/transcript.srt"
}
},
"tokens": {
"input": 5000,
"output": 3200,
"total": 8200
},
"cost": 0.00185,
"success": true
}
Raises:
ToolInputError: If the file path is invalid or unsupported
ToolError: If transcription or enhancement fails
ResourceError: If required dependencies are not available
"""
# Start timing total processing
start_time = time.time()
# --- Input Validation ---
try:
if not file_path or not isinstance(file_path, str):
raise ToolInputError("File path must be a non-empty string.")
file_path = os.path.abspath(os.path.expanduser(file_path))
if not os.path.exists(file_path):
raise ToolInputError(f"File not found: {file_path}")
if not os.access(file_path, os.R_OK):
raise ToolInputError(f"File not readable: {file_path}")
# Validate file is an audio file
_, ext = os.path.splitext(file_path)
if ext.lower() not in ['.mp3', '.wav', '.m4a', '.mp4', '.flac', '.ogg', '.aac', '.wma', '.opus']:
raise ToolInputError(f"Unsupported file format: {ext}. Please provide an audio file.")
# Parse and validate options
parsed_options = parse_options(options or {})
except Exception as e:
if isinstance(e, ToolInputError):
raise
raise ToolInputError(f"Failed to validate input: {str(e)}") from e
# --- Initialize Processing Context ---
try:
temp_dir = tempfile.mkdtemp(prefix="llm_audio_")
original_filename = os.path.basename(file_path)
base_filename = os.path.splitext(original_filename)[0]
context = ProcessingContext(
file_path=file_path,
temp_dir=temp_dir,
original_filename=original_filename,
base_filename=base_filename,
options=parsed_options,
)
logger.info(
f"Starting audio transcription process for {original_filename}",
emoji_key="audio",
temp_dir=temp_dir
)
except Exception as e:
raise ToolError(f"Failed to initialize processing context: {str(e)}") from e
try:
# --- Check Dependencies ---
await check_dependencies(context)
# --- Process Audio ---
result = await process_audio_file(context)
# --- Calculate Total Time ---
total_time = time.time() - start_time
context.processing_times["total"] = total_time
result["processing_time"] = context.processing_times
logger.success(
f"Audio transcription completed in {total_time:.2f}s",
emoji_key="success",
file=context.original_filename,
duration=result.get("audio_info", {}).get("duration", 0)
)
return result
except Exception as e:
logger.error(
f"Audio transcription failed: {str(e)}",
emoji_key="error",
exc_info=True,
file=context.original_filename
)
# Clean up temporary directory unless keep_artifacts is True
if context.options.keep_artifacts:
logger.info(f"Keeping artifacts in {temp_dir}")
else:
try:
shutil.rmtree(temp_dir)
except Exception as cleanup_err:
logger.warning(f"Failed to clean up temporary directory: {cleanup_err}")
if isinstance(e, (ToolError, ToolInputError, ResourceError)):
raise
# Return a structured error response instead of just raising an exception
return {
"raw_transcript": "",
"enhanced_transcript": "",
"segments": [],
"metadata": {},
"audio_info": {},
"processing_time": context.processing_times if hasattr(context, "processing_times") else {},
"success": False,
"error": f"Audio transcription failed: {str(e)}"
}
finally:
# Clean up temporary directory unless keep_artifacts is True
if not context.options.keep_artifacts:
try:
shutil.rmtree(temp_dir)
except Exception as cleanup_err:
logger.warning(f"Failed to clean up temporary directory: {cleanup_err}")
# --- Main Processing Functions ---
async def process_audio_file(context: ProcessingContext) -> Dict[str, Any]:
"""Process an audio file through the complete transcription pipeline."""
# Get detailed audio information
audio_analysis_start = time.time()
audio_info = await get_detailed_audio_info(context.file_path)
context.processing_times["audio_analysis"] = time.time() - audio_analysis_start
logger.info(
f"Audio analysis: {audio_info.get('duration', 0):.1f}s duration, "
f"{audio_info.get('sample_rate', 0)} Hz, "
f"{audio_info.get('channels', 0)} channels",
emoji_key="audio"
)
# Update parameters based on audio analysis if needed
_update_parameters_from_audio_info(context, audio_info)
# --- Audio Enhancement ---
enhanced_audio_path = context.file_path
if context.options.enhance_audio:
audio_enhance_start = time.time()
logger.info("Enhancing audio quality", emoji_key="audio", profile=context.options.audio_params.profile.value)
enhanced_audio_path = await enhance_audio(context, audio_info)
if not enhanced_audio_path:
logger.warning("Audio enhancement failed, falling back to original file", emoji_key="warning")
enhanced_audio_path = context.file_path
context.processing_times["audio_enhancement"] = time.time() - audio_enhance_start
else:
context.processing_times["audio_enhancement"] = 0
context.enhanced_audio_path = enhanced_audio_path
if not os.path.exists(enhanced_audio_path):
logger.warning(f"Enhanced audio path does not exist: {enhanced_audio_path}", emoji_key="warning")
return {
"raw_transcript": "",
"enhanced_transcript": "",
"segments": [],
"metadata": {},
"audio_info": audio_info,
"processing_time": context.processing_times,
"success": False,
"error": "Enhanced audio file not found"
}
# --- Skip Language Detection ---
# Language detection is disabled - always use Whisper's built-in detection
context.processing_times["language_detection"] = 0
if context.options.whisper_params.language:
logger.info(f"Using specified language: {context.options.whisper_params.language}", emoji_key="language")
else:
logger.info("Using Whisper's built-in language detection", emoji_key="language")
# --- Transcribe Audio ---
transcribe_start = time.time()
model = context.options.whisper_params.model
quality = context.options.whisper_params.quality.value
logger.info(
f"Transcribing audio with model '{model}' (quality: {quality})",
emoji_key="transcribe"
)
try:
transcript_result = await transcribe_with_whisper(context)
context.processing_times["transcription"] = time.time() - transcribe_start
raw_transcript = transcript_result.get("text", "")
segments = transcript_result.get("segments", [])
# Check if transcript is empty
if not raw_transcript:
logger.warning("Whisper returned an empty transcript", emoji_key="warning")
else:
transcript_length = len(raw_transcript)
segments_count = len(segments)
logger.info(
f"Transcription complete: {transcript_length} characters, {segments_count} segments",
emoji_key="success"
)
# Extract metadata if available
metadata = transcript_result.get("metadata", {})
if context.language_code and "language" not in metadata:
metadata["language"] = context.language_code
except Exception as e:
logger.error(f"Transcription failed: {str(e)}", emoji_key="error", exc_info=True)
context.processing_times["transcription"] = time.time() - transcribe_start
return {
"raw_transcript": "",
"enhanced_transcript": "",
"segments": [],
"metadata": {"language": context.language_code} if context.language_code else {},
"audio_info": audio_info,
"processing_time": context.processing_times,
"success": False,
"error": f"Transcription failed: {str(e)}"
}
# --- Enhance Transcript ---
enhanced_transcript = raw_transcript
enhancement_cost = 0.0
enhancement_tokens = {"input": 0, "output": 0, "total": 0}
if context.options.enhance_transcript and raw_transcript:
enhance_start = time.time()
logger.info(
f"Enhancing transcript with style: {context.options.enhancement_params.style.value}",
emoji_key="enhance",
provider=context.options.enhancement_params.provider
)
try:
enhancement_result = await enhance_transcript(context, raw_transcript, metadata)
enhanced_transcript = enhancement_result["transcript"]
enhancement_cost = enhancement_result["cost"]
enhancement_tokens = enhancement_result["tokens"]
if "topics" in enhancement_result and enhancement_result["topics"]:
metadata["topics"] = enhancement_result["topics"]
if "title" in enhancement_result and enhancement_result["title"]:
metadata["title"] = enhancement_result["title"]
context.processing_times["transcript_enhancement"] = time.time() - enhance_start
if not enhanced_transcript:
logger.warning("Enhancement returned an empty transcript, falling back to raw transcript", emoji_key="warning")
enhanced_transcript = raw_transcript
else:
enhancement_length = len(enhanced_transcript)
logger.info(f"Enhancement complete: {enhancement_length} characters", emoji_key="success")
except Exception as e:
logger.error(f"Transcript enhancement failed: {e}", emoji_key="error", exc_info=True)
context.processing_times["transcript_enhancement"] = time.time() - enhance_start
# Fall back to raw transcript
enhanced_transcript = raw_transcript
else:
if not raw_transcript:
logger.warning("Skipping transcript enhancement because raw transcript is empty", emoji_key="warning")
elif not context.options.enhance_transcript:
logger.info("Transcript enhancement disabled by options", emoji_key="info")
context.processing_times["transcript_enhancement"] = 0
# --- Generate Output Files ---
artifact_paths = await generate_output_files(context, raw_transcript, enhanced_transcript, segments, metadata)
# --- Prepare Result ---
result = {
"raw_transcript": raw_transcript,
"enhanced_transcript": enhanced_transcript,
"segments": segments,
"metadata": metadata,
"audio_info": audio_info,
"tokens": enhancement_tokens,
"cost": enhancement_cost,
"artifacts": artifact_paths,
"success": bool(raw_transcript or enhanced_transcript)
}
return result
def parse_options(options: Dict[str, Any]) -> TranscriptionOptions:
"""Parse and validate transcription options."""
# Convert string output formats to enum values
if "output_formats" in options:
if isinstance(options["output_formats"], list):
formats = []
for fmt in options["output_formats"]:
if isinstance(fmt, str):
try:
formats.append(OutputFormat(fmt.lower()))
except ValueError:
logger.warning(f"Invalid output format: {fmt}, ignoring")
elif isinstance(fmt, OutputFormat):
formats.append(fmt)
if formats: # Only update if we have valid formats
options["output_formats"] = formats
# Handle nested parameter objects
for key in ["audio_params", "whisper_params", "enhancement_params"]:
if key in options and options[key]:
# If a dictionary is provided, keep it for Pydantic
if not isinstance(options[key], dict):
# Convert non-dict to dict by serializing/deserializing if possible
try:
options[key] = json.loads(json.dumps(options[key]))
except Exception:
# If can't convert, remove the invalid value
logger.warning(f"Invalid format for {key}, using defaults")
options.pop(key)
# Set audio profile parameters if a profile is specified
if "audio_params" in options and "profile" in options["audio_params"]:
profile = options["audio_params"]["profile"]
if isinstance(profile, str):
try:
profile = AudioEnhancementProfile(profile.lower())
# Update audio parameters based on the selected profile
options["audio_params"].update(_get_audio_profile_params(profile))
except ValueError:
logger.warning(f"Invalid audio profile: {profile}, using default")
# Set whisper quality parameters if quality is specified
if "whisper_params" in options and "quality" in options["whisper_params"]:
quality = options["whisper_params"]["quality"]
if isinstance(quality, str):
try:
quality = TranscriptionQuality(quality.lower())
# Update whisper parameters based on the selected quality
options["whisper_params"].update(_get_whisper_quality_params(quality))
except ValueError:
logger.warning(f"Invalid transcription quality: {quality}, using default")
# Parse with Pydantic model
try:
return TranscriptionOptions(**options)
except Exception as e:
logger.warning(f"Error parsing options: {e}, using defaults with valid values", emoji_key="warning")
# Create with default options
return TranscriptionOptions()
def _get_audio_profile_params(profile: AudioEnhancementProfile) -> Dict[str, Any]:
"""Get audio enhancement parameters for a specific profile."""
profiles = {
AudioEnhancementProfile.CONFERENCE_CALL: {
"volume": 1.5,
"noise_reduction": 10,
"highpass": 200,
"lowpass": 3000,
"compression": True,
"normalize": True,
"dereverberation": False
},
AudioEnhancementProfile.INTERVIEW: {
"volume": 1.3,
"noise_reduction": 8,
"highpass": 150,
"lowpass": 8000,
"compression": True,
"normalize": True,
"dereverberation": False
},
AudioEnhancementProfile.LECTURE: {
"volume": 1.4,
"noise_reduction": 6,
"highpass": 120,
"lowpass": 8000,
"compression": True,
"normalize": True,
"dereverberation": True
},
AudioEnhancementProfile.NOISY: {
"volume": 1.8,
"noise_reduction": 20,
"highpass": 250,
"lowpass": 3000,
"compression": True,
"normalize": True,
"dereverberation": True
},
AudioEnhancementProfile.PHONE_CALL: {
"volume": 2.0,
"noise_reduction": 15,
"highpass": 300,
"lowpass": 3400,
"compression": True,
"normalize": True,
"dereverberation": False
},
AudioEnhancementProfile.VOICEMAIL: {
"volume": 2.0,
"noise_reduction": 12,
"highpass": 250,
"lowpass": 3000,
"compression": True,
"normalize": True,
"dereverberation": False
}
}
return profiles.get(profile, {})
def _get_whisper_quality_params(quality: TranscriptionQuality) -> Dict[str, Any]:
"""Get whisper parameters for a specific quality level."""
quality_params = {
TranscriptionQuality.DRAFT: {
"beam_size": 1,
"processors": 1,
"word_timestamps": False,
"highlight_words": False
},
TranscriptionQuality.STANDARD: {
"beam_size": 5,
"processors": 2,
"word_timestamps": True,
"highlight_words": False
},
TranscriptionQuality.ENHANCED: {
"beam_size": 8,
"processors": 2,
"word_timestamps": True,
"highlight_words": True
},
TranscriptionQuality.MAXIMUM: {
"beam_size": 10,
"processors": 4,
"word_timestamps": True,
"highlight_words": True
}
}
return quality_params.get(quality, {})
def _update_parameters_from_audio_info(context: ProcessingContext, audio_info: Dict[str, Any]) -> None:
"""Update processing parameters based on audio file analysis."""
# If mono audio, adjust enhancement params
audio_params = context.options.audio_params
updated_params = False
if audio_info.get("channels", 0) == 1:
# Set output channels to match input if not explicitly set
if "output_channels" not in context.options.audio_params.dict():
# Create a copy with the updated parameter
audio_params = AudioEnhancementParams(
**{**audio_params.dict(), "output_channels": 1}
)
updated_params = True
# If low-quality audio, adjust enhancement profile
sample_rate = audio_info.get("sample_rate", 0)
if sample_rate < 16000 and context.options.audio_params.profile == AudioEnhancementProfile.CONFERENCE_CALL:
logger.info(f"Detected low sample rate ({sample_rate} Hz), adjusting enhancement profile", emoji_key="audio")
# Use phone_call profile for low sample rate audio
params = _get_audio_profile_params(AudioEnhancementProfile.PHONE_CALL)
# Create a copy with the updated parameters
audio_params = AudioEnhancementParams(
**{**audio_params.dict(), **params}
)
updated_params = True
# If params were updated, create a new options object with the updated audio_params
if updated_params:
context.options = TranscriptionOptions(
**{**context.options.dict(), "audio_params": audio_params}
)
# If very short audio (<10 seconds), adjust transcription quality
duration = audio_info.get("duration", 0)
if duration < 10 and context.options.whisper_params.quality != TranscriptionQuality.MAXIMUM:
logger.info(f"Short audio detected ({duration:.1f}s), increasing transcription quality", emoji_key="audio")
# Create a new whisper_params with enhanced quality
whisper_params = WhisperParams(
**{**context.options.whisper_params.dict(), "quality": TranscriptionQuality.ENHANCED}
)
# Update options with new whisper_params
context.options = TranscriptionOptions(
**{**context.options.dict(), "whisper_params": whisper_params}
)
# --- Dependency and Audio Processing Functions ---
async def download_whisper_model(model_name: str, output_path: str) -> bool:
"""Download a Whisper model from Hugging Face using httpx.
Args:
model_name: Name of the model to download (e.g. 'large-v3')
output_path: Path where to save the model file
Returns:
True if download was successful, False otherwise
"""
url = f"https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-{model_name}.bin"
logger.info(f"Downloading Whisper model '{model_name}' from {url}", emoji_key="download")
# Expected model size
min_size_bytes = WHISPER_MODEL_SIZES.get(model_name, 100000000) # Default to 100MB minimum if unknown
expected_size_bytes = min_size_bytes # Initialize with minimum size
try:
async with httpx.AsyncClient(timeout=None) as client:
# Make HEAD request first to check if the URL is valid and get expected size
try:
head_response = await client.head(url)
if head_response.status_code != 200:
logger.warning(f"Model URL is not accessible: HTTP {head_response.status_code} for {url}", emoji_key="warning")
return False
# If Content-Length is available, use it to check expected size
content_length = int(head_response.headers.get("content-length", 0))
if content_length > 0:
expected_size_mb = content_length / (1024 * 1024)
logger.info(f"Expected model size: {expected_size_mb:.1f} MB", emoji_key="info")
# Update expected size if it's larger than our preset minimum
if content_length > min_size_bytes:
expected_size_bytes = int(content_length * 0.95) # Allow 5% tolerance
except Exception as e:
logger.warning(f"Failed to validate model URL: {url} - Error: {str(e)}", emoji_key="warning")
# Continue anyway, the GET might still work
# Stream the response to handle large files
try:
async with client.stream("GET", url) as response:
if response.status_code != 200:
logger.warning(f"Failed to download model: HTTP {response.status_code} for {url}", emoji_key="warning")
if response.status_code == 404:
logger.warning(f"Model '{model_name}' not found on Hugging Face repository", emoji_key="warning")
return False
# Get content length if available
content_length = int(response.headers.get("content-length", 0))
if content_length == 0:
logger.warning("Content length is zero or not provided", emoji_key="warning")
downloaded = 0
# Open file for writing
try:
with open(output_path, "wb") as f:
# Display progress
last_log_time = time.time()
async for chunk in response.aiter_bytes(chunk_size=8192):
f.write(chunk)
downloaded += len(chunk)
# Log progress every 5 seconds
now = time.time()
if now - last_log_time > 5:
if content_length > 0:
percent = downloaded / content_length * 100
logger.info(f"Download progress: {percent:.1f}% ({downloaded/(1024*1024):.1f} MB)", emoji_key="download")
else:
logger.info(f"Downloaded {downloaded/(1024*1024):.1f} MB", emoji_key="download")
last_log_time = now
except IOError as e:
logger.warning(f"Failed to write to file {output_path}: {str(e)}", emoji_key="warning")
return False
except httpx.RequestError as e:
logger.warning(f"HTTP request error while downloading model: {str(e)}", emoji_key="warning")
return False
# Verify the file was downloaded
if not os.path.exists(output_path):
logger.warning(f"Downloaded file doesn't exist at {output_path}", emoji_key="warning")
return False
actual_size = os.path.getsize(output_path)
if actual_size == 0:
logger.warning(f"Downloaded file is empty at {output_path}", emoji_key="warning")
return False
# Verify file size meets expectations
actual_size_mb = actual_size / (1024 * 1024)
if actual_size < expected_size_bytes:
logger.warning(
f"Model file size ({actual_size_mb:.1f} MB) is smaller than expected minimum size "
f"({expected_size_bytes/(1024*1024):.1f} MB). File may be corrupted or incomplete.",
emoji_key="warning"
)
return False
logger.info(f"Successfully downloaded model to {output_path} ({actual_size_mb:.1f} MB)", emoji_key="success")
return True
except Exception as e:
logger.warning(f"Error downloading whisper model: {e}", emoji_key="warning", exc_info=True)
return False
async def check_dependencies(context: ProcessingContext) -> bool:
"""Verifies that required dependencies are installed and accessible."""
# Check ffmpeg
try:
result = await asyncio.create_subprocess_exec(
"ffmpeg", "-version",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await result.communicate()
if result.returncode != 0:
raise ResourceError(
"ffmpeg is not installed or not accessible. Please install it with 'apt install ffmpeg'."
)
# Extract ffmpeg version for logging
version_match = re.search(r'ffmpeg version (\S+)', stdout.decode('utf-8', errors='ignore'))
version = version_match.group(1) if version_match else "unknown"
logger.debug(f"Found ffmpeg version {version}", emoji_key="dependency")
except FileNotFoundError as e:
raise ResourceError(
"ffmpeg is not installed. Please install it with 'apt install ffmpeg'."
) from e
# Check whisper.cpp
whisper_path = os.path.expanduser("~/whisper.cpp")
# Use user-supplied model
model = context.options.whisper_params.model
model_path = os.path.join(whisper_path, "models", f"ggml-{model}.bin")
if not os.path.exists(whisper_path):
raise ResourceError(
f"whisper.cpp not found at {whisper_path}. Please install it first."
)
if not os.path.exists(model_path):
# Check if models directory exists
models_dir = os.path.join(whisper_path, "models")
if not os.path.exists(models_dir):
try:
os.makedirs(models_dir)
logger.info(f"Created models directory at {models_dir}", emoji_key="info")
except Exception as e:
raise ResourceError(f"Failed to create models directory: {e}") from e
# Try to automatically download the model using httpx
logger.info(f"Whisper model '{model}' not found at {model_path}", emoji_key="info")
logger.info(f"Attempting to download model '{model}' now...", emoji_key="download")
# Download the model directly using httpx - first check if model exists
if model == "large-v3":
# Double check if the model actually exists
test_url = f"https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-{model}.bin"
async with httpx.AsyncClient(timeout=10.0) as client:
try:
head_response = await client.head(test_url)
if head_response.status_code != 200:
# Model might not be directly available - show a clear error
if head_response.status_code == 404:
raise ResourceError(
f"Whisper model 'large-v3' not found in the HuggingFace repository at {test_url}\n"
f"Please download it manually with one of these commands:\n"
f"1. ~/whisper.cpp/models/download-ggml-model.sh large-v3\n"
f"2. Or try a different model like 'large-v3-turbo' which is known to be available"
)
except Exception as e:
logger.warning(f"Error checking model availability: {e}", emoji_key="warning")
# Continue with download attempt anyway
# Attempt to download the model
download_success = await download_whisper_model(model, model_path)
if not download_success:
# Modified error message with alternative suggestions
if model == "large-v3":
raise ResourceError(
f"Failed to download Whisper model '{model}'.\n"
f"You can:\n"
f"1. Try using a different model like 'large-v3-turbo' which is more reliable\n"
f"2. Or download manually with: ~/whisper.cpp/models/download-ggml-model.sh {model}"
)
else:
raise ResourceError(
f"Failed to download Whisper model '{model}'.\n"
f"Please download it manually with: ~/whisper.cpp/models/download-ggml-model.sh {model}"
)
# Verify that the model was downloaded
if not os.path.exists(model_path):
raise ResourceError(
f"Model download completed but model file not found at {model_path}. "
f"Please check the download and try again."
)
# Verify model file size
actual_size = os.path.getsize(model_path)
expected_min_size = WHISPER_MODEL_SIZES.get(model, 100000000) # Default to 100MB minimum if unknown
if actual_size < expected_min_size:
actual_size_mb = actual_size / (1024 * 1024)
expected_mb = expected_min_size / (1024 * 1024)
raise ResourceError(
f"Downloaded model file at {model_path} is too small ({actual_size_mb:.1f} MB). "
f"Expected at least {expected_mb:.1f} MB. File may be corrupted or incomplete. "
f"Please download it manually with: ~/whisper.cpp/models/download-ggml-model.sh {model}"
)
logger.info(f"Successfully downloaded Whisper model '{model}'", emoji_key="success")
else:
# Verify existing model file size
actual_size = os.path.getsize(model_path)
expected_min_size = WHISPER_MODEL_SIZES.get(model, 100000000) # Default to 100MB minimum if unknown
file_size_mb = actual_size / (1024 * 1024)
if actual_size < expected_min_size:
expected_mb = expected_min_size / (1024 * 1024)
logger.warning(
f"Existing model at {model_path} is suspiciously small ({file_size_mb:.1f} MB). "
f"Expected at least {expected_mb:.1f} MB. Model may be corrupted.",
emoji_key="warning"
)
else:
logger.info(f"Found existing model at {model_path} ({file_size_mb:.1f} MB)", emoji_key="dependency")
# Check if whisper binary is available in PATH using shlex for command safety
try:
result = subprocess.run(["which", "whisper-cli"], capture_output=True, text=True)
if result.returncode == 0:
whisper_path_found = result.stdout.strip()
logger.debug(f"Found whisper-cli in PATH: {whisper_path_found}", emoji_key="dependency")
else:
# Check in the expected location
whisper_path = os.path.expanduser("~/whisper.cpp")
whisper_bin = os.path.join(whisper_path, "build", "bin", "whisper-cli")
if not os.path.exists(whisper_bin):
raise ResourceError(
f"whisper-cli binary not found at {whisper_bin}. "
f"Please build whisper.cpp first with: "
f"cd ~/whisper.cpp && cmake -B build && cmake --build build -j --config Release"
)
logger.debug(f"Found whisper-cli at {whisper_bin}", emoji_key="dependency")
except FileNotFoundError as e:
raise ResourceError("Command 'which' not found. Cannot check for whisper-cli installation.") from e
logger.debug(f"Found whisper model: {model}", emoji_key="dependency")
return True
async def get_detailed_audio_info(file_path: str) -> Dict[str, Any]:
"""Gets detailed information about an audio file using ffprobe."""
cmd = [
"ffprobe",
"-v", "error",
"-show_entries", "format=duration,bit_rate,size",
"-select_streams", "a:0",
"-show_entries", "stream=channels,sample_rate,codec_name,bits_per_sample",
"-of", "json",
file_path
]
try:
process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode != 0:
logger.warning(f"Failed to get audio info: {stderr.decode('utf-8', errors='ignore')}", emoji_key="warning")
return {
"duration": 0,
"channels": 0,
"sample_rate": 0,
"format": os.path.splitext(file_path)[1][1:],
"codec": None,
"bit_depth": None,
"bitrate": None,
"size_bytes": os.path.getsize(file_path) if os.path.exists(file_path) else 0
}
info = json.loads(stdout)
format_info = info.get("format", {})
stream_info = info.get("streams", [{}])[0] if info.get("streams") else {}
# Extract information
duration = float(format_info.get("duration", 0))
channels = int(stream_info.get("channels", 0))
sample_rate = int(stream_info.get("sample_rate", 0))
codec = stream_info.get("codec_name")
bit_depth = int(stream_info.get("bits_per_sample", 0)) or None
bitrate = int(format_info.get("bit_rate", 0)) or None
size_bytes = int(format_info.get("size", 0)) or os.path.getsize(file_path)
audio_format = os.path.splitext(file_path)[1][1:]
return {
"duration": duration,
"channels": channels,
"sample_rate": sample_rate,
"format": audio_format,
"codec": codec,
"bit_depth": bit_depth,
"bitrate": bitrate,
"size_bytes": size_bytes
}
except Exception as e:
logger.warning(f"Error getting audio info: {e}", emoji_key="warning", exc_info=True)
try:
size_bytes = os.path.getsize(file_path) if os.path.exists(file_path) else 0
except Exception:
size_bytes = 0
return {
"duration": 0,
"channels": 0,
"sample_rate": 0,
"format": os.path.splitext(file_path)[1][1:],
"codec": None,
"bit_depth": None,
"bitrate": None,
"size_bytes": size_bytes
}
async def enhance_audio(context: ProcessingContext, audio_info: Dict[str, Any]) -> Optional[str]:
"""Enhances audio quality using ffmpeg preprocessing."""
# Create output path in temp directory
output_path = os.path.join(context.temp_dir, f"{context.base_filename}_enhanced.wav")
# Use optimized audio enhancement settings
# Build the complete command with the standardized enhancement settings
cmd = [
"ffmpeg",
"-i", context.file_path,
"-threads", str(os.cpu_count() or 1),
"-af", "volume=1.5, highpass=f=200, lowpass=f=3000, afftdn=nr=10:nf=-20, "
"compand=attacks=0:points=-80/-80|-45/-15|-27/-9|0/-7|20/-7:gain=5, "
"dynaudnorm=f=150:g=15:p=1:m=1:s=0, "
"pan=stereo|c0=c0|c1=c0",
"-ar", "16000",
"-ac", "2",
"-c:a", "pcm_s16le",
"-y", # Overwrite output if exists
output_path
]
logger.debug(f"Running ffmpeg command: {' '.join(cmd)}", emoji_key="command")
try:
process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode != 0:
error_msg = stderr.decode('utf-8', errors='ignore')
logger.error(f"FFmpeg error: {error_msg}", emoji_key="error")
return None
# Verify the output file was created and has a reasonable size
if os.path.exists(output_path) and os.path.getsize(output_path) > 1000:
file_size_mb = os.path.getsize(output_path) / (1024 * 1024)
logger.info(f"Enhanced audio saved to {output_path} ({file_size_mb:.2f} MB)", emoji_key="audio")
else:
logger.warning("Enhanced audio file is suspiciously small or doesn't exist", emoji_key="warning")
return None
# If we're keeping enhanced audio, copy it to a persistent location
if context.options.save_enhanced_audio:
original_dir = os.path.dirname(context.file_path)
persistent_path = os.path.join(original_dir, f"{context.base_filename}_enhanced.wav")
shutil.copy2(output_path, persistent_path)
logger.info(f"Saved enhanced audio to {persistent_path}", emoji_key="save")
logger.info("Audio enhancement completed", emoji_key="audio")
return output_path
except Exception as e:
logger.error(f"Error enhancing audio: {e}", emoji_key="error", exc_info=True)
return None
async def transcribe_with_whisper(context: ProcessingContext) -> Dict[str, Any]:
"""Transcribes audio using Whisper.cpp with advanced options."""
# Create output base name in temp directory
output_base = os.path.join(context.temp_dir, context.base_filename)
output_json = f"{output_base}.json"
output_txt = f"{output_base}.txt"
# Get whisper parameters
params = context.options.whisper_params
# Build command with configurable parameters
whisper_bin = os.path.expanduser("~/whisper.cpp/build/bin/whisper-cli")
model_path = os.path.expanduser(f"~/whisper.cpp/models/ggml-{params.model}.bin")
# Validate required files exist
if not os.path.exists(whisper_bin):
logger.error(f"Whisper binary not found at {whisper_bin}", emoji_key="error")
raise ToolError(f"Whisper binary not found at {whisper_bin}")
if not os.path.exists(model_path):
logger.error(f"Whisper model not found at {model_path}", emoji_key="error")
raise ToolError(f"Whisper model not found at {model_path}")
# Verify model file size
actual_size = os.path.getsize(model_path)
expected_min_size = WHISPER_MODEL_SIZES.get(params.model, 100000000) # Default to 100MB minimum if unknown
if actual_size < expected_min_size:
actual_size_mb = actual_size / (1024 * 1024)
expected_mb = expected_min_size / (1024 * 1024)
logger.warning(
f"Model file at {model_path} is smaller than expected ({actual_size_mb:.1f} MB, expected {expected_mb:.1f} MB)",
emoji_key="warning"
)
# Use file_path as fallback if enhanced_audio_path is None
audio_path = context.enhanced_audio_path or context.file_path
if not os.path.exists(audio_path):
logger.error(f"Audio file not found at {audio_path}", emoji_key="error")
raise ToolError(f"Audio file not found at {audio_path}")
cmd = [
whisper_bin,
"-m", model_path,
"-f", audio_path,
"-of", output_base,
"-oj" # Always output JSON for post-processing
]
# Add boolean flags
if params.word_timestamps:
cmd.append("-pc")
if params.translate:
cmd.append("-tr")
# Always output text for readability
cmd.append("-otxt")
# Add numeric parameters
cmd.extend(["-t", str(os.cpu_count() if params.processors <= 0 else params.processors)])
if params.beam_size:
cmd.extend(["-bs", str(params.beam_size)])
# Add language parameter if specified
if params.language:
cmd.extend(["-l", params.language])
# Add max context parameter if specified
if params.max_context > 0:
cmd.extend(["-mc", str(params.max_context)])
# Additional optimizations
cmd.append("-fa") # Full sentence timestamps (improved segmentation)
cmd.append("-pp") # Enable post-processing
# Add custom vocab if specified (create a vocab file)
if params.custom_vocab:
vocab_path = os.path.join(context.temp_dir, "custom_vocab.txt")
try:
async with aiofiles.open(vocab_path, 'w') as f:
await f.write("\n".join(params.custom_vocab))
cmd.extend(["-kv", vocab_path])
except Exception as e:
logger.warning(f"Failed to create custom vocab file: {e}", emoji_key="warning")
# Add diarization if requested
if params.diarize:
cmd.append("-dm")
cmd_str = ' '.join(cmd)
logger.debug(f"Running whisper command: {cmd_str}", emoji_key="command")
try:
process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
stderr_output = stderr.decode('utf-8', errors='ignore') if stderr else ""
stdout_output = stdout.decode('utf-8', errors='ignore') if stdout else ""
# Log outputs for debugging
if stdout_output:
logger.debug(f"Whisper stdout (first 500 chars): {stdout_output[:500]}", emoji_key="info")
if stderr_output:
if process.returncode != 0:
logger.error(f"Whisper stderr: {stderr_output}", emoji_key="error")
elif "error" in stderr_output.lower() or "warning" in stderr_output.lower():
logger.warning(f"Whisper warnings/errors: {stderr_output}", emoji_key="warning")
else:
logger.debug(f"Whisper stderr: {stderr_output}", emoji_key="info")
if process.returncode != 0:
error_msg = stderr_output or "Unknown error"
logger.error(f"Whisper transcription error (exit code {process.returncode}): {error_msg}", emoji_key="error")
raise ToolError(f"Whisper transcription failed with exit code {process.returncode}: {error_msg}")
# Check if output files exist
if not os.path.exists(output_json) and not os.path.exists(output_txt):
logger.error("Whisper completed successfully but no output files were created", emoji_key="error")
raise ToolError("Whisper completed successfully but no output files were created")
# Read results from the JSON file
if os.path.exists(output_json):
json_file_size = os.path.getsize(output_json)
logger.debug(f"Reading JSON output file: {output_json} ({json_file_size} bytes)", emoji_key="info")
if json_file_size < 10: # Suspiciously small
logger.warning(f"JSON output file is suspiciously small: {json_file_size} bytes", emoji_key="warning")
try:
async with aiofiles.open(output_json, 'r') as f:
content = await f.read()
try:
result = json.loads(content)
# Fix missing fields in result if needed
if "segments" not in result:
logger.warning("No segments found in Whisper JSON output", emoji_key="warning")
result["segments"] = []
if "text" not in result or not result.get("text"):
logger.warning("No transcript text found in Whisper JSON output", emoji_key="warning")
# Try to construct text from segments
if result.get("segments"):
reconstructed_text = " ".join([seg.get("text", "") for seg in result["segments"]])
if reconstructed_text:
logger.info("Reconstructed transcript text from segments", emoji_key="info")
result["text"] = reconstructed_text
else:
result["text"] = ""
else:
result["text"] = ""
# Extract metadata
metadata = {
"language": context.language_code or result.get("language"),
"duration": result.get("duration", 0)
}
result["metadata"] = metadata
except json.JSONDecodeError as e:
logger.error(f"Failed to parse Whisper JSON output: {e}", emoji_key="error")
logger.error(f"JSON content: {content[:1000]}...", emoji_key="error")
raise ToolError(f"Failed to parse Whisper output JSON: {e}") from e
except Exception as e:
logger.error(f"Failed to read Whisper JSON output file: {e}", emoji_key="error")
raise ToolError(f"Failed to read Whisper output: {e}") from e
else:
logger.warning(f"Whisper JSON output not found at expected path: {output_json}", emoji_key="warning")
# Fallback to text file
if os.path.exists(output_txt):
txt_file_size = os.path.getsize(output_txt)
logger.info(f"Falling back to text output file: {output_txt} ({txt_file_size} bytes)", emoji_key="info")
if txt_file_size < 10: # Suspiciously small
logger.warning(f"Text output file is suspiciously small: {txt_file_size} bytes", emoji_key="warning")
try:
async with aiofiles.open(output_txt, 'r') as f:
text = await f.read()
# Create minimal result structure
result = {
"text": text,
"segments": [{"text": text, "start": 0, "end": 0}],
"metadata": {
"language": context.language_code,
"duration": 0
}
}
if not text:
logger.warning("Text output file is empty", emoji_key="warning")
except Exception as e:
logger.error(f"Failed to read text output file: {e}", emoji_key="error")
raise ToolError(f"Failed to read Whisper text output: {e}") from e
else:
logger.error(f"No output files found from Whisper at {output_base}.*", emoji_key="error")
raise ToolError("No output files found from Whisper transcription")
# Check if we actually got a transcript
if not result.get("text"):
logger.warning("Whisper returned an empty transcript", emoji_key="warning")
# Clean up results (remove empty/duplicate segments)
cleaned_segments = clean_segments(result.get("segments", []))
result["segments"] = cleaned_segments
# Clean up text (remove dots-only lines, etc.)
result["text"] = clean_raw_transcript(result.get("text", ""))
return result
except ToolError:
raise
except Exception as e:
logger.error(f"Error in Whisper transcription: {e}", emoji_key="error", exc_info=True)
raise ToolError(f"Whisper transcription failed: {str(e)}") from e
def clean_segments(segments: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Clean and normalize segment data."""
cleaned_segments = []
seen_texts = set()
for segment in segments:
# Skip segments with empty or meaningless text
text = segment.get("text", "").strip()
if not text or re.match(r'^[\s.]*$', text):
continue
# Skip exact duplicates unless they have meaningful timing differences
# (within 0.5s of a previous segment)
is_duplicate = False
if text in seen_texts:
for prev_segment in cleaned_segments:
if prev_segment.get("text") == text:
start_diff = abs(prev_segment.get("start", 0) - segment.get("start", 0))
end_diff = abs(prev_segment.get("end", 0) - segment.get("end", 0))
if start_diff < 0.5 and end_diff < 0.5:
is_duplicate = True
break
if is_duplicate:
continue
# Add to seen texts
seen_texts.add(text)
# Standardize segment structure
clean_segment = {
"start": float(segment.get("start", 0)),
"end": float(segment.get("end", 0)),
"text": text
}
# Add optional fields if available
for field in ["speaker", "words", "confidence"]:
if field in segment:
clean_segment[field] = segment[field]
cleaned_segments.append(clean_segment)
# Sort by start time
cleaned_segments.sort(key=lambda x: x["start"])
return cleaned_segments
def clean_raw_transcript(text: str) -> str:
"""Cleans raw transcript."""
if not text:
return ""
# Split into lines and process
lines = text.split("\n")
cleaned_lines = []
seen_lines = set()
for line in lines:
line = line.strip()
# Skip empty lines
if not line:
continue
# Skip lines with just dots or other meaningless patterns
if re.match(r'^[\s.]*$', line) or line == '[BLANK_AUDIO]':
continue
# Standardize multiple spaces
line = re.sub(r'\s+', ' ', line)
# Keep duplicates if they're long (likely not duplicates but legitimate repetition)
if line in seen_lines and len(line) <= 50:
continue
seen_lines.add(line)
cleaned_lines.append(line)
# Join but ensure there's proper spacing
return "\n".join(cleaned_lines)
# --- Transcript Enhancement Functions ---
async def enhance_transcript(
context: ProcessingContext,
transcript: str,
metadata: Dict[str, Any]
) -> Dict[str, Any]:
"""Enhance a transcript with formatting, readability and semantic structuring."""
if not transcript or transcript.strip() == "":
return {
"transcript": "",
"tokens": {"input": 0, "output": 0, "total": 0},
"cost": 0.0,
"topics": [],
"title": None
}
# Extract key parameters
params = context.options.enhancement_params
# Split transcript into manageable chunks
chunks = await chunk_text(transcript, params.max_chunk_size)
if not chunks:
return {
"transcript": transcript,
"tokens": {"input": 0, "output": 0, "total": 0},
"cost": 0.0,
"topics": [],
"title": None
}
# First analyze context to get a summary of the content
context_data = await detect_subject_matter(
chunks[0],
params.provider,
params.model,
metadata
)
# Track topics if available
topics = context_data.get("topics", [])
title = context_data.get("title")
context_info = context_data.get("context", "")
logger.info(f"Content analysis complete: {len(topics)} topics identified", emoji_key="analyze")
# Process chunks concurrently if parallel processing is enabled
if context.options.parallel_processing and len(chunks) > 1:
enhanced_chunks = await process_chunks_parallel(
context,
chunks,
context_info,
params
)
else:
enhanced_chunks = await process_chunks_sequential(
context,
chunks,
context_info,
params
)
# Calculate total metrics
total_tokens = {"input": 0, "output": 0, "total": 0}
total_cost = 0.0
for chunk_data in enhanced_chunks:
chunk_tokens = chunk_data.get("tokens", {})
total_tokens["input"] += chunk_tokens.get("input", 0)
total_tokens["output"] += chunk_tokens.get("output", 0)
total_tokens["total"] += chunk_tokens.get("total", 0)
total_cost += chunk_data.get("cost", 0.0)
# Join the enhanced chunks
enhanced_transcript = "\n\n".join(chunk_data["text"] for chunk_data in enhanced_chunks)
# If sections are enabled, try to add section headings
if params.sections and topics:
enhanced_transcript = await add_section_headings(
enhanced_transcript,
topics,
params.provider,
params.model
)
return {
"transcript": enhanced_transcript,
"tokens": total_tokens,
"cost": total_cost,
"topics": topics,
"title": title
}
async def chunk_text(text: str, max_chunk_size: int = 6500) -> List[str]:
"""Split text into chunks with intelligent boundary detection."""
if len(text) <= max_chunk_size:
return [text]
# Define patterns for natural breaks, prioritized
patterns = [
r'\n\s*\n\s*\n', # Triple line break (highest priority)
r'\n\s*\n', # Double line break
r'(?<=[.!?])\s+(?=[A-Z])', # Sentence boundary with capital letter following
r'(?<=[.!?])\s', # Any sentence boundary
r'(?<=[,:;])\s' # Phrase boundary (lowest priority)
]
chunks = []
remaining_text = text
while remaining_text:
if len(remaining_text) <= max_chunk_size:
chunks.append(remaining_text)
break
# Start with an initial chunk at max size
chunk_candidate = remaining_text[:max_chunk_size]
split_position = None
# Try each pattern in order of priority
for pattern in patterns:
# Look for the last occurrence of the pattern
matches = list(re.finditer(pattern, chunk_candidate))
if matches:
# Use the last match as the split point
split_position = matches[-1].end()
break
# Fallback if no natural breaks found
if split_position is None or split_position < max_chunk_size // 2:
# Look for the last space after a minimum chunk size
min_size = max(max_chunk_size // 2, 1000)
last_space = chunk_candidate.rfind(' ', min_size)
if last_space > min_size:
split_position = last_space
else:
# Forced split at max_chunk_size
split_position = max_chunk_size
# Create chunk and update remaining text
chunks.append(remaining_text[:split_position].strip())
remaining_text = remaining_text[split_position:].strip()
# Validate chunks
validated_chunks = []
for chunk in chunks:
if chunk and len(chunk) >= 100: # Minimum viable chunk size
validated_chunks.append(chunk)
elif chunk:
# If chunk is too small, combine with previous or next chunk
if validated_chunks:
validated_chunks[-1] += "\n\n" + chunk
else:
# This is the first chunk and it's too small - rare case
validated_chunks.append(chunk)
return validated_chunks
async def detect_subject_matter(
text: str,
provider: str,
model: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Analyze transcript to determine subject matter, topics, and title."""
prompt = """Analyze this transcript excerpt for the following:
1. CONTEXT: The primary domain or topic being discussed (e.g., technology, business, healthcare, etc.)
2. SPEAKERS: The likely type and number of speakers (e.g., interview, panel, lecture, etc.)
3. TOPICS: List 2-5 specific topics covered, in order of importance
4. TITLE: A short, descriptive title for this content (under 10 words)
Return your analysis in JSON format ONLY:
{
"context": "Brief description of the domain and conversation type",
"topics": ["Topic 1", "Topic 2", "Topic 3"],
"title": "Concise descriptive title"
}
Transcript excerpt:
{text}"""
# Include metadata if available
if metadata and metadata.get("language"):
language = metadata.get("language")
prompt += f"\n\nMetadata: The transcript language is {language}."
try:
result = await chat_completion(
messages=[{"role": "user", "content": prompt.format(text=text)}],
provider=provider,
model=model,
temperature=0
)
if result.get("success") and "message" in result:
content = result["message"].get("content", "")
# Try to parse JSON from content
try:
# Extract JSON if it's embedded in text
json_match = re.search(r'({[\s\S]*})', content)
if json_match:
analysis = json.loads(json_match.group(1))
else:
analysis = json.loads(content)
return {
"context": analysis.get("context", ""),
"topics": analysis.get("topics", []),
"title": analysis.get("title")
}
except json.JSONDecodeError:
# Fallback: extract fields manually
context_match = re.search(r'context["\s:]+([^"}\n]+)', content, re.IGNORECASE)
title_match = re.search(r'title["\s:]+([^"}\n]+)', content, re.IGNORECASE)
topics_match = re.search(r'topics["\s:]+\[(.*?)\]', content, re.IGNORECASE | re.DOTALL)
context = context_match.group(1).strip() if context_match else ""
title = title_match.group(1).strip() if title_match else None
topics = []
if topics_match:
topics_text = topics_match.group(1)
topics = [t.strip().strip('"\'') for t in re.findall(r'"([^"]+)"', topics_text)]
if not topics:
topics = [t.strip() for t in topics_text.split(',') if t.strip()]
return {
"context": context,
"topics": topics,
"title": title
}
except Exception as e:
logger.warning(f"Subject matter detection failed: {e}", emoji_key="warning")
return {
"context": "",
"topics": [],
"title": None
}
async def process_chunks_parallel(
context: ProcessingContext,
chunks: List[str],
context_info: str,
params: TranscriptEnhancementParams
) -> List[Dict[str, Any]]:
"""Process transcript chunks in parallel for better performance."""
# Limit max workers to CPU count or specified max
max_workers = min(context.options.max_workers, os.cpu_count() or 4, len(chunks))
# Create a semaphore to limit concurrency
sem = asyncio.Semaphore(max_workers)
# Create a thread pool for parallel processing
chunk_results = []
async def process_chunk(i, chunk):
"""Process an individual chunk."""
async with sem: # Use semaphore to limit concurrency
logger.info(f"Enhancing chunk {i+1}/{len(chunks)}", emoji_key="enhance")
try:
result = await enhance_chunk(chunk, context_info, params, i, len(chunks))
return result
except Exception as e:
logger.error(f"Error enhancing chunk {i+1}: {e}", emoji_key="error", exc_info=True)
return {"text": chunk, "tokens": {"input": 0, "output": 0, "total": 0}, "cost": 0.0}
# Create tasks for parallel execution
tasks = [process_chunk(i, chunk) for i, chunk in enumerate(chunks)]
chunk_results = await asyncio.gather(*tasks)
# Make sure results are in original order (they should be)
return chunk_results
async def process_chunks_sequential(
context: ProcessingContext,
chunks: List[str],
context_info: str,
params: TranscriptEnhancementParams
) -> List[Dict[str, Any]]:
"""Process transcript chunks sequentially to preserve context flow."""
enhanced_chunks = []
accumulated_context = context_info
for i, chunk in enumerate(chunks):
logger.info(f"Enhancing chunk {i+1}/{len(chunks)}", emoji_key="enhance")
# Update context with information from previous chunks
if i > 0 and enhanced_chunks:
# Add brief summary of what was covered in previous chunk
previous_text = enhanced_chunks[-1]["text"]
if len(previous_text) > 500:
accumulated_context += f"\nPrevious chunk ended with: {previous_text[-500:]}"
else:
accumulated_context += f"\nPrevious chunk: {previous_text}"
try:
result = await enhance_chunk(chunk, accumulated_context, params, i, len(chunks))
enhanced_chunks.append(result)
except Exception as e:
logger.error(f"Error enhancing chunk {i+1}: {e}", emoji_key="error", exc_info=True)
# Use original text on error
enhanced_chunks.append({"text": chunk, "tokens": {"input": 0, "output": 0, "total": 0}, "cost": 0.0})
return enhanced_chunks
async def enhance_chunk(
chunk: str,
context_info: str,
params: TranscriptEnhancementParams,
chunk_index: int,
total_chunks: int
) -> Dict[str, Any]:
"""Enhance a single transcript chunk with LLM."""
# Build the prompt based on enhancement parameters
style_instructions = _get_style_instructions(params.style)
fix_instructions = []
if params.add_paragraphs:
fix_instructions.append("- Add paragraph breaks at natural topic transitions")
if params.fix_spelling:
fix_instructions.append("- Fix obvious spelling errors while preserving domain-specific terms")
if params.fix_grammar:
fix_instructions.append("- Fix basic grammatical errors without changing style or meaning")
if params.format_numbers:
fix_instructions.append("- Format numbers consistently (e.g., '25' instead of 'twenty-five')")
if params.identify_speakers:
fix_instructions.append("- Try to identify different speakers and label them as Speaker 1, Speaker 2, etc.")
fix_instructions.append("- Format speaker changes as 'Speaker N: text' on a new line")
fix_section = "\n".join(fix_instructions) if fix_instructions else "None"
# Add custom instructions if provided
custom_section = f"\nADDITIONAL INSTRUCTIONS:\n{params.custom_instructions}" if params.custom_instructions else ""
# Mention chunk position in context
position_info = f"This is chunk {chunk_index+1} of {total_chunks}." if total_chunks > 1 else ""
prompt = f"""You are cleaning up a raw transcript from a recorded conversation. {position_info}
CONTENT CONTEXT: {context_info}
ENHANCEMENT STYLE: {style_instructions}
CLEANUP INSTRUCTIONS:
1. Remove filler sounds: "um", "uh", "er", "ah", "hmm"
2. Remove stutters and word repetitions: "the- the", "I- I"
3. Remove meaningless filler phrases when used as pure filler: "you know", "like", "sort of"
4. Fix clear transcription errors and garbled text
5. Add proper punctuation for readability
{fix_section}
STRICT PRESERVATION RULES:
1. DO NOT modify, rephrase, or restructure ANY of the speaker's content
2. DO NOT add ANY new content or explanations
3. DO NOT make the language more formal or technical
4. DO NOT summarize or condense anything
5. PRESERVE ALL technical terms, numbers, and specific details exactly as spoken
6. PRESERVE the speaker's unique speaking style and personality
{custom_section}
Here's the transcript chunk to clean:
{chunk}
Return ONLY the cleaned transcript text with no explanations, comments, or metadata."""
try:
result = await chat_completion(
messages=[{"role": "user", "content": prompt}],
provider=params.provider,
model=params.model,
temperature=0.1, # Low temperature for consistent results
max_tokens=min(len(chunk) * 2, 8192) # Reasonable token limit based on input size
)
if result.get("success") and "message" in result:
enhanced_text = result["message"].get("content", "").strip()
# Validation: if enhanced text is much shorter, it might have been summarized
if len(enhanced_text) < len(chunk) * 0.6:
logger.warning(
f"Enhanced text suspiciously short ({len(enhanced_text)} vs {len(chunk)} chars), "
f"may have been summarized. Using original.",
emoji_key="warning"
)
enhanced_text = chunk
return {
"text": enhanced_text,
"tokens": result.get("tokens", {"input": 0, "output": 0, "total": 0}),
"cost": result.get("cost", 0.0)
}
return {"text": chunk, "tokens": {"input": 0, "output": 0, "total": 0}, "cost": 0.0}
except Exception as e:
logger.error(f"Chunk enhancement error: {e}", emoji_key="error")
return {"text": chunk, "tokens": {"input": 0, "output": 0, "total": 0}, "cost": 0.0}
def _get_style_instructions(style: EnhancementStyle) -> str:
"""Get instructions for the specified enhancement style."""
styles = {
EnhancementStyle.RAW: (
"Minimal cleaning only. Preserve all speech patterns and informality. "
"Focus on removing only transcription errors and unintelligible elements."
),
EnhancementStyle.READABLE: (
"Basic readability improvements. Light cleanup while preserving natural speech patterns. "
"Remove only clear disfluencies and maintain conversational flow."
),
EnhancementStyle.POLISHED: (
"Well-formatted with proper punctuation and clean sentences. "
"Remove speech disfluencies but preserve the speaker's voice and style. "
"Create a professional but authentic reading experience."
),
EnhancementStyle.VERBATIM: (
"Preserve all speech patterns, hesitations, and repetitions. "
"Format for readability but maintain every verbal quirk and pause. "
"Indicate hesitations with ellipses [...] and preserve every repeated word or phrase."
),
EnhancementStyle.STRUCTURED: (
"Add semantic structure with clear paragraphs around topics. "
"Clean speech for maximum readability while preserving content accuracy. "
"Organize into logical sections while keeping all original information."
)
}
return styles.get(style, styles[EnhancementStyle.READABLE])
async def add_section_headings(
transcript: str,
topics: List[str],
provider: str,
model: Optional[str] = None
) -> str:
"""Add section headings to the transcript based on topic changes."""
if not transcript or not topics:
return transcript
prompt = """Add clear section headings to this transcript based on topic changes.
TOPICS COVERED (in approximate order):
{topics}
RULES:
1. Insert section headings as "## [Topic]" on their own line
2. Place headings ONLY where there is a clear topic change
3. Use at most {max_sections} headings total
4. NEVER add content or edit the existing text
5. NEVER remove any original content
6. Base headings on the given topics list, but you can adjust wording for clarity
7. Don't duplicate headings for the same topic
8. Keep headings short and descriptive (2-6 words each)
TRANSCRIPT:
{text}
Return the full transcript with section headings added."""
# Adjust max sections based on transcript length
token_estimate = len(transcript) // 4
max_sections = min(len(topics) + 1, token_estimate // 1000 + 1)
topics_text = "\n".join([f"- {topic}" for topic in topics])
try:
result = await chat_completion(
messages=[{
"role": "user",
"content": prompt.format(topics=topics_text, text=transcript, max_sections=max_sections)
}],
provider=provider,
model=model,
temperature=0.1,
max_tokens=min(len(transcript) * 2, 8192)
)
if result.get("success") and "message" in result:
return result["message"].get("content", "").strip()
return transcript
except Exception as e:
logger.warning(f"Failed to add section headings: {e}", emoji_key="warning")
return transcript
# --- Output File Generation ---
async def generate_output_files(
context: ProcessingContext,
raw_transcript: str,
enhanced_transcript: str,
segments: List[Dict[str, Any]],
metadata: Dict[str, Any]
) -> Dict[str, Any]:
"""Generate output files in requested formats."""
artifact_paths = {
"output_files": {}
}
# Save enhanced audio path if requested
if context.options.save_enhanced_audio and context.enhanced_audio_path:
original_dir = os.path.dirname(context.file_path)
persistent_path = os.path.join(original_dir, f"{context.base_filename}_enhanced.wav")
# May have already been saved during enhancement
if not os.path.exists(persistent_path) and os.path.exists(context.enhanced_audio_path):
shutil.copy2(context.enhanced_audio_path, persistent_path)
artifact_paths["enhanced_audio"] = persistent_path
# Generate requested output formats
output_formats = context.options.output_formats
output_dir = os.path.dirname(context.file_path)
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
# Convert temp files to Path objects for easier path manipulation
output_dir = Path(os.path.dirname(context.file_path))
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
# Generate JSON output
if OutputFormat.JSON in output_formats:
json_path = output_dir / f"{context.base_filename}_transcript_{timestamp}.json"
# Create JSON data structure
json_data = {
"metadata": {
"filename": context.original_filename,
"processed_at": timestamp,
**metadata
},
"raw_transcript": raw_transcript,
"enhanced_transcript": enhanced_transcript,
"segments": segments
}
# Write JSON file
async with aiofiles.open(str(json_path), 'w') as f:
await f.write(json.dumps(json_data, indent=2))
artifact_paths["output_files"]["json"] = json_path
# Generate TEXT output
if OutputFormat.TEXT in output_formats:
text_path = output_dir / f"{context.base_filename}_transcript_{timestamp}.txt"
# Create plain text file
async with aiofiles.open(str(text_path), 'w') as f:
# Add metadata header if available
if metadata:
if "title" in metadata and metadata["title"]:
await f.write(f"Title: {metadata['title']}\n")
if "language" in metadata and metadata["language"]:
await f.write(f"Language: {metadata['language']}\n")
if "topics" in metadata and metadata["topics"]:
topics_str = ", ".join(metadata["topics"])
await f.write(f"Topics: {topics_str}\n")
await f.write("\n")
# Write enhanced transcript if available, otherwise use raw transcript
await f.write(enhanced_transcript or raw_transcript)
artifact_paths["output_files"]["text"] = text_path
# Generate SRT output
if OutputFormat.SRT in output_formats:
srt_path = output_dir / f"{context.base_filename}_transcript_{timestamp}.srt"
# Convert segments to SRT format
srt_content = generate_srt(segments)
# Write SRT file
async with aiofiles.open(str(srt_path), 'w') as f:
await f.write(srt_content)
artifact_paths["output_files"]["srt"] = srt_path
# Generate VTT output
if OutputFormat.VTT in output_formats:
vtt_path = output_dir / f"{context.base_filename}_transcript_{timestamp}.vtt"
# Convert segments to VTT format
vtt_content = generate_vtt(segments)
# Write VTT file
async with aiofiles.open(str(vtt_path), 'w') as f:
await f.write(vtt_content)
artifact_paths["output_files"]["vtt"] = vtt_path
# Generate Markdown output
if OutputFormat.MARKDOWN in output_formats:
md_path = output_dir / f"{context.base_filename}_transcript_{timestamp}.md"
# Create markdown content
md_content = generate_markdown(enhanced_transcript, metadata)
# Write markdown file
async with aiofiles.open(str(md_path), 'w') as f:
await f.write(md_content)
artifact_paths["output_files"]["markdown"] = md_path
# Generate DOCX output (if supported)
if OutputFormat.DOCX in output_formats:
try:
docx_path = output_dir / f"{context.base_filename}_transcript_{timestamp}.docx"
# Generate DOCX in a thread pool to avoid blocking
with concurrent.futures.ThreadPoolExecutor() as executor:
await asyncio.get_event_loop().run_in_executor(
executor,
generate_docx,
docx_path,
enhanced_transcript,
metadata
)
artifact_paths["output_files"]["docx"] = docx_path
except (ImportError, Exception) as e:
logger.warning(f"Failed to generate DOCX output: {e}", emoji_key="warning")
return artifact_paths
def generate_srt(segments: List[Dict[str, Any]]) -> str:
"""Generate SRT format from segments."""
srt_lines = []
for i, segment in enumerate(segments):
# Convert times to SRT format (HH:MM:SS,mmm)
start_time = segment.get("start", 0)
end_time = segment.get("end", 0)
start_str = format_srt_time(start_time)
end_str = format_srt_time(end_time)
# Format text
text = segment.get("text", "").replace("\n", " ")
# Add speaker if available
if "speaker" in segment and segment["speaker"]:
text = f"[{segment['speaker']}] {text}"
# Add to SRT
srt_lines.append(f"{i+1}")
srt_lines.append(f"{start_str} --> {end_str}")
srt_lines.append(f"{text}")
srt_lines.append("") # Empty line between entries
return "\n".join(srt_lines)
def format_srt_time(seconds: float) -> str:
"""Format seconds as SRT time: HH:MM:SS,mmm."""
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
sec_int = int(seconds % 60)
ms = int((seconds % 1) * 1000)
return f"{hours:02d}:{minutes:02d}:{sec_int:02d},{ms:03d}"
def generate_vtt(segments: List[Dict[str, Any]]) -> str:
"""Generate WebVTT format from segments."""
vtt_lines = ["WEBVTT", ""]
for segment in segments:
# Convert times to VTT format (HH:MM:SS.mmm)
start_time = segment.get("start", 0)
end_time = segment.get("end", 0)
start_str = format_vtt_time(start_time)
end_str = format_vtt_time(end_time)
# Format text
text = segment.get("text", "").replace("\n", " ")
# Add speaker if available
if "speaker" in segment and segment["speaker"]:
text = f"<v {segment['speaker']}>{text}</v>"
# Add to VTT
vtt_lines.append(f"{start_str} --> {end_str}")
vtt_lines.append(f"{text}")
vtt_lines.append("") # Empty line between entries
return "\n".join(vtt_lines)
def format_vtt_time(seconds: float) -> str:
"""Format seconds as WebVTT time: HH:MM:SS.mmm."""
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
sec_fractional = seconds % 60
return f"{hours:02d}:{minutes:02d}:{sec_fractional:06.3f}"
def generate_markdown(transcript: str, metadata: Dict[str, Any]) -> str:
"""Generate Markdown format for the transcript."""
md_lines = []
# Add title
if "title" in metadata and metadata["title"]:
md_lines.append(f"# {metadata['title']}")
md_lines.append("")
else:
md_lines.append("# Transcript")
md_lines.append("")
# Add metadata section
md_lines.append("## Metadata")
md_lines.append("")
if "language" in metadata and metadata["language"]:
md_lines.append(f"- **Language:** {metadata['language']}")
if "duration" in metadata and metadata["duration"]:
duration_min = int(metadata["duration"] // 60)
duration_sec = int(metadata["duration"] % 60)
md_lines.append(f"- **Duration:** {duration_min} min {duration_sec} sec")
if "topics" in metadata and metadata["topics"]:
topics_str = ", ".join(metadata["topics"])
md_lines.append(f"- **Topics:** {topics_str}")
md_lines.append("")
md_lines.append("## Transcript")
md_lines.append("")
# Add transcript with proper line breaks preserved
for line in transcript.split("\n"):
md_lines.append(line)
return "\n".join(md_lines)
def generate_docx(docx_path: str, transcript: str, metadata: Dict[str, Any]) -> None:
"""Generate DOCX format for the transcript."""
# Must be run in a ThreadPoolExecutor since python-docx is not async
doc = Document()
# Add title
if "title" in metadata and metadata["title"]:
title = doc.add_heading(metadata["title"], 0)
else:
title = doc.add_heading("Transcript", 0) # noqa: F841
# Add metadata section
doc.add_heading("Metadata", 1)
if "language" in metadata and metadata["language"]:
doc.add_paragraph(f"Language: {metadata['language']}")
if "duration" in metadata and metadata["duration"]:
duration_min = int(metadata["duration"] // 60)
duration_sec = int(metadata["duration"] % 60)
doc.add_paragraph(f"Duration: {duration_min} min {duration_sec} sec")
if "topics" in metadata and metadata["topics"]:
topics_str = ", ".join(metadata["topics"])
doc.add_paragraph(f"Topics: {topics_str}")
# Add transcript
doc.add_heading("Transcript", 1)
# Split into paragraphs and add
for paragraph in transcript.split("\n\n"):
if paragraph.strip():
p = doc.add_paragraph()
# Check if paragraph starts with a heading marker
if paragraph.startswith("##"):
parts = paragraph.split(" ", 1)
if len(parts) > 1:
doc.add_heading(parts[1], 2)
continue
# Regular paragraph
p.add_run(paragraph)
# Save the document
doc.save(docx_path)
async def chat_with_transcript(
transcript: str,
query: str,
provider: str = Provider.ANTHROPIC.value,
model: Optional[str] = None,
context: Optional[str] = None
) -> Dict[str, Any]:
"""Chat with a transcript to extract information or answer questions about its content.
Args:
transcript: The transcript text to analyze
query: The question or instruction to process regarding the transcript
provider: LLM provider to use (default: Anthropic)
model: Specific model to use (default: provider's default model)
context: Optional additional context about the audio/transcript
Returns:
A dictionary containing the response and related metadata
"""
if not transcript or not isinstance(transcript, str):
raise ToolInputError("Transcript must be a non-empty string.")
if not query or not isinstance(query, str):
raise ToolInputError("Query must be a non-empty string.")
# Calculate token count for logging
try:
transcript_tokens = count_tokens(transcript, model)
query_tokens = count_tokens(query, model)
logger.info(
f"Transcript: {transcript_tokens} tokens, Query: {query_tokens} tokens",
emoji_key=TaskType.CHAT.value
)
except Exception as e:
logger.warning(f"Failed to count tokens: {e}", emoji_key="warning")
# Build the prompt
system_prompt = """You are an expert at analyzing transcripts and extracting information.
Provide concise, accurate answers based solely on the provided transcript.
If the answer is not in the transcript, say so clearly."""
if context:
system_prompt += f"\n\nAdditional context about this transcript: {context}"
# Get provider instance to ensure it exists and is available
try:
provider_instance = await get_provider(provider)
if model is None:
# Check if the provider has a default model or use claude-3-7-sonnet as fallback
default_models = await provider_instance.list_models()
if default_models and len(default_models) > 0:
model = default_models[0].get("id")
else:
model = "claude-3-7-sonnet-20250219" if provider == Provider.ANTHROPIC.value else None
logger.info(f"Using model: {provider}/{model}", emoji_key="model")
except Exception as e:
raise ProviderError(
f"Failed to initialize provider '{provider}': {str(e)}",
provider=provider,
cause=e
) from e
# Use relative file paths for any file references
rel_transcript_path = None
if os.path.exists(transcript):
rel_transcript_path = Path(transcript).relative_to(Path.cwd()) # noqa: F841
# Create the message with the transcript and query
user_message = f"""Here is a transcript to analyze:
---TRANSCRIPT BEGIN---
{transcript}
---TRANSCRIPT END---
{query}"""
# Send to LLM
result = await chat_completion(
messages=[{"role": "user", "content": user_message}],
provider=provider,
model=model,
system_prompt=system_prompt,
temperature=0.1
)
return result
@with_cache(ttl=24 * 60 * 60) # Cache results for 24 hours
@with_tool_metrics
@with_retry(max_retries=1, retry_delay=1.0)
@with_error_handling
async def extract_audio_transcript_key_points(
file_path_or_transcript: str,
is_file: bool = True,
provider: str = Provider.ANTHROPIC.value,
model: Optional[str] = None,
max_points: int = 10
) -> Dict[str, Any]:
"""Extracts the most important key points from an audio transcript.
This tool can process either an audio file (which it will transcribe first)
or directly analyze an existing transcript to identify the most important
information, main topics, and key takeaways.
Args:
file_path_or_transcript: Path to audio file or transcript text content
is_file: Whether the input is a file path (True) or transcript text (False)
provider: LLM provider to use for analysis
model: Specific model to use (provider default if None)
max_points: Maximum number of key points to extract
Returns:
A dictionary containing:
{
"key_points": ["Point 1", "Point 2", ...],
"summary": "Brief summary of the content",
"topics": ["Topic 1", "Topic 2", ...],
"speakers": ["Speaker 1", "Speaker 2", ...] (if multiple speakers detected),
"tokens": { statistics about token usage },
"cost": estimated cost of the operation,
"processing_time": total processing time in seconds
}
"""
start_time = time.time()
# Get transcript from file or use provided text
transcript = ""
if is_file:
try:
# Validate file path
file_path = os.path.abspath(os.path.expanduser(file_path_or_transcript))
if not os.path.exists(file_path):
raise ToolInputError(f"File not found: {file_path}")
# Get file info for logging
file_size_mb = os.path.getsize(file_path) / (1024 * 1024)
file_name = Path(file_path).name
logger.info(
f"Extracting key points from audio file: {file_name} ({file_size_mb:.2f} MB)",
emoji_key="audio"
)
# Transcribe audio
transcription_result = await transcribe_audio(file_path, {
"enhance_audio": True,
"enhance_transcript": True,
"output_formats": ["json"]
})
transcript = transcription_result.get("enhanced_transcript", "")
if not transcript:
transcript = transcription_result.get("raw_transcript", "")
if not transcript:
raise ToolError("Failed to generate transcript from audio")
except Exception as e:
if isinstance(e, (ToolError, ToolInputError)):
raise
raise ToolError(f"Failed to process audio file: {str(e)}") from e
else:
# Input is already a transcript
transcript = file_path_or_transcript
if not transcript or not isinstance(transcript, str):
raise ToolInputError("Transcript text must be a non-empty string")
# Calculate token count for the transcript
try:
token_count = count_tokens(transcript, model)
logger.info(f"Transcript token count: {token_count}", emoji_key="tokens")
except Exception as e:
logger.warning(f"Failed to count tokens: {e}")
# Create prompt for key points extraction
prompt = f"""Extract the most important key points from this transcript.
Identify:
1. The {max_points} most important key points or takeaways
2. Main topics discussed
3. Any speakers or main entities mentioned (if identifiable)
4. A brief summary (2-3 sentences max)
Format your response as JSON with these fields:
{{
"key_points": ["Point 1", "Point 2", ...],
"topics": ["Topic 1", "Topic 2", ...],
"speakers": ["Speaker 1", "Speaker 2", ...],
"summary": "Brief summary here"
}}
TRANSCRIPT:
{transcript}
"""
# Get provider instance
try:
provider_instance = await get_provider(provider) # noqa: F841
except Exception as e:
raise ProviderError(
f"Failed to initialize provider '{provider}': {str(e)}",
provider=provider,
cause=e
) from e
# Generate completion
try:
completion_result = await generate_completion(
prompt=prompt,
provider=provider,
model=model,
temperature=0.1,
max_tokens=1000
)
# Parse JSON response
response_text = completion_result.get("text", "")
# Find JSON in the response
json_match = re.search(r'({[\s\S]*})', response_text)
if json_match:
try:
extracted_data = json.loads(json_match.group(1))
except json.JSONDecodeError:
# Fallback to regex extraction
extracted_data = _extract_key_points_with_regex(response_text)
else:
# Fallback to regex extraction
extracted_data = _extract_key_points_with_regex(response_text)
processing_time = time.time() - start_time
# Add token and cost info
result = {
"key_points": extracted_data.get("key_points", []),
"topics": extracted_data.get("topics", []),
"speakers": extracted_data.get("speakers", []),
"summary": extracted_data.get("summary", ""),
"tokens": completion_result.get("tokens", {"input": 0, "output": 0, "total": 0}),
"cost": completion_result.get("cost", 0.0),
"processing_time": processing_time
}
return result
except Exception as e:
error_model = model or f"{provider}/default"
raise ProviderError(
f"Key points extraction failed for model '{error_model}': {str(e)}",
provider=provider,
model=error_model,
cause=e
) from e
def _extract_key_points_with_regex(text: str) -> Dict[str, Any]:
"""Extract key points data using regex when JSON parsing fails."""
result = {
"key_points": [],
"topics": [],
"speakers": [],
"summary": ""
}
# Extract key points
key_points_pattern = r'key_points"?\s*:?\s*\[\s*"([^"]+)"(?:\s*,\s*"([^"]+)")*\s*\]'
key_points_match = re.search(key_points_pattern, text, re.IGNORECASE | re.DOTALL)
if key_points_match:
point_list = re.findall(r'"([^"]+)"', key_points_match.group(0))
# Filter out empty strings
result["key_points"] = [p for p in point_list if p.strip()]
else:
# Try alternative pattern for non-JSON format
point_list = re.findall(r'(?:^|\n)(?:•|\*|-|[0-9]+\.)\s*([^\n]+?)(?:\n|$)', text)
# Filter out empty strings
result["key_points"] = [p.strip() for p in point_list if p.strip()][:10] # Limit to 10 points
# Extract topics
topics_pattern = r'topics"?\s*:?\s*\[\s*"([^"]+)"(?:\s*,\s*"([^"]+)")*\s*\]'
topics_match = re.search(topics_pattern, text, re.IGNORECASE | re.DOTALL)
if topics_match:
topic_list = re.findall(r'"([^"]+)"', topics_match.group(0))
# Filter out empty strings
result["topics"] = [t for t in topic_list if t.strip()]
# Extract speakers
speakers_pattern = r'speakers"?\s*:?\s*\[\s*"([^"]+)"(?:\s*,\s*"([^"]+)")*\s*\]'
speakers_match = re.search(speakers_pattern, text, re.IGNORECASE | re.DOTALL)
if speakers_match:
speaker_list = re.findall(r'"([^"]+)"', speakers_match.group(0))
# Filter out empty strings
result["speakers"] = [s for s in speaker_list if s.strip()]
# Extract summary
summary_pattern = r'summary"?\s*:?\s*"([^"]+)"'
summary_match = re.search(summary_pattern, text, re.IGNORECASE)
if summary_match and summary_match.group(1).strip():
result["summary"] = summary_match.group(1).strip()
return result
```