This is page 2 of 4. Use http://codebase.md/arthurcolle/openai-mcp?page={x} to view the full context.
# Directory Structure
```
├── .gitignore
├── claude_code
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-312.pyc
│ │ └── mcp_server.cpython-312.pyc
│ ├── claude.py
│ ├── commands
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-312.pyc
│ │ │ └── serve.cpython-312.pyc
│ │ ├── client.py
│ │ ├── multi_agent_client.py
│ │ └── serve.py
│ ├── config
│ │ └── __init__.py
│ ├── examples
│ │ ├── agents_config.json
│ │ ├── claude_mcp_config.html
│ │ ├── claude_mcp_config.json
│ │ ├── echo_server.py
│ │ └── README.md
│ ├── lib
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ └── __init__.cpython-312.pyc
│ │ ├── context
│ │ │ └── __init__.py
│ │ ├── monitoring
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── __init__.cpython-312.pyc
│ │ │ │ └── server_metrics.cpython-312.pyc
│ │ │ ├── cost_tracker.py
│ │ │ └── server_metrics.py
│ │ ├── providers
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ └── openai.py
│ │ ├── rl
│ │ │ ├── __init__.py
│ │ │ ├── grpo.py
│ │ │ ├── mcts.py
│ │ │ └── tool_optimizer.py
│ │ ├── tools
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── __init__.cpython-312.pyc
│ │ │ │ ├── base.cpython-312.pyc
│ │ │ │ ├── file_tools.cpython-312.pyc
│ │ │ │ └── manager.cpython-312.pyc
│ │ │ ├── ai_tools.py
│ │ │ ├── base.py
│ │ │ ├── code_tools.py
│ │ │ ├── file_tools.py
│ │ │ ├── manager.py
│ │ │ └── search_tools.py
│ │ └── ui
│ │ ├── __init__.py
│ │ └── tool_visualizer.py
│ ├── mcp_server.py
│ ├── README_MCP_CLIENT.md
│ ├── README_MULTI_AGENT.md
│ └── util
│ └── __init__.py
├── claude.py
├── cli.py
├── data
│ └── prompt_templates.json
├── deploy_modal_mcp.py
├── deploy.sh
├── examples
│ ├── agents_config.json
│ └── echo_server.py
├── install.sh
├── mcp_modal_adapter.py
├── mcp_server.py
├── modal_mcp_server.py
├── README_modal_mcp.md
├── README.md
├── requirements.txt
├── setup.py
├── static
│ └── style.css
├── templates
│ └── index.html
└── web-client.html
```
# Files
--------------------------------------------------------------------------------
/claude_code/commands/multi_agent_client.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
# claude_code/commands/multi_agent_client.py
"""Multi-agent MCP client implementation with synchronization capabilities."""
import asyncio
import sys
import os
import json
import logging
import uuid
import argparse
import time
from typing import Optional, Dict, Any, List, Set, Tuple
from contextlib import AsyncExitStack
from dataclasses import dataclass, field, asdict
from rich.console import Console
from rich.prompt import Prompt
from rich.panel import Panel
from rich.markdown import Markdown
from rich.table import Table
from rich.live import Live
from rich import print as rprint
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from anthropic import Anthropic
from dotenv import load_dotenv
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load environment variables
load_dotenv()
# Console for rich output
console = Console()
@dataclass
class Agent:
"""Agent representation for multi-agent scenarios."""
id: str
name: str
role: str
model: str
system_prompt: str
conversation: List[Dict[str, Any]] = field(default_factory=list)
connected_agents: Set[str] = field(default_factory=set)
message_queue: asyncio.Queue = field(default_factory=asyncio.Queue)
def __post_init__(self):
"""Initialize the conversation with system prompt."""
self.conversation = [{
"role": "system",
"content": self.system_prompt
}]
@dataclass
class Message:
"""Message for agent communication."""
id: str = field(default_factory=lambda: str(uuid.uuid4()))
sender_id: str = ""
sender_name: str = ""
recipient_id: Optional[str] = None # None means broadcast to all
recipient_name: Optional[str] = None
content: str = ""
timestamp: float = field(default_factory=time.time)
@classmethod
def create(cls, sender_id: str, sender_name: str, content: str,
recipient_id: Optional[str] = None, recipient_name: Optional[str] = None) -> 'Message':
"""Create a new message."""
return cls(
sender_id=sender_id,
sender_name=sender_name,
recipient_id=recipient_id,
recipient_name=recipient_name,
content=content
)
class AgentCoordinator:
"""Coordinates communication between multiple agents."""
def __init__(self):
"""Initialize the agent coordinator."""
self.agents: Dict[str, Agent] = {}
self.message_history: List[Message] = []
self.broadcast_queue: asyncio.Queue = asyncio.Queue()
def add_agent(self, agent: Agent) -> None:
"""Add a new agent to the coordinator.
Args:
agent: The agent to add
"""
self.agents[agent.id] = agent
def remove_agent(self, agent_id: str) -> None:
"""Remove an agent from the coordinator.
Args:
agent_id: ID of the agent to remove
"""
if agent_id in self.agents:
del self.agents[agent_id]
async def broadcast_message(self, message: Message) -> None:
"""Broadcast a message to all agents.
Args:
message: The message to broadcast
"""
self.message_history.append(message)
for agent_id, agent in self.agents.items():
# Don't send message back to sender
if agent_id != message.sender_id:
await agent.message_queue.put(message)
logger.debug(f"Queued message from {message.sender_name} to {agent.name}")
async def send_direct_message(self, message: Message) -> None:
"""Send a message to a specific agent.
Args:
message: The message to send
"""
self.message_history.append(message)
if message.recipient_id in self.agents:
recipient = self.agents[message.recipient_id]
await recipient.message_queue.put(message)
logger.debug(f"Queued direct message from {message.sender_name} to {recipient.name}")
async def process_message(self, message: Message) -> None:
"""Process an incoming message and route appropriately.
Args:
message: The message to process
"""
if message.recipient_id is None:
# Broadcast message
await self.broadcast_message(message)
else:
# Direct message
await self.send_direct_message(message)
def get_message_history_for_agent(self, agent_id: str) -> List[Dict[str, Any]]:
"""Get conversation messages formatted for a specific agent.
Args:
agent_id: ID of the agent
Returns:
List of messages in the format expected by Claude
"""
agent = self.agents.get(agent_id)
if not agent:
return []
messages = []
# Start with the agent's conversation history
messages.extend(agent.conversation)
# Add relevant messages from the message history
for msg in self.message_history:
# Include messages sent by this agent or addressed to this agent
# or broadcast messages from other agents
if (msg.sender_id == agent_id or
msg.recipient_id == agent_id or
(msg.recipient_id is None and msg.sender_id != agent_id)):
if msg.sender_id == agent_id:
# This agent's own messages
messages.append({
"role": "assistant",
"content": msg.content
})
else:
# Messages from other agents
sender = self.agents.get(msg.sender_id)
sender_name = sender.name if sender else msg.sender_name
if msg.recipient_id is None:
# Broadcast message
messages.append({
"role": "user",
"content": f"{sender_name}: {msg.content}"
})
else:
# Direct message
messages.append({
"role": "user",
"content": f"{sender_name} (direct): {msg.content}"
})
return messages
class MultiAgentMCPClient:
"""Multi-agent Model Context Protocol client with synchronization capabilities."""
def __init__(self, config_path: str = None):
"""Initialize the multi-agent MCP client.
Args:
config_path: Path to the agent configuration file
"""
self.session: Optional[ClientSession] = None
self.exit_stack = AsyncExitStack()
self.anthropic = Anthropic()
self.coordinator = AgentCoordinator()
self.available_tools = []
# Configuration
self.config_path = config_path
self.agents_config = self._load_agents_config()
def _load_agents_config(self) -> List[Dict[str, Any]]:
"""Load agent configurations from file.
Returns:
List of agent configurations
"""
default_config = [{
"name": "Assistant",
"role": "general assistant",
"model": "claude-3-5-sonnet-20241022",
"system_prompt": "You are a helpful AI assistant participating in a multi-agent conversation. You can communicate with other agents and humans to solve complex problems."
}]
if not self.config_path:
return default_config
try:
with open(self.config_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.warning(f"Failed to load agent configuration: {e}")
return default_config
def setup_agents(self) -> None:
"""Set up agents based on configuration."""
for idx, agent_config in enumerate(self.agents_config):
agent_id = str(uuid.uuid4())
agent = Agent(
id=agent_id,
name=agent_config.get("name", f"Agent-{idx+1}"),
role=agent_config.get("role", "assistant"),
model=agent_config.get("model", "claude-3-5-sonnet-20241022"),
system_prompt=agent_config.get("system_prompt", "You are a helpful AI assistant.")
)
self.coordinator.add_agent(agent)
logger.info(f"Created agent: {agent.name} ({agent.role})")
async def connect_to_server(self, server_script_path: str):
"""Connect to an MCP server.
Args:
server_script_path: Path to the server script (.py or .js)
"""
is_python = server_script_path.endswith('.py')
is_js = server_script_path.endswith('.js')
if not (is_python or is_js):
raise ValueError("Server script must be a .py or .js file")
command = "python" if is_python else "node"
server_params = StdioServerParameters(
command=command,
args=[server_script_path],
env=None
)
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
self.stdio, self.write = stdio_transport
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
await self.session.initialize()
# List available tools
response = await self.session.list_tools()
tools = response.tools
self.available_tools = [{
"name": tool.name,
"description": tool.description,
"input_schema": tool.inputSchema
} for tool in response.tools]
tool_names = [tool.name for tool in tools]
logger.info(f"Connected to server with tools: {tool_names}")
console.print(Panel.fit(
f"[bold green]Connected to MCP server[/bold green]\n"
f"Available tools: {', '.join(tool_names)}",
title="Connection Status",
border_style="green"
))
async def process_agent_query(self, agent_id: str, query: str, is_direct_message: bool = False) -> str:
"""Process a query using Claude and available tools for a specific agent.
Args:
agent_id: The ID of the agent processing the query
query: The query to process
is_direct_message: Whether this is a direct message from user
Returns:
The response text
"""
agent = self.coordinator.agents.get(agent_id)
if not agent:
return "Error: Agent not found"
# Get the conversation history for this agent
messages = self.coordinator.get_message_history_for_agent(agent_id)
# Add the current query if it's a direct message
if is_direct_message:
messages.append({
"role": "user",
"content": query
})
# Initial Claude API call
response = self.anthropic.messages.create(
model=agent.model,
max_tokens=1000,
messages=messages,
tools=self.available_tools
)
# Process response and handle tool calls
tool_results = []
final_text = ""
assistant_message_content = []
for content in response.content:
if content.type == 'text':
final_text = content.text
assistant_message_content.append(content)
elif content.type == 'tool_use':
tool_name = content.name
tool_args = content.input
# Execute tool call
result = await self.session.call_tool(tool_name, tool_args)
tool_results.append({"call": tool_name, "result": result})
console.print(f"[bold cyan]Agent {agent.name} calling tool {tool_name}[/bold cyan]")
assistant_message_content.append(content)
messages.append({
"role": "assistant",
"content": assistant_message_content
})
messages.append({
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": content.id,
"content": result.content
}
]
})
# Get next response from Claude
response = self.anthropic.messages.create(
model=agent.model,
max_tokens=1000,
messages=messages,
tools=self.available_tools
)
final_text = response.content[0].text
# Create a message from the agent's response
message = Message.create(
sender_id=agent_id,
sender_name=agent.name,
content=final_text,
recipient_id=None # Broadcast to all
)
# Process the message
await self.coordinator.process_message(message)
return final_text
async def process_user_query(self, query: str, target_agent_id: Optional[str] = None) -> None:
"""Process a query from the user and route it to agents.
Args:
query: The user query
target_agent_id: Optional ID of a specific agent to target
"""
# Handle special commands
if query.startswith("/"):
await self._handle_special_command(query)
return
if target_agent_id:
# Direct message to a specific agent
agent = self.coordinator.agents.get(target_agent_id)
if not agent:
console.print("[bold red]Error: Agent not found[/bold red]")
return
console.print(f"[bold blue]User → {agent.name}:[/bold blue] {query}")
response = await self.process_agent_query(target_agent_id, query, is_direct_message=True)
console.print(f"[bold green]{agent.name}:[/bold green] {response}")
else:
# Broadcast to all agents
console.print(f"[bold blue]User (broadcast):[/bold blue] {query}")
# Create a message from the user
message = Message.create(
sender_id="user",
sender_name="User",
content=query,
recipient_id=None # Broadcast
)
# Process the message
await self.coordinator.process_message(message)
# Process in parallel for all agents
tasks = []
for agent_id in self.coordinator.agents:
tasks.append(asyncio.create_task(self.process_agent_query(agent_id, query)))
# Wait for all agents to respond
await asyncio.gather(*tasks)
async def run_agent_thought_loops(self) -> None:
"""Run continuous thought loops for each agent in the background."""
while True:
for agent_id, agent in self.coordinator.agents.items():
try:
# Check if there are new messages for this agent
if not agent.message_queue.empty():
message = await agent.message_queue.get()
# Log the message
if message.recipient_id is None:
console.print(f"[bold cyan]{message.sender_name} (broadcast):[/bold cyan] {message.content}")
else:
console.print(f"[bold cyan]{message.sender_name} → {agent.name}:[/bold cyan] {message.content}")
# Give the agent a chance to respond
await self.process_agent_query(agent_id, message.content)
# Mark the message as processed
agent.message_queue.task_done()
except Exception as e:
logger.exception(f"Error in agent thought loop for {agent.name}: {e}")
# Small delay to prevent CPU hogging
await asyncio.sleep(0.1)
async def _handle_special_command(self, command: str) -> None:
"""Handle special commands.
Args:
command: The command string starting with /
"""
parts = command.strip().split()
cmd = parts[0].lower()
args = parts[1:]
if cmd == "/help":
self._show_help()
elif cmd == "/agents":
self._show_agents()
elif cmd == "/talk":
if len(args) < 2:
console.print("[bold red]Error: /talk requires agent name and message[/bold red]")
return
agent_name = args[0]
message = " ".join(args[1:])
# Find agent by name
target_agent = None
for agent_id, agent in self.coordinator.agents.items():
if agent.name.lower() == agent_name.lower():
target_agent = agent
break
if target_agent:
await self.process_user_query(message, target_agent.id)
else:
console.print(f"[bold red]Error: Agent '{agent_name}' not found[/bold red]")
elif cmd == "/history":
self._show_message_history()
elif cmd == "/quit" or cmd == "/exit":
console.print("[bold yellow]Exiting multi-agent session...[/bold yellow]")
sys.exit(0)
else:
console.print(f"[bold red]Unknown command: {cmd}[/bold red]")
self._show_help()
def _show_help(self) -> None:
"""Show help information."""
help_text = """
# Multi-Agent MCP Client Commands
- **/help**: Show this help message
- **/agents**: List all active agents
- **/talk <agent> <message>**: Send a direct message to a specific agent
- **/history**: Show message history
- **/quit**, **/exit**: Exit the application
To broadcast a message to all agents, simply type your message without any command.
"""
console.print(Markdown(help_text))
def _show_agents(self) -> None:
"""Show information about all active agents."""
table = Table(title="Active Agents")
table.add_column("Name", style="cyan")
table.add_column("Role", style="green")
table.add_column("Model", style="blue")
for agent_id, agent in self.coordinator.agents.items():
table.add_row(agent.name, agent.role, agent.model)
console.print(table)
def _show_message_history(self) -> None:
"""Show the message history."""
if not self.coordinator.message_history:
console.print("[yellow]No messages in history yet.[/yellow]")
return
table = Table(title="Message History")
table.add_column("Time", style="cyan")
table.add_column("From", style="green")
table.add_column("To", style="blue")
table.add_column("Message", style="white")
for msg in self.coordinator.message_history:
timestamp = time.strftime("%H:%M:%S", time.localtime(msg.timestamp))
recipient = msg.recipient_name if msg.recipient_name else "All"
table.add_row(timestamp, msg.sender_name, recipient, msg.content[:50] + ("..." if len(msg.content) > 50 else ""))
console.print(table)
async def chat_loop(self) -> None:
"""Run the interactive chat loop."""
console.print(Panel.fit(
"[bold green]Multi-Agent MCP Client Started![/bold green]\n"
"Type your messages to broadcast to all agents or use /help for commands.",
title="Welcome",
border_style="green"
))
# Start the agent thought loop in the background
thought_loop_task = asyncio.create_task(self.run_agent_thought_loops())
try:
# First, show active agents
self._show_agents()
# Main chat loop
while True:
try:
query = Prompt.ask("\n[bold blue]>[/bold blue]").strip()
if not query:
continue
if query.lower() == "quit" or query.lower() == "exit":
break
await self.process_user_query(query)
except KeyboardInterrupt:
console.print("\n[bold yellow]Operation cancelled.[/bold yellow]")
continue
except Exception as e:
console.print(f"\n[bold red]Error: {str(e)}[/bold red]")
logger.exception("Error processing query")
finally:
# Cancel the thought loop task
thought_loop_task.cancel()
try:
await thought_loop_task
except asyncio.CancelledError:
pass
async def cleanup(self) -> None:
"""Clean up resources."""
await self.exit_stack.aclose()
def add_arguments(parser: argparse.ArgumentParser) -> None:
"""Add command-specific arguments to the parser.
Args:
parser: Argument parser
"""
parser.add_argument(
"server_script",
type=str,
help="Path to the server script (.py or .js)"
)
parser.add_argument(
"--config",
type=str,
help="Path to agent configuration JSON file"
)
def execute(args: argparse.Namespace) -> int:
"""Execute the multi-agent client command.
Args:
args: Command arguments
Returns:
Exit code
"""
try:
client = MultiAgentMCPClient(config_path=args.config)
client.setup_agents()
async def run_client():
try:
await client.connect_to_server(args.server_script)
await client.chat_loop()
finally:
await client.cleanup()
asyncio.run(run_client())
return 0
except Exception as e:
logger.exception(f"Error running multi-agent MCP client: {e}")
console.print(f"[bold red]Error: {str(e)}[/bold red]")
return 1
def main() -> int:
"""Run the multi-agent client command as a standalone script."""
parser = argparse.ArgumentParser(description="Run the Claude Code Multi-Agent MCP client")
add_arguments(parser)
args = parser.parse_args()
return execute(args)
if __name__ == "__main__":
sys.exit(main())
```
--------------------------------------------------------------------------------
/claude_code/lib/tools/code_tools.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
# claude_code/lib/tools/code_tools.py
"""Code analysis and manipulation tools."""
import os
import logging
import subprocess
import tempfile
import json
from typing import Dict, List, Optional, Any, Union
import ast
import re
from .base import tool, ToolRegistry
logger = logging.getLogger(__name__)
@tool(
name="CodeAnalyze",
description="Analyze code to extract structure, dependencies, and complexity metrics",
parameters={
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "The absolute path to the file to analyze"
},
"analysis_type": {
"type": "string",
"description": "Type of analysis to perform",
"enum": ["structure", "complexity", "dependencies", "all"],
"default": "all"
}
},
"required": ["file_path"]
},
category="code"
)
def analyze_code(file_path: str, analysis_type: str = "all") -> str:
"""Analyze code to extract structure and metrics.
Args:
file_path: Path to the file to analyze
analysis_type: Type of analysis to perform
Returns:
Analysis results as formatted text
"""
logger.info(f"Analyzing code in {file_path} (type: {analysis_type})")
if not os.path.isabs(file_path):
return f"Error: File path must be absolute: {file_path}"
if not os.path.exists(file_path):
return f"Error: File not found: {file_path}"
try:
# Read the file
with open(file_path, 'r', encoding='utf-8', errors='replace') as f:
code = f.read()
# Get file extension
_, ext = os.path.splitext(file_path)
ext = ext.lower()
# Determine language
if ext in ['.py']:
return _analyze_python(code, analysis_type)
elif ext in ['.js', '.jsx', '.ts', '.tsx']:
return _analyze_javascript(code, analysis_type)
elif ext in ['.java']:
return _analyze_java(code, analysis_type)
elif ext in ['.c', '.cpp', '.cc', '.h', '.hpp']:
return _analyze_cpp(code, analysis_type)
else:
return _analyze_generic(code, analysis_type)
except Exception as e:
logger.exception(f"Error analyzing code: {str(e)}")
return f"Error analyzing code: {str(e)}"
def _analyze_python(code: str, analysis_type: str) -> str:
"""Analyze Python code."""
result = []
# Structure analysis
if analysis_type in ["structure", "all"]:
try:
tree = ast.parse(code)
# Extract classes
classes = [node for node in ast.walk(tree) if isinstance(node, ast.ClassDef)]
if classes:
result.append("Classes:")
for cls in classes:
methods = [node.name for node in ast.walk(cls) if isinstance(node, ast.FunctionDef)]
result.append(f" - {cls.name}")
if methods:
result.append(" Methods:")
for method in methods:
result.append(f" - {method}")
# Extract functions
functions = [node for node in ast.walk(tree) if isinstance(node, ast.FunctionDef) and
not any(isinstance(parent, ast.ClassDef) for parent in ast.iter_child_nodes(tree))]
if functions:
result.append("\nFunctions:")
for func in functions:
result.append(f" - {func.name}")
# Extract imports
imports = []
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for name in node.names:
imports.append(name.name)
elif isinstance(node, ast.ImportFrom):
module = node.module or ""
for name in node.names:
imports.append(f"{module}.{name.name}")
if imports:
result.append("\nImports:")
for imp in imports:
result.append(f" - {imp}")
except SyntaxError as e:
result.append(f"Error parsing Python code: {str(e)}")
# Complexity analysis
if analysis_type in ["complexity", "all"]:
try:
# Count lines of code
lines = code.count('\n') + 1
non_empty_lines = sum(1 for line in code.split('\n') if line.strip())
comment_lines = sum(1 for line in code.split('\n') if line.strip().startswith('#'))
result.append("\nComplexity Metrics:")
result.append(f" - Total lines: {lines}")
result.append(f" - Non-empty lines: {non_empty_lines}")
result.append(f" - Comment lines: {comment_lines}")
result.append(f" - Code lines: {non_empty_lines - comment_lines}")
# Cyclomatic complexity (simplified)
tree = ast.parse(code)
complexity = 1 # Base complexity
for node in ast.walk(tree):
if isinstance(node, (ast.If, ast.While, ast.For, ast.comprehension)):
complexity += 1
elif isinstance(node, ast.BoolOp) and isinstance(node.op, ast.And):
complexity += len(node.values) - 1
result.append(f" - Cyclomatic complexity (estimated): {complexity}")
except Exception as e:
result.append(f"Error calculating complexity: {str(e)}")
# Dependencies analysis
if analysis_type in ["dependencies", "all"]:
try:
# Extract imports
tree = ast.parse(code)
std_lib_imports = []
third_party_imports = []
local_imports = []
std_lib_modules = [
"abc", "argparse", "ast", "asyncio", "base64", "collections", "concurrent", "contextlib",
"copy", "csv", "datetime", "decimal", "enum", "functools", "glob", "gzip", "hashlib",
"http", "io", "itertools", "json", "logging", "math", "multiprocessing", "os", "pathlib",
"pickle", "random", "re", "shutil", "socket", "sqlite3", "string", "subprocess", "sys",
"tempfile", "threading", "time", "typing", "unittest", "urllib", "uuid", "xml", "zipfile"
]
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for name in node.names:
module = name.name.split('.')[0]
if module in std_lib_modules:
std_lib_imports.append(name.name)
else:
third_party_imports.append(name.name)
elif isinstance(node, ast.ImportFrom):
if node.module:
module = node.module.split('.')[0]
if module in std_lib_modules:
for name in node.names:
std_lib_imports.append(f"{node.module}.{name.name}")
elif node.level > 0: # Relative import
for name in node.names:
local_imports.append(f"{'.' * node.level}{node.module or ''}.{name.name}")
else:
for name in node.names:
third_party_imports.append(f"{node.module}.{name.name}")
result.append("\nDependencies:")
if std_lib_imports:
result.append(" Standard Library:")
for imp in sorted(set(std_lib_imports)):
result.append(f" - {imp}")
if third_party_imports:
result.append(" Third-Party:")
for imp in sorted(set(third_party_imports)):
result.append(f" - {imp}")
if local_imports:
result.append(" Local/Project:")
for imp in sorted(set(local_imports)):
result.append(f" - {imp}")
except Exception as e:
result.append(f"Error analyzing dependencies: {str(e)}")
return "\n".join(result)
def _analyze_javascript(code: str, analysis_type: str) -> str:
"""Analyze JavaScript/TypeScript code."""
result = []
# Structure analysis
if analysis_type in ["structure", "all"]:
try:
# Extract functions using regex (simplified)
function_pattern = r'(function\s+(\w+)|const\s+(\w+)\s*=\s*function|const\s+(\w+)\s*=\s*\(.*?\)\s*=>)'
functions = re.findall(function_pattern, code)
if functions:
result.append("Functions:")
for func in functions:
# Get the first non-empty group which is the function name
func_name = next((name for name in func[1:] if name), "anonymous")
result.append(f" - {func_name}")
# Extract classes
class_pattern = r'class\s+(\w+)'
classes = re.findall(class_pattern, code)
if classes:
result.append("\nClasses:")
for cls in classes:
result.append(f" - {cls}")
# Extract imports
import_pattern = r'import\s+.*?from\s+[\'"](.+?)[\'"]'
imports = re.findall(import_pattern, code)
if imports:
result.append("\nImports:")
for imp in imports:
result.append(f" - {imp}")
except Exception as e:
result.append(f"Error parsing JavaScript code: {str(e)}")
# Complexity analysis
if analysis_type in ["complexity", "all"]:
try:
# Count lines of code
lines = code.count('\n') + 1
non_empty_lines = sum(1 for line in code.split('\n') if line.strip())
comment_lines = sum(1 for line in code.split('\n')
if line.strip().startswith('//') or line.strip().startswith('/*'))
result.append("\nComplexity Metrics:")
result.append(f" - Total lines: {lines}")
result.append(f" - Non-empty lines: {non_empty_lines}")
result.append(f" - Comment lines: {comment_lines}")
result.append(f" - Code lines: {non_empty_lines - comment_lines}")
# Simplified cyclomatic complexity
control_structures = len(re.findall(r'\b(if|for|while|switch|catch)\b', code))
logical_operators = len(re.findall(r'(&&|\|\|)', code))
complexity = 1 + control_structures + logical_operators
result.append(f" - Cyclomatic complexity (estimated): {complexity}")
except Exception as e:
result.append(f"Error calculating complexity: {str(e)}")
# Dependencies analysis
if analysis_type in ["dependencies", "all"]:
try:
# Extract imports
import_pattern = r'import\s+.*?from\s+[\'"](.+?)[\'"]'
imports = re.findall(import_pattern, code)
node_std_libs = [
"fs", "path", "http", "https", "url", "querystring", "crypto", "os",
"util", "stream", "events", "buffer", "assert", "zlib", "child_process"
]
std_lib_imports = []
third_party_imports = []
local_imports = []
for imp in imports:
if imp in node_std_libs:
std_lib_imports.append(imp)
elif imp.startswith('.'):
local_imports.append(imp)
else:
third_party_imports.append(imp)
result.append("\nDependencies:")
if std_lib_imports:
result.append(" Standard Library:")
for imp in sorted(set(std_lib_imports)):
result.append(f" - {imp}")
if third_party_imports:
result.append(" Third-Party:")
for imp in sorted(set(third_party_imports)):
result.append(f" - {imp}")
if local_imports:
result.append(" Local/Project:")
for imp in sorted(set(local_imports)):
result.append(f" - {imp}")
except Exception as e:
result.append(f"Error analyzing dependencies: {str(e)}")
return "\n".join(result)
def _analyze_java(code: str, analysis_type: str) -> str:
"""Analyze Java code."""
# Simplified Java analysis
result = []
# Structure analysis
if analysis_type in ["structure", "all"]:
try:
# Extract class names
class_pattern = r'(public|private|protected)?\s+class\s+(\w+)'
classes = re.findall(class_pattern, code)
if classes:
result.append("Classes:")
for cls in classes:
result.append(f" - {cls[1]}")
# Extract methods
method_pattern = r'(public|private|protected)?\s+\w+\s+(\w+)\s*\([^)]*\)\s*\{'
methods = re.findall(method_pattern, code)
if methods:
result.append("\nMethods:")
for method in methods:
result.append(f" - {method[1]}")
# Extract imports
import_pattern = r'import\s+(.+?);'
imports = re.findall(import_pattern, code)
if imports:
result.append("\nImports:")
for imp in imports:
result.append(f" - {imp}")
except Exception as e:
result.append(f"Error parsing Java code: {str(e)}")
# Complexity analysis
if analysis_type in ["complexity", "all"]:
try:
# Count lines of code
lines = code.count('\n') + 1
non_empty_lines = sum(1 for line in code.split('\n') if line.strip())
comment_lines = sum(1 for line in code.split('\n')
if line.strip().startswith('//') or line.strip().startswith('/*'))
result.append("\nComplexity Metrics:")
result.append(f" - Total lines: {lines}")
result.append(f" - Non-empty lines: {non_empty_lines}")
result.append(f" - Comment lines: {comment_lines}")
result.append(f" - Code lines: {non_empty_lines - comment_lines}")
# Simplified cyclomatic complexity
control_structures = len(re.findall(r'\b(if|for|while|switch|catch)\b', code))
logical_operators = len(re.findall(r'(&&|\|\|)', code))
complexity = 1 + control_structures + logical_operators
result.append(f" - Cyclomatic complexity (estimated): {complexity}")
except Exception as e:
result.append(f"Error calculating complexity: {str(e)}")
return "\n".join(result)
def _analyze_cpp(code: str, analysis_type: str) -> str:
"""Analyze C/C++ code."""
# Simplified C/C++ analysis
result = []
# Structure analysis
if analysis_type in ["structure", "all"]:
try:
# Extract class names
class_pattern = r'class\s+(\w+)'
classes = re.findall(class_pattern, code)
if classes:
result.append("Classes:")
for cls in classes:
result.append(f" - {cls}")
# Extract functions
function_pattern = r'(\w+)\s+(\w+)\s*\([^)]*\)\s*\{'
functions = re.findall(function_pattern, code)
if functions:
result.append("\nFunctions:")
for func in functions:
# Filter out keywords that might be matched
if func[1] not in ['if', 'for', 'while', 'switch']:
result.append(f" - {func[1]} (return type: {func[0]})")
# Extract includes
include_pattern = r'#include\s+[<"](.+?)[>"]'
includes = re.findall(include_pattern, code)
if includes:
result.append("\nIncludes:")
for inc in includes:
result.append(f" - {inc}")
except Exception as e:
result.append(f"Error parsing C/C++ code: {str(e)}")
# Complexity analysis
if analysis_type in ["complexity", "all"]:
try:
# Count lines of code
lines = code.count('\n') + 1
non_empty_lines = sum(1 for line in code.split('\n') if line.strip())
comment_lines = sum(1 for line in code.split('\n')
if line.strip().startswith('//') or line.strip().startswith('/*'))
result.append("\nComplexity Metrics:")
result.append(f" - Total lines: {lines}")
result.append(f" - Non-empty lines: {non_empty_lines}")
result.append(f" - Comment lines: {comment_lines}")
result.append(f" - Code lines: {non_empty_lines - comment_lines}")
# Simplified cyclomatic complexity
control_structures = len(re.findall(r'\b(if|for|while|switch|catch)\b', code))
logical_operators = len(re.findall(r'(&&|\|\|)', code))
complexity = 1 + control_structures + logical_operators
result.append(f" - Cyclomatic complexity (estimated): {complexity}")
except Exception as e:
result.append(f"Error calculating complexity: {str(e)}")
return "\n".join(result)
def _analyze_generic(code: str, analysis_type: str) -> str:
"""Generic code analysis for unsupported languages."""
result = []
# Basic analysis for any language
try:
# Count lines of code
lines = code.count('\n') + 1
non_empty_lines = sum(1 for line in code.split('\n') if line.strip())
result.append("Basic Code Metrics:")
result.append(f" - Total lines: {lines}")
result.append(f" - Non-empty lines: {non_empty_lines}")
# Try to identify language
language = "unknown"
if "def " in code and "import " in code:
language = "Python"
elif "function " in code or "const " in code or "let " in code:
language = "JavaScript"
elif "public class " in code or "private class " in code:
language = "Java"
elif "#include" in code and "{" in code:
language = "C/C++"
result.append(f" - Detected language: {language}")
# Find potential functions/methods using a generic pattern
function_pattern = r'\b(\w+)\s*\([^)]*\)\s*\{'
functions = re.findall(function_pattern, code)
if functions:
result.append("\nPotential Functions/Methods:")
for func in functions:
# Filter out common keywords
if func not in ['if', 'for', 'while', 'switch', 'catch']:
result.append(f" - {func}")
except Exception as e:
result.append(f"Error analyzing code: {str(e)}")
return "\n".join(result)
@tool(
name="LintCode",
description="Lint code to find potential issues and style violations",
parameters={
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "The absolute path to the file to lint"
},
"linter": {
"type": "string",
"description": "Linter to use (auto, pylint, eslint, etc.)",
"default": "auto"
}
},
"required": ["file_path"]
},
category="code"
)
def lint_code(file_path: str, linter: str = "auto") -> str:
"""Lint code to find potential issues.
Args:
file_path: Path to the file to lint
linter: Linter to use
Returns:
Linting results as formatted text
"""
logger.info(f"Linting code in {file_path} using {linter}")
if not os.path.isabs(file_path):
return f"Error: File path must be absolute: {file_path}"
if not os.path.exists(file_path):
return f"Error: File not found: {file_path}"
try:
# Get file extension
_, ext = os.path.splitext(file_path)
ext = ext.lower()
# Auto-detect linter if not specified
if linter == "auto":
if ext in ['.py']:
linter = "pylint"
elif ext in ['.js', '.jsx']:
linter = "eslint"
elif ext in ['.ts', '.tsx']:
linter = "tslint"
elif ext in ['.java']:
linter = "checkstyle"
elif ext in ['.c', '.cpp', '.cc', '.h', '.hpp']:
linter = "cppcheck"
else:
return f"Error: Could not auto-detect linter for file type {ext}"
# Run appropriate linter
if linter == "pylint":
return _run_pylint(file_path)
elif linter == "eslint":
return _run_eslint(file_path)
elif linter == "tslint":
return _run_tslint(file_path)
elif linter == "checkstyle":
return _run_checkstyle(file_path)
elif linter == "cppcheck":
return _run_cppcheck(file_path)
else:
return f"Error: Unsupported linter: {linter}"
except Exception as e:
logger.exception(f"Error linting code: {str(e)}")
return f"Error linting code: {str(e)}"
def _run_pylint(file_path: str) -> str:
"""Run pylint on a Python file."""
try:
# Check if pylint is installed
try:
subprocess.run(["pylint", "--version"], capture_output=True, check=True)
except (subprocess.SubprocessError, FileNotFoundError):
return "Error: pylint is not installed. Please install it with 'pip install pylint'."
# Run pylint
result = subprocess.run(
["pylint", "--output-format=text", file_path],
capture_output=True,
text=True
)
if result.returncode == 0:
return "No issues found."
# Format output
output = result.stdout or result.stderr
# Summarize output
lines = output.split('\n')
summary_lines = [line for line in lines if "rated at" in line]
issue_lines = [line for line in lines if re.match(r'^.*?:\d+:\d+:', line)]
formatted_output = []
if issue_lines:
formatted_output.append("Issues found:")
for line in issue_lines:
formatted_output.append(f" {line}")
if summary_lines:
formatted_output.append("\nSummary:")
for line in summary_lines:
formatted_output.append(f" {line}")
return "\n".join(formatted_output)
except Exception as e:
return f"Error running pylint: {str(e)}"
def _run_eslint(file_path: str) -> str:
"""Run eslint on a JavaScript file."""
try:
# Check if eslint is installed
try:
subprocess.run(["eslint", "--version"], capture_output=True, check=True)
except (subprocess.SubprocessError, FileNotFoundError):
return "Error: eslint is not installed. Please install it with 'npm install -g eslint'."
# Run eslint
result = subprocess.run(
["eslint", "--format=stylish", file_path],
capture_output=True,
text=True
)
if result.returncode == 0:
return "No issues found."
# Format output
output = result.stdout or result.stderr
# Clean up output
lines = output.split('\n')
filtered_lines = [line for line in lines if line.strip() and not line.startswith("eslint:")]
return "\n".join(filtered_lines)
except Exception as e:
return f"Error running eslint: {str(e)}"
def _run_tslint(file_path: str) -> str:
"""Run tslint on a TypeScript file."""
try:
# Check if tslint is installed
try:
subprocess.run(["tslint", "--version"], capture_output=True, check=True)
except (subprocess.SubprocessError, FileNotFoundError):
return "Error: tslint is not installed. Please install it with 'npm install -g tslint'."
# Run tslint
result = subprocess.run(
["tslint", "-t", "verbose", file_path],
capture_output=True,
text=True
)
if result.returncode == 0:
return "No issues found."
# Format output
output = result.stdout or result.stderr
return output
except Exception as e:
return f"Error running tslint: {str(e)}"
def _run_checkstyle(file_path: str) -> str:
"""Run checkstyle on a Java file."""
return "Checkstyle support not implemented. Please install and run checkstyle manually."
def _run_cppcheck(file_path: str) -> str:
"""Run cppcheck on a C/C++ file."""
try:
# Check if cppcheck is installed
try:
subprocess.run(["cppcheck", "--version"], capture_output=True, check=True)
except (subprocess.SubprocessError, FileNotFoundError):
return "Error: cppcheck is not installed. Please install it using your system package manager."
# Run cppcheck
result = subprocess.run(
["cppcheck", "--enable=all", "--template='{file}:{line}: {severity}: {message}'", file_path],
capture_output=True,
text=True
)
# Format output
output = result.stderr # cppcheck outputs to stderr
if not output or "no errors found" in output.lower():
return "No issues found."
# Clean up output
lines = output.split('\n')
filtered_lines = [line for line in lines if line.strip() and "Checking" not in line]
return "\n".join(filtered_lines)
except Exception as e:
return f"Error running cppcheck: {str(e)}"
def register_code_tools(registry: ToolRegistry) -> None:
"""Register all code analysis tools with the registry.
Args:
registry: Tool registry to register with
"""
from .base import create_tools_from_functions
code_tools = [
analyze_code,
lint_code
]
create_tools_from_functions(registry, code_tools)
```
--------------------------------------------------------------------------------
/claude_code/lib/tools/manager.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
# claude_code/lib/tools/manager.py
"""Tool execution manager."""
import logging
import time
import json
import uuid
import os
from typing import Dict, List, Any, Optional, Callable, Union, Sequence
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor, Future
from .base import Tool, ToolResult, ToolRegistry, Routine, RoutineStep, RoutineDefinition
logger = logging.getLogger(__name__)
class RoutineExecutionManager:
"""Manages the execution of tool routines."""
def __init__(self, registry: ToolRegistry, execution_manager: 'ToolExecutionManager'):
"""Initialize the routine execution manager.
Args:
registry: Tool registry containing available tools and routines
execution_manager: Tool execution manager for executing individual tools
"""
self.registry = registry
self.execution_manager = execution_manager
self.active_routines: Dict[str, Dict[str, Any]] = {}
self.progress_callback: Optional[Callable[[str, str, float], None]] = None
self.result_callback: Optional[Callable[[str, List[ToolResult]], None]] = None
# Load existing routines
self.registry.load_routines()
def set_progress_callback(self, callback: Callable[[str, str, float], None]) -> None:
"""Set a callback function for routine progress updates.
Args:
callback: Function that takes routine_id, step_name, and progress (0-1) as arguments
"""
self.progress_callback = callback
def set_result_callback(self, callback: Callable[[str, List[ToolResult]], None]) -> None:
"""Set a callback function for routine results.
Args:
callback: Function that takes routine_id and list of ToolResults as arguments
"""
self.result_callback = callback
def create_routine(self, definition: RoutineDefinition) -> str:
"""Create a new routine from a definition.
Args:
definition: Routine definition
Returns:
Routine ID
Raises:
ValueError: If a routine with the same name already exists
"""
# Convert step objects to dictionaries
steps = []
for step in definition.steps:
step_dict = {
"tool_name": step.tool_name,
"args": step.args
}
if step.condition is not None:
step_dict["condition"] = step.condition
if step.store_result:
step_dict["store_result"] = True
if step.result_var is not None:
step_dict["result_var"] = step.result_var
steps.append(step_dict)
# Create routine
routine = Routine(
name=definition.name,
description=definition.description,
steps=steps
)
# Register routine
self.registry.register_routine(routine)
return routine.name
def create_routine_from_tool_history(
self,
name: str,
description: str,
tool_results: List[ToolResult],
context_variables: Dict[str, Any] = None
) -> str:
"""Create a routine from a history of tool executions.
Args:
name: Name for the routine
description: Description of the routine
tool_results: List of tool results to base the routine on
context_variables: Optional dictionary of context variables to identify
Returns:
Routine ID
"""
steps = []
# Process tool results into steps
for i, result in enumerate(tool_results):
# Skip failed tool calls
if result.status != "success":
continue
# Get tool
tool = self.registry.get_tool(result.name)
if not tool:
continue
# Extract arguments from tool call
args = {}
# Here we would need to extract the arguments from the tool call
# This is a simplification and would need to be adapted to the actual structure
# Create step
step = {
"tool_name": result.name,
"args": args,
"store_result": True,
"result_var": f"result_{i}"
}
steps.append(step)
# Create routine
routine = Routine(
name=name,
description=description,
steps=steps
)
# Register routine
self.registry.register_routine(routine)
return routine.name
def execute_routine(self, name: str, context: Dict[str, Any] = None) -> str:
"""Execute a routine with the given context.
Args:
name: Name of the routine to execute
context: Context variables for the routine
Returns:
Routine execution ID
Raises:
ValueError: If the routine is not found
"""
# Get routine
routine = self.registry.get_routine(name)
if not routine:
raise ValueError(f"Routine not found: {name}")
# Create execution ID
execution_id = str(uuid.uuid4())
# Initialize context
if context is None:
context = {}
# Initialize execution state
self.active_routines[execution_id] = {
"routine": routine,
"context": context.copy(),
"results": [],
"current_step": 0,
"start_time": time.time(),
"status": "running"
}
# Record routine usage
self.registry.record_routine_usage(name)
# Start execution in background thread
executor = ThreadPoolExecutor(max_workers=1)
executor.submit(self._execute_routine_steps, execution_id)
return execution_id
def _execute_routine_steps(self, execution_id: str) -> None:
"""Execute the steps of a routine in sequence.
Args:
execution_id: Routine execution ID
"""
if execution_id not in self.active_routines:
logger.error(f"Routine execution not found: {execution_id}")
return
execution = self.active_routines[execution_id]
routine = execution["routine"]
context = execution["context"]
results = execution["results"]
try:
# Execute each step
for i, step in enumerate(routine.steps):
# Update current step
execution["current_step"] = i
# Check for conditions
if "condition" in step and not self._evaluate_condition(step["condition"], context, results):
logger.info(f"Skipping step {i+1}/{len(routine.steps)} due to condition")
continue
# Process tool arguments with variable substitution
processed_args = self._process_arguments(step["args"], context, results)
# Create tool call
tool_call = {
"id": f"{execution_id}_{i}",
"function": {
"name": step["tool_name"],
"arguments": json.dumps(processed_args)
}
}
# Report progress
self._report_routine_progress(execution_id, i, len(routine.steps), step["tool_name"])
# Execute tool
result = self.execution_manager.execute_tool(tool_call)
# Add result to results
results.append(result)
# Store result in context if requested
if step.get("store_result", False):
var_name = step.get("result_var", f"result_{i}")
context[var_name] = result.result
# Check for loop control
if "repeat_until" in step and not self._evaluate_condition(step["repeat_until"], context, results):
# Go back to specified step
target_step = step.get("repeat_target", 0)
if 0 <= target_step < i:
i = target_step - 1 # Will be incremented in next loop iteration
# Check for exit condition
if "exit_condition" in step and self._evaluate_condition(step["exit_condition"], context, results):
logger.info(f"Exiting routine early due to exit condition at step {i+1}/{len(routine.steps)}")
break
# Update execution status
execution["status"] = "completed"
# Report final progress
self._report_routine_progress(execution_id, len(routine.steps), len(routine.steps), "completed")
# Call result callback
if self.result_callback:
self.result_callback(execution_id, results)
except Exception as e:
logger.exception(f"Error executing routine: {e}")
execution["status"] = "error"
execution["error"] = str(e)
# Report error progress
self._report_routine_progress(execution_id, execution["current_step"], len(routine.steps), "error")
def _process_arguments(
self,
args: Dict[str, Any],
context: Dict[str, Any],
results: List[ToolResult]
) -> Dict[str, Any]:
"""Process tool arguments with variable substitution.
Args:
args: Tool arguments
context: Context variables
results: Previous tool results
Returns:
Processed arguments
"""
processed_args = {}
for key, value in args.items():
if isinstance(value, str) and value.startswith("$"):
# Variable reference
var_name = value[1:]
if var_name in context:
processed_args[key] = context[var_name]
elif var_name.startswith("result[") and var_name.endswith("]"):
# Reference to previous result
try:
idx = int(var_name[7:-1])
if 0 <= idx < len(results):
processed_args[key] = results[idx].result
else:
processed_args[key] = value
except (ValueError, IndexError):
processed_args[key] = value
else:
processed_args[key] = value
else:
processed_args[key] = value
return processed_args
def _evaluate_condition(
self,
condition: Dict[str, Any],
context: Dict[str, Any],
results: List[ToolResult]
) -> bool:
"""Evaluate a condition for a routine step.
Args:
condition: Condition specification
context: Context variables
results: Previous tool results
Returns:
Whether the condition is met
"""
condition_type = condition.get("type", "simple")
if condition_type == "simple":
# Simple variable comparison
var_name = condition.get("variable", "")
operation = condition.get("operation", "equals")
value = condition.get("value")
# Get variable value
var_value = None
if var_name.startswith("$"):
var_name = var_name[1:]
var_value = context.get(var_name)
elif var_name.startswith("result[") and var_name.endswith("]"):
try:
idx = int(var_name[7:-1])
if 0 <= idx < len(results):
var_value = results[idx].result
except (ValueError, IndexError):
return False
# Compare
if operation == "equals":
return var_value == value
elif operation == "not_equals":
return var_value != value
elif operation == "contains":
return value in var_value if var_value is not None else False
elif operation == "greater_than":
return var_value > value if var_value is not None else False
elif operation == "less_than":
return var_value < value if var_value is not None else False
return False
elif condition_type == "and":
# Logical AND of multiple conditions
sub_conditions = condition.get("conditions", [])
return all(self._evaluate_condition(c, context, results) for c in sub_conditions)
elif condition_type == "or":
# Logical OR of multiple conditions
sub_conditions = condition.get("conditions", [])
return any(self._evaluate_condition(c, context, results) for c in sub_conditions)
elif condition_type == "not":
# Logical NOT
sub_condition = condition.get("condition", {})
return not self._evaluate_condition(sub_condition, context, results)
return True # Default to True
def _report_routine_progress(
self,
execution_id: str,
current_step: int,
total_steps: int,
step_name: str
) -> None:
"""Report progress for a routine execution.
Args:
execution_id: Routine execution ID
current_step: Current step index
total_steps: Total number of steps
step_name: Name of the current step
"""
progress = current_step / total_steps if total_steps > 0 else 1.0
# Call progress callback if set
if self.progress_callback:
self.progress_callback(execution_id, step_name, progress)
def get_active_routines(self) -> Dict[str, Dict[str, Any]]:
"""Get information about active routine executions.
Returns:
Dictionary mapping execution ID to routine execution information
"""
return {
k: {
"routine_name": v["routine"].name,
"current_step": v["current_step"],
"total_steps": len(v["routine"].steps),
"status": v["status"],
"start_time": v["start_time"],
"elapsed_time": time.time() - v["start_time"]
}
for k, v in self.active_routines.items()
}
def get_routine_results(self, execution_id: str) -> Optional[List[ToolResult]]:
"""Get the results of a routine execution.
Args:
execution_id: Routine execution ID
Returns:
List of tool results, or None if the routine execution is not found
"""
if execution_id in self.active_routines:
return self.active_routines[execution_id]["results"]
return None
def cancel_routine(self, execution_id: str) -> bool:
"""Cancel a routine execution.
Args:
execution_id: Routine execution ID
Returns:
Whether the routine was canceled successfully
"""
if execution_id in self.active_routines:
self.active_routines[execution_id]["status"] = "canceled"
return True
return False
class ToolExecutionManager:
"""Manages tool execution, including parallel execution and progress tracking."""
def __init__(self, registry: ToolRegistry):
"""Initialize the tool execution manager.
Args:
registry: Tool registry containing available tools
"""
self.registry = registry
self.active_executions: Dict[str, Dict[str, Any]] = {}
self.progress_callback: Optional[Callable[[str, float], None]] = None
self.result_callback: Optional[Callable[[str, ToolResult], None]] = None
self.max_workers = 10
# Initialize routine manager
self.routine_manager = RoutineExecutionManager(registry, self)
def set_progress_callback(self, callback: Callable[[str, float], None]) -> None:
"""Set a callback function for progress updates.
Args:
callback: Function that takes tool_call_id and progress (0-1) as arguments
"""
self.progress_callback = callback
def set_result_callback(self, callback: Callable[[str, ToolResult], None]) -> None:
"""Set a callback function for results.
Args:
callback: Function that takes tool_call_id and ToolResult as arguments
"""
self.result_callback = callback
def execute_tool(self, tool_call: Dict[str, Any]) -> ToolResult:
"""Execute a single tool synchronously.
Args:
tool_call: Dictionary containing tool call information
Returns:
ToolResult with execution result
Raises:
ValueError: If the tool is not found
"""
function_name = tool_call.get("function", {}).get("name", "")
tool_call_id = tool_call.get("id", "unknown")
# Check if it's a routine
if function_name.startswith("routine."):
routine_name = function_name[9:] # Remove "routine." prefix
return self._execute_routine_as_tool(routine_name, tool_call)
# Get the tool
tool = self.registry.get_tool(function_name)
if not tool:
error_msg = f"Tool not found: {function_name}"
logger.error(error_msg)
return ToolResult(
tool_call_id=tool_call_id,
name=function_name,
result=f"Error: {error_msg}",
execution_time=0,
status="error",
error=error_msg
)
# Check if tool needs permission and handle it
if tool.needs_permission:
# TODO: Implement permission handling
logger.warning(f"Tool {function_name} needs permission, but permission handling is not implemented")
# Track progress
self._track_execution_start(tool_call_id, function_name)
try:
# Execute the tool
result = tool.execute(tool_call)
# Track completion
self._track_execution_complete(tool_call_id, result)
return result
except Exception as e:
logger.exception(f"Error executing tool {function_name}: {e}")
result = ToolResult(
tool_call_id=tool_call_id,
name=function_name,
result=f"Error: {str(e)}",
execution_time=0,
status="error",
error=str(e)
)
self._track_execution_complete(tool_call_id, result)
return result
def _execute_routine_as_tool(self, routine_name: str, tool_call: Dict[str, Any]) -> ToolResult:
"""Execute a routine as if it were a tool.
Args:
routine_name: Name of the routine
tool_call: Dictionary containing tool call information
Returns:
ToolResult with execution result
"""
tool_call_id = tool_call.get("id", "unknown")
start_time = time.time()
try:
# Extract context from arguments
arguments_str = tool_call.get("function", {}).get("arguments", "{}")
try:
context = json.loads(arguments_str)
except json.JSONDecodeError:
context = {}
# Execute routine
execution_id = self.routine_manager.execute_routine(routine_name, context)
# Wait for routine to complete
while True:
routine_status = self.routine_manager.get_active_routines().get(execution_id, {})
if routine_status.get("status") != "running":
break
time.sleep(0.1)
# Get results
results = self.routine_manager.get_routine_results(execution_id)
if not results:
raise ValueError(f"No results from routine: {routine_name}")
# Format results
result_summary = f"Routine {routine_name} executed successfully with {len(results)} steps\n\n"
for i, result in enumerate(results):
result_summary += f"Step {i+1}: {result.name} - {'SUCCESS' if result.status == 'success' else 'ERROR'}\n"
if result.status != "success":
result_summary += f" Error: {result.error}\n"
# Track execution time
execution_time = time.time() - start_time
# Create result
return ToolResult(
tool_call_id=tool_call_id,
name=f"routine.{routine_name}",
result=result_summary,
execution_time=execution_time,
status="success"
)
except Exception as e:
logger.exception(f"Error executing routine {routine_name}: {e}")
return ToolResult(
tool_call_id=tool_call_id,
name=f"routine.{routine_name}",
result=f"Error: {str(e)}",
execution_time=time.time() - start_time,
status="error",
error=str(e)
)
def execute_tools_parallel(self, tool_calls: List[Dict[str, Any]]) -> List[ToolResult]:
"""Execute multiple tools in parallel.
Args:
tool_calls: List of dictionaries containing tool call information
Returns:
List of ToolResult with execution results
"""
results = []
futures: Dict[Future, str] = {}
# Use ThreadPoolExecutor for parallel execution
with ThreadPoolExecutor(max_workers=min(self.max_workers, len(tool_calls))) as executor:
# Submit all tool calls
for tool_call in tool_calls:
tool_call_id = tool_call.get("id", "unknown")
future = executor.submit(self.execute_tool, tool_call)
futures[future] = tool_call_id
# Wait for completion and collect results
for future in concurrent.futures.as_completed(futures):
tool_call_id = futures[future]
try:
result = future.result()
results.append(result)
except Exception as e:
logger.exception(f"Error in parallel tool execution for {tool_call_id}: {e}")
# Create an error result
function_name = next(
(tc.get("function", {}).get("name", "") for tc in tool_calls
if tc.get("id", "") == tool_call_id),
"unknown"
)
results.append(ToolResult(
tool_call_id=tool_call_id,
name=function_name,
result=f"Error: {str(e)}",
execution_time=0,
status="error",
error=str(e)
))
return results
def create_routine(self, definition: RoutineDefinition) -> str:
"""Create a new routine.
Args:
definition: Routine definition
Returns:
Routine ID
"""
return self.routine_manager.create_routine(definition)
def create_routine_from_tool_history(
self,
name: str,
description: str,
tool_results: List[ToolResult],
context_variables: Dict[str, Any] = None
) -> str:
"""Create a routine from a history of tool executions.
Args:
name: Name for the routine
description: Description of the routine
tool_results: List of tool results to base the routine on
context_variables: Optional dictionary of context variables to identify
Returns:
Routine ID
"""
return self.routine_manager.create_routine_from_tool_history(
name, description, tool_results, context_variables
)
def execute_routine(self, name: str, context: Dict[str, Any] = None) -> str:
"""Execute a routine with the given context.
Args:
name: Name of the routine to execute
context: Context variables for the routine
Returns:
Routine execution ID
"""
return self.routine_manager.execute_routine(name, context)
def get_routine_results(self, execution_id: str) -> Optional[List[ToolResult]]:
"""Get the results of a routine execution.
Args:
execution_id: Routine execution ID
Returns:
List of tool results, or None if the routine execution is not found
"""
return self.routine_manager.get_routine_results(execution_id)
def _track_execution_start(self, tool_call_id: str, tool_name: str) -> None:
"""Track the start of tool execution.
Args:
tool_call_id: ID of the tool call
tool_name: Name of the tool
"""
self.active_executions[tool_call_id] = {
"tool_name": tool_name,
"start_time": time.time(),
"progress": 0.0
}
# Call progress callback if set
if self.progress_callback:
self.progress_callback(tool_call_id, 0.0)
def _track_execution_progress(self, tool_call_id: str, progress: float) -> None:
"""Track the progress of tool execution.
Args:
tool_call_id: ID of the tool call
progress: Progress value (0-1)
"""
if tool_call_id in self.active_executions:
self.active_executions[tool_call_id]["progress"] = progress
# Call progress callback if set
if self.progress_callback:
self.progress_callback(tool_call_id, progress)
def _track_execution_complete(self, tool_call_id: str, result: ToolResult) -> None:
"""Track the completion of tool execution.
Args:
tool_call_id: ID of the tool call
result: Tool execution result
"""
if tool_call_id in self.active_executions:
# Update progress
self._track_execution_progress(tool_call_id, 1.0)
# Calculate execution time
start_time = self.active_executions[tool_call_id]["start_time"]
execution_time = time.time() - start_time
# Clean up
del self.active_executions[tool_call_id]
# Call result callback if set
if self.result_callback:
self.result_callback(tool_call_id, result)
def get_active_executions(self) -> Dict[str, Dict[str, Any]]:
"""Get information about active tool executions.
Returns:
Dictionary mapping tool_call_id to execution information
"""
return self.active_executions.copy()
def cancel_execution(self, tool_call_id: str) -> bool:
"""Cancel a tool execution if possible.
Args:
tool_call_id: ID of the tool call to cancel
Returns:
True if canceled successfully, False otherwise
"""
# TODO: Implement cancellation logic
# This would require more sophisticated execution tracking
logger.warning(f"Cancellation not implemented for tool_call_id: {tool_call_id}")
return False
```
--------------------------------------------------------------------------------
/claude_code/claude.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
# claude.py
"""Claude Code Python Edition - CLI entry point."""
import os
import sys
import logging
import argparse
from typing import Dict, List, Optional, Any
import json
import signal
from datetime import datetime
import typer
from rich.console import Console
from rich.panel import Panel
from rich.markdown import Markdown
from rich.prompt import Prompt
from rich.syntax import Syntax
from rich.logging import RichHandler
from dotenv import load_dotenv
from claude_code.lib.providers import get_provider, list_available_providers
from claude_code.lib.tools.base import ToolRegistry
from claude_code.lib.tools.manager import ToolExecutionManager
from claude_code.lib.tools.file_tools import register_file_tools
from claude_code.lib.ui.tool_visualizer import ToolCallVisualizer, MultiPanelLayout
from claude_code.lib.monitoring.cost_tracker import CostTracker
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(message)s",
datefmt="[%X]",
handlers=[RichHandler(rich_tracebacks=True)]
)
logger = logging.getLogger("claude_code")
# Load environment variables
load_dotenv()
# Get version from package
VERSION = "0.1.0"
# Create typer app
app = typer.Typer(help="Claude Code Python Edition")
console = Console()
# Global state
conversation: List[Dict[str, Any]] = []
tool_registry = ToolRegistry()
tool_manager: Optional[ToolExecutionManager] = None
cost_tracker: Optional[CostTracker] = None
visualizer: Optional[ToolCallVisualizer] = None
provider_name: str = ""
model_name: str = ""
user_config: Dict[str, Any] = {}
def initialize_tools() -> None:
"""Initialize all available tools."""
global tool_registry, tool_manager
# Create the registry and manager
tool_registry = ToolRegistry()
tool_manager = ToolExecutionManager(tool_registry)
# Register file tools
register_file_tools(tool_registry)
# TODO: Register more tools
# register_search_tools(tool_registry)
# register_bash_tools(tool_registry)
# register_agent_tools(tool_registry)
logger.info(f"Initialized {len(tool_registry.get_all_tools())} tools")
def setup_visualizer() -> None:
"""Set up the tool visualizer with callbacks."""
global tool_manager, visualizer
if not tool_manager:
return
# Create visualizer
visualizer = ToolCallVisualizer(console)
# Set up callbacks
def progress_callback(tool_call_id: str, progress: float) -> None:
if visualizer:
visualizer.update_progress(tool_call_id, progress)
def result_callback(tool_call_id: str, result: Any) -> None:
if visualizer:
visualizer.complete_tool_call(tool_call_id, result)
tool_manager.set_progress_callback(progress_callback)
tool_manager.set_result_callback(result_callback)
def load_configuration() -> Dict[str, Any]:
"""Load user configuration from file."""
config_path = os.path.expanduser("~/.config/claude_code/config.json")
# Default configuration
default_config = {
"provider": "openai",
"model": None, # Use provider default
"budget_limit": None,
"history_file": os.path.expanduser("~/.config/claude_code/usage_history.json"),
"ui": {
"theme": "dark",
"show_tool_calls": True,
"show_cost": True
}
}
# If configuration file doesn't exist, create it with defaults
if not os.path.exists(config_path):
try:
os.makedirs(os.path.dirname(config_path), exist_ok=True)
with open(config_path, 'w', encoding='utf-8') as f:
json.dump(default_config, f, indent=2)
except Exception as e:
logger.warning(f"Failed to create default configuration: {e}")
return default_config
# Load configuration
try:
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
# Merge with defaults for any missing keys
for key, value in default_config.items():
if key not in config:
config[key] = value
return config
except Exception as e:
logger.warning(f"Failed to load configuration: {e}")
return default_config
def handle_compact_command() -> str:
"""Handle the /compact command to compress conversation history."""
global conversation, provider_name, model_name
if not conversation:
return "No conversation to compact."
# Add a system message requesting summarization
compact_prompt = (
"Summarize the conversation so far, focusing on the key points, decisions, and context. "
"Keep important details about the code and tasks. Retain critical file paths, commands, "
"and code snippets. The summary should be concise but complete enough to continue the "
"conversation effectively."
)
conversation.append({"role": "user", "content": compact_prompt})
# Get the provider
provider = get_provider(provider_name, model=model_name)
# Make non-streaming API call for compaction
response = provider.generate_completion(conversation, stream=False)
# Extract summary
summary = response["content"] or ""
# Reset conversation with summary
system_message = next((m for m in conversation if m["role"] == "system"), None)
if system_message:
conversation = [system_message]
else:
conversation = []
# Add compacted context
conversation.append({
"role": "system",
"content": f"This is a compacted conversation. Previous context: {summary}"
})
return "Conversation compacted successfully."
def handle_help_command() -> str:
"""Handle the /help command."""
help_text = """
# Claude Code Python Edition Help
## Commands
- **/help**: Show this help message
- **/compact**: Compact the conversation to reduce token usage
- **/version**: Show version information
- **/providers**: List available LLM providers
- **/cost**: Show cost and usage information
- **/budget [amount]**: Set a budget limit (e.g., /budget 5.00)
- **/quit, /exit**: Exit the application
## Routine Commands
- **/routine list**: List all available routines
- **/routine create <name> <description>**: Create a routine from recent tool executions
- **/routine run <name>**: Run a routine
- **/routine delete <name>**: Delete a routine
## Tools
Claude Code has access to these tools:
- **View**: Read files
- **Edit**: Edit files (replace text)
- **Replace**: Overwrite or create files
- **GlobTool**: Find files by pattern
- **GrepTool**: Search file contents
- **LS**: List directory contents
- **Bash**: Execute shell commands
## CLI Commands
- **claude**: Start the Claude Code assistant (main interface)
- **claude mcp-client**: Start the MCP client to connect to MCP servers
- Usage: `claude mcp-client path/to/server.py [--model MODEL]`
- **claude mcp-multi-agent**: Start the multi-agent MCP client with synchronized agents
- Usage: `claude mcp-multi-agent path/to/server.py [--config CONFIG_FILE]`
## Multi-Agent Commands
When using the multi-agent client:
- **/agents**: List all active agents
- **/talk <agent> <message>**: Send a direct message to a specific agent
- **/history**: Show message history
- **/help**: Show multi-agent help
## Tips
- Be specific about file paths when requesting file operations
- For complex tasks, break them down into smaller steps
- Use /compact periodically for long sessions to save tokens
- Create routines for repetitive sequences of tool operations
- In multi-agent mode, use agent specialization for complex problems
"""
return help_text
def handle_version_command() -> str:
"""Handle the /version command."""
import platform
python_version = platform.python_version()
version_info = f"""
# Claude Code Python Edition v{VERSION}
- Python: {python_version}
- Provider: {provider_name}
- Model: {model_name}
- Tools: {len(tool_registry.get_all_tools()) if tool_registry else 0} available
"""
return version_info
def handle_providers_command() -> str:
"""Handle the /providers command."""
providers = list_available_providers()
providers_text = "# Available LLM Providers\n\n"
for name, info in providers.items():
providers_text += f"## {info['name']}\n"
if info['available']:
providers_text += f"- Status: Available\n"
providers_text += f"- Current model: {info['current_model']}\n"
providers_text += f"- Available models: {', '.join(info['models'])}\n"
else:
providers_text += f"- Status: Not available ({info['error']})\n"
providers_text += "\n"
return providers_text
def handle_cost_command() -> str:
"""Handle the /cost command."""
global cost_tracker
if not cost_tracker:
return "Cost tracking is not available."
# Generate a usage report
return cost_tracker.generate_usage_report(format="markdown")
def handle_budget_command(args: List[str]) -> str:
"""Handle the /budget command."""
global cost_tracker
if not cost_tracker:
return "Cost tracking is not available."
if not args:
# Show current budget
budget = cost_tracker.check_budget()
if not budget["has_budget"]:
return "No budget limit is currently set."
return f"Current budget: ${budget['limit']:.2f} (${budget['used']:.2f} used, ${budget['remaining']:.2f} remaining)"
# Set new budget
try:
budget_amount = float(args[0])
if budget_amount <= 0:
return "Budget must be a positive number."
cost_tracker.budget_limit = budget_amount
# Update configuration
user_config["budget_limit"] = budget_amount
# Save configuration
config_path = os.path.expanduser("~/.config/claude_code/config.json")
try:
with open(config_path, 'w', encoding='utf-8') as f:
json.dump(user_config, f, indent=2)
except Exception as e:
logger.warning(f"Failed to save configuration: {e}")
return f"Budget set to ${budget_amount:.2f}"
except ValueError:
return f"Invalid budget amount: {args[0]}"
def handle_routine_list_command() -> str:
"""Handle the /routine list command."""
global tool_manager
if not tool_manager:
return "Tool manager is not initialized."
routines = tool_manager.registry.get_all_routines()
if not routines:
return "No routines available."
routines_text = "# Available Routines\n\n"
for routine in routines:
usage = f" (Used {routine.usage_count} times)" if routine.usage_count > 0 else ""
last_used = ""
if routine.last_used_at:
last_used_time = datetime.fromtimestamp(routine.last_used_at)
last_used = f" (Last used: {last_used_time.strftime('%Y-%m-%d %H:%M')})"
routines_text += f"## {routine.name}{usage}{last_used}\n"
routines_text += f"{routine.description}\n\n"
routines_text += f"**Steps:** {len(routine.steps)}\n\n"
return routines_text
def handle_routine_create_command(args: List[str]) -> str:
"""Handle the /routine create command."""
global tool_manager, visualizer
if not tool_manager:
return "Tool manager is not initialized."
if len(args) < 2:
return "Usage: /routine create <name> <description>"
name = args[0]
description = " ".join(args[1:])
# Get recent tool results from visualizer
if not visualizer or not hasattr(visualizer, "recent_tool_results"):
return "No recent tool executions to create a routine from."
recent_tool_results = visualizer.recent_tool_results
if not recent_tool_results:
return "No recent tool executions to create a routine from."
try:
routine_id = tool_manager.create_routine_from_tool_history(
name, description, recent_tool_results
)
return f"Created routine '{name}' with {len(recent_tool_results)} steps."
except Exception as e:
logger.exception(f"Error creating routine: {e}")
return f"Error creating routine: {str(e)}"
def handle_routine_run_command(args: List[str]) -> str:
"""Handle the /routine run command."""
global tool_manager, visualizer
if not tool_manager:
return "Tool manager is not initialized."
if not args:
return "Usage: /routine run <name>"
name = args[0]
# Check if routine exists
routine = tool_manager.registry.get_routine(name)
if not routine:
return f"Routine '{name}' not found."
try:
# Execute the routine
execution_id = tool_manager.execute_routine(name)
# Wait for completion
while True:
routine_status = tool_manager.routine_manager.get_active_routines().get(execution_id, {})
if routine_status.get("status") != "running":
break
time.sleep(0.1)
# Get results
results = tool_manager.get_routine_results(execution_id)
if not results:
return f"Routine '{name}' completed but returned no results."
# Format results
result_text = f"# Routine '{name}' Results\n\n"
result_text += f"Executed {len(results)} steps:\n\n"
for i, result in enumerate(results):
status = "✅" if result.status == "success" else "❌"
result_text += f"## Step {i+1}: {result.name} {status}\n"
result_text += f"```\n{result.result}\n```\n\n"
return result_text
except Exception as e:
logger.exception(f"Error executing routine: {e}")
return f"Error executing routine: {str(e)}"
def handle_routine_delete_command(args: List[str]) -> str:
"""Handle the /routine delete command."""
global tool_manager
if not tool_manager:
return "Tool manager is not initialized."
if not args:
return "Usage: /routine delete <name>"
name = args[0]
# Check if routine exists
routine = tool_manager.registry.get_routine(name)
if not routine:
return f"Routine '{name}' not found."
try:
# Remove from registry and save
tool_manager.registry.routines.pop(name, None)
tool_manager.registry._save_routines()
return f"Deleted routine '{name}'."
except Exception as e:
logger.exception(f"Error deleting routine: {e}")
return f"Error deleting routine: {str(e)}"
def process_special_command(user_input: str) -> Optional[str]:
"""Process special commands starting with /."""
# Split into command and arguments
parts = user_input.strip().split()
command = parts[0].lower()
args = parts[1:]
# Handle commands
if command == "/help":
return handle_help_command()
elif command == "/compact":
return handle_compact_command()
elif command == "/version":
return handle_version_command()
elif command == "/providers":
return handle_providers_command()
elif command == "/cost":
return handle_cost_command()
elif command == "/budget":
return handle_budget_command(args)
elif command in ["/quit", "/exit"]:
console.print("[bold yellow]Goodbye![/bold yellow]")
sys.exit(0)
# Handle routine commands
elif command == "/routine":
if not args:
return "Usage: /routine [list|create|run|delete]"
subcmd = args[0].lower()
if subcmd == "list":
return handle_routine_list_command()
elif subcmd == "create":
return handle_routine_create_command(args[1:])
elif subcmd == "run":
return handle_routine_run_command(args[1:])
elif subcmd == "delete":
return handle_routine_delete_command(args[1:])
else:
return f"Unknown routine command: {subcmd}\nUsage: /routine [list|create|run|delete]"
# Not a recognized command
return None
def process_tool_calls(tool_calls: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Process tool calls and return results.
Args:
tool_calls: List of tool call dictionaries
Returns:
List of tool responses
"""
global tool_manager, visualizer
if not tool_manager:
logger.error("Tool manager not initialized")
return []
# Add tool calls to visualizer
if visualizer:
for tool_call in tool_calls:
function_name = tool_call.get("function", {}).get("name", "")
tool_call_id = tool_call.get("id", "unknown")
arguments_str = tool_call.get("function", {}).get("arguments", "{}")
try:
parameters = json.loads(arguments_str)
visualizer.add_tool_call(tool_call_id, function_name, parameters)
except json.JSONDecodeError:
visualizer.add_tool_call(tool_call_id, function_name, {})
# Execute tools in parallel
tool_results = tool_manager.execute_tools_parallel(tool_calls)
# Format results for the conversation
tool_responses = []
for result in tool_results:
tool_responses.append({
"tool_call_id": result.tool_call_id,
"role": "tool",
"name": result.name,
"content": result.result
})
return tool_responses
@app.command(name="mcp-client")
def mcp_client(
server_script: str = typer.Argument(..., help="Path to the server script (.py or .js)"),
model: str = typer.Option("claude-3-5-sonnet-20241022", "--model", "-m", help="Claude model to use")
):
"""Run the MCP client to interact with an MCP server."""
from claude_code.commands.client import execute as client_execute
import argparse
# Create a namespace with the arguments
args = argparse.Namespace()
args.server_script = server_script
args.model = model
# Execute the client
return client_execute(args)
@app.command(name="mcp-multi-agent")
def mcp_multi_agent(
server_script: str = typer.Argument(..., help="Path to the server script (.py or .js)"),
config: str = typer.Option(None, "--config", "-c", help="Path to agent configuration JSON file")
):
"""Run the multi-agent MCP client with agent synchronization."""
from claude_code.commands.multi_agent_client import execute as multi_agent_execute
import argparse
# Create a namespace with the arguments
args = argparse.Namespace()
args.server_script = server_script
args.config = config
# Execute the multi-agent client
return multi_agent_execute(args)
@app.command()
def main(
provider: str = typer.Option(None, "--provider", "-p", help="LLM provider to use"),
model: str = typer.Option(None, "--model", "-m", help="Model to use"),
budget: Optional[float] = typer.Option(None, "--budget", "-b", help="Budget limit in dollars"),
system_prompt: Optional[str] = typer.Option(None, "--system", "-s", help="System prompt file"),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable verbose output")
):
"""Claude Code Python Edition - A LLM-powered coding assistant."""
global conversation, tool_registry, tool_manager, cost_tracker, visualizer
global provider_name, model_name, user_config
# Set logging level
if verbose:
logging.getLogger("claude_code").setLevel(logging.DEBUG)
# Show welcome message
console.print(Panel.fit(
f"[bold green]Claude Code Python Edition v{VERSION}[/bold green]\n"
"Type your questions or commands. Use /help for available commands.",
title="Welcome",
border_style="green"
))
# Load configuration
user_config = load_configuration()
# Override with command line arguments
if provider:
user_config["provider"] = provider
if model:
user_config["model"] = model
if budget is not None:
user_config["budget_limit"] = budget
# Set provider and model
provider_name = user_config["provider"]
model_name = user_config["model"]
try:
# Initialize tools
initialize_tools()
# Set up cost tracking
cost_tracker = CostTracker(
budget_limit=user_config["budget_limit"],
history_file=user_config["history_file"]
)
# Get provider
provider = get_provider(provider_name, model=model_name)
provider_name = provider.name
model_name = provider.current_model
logger.info(f"Using {provider_name} with model {model_name}")
# Set up tool visualizer
setup_visualizer()
if visualizer:
visualizer.start()
# Load system prompt
system_message = ""
if system_prompt:
try:
with open(system_prompt, 'r', encoding='utf-8') as f:
system_message = f.read()
except Exception as e:
logger.error(f"Failed to load system prompt: {e}")
system_message = get_default_system_prompt()
else:
system_message = get_default_system_prompt()
# Initialize conversation
conversation = [{"role": "system", "content": system_message}]
# Main interaction loop
while True:
try:
# Get user input
user_input = Prompt.ask("\n[bold blue]>>[/bold blue]")
# Handle special commands
if user_input.startswith("/"):
result = process_special_command(user_input)
if result:
console.print(Markdown(result))
continue
# Add user message to conversation
conversation.append({"role": "user", "content": user_input})
# Get schemas for all tools
tool_schemas = tool_registry.get_tool_schemas() if tool_registry else None
# Call the LLM
with console.status("[bold blue]Thinking...[/bold blue]", spinner="dots"):
# Stream the response
response_stream = provider.generate_completion(
messages=conversation,
tools=tool_schemas,
stream=True
)
# Track tool calls from streaming response
current_content = ""
current_tool_calls = []
# Process streaming response
for chunk in response_stream:
# If there's content, print it
if chunk.get("content"):
content_piece = chunk["content"]
current_content += content_piece
console.print(content_piece, end="")
# Process tool calls
if chunk.get("tool_calls") and not chunk.get("delta", True):
# This is a complete tool call
current_tool_calls = chunk["tool_calls"]
break
console.print() # Add newline after content
# Add assistant response to conversation
conversation.append({
"role": "assistant",
"content": current_content,
"tool_calls": current_tool_calls
})
# Process tool calls if any
if current_tool_calls:
console.print("[bold green]Executing tools...[/bold green]")
# Process tool calls
tool_responses = process_tool_calls(current_tool_calls)
# Add tool responses to conversation
conversation.extend(tool_responses)
# Continue the conversation with tool responses
console.print("[bold blue]Continuing with tool results...[/bold blue]")
follow_up = provider.generate_completion(
messages=conversation,
tools=tool_schemas,
stream=False
)
follow_up_text = follow_up.get("content", "")
if follow_up_text:
console.print(Markdown(follow_up_text))
# Add to conversation
conversation.append({
"role": "assistant",
"content": follow_up_text
})
# Track token usage and cost
if cost_tracker:
# Get token counts - this is an approximation
token_counts = provider.count_message_tokens(conversation[-3:])
cost_info = provider.cost_per_1k_tokens
# Add request to tracker
cost_tracker.add_request(
provider=provider_name,
model=model_name,
tokens_input=token_counts["input"],
tokens_output=token_counts.get("output", 0) or 150, # Estimate if not available
input_cost_per_1k=cost_info["input"],
output_cost_per_1k=cost_info["output"]
)
# Check budget
budget_status = cost_tracker.check_budget()
if budget_status["has_budget"] and budget_status["status"] in ["critical", "exceeded"]:
console.print(f"[bold red]{budget_status['message']}[/bold red]")
except KeyboardInterrupt:
console.print("\n[bold yellow]Operation cancelled by user.[/bold yellow]")
continue
except Exception as e:
logger.exception(f"Error: {str(e)}")
console.print(f"[bold red]Error:[/bold red] {str(e)}")
finally:
# Clean up
if visualizer:
visualizer.stop()
# Save cost history
if cost_tracker and hasattr(cost_tracker, '_save_history'):
cost_tracker._save_history()
def get_default_system_prompt() -> str:
"""Get the default system prompt."""
return """You are Claude Code Python Edition, a CLI tool that helps users with software engineering tasks.
Use the available tools to assist the user with their requests.
# Tone and style
You should be concise, direct, and to the point. When you run a non-trivial bash command,
you should explain what the command does and why you are running it.
Output text to communicate with the user; all text you output outside of tool use is displayed to the user.
Remember that your output will be displayed on a command line interface.
# Tool usage policy
- When doing file search, remember to search effectively with the available tools.
- Always use the appropriate tool for the task.
- Use parallel tool calls when appropriate to improve performance.
- NEVER commit changes unless the user explicitly asks you to.
# Routines
You have access to Routines, which are sequences of tool calls that can be created and reused.
To create a routine from recent tool executions, use `/routine create <name> <description>`.
To run a routine, use `/routine run <name>`.
Routines are ideal for repetitive task sequences like:
- Deep research across multiple sources
- Multi-step code updates across files
- Complex search and replace operations
- Data processing pipelines
# Tasks
The user will primarily request you perform software engineering tasks:
1. Solving bugs
2. Adding new functionality
3. Refactoring code
4. Explaining code
5. Writing tests
For these tasks:
1. Use search tools to understand the codebase
2. Implement solutions using the available tools
3. Verify solutions with tests if possible
4. Run lint and typecheck commands when appropriate
5. Consider creating routines for repetitive operations
# Code style
- Follow the existing code style of the project
- Maintain consistent naming conventions
- Use appropriate libraries that are already in the project
- Add comments when code is complex or non-obvious
IMPORTANT: You should minimize output tokens as much as possible while maintaining helpfulness,
quality, and accuracy. Answer concisely with short lines of text unless the user asks for detail.
"""
if __name__ == "__main__":
# Handle Ctrl+C gracefully
signal.signal(signal.SIGINT, lambda sig, frame: sys.exit(0))
# Run app
app()
```
--------------------------------------------------------------------------------
/claude_code/lib/rl/grpo.py:
--------------------------------------------------------------------------------
```python
"""
Group Relative Policy Optimization (GRPO) for multi-agent learning in Claude Code.
This module provides a multi-agent GRPO implementation that learns from interactions.
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
from typing import List, Dict, Tuple, Optional, Any, Union, Callable
from dataclasses import dataclass
from collections import deque
import random
import time
@dataclass
class Experience:
"""A single step of experience for reinforcement learning."""
state: Any
action: Any
reward: float
next_state: Any
done: bool
info: Optional[Dict[str, Any]] = None
class ExperienceBuffer:
"""Buffer to store and sample experiences for training."""
def __init__(self, capacity: int = 100000):
"""
Initialize the experience buffer.
Args:
capacity: Maximum number of experiences to store
"""
self.buffer = deque(maxlen=capacity)
def add(self, experience: Experience) -> None:
"""Add an experience to the buffer."""
self.buffer.append(experience)
def sample(self, batch_size: int) -> List[Experience]:
"""Sample a batch of experiences from the buffer."""
return random.sample(self.buffer, min(batch_size, len(self.buffer)))
def __len__(self) -> int:
"""Get the current size of the buffer."""
return len(self.buffer)
class PolicyNetwork(nn.Module):
"""Neural network to represent a policy."""
def __init__(self, input_dim: int, hidden_dims: List[int], output_dim: int):
"""
Initialize the policy network.
Args:
input_dim: Dimension of the input state
hidden_dims: List of hidden layer dimensions
output_dim: Dimension of the action space
"""
super(PolicyNetwork, self).__init__()
# Create the input layer
layers = [nn.Linear(input_dim, hidden_dims[0]), nn.ReLU()]
# Create hidden layers
for i in range(len(hidden_dims) - 1):
layers.append(nn.Linear(hidden_dims[i], hidden_dims[i + 1]))
layers.append(nn.ReLU())
# Create output layer
layers.append(nn.Linear(hidden_dims[-1], output_dim))
self.network = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through the network."""
return self.network(x)
class ValueNetwork(nn.Module):
"""Neural network to represent a value function."""
def __init__(self, input_dim: int, hidden_dims: List[int]):
"""
Initialize the value network.
Args:
input_dim: Dimension of the input state
hidden_dims: List of hidden layer dimensions
"""
super(ValueNetwork, self).__init__()
# Create the input layer
layers = [nn.Linear(input_dim, hidden_dims[0]), nn.ReLU()]
# Create hidden layers
for i in range(len(hidden_dims) - 1):
layers.append(nn.Linear(hidden_dims[i], hidden_dims[i + 1]))
layers.append(nn.ReLU())
# Create output layer (scalar value)
layers.append(nn.Linear(hidden_dims[-1], 1))
self.network = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through the network."""
return self.network(x)
class GRPO:
"""
Group Relative Policy Optimization implementation for multi-agent learning.
GRPO extends PPO by considering relative performance within a group of agents.
"""
def __init__(
self,
state_dim: int,
action_dim: int,
hidden_dims: List[int] = [64, 64],
lr_policy: float = 3e-4,
lr_value: float = 1e-3,
gamma: float = 0.99,
gae_lambda: float = 0.95,
clip_ratio: float = 0.2,
target_kl: float = 0.01,
value_coef: float = 0.5,
entropy_coef: float = 0.01,
max_grad_norm: float = 0.5,
use_gae: bool = True,
normalize_advantages: bool = True,
relative_advantage_weight: float = 0.5,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
"""
Initialize the GRPO agent.
Args:
state_dim: Dimension of the state space
action_dim: Dimension of the action space
hidden_dims: Dimensions of hidden layers in networks
lr_policy: Learning rate for policy network
lr_value: Learning rate for value network
gamma: Discount factor
gae_lambda: Lambda for GAE
clip_ratio: PPO clipping parameter
target_kl: Target KL divergence for early stopping
value_coef: Value loss coefficient
entropy_coef: Entropy bonus coefficient
max_grad_norm: Maximum gradient norm for clipping
use_gae: Whether to use GAE
normalize_advantages: Whether to normalize advantages
relative_advantage_weight: Weight for relative advantage component
device: Device to run the model on
"""
self.state_dim = state_dim
self.action_dim = action_dim
self.gamma = gamma
self.gae_lambda = gae_lambda
self.clip_ratio = clip_ratio
self.target_kl = target_kl
self.value_coef = value_coef
self.entropy_coef = entropy_coef
self.max_grad_norm = max_grad_norm
self.use_gae = use_gae
self.normalize_advantages = normalize_advantages
self.relative_advantage_weight = relative_advantage_weight
self.device = device
# Initialize networks
self.policy = PolicyNetwork(state_dim, hidden_dims, action_dim).to(device)
self.value = ValueNetwork(state_dim, hidden_dims).to(device)
# Initialize optimizers
self.policy_optimizer = optim.Adam(self.policy.parameters(), lr=lr_policy)
self.value_optimizer = optim.Adam(self.value.parameters(), lr=lr_value)
# Initialize experience buffer
self.buffer = ExperienceBuffer()
# Group-level buffers for relative advantage computation
self.group_rewards = []
self.agent_id = None # Will be set when joining a group
def set_agent_id(self, agent_id: str) -> None:
"""Set the agent's ID within the group."""
self.agent_id = agent_id
def get_action(self, state: np.ndarray, deterministic: bool = False) -> Tuple[int, float]:
"""
Get an action from the policy for the given state.
Args:
state: The current state
deterministic: Whether to return the most likely action
Returns:
Tuple of (action, log probability)
"""
# Convert state to tensor
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
# Get action distributions
with torch.no_grad():
logits = self.policy(state_tensor)
distribution = Categorical(logits=logits)
if deterministic:
action = torch.argmax(logits, dim=1).item()
else:
action = distribution.sample().item()
log_prob = distribution.log_prob(torch.tensor(action)).item()
return action, log_prob
def get_value(self, state: np.ndarray) -> float:
"""
Get the estimated value of a state.
Args:
state: The state to evaluate
Returns:
The estimated value
"""
# Convert state to tensor
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
# Get value estimate
with torch.no_grad():
value = self.value(state_tensor).item()
return value
def learn(
self,
batch_size: int = 64,
epochs: int = 10,
group_rewards: Optional[Dict[str, List[float]]] = None
) -> Dict[str, float]:
"""
Update policy and value networks based on collected experience.
Args:
batch_size: Size of batches to use for updates
epochs: Number of epochs to train for
group_rewards: Rewards collected by all agents in the group
Returns:
Dictionary of training metrics
"""
if len(self.buffer) < batch_size:
return {"policy_loss": 0, "value_loss": 0, "kl": 0}
# Prepare data for training
states, actions, old_log_probs, returns, advantages = self._prepare_training_data(
group_rewards)
# Training metrics
metrics = {
"policy_loss": 0,
"value_loss": 0,
"entropy": 0,
"kl": 0,
}
# Run training for multiple epochs
for epoch in range(epochs):
# Generate random indices for batching
indices = np.random.permutation(len(states))
# Process in batches
for start_idx in range(0, len(states), batch_size):
# Get batch indices
batch_indices = indices[start_idx:start_idx + batch_size]
# Extract batch data
batch_states = states[batch_indices]
batch_actions = actions[batch_indices]
batch_old_log_probs = old_log_probs[batch_indices]
batch_returns = returns[batch_indices]
batch_advantages = advantages[batch_indices]
# Update policy
policy_loss, entropy, kl = self._update_policy(
batch_states, batch_actions, batch_old_log_probs, batch_advantages)
# Early stopping based on KL divergence
if kl > 1.5 * self.target_kl:
break
# Update value function
value_loss = self._update_value(batch_states, batch_returns)
# Update metrics
metrics["policy_loss"] += policy_loss
metrics["value_loss"] += value_loss
metrics["entropy"] += entropy
metrics["kl"] += kl
# Check for early stopping after each epoch
if metrics["kl"] / (epoch + 1) > self.target_kl:
break
# Normalize metrics by number of updates
num_updates = epochs * ((len(states) + batch_size - 1) // batch_size)
for key in metrics:
metrics[key] /= num_updates
# Clear buffer after training
self.buffer = ExperienceBuffer()
return metrics
def _prepare_training_data(
self, group_rewards: Optional[Dict[str, List[float]]] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Prepare data for training from the experience buffer.
Args:
group_rewards: Rewards collected by all agents in the group
Returns:
Tuple of (states, actions, old_log_probs, returns, advantages)
"""
# Collect experiences from buffer
experiences = list(self.buffer.buffer)
# Extract components
states = torch.FloatTensor([exp.state for exp in experiences]).to(self.device)
actions = torch.LongTensor([exp.action for exp in experiences]).to(self.device)
rewards = torch.FloatTensor([exp.reward for exp in experiences]).to(self.device)
next_states = torch.FloatTensor([exp.next_state for exp in experiences]).to(self.device)
dones = torch.FloatTensor([float(exp.done) for exp in experiences]).to(self.device)
# Compute values for all states and next states
with torch.no_grad():
values = self.value(states).squeeze()
next_values = self.value(next_states).squeeze()
# Compute advantages and returns
if self.use_gae:
# Generalized Advantage Estimation
advantages = self._compute_gae(rewards, values, next_values, dones)
else:
# Regular advantages
advantages = rewards + self.gamma * next_values * (1 - dones) - values
# Compute returns (for value function)
returns = advantages + values
# If group rewards are provided, compute relative advantages
if group_rewards is not None and self.agent_id in group_rewards:
relative_advantages = self._compute_relative_advantages(
advantages, group_rewards)
# Combine regular and relative advantages
advantages = (1 - self.relative_advantage_weight) * advantages + \
self.relative_advantage_weight * relative_advantages
# Normalize advantages if enabled
if self.normalize_advantages:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# Get old log probabilities
old_log_probs = torch.FloatTensor(
[self._compute_log_prob(exp.state, exp.action) for exp in experiences]
).to(self.device)
return states, actions, old_log_probs, returns, advantages
def _compute_gae(
self, rewards: torch.Tensor, values: torch.Tensor,
next_values: torch.Tensor, dones: torch.Tensor
) -> torch.Tensor:
"""
Compute advantages using Generalized Advantage Estimation.
Args:
rewards: Batch of rewards
values: Batch of state values
next_values: Batch of next state values
dones: Batch of done flags
Returns:
Batch of advantage estimates
"""
# Initialize advantages
advantages = torch.zeros_like(rewards)
# Initialize gae
gae = 0
# Compute advantages in reverse order
for t in reversed(range(len(rewards))):
# Compute TD error
delta = rewards[t] + self.gamma * next_values[t] * (1 - dones[t]) - values[t]
# Update gae
gae = delta + self.gamma * self.gae_lambda * (1 - dones[t]) * gae
# Store advantage
advantages[t] = gae
return advantages
def _compute_relative_advantages(
self, advantages: torch.Tensor, group_rewards: Dict[str, List[float]]
) -> torch.Tensor:
"""
Compute relative advantages compared to other agents in the group.
Args:
advantages: This agent's advantages
group_rewards: Rewards collected by all agents in the group
Returns:
Relative advantages
"""
# Compute mean reward for each agent
agent_mean_rewards = {
agent_id: sum(rewards) / max(1, len(rewards))
for agent_id, rewards in group_rewards.items()
}
# Compute mean reward across all agents
group_mean_reward = sum(agent_mean_rewards.values()) / len(agent_mean_rewards)
# Compute relative performance factor
# Higher if this agent is doing better than the group average
if self.agent_id in agent_mean_rewards:
relative_factor = agent_mean_rewards[self.agent_id] / (group_mean_reward + 1e-8)
else:
relative_factor = 1.0
# Apply the relative factor to the advantages
relative_advantages = advantages * relative_factor
return relative_advantages
def _compute_log_prob(self, state: np.ndarray, action: int) -> float:
"""
Compute the log probability of an action given a state.
Args:
state: The state
action: The action
Returns:
The log probability
"""
# Convert state to tensor
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
# Get action distribution
with torch.no_grad():
logits = self.policy(state_tensor)
distribution = Categorical(logits=logits)
log_prob = distribution.log_prob(torch.tensor(action, device=self.device)).item()
return log_prob
def _update_policy(
self,
states: torch.Tensor,
actions: torch.Tensor,
old_log_probs: torch.Tensor,
advantages: torch.Tensor
) -> Tuple[float, float, float]:
"""
Update the policy network using PPO.
Args:
states: Batch of states
actions: Batch of actions
old_log_probs: Batch of old log probabilities
advantages: Batch of advantages
Returns:
Tuple of (policy_loss, entropy, kl_divergence)
"""
# Get action distributions
logits = self.policy(states)
distribution = Categorical(logits=logits)
# Get new log probabilities
new_log_probs = distribution.log_prob(actions)
# Compute probability ratio
ratio = torch.exp(new_log_probs - old_log_probs)
# Compute surrogate objectives
surrogate1 = ratio * advantages
surrogate2 = torch.clamp(ratio, 1 - self.clip_ratio, 1 + self.clip_ratio) * advantages
# Compute policy loss (negative because we're maximizing)
policy_loss = -torch.min(surrogate1, surrogate2).mean()
# Compute entropy bonus
entropy = distribution.entropy().mean()
# Add entropy bonus to loss
loss = policy_loss - self.entropy_coef * entropy
# Compute approximate KL divergence for monitoring
with torch.no_grad():
kl = (old_log_probs - new_log_probs).mean().item()
# Update policy network
self.policy_optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy_optimizer.step()
return policy_loss.item(), entropy.item(), kl
def _update_value(self, states: torch.Tensor, returns: torch.Tensor) -> float:
"""
Update the value network.
Args:
states: Batch of states
returns: Batch of returns
Returns:
Value loss
"""
# Get value predictions
values = self.value(states).squeeze()
# Compute value loss
value_loss = F.mse_loss(values, returns)
# Update value network
self.value_optimizer.zero_grad()
value_loss.backward()
nn.utils.clip_grad_norm_(self.value.parameters(), self.max_grad_norm)
self.value_optimizer.step()
return value_loss.item()
class MultiAgentGroupRL:
"""
Multi-agent reinforcement learning system using GRPO for Claude Code.
This class manages multiple GRPO agents that learn in a coordinated way.
"""
def __init__(
self,
agent_configs: List[Dict[str, Any]],
feature_extractor: Callable[[Dict[str, Any]], np.ndarray],
reward_function: Callable[[Dict[str, Any], str, Any], float],
update_interval: int = 1000,
training_epochs: int = 10,
batch_size: int = 64,
save_dir: str = "./models",
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
"""
Initialize the multi-agent RL system.
Args:
agent_configs: List of configurations for each agent
feature_extractor: Function to extract state features
reward_function: Function to compute rewards
update_interval: How often to update agents (in steps)
training_epochs: Number of epochs to train for each update
batch_size: Batch size for training
save_dir: Directory to save models
device: Device to run on
"""
self.feature_extractor = feature_extractor
self.reward_function = reward_function
self.update_interval = update_interval
self.training_epochs = training_epochs
self.batch_size = batch_size
self.save_dir = save_dir
self.device = device
# Initialize agents
self.agents = {}
for config in agent_configs:
agent_id = config["id"]
state_dim = config["state_dim"]
action_dim = config["action_dim"]
# Create GRPO agent
agent = GRPO(
state_dim=state_dim,
action_dim=action_dim,
hidden_dims=config.get("hidden_dims", [64, 64]),
device=device,
**{k: v for k, v in config.items() if k not in ["id", "state_dim", "action_dim", "hidden_dims"]}
)
# Set agent ID
agent.set_agent_id(agent_id)
self.agents[agent_id] = agent
# Track steps for periodic updates
self.total_steps = 0
# Store rewards for relative advantage computation
self.agent_rewards = {agent_id: [] for agent_id in self.agents}
def select_action(
self, agent_id: str, observation: Dict[str, Any], deterministic: bool = False
) -> Tuple[Any, float]:
"""
Select an action for the specified agent.
Args:
agent_id: ID of the agent
observation: Current observation
deterministic: Whether to select deterministically
Returns:
Tuple of (action, log probability)
"""
if agent_id not in self.agents:
raise ValueError(f"Unknown agent ID: {agent_id}")
# Extract features
state = self.feature_extractor(observation)
# Get action from agent
action, log_prob = self.agents[agent_id].get_action(state, deterministic)
return action, log_prob
def observe(
self,
agent_id: str,
observation: Dict[str, Any],
action: Any,
reward: float,
next_observation: Dict[str, Any],
done: bool,
info: Optional[Dict[str, Any]] = None
) -> None:
"""
Record an observation for the specified agent.
Args:
agent_id: ID of the agent
observation: Current observation
action: Action taken
reward: Reward received
next_observation: Next observation
done: Whether the episode is done
info: Additional information
"""
if agent_id not in self.agents:
raise ValueError(f"Unknown agent ID: {agent_id}")
# Extract features
state = self.feature_extractor(observation)
next_state = self.feature_extractor(next_observation)
# Create experience
exp = Experience(
state=state,
action=action,
reward=reward,
next_state=next_state,
done=done,
info=info
)
# Add experience to agent's buffer
self.agents[agent_id].buffer.add(exp)
# Store reward for relative advantage computation
self.agent_rewards[agent_id].append(reward)
# Increment step counter
self.total_steps += 1
# Perform updates if needed
if self.total_steps % self.update_interval == 0:
self.update_all_agents()
def update_all_agents(self) -> Dict[str, Dict[str, float]]:
"""
Update all agents' policies.
Returns:
Dictionary of training metrics for each agent
"""
# Store metrics for each agent
metrics = {}
# Update each agent
for agent_id, agent in self.agents.items():
# Train the agent with group rewards
agent_metrics = agent.learn(
batch_size=self.batch_size,
epochs=self.training_epochs,
group_rewards=self.agent_rewards
)
metrics[agent_id] = agent_metrics
# Reset reward tracking
self.agent_rewards = {agent_id: [] for agent_id in self.agents}
return metrics
def save_agents(self, suffix: str = "") -> None:
"""
Save all agents' models.
Args:
suffix: Optional suffix for saved files
"""
import os
# Create save directory if it doesn't exist
os.makedirs(self.save_dir, exist_ok=True)
# Save each agent
for agent_id, agent in self.agents.items():
# Create file path
file_path = os.path.join(self.save_dir, f"{agent_id}{suffix}.pt")
# Save model
torch.save({
"policy_state_dict": agent.policy.state_dict(),
"value_state_dict": agent.value.state_dict(),
"policy_optimizer_state_dict": agent.policy_optimizer.state_dict(),
"value_optimizer_state_dict": agent.value_optimizer.state_dict(),
}, file_path)
def load_agents(self, suffix: str = "") -> None:
"""
Load all agents' models.
Args:
suffix: Optional suffix for loaded files
"""
import os
# Load each agent
for agent_id, agent in self.agents.items():
# Create file path
file_path = os.path.join(self.save_dir, f"{agent_id}{suffix}.pt")
# Check if file exists
if not os.path.exists(file_path):
print(f"Warning: Model file not found for agent {agent_id}")
continue
# Load model
checkpoint = torch.load(file_path, map_location=self.device)
# Load state dicts
agent.policy.load_state_dict(checkpoint["policy_state_dict"])
agent.value.load_state_dict(checkpoint["value_state_dict"])
agent.policy_optimizer.load_state_dict(checkpoint["policy_optimizer_state_dict"])
agent.value_optimizer.load_state_dict(checkpoint["value_optimizer_state_dict"])
class ToolSelectionGRPO:
"""
Specialized GRPO implementation for tool selection in Claude Code.
This class adapts the MultiAgentGroupRL for the specific context of tool selection.
"""
def __init__(
self,
tool_registry: Any, # Should be a reference to the tool registry
context_evaluator: Callable, # Function to evaluate quality of response given context
state_dim: int = 768, # Embedding dimension for query
num_agents: int = 3, # Number of agents in the group
update_interval: int = 100,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
"""
Initialize the GRPO tool selector.
Args:
tool_registry: Registry containing available tools
context_evaluator: Function to evaluate response quality
state_dim: Dimension of state features
num_agents: Number of agents in the group
update_interval: How often to update agents
device: Device to run on
"""
self.tool_registry = tool_registry
self.context_evaluator = context_evaluator
# Get all available tools
self.tool_names = tool_registry.get_all_tool_names()
self.action_dim = len(self.tool_names)
# Define agent configurations
agent_configs = [
{
"id": f"tool_agent_{i}",
"state_dim": state_dim,
"action_dim": self.action_dim,
"hidden_dims": [256, 128],
"relative_advantage_weight": 0.7 if i > 0 else 0.3, # Different weights
"entropy_coef": 0.02 if i == 0 else 0.01, # Different exploration rates
}
for i in range(num_agents)
]
# Initialize multi-agent RL system
self.rl_system = MultiAgentGroupRL(
agent_configs=agent_configs,
feature_extractor=self._extract_features,
reward_function=self._compute_reward,
update_interval=update_interval,
device=device,
)
# Track current episode
self.current_episode = {agent_id: {} for agent_id in self.rl_system.agents}
def select_tool(self, user_query: str, context: Dict[str, Any], visualizer=None) -> str:
"""
Select the best tool to use for a given user query and context.
Args:
user_query: The user's query
context: The current conversation context
visualizer: Optional visualizer to display the selection process
Returns:
The name of the best tool to use
"""
# Create observation
observation = {
"query": user_query,
"context": context,
}
# If visualizer is provided, start it
if visualizer:
visualizer.start()
visualizer.add_execution(
execution_id="tool_selection",
tool_name="GRPO Tool Selection",
parameters={"query": user_query[:100] + "..." if len(user_query) > 100 else user_query}
)
# Select agent to use (round-robin for now)
agent_id = f"tool_agent_{self.rl_system.total_steps % len(self.rl_system.agents)}"
# Update visualizer if provided
if visualizer:
visualizer.update_progress("tool_selection", 0.3)
# Get action from agent
action_idx, _ = self.rl_system.select_action(
agent_id=agent_id,
observation=observation,
deterministic=False # Use exploratory actions during learning
)
# Update visualizer if provided
if visualizer:
visualizer.update_progress("tool_selection", 0.6)
# Store initial information for the episode
self.current_episode[agent_id] = {
"observation": observation,
"action_idx": action_idx,
"initial_quality": self.context_evaluator(context),
}
# Map action index to tool name
tool_name = self.tool_names[action_idx]
# Complete visualization if provided
if visualizer:
# Create detailed metrics for visualization
agent_data = {}
for aid, agent in self.rl_system.agents.items():
# Get all tool probabilities for this agent
with torch.no_grad():
state = self.rl_system._extract_features(observation)
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(agent.device)
logits = agent.policy(state_tensor)
probs = F.softmax(logits, dim=1).squeeze().cpu().numpy()
# Add to metrics
agent_data[aid] = {
"selected": aid == agent_id,
"tool_probabilities": {
self.tool_names[i]: float(prob)
for i, prob in enumerate(probs)
}
}
# Complete the visualization
visualizer.complete_execution(
execution_id="tool_selection",
result={
"selected_tool": tool_name,
"selected_agent": agent_id,
"agent_data": agent_data
},
status="success"
)
visualizer.stop()
return tool_name
def observe_result(
self, agent_id: str, result: Any, context: Dict[str, Any], done: bool = True
) -> None:
"""
Observe the result of using a tool.
Args:
agent_id: The ID of the agent that selected the tool
result: The result of using the tool
context: The updated context after using the tool
done: Whether the interaction is complete
"""
if agent_id not in self.current_episode:
return
# Get episode information
episode = self.current_episode[agent_id]
observation = episode["observation"]
action_idx = episode["action_idx"]
initial_quality = episode["initial_quality"]
# Create next observation
next_observation = {
"query": observation["query"],
"context": context,
"result": result,
}
# Compute reward
reward = self._compute_reward(observation, action_idx, result, context, initial_quality)
# Record observation
self.rl_system.observe(
agent_id=agent_id,
observation=observation,
action=action_idx,
reward=reward,
next_observation=next_observation,
done=done,
)
# Clear episode if done
if done:
self.current_episode[agent_id] = {}
def _extract_features(self, observation: Dict[str, Any]) -> np.ndarray:
"""Extract features from an observation."""
# This would ideally use an embedding model
# For now, return a random vector as a placeholder
return np.random.randn(768)
def _compute_reward(
self,
observation: Dict[str, Any],
action_idx: int,
result: Any,
context: Dict[str, Any],
initial_quality: float
) -> float:
"""Compute the reward for an action."""
# Compute the quality improvement
final_quality = self.context_evaluator(context)
quality_improvement = final_quality - initial_quality
# Base reward on quality improvement
reward = max(0, quality_improvement * 10) # Scale for better learning
return reward
def update(self) -> Dict[str, Dict[str, float]]:
"""
Trigger an update of all agents.
Returns:
Dictionary of training metrics
"""
return self.rl_system.update_all_agents()
def save(self, suffix: str = "") -> None:
"""Save all agents."""
self.rl_system.save_agents(suffix)
def load(self, suffix: str = "") -> None:
"""Load all agents."""
self.rl_system.load_agents(suffix)
```
--------------------------------------------------------------------------------
/claude_code/lib/rl/mcts.py:
--------------------------------------------------------------------------------
```python
"""
Monte Carlo Tree Search implementation for decision making in Claude Code.
This module provides an advanced MCTS implementation that can be used to select
optimal actions/tools based on simulated outcomes.
"""
import math
import numpy as np
import random
from typing import List, Dict, Any, Callable, Tuple, Optional, Union
from dataclasses import dataclass
@dataclass
class MCTSNode:
"""Represents a node in the Monte Carlo search tree."""
state: Any
parent: Optional['MCTSNode'] = None
action_taken: Any = None
visits: int = 0
value: float = 0.0
children: Dict[Any, 'MCTSNode'] = None
def __post_init__(self):
if self.children is None:
self.children = {}
def is_fully_expanded(self, possible_actions: List[Any]) -> bool:
"""Check if all possible actions have been tried from this node."""
return all(action in self.children for action in possible_actions)
def is_terminal(self) -> bool:
"""Check if this node represents a terminal state."""
# This should be customized based on your environment
return False
def best_child(self, exploration_weight: float = 1.0) -> 'MCTSNode':
"""Select the best child node according to UCB1 formula."""
if not self.children:
return None
def ucb_score(child: MCTSNode) -> float:
exploitation = child.value / child.visits if child.visits > 0 else 0
exploration = math.sqrt(2 * math.log(self.visits) / child.visits) if child.visits > 0 else float('inf')
return exploitation + exploration_weight * exploration
return max(self.children.values(), key=ucb_score)
class AdvancedMCTS:
"""
Advanced Monte Carlo Tree Search implementation with various enhancements:
- Progressive widening for large/continuous action spaces
- RAVE (Rapid Action Value Estimation)
- Parallel simulations
- Dynamic exploration weight
- Customizable simulation and backpropagation strategies
"""
def __init__(
self,
state_evaluator: Callable[[Any], float],
action_generator: Callable[[Any], List[Any]],
simulator: Callable[[Any, Any], Any],
max_iterations: int = 1000,
exploration_weight: float = 1.0,
time_limit: Optional[float] = None,
progressive_widening: bool = False,
pw_coef: float = 0.5,
pw_power: float = 0.5,
use_rave: bool = False,
rave_equiv_param: float = 1000,
):
"""
Initialize the MCTS algorithm.
Args:
state_evaluator: Function to evaluate the value of a state (terminal or not)
action_generator: Function to generate possible actions from a state
simulator: Function to simulate taking an action in a state, returning new state
max_iterations: Maximum number of search iterations
exploration_weight: Controls exploration vs exploitation balance
time_limit: Optional time limit for search in seconds
progressive_widening: Whether to use progressive widening for large action spaces
pw_coef: Coefficient for progressive widening
pw_power: Power for progressive widening
use_rave: Whether to use RAVE (Rapid Action Value Estimation)
rave_equiv_param: RAVE equivalence parameter
"""
self.state_evaluator = state_evaluator
self.action_generator = action_generator
self.simulator = simulator
self.max_iterations = max_iterations
self.exploration_weight = exploration_weight
self.time_limit = time_limit
# Progressive widening parameters
self.progressive_widening = progressive_widening
self.pw_coef = pw_coef
self.pw_power = pw_power
# RAVE parameters
self.use_rave = use_rave
self.rave_equiv_param = rave_equiv_param
self.rave_values = {} # (state, action) -> (value, visits)
def search(self, initial_state: Any, visualizer=None) -> Any:
"""
Perform MCTS search from the initial state and return the best action.
Args:
initial_state: The starting state for the search
visualizer: Optional visualizer to show progress
Returns:
The best action found by the search
"""
root = MCTSNode(state=initial_state)
# Initialize visualizer if provided
if visualizer:
visualizer.set_search_parameters(root, self.max_iterations)
# Run iterations of the MCTS algorithm
for iteration in range(self.max_iterations):
# Selection phase
selected_node = self._select(root)
# Expansion phase (if not terminal)
expanded_node = None
if not selected_node.is_terminal():
expanded_node = self._expand(selected_node)
else:
expanded_node = selected_node
# Simulation phase
simulation_path = []
if visualizer:
# Track simulation path for visualization
current = expanded_node
current_state = current.state
while current.parent:
simulation_path.insert(0, (current.parent.state, current.action_taken))
current = current.parent
simulation_result = self._simulate(expanded_node)
# Backpropagation phase
self._backpropagate(expanded_node, simulation_result)
# Update visualization
if visualizer:
# Find current best action
best_action = None
if root.children:
best_action = max(root.children.items(), key=lambda x: x[1].visits)[0]
# Update visualizer
visualizer.update_iteration(
iteration=iteration + 1,
selected_node=selected_node,
expanded_node=expanded_node,
simulation_path=simulation_path,
simulation_result=simulation_result,
best_action=best_action
)
# Return the action that leads to the child with the highest value
if not root.children:
possible_actions = self.action_generator(root.state)
if possible_actions:
best_action = random.choice(possible_actions)
if visualizer:
visualizer.update_iteration(
iteration=self.max_iterations,
best_action=best_action
)
return best_action
return None
best_action = max(root.children.items(), key=lambda x: x[1].visits)[0]
if visualizer:
visualizer.update_iteration(
iteration=self.max_iterations,
best_action=best_action
)
return best_action
def _select(self, node: MCTSNode) -> MCTSNode:
"""
Select a node to expand using UCB1 and progressive widening if enabled.
Args:
node: The current node
Returns:
The selected node for expansion
"""
while not node.is_terminal():
possible_actions = self.action_generator(node.state)
# Handle progressive widening if enabled
if self.progressive_widening:
max_children = max(1, int(self.pw_coef * (node.visits ** self.pw_power)))
if len(node.children) < min(max_children, len(possible_actions)):
return node
# If not fully expanded, select this node for expansion
if not node.is_fully_expanded(possible_actions):
return node
# Otherwise, select the best child according to UCB1
node = node.best_child(self.exploration_weight)
if node is None:
break
return node
def _expand(self, node: MCTSNode) -> MCTSNode:
"""
Expand the node by selecting an untried action and creating a new child node.
Args:
node: The node to expand
Returns:
The newly created child node
"""
possible_actions = self.action_generator(node.state)
untried_actions = [a for a in possible_actions if a not in node.children]
if not untried_actions:
return node
action = random.choice(untried_actions)
new_state = self.simulator(node.state, action)
child_node = MCTSNode(
state=new_state,
parent=node,
action_taken=action
)
node.children[action] = child_node
return child_node
def _simulate(self, node: MCTSNode, depth: int = 10) -> float:
"""
Simulate a random playout from the given node until a terminal state or max depth.
Args:
node: The node to start simulation from
depth: Maximum simulation depth
Returns:
The value of the simulated outcome
"""
state = node.state
current_depth = 0
# Continue simulation until we reach a terminal state or max depth
while current_depth < depth:
if self._is_terminal_state(state):
break
possible_actions = self.action_generator(state)
if not possible_actions:
break
action = random.choice(possible_actions)
state = self.simulator(state, action)
current_depth += 1
return self.state_evaluator(state)
def _is_terminal_state(self, state: Any) -> bool:
"""Determine if the state is terminal."""
# This should be customized based on your environment
return False
def _backpropagate(self, node: MCTSNode, value: float) -> None:
"""
Backpropagate the simulation result up the tree.
Args:
node: The leaf node where simulation started
value: The value from the simulation
"""
while node is not None:
node.visits += 1
node.value += value
# Update RAVE values if enabled
if self.use_rave and node.parent is not None:
state_hash = self._hash_state(node.parent.state)
action = node.action_taken
if (state_hash, action) not in self.rave_values:
self.rave_values[(state_hash, action)] = [0, 0] # [value, visits]
rave_value, rave_visits = self.rave_values[(state_hash, action)]
self.rave_values[(state_hash, action)] = [
rave_value + value,
rave_visits + 1
]
node = node.parent
def _hash_state(self, state: Any) -> int:
"""Create a hash of the state for RAVE table lookups."""
# This should be customized based on your state representation
if hasattr(state, "__hash__"):
return hash(state)
return hash(str(state))
class MCTSToolSelector:
"""
Specialized MCTS implementation for selecting optimal tools in Claude Code.
This class adapts the AdvancedMCTS for the specific context of tool selection.
"""
def __init__(
self,
tool_registry: Any, # Should be a reference to the tool registry
context_evaluator: Callable, # Function to evaluate quality of response given context
max_iterations: int = 200,
exploration_weight: float = 1.0,
use_learning: bool = True,
tool_history_weight: float = 0.7,
enable_plan_generation: bool = True,
use_semantic_similarity: bool = True,
adaptation_rate: float = 0.05
):
"""
Initialize the MCTS tool selector with enhanced intelligence.
Args:
tool_registry: Registry containing available tools
context_evaluator: Function to evaluate response quality
max_iterations: Maximum search iterations
exploration_weight: Controls exploration vs exploitation
use_learning: Whether to use learning from past tool selections
tool_history_weight: Weight given to historical tool performance
enable_plan_generation: Generate complete tool sequences as plans
use_semantic_similarity: Use semantic similarity for tool relevance
adaptation_rate: Rate at which the system adapts to new patterns
"""
self.tool_registry = tool_registry
self.context_evaluator = context_evaluator
self.use_learning = use_learning
self.tool_history_weight = tool_history_weight
self.enable_plan_generation = enable_plan_generation
self.use_semantic_similarity = use_semantic_similarity
self.adaptation_rate = adaptation_rate
# Tool performance history by query type
self.tool_history = {}
# Tool sequence effectiveness records
self.sequence_effectiveness = {}
# Semantic fingerprints for tools and queries
self.tool_fingerprints = {}
self.query_clusters = {}
# Cached simulation results for similar queries
self.simulation_cache = {}
# Initialize the MCTS algorithm
self.mcts = AdvancedMCTS(
state_evaluator=self._evaluate_state,
action_generator=self._generate_actions,
simulator=self._simulate_action,
max_iterations=max_iterations,
exploration_weight=exploration_weight,
progressive_widening=True
)
# Initialize tool fingerprints
self._initialize_tool_fingerprints()
def _initialize_tool_fingerprints(self):
"""Initialize semantic fingerprints for all available tools."""
if not self.use_semantic_similarity:
return
for tool_name in self.tool_registry.get_all_tool_names():
tool = self.tool_registry.get_tool(tool_name)
if tool and hasattr(tool, 'description'):
# In a real implementation, this would compute an embedding
# For now, we'll use a simple keyword extraction as a placeholder
keywords = set(word.lower() for word in tool.description.split()
if len(word) > 3)
self.tool_fingerprints[tool_name] = {
'keywords': keywords,
'description': tool.description,
'usage_contexts': set()
}
def select_tool(self, user_query: str, context: Dict[str, Any], visualizer=None) -> Union[str, List[str]]:
"""
Select the best tool to use for a given user query and context.
Args:
user_query: The user's query
context: The current conversation context
visualizer: Optional visualizer to show the selection process
Returns:
Either a single tool name or a sequence of tool names (if plan generation is enabled)
"""
# Analyze query to determine its type/characteristics
query_type = self._analyze_query(user_query)
# Update semantic fingerprints with this query
if self.use_semantic_similarity:
self._update_query_clusters(user_query, query_type)
initial_state = {
'query': user_query,
'query_type': query_type,
'context': context,
'actions_taken': [],
'response_quality': 0.0,
'steps_remaining': 3 if self.enable_plan_generation else 1,
'step_results': {}
}
# First check if we have a high-confidence cached result for similar queries
cached_result = self._check_cache(user_query, query_type)
if cached_result and random.random() > 0.1: # 10% random exploration
if visualizer:
visualizer.add_execution(
execution_id="mcts_cache_hit",
tool_name="MCTS Tool Selection (cached)",
parameters={"query": user_query[:100] + "..." if len(user_query) > 100 else user_query}
)
visualizer.complete_execution(
execution_id="mcts_cache_hit",
result={"selected_tool": cached_result, "source": "cache"},
status="success"
)
return cached_result
# Run MCTS search
best_action = self.mcts.search(initial_state, visualizer)
# If plan generation is enabled, we might want to return a sequence
if self.enable_plan_generation:
# Extract the most promising action sequence from search
plan = self._extract_plan_from_search()
if plan and len(plan) > 1:
# Store this plan in our cache
self._cache_result(user_query, query_type, plan)
return plan
# Store single action in cache
self._cache_result(user_query, query_type, best_action)
return best_action
def _analyze_query(self, query: str) -> str:
"""
Analyze a query to determine its type and characteristics.
Args:
query: The user query
Returns:
A string identifying the query type
"""
query_lower = query.lower()
# Check for search-related queries
if any(term in query_lower for term in ['find', 'search', 'where', 'look for']):
return 'search'
# Check for explanation queries
if any(term in query_lower for term in ['explain', 'how', 'why', 'what is']):
return 'explanation'
# Check for file operation queries
if any(term in query_lower for term in ['file', 'read', 'write', 'edit', 'create']):
return 'file_operation'
# Check for execution queries
if any(term in query_lower for term in ['run', 'execute', 'start']):
return 'execution'
# Check for debugging queries
if any(term in query_lower for term in ['debug', 'fix', 'error', 'problem']):
return 'debugging'
# Default to general
return 'general'
def _update_query_clusters(self, query: str, query_type: str):
"""
Update query clusters with new query information.
Args:
query: The user query
query_type: The type of query
"""
# Extract query keywords
keywords = set(word.lower() for word in query.split() if len(word) > 3)
# Update query clusters
if query_type not in self.query_clusters:
self.query_clusters[query_type] = {
'keywords': set(),
'queries': []
}
# Add keywords to cluster
self.query_clusters[query_type]['keywords'].update(keywords)
# Add query to cluster (limit to last 50)
self.query_clusters[query_type]['queries'].append(query)
if len(self.query_clusters[query_type]['queries']) > 50:
self.query_clusters[query_type]['queries'].pop(0)
# Update tool fingerprints with these keywords
for tool_name, fingerprint in self.tool_fingerprints.items():
# If tool has been used successfully for this query type before
if tool_name in self.tool_history.get(query_type, {}) and \
self.tool_history[query_type][tool_name]['success_rate'] > 0.6:
fingerprint['usage_contexts'].add(query_type)
def _check_cache(self, query: str, query_type: str) -> Union[str, List[str], None]:
"""
Check if we have a cached result for a similar query.
Args:
query: The user query
query_type: The type of query
Returns:
A cached tool selection or None
"""
if not self.use_learning or query_type not in self.tool_history:
return None
# Find the most successful tool for this query type
type_history = self.tool_history[query_type]
best_tools = sorted(
[(tool, data['success_rate']) for tool, data in type_history.items()],
key=lambda x: x[1],
reverse=True
)
# Only use cache if we have a high confidence result
if best_tools and best_tools[0][1] > 0.75:
return best_tools[0][0]
return None
def _cache_result(self, query: str, query_type: str, action: Union[str, List[str]]):
"""
Cache a result for future similar queries.
Args:
query: The user query
query_type: The type of query
action: The selected action or plan
"""
# Store in simulation cache
query_key = self._get_query_cache_key(query)
self.simulation_cache[query_key] = {
'action': action,
'timestamp': self._get_timestamp(),
'query_type': query_type
}
# Limit cache size
if len(self.simulation_cache) > 1000:
# Remove oldest entries
oldest_key = min(self.simulation_cache.keys(),
key=lambda k: self.simulation_cache[k]['timestamp'])
del self.simulation_cache[oldest_key]
def _get_query_cache_key(self, query: str) -> str:
"""Generate a cache key for a query."""
# In a real implementation, this might use a hash of query embeddings
# For now, use a simple keyword approach
keywords = ' '.join(sorted(set(word.lower() for word in query.split() if len(word) > 3)))
return keywords[:100] # Limit key length
def _get_timestamp(self):
"""Get current timestamp."""
import time
return time.time()
def _evaluate_state(self, state: Dict[str, Any]) -> float:
"""
Evaluate the quality of a state based on response quality and steps.
Args:
state: The current state
Returns:
A quality score
"""
# Base score is the response quality
score = state['response_quality']
# If plan generation is enabled, we want to encourage complete plans
if self.enable_plan_generation:
steps_completed = len(state['actions_taken'])
total_steps = steps_completed + state['steps_remaining']
# Add bonus for completing more steps
if total_steps > 0:
step_completion_bonus = steps_completed / total_steps
score += step_completion_bonus * 0.2 # 20% bonus for step completion
return score
def _generate_actions(self, state: Dict[str, Any]) -> List[str]:
"""
Generate possible tool actions from the current state with intelligent filtering.
Args:
state: The current state
Returns:
List of possible actions
"""
# Get query type
query_type = state['query_type']
query = state['query']
# Get all available tools
all_tools = set(self.tool_registry.get_all_tool_names())
# Tools already used in this sequence
used_tools = set(state['actions_taken'])
# Remaining tools
remaining_tools = all_tools - used_tools
# If we're using learning, prioritize tools based on history
if self.use_learning and query_type in self.tool_history:
prioritized_tools = []
# First, add tools that have been successful for this query type
type_history = self.tool_history[query_type]
# Check for successful tools
for tool in remaining_tools:
if tool in type_history and type_history[tool]['success_rate'] > 0.5:
prioritized_tools.append(tool)
# If we have at least some tools, return them
if prioritized_tools and random.random() < self.tool_history_weight:
return prioritized_tools
# If using semantic similarity, filter by relevant tools
if self.use_semantic_similarity:
query_keywords = set(word.lower() for word in query.split() if len(word) > 3)
# Score tools by semantic similarity to query
scored_tools = []
for tool in remaining_tools:
if tool in self.tool_fingerprints:
fingerprint = self.tool_fingerprints[tool]
# Calculate keyword overlap
keyword_overlap = len(query_keywords.intersection(fingerprint['keywords']))
# Check if tool has been used for this query type
context_match = 1.0 if query_type in fingerprint['usage_contexts'] else 0.0
# Combined score
score = keyword_overlap * 0.7 + context_match * 0.3
scored_tools.append((tool, score))
# Sort and filter tools
scored_tools.sort(key=lambda x: x[1], reverse=True)
# Take top half of tools if we have enough
if len(scored_tools) > 2:
return [t[0] for t in scored_tools[:max(2, len(scored_tools) // 2)]]
# If we reach here, use all remaining tools
return list(remaining_tools)
def _simulate_action(self, state: Dict[str, Any], action: str) -> Dict[str, Any]:
"""
Simulate taking an action (using a tool) in the given state with enhanced modeling.
Args:
state: The current state
action: The tool action to simulate
Returns:
The new state after taking the action
"""
# Create a new state with the action added
new_state = state.copy()
new_actions = state['actions_taken'].copy()
new_actions.append(action)
new_state['actions_taken'] = new_actions
# Decrement steps remaining if using plan generation
if self.enable_plan_generation and new_state['steps_remaining'] > 0:
new_state['steps_remaining'] -= 1
# Get query type and query
query_type = state['query_type']
query = state['query']
# Simulate step result
step_results = state['step_results'].copy()
step_results[action] = self._simulate_tool_result(action, query)
new_state['step_results'] = step_results
# Estimate tool relevance based on learning or semantic similarity
tool_relevance = self._estimate_tool_relevance(action, query, query_type)
# Check for sequence effects (tools that work well together)
sequence_bonus = 0.0
if len(new_actions) > 1:
prev_tool = new_actions[-2]
sequence_key = f"{prev_tool}->{action}"
if sequence_key in self.sequence_effectiveness:
sequence_bonus = self.sequence_effectiveness[sequence_key] * 0.3 # 30% weight for sequence effects
# Update quality based on relevance and sequence effects
current_quality = state['response_quality']
quality_improvement = tool_relevance + sequence_bonus
# Add diminishing returns effect for additional tools
if len(new_actions) > 1:
diminishing_factor = 1.0 / len(new_actions)
quality_improvement *= diminishing_factor
new_quality = min(1.0, current_quality + quality_improvement)
new_state['response_quality'] = new_quality
return new_state
def _simulate_tool_result(self, tool_name: str, query: str) -> Dict[str, Any]:
"""
Simulate the result of using a tool for a query.
Args:
tool_name: The name of the tool
query: The user query
Returns:
A simulated result
"""
# In a real implementation, this would be a more sophisticated simulation
return {
"tool": tool_name,
"success_probability": self._estimate_tool_relevance(tool_name, query),
"simulated": True
}
def _estimate_tool_relevance(self, tool_name: str, query: str, query_type: str = None) -> float:
"""
Estimate how relevant a tool is for a given query using history and semantics.
Args:
tool_name: The name of the tool
query: The user query
query_type: Optional query type
Returns:
A relevance score between 0.0 and 1.0
"""
relevance_score = 0.0
# If we have historical data for this query type
if self.use_learning and query_type and query_type in self.tool_history and \
tool_name in self.tool_history[query_type]:
# Get historical success rate
history_score = self.tool_history[query_type][tool_name]['success_rate']
relevance_score += history_score * self.tool_history_weight
# If we're using semantic similarity
if self.use_semantic_similarity and tool_name in self.tool_fingerprints:
fingerprint = self.tool_fingerprints[tool_name]
# Calculate keyword overlap
query_keywords = set(word.lower() for word in query.split() if len(word) > 3)
keyword_overlap = len(query_keywords.intersection(fingerprint['keywords']))
# Normalize by query keywords
if query_keywords:
semantic_score = keyword_overlap / len(query_keywords)
relevance_score += semantic_score * (1.0 - self.tool_history_weight)
# Ensure we have a minimum score for exploration
if relevance_score < 0.1:
relevance_score = 0.1 + (random.random() * 0.1) # Random boost between 0.1-0.2
return relevance_score
def _extract_plan_from_search(self) -> List[str]:
"""
Extract a complete plan (tool sequence) from the search results.
Returns:
A list of tool names representing the plan
"""
# In a real implementation, this would extract the highest value path
# from the search tree. For now, return None to indicate no plan extraction.
return None
def update_tool_history(self, tool_name: str, query: str, success: bool,
execution_time: float, result: Any = None):
"""
Update the tool history with the results of using a tool.
Args:
tool_name: The name of the tool used
query: The query the tool was used for
success: Whether the tool was successful
execution_time: The execution time in seconds
result: Optional result of the tool execution
"""
if not self.use_learning:
return
# Get query type
query_type = self._analyze_query(query)
# Initialize history entry if needed
if query_type not in self.tool_history:
self.tool_history[query_type] = {}
if tool_name not in self.tool_history[query_type]:
self.tool_history[query_type][tool_name] = {
'success_count': 0,
'failure_count': 0,
'total_time': 0.0,
'success_rate': 0.0,
'avg_time': 0.0,
'examples': []
}
# Update history
history = self.tool_history[query_type][tool_name]
# Update counts
if success:
history['success_count'] += 1
else:
history['failure_count'] += 1
# Update time
history['total_time'] += execution_time
# Update success rate
total = history['success_count'] + history['failure_count']
history['success_rate'] = history['success_count'] / total if total > 0 else 0.0
# Update average time
history['avg_time'] = history['total_time'] / total if total > 0 else 0.0
# Add example (limit to last 5)
history['examples'].append({
'query': query,
'success': success,
'timestamp': self._get_timestamp()
})
if len(history['examples']) > 5:
history['examples'].pop(0)
# Update tool fingerprint
if self.use_semantic_similarity and tool_name in self.tool_fingerprints:
if success:
# Add query type to usage contexts
self.tool_fingerprints[tool_name]['usage_contexts'].add(query_type)
# Add query keywords to tool fingerprint (with decay)
query_keywords = set(word.lower() for word in query.split() if len(word) > 3)
current_keywords = self.tool_fingerprints[tool_name]['keywords']
# Add new keywords with adaptation rate
for keyword in query_keywords:
if keyword not in current_keywords:
if random.random() < self.adaptation_rate:
current_keywords.add(keyword)
def update_sequence_effectiveness(self, tool_sequence: List[str], success: bool, quality_score: float):
"""
Update the effectiveness record for a sequence of tools.
Args:
tool_sequence: The sequence of tools used
success: Whether the sequence was successful
quality_score: A quality score for the sequence
"""
if not self.use_learning or len(tool_sequence) < 2:
return
# Update pairwise effectiveness
for i in range(len(tool_sequence) - 1):
first_tool = tool_sequence[i]
second_tool = tool_sequence[i + 1]
sequence_key = f"{first_tool}->{second_tool}"
if sequence_key not in self.sequence_effectiveness:
self.sequence_effectiveness[sequence_key] = 0.5 # Initial neutral score
# Update score with decay
current_score = self.sequence_effectiveness[sequence_key]
if success:
# Increase score with quality bonus
new_score = current_score + self.adaptation_rate * quality_score
else:
# Decrease score
new_score = current_score - self.adaptation_rate
# Clamp between 0 and 1
self.sequence_effectiveness[sequence_key] = max(0.0, min(1.0, new_score))
```
--------------------------------------------------------------------------------
/claude_code/lib/ui/tool_visualizer.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
# claude_code/lib/ui/tool_visualizer.py
"""Real-time tool execution visualization."""
import logging
import time
import json
from typing import Dict, List, Any, Optional
from rich.console import Console
from rich.panel import Panel
from rich.progress import Progress, BarColumn, TextColumn, TimeElapsedColumn
from rich.table import Table
from rich.box import ROUNDED
from rich.text import Text
from rich.live import Live
from rich.layout import Layout
from rich.syntax import Syntax
from ..tools.base import ToolResult
logger = logging.getLogger(__name__)
class ToolCallVisualizer:
"""Visualizes tool calls in real-time."""
def __init__(self, console: Console):
"""Initialize the tool call visualizer.
Args:
console: Rich console instance
"""
self.console = console
self.active_calls: Dict[str, Dict[str, Any]] = {}
self.completed_calls: List[Dict[str, Any]] = []
self.layout = self._create_layout()
self.live = Live(self.layout, console=console, refresh_per_second=4, auto_refresh=False)
self.max_completed_calls = 5
# Keep track of recent tool results for routines
self.recent_tool_results: List[ToolResult] = []
self.max_recent_results = 20 # Maximum number of recent results to track
def _create_layout(self) -> Layout:
"""Create the layout for the tool call visualization.
Returns:
Layout object
"""
layout = Layout()
layout.split(
Layout(name="active", size=3),
Layout(name="completed", size=3)
)
return layout
def _create_active_calls_panel(self) -> Panel:
"""Create a panel with active tool calls.
Returns:
Panel with active call information
"""
if not self.active_calls:
return Panel(
"No active tool calls",
title="[bold blue]Active Tool Calls[/bold blue]",
border_style="blue",
box=ROUNDED
)
# Create progress bars for each active call
progress = Progress(
TextColumn("[bold blue]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeElapsedColumn(),
expand=True,
console=self.console
)
# Add tasks for each active call
for call_id, call_info in self.active_calls.items():
if "task_id" not in call_info:
# Create a new task for this call
description = f"{call_info['tool_name']} ({call_id[:6]}...)"
task_id = progress.add_task(description, total=100, completed=int(call_info["progress"] * 100))
call_info["task_id"] = task_id
else:
# Update existing task
progress.update(call_info["task_id"], completed=int(call_info["progress"] * 100))
# Create a table with parameter information
table = Table(show_header=True, header_style="bold cyan", box=ROUNDED, expand=True)
table.add_column("Tool")
table.add_column("Parameters")
for call_id, call_info in self.active_calls.items():
# Format parameters nicely
params = call_info.get("parameters", {})
if params:
formatted_params = "\n".join([f"{k}: {self._format_value(v)}" for k, v in params.items()])
else:
formatted_params = "None"
table.add_row(call_info["tool_name"], formatted_params)
return Panel(
progress,
title="[bold blue]Active Tool Calls[/bold blue]",
border_style="blue",
box=ROUNDED
)
def _create_completed_calls_panel(self) -> Panel:
"""Create a panel with completed tool calls.
Returns:
Panel with completed call information
"""
if not self.completed_calls:
return Panel(
"No completed tool calls",
title="[bold green]Recent Tool Results[/bold green]",
border_style="green",
box=ROUNDED
)
# Create a table for results
table = Table(show_header=True, header_style="bold green", box=ROUNDED, expand=True)
table.add_column("Tool")
table.add_column("Status")
table.add_column("Time")
table.add_column("Result Preview")
# Show only the most recent completed calls
for call_info in self.completed_calls[-self.max_completed_calls:]:
tool_name = call_info["tool_name"]
status = call_info["status"]
execution_time = f"{call_info['execution_time']:.2f}s"
# Format result preview
result = call_info.get("result", "")
if result:
# Truncate and format result
preview = self._format_result_preview(result, tool_name)
else:
preview = "No result"
# Status with color
status_text = Text(status)
if status == "success":
status_text.stylize("bold green")
else:
status_text.stylize("bold red")
table.add_row(tool_name, status_text, execution_time, preview)
return Panel(
table,
title="[bold green]Recent Tool Results[/bold green]",
border_style="green",
box=ROUNDED
)
def _format_value(self, value: Any) -> str:
"""Format a parameter value for display.
Args:
value: Parameter value
Returns:
Formatted string
"""
if isinstance(value, (dict, list)):
# Convert complex structures to JSON with indentation
return json.dumps(value, indent=2)
return str(value)
def _format_result_preview(self, result: str, tool_name: str) -> str:
"""Format a result preview.
Args:
result: Result string
tool_name: Name of the tool
Returns:
Formatted preview string
"""
# Truncate result for preview
if len(result) > 200:
preview = result[:200] + "..."
else:
preview = result
# Clean up newlines for display
preview = preview.replace("\n", "\\n")
return preview
def start(self) -> None:
"""Start the visualization."""
self.live.start()
self.refresh()
def stop(self) -> None:
"""Stop the visualization."""
self.live.stop()
def refresh(self) -> None:
"""Refresh the visualization."""
# Update the layout with current information
self.layout["active"].update(self._create_active_calls_panel())
self.layout["completed"].update(self._create_completed_calls_panel())
# Refresh the live display
self.live.refresh()
def add_tool_call(self, tool_call_id: str, tool_name: str, parameters: Dict[str, Any]) -> None:
"""Add a new tool call to visualize.
Args:
tool_call_id: ID of the tool call
tool_name: Name of the tool
parameters: Tool parameters
"""
self.active_calls[tool_call_id] = {
"tool_name": tool_name,
"parameters": parameters,
"start_time": time.time(),
"progress": 0.0
}
self.refresh()
def update_progress(self, tool_call_id: str, progress: float) -> None:
"""Update the progress of a tool call.
Args:
tool_call_id: ID of the tool call
progress: Progress value (0-1)
"""
if tool_call_id in self.active_calls:
self.active_calls[tool_call_id]["progress"] = progress
self.refresh()
def complete_tool_call(self, tool_call_id: str, result: ToolResult) -> None:
"""Mark a tool call as complete.
Args:
tool_call_id: ID of the tool call
result: Tool execution result
"""
if tool_call_id in self.active_calls:
call_info = self.active_calls[tool_call_id].copy()
# Add result information
call_info["result"] = result.result
call_info["status"] = result.status
call_info["execution_time"] = result.execution_time
call_info["end_time"] = time.time()
# Add to completed calls
self.completed_calls.append(call_info)
# Trim completed calls if needed
if len(self.completed_calls) > self.max_completed_calls * 2:
self.completed_calls = self.completed_calls[-self.max_completed_calls:]
# Remove from active calls
del self.active_calls[tool_call_id]
# Store in recent tool results for routines
if result.status == "success":
self.recent_tool_results.append(result)
# Keep only the most recent results
if len(self.recent_tool_results) > self.max_recent_results:
self.recent_tool_results.pop(0)
self.refresh()
def show_result_detail(self, result: ToolResult) -> None:
"""Display detailed result information.
Args:
result: Tool execution result
"""
# Detect if result might be code
content = result.result
if content.startswith(("def ", "class ", "import ", "from ")) or "```" in content:
# Try to extract code blocks
if "```" in content:
blocks = content.split("```")
# Find a code block with a language specifier
for i in range(1, len(blocks), 2):
if i < len(blocks):
lang = blocks[i].split("\n")[0].strip()
code = "\n".join(blocks[i].split("\n")[1:])
if lang and code:
# Attempt to display as syntax-highlighted code
try:
syntax = Syntax(code, lang, theme="monokai", line_numbers=True)
self.console.print(Panel(syntax, title=f"[bold]Result: {result.name}[/bold]"))
return
except Exception:
pass
# If we can't extract a code block, try to detect language
for lang in ["python", "javascript", "bash", "json"]:
try:
syntax = Syntax(content, lang, theme="monokai", line_numbers=True)
self.console.print(Panel(syntax, title=f"[bold]Result: {result.name}[/bold]"))
return
except Exception:
pass
# Just print as regular text if not code or if highlighting failed
self.console.print(Panel(content, title=f"[bold]Result: {result.name}[/bold]"))
class MCTSVisualizer:
"""Visualizes the Monte Carlo Tree Search process in real-time with enhanced intelligence."""
def __init__(self, console: Console):
"""Initialize the MCTS visualizer.
Args:
console: Rich console instance
"""
self.console = console
self.root_node = None
self.current_iteration = 0
self.max_iterations = 0
self.best_action = None
self.active_simulation = None
self.simulation_path = []
self.layout = self._create_layout()
self.live = Live(self.layout, console=console, refresh_per_second=10, auto_refresh=False)
# Intelligence enhancement - track history
self.action_history = {} # Track action performance over time
self.visit_distribution = {} # Track how visits are distributed
self.exploration_patterns = [] # Track exploration patterns
self.quality_metrics = {"search_efficiency": 0.0, "exploration_balance": 0.0}
self.auto_improvement_enabled = True
def _create_layout(self) -> Layout:
"""Create the layout for MCTS visualization.
Returns:
Layout object
"""
layout = Layout()
# Create the main sections with more detailed visualization
layout.split(
Layout(name="header", size=3),
Layout(name="main"),
Layout(name="intelligence", size=7), # New section for intelligence metrics
Layout(name="stats", size=5)
)
# Split the main section into tree, simulation and action insights
layout["main"].split_row(
Layout(name="tree", ratio=2),
Layout(name="simulation", ratio=1),
Layout(name="insights", ratio=1) # New section for action insights
)
return layout
def set_search_parameters(self, root_node: Any, max_iterations: int, additional_params: Dict[str, Any] = None) -> None:
"""Set the search parameters with enhanced intelligence options.
Args:
root_node: The root node of the search tree
max_iterations: Maximum number of iterations
additional_params: Additional parameters for intelligent search
"""
self.root_node = root_node
self.max_iterations = max_iterations
self.current_iteration = 0
# Initialize intelligence tracking
self.action_history = {}
self.visit_distribution = {}
self.exploration_patterns = []
# Set additional intelligence parameters
if additional_params:
self.auto_improvement_enabled = additional_params.get('auto_improvement', True)
# Apply any initial intelligence strategies
if additional_params.get('initial_action_bias'):
self.action_history = additional_params['initial_action_bias']
self.refresh()
def update_iteration(self, iteration: int, selected_node: Any = None,
expanded_node: Any = None, simulation_path: List[Any] = None,
simulation_result: float = None, best_action: Any = None,
node_values: Dict[str, float] = None) -> None:
"""Update the current iteration status with enhanced tracking.
Args:
iteration: Current iteration number
selected_node: Node selected in this iteration
expanded_node: Node expanded in this iteration
simulation_path: Path of the simulation
simulation_result: Result of the simulation
best_action: Current best action
node_values: Values of important nodes in the search (for visualization)
"""
self.current_iteration = iteration
self.selected_node = selected_node
self.expanded_node = expanded_node
self.simulation_path = simulation_path or []
self.simulation_result = simulation_result
if best_action is not None:
self.best_action = best_action
# Intelligence tracking - update action history
if self.simulation_path and simulation_result is not None:
for _, action in self.simulation_path:
if action is not None:
action_str = str(action)
if action_str not in self.action_history:
self.action_history[action_str] = {
"visits": 0,
"total_value": 0.0,
"iterations": []
}
self.action_history[action_str]["visits"] += 1
self.action_history[action_str]["total_value"] += simulation_result
self.action_history[action_str]["iterations"].append(iteration)
# Update exploration pattern
if selected_node:
# Record exploration choice
self.exploration_patterns.append({
"iteration": iteration,
"node_depth": self._get_node_depth(selected_node),
"node_breadth": len(getattr(selected_node, "children", {})),
"value_estimate": getattr(selected_node, "value", 0) / max(1, getattr(selected_node, "visits", 1))
})
# Update visit distribution
if self.root_node and hasattr(self.root_node, "children"):
self._update_visit_distribution()
# Update quality metrics
self._update_quality_metrics()
self.refresh()
def start(self) -> None:
"""Start the visualization."""
self.live.start()
self.refresh()
def stop(self) -> None:
"""Stop the visualization."""
self.live.stop()
def refresh(self) -> None:
"""Refresh the visualization."""
# Update header
header_content = f"[bold blue]Enhanced Monte Carlo Tree Search - Iteration {self.current_iteration}/{self.max_iterations}[/bold blue]"
if self.best_action:
header_content += f" | Best Action: {self.best_action}"
intelligence_status = "[green]Enabled[/green]" if self.auto_improvement_enabled else "[yellow]Disabled[/yellow]"
header_content += f" | Intelligent Search: {intelligence_status}"
self.layout["header"].update(Panel(header_content, border_style="blue"))
# Update tree visualization
self.layout["tree"].update(self._create_tree_panel())
# Update simulation visualization
self.layout["simulation"].update(self._create_simulation_panel())
# Update action insights panel
self.layout["insights"].update(self._create_insights_panel())
# Update intelligence metrics
self.layout["intelligence"].update(self._create_intelligence_panel())
# Update stats
self.layout["stats"].update(self._create_stats_panel())
# Refresh the live display
self.live.refresh()
def _create_tree_panel(self) -> Panel:
"""Create a panel showing the current state of the search tree.
Returns:
Panel with tree visualization
"""
if not self.root_node:
return Panel("No search tree initialized", title="[bold]Search Tree[/bold]")
# Create a table to show the tree structure
from rich.tree import Tree
from rich import box
tree = Tree("🔍 Root Node", guide_style="bold blue")
# Limit the depth and breadth for display
max_depth = 3
max_children = 5
def add_node(node, tree_node, depth=0, path=None):
if depth >= max_depth or not node or not hasattr(node, "children"):
return
if path is None:
path = []
# Add children nodes
children = list(node.children.items())
if not children:
return
# Sort children by a combination of visits and value
def node_score(node_pair):
child_node = node_pair[1]
visits = getattr(child_node, "visits", 0)
value = getattr(child_node, "value", 0)
# Combine visits and value for scoring
if visits > 0:
# Use UCB-style formula for ranking
exploitation = value / visits
exploration = (2 * 0.5 * (math.log(node.visits) / visits)) if node.visits > 0 and visits > 0 else 0
return exploitation + exploration
return 0
# Sort by this smarter formula
children.sort(key=node_score, reverse=True)
children = children[:max_children]
for action, child in children:
# Format node information
visits = getattr(child, "visits", 0)
value = getattr(child, "value", 0)
# Highlight the node with more sophisticated coloring
style = ""
if child == self.selected_node:
style = "bold yellow"
elif child == self.expanded_node:
style = "bold green"
else:
# Color based on value
if visits > 0:
avg_value = value / visits
if avg_value > 0.7:
style = "green"
elif avg_value > 0.4:
style = "blue"
elif avg_value > 0.2:
style = "yellow"
else:
style = "red"
# Create the node label with enhanced information
current_path = path + [action]
if visits > 0:
avg_value = value / visits
confidence = min(1.0, math.sqrt(visits) / 5) * 100 # Simple confidence estimate
label = f"[{style}]{action}: (Visits: {visits}, Value: {avg_value:.3f}, Conf: {confidence:.0f}%)[/{style}]"
else:
label = f"[{style}]{action}: (New)[/{style}]"
# Add the child node to the tree
child_tree = tree_node.add(label)
# Recursively add its children
add_node(child, child_tree, depth + 1, current_path)
# Start building the tree from the root
if hasattr(self.root_node, "children"):
# Add math import for node scoring
import math
add_node(self.root_node, tree)
return Panel(tree, title="[bold]Search Tree[/bold]", border_style="blue")
def _create_simulation_panel(self) -> Panel:
"""Create a panel showing the current simulation with enhanced analytics.
Returns:
Panel with simulation visualization
"""
if not self.simulation_path:
return Panel("No active simulation", title="[bold]Current Simulation[/bold]")
# Create a list of simulation steps
from rich.table import Table
table = Table(box=None, expand=True)
table.add_column("Step")
table.add_column("Action")
table.add_column("Expected Value") # New column
for i, (state, action) in enumerate(self.simulation_path):
# Get expected value for this action
action_str = str(action) if action is not None else "None"
expected_value = "N/A"
if action_str in self.action_history:
history = self.action_history[action_str]
if history["visits"] > 0:
expected_value = f"{history['total_value'] / history['visits']:.3f}"
table.add_row(f"Step {i+1}", f"{action}", expected_value)
if self.simulation_result is not None:
# Add path quality metric
path_quality = "Low"
if self.simulation_result > 0.7:
path_quality = "[bold green]High[/bold green]"
elif self.simulation_result > 0.4:
path_quality = "[yellow]Medium[/yellow]"
else:
path_quality = "[red]Low[/red]"
table.add_row("Result",
f"[bold green]{self.simulation_result:.3f}[/bold green]",
f"Path Quality: {path_quality}")
return Panel(table, title="[bold]Current Simulation[/bold]", border_style="green")
def _create_insights_panel(self) -> Panel:
"""Create a panel showing action insights from learned patterns.
Returns:
Panel with action insights
"""
from rich.table import Table
if not self.action_history:
return Panel("No action insights available yet", title="[bold]Action Insights[/bold]")
# Get top performing actions
top_actions = []
for action, data in self.action_history.items():
if data["visits"] >= 3: # Only consider actions with enough samples
avg_value = data["total_value"] / data["visits"]
top_actions.append((action, avg_value, data["visits"]))
# Sort by value and take top 5
top_actions.sort(key=lambda x: x[1], reverse=True)
top_actions = top_actions[:5]
# Create insights table
table = Table(box=None, expand=True)
table.add_column("Action")
table.add_column("Avg Value")
table.add_column("Visits")
table.add_column("Trend")
for action, avg_value, visits in top_actions:
# Generate trend indicator based on recent performance
trend = "→"
history = self.action_history[action]["iterations"]
if len(history) >= 5:
recent = set(history[-3:]) # Last 3 iterations
if self.current_iteration - max(recent) <= 5:
trend = "↑" # Recently used
elif self.current_iteration - max(recent) >= 10:
trend = "↓" # Not used recently
# Color code based on value
if avg_value > 0.7:
value_str = f"[green]{avg_value:.3f}[/green]"
elif avg_value > 0.4:
value_str = f"[blue]{avg_value:.3f}[/blue]"
else:
value_str = f"[yellow]{avg_value:.3f}[/yellow]"
table.add_row(str(action), value_str, str(visits), trend)
return Panel(table, title="[bold]Action Insights[/bold]", border_style="cyan")
def _create_intelligence_panel(self) -> Panel:
"""Create a panel showing intelligence metrics and learning patterns.
Returns:
Panel with intelligence visualization
"""
from rich.table import Table
from rich.columns import Columns
# Create metrics table
metrics_table = Table(box=None, expand=True)
metrics_table.add_column("Metric")
metrics_table.add_column("Value")
# Add search quality metrics
for metric, value in self.quality_metrics.items():
formatted_name = metric.replace("_", " ").title()
# Color based on value
if value > 0.7:
value_str = f"[green]{value:.2f}[/green]"
elif value > 0.4:
value_str = f"[blue]{value:.2f}[/blue]"
else:
value_str = f"[yellow]{value:.2f}[/yellow]"
metrics_table.add_row(formatted_name, value_str)
# Create exploration table
exploration_table = Table(box=None, expand=True)
exploration_table.add_column("Pattern")
exploration_table.add_column("Value")
# Add exploration patterns
if self.exploration_patterns:
# Average depth of exploration
avg_depth = sum(p["node_depth"] for p in self.exploration_patterns) / len(self.exploration_patterns)
exploration_table.add_row("Avg Exploration Depth", f"{avg_depth:.2f}")
# Depth trend (increasing or decreasing)
if len(self.exploration_patterns) >= 5:
recent_avg = sum(p["node_depth"] for p in self.exploration_patterns[-5:]) / 5
earlier_avg = sum(p["node_depth"] for p in self.exploration_patterns[:-5]) / max(1, len(self.exploration_patterns) - 5)
if recent_avg > earlier_avg * 1.2:
trend = "[green]Deepening[/green]"
elif recent_avg < earlier_avg * 0.8:
trend = "[yellow]Shallowing[/yellow]"
else:
trend = "[blue]Stable[/blue]"
exploration_table.add_row("Depth Trend", trend)
# Exploration-exploitation balance
if len(self.exploration_patterns) >= 3:
# Higher values = more exploitation of known good paths
exploitation_ratio = sum(1 for p in self.exploration_patterns[-10:]
if p["value_estimate"] > 0.5) / min(10, len(self.exploration_patterns))
if exploitation_ratio > 0.7:
balance = "[yellow]Heavy Exploitation[/yellow]"
elif exploitation_ratio < 0.3:
balance = "[yellow]Heavy Exploration[/yellow]"
else:
balance = "[green]Balanced[/green]"
exploration_table.add_row("Search Balance", balance)
# Combine tables into columns
columns = Columns([metrics_table, exploration_table])
return Panel(columns, title="[bold]Intelligence Metrics[/bold]", border_style="magenta")
def _create_stats_panel(self) -> Panel:
"""Create a panel showing search statistics with enhanced metrics.
Returns:
Panel with statistics
"""
if not self.root_node:
return Panel("No statistics available", title="[bold]Search Statistics[/bold]")
# Collect statistics
total_nodes = 0
max_depth = 0
total_visits = getattr(self.root_node, "visits", 0)
avg_branching = 0
def count_nodes(node, depth=0):
nonlocal total_nodes, max_depth, avg_branching
if not node or not hasattr(node, "children"):
return
total_nodes += 1
max_depth = max(max_depth, depth)
# Count children for branching factor
num_children = len(node.children)
if num_children > 0:
avg_branching += num_children
for child in node.children.values():
count_nodes(child, depth + 1)
count_nodes(self.root_node)
# Calculate average branching factor
if total_nodes > 1: # Root node doesn't count for avg branching
avg_branching /= (total_nodes - 1)
# Create a table of statistics
from rich.table import Table
table = Table(box=None, expand=True)
table.add_column("Metric")
table.add_column("Value")
table.add_row("Total Nodes", str(total_nodes))
table.add_row("Max Depth", str(max_depth))
table.add_row("Total Visits", str(total_visits))
table.add_row("Avg Branching", f"{avg_branching:.2f}")
table.add_row("Progress", f"{self.current_iteration / self.max_iterations:.1%}")
# Efficiency estimate (higher is better)
if total_visits > 0:
visit_efficiency = total_nodes / total_visits
efficiency_str = f"{visit_efficiency:.2f}"
table.add_row("Search Efficiency", efficiency_str)
return Panel(table, title="[bold]Search Statistics[/bold]", border_style="magenta")
def _get_node_depth(self, node):
"""Calculate the depth of a node in the tree."""
depth = 0
current = node
while getattr(current, "parent", None) is not None:
depth += 1
current = current.parent
return depth
def _update_visit_distribution(self):
"""Update the distribution of visits across the tree."""
levels = {}
def count_visits_by_level(node, depth=0):
if not node or not hasattr(node, "children"):
return
# Initialize level if not present
if depth not in levels:
levels[depth] = {"visits": 0, "nodes": 0}
# Update level stats
levels[depth]["visits"] += getattr(node, "visits", 0)
levels[depth]["nodes"] += 1
# Process children
for child in node.children.values():
count_visits_by_level(child, depth + 1)
# Start counting from root
count_visits_by_level(self.root_node)
# Update visit distribution
self.visit_distribution = levels
def _update_quality_metrics(self):
"""Update quality metrics for the search process."""
# Search efficiency - ratio of valuable nodes to total nodes
# Higher values indicate more efficient search
if self.visit_distribution:
useful_visits = sum(level["visits"] for depth, level in self.visit_distribution.items()
if depth > 0) # Exclude root
total_visits = sum(level["visits"] for level in self.visit_distribution.values())
if total_visits > 0:
self.quality_metrics["search_efficiency"] = useful_visits / total_visits
# Exploration balance - how well the algorithm balances exploration vs exploitation
if self.exploration_patterns:
# Calculate variance in exploration depth
depths = [p["node_depth"] for p in self.exploration_patterns[-20:]] # Last 20 iterations
if depths:
import statistics
try:
depth_variance = statistics.variance(depths) if len(depths) > 1 else 0
# Normalize to 0-1 range (higher variance = more balanced exploration)
normalized_variance = min(1.0, depth_variance / 5.0) # Assume variance > 5 is high
self.quality_metrics["exploration_balance"] = normalized_variance
except statistics.StatisticsError:
pass
class ParallelExecutionVisualizer:
"""Visualizes parallel execution of tool calls in real-time."""
def __init__(self, console: Console):
"""Initialize the parallel execution visualizer.
Args:
console: Rich console instance
"""
self.console = console
self.active_executions = {}
self.completed_executions = []
self.layout = self._create_layout()
self.live = Live(self.layout, console=console, refresh_per_second=10, auto_refresh=False)
def _create_layout(self) -> Layout:
"""Create the layout for parallel execution visualization.
Returns:
Layout object
"""
layout = Layout()
# Create the main sections
layout.split(
Layout(name="header", size=3),
Layout(name="executions"),
Layout(name="metrics", size=5)
)
return layout
def add_execution(self, execution_id: str, tool_name: str, parameters: Dict[str, Any]) -> None:
"""Add a new execution to visualize.
Args:
execution_id: Unique ID for the execution
tool_name: Name of the tool being executed
parameters: Parameters for the execution
"""
self.active_executions[execution_id] = {
"tool_name": tool_name,
"parameters": parameters,
"start_time": time.time(),
"progress": 0.0,
"status": "running"
}
self.refresh()
def update_progress(self, execution_id: str, progress: float) -> None:
"""Update the progress of an execution.
Args:
execution_id: ID of the execution
progress: Progress value (0-1)
"""
if execution_id in self.active_executions:
self.active_executions[execution_id]["progress"] = progress
self.refresh()
def complete_execution(self, execution_id: str, result: Any, status: str = "success") -> None:
"""Mark an execution as complete.
Args:
execution_id: ID of the execution
result: Result of the execution
status: Status of completion
"""
if execution_id in self.active_executions:
execution = self.active_executions[execution_id].copy()
execution["end_time"] = time.time()
execution["duration"] = execution["end_time"] - execution["start_time"]
execution["result"] = result
execution["status"] = status
# Move to completed executions
self.completed_executions.append(execution)
del self.active_executions[execution_id]
# Limit completed executions list
if len(self.completed_executions) > 20:
self.completed_executions = self.completed_executions[-20:]
self.refresh()
def start(self) -> None:
"""Start the visualization."""
self.live.start()
self.refresh()
def stop(self) -> None:
"""Stop the visualization."""
self.live.stop()
def refresh(self) -> None:
"""Refresh the visualization."""
# Update header
header_content = f"[bold blue]Parallel Execution Monitor[/bold blue] | Active: {len(self.active_executions)} | Completed: {len(self.completed_executions)}"
self.layout["header"].update(Panel(header_content, border_style="blue"))
# Update executions visualization
self.layout["executions"].update(self._create_executions_panel())
# Update metrics
self.layout["metrics"].update(self._create_metrics_panel())
# Refresh the live display
self.live.refresh()
def _create_executions_panel(self) -> Panel:
"""Create a panel showing active and recent executions.
Returns:
Panel with executions visualization
"""
from rich.table import Table
from rich.progress import BarColumn, Progress, TextColumn
# Create progress bars for active executions
progress_group = Table.grid(expand=True)
if self.active_executions:
# Create a progress group
progress = Progress(
TextColumn("[bold blue]{task.description}"),
BarColumn(bar_width=None),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TextColumn("| Elapsed: {task.elapsed:.2f}s"),
expand=True
)
# Add tasks for each active execution
for exec_id, execution in self.active_executions.items():
tool_name = execution["tool_name"]
description = f"{tool_name} ({exec_id[:8]}...)"
task_id = progress.add_task(description, total=100, completed=int(execution["progress"] * 100))
progress_group.add_row(progress)
else:
progress_group.add_row("[italic]No active executions[/italic]")
# Create a table for completed executions
completed_table = Table(show_header=True, header_style="bold blue", expand=True)
completed_table.add_column("Tool")
completed_table.add_column("Duration")
completed_table.add_column("Status")
completed_table.add_column("Result Preview")
if self.completed_executions:
# Most recent first
for execution in reversed(self.completed_executions[-10:]):
tool_name = execution["tool_name"]
duration = f"{execution['duration']:.2f}s"
status = execution["status"]
# Format result preview
result = str(execution.get("result", ""))
preview = result[:50] + "..." if len(result) > 50 else result
# Add status with color
status_text = f"[green]{status}[/green]" if status == "success" else f"[red]{status}[/red]"
completed_table.add_row(tool_name, duration, status_text, preview)
else:
completed_table.add_row("[italic]No completed executions[/italic]", "", "", "")
# Combine both into a layout
layout = Layout()
layout.split(
Layout(name="active", size=len(self.active_executions) * 2 + 3 if self.active_executions else 3),
Layout(name="completed")
)
layout["active"].update(Panel(progress_group, title="[bold]Active Executions[/bold]", border_style="blue"))
layout["completed"].update(Panel(completed_table, title="[bold]Recent Completions[/bold]", border_style="green"))
return layout
def _create_metrics_panel(self) -> Panel:
"""Create a panel showing execution metrics.
Returns:
Panel with metrics visualization
"""
from rich.table import Table
# Calculate metrics
total_executions = len(self.completed_executions)
successful = sum(1 for e in self.completed_executions if e["status"] == "success")
failed = total_executions - successful
if total_executions > 0:
success_rate = successful / total_executions
avg_duration = sum(e["duration"] for e in self.completed_executions) / total_executions
else:
success_rate = 0
avg_duration = 0
# Create metrics table
table = Table(box=None, expand=True)
table.add_column("Metric")
table.add_column("Value")
table.add_row("Total Executions", str(total_executions))
table.add_row("Success Rate", f"{success_rate:.1%}")
table.add_row("Average Duration", f"{avg_duration:.2f}s")
table.add_row("Current Parallelism", str(len(self.active_executions)))
return Panel(table, title="[bold]Execution Metrics[/bold]", border_style="magenta")
class MultiPanelLayout:
"""Creates a multi-panel layout for the entire UI."""
def __init__(self, console: Console):
"""Initialize the multi-panel layout.
Args:
console: Rich console instance
"""
self.console = console
self.layout = self._create_layout()
self.live = Live(self.layout, console=console, refresh_per_second=4, auto_refresh=False)
def _create_layout(self) -> Layout:
"""Create the main application layout.
Returns:
Layout object
"""
layout = Layout()
# Split into three main sections
layout.split(
Layout(name="conversation", ratio=3),
Layout(name="tools", ratio=2),
Layout(name="input", ratio=1)
)
# Further split the tools section
layout["tools"].split_row(
Layout(name="active_tools"),
Layout(name="cost", size=30)
)
return layout
def start(self) -> None:
"""Start the live display."""
self.live.start()
def stop(self) -> None:
"""Stop the live display."""
self.live.stop()
def refresh(self) -> None:
"""Refresh the display."""
self.live.refresh()
def update_section(self, section: str, content: Any) -> None:
"""Update a section of the layout.
Args:
section: Section name
content: Content to display
"""
if section in self.layout:
self.layout[section].update(content)
self.refresh()
```