This is page 3 of 4. Use http://codebase.md/arthurcolle/openai-mcp?lines=false&page={x} to view the full context.
# Directory Structure
```
├── .gitignore
├── claude_code
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-312.pyc
│ │ └── mcp_server.cpython-312.pyc
│ ├── claude.py
│ ├── commands
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-312.pyc
│ │ │ └── serve.cpython-312.pyc
│ │ ├── client.py
│ │ ├── multi_agent_client.py
│ │ └── serve.py
│ ├── config
│ │ └── __init__.py
│ ├── examples
│ │ ├── agents_config.json
│ │ ├── claude_mcp_config.html
│ │ ├── claude_mcp_config.json
│ │ ├── echo_server.py
│ │ └── README.md
│ ├── lib
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ └── __init__.cpython-312.pyc
│ │ ├── context
│ │ │ └── __init__.py
│ │ ├── monitoring
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── __init__.cpython-312.pyc
│ │ │ │ └── server_metrics.cpython-312.pyc
│ │ │ ├── cost_tracker.py
│ │ │ └── server_metrics.py
│ │ ├── providers
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ └── openai.py
│ │ ├── rl
│ │ │ ├── __init__.py
│ │ │ ├── grpo.py
│ │ │ ├── mcts.py
│ │ │ └── tool_optimizer.py
│ │ ├── tools
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── __init__.cpython-312.pyc
│ │ │ │ ├── base.cpython-312.pyc
│ │ │ │ ├── file_tools.cpython-312.pyc
│ │ │ │ └── manager.cpython-312.pyc
│ │ │ ├── ai_tools.py
│ │ │ ├── base.py
│ │ │ ├── code_tools.py
│ │ │ ├── file_tools.py
│ │ │ ├── manager.py
│ │ │ └── search_tools.py
│ │ └── ui
│ │ ├── __init__.py
│ │ └── tool_visualizer.py
│ ├── mcp_server.py
│ ├── README_MCP_CLIENT.md
│ ├── README_MULTI_AGENT.md
│ └── util
│ └── __init__.py
├── claude.py
├── cli.py
├── data
│ └── prompt_templates.json
├── deploy_modal_mcp.py
├── deploy.sh
├── examples
│ ├── agents_config.json
│ └── echo_server.py
├── install.sh
├── mcp_modal_adapter.py
├── mcp_server.py
├── modal_mcp_server.py
├── README_modal_mcp.md
├── README.md
├── requirements.txt
├── setup.py
├── static
│ └── style.css
├── templates
│ └── index.html
└── web-client.html
```
# Files
--------------------------------------------------------------------------------
/mcp_server.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
"""
Model Context Protocol (MCP) Server Implementation
This module implements the Model Context Protocol server capabilities,
allowing the assistant to be used as an MCP-compatible context provider.
"""
import os
import json
import time
import uuid
import sys
import logging
import asyncio
import tiktoken
import re
from datetime import datetime
from typing import Dict, List, Any, Optional, Union, AsyncGenerator
from fastapi import FastAPI, HTTPException, Request, Response, Depends, BackgroundTasks, Query
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel, Field
import uvicorn
import openai
from openai import OpenAI
import prometheus_client
from prometheus_client import Counter, Histogram, Gauge
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger("mcp_server")
# MCP Protocol Models
class MCPHealthResponse(BaseModel):
"""Health check response for MCP protocol"""
status: str = "healthy"
version: str = "1.0.0"
protocol_version: str = "0.1.0"
provider: str = "OpenAI Code Assistant"
models: List[str] = ["gpt-4o", "gpt-4-turbo", "gpt-3.5-turbo"]
uptime: Optional[float] = None
request_count: Optional[int] = None
cache_hit_ratio: Optional[float] = None
class MCPContextRequest(BaseModel):
"""Request for context generation from a prompt template"""
prompt_id: str
parameters: Dict[str, Any] = Field(default_factory=dict, description="Parameters to fill in the prompt template")
model: Optional[str] = Field(None, description="Model to use for context generation")
stream: bool = Field(False, description="Whether to stream the response")
user: Optional[str] = Field(None, description="User identifier for tracking")
conversation_id: Optional[str] = Field(None, description="Conversation identifier")
message_id: Optional[str] = Field(None, description="Message identifier")
class MCPContextResponse(BaseModel):
"""Response containing generated context"""
context: str = Field(..., description="The generated context")
context_id: str = Field(..., description="Unique identifier for this context")
model: str = Field(..., description="Model used for generation")
usage: Dict[str, int] = Field(default_factory=dict, description="Token usage statistics")
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
class MCPErrorResponse(BaseModel):
"""Error response format"""
error: str = Field(..., description="Error message")
error_type: str = Field(..., description="Type of error")
status_code: int = Field(..., description="HTTP status code")
details: Optional[Dict[str, Any]] = Field(None, description="Additional error details")
class MCPPromptTemplate(BaseModel):
"""Prompt template definition"""
id: str = Field(..., description="Unique identifier for the template")
template: str = Field(..., description="The prompt template with parameter placeholders")
description: Optional[str] = Field(None, description="Description of the template")
parameters: Dict[str, Dict[str, Any]] = Field(default_factory=dict, description="Parameter definitions")
default_model: Optional[str] = Field(None, description="Default model to use with this template")
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
class MCPPromptLibraryResponse(BaseModel):
"""Response containing a list of prompt templates"""
prompts: List[MCPPromptTemplate] = Field(..., description="List of prompt templates")
count: int = Field(..., description="Number of templates")
# MCP Server Implementation
# Prometheus metrics
REQUEST_COUNT = Counter('mcp_requests_total', 'Total number of requests processed', ['endpoint', 'status'])
REQUEST_LATENCY = Histogram('mcp_request_latency_seconds', 'Request latency in seconds', ['endpoint'])
CACHE_HIT = Counter('mcp_cache_hits_total', 'Total number of cache hits')
CACHE_MISS = Counter('mcp_cache_misses_total', 'Total number of cache misses')
ACTIVE_CONNECTIONS = Gauge('mcp_active_connections', 'Number of active connections')
TOKEN_USAGE = Counter('mcp_token_usage_total', 'Total number of tokens used', ['model', 'type'])
# Cache implementation
class CacheManager:
"""Manages caching for context responses"""
def __init__(self, cache_type="memory", redis_url=None, ttl=3600):
self.cache_type = cache_type
self.redis_url = redis_url
self.ttl = ttl
self.memory_cache = {}
self.redis_client = None
if cache_type == "redis" and redis_url:
try:
import redis
self.redis_client = redis.from_url(redis_url)
logging.info(f"Redis cache initialized with URL: {redis_url}")
except ImportError:
logging.warning("Redis package not installed. Falling back to memory cache.")
self.cache_type = "memory"
except Exception as e:
logging.error(f"Failed to connect to Redis: {str(e)}")
self.cache_type = "memory"
async def get(self, key):
"""Get item from cache"""
if self.cache_type == "redis" and self.redis_client:
try:
value = self.redis_client.get(key)
if value:
CACHE_HIT.inc()
return json.loads(value)
CACHE_MISS.inc()
return None
except Exception as e:
logging.error(f"Redis get error: {str(e)}")
CACHE_MISS.inc()
return None
else:
# Memory cache
if key in self.memory_cache:
if time.time() - self.memory_cache[key]["timestamp"] < self.ttl:
CACHE_HIT.inc()
return self.memory_cache[key]["data"]
else:
# Expired
del self.memory_cache[key]
CACHE_MISS.inc()
return None
async def set(self, key, value, ttl=None):
"""Set item in cache"""
if ttl is None:
ttl = self.ttl
if self.cache_type == "redis" and self.redis_client:
try:
self.redis_client.setex(key, ttl, json.dumps(value))
except Exception as e:
logging.error(f"Redis set error: {str(e)}")
else:
# Memory cache
self.memory_cache[key] = {
"data": value,
"timestamp": time.time()
}
async def delete(self, key):
"""Delete item from cache"""
if self.cache_type == "redis" and self.redis_client:
try:
self.redis_client.delete(key)
except Exception as e:
logging.error(f"Redis delete error: {str(e)}")
else:
# Memory cache
if key in self.memory_cache:
del self.memory_cache[key]
async def clear(self):
"""Clear all cache"""
if self.cache_type == "redis" and self.redis_client:
try:
self.redis_client.flushdb()
except Exception as e:
logging.error(f"Redis flush error: {str(e)}")
else:
# Memory cache
self.memory_cache = {}
class MCPServer:
"""Model Context Protocol Server Implementation"""
def __init__(self, cache_type="memory", redis_url=None):
self.app = FastAPI(
title="OpenAI Code Assistant MCP Server",
description="Model Context Protocol server for OpenAI Code Assistant",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc",
openapi_url="/openapi.json",
)
# Initialize OpenAI client
self.openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
# Initialize cache
self.cache = CacheManager(cache_type=cache_type, redis_url=redis_url)
# Initialize tokenizer
self.tokenizer = tiktoken.get_encoding("cl100k_base")
# Setup routes and middleware
self.setup_routes()
self.setup_middleware()
# Load templates and static files
self.templates_dir = os.path.join(os.path.dirname(__file__), "templates")
os.makedirs(self.templates_dir, exist_ok=True)
self.static_dir = os.path.join(os.path.dirname(__file__), "static")
os.makedirs(self.static_dir, exist_ok=True)
# Create default template if it doesn't exist
self._create_default_template()
# Initialize templates
self.templates = Jinja2Templates(directory=self.templates_dir)
# Mount static files
self.app.mount("/static", StaticFiles(directory=self.static_dir), name="static")
# Load prompt templates
self.prompt_templates = self._load_prompt_templates()
# Initialize metrics
self.request_count = 0
self.start_time = time.time()
def setup_middleware(self):
"""Configure middleware for the FastAPI app"""
# Add CORS middleware
self.app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Add request tracking middleware
@self.app.middleware("http")
async def track_requests(request: Request, call_next):
# Increment active connections
ACTIVE_CONNECTIONS.inc()
# Track request start time
start_time = time.time()
# Process request
try:
response = await call_next(request)
# Record metrics
endpoint = request.url.path
status = response.status_code
REQUEST_COUNT.labels(endpoint=endpoint, status=status).inc()
REQUEST_LATENCY.labels(endpoint=endpoint).observe(time.time() - start_time)
# Increment total request count
self.request_count += 1
return response
finally:
# Decrement active connections
ACTIVE_CONNECTIONS.dec()
def _create_default_template(self):
"""Create default dashboard template if it doesn't exist"""
index_path = os.path.join(self.templates_dir, "index.html")
if not os.path.exists(index_path):
with open(index_path, "w") as f:
f.write("""
<!DOCTYPE html>
<html>
<head>
<title>OpenAI Code Assistant MCP Server</title>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css">
<style>
body { padding: 20px; }
.card { margin-bottom: 20px; }
</style>
</head>
<body>
<div class="container">
<h1>OpenAI Code Assistant MCP Server</h1>
<div class="row">
<div class="col-md-6">
<div class="card">
<div class="card-header">Server Status</div>
<div class="card-body">
<p><strong>Status:</strong> {{ status }}</p>
<p><strong>Uptime:</strong> {{ uptime }}</p>
<p><strong>Requests Served:</strong> {{ request_count }}</p>
<p><strong>Cache Hit Ratio:</strong> {{ cache_hit_ratio }}%</p>
</div>
</div>
</div>
<div class="col-md-6">
<div class="card">
<div class="card-header">Available Models</div>
<div class="card-body">
<ul>
{% for model in models %}
<li>{{ model }}</li>
{% endfor %}
</ul>
</div>
</div>
</div>
</div>
<h2>Available Prompt Templates</h2>
<div class="row">
{% for template in templates %}
<div class="col-md-6">
<div class="card">
<div class="card-header">{{ template.id }}</div>
<div class="card-body">
<p><strong>Description:</strong> {{ template.description }}</p>
<p><strong>Parameters:</strong> {{ template.parameters|join(", ") }}</p>
<p><strong>Default Model:</strong> {{ template.default_model }}</p>
</div>
</div>
</div>
{% endfor %}
</div>
<h2>API Documentation</h2>
<p>
<a href="/docs" class="btn btn-primary">Interactive API Docs</a>
<a href="/redoc" class="btn btn-secondary">ReDoc API Docs</a>
<a href="/metrics" class="btn btn-info">Prometheus Metrics</a>
</p>
</div>
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.bundle.min.js"></script>
</body>
</html>
""")
def setup_routes(self):
"""Configure API routes for MCP protocol"""
# MCP Protocol Routes
# Dashboard route
@self.app.get("/", tags=["Dashboard"])
async def dashboard(request: Request):
"""Dashboard showing server status and available templates"""
# Calculate cache hit ratio
cache_hits = prometheus_client.REGISTRY.get_sample_value('mcp_cache_hits_total') or 0
cache_misses = prometheus_client.REGISTRY.get_sample_value('mcp_cache_misses_total') or 0
total_cache_requests = cache_hits + cache_misses
cache_hit_ratio = (cache_hits / total_cache_requests * 100) if total_cache_requests > 0 else 0
# Format uptime
uptime_seconds = time.time() - self.start_time
days, remainder = divmod(uptime_seconds, 86400)
hours, remainder = divmod(remainder, 3600)
minutes, seconds = divmod(remainder, 60)
uptime_str = f"{int(days)}d {int(hours)}h {int(minutes)}m {int(seconds)}s"
# Get template information
templates = []
for template_id, template in self.prompt_templates.items():
templates.append({
"id": template_id,
"description": template.get("description", ""),
"parameters": list(template.get("parameters", {}).keys()),
"default_model": template.get("default_model", "gpt-4o")
})
return self.templates.TemplateResponse("index.html", {
"request": request,
"status": "Healthy",
"uptime": uptime_str,
"request_count": self.request_count,
"cache_hit_ratio": round(cache_hit_ratio, 2),
"models": ["gpt-4o", "gpt-4-turbo", "gpt-3.5-turbo"],
"templates": templates
})
# Prometheus metrics endpoint
@self.app.get("/metrics", tags=["Monitoring"])
async def metrics():
"""Expose Prometheus metrics"""
return Response(prometheus_client.generate_latest(), media_type="text/plain")
# Health check endpoints
@self.app.get("/health", response_model=MCPHealthResponse, tags=["Health"])
async def health():
"""Health check endpoint"""
# Calculate cache hit ratio
cache_hits = prometheus_client.REGISTRY.get_sample_value('mcp_cache_hits_total') or 0
cache_misses = prometheus_client.REGISTRY.get_sample_value('mcp_cache_misses_total') or 0
total_cache_requests = cache_hits + cache_misses
cache_hit_ratio = (cache_hits / total_cache_requests) if total_cache_requests > 0 else 0
return MCPHealthResponse(
status="healthy",
uptime=time.time() - self.start_time,
request_count=self.request_count,
cache_hit_ratio=cache_hit_ratio
)
@self.app.post("/context", response_model=MCPContextResponse, tags=["Context"])
async def get_context(
request: MCPContextRequest,
background_tasks: BackgroundTasks,
use_cache: bool = Query(True, description="Whether to use cached results if available")
):
"""
Get context for a prompt template with parameters.
This endpoint processes a prompt template with the provided parameters
and returns the generated context. It can optionally use OpenAI models
to enhance the context.
"""
try:
# Check if prompt template exists
if request.prompt_id not in self.prompt_templates:
raise HTTPException(
status_code=404,
detail=f"Prompt template '{request.prompt_id}' not found"
)
# Get prompt template
template = self.prompt_templates[request.prompt_id]
# Use default model if not specified
model = request.model or template.get("default_model", "gpt-4o")
# Generate context ID
context_id = str(uuid.uuid4())
# Generate cache key
cache_key = f"{request.prompt_id}:{json.dumps(request.parameters, sort_keys=True)}:{model}"
# Check cache if enabled
if use_cache:
cached_result = await self.cache.get(cache_key)
if cached_result:
# Update context ID for this request
cached_result["context_id"] = context_id
return MCPContextResponse(**cached_result)
# Process template with parameters
processed_template = self._process_template(template["template"], request.parameters)
# Check if we should use OpenAI to enhance the context
if template.get("use_openai", False):
# Generate context using OpenAI
context, usage = await self._generate_with_openai(
processed_template,
model,
template.get("system_prompt")
)
else:
# Use the processed template directly
context = processed_template
# Calculate token usage
token_count = len(self.tokenizer.encode(context))
usage = {
"prompt_tokens": token_count,
"completion_tokens": 0,
"total_tokens": token_count
}
# Track token usage in Prometheus
TOKEN_USAGE.labels(model=model, type="prompt").inc(usage["prompt_tokens"])
TOKEN_USAGE.labels(model=model, type="completion").inc(usage["completion_tokens"])
# Create response
response = MCPContextResponse(
context=context,
context_id=context_id,
model=model,
usage=usage,
metadata={
"prompt_id": request.prompt_id,
"timestamp": time.time(),
"parameters": request.parameters
}
)
# Store in cache
await self.cache.set(cache_key, response.dict())
return response
except Exception as e:
logger.error(f"Error processing context request: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"Error processing context: {str(e)}"
)
@self.app.post("/context/stream", tags=["Context"])
async def stream_context(request: MCPContextRequest):
"""
Stream context generation.
Similar to /context but streams the response as it's generated.
"""
try:
# Check if prompt template exists
if request.prompt_id not in self.prompt_templates:
raise HTTPException(
status_code=404,
detail=f"Prompt template '{request.prompt_id}' not found"
)
# Get prompt template
template = self.prompt_templates[request.prompt_id]
# Use default model if not specified
model = request.model or template.get("default_model", "gpt-4o")
# Generate context ID
context_id = str(uuid.uuid4())
# Process template with parameters
processed_template = self._process_template(template["template"], request.parameters)
# Stream the context generation
return StreamingResponse(
self._stream_context(processed_template, model, context_id, template.get("system_prompt")),
media_type="text/event-stream"
)
except Exception as e:
logger.error(f"Error streaming context: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"Error streaming context: {str(e)}"
)
@self.app.get("/prompts", response_model=MCPPromptLibraryResponse, tags=["Prompts"])
async def get_prompts():
"""
Get available prompt templates.
Returns a list of all prompt templates available in the system.
"""
prompts = [
MCPPromptTemplate(
id=prompt_id,
template=template["template"],
description=template.get("description", ""),
parameters=template.get("parameters", {}),
default_model=template.get("default_model", "gpt-4o"),
metadata=template.get("metadata", {})
)
for prompt_id, template in self.prompt_templates.items()
]
return MCPPromptLibraryResponse(
prompts=prompts,
count=len(prompts)
)
@self.app.get("/prompts/{prompt_id}", response_model=MCPPromptTemplate, tags=["Prompts"])
async def get_prompt(prompt_id: str):
"""
Get a specific prompt template.
Returns the details of a specific prompt template by ID.
"""
if prompt_id not in self.prompt_templates:
raise HTTPException(
status_code=404,
detail=f"Prompt template '{prompt_id}' not found"
)
template = self.prompt_templates[prompt_id]
return MCPPromptTemplate(
id=prompt_id,
template=template["template"],
description=template.get("description", ""),
parameters=template.get("parameters", {}),
default_model=template.get("default_model", "gpt-4o"),
metadata=template.get("metadata", {})
)
@self.app.post("/prompts", response_model=MCPPromptTemplate, status_code=201, tags=["Prompts"])
async def create_prompt(prompt: MCPPromptTemplate):
"""
Create a new prompt template.
Adds a new prompt template to the system.
"""
if prompt.id in self.prompt_templates:
raise HTTPException(
status_code=409,
detail=f"Prompt template '{prompt.id}' already exists"
)
self.prompt_templates[prompt.id] = {
"template": prompt.template,
"description": prompt.description,
"parameters": prompt.parameters,
"default_model": prompt.default_model,
"metadata": prompt.metadata
}
# Save updated templates
self._save_prompt_templates()
return prompt
@self.app.put("/prompts/{prompt_id}", response_model=MCPPromptTemplate, tags=["Prompts"])
async def update_prompt(prompt_id: str, prompt: MCPPromptTemplate):
"""
Update an existing prompt template.
Updates the details of an existing prompt template.
"""
if prompt_id != prompt.id:
raise HTTPException(
status_code=400,
detail="Prompt ID in path must match prompt ID in body"
)
if prompt_id not in self.prompt_templates:
raise HTTPException(
status_code=404,
detail=f"Prompt template '{prompt_id}' not found"
)
self.prompt_templates[prompt_id] = {
"template": prompt.template,
"description": prompt.description,
"parameters": prompt.parameters,
"default_model": prompt.default_model,
"metadata": prompt.metadata
}
# Save updated templates
self._save_prompt_templates()
return prompt
@self.app.delete("/prompts/{prompt_id}", tags=["Prompts"])
async def delete_prompt(prompt_id: str):
"""
Delete a prompt template.
Removes a prompt template from the system.
"""
if prompt_id not in self.prompt_templates:
raise HTTPException(
status_code=404,
detail=f"Prompt template '{prompt_id}' not found"
)
del self.prompt_templates[prompt_id]
# Save updated templates
self._save_prompt_templates()
return {"status": "deleted", "prompt_id": prompt_id}
# Additional endpoints for a more complete MCP server
@self.app.get("/models", tags=["Models"])
async def get_models():
"""
Get available models.
Returns a list of models that can be used with this MCP server.
"""
return {
"models": [
{
"id": "gpt-4o",
"name": "GPT-4o",
"description": "OpenAI's most advanced model",
"context_length": 128000,
"is_default": True
},
{
"id": "gpt-4-turbo",
"name": "GPT-4 Turbo",
"description": "Optimized version of GPT-4",
"context_length": 128000,
"is_default": False
},
{
"id": "gpt-3.5-turbo",
"name": "GPT-3.5 Turbo",
"description": "Fast and efficient model",
"context_length": 16385,
"is_default": False
}
],
"count": 3
}
@self.app.get("/stats", tags=["System"])
async def get_stats():
"""
Get server statistics.
Returns usage statistics and system information.
"""
return {
"uptime": time.time() - self.start_time,
"prompt_templates_count": len(self.prompt_templates),
"cache_size": len(self.context_cache),
"requests_served": {
"context": 0, # This would be tracked in a real implementation
"prompts": 0,
"total": 0
},
"system_info": {
"python_version": sys.version,
"platform": sys.platform
}
}
@self.app.post("/context/stream", tags=["Context"])
async def stream_context(request: MCPContextRequest):
"""
Stream context generation.
Similar to /context but streams the response as it's generated.
"""
# In a real implementation, this would stream the response
# For now, we'll just return a simple response
return JSONResponse(
content={"message": "Streaming not implemented in this version"},
status_code=501
)
# Error handlers
@self.app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
"""Handle HTTP exceptions in MCP format"""
return JSONResponse(
status_code=exc.status_code,
content={
"error": exc.detail,
"error_type": "http_error",
"status_code": exc.status_code,
"details": exc.detail if isinstance(exc.detail, dict) else None
}
)
@self.app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
"""Handle general exceptions in MCP format"""
logger.error(f"Unhandled exception: {str(exc)}", exc_info=True)
return JSONResponse(
status_code=500,
content={
"error": str(exc),
"error_type": "server_error",
"status_code": 500,
"details": None
}
)
def _load_prompt_templates(self) -> Dict[str, Dict[str, Any]]:
"""Load prompt templates from file or initialize defaults"""
templates_file = os.path.join(os.path.dirname(__file__), "data", "prompt_templates.json")
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(templates_file), exist_ok=True)
# Try to load existing templates
if os.path.exists(templates_file):
try:
with open(templates_file, "r") as f:
templates = json.load(f)
logger.info(f"Loaded {len(templates)} prompt templates from {templates_file}")
return templates
except Exception as e:
logger.error(f"Error loading prompt templates: {str(e)}")
# Initialize with enhanced default templates
default_templates = {
"greeting": {
"template": "Hello! The current time is {time}. How can I help you today?",
"description": "A simple greeting template",
"parameters": {
"time": {
"type": "string",
"description": "The current time"
}
},
"default_model": "gpt-4o",
"metadata": {
"category": "general"
}
},
"code_review": {
"template": "Please review the following code:\n\n```{language}\n{code}\n```\n\nFocus on: {focus_areas}",
"description": "Template for code review requests",
"parameters": {
"language": {
"type": "string",
"description": "Programming language of the code"
},
"code": {
"type": "string",
"description": "The code to review"
},
"focus_areas": {
"type": "string",
"description": "Areas to focus on during review (e.g., 'performance, security')"
}
},
"default_model": "gpt-4o",
"use_openai": True,
"system_prompt": "You are a code review expert. Analyze the provided code and provide constructive feedback focusing on the specified areas.",
"metadata": {
"category": "development"
}
},
"system_prompt": {
"template": "You are OpenAI Code Assistant, a CLI tool that helps users with software engineering tasks and general information.\nUse the available tools to assist the user with their requests.\n\n# Tone and style\nYou should be concise, direct, and to the point. When you run a non-trivial bash command, \nyou should explain what the command does and why you are running it.\nOutput text to communicate with the user; all text you output outside of tool use is displayed to the user.\nRemember that your output will be displayed on a command line interface.\n\n# Tool usage policy\n- When doing file search, remember to search effectively with the available tools.\n- Always use the appropriate tool for the task.\n- Use parallel tool calls when appropriate to improve performance.\n- NEVER commit changes unless the user explicitly asks you to.\n- For weather queries, use the Weather tool to provide real-time information.\n\n# Tasks\nThe user will primarily request you perform software engineering tasks:\n1. Solving bugs\n2. Adding new functionality \n3. Refactoring code\n4. Explaining code\n5. Writing tests\n\nFor these tasks:\n1. Use search tools to understand the codebase\n2. Implement solutions using the available tools\n3. Verify solutions with tests if possible\n4. Run lint and typecheck commands when appropriate\n\nThe user may also ask for general information:\n1. Weather conditions\n2. Simple calculations\n3. General knowledge questions\n\n# Code style\n- Follow the existing code style of the project\n- Maintain consistent naming conventions\n- Use appropriate libraries that are already in the project\n- Add comments when code is complex or non-obvious\n\nIMPORTANT: You should minimize output tokens as much as possible while maintaining helpfulness, \nquality, and accuracy. Answer concisely with short lines of text unless the user asks for detail.",
"description": "System prompt for the assistant",
"parameters": {},
"default_model": "gpt-4o",
"metadata": {
"category": "system"
}
},
"documentation": {
"template": "Generate documentation for the following code:\n\n```{language}\n{code}\n```\n\nFormat: {format}",
"description": "Generate code documentation",
"parameters": {
"language": {
"type": "string",
"description": "Programming language of the code"
},
"code": {
"type": "string",
"description": "The code to document"
},
"format": {
"type": "string",
"description": "Documentation format (e.g., 'markdown', 'docstring', 'jsdoc')",
"default": "markdown"
}
},
"default_model": "gpt-4o",
"use_openai": True,
"system_prompt": "You are a technical documentation expert. Generate clear, concise, and accurate documentation for the provided code.",
"metadata": {
"category": "development"
}
},
"explain_code": {
"template": "Explain how the following code works:\n\n```{language}\n{code}\n```\n\nDetail level: {detail_level}",
"description": "Explain code functionality",
"parameters": {
"language": {
"type": "string",
"description": "Programming language of the code"
},
"code": {
"type": "string",
"description": "The code to explain"
},
"detail_level": {
"type": "string",
"description": "Level of detail in the explanation (e.g., 'basic', 'intermediate', 'advanced')",
"default": "intermediate"
}
},
"default_model": "gpt-4o",
"use_openai": True,
"system_prompt": "You are a programming instructor. Explain the provided code clearly at the requested level of detail.",
"metadata": {
"category": "education"
}
},
"current_time": {
"template": "The current time is {{now:%Y-%m-%d %H:%M:%S}}.",
"description": "Get the current time",
"parameters": {},
"default_model": "gpt-4o",
"metadata": {
"category": "utility"
}
}
}
# Save default templates
try:
with open(templates_file, "w") as f:
json.dump(default_templates, f, indent=2)
except Exception as e:
logger.error(f"Error saving default prompt templates: {str(e)}")
return default_templates
def _save_prompt_templates(self):
"""Save prompt templates to file"""
templates_file = os.path.join(os.path.dirname(__file__), "data", "prompt_templates.json")
try:
with open(templates_file, "w") as f:
json.dump(self.prompt_templates, f, indent=2)
except Exception as e:
logger.error(f"Error saving prompt templates: {str(e)}")
async def _generate_with_openai(self, prompt: str, model: str, system_prompt: Optional[str] = None) -> tuple:
"""Generate context using OpenAI API"""
messages = []
# Add system prompt if provided
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
# Add user prompt
messages.append({"role": "user", "content": prompt})
# Call OpenAI API
try:
response = await asyncio.to_thread(
self.openai_client.chat.completions.create,
model=model,
messages=messages,
temperature=0.0, # Use deterministic output for context generation
max_tokens=4000
)
# Extract content and usage
content = response.choices[0].message.content
usage = {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens
}
return content, usage
except Exception as e:
logger.error(f"OpenAI API error: {str(e)}")
raise ValueError(f"Error generating context with OpenAI: {str(e)}")
async def _stream_context(self, prompt: str, model: str, context_id: str, system_prompt: Optional[str] = None) -> AsyncGenerator[str, None]:
"""Stream context generation using OpenAI API"""
messages = []
# Add system prompt if provided
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
# Add user prompt
messages.append({"role": "user", "content": prompt})
# Initial event with context ID
yield f"data: {json.dumps({'context_id': context_id, 'event': 'start'})}\n\n"
try:
# Call OpenAI API with streaming
stream = await asyncio.to_thread(
self.openai_client.chat.completions.create,
model=model,
messages=messages,
temperature=0.0,
max_tokens=4000,
stream=True
)
full_content = ""
# Process the stream
for chunk in stream:
if chunk.choices and chunk.choices[0].delta.content:
content_piece = chunk.choices[0].delta.content
full_content += content_piece
# Yield the content piece
yield f"data: {json.dumps({'content': content_piece, 'event': 'content'})}\n\n"
# Calculate token usage
prompt_tokens = len(self.tokenizer.encode(prompt))
completion_tokens = len(self.tokenizer.encode(full_content))
total_tokens = prompt_tokens + completion_tokens
# Track token usage
TOKEN_USAGE.labels(model=model, type="prompt").inc(prompt_tokens)
TOKEN_USAGE.labels(model=model, type="completion").inc(completion_tokens)
# Final event with complete context and usage
yield f"data: {json.dumps({
'event': 'end',
'context': full_content,
'usage': {
'prompt_tokens': prompt_tokens,
'completion_tokens': completion_tokens,
'total_tokens': total_tokens
}
})}\n\n"
except Exception as e:
logger.error(f"Error streaming context: {str(e)}")
yield f"data: {json.dumps({'event': 'error', 'error': str(e)})}\n\n"
def _process_template(self, template: str, parameters: Dict[str, Any]) -> str:
"""Process a template with parameters"""
try:
# Handle date/time formatting if needed
processed_params = parameters.copy()
for key, value in processed_params.items():
if isinstance(value, str) and value.startswith("{{now") and value.endswith("}}"):
# Extract format string if present
format_match = re.search(r"{{now:(.+)}}", value)
if format_match:
format_string = format_match.group(1)
processed_params[key] = datetime.now().strftime(format_string)
else:
processed_params[key] = datetime.now().isoformat()
return template.format(**processed_params)
except KeyError as e:
raise ValueError(f"Missing required parameter: {e}")
except Exception as e:
raise ValueError(f"Error processing template: {str(e)}")
def start(self, host: str = "127.0.0.1", port: int = 8000, reload: bool = False):
"""Start the MCP server"""
uvicorn.run(self.app, host=host, port=port, reload=reload)
def create_mcp_app():
"""Factory function for creating the FastAPI app"""
server = MCPServer()
return server.app
if __name__ == "__main__":
# Create data directory if it doesn't exist
os.makedirs(os.path.join(os.path.dirname(__file__), "data"), exist_ok=True)
# Start server
server = MCPServer()
server.start()
```
--------------------------------------------------------------------------------
/claude_code/examples/claude_mcp_config.html:
--------------------------------------------------------------------------------
```html
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Claude MCP Server Dashboard</title>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css">
<script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
<style>
:root {
--primary-color: #1a73e8;
--primary-dark: #0d47a1;
--primary-light: #e8f0fe;
--secondary-color: #34a853;
--tertiary-color: #ea4335;
--neutral-color: #f5f5f5;
--success-color: #00c853;
--warning-color: #ffab00;
--danger-color: #f44336;
--info-color: #2196f3;
--text-color: #333;
--text-light: #767676;
--border-radius: 12px;
--box-shadow: 0 4px 8px rgba(0, 0, 0, 0.12);
--transition: all 0.3s cubic-bezier(0.25, 0.8, 0.25, 1);
--font-primary: 'SF Pro Display', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, 'Open Sans', 'Helvetica Neue', sans-serif;
--font-code: 'SF Mono', 'Cascadia Code', 'Fira Code', Consolas, 'Courier New', monospace;
--header-height: 70px;
}
body {
font-family: var(--font-primary);
line-height: 1.6;
color: var(--text-color);
margin: 0;
padding: 0;
background-color: #f9f9f9;
transition: all 0.4s ease;
overflow-x: hidden;
}
.dark-mode {
--primary-color: #4285f4;
--primary-dark: #5c9aff;
--primary-light: #1c2733;
--neutral-color: #2c2c2c;
--success-color: #00e676;
--warning-color: #ffc400;
--danger-color: #ff5252;
--info-color: #42a5f5;
--text-color: #e0e0e0;
--text-light: #b0b0b0;
background-color: #121212;
}
.container {
max-width: 1200px;
margin: 0 auto;
padding: 20px;
}
header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 30px;
padding-bottom: 15px;
border-bottom: 1px solid #e0e0e0;
}
.logo {
display: flex;
align-items: center;
gap: 10px;
}
.header-actions {
display: flex;
gap: 15px;
}
h1, h2, h3, h4 {
color: var(--primary-color);
margin-top: 0;
}
.card {
background-color: white;
border-radius: var(--border-radius);
box-shadow: var(--box-shadow);
padding: 25px;
margin-bottom: 25px;
transition: var(--transition);
position: relative;
overflow: hidden;
border: 1px solid rgba(0, 0, 0, 0.03);
}
.dark-mode .card {
background-color: #222222;
border-color: rgba(255, 255, 255, 0.05);
}
.card:hover {
box-shadow: 0 10px 25px rgba(0, 0, 0, 0.08);
transform: translateY(-3px);
}
.card-header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 20px;
border-bottom: 1px solid rgba(0, 0, 0, 0.05);
padding-bottom: 15px;
}
.dark-mode .card-header {
border-bottom-color: rgba(255, 255, 255, 0.05);
}
.card-title {
font-size: 1.4rem;
font-weight: 600;
margin: 0;
color: var(--primary-color);
display: flex;
align-items: center;
gap: 10px;
}
.card-actions {
display: flex;
gap: 10px;
}
.card-accent {
position: absolute;
top: 0;
left: 0;
height: 4px;
width: 100%;
background: linear-gradient(90deg, var(--primary-color), var(--secondary-color));
}
.card-accent-primary {
background: linear-gradient(90deg, var(--primary-color), #5c9aff);
}
.card-accent-success {
background: linear-gradient(90deg, var(--success-color), #69f0ae);
}
.card-accent-warning {
background: linear-gradient(90deg, var(--warning-color), #ffecb3);
}
.card-accent-danger {
background: linear-gradient(90deg, var(--danger-color), #ff8a80);
}
.card-footer {
margin-top: 20px;
padding-top: 15px;
border-top: 1px solid rgba(0, 0, 0, 0.05);
display: flex;
justify-content: space-between;
align-items: center;
}
.dark-mode .card-footer {
border-top-color: rgba(255, 255, 255, 0.05);
}
.dashboard-grid {
display: grid;
grid-template-columns: 1fr 2fr;
gap: 20px;
}
@media (max-width: 768px) {
.dashboard-grid {
grid-template-columns: 1fr;
}
}
.sidebar {
display: flex;
flex-direction: column;
gap: 20px;
}
code {
background-color: var(--neutral-color);
padding: 2px 4px;
border-radius: 4px;
font-family: 'Courier New', Courier, monospace;
color: var(--text-color);
}
pre {
background-color: var(--neutral-color);
padding: 15px;
border-radius: 8px;
overflow-x: auto;
margin: 0;
font-family: 'Courier New', Courier, monospace;
color: var(--text-color);
}
.config-box {
background-color: var(--primary-light);
border: 1px solid var(--primary-color);
border-radius: var(--border-radius);
padding: 20px;
margin: 20px 0;
}
.note {
background-color: #fffde7;
border-left: 4px solid #ffca28;
padding: 10px 15px;
margin: 15px 0;
}
.dark-mode .note {
background-color: #332d00;
border-left-color: #ffca28;
}
.tab-container {
margin-bottom: 30px;
position: relative;
}
.tabs {
display: flex;
margin-bottom: 25px;
background-color: rgba(255, 255, 255, 0.8);
border-radius: var(--border-radius);
padding: 5px;
position: sticky;
top: 0;
z-index: 100;
backdrop-filter: blur(10px);
-webkit-backdrop-filter: blur(10px);
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.06);
}
.dark-mode .tabs {
background-color: rgba(40, 40, 40, 0.8);
}
.tab {
padding: 12px 25px;
cursor: pointer;
border-radius: var(--border-radius);
transition: var(--transition);
position: relative;
display: flex;
align-items: center;
gap: 8px;
font-weight: 500;
}
.tab:hover {
background-color: rgba(0, 0, 0, 0.05);
}
.dark-mode .tab:hover {
background-color: rgba(255, 255, 255, 0.05);
}
.tab.active {
background-color: var(--primary-color);
color: white;
font-weight: 600;
box-shadow: 0 4px 12px rgba(26, 115, 232, 0.3);
}
.dark-mode .tab.active {
box-shadow: 0 4px 12px rgba(66, 133, 244, 0.3);
}
.tab-indicator {
width: 8px;
height: 8px;
border-radius: 50%;
background-color: var(--success-color);
position: absolute;
top: 10px;
right: 10px;
display: none;
}
.tab-indicator.active {
display: block;
animation: pulse 2s infinite;
}
.tab-content {
display: none;
animation: fadeIn 0.4s ease;
}
.tab-content.active {
display: block;
}
@keyframes fadeIn {
from { opacity: 0; transform: translateY(10px); }
to { opacity: 1; transform: translateY(0); }
}
button, .btn {
background-color: var(--primary-color);
color: white;
border: none;
padding: 12px 20px;
border-radius: 8px;
cursor: pointer;
font-size: 15px;
font-weight: 500;
transition: var(--transition);
display: inline-flex;
align-items: center;
justify-content: center;
gap: 10px;
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
position: relative;
overflow: hidden;
}
button:after, .btn:after {
content: '';
position: absolute;
top: 0;
left: 0;
width: 0;
height: 100%;
background-color: rgba(255, 255, 255, 0.1);
transition: width 0.4s ease;
}
button:hover:after, .btn:hover:after {
width: 100%;
}
button:hover, .btn:hover {
background-color: var(--primary-dark);
transform: translateY(-2px);
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);
}
button:active, .btn:active {
transform: translateY(1px);
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);
}
.btn-secondary {
background-color: var(--secondary-color);
background-image: linear-gradient(135deg, var(--secondary-color), #27ae60);
}
.btn-secondary:hover {
background-color: #2d904c;
}
.btn-danger {
background-color: var(--danger-color);
background-image: linear-gradient(135deg, var(--danger-color), #d32f2f);
}
.btn-danger:hover {
background-color: #c62828;
}
.btn-warning {
background-color: var(--warning-color);
background-image: linear-gradient(135deg, var(--warning-color), #ff8f00);
color: #212121;
}
.btn-warning:hover {
background-color: #ff8f00;
}
.btn-info {
background-color: var(--info-color);
background-image: linear-gradient(135deg, var(--info-color), #1976d2);
}
.btn-info:hover {
background-color: #1976d2;
}
.btn-ghost {
background-color: transparent;
background-image: none;
color: var(--primary-color);
border: 2px solid var(--primary-color);
box-shadow: none;
}
.btn-ghost:hover {
background-color: var(--primary-light);
color: var(--primary-dark);
box-shadow: 0 4px 12px rgba(26, 115, 232, 0.12);
}
.btn-icon {
width: 44px;
height: 44px;
padding: 0;
border-radius: 50%;
display: inline-flex;
align-items: center;
justify-content: center;
}
.btn-large {
padding: 14px 24px;
font-size: 16px;
}
.btn-small {
padding: 8px 16px;
font-size: 13px;
}
.status {
display: flex;
align-items: center;
gap: 8px;
margin-bottom: 10px;
}
.status-indicator {
display: inline-block;
width: 12px;
height: 12px;
border-radius: 50%;
}
.status-active {
background-color: #34a853;
box-shadow: 0 0 0 3px rgba(52, 168, 83, 0.2);
animation: pulse 2s infinite;
}
@keyframes pulse {
0% { box-shadow: 0 0 0 0 rgba(52, 168, 83, 0.4); }
70% { box-shadow: 0 0 0 8px rgba(52, 168, 83, 0); }
100% { box-shadow: 0 0 0 0 rgba(52, 168, 83, 0); }
}
.tools-grid {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(250px, 1fr));
gap: 15px;
margin-top: 20px;
}
.tool-card {
background-color: white;
border-radius: var(--border-radius);
padding: 20px;
box-shadow: var(--box-shadow);
border-left: 4px solid var(--primary-color);
transition: var(--transition);
position: relative;
overflow: hidden;
display: flex;
flex-direction: column;
height: 100%;
}
.dark-mode .tool-card {
background-color: #222222;
}
.tool-card:hover {
transform: translateY(-5px);
box-shadow: 0 12px 20px rgba(0, 0, 0, 0.1);
}
.tool-card:after {
content: '';
position: absolute;
bottom: 0;
right: 0;
width: 50px;
height: 50px;
background-color: rgba(0, 0, 0, 0.02);
border-radius: 50% 0 0 0;
transform: scale(1.5);
z-index: 0;
transition: var(--transition);
}
.dark-mode .tool-card:after {
background-color: rgba(255, 255, 255, 0.02);
}
.tool-card:hover:after {
width: 150px;
height: 150px;
transform: scale(1);
}
.tool-card.tool-bash { border-left-color: #f44336; }
.tool-card.tool-view { border-left-color: #2196f3; }
.tool-card.tool-edit { border-left-color: #4caf50; }
.tool-card.tool-glob { border-left-color: #ff9800; }
.tool-card.tool-grep { border-left-color: #9c27b0; }
.tool-card.tool-ls { border-left-color: #00bcd4; }
.tool-icon {
width: 42px;
height: 42px;
border-radius: 50%;
background-color: var(--primary-light);
display: flex;
align-items: center;
justify-content: center;
margin-bottom: 15px;
color: var(--primary-color);
font-size: 20px;
transition: var(--transition);
}
.tool-bash .tool-icon { background-color: rgba(244, 67, 54, 0.1); color: #f44336; }
.tool-view .tool-icon { background-color: rgba(33, 150, 243, 0.1); color: #2196f3; }
.tool-edit .tool-icon { background-color: rgba(76, 175, 80, 0.1); color: #4caf50; }
.tool-glob .tool-icon { background-color: rgba(255, 152, 0, 0.1); color: #ff9800; }
.tool-grep .tool-icon { background-color: rgba(156, 39, 176, 0.1); color: #9c27b0; }
.tool-ls .tool-icon { background-color: rgba(0, 188, 212, 0.1); color: #00bcd4; }
.tool-card:hover .tool-icon {
transform: scale(1.1);
}
.tool-name {
font-weight: 600;
font-size: 16px;
color: var(--text-color);
margin-bottom: 8px;
display: flex;
align-items: center;
gap: 8px;
z-index: 1;
}
.tool-description {
font-size: 14px;
color: var(--text-light);
flex-grow: 1;
z-index: 1;
}
.tool-stats {
display: flex;
align-items: center;
justify-content: space-between;
margin-top: 15px;
padding-top: 10px;
border-top: 1px solid rgba(0, 0, 0, 0.05);
font-size: 13px;
color: var(--text-light);
z-index: 1;
}
.dark-mode .tool-stats {
border-top-color: rgba(255, 255, 255, 0.05);
}
.tool-usage {
display: flex;
align-items: center;
gap: 5px;
}
.tool-latency {
display: flex;
align-items: center;
gap: 5px;
}
.stats-container {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(200px, 1fr));
gap: 15px;
margin-bottom: 20px;
}
.stat-card {
background-color: white;
border-radius: var(--border-radius);
padding: 20px;
box-shadow: var(--box-shadow);
text-align: center;
transition: var(--transition);
position: relative;
overflow: hidden;
display: flex;
flex-direction: column;
align-items: center;
border: 1px solid rgba(0, 0, 0, 0.03);
}
.dark-mode .stat-card {
background-color: #222222;
border-color: rgba(255, 255, 255, 0.05);
}
.stat-card:hover {
transform: translateY(-5px);
box-shadow: 0 8px 20px rgba(0, 0, 0, 0.1);
}
.stat-card:before {
content: '';
position: absolute;
top: 0;
left: 0;
width: 100%;
height: 4px;
background: linear-gradient(90deg, var(--primary-color), var(--secondary-color));
}
.stat-card.primary:before {
background: linear-gradient(90deg, var(--primary-color), #5c9aff);
}
.stat-card.success:before {
background: linear-gradient(90deg, var(--success-color), #69f0ae);
}
.stat-card.warning:before {
background: linear-gradient(90deg, var(--warning-color), #ffecb3);
}
.stat-card.danger:before {
background: linear-gradient(90deg, var(--danger-color), #ff8a80);
}
.stat-icon {
font-size: 24px;
height: 50px;
width: 50px;
display: flex;
align-items: center;
justify-content: center;
border-radius: 50%;
margin-bottom: 10px;
color: white;
background-color: var(--primary-color);
box-shadow: 0 4px 10px rgba(0, 0, 0, 0.1);
}
.stat-card.primary .stat-icon {
background-color: var(--primary-color);
}
.stat-card.success .stat-icon {
background-color: var(--success-color);
}
.stat-card.warning .stat-icon {
background-color: var(--warning-color);
}
.stat-card.danger .stat-icon {
background-color: var(--danger-color);
}
.stat-value {
font-size: 32px;
font-weight: 700;
color: var(--text-color);
margin: 15px 0 5px;
line-height: 1;
}
.stat-card.primary .stat-value {
color: var(--primary-color);
}
.stat-card.success .stat-value {
color: var(--success-color);
}
.stat-card.warning .stat-value {
color: var(--warning-color);
}
.stat-card.danger .stat-value {
color: var(--danger-color);
}
.stat-label {
font-size: 14px;
font-weight: 500;
color: var(--text-light);
text-transform: uppercase;
letter-spacing: 0.5px;
}
.stat-trend {
display: flex;
align-items: center;
gap: 5px;
margin-top: 8px;
font-size: 13px;
}
.stat-trend.up {
color: var(--success-color);
}
.stat-trend.down {
color: var(--danger-color);
}
.chart-container {
position: relative;
height: 300px;
margin-top: 20px;
border-radius: var(--border-radius);
background-color: rgba(255, 255, 255, 0.5);
padding: 15px;
border: 1px solid rgba(0, 0, 0, 0.03);
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.03);
transition: var(--transition);
}
.dark-mode .chart-container {
background-color: rgba(50, 50, 50, 0.2);
border-color: rgba(255, 255, 255, 0.05);
}
.chart-container:hover {
box-shadow: 0 5px 15px rgba(0, 0, 0, 0.05);
}
.chart-header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 10px;
}
.chart-title {
font-weight: 600;
color: var(--primary-color);
display: flex;
align-items: center;
gap: 8px;
}
.chart-controls {
display: flex;
gap: 10px;
}
.chart-legend {
display: flex;
gap: 15px;
margin-top: 10px;
flex-wrap: wrap;
}
.chart-legend-item {
display: flex;
align-items: center;
gap: 5px;
font-size: 13px;
}
.chart-legend-color {
width: 12px;
height: 12px;
border-radius: 3px;
}
.settings-grid {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(300px, 1fr));
gap: 20px;
}
.form-group {
margin-bottom: 15px;
}
label {
display: block;
margin-bottom: 5px;
font-weight: bold;
}
input, select, textarea {
width: 100%;
padding: 10px;
border: 1px solid #ddd;
border-radius: 4px;
font-size: 14px;
background-color: white;
color: var(--text-color);
}
.dark-mode input,
.dark-mode select,
.dark-mode textarea {
background-color: #333;
border-color: #444;
color: #e0e0e0;
}
.toggle-container {
display: flex;
align-items: center;
}
.toggle {
position: relative;
display: inline-block;
width: 50px;
height: 24px;
margin-right: 10px;
}
.toggle input {
opacity: 0;
width: 0;
height: 0;
}
.toggle-slider {
position: absolute;
cursor: pointer;
top: 0;
left: 0;
right: 0;
bottom: 0;
background-color: #ccc;
transition: .4s;
border-radius: 34px;
}
.toggle-slider:before {
position: absolute;
content: "";
height: 18px;
width: 18px;
left: 3px;
bottom: 3px;
background-color: white;
transition: .4s;
border-radius: 50%;
}
input:checked + .toggle-slider {
background-color: var(--primary-color);
}
input:checked + .toggle-slider:before {
transform: translateX(26px);
}
.modal {
display: none;
position: fixed;
z-index: 1000;
left: 0;
top: 0;
width: 100%;
height: 100%;
background-color: rgba(0, 0, 0, 0.5);
animation: fadeIn 0.3s;
}
@keyframes fadeIn {
from { opacity: 0; }
to { opacity: 1; }
}
.modal-content {
background-color: white;
margin: 10% auto;
padding: 25px;
border-radius: var(--border-radius);
max-width: 600px;
box-shadow: 0 5px 15px rgba(0, 0, 0, 0.3);
position: relative;
animation: slideIn 0.3s;
}
.dark-mode .modal-content {
background-color: #2a2a2a;
}
@keyframes slideIn {
from { transform: translateY(-50px); opacity: 0; }
to { transform: translateY(0); opacity: 1; }
}
.close {
position: absolute;
right: 20px;
top: 15px;
font-size: 22px;
cursor: pointer;
color: var(--text-light);
}
.close:hover {
color: var(--text-color);
}
.loader {
display: inline-block;
width: 20px;
height: 20px;
border: 3px solid rgba(255,255,255,.3);
border-radius: 50%;
border-top-color: white;
animation: spin 1s ease-in-out infinite;
}
@keyframes spin {
to { transform: rotate(360deg); }
}
.notification {
position: fixed;
bottom: 20px;
right: 20px;
padding: 15px 20px;
background-color: var(--primary-color);
color: white;
border-radius: var(--border-radius);
box-shadow: var(--box-shadow);
display: none;
z-index: 1000;
animation: slideUp 0.3s;
}
@keyframes slideUp {
from { transform: translateY(30px); opacity: 0; }
to { transform: translateY(0); opacity: 1; }
}
.code-with-copy {
position: relative;
}
.copy-button {
position: absolute;
top: 5px;
right: 5px;
padding: 5px;
background-color: rgba(255, 255, 255, 0.8);
border-radius: 4px;
cursor: pointer;
font-size: 12px;
color: var(--text-color);
border: none;
}
.dark-mode .copy-button {
background-color: rgba(50, 50, 50, 0.8);
color: var(--text-light);
}
.config-form {
margin-top: 20px;
}
.action-buttons {
display: flex;
gap: 10px;
margin-top: 20px;
}
.server-info {
display: flex;
flex-direction: column;
gap: 5px;
}
.info-row {
display: flex;
justify-content: space-between;
}
.info-label {
font-weight: bold;
color: var(--text-light);
}
.info-value {
color: var(--text-color);
}
.badge {
display: inline-block;
padding: 2px 8px;
border-radius: 12px;
font-size: 12px;
font-weight: bold;
}
.badge-primary {
background-color: var(--primary-light);
color: var(--primary-color);
}
.dark-mode .badge-primary {
background-color: rgba(66, 133, 244, 0.2);
}
.badge-success {
background-color: rgba(52, 168, 83, 0.1);
color: #34a853;
}
.badge-warning {
background-color: rgba(251, 188, 5, 0.1);
color: #fbbc05;
}
.badge-danger {
background-color: rgba(234, 67, 53, 0.1);
color: #ea4335;
}
.resource-list {
margin-top: 15px;
}
.resource-item {
padding: 10px;
border-radius: 4px;
margin-bottom: 5px;
background-color: var(--neutral-color);
display: flex;
justify-content: space-between;
align-items: center;
}
.dark-mode .resource-item {
background-color: #333;
}
.resource-uri {
font-family: 'Courier New', Courier, monospace;
color: var(--primary-color);
}
.theme-switch {
cursor: pointer;
color: var(--text-light);
transition: var(--transition);
}
.theme-switch:hover {
color: var(--primary-color);
}
.clipboard-success {
position: fixed;
top: 20px;
right: 20px;
background-color: var(--secondary-color);
color: white;
padding: 10px 15px;
border-radius: var(--border-radius);
box-shadow: var(--box-shadow);
animation: fadeInOut 2s;
z-index: 1000;
}
@keyframes fadeInOut {
0% { opacity: 0; transform: translateY(-20px); }
20% { opacity: 1; transform: translateY(0); }
80% { opacity: 1; transform: translateY(0); }
100% { opacity: 0; transform: translateY(-20px); }
}
</style>
</head>
<body>
<div class="container">
<header>
<div class="logo">
<i class="fas fa-robot fa-2x" style="color: #4285f4;"></i>
<div>
<h1>Claude MCP Server Dashboard</h1>
<p style="margin: 0;">Advanced Configuration and Monitoring</p>
</div>
</div>
<div class="header-actions">
<button id="refresh-btn" class="btn-ghost"><i class="fas fa-sync-alt"></i> Refresh</button>
<div class="theme-switch" id="theme-toggle">
<i class="fas fa-moon"></i>
</div>
</div>
</header>
<div class="status">
<span class="status-indicator status-active"></span>
<span id="status-text">MCP Server Running - <span id="uptime">10 minutes</span></span>
</div>
<div class="stats-container">
<div class="stat-card primary">
<div class="stat-icon">
<i class="fas fa-tools"></i>
</div>
<div id="tools-count" class="stat-value">7</div>
<div class="stat-label">Available Tools</div>
<div class="stat-trend up">
<i class="fas fa-arrow-up"></i>
<span>2 new</span>
</div>
</div>
<div class="stat-card success">
<div class="stat-icon">
<i class="fas fa-plug"></i>
</div>
<div id="connections-count" class="stat-value">1</div>
<div class="stat-label">Active Connections</div>
<div class="stat-trend up">
<i class="fas fa-arrow-up"></i>
<span>Active now</span>
</div>
</div>
<div class="stat-card danger">
<div class="stat-icon">
<i class="fas fa-exchange-alt"></i>
</div>
<div id="requests-count" class="stat-value">45</div>
<div class="stat-label">Total Requests</div>
<div class="stat-trend up">
<i class="fas fa-arrow-up"></i>
<span>+12 today</span>
</div>
</div>
<div class="stat-card warning">
<div class="stat-icon">
<i class="fas fa-box"></i>
</div>
<div id="resources-count" class="stat-value">3</div>
<div class="stat-label">Resources</div>
<div class="stat-trend">
<i class="fas fa-minus"></i>
<span>No change</span>
</div>
</div>
</div>
<div class="tab-container">
<div class="tabs">
<div class="tab active" data-tab="config">
<i class="fas fa-cog"></i>
Configuration
<span class="tab-indicator"></span>
</div>
<div class="tab" data-tab="monitoring">
<i class="fas fa-chart-line"></i>
Monitoring
<span class="tab-indicator active"></span>
</div>
<div class="tab" data-tab="tools">
<i class="fas fa-tools"></i>
Tools
<span class="tab-indicator"></span>
</div>
<div class="tab" data-tab="clients">
<i class="fas fa-users"></i>
Clients
<span class="tab-indicator"></span>
</div>
<div class="tab" data-tab="settings">
<i class="fas fa-sliders-h"></i>
Settings
<span class="tab-indicator"></span>
</div>
</div>
<div class="tab-content active" id="config-tab">
<div class="card">
<div class="card-accent card-accent-primary"></div>
<div class="card-header">
<h2 class="card-title"><i class="fas fa-cog"></i> Claude Desktop Configuration</h2>
<div class="card-actions">
<button class="btn-icon btn-ghost"><i class="fas fa-question-circle"></i></button>
<button class="btn-icon btn-ghost"><i class="fas fa-external-link-alt"></i></button>
</div>
</div>
<p>Connect your Claude Desktop client to this MCP server by following these steps:</p>
<ol>
<li>Open Claude Desktop application</li>
<li>Click on Settings (gear icon)</li>
<li>Navigate to "Model Context Protocol" section</li>
<li>Click "Add New Server"</li>
<li>Enter the configuration below or use the auto-configuration options</li>
</ol>
<div class="config-box">
<h3>Server Configuration</h3>
<div class="code-with-copy">
<pre><code id="server-config">{
"name": "Claude Code Tools",
"type": "local_process",
"command": "python",
"args": ["claude.py", "serve"],
"workingDirectory": "/path/to/claude-code-directory",
"environment": {},
"description": "A Model Context Protocol server for Claude Code tools"
}</code></pre>
<button class="copy-button" onclick="copyConfig()"><i class="fas fa-copy"></i></button>
</div>
</div>
<div class="action-buttons">
<a href="resource:config://json" download="claude_mcp_config.json" class="btn">
<i class="fas fa-download"></i> Download Configuration File
</a>
<button class="btn btn-secondary" onclick="openQRModal()">
<i class="fas fa-qrcode"></i> Show QR Code
</button>
</div>
<div class="note">
<h3><i class="fas fa-shield-alt"></i> Security Note</h3>
<p>The Claude Code MCP server provides access to your file system and command execution. Only connect to it from trusted clients and be cautious about the operations you perform.</p>
</div>
</div>
<div class="card">
<div class="card-accent card-accent-primary"></div>
<div class="card-header">
<h2 class="card-title"><i class="fas fa-wrench"></i> Advanced Client Options</h2>
<div class="card-actions">
<button class="btn-icon btn-ghost"><i class="fas fa-cog"></i></button>
</div>
</div>
<div class="config-form">
<h3>Customize Your Configuration</h3>
<div class="form-group">
<label for="server-name">Server Name</label>
<input type="text" id="server-name" value="Claude Code Tools">
</div>
<div class="form-group">
<label for="working-dir">Working Directory</label>
<input type="text" id="working-dir" value="/path/to/claude-code-directory">
</div>
<div class="form-group">
<label for="server-env">Environment Variables (JSON)</label>
<textarea id="server-env" rows="3">{}</textarea>
</div>
<button class="btn" onclick="updateConfig()"><i class="fas fa-sync"></i> Update Configuration</button>
</div>
<h3>Multi-Agent Configuration</h3>
<p>For complex problems, use multi-agent mode with synchronized agents:</p>
<div class="code-with-copy">
<pre><code>python claude.py mcp-multi-agent path/to/server.py --config examples/agents_config.json</code></pre>
<button class="copy-button" onclick="copyMultiAgentCmd()"><i class="fas fa-copy"></i></button>
</div>
<button class="btn btn-ghost" onclick="openAgentEditor()"><i class="fas fa-users-cog"></i> Configure Agent Roles</button>
</div>
</div>
<div class="tab-content" id="monitoring-tab">
<div class="dashboard-grid">
<div class="sidebar">
<div class="card">
<h3><i class="fas fa-info-circle"></i> Server Information</h3>
<div class="server-info">
<div class="info-row">
<span class="info-label">Status:</span>
<span class="info-value"><span class="badge badge-success">Running</span></span>
</div>
<div class="info-row">
<span class="info-label">Version:</span>
<span class="info-value">0.1.0</span>
</div>
<div class="info-row">
<span class="info-label">Host:</span>
<span class="info-value">localhost:8000</span>
</div>
<div class="info-row">
<span class="info-label">Uptime:</span>
<span class="info-value" id="server-uptime">10 minutes</span>
</div>
<div class="info-row">
<span class="info-label">Python:</span>
<span class="info-value">3.10.4</span>
</div>
<div class="info-row">
<span class="info-label">FastMCP:</span>
<span class="info-value">0.4.1</span>
</div>
</div>
</div>
<div class="card">
<h3><i class="fas fa-box"></i> Resources</h3>
<p>Available resources:</p>
<div class="resource-list">
<div class="resource-item">
<span class="resource-uri">system://info</span>
<span class="badge badge-primary">GET</span>
</div>
<div class="resource-item">
<span class="resource-uri">config://json</span>
<span class="badge badge-primary">GET</span>
</div>
<div class="resource-item">
<span class="resource-uri">filesystem://{path}</span>
<span class="badge badge-primary">GET</span>
</div>
<div class="resource-item">
<span class="resource-uri">file://{file_path}</span>
<span class="badge badge-primary">GET</span>
</div>
</div>
</div>
</div>
<div class="main-content">
<div class="card">
<div class="card-accent card-accent-primary"></div>
<div class="card-header">
<h2 class="card-title"><i class="fas fa-chart-line"></i> Request Activity</h2>
<div class="card-actions">
<button class="btn-icon btn-ghost"><i class="fas fa-expand"></i></button>
<button class="btn-icon btn-ghost"><i class="fas fa-download"></i></button>
<button class="btn-icon btn-ghost"><i class="fas fa-sync-alt"></i></button>
</div>
</div>
<div class="chart-container">
<div class="chart-header">
<div class="chart-title"><i class="fas fa-chart-line"></i> Real-time Request Activity</div>
<div class="chart-controls">
<button class="btn-small">Hourly</button>
<button class="btn-small">Daily</button>
<button class="btn-small">Weekly</button>
</div>
</div>
<canvas id="requestsChart"></canvas>
<div class="chart-legend">
<div class="chart-legend-item">
<div class="chart-legend-color" style="background-color: #4285f4;"></div>
<span>Tool Calls</span>
</div>
<div class="chart-legend-item">
<div class="chart-legend-color" style="background-color: #34a853;"></div>
<span>Resource Requests</span>
</div>
</div>
</div>
</div>
<div class="card">
<h3><i class="fas fa-clipboard-list"></i> Recent Activity</h3>
<div class="activity-log" id="activity-log">
<pre style="max-height: 300px; overflow-y: auto;">
[2025-03-07 13:45:20] Server started on localhost:8000
[2025-03-07 13:46:05] New connection from 127.0.0.1
[2025-03-07 13:46:10] Tool call: View
[2025-03-07 13:46:15] Resource request: system://info
[2025-03-07 13:46:30] Tool call: GlobTool
[2025-03-07 13:46:45] Tool call: GrepTool
[2025-03-07 13:47:00] Tool call: Bash
[2025-03-07 13:47:15] Resource request: file://README.md
</pre>
</div>
</div>
</div>
</div>
</div>
<div class="tab-content" id="tools-tab">
<div class="card">
<h2><i class="fas fa-tools"></i> Available Tools</h2>
<p>The Claude Code MCP Server provides access to the following tools:</p>
<div class="tools-grid">
<div class="tool-card tool-view">
<div class="tool-icon"><i class="fas fa-eye"></i></div>
<div class="tool-name">View</div>
<div class="tool-description">Read files with optional line limits and supports syntax highlighting for code files</div>
<div class="tool-stats">
<div class="tool-usage"><i class="fas fa-chart-bar"></i> 32 calls</div>
<div class="tool-latency"><i class="fas fa-clock"></i> 45ms avg</div>
</div>
</div>
<div class="tool-card tool-edit">
<div class="tool-icon"><i class="fas fa-edit"></i></div>
<div class="tool-name">Edit</div>
<div class="tool-description">Edit files with precise text replacement and context-aware modifications</div>
<div class="tool-stats">
<div class="tool-usage"><i class="fas fa-chart-bar"></i> 18 calls</div>
<div class="tool-latency"><i class="fas fa-clock"></i> 62ms avg</div>
</div>
</div>
<div class="tool-card tool-edit">
<div class="tool-icon"><i class="fas fa-file-alt"></i></div>
<div class="tool-name">Replace</div>
<div class="tool-description">Overwrite existing files or create new files with specified content</div>
<div class="tool-stats">
<div class="tool-usage"><i class="fas fa-chart-bar"></i> 7 calls</div>
<div class="tool-latency"><i class="fas fa-clock"></i> 38ms avg</div>
</div>
</div>
<div class="tool-card tool-glob">
<div class="tool-icon"><i class="fas fa-search"></i></div>
<div class="tool-name">GlobTool</div>
<div class="tool-description">Find files by pattern matching with support for complex glob patterns</div>
<div class="tool-stats">
<div class="tool-usage"><i class="fas fa-chart-bar"></i> 29 calls</div>
<div class="tool-latency"><i class="fas fa-clock"></i> 74ms avg</div>
</div>
</div>
<div class="tool-card tool-grep">
<div class="tool-icon"><i class="fas fa-search-plus"></i></div>
<div class="tool-name">GrepTool</div>
<div class="tool-description">Search file contents using powerful regular expressions with filters</div>
<div class="tool-stats">
<div class="tool-usage"><i class="fas fa-chart-bar"></i> 24 calls</div>
<div class="tool-latency"><i class="fas fa-clock"></i> 112ms avg</div>
</div>
</div>
<div class="tool-card tool-ls">
<div class="tool-icon"><i class="fas fa-folder-open"></i></div>
<div class="tool-name">LS</div>
<div class="tool-description">List directory contents with optional filtering and ignore patterns</div>
<div class="tool-stats">
<div class="tool-usage"><i class="fas fa-chart-bar"></i> 41 calls</div>
<div class="tool-latency"><i class="fas fa-clock"></i> 28ms avg</div>
</div>
</div>
<div class="tool-card tool-bash">
<div class="tool-icon"><i class="fas fa-terminal"></i></div>
<div class="tool-name">Bash</div>
<div class="tool-description">Execute shell commands with persistent state and timeout options</div>
<div class="tool-stats">
<div class="tool-usage"><i class="fas fa-chart-bar"></i> 15 calls</div>
<div class="tool-latency"><i class="fas fa-clock"></i> 350ms avg</div>
</div>
</div>
<div class="tool-card">
<div class="tool-icon"><i class="fas fa-cog"></i></div>
<div class="tool-name">GetConfiguration</div>
<div class="tool-description">Get Claude Desktop configuration for easy setup and customization</div>
<div class="tool-stats">
<div class="tool-usage"><i class="fas fa-chart-bar"></i> 3 calls</div>
<div class="tool-latency"><i class="fas fa-clock"></i> 18ms avg</div>
</div>
</div>
</div>
<div class="chart-container" style="margin-top: 30px;">
<h3><i class="fas fa-chart-pie"></i> Tool Usage</h3>
<canvas id="toolUsageChart"></canvas>
</div>
</div>
</div>
<div class="tab-content" id="clients-tab">
<div class="card">
<h2><i class="fas fa-laptop-code"></i> Client Configuration</h2>
<h3>Using Claude Code as an MCP Client</h3>
<p>Claude Code can act as a client to connect to other MCP servers:</p>
<div class="code-with-copy">
<pre><code>python claude.py mcp-client path/to/server.py</code></pre>
<button class="copy-button" onclick="copyClientCmd()"><i class="fas fa-copy"></i></button>
</div>
<h3>Specify a Claude Model</h3>
<div class="code-with-copy">
<pre><code>python claude.py mcp-client path/to/server.py --model claude-3-5-sonnet-20241022</code></pre>
<button class="copy-button" onclick="copyModelCmd()"><i class="fas fa-copy"></i></button>
</div>
<h3>Example with Echo Server</h3>
<div class="note">
<p>Run these commands in separate terminals:</p>
<div class="code-with-copy">
<pre><code># Terminal 1: Start the server
python examples/echo_server.py
# Terminal 2: Connect with the client
python claude.py mcp-client examples/echo_server.py</code></pre>
<button class="copy-button" onclick="copyEchoExample()"><i class="fas fa-copy"></i></button>
</div>
</div>
</div>
<div class="card">
<h2><i class="fas fa-users"></i> Multi-Agent Mode</h2>
<p>For complex tasks, the multi-agent mode allows multiple specialized agents to collaborate:</p>
<h3>Quick Start</h3>
<div class="code-with-copy">
<pre><code>python claude.py mcp-multi-agent examples/echo_server.py --config examples/agents_config.json</code></pre>
<button class="copy-button" onclick="copyMultiAgentExample()"><i class="fas fa-copy"></i></button>
</div>
<h3>Agent Configuration</h3>
<p>The <code>agents_config.json</code> file contains these specialized roles:</p>
<ul>
<li><strong>Researcher:</strong> Finds information and analyzes data</li>
<li><strong>Coder:</strong> Writes and debugs code</li>
<li><strong>Critic:</strong> Evaluates solutions and suggests improvements</li>
</ul>
<h3>Multi-Agent Commands</h3>
<ul>
<li><code>/agents</code> - List all active agents</li>
<li><code>/talk <agent> <message></code> - Send a direct message to agent</li>
<li><code>/history</code> - Show message history</li>
<li><code>/help</code> - Show multi-agent help</li>
</ul>
</div>
</div>
<div class="tab-content" id="settings-tab">
<div class="card">
<h2><i class="fas fa-sliders-h"></i> Server Settings</h2>
<div class="settings-grid">
<div>
<h3>General</h3>
<div class="form-group">
<label for="server-port">Server Port</label>
<input type="number" id="server-port" value="8000">
</div>
<div class="form-group">
<label for="server-host">Server Host</label>
<input type="text" id="server-host" value="localhost">
</div>
<div class="form-group toggle-container">
<label class="toggle">
<input type="checkbox" id="dev-mode" checked>
<span class="toggle-slider"></span>
</label>
<span>Development Mode</span>
</div>
</div>
<div>
<h3>Advanced</h3>
<div class="form-group">
<label for="log-level">Log Level</label>
<select id="log-level">
<option>INFO</option>
<option>DEBUG</option>
<option>WARNING</option>
<option>ERROR</option>
</select>
</div>
<div class="form-group toggle-container">
<label class="toggle">
<input type="checkbox" id="enable-metrics" checked>
<span class="toggle-slider"></span>
</label>
<span>Enable Metrics Collection</span>
</div>
<div class="form-group toggle-container">
<label class="toggle">
<input type="checkbox" id="auto-reload">
<span class="toggle-slider"></span>
</label>
<span>Auto-reload on File Changes</span>
</div>
</div>
</div>
<div class="action-buttons">
<button class="btn" onclick="saveSettings()"><i class="fas fa-save"></i> Save Settings</button>
<button class="btn btn-secondary" onclick="restartServer()"><i class="fas fa-redo"></i> Restart Server</button>
<button class="btn btn-danger" onclick="resetSettings()"><i class="fas fa-trash-alt"></i> Reset to Defaults</button>
</div>
<h3 style="margin-top: 30px;"><i class="fas fa-chart-line"></i> Metrics Management</h3>
<p>Manage server metrics data collection and storage.</p>
<div class="action-buttons">
<button class="btn btn-danger" onclick="resetServerMetrics()"><i class="fas fa-eraser"></i> Reset All Metrics</button>
</div>
</div>
</div>
</div>
</div>
<!-- QR Code Modal -->
<div id="qr-modal" class="modal">
<div class="modal-content">
<span class="close" onclick="closeModal('qr-modal')">×</span>
<h2>Scan QR Code to Configure</h2>
<p>Use your Claude Desktop app to scan this QR code and automatically configure the connection:</p>
<div style="text-align: center; margin: 20px 0;">
<img id="qr-code" src="https://api.qrserver.com/v1/create-qr-code/?size=200x200&data=https://example.com/config" alt="QR Code">
</div>
<p class="note">Note: This QR code contains the server configuration details needed to connect.</p>
</div>
</div>
<!-- Agent Editor Modal -->
<div id="agent-editor-modal" class="modal">
<div class="modal-content" style="max-width: 800px;">
<span class="close" onclick="closeModal('agent-editor-modal')">×</span>
<h2>Multi-Agent Configuration Editor</h2>
<p>Customize the roles and capabilities of your agents:</p>
<div id="agent-config-editor" style="margin: 20px 0;">
<div class="form-group">
<label for="agent-config-json">Agent Configuration (JSON)</label>
<textarea id="agent-config-json" rows="15" style="font-family: monospace; white-space: pre;">[
{
"name": "Researcher",
"role": "research specialist",
"model": "claude-3-5-sonnet-20241022",
"system_prompt": "You are a research specialist participating in a multi-agent conversation. Your primary role is to find information, analyze data, and provide well-researched answers. You should use tools to gather information and verify facts. Always cite your sources when possible."
},
{
"name": "Coder",
"role": "programming expert",
"model": "claude-3-5-sonnet-20241022",
"system_prompt": "You are a coding expert participating in a multi-agent conversation. Your primary role is to write, debug, and explain code. You should use tools to test your code and provide working solutions. Always prioritize clean, maintainable code with proper error handling. You can collaborate with other agents to solve complex problems."
},
{
"name": "Critic",
"role": "critical thinker",
"model": "claude-3-5-sonnet-20241022",
"system_prompt": "You are a critical thinker participating in a multi-agent conversation. Your primary role is to evaluate proposals, find potential issues, and suggest improvements. You should question assumptions, point out flaws, and help refine ideas. Be constructive in your criticism and suggest alternatives rather than just pointing out problems."
}
]</textarea>
</div>
<div class="action-buttons">
<button class="btn" onclick="saveAgentConfig()"><i class="fas fa-save"></i> Save Configuration</button>
<button class="btn btn-secondary" onclick="addNewAgent()"><i class="fas fa-plus"></i> Add Agent</button>
<button class="btn btn-ghost" onclick="validateAgentConfig()"><i class="fas fa-check"></i> Validate</button>
</div>
</div>
</div>
</div>
<div id="notification" class="notification"></div>
<script>
// Initialize charts when the page loads
document.addEventListener('DOMContentLoaded', function() {
initializeCharts();
initializeConfig();
setupTabNavigation();
setupThemeToggle();
// Start live updates
updateStats();
setInterval(updateStats, 5000);
// Set up refresh button
document.getElementById('refresh-btn').addEventListener('click', function() {
updateStats();
showNotification('Dashboard refreshed!');
});
});
// Tab navigation
function setupTabNavigation() {
const tabs = document.querySelectorAll('.tab');
tabs.forEach(tab => {
tab.addEventListener('click', () => {
// Remove active class from all tabs and content
document.querySelectorAll('.tab').forEach(t => t.classList.remove('active'));
document.querySelectorAll('.tab-content').forEach(c => c.classList.remove('active'));
// Add active class to current tab and content
tab.classList.add('active');
const tabId = `${tab.dataset.tab}-tab`;
document.getElementById(tabId).classList.add('active');
});
});
}
// Dark mode toggle
function setupThemeToggle() {
const themeToggle = document.getElementById('theme-toggle');
const prefersDarkMode = window.matchMedia && window.matchMedia('(prefers-color-scheme: dark)').matches;
// Set initial theme based on user preference
if (prefersDarkMode) {
document.body.classList.add('dark-mode');
themeToggle.innerHTML = '<i class="fas fa-sun"></i>';
}
themeToggle.addEventListener('click', () => {
document.body.classList.toggle('dark-mode');
if (document.body.classList.contains('dark-mode')) {
themeToggle.innerHTML = '<i class="fas fa-sun"></i>';
} else {
themeToggle.innerHTML = '<i class="fas fa-moon"></i>';
}
});
}
// Copy configuration to clipboard
function copyConfig() {
const configText = document.getElementById('server-config').textContent;
navigator.clipboard.writeText(configText)
.then(() => showNotification('Configuration copied to clipboard!'))
.catch(err => console.error('Failed to copy: ', err));
}
// Copy other commands
function copyMultiAgentCmd() {
navigator.clipboard.writeText('python claude.py mcp-multi-agent path/to/server.py --config examples/agents_config.json')
.then(() => showNotification('Command copied to clipboard!'));
}
function copyClientCmd() {
navigator.clipboard.writeText('python claude.py mcp-client path/to/server.py')
.then(() => showNotification('Command copied to clipboard!'));
}
function copyModelCmd() {
navigator.clipboard.writeText('python claude.py mcp-client path/to/server.py --model claude-3-5-sonnet-20241022')
.then(() => showNotification('Command copied to clipboard!'));
}
function copyEchoExample() {
navigator.clipboard.writeText('# Terminal 1: Start the server\npython examples/echo_server.py\n\n# Terminal 2: Connect with the client\npython claude.py mcp-client examples/echo_server.py')
.then(() => showNotification('Example copied to clipboard!'));
}
function copyMultiAgentExample() {
navigator.clipboard.writeText('python claude.py mcp-multi-agent examples/echo_server.py --config examples/agents_config.json')
.then(() => showNotification('Example copied to clipboard!'));
}
// Show notification
function showNotification(message) {
const notification = document.getElementById('notification');
notification.textContent = message;
notification.style.display = 'block';
setTimeout(() => {
notification.style.display = 'none';
}, 3000);
}
// Modal handlers
function openQRModal() {
document.getElementById('qr-modal').style.display = 'block';
// In a real implementation, this would generate a QR code with the actual config
const config = encodeURIComponent(JSON.stringify({
name: "Claude Code Tools",
type: "local_process",
command: "python",
args: ["claude.py", "serve"],
workingDirectory: "/path/to/claude-code-directory"
}));
document.getElementById('qr-code').src = `https://api.qrserver.com/v1/create-qr-code/?size=200x200&data=${config}`;
}
function openAgentEditor() {
document.getElementById('agent-editor-modal').style.display = 'block';
}
function closeModal(modalId) {
document.getElementById(modalId).style.display = 'none';
}
// Update configuration
function updateConfig() {
const serverName = document.getElementById('server-name').value;
const workingDir = document.getElementById('working-dir').value;
let serverEnv = {};
try {
serverEnv = JSON.parse(document.getElementById('server-env').value);
} catch (e) {
showNotification('Invalid JSON in environment variables');
return;
}
const config = {
name: serverName,
type: "local_process",
command: "python",
args: ["claude.py", "serve"],
workingDirectory: workingDir,
environment: serverEnv,
description: "A Model Context Protocol server for Claude Code tools"
};
document.getElementById('server-config').textContent = JSON.stringify(config, null, 2);
showNotification('Configuration updated!');
}
// Initialize configuration
function initializeConfig() {
// Fetch actual configuration
fetch('resource:config://json')
.then(response => response.json())
.then(config => {
document.getElementById('server-config').textContent = JSON.stringify(config, null, 2);
document.getElementById('working-dir').value = config.workingDirectory || '/path/to/claude-code-directory';
document.getElementById('server-name').value = config.name || 'Claude Code Tools';
document.getElementById('server-env').value = JSON.stringify(config.environment || {}, null, 2);
})
.catch(error => {
console.error('Error fetching configuration:', error);
});
// Fetch metrics data
fetch('resource:metrics://json')
.then(response => response.json())
.then(metricsData => {
// Update the stats
document.getElementById('uptime').textContent = metricsData.uptime;
document.getElementById('server-uptime').textContent = metricsData.uptime;
document.getElementById('tools-count').textContent = Object.keys(metricsData.tool_usage || {}).length;
document.getElementById('connections-count').textContent = metricsData.active_connections || 0;
document.getElementById('requests-count').textContent = (
Object.values(metricsData.tool_usage || {}).reduce((a, b) => a + b, 0) +
Object.values(metricsData.resource_usage || {}).reduce((a, b) => a + b, 0)
);
document.getElementById('resources-count').textContent = Object.keys(metricsData.resource_usage || {}).length;
// Update activity log
if (metricsData.recent_activity && metricsData.recent_activity.length > 0) {
const activityLog = document.getElementById('activity-log');
let logContent = '';
metricsData.recent_activity.forEach(event => {
const time = event.formatted_time;
if (event.type === 'tool') {
logContent += `[${time}] Tool call: ${event.name}\n`;
} else if (event.type === 'resource') {
logContent += `[${time}] Resource request: ${event.uri}\n`;
} else if (event.type === 'connection') {
const action = event.action === 'connect' ? 'connected' : 'disconnected';
logContent += `[${time}] Client ${event.client_id.substring(0, 8)} ${action}\n`;
} else if (event.type === 'error') {
logContent += `[${time}] Error (${event.error_type}): ${event.message}\n`;
}
});
activityLog.querySelector('pre').textContent = logContent;
}
// Update chart data if the charts are initialized
if (window.toolUsageChart && metricsData.tool_usage) {
const toolLabels = Object.keys(metricsData.tool_usage);
const toolData = Object.values(metricsData.tool_usage);
window.toolUsageChart.data.labels = toolLabels;
window.toolUsageChart.data.datasets[0].data = toolData;
window.toolUsageChart.update();
}
if (window.requestsChart && metricsData.time_series) {
// Update the time series data
if (metricsData.time_series.tool_calls) {
const labels = metricsData.time_series.tool_calls.map(d => d.formatted_time);
const toolCallData = metricsData.time_series.tool_calls.map(d => d.value);
const resourceData = metricsData.time_series.resource_calls.map(d => d.value);
window.requestsChart.data.labels = labels;
window.requestsChart.data.datasets[0].data = toolCallData;
window.requestsChart.data.datasets[1].data = resourceData;
window.requestsChart.update();
}
}
})
.catch(error => {
console.error('Error fetching metrics:', error);
});
}
// Save agent configuration
function saveAgentConfig() {
try {
const config = JSON.parse(document.getElementById('agent-config-json').value);
// In a real implementation, this would save the config to a file
showNotification('Agent configuration saved!');
} catch (e) {
showNotification('Invalid JSON configuration');
}
}
// Add new agent
function addNewAgent() {
try {
let config = JSON.parse(document.getElementById('agent-config-json').value);
config.push({
name: "New Agent",
role: "assistant",
model: "claude-3-5-sonnet-20241022",
system_prompt: "You are a helpful AI assistant participating in a multi-agent conversation."
});
document.getElementById('agent-config-json').value = JSON.stringify(config, null, 2);
showNotification('New agent added!');
} catch (e) {
showNotification('Invalid JSON configuration');
}
}
// Validate agent configuration
function validateAgentConfig() {
try {
const config = JSON.parse(document.getElementById('agent-config-json').value);
if (!Array.isArray(config)) {
throw new Error('Configuration must be an array');
}
for (const agent of config) {
if (!agent.name || !agent.role || !agent.model || !agent.system_prompt) {
throw new Error('Each agent must have name, role, model, and system_prompt');
}
}
showNotification('Configuration is valid!');
} catch (e) {
showNotification('Invalid configuration: ' + e.message);
}
}
// Settings handlers
function saveSettings() {
// In a real implementation, this would save settings to the server
showNotification('Settings saved successfully!');
}
function restartServer() {
// In a real implementation, this would restart the server
showNotification('Server restarting...');
setTimeout(() => {
showNotification('Server restarted successfully!');
}, 2000);
}
function resetSettings() {
// Reset settings to defaults
document.getElementById('server-port').value = '8000';
document.getElementById('server-host').value = 'localhost';
document.getElementById('dev-mode').checked = true;
document.getElementById('log-level').value = 'INFO';
document.getElementById('enable-metrics').checked = true;
document.getElementById('auto-reload').checked = false;
showNotification('Settings reset to defaults');
}
// Reset server metrics
function resetServerMetrics() {
if (confirm('Are you sure you want to reset all server metrics? This action cannot be undone.')) {
// Use the metrics reset tool
fetch('resource:metrics://json', { method: 'GET' })
.then(response => {
// We don't actually reset here, but in a real implementation
// this would use the ResetServerMetrics tool
showNotification('Server metrics have been reset!');
// Refresh the dashboard
updateStats();
})
.catch(error => {
console.error('Error resetting metrics:', error);
showNotification('Error resetting metrics');
});
}
}
// Initialize charts
function initializeCharts() {
// Request activity chart
const requestsCtx = document.getElementById('requestsChart').getContext('2d');
window.requestsChart = new Chart(requestsCtx, {
type: 'line',
data: {
labels: ['10m ago', '9m ago', '8m ago', '7m ago', '6m ago', '5m ago', '4m ago', '3m ago', '2m ago', '1m ago', 'Now'],
datasets: [{
label: 'Tool Calls',
data: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
borderColor: '#4285f4',
backgroundColor: 'rgba(66, 133, 244, 0.1)',
tension: 0.4,
fill: true
}, {
label: 'Resource Requests',
data: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
borderColor: '#34a853',
backgroundColor: 'rgba(52, 168, 83, 0.1)',
tension: 0.4,
fill: true
}]
},
options: {
responsive: true,
maintainAspectRatio: false,
plugins: {
legend: {
position: 'top',
},
title: {
display: false
}
},
scales: {
y: {
beginAtZero: true
}
}
}
});
// Tool usage chart
const toolsCtx = document.getElementById('toolUsageChart').getContext('2d');
window.toolUsageChart = new Chart(toolsCtx, {
type: 'doughnut',
data: {
labels: ['View', 'GlobTool', 'GrepTool', 'Bash', 'LS', 'Edit', 'Replace', 'GetConfiguration'],
datasets: [{
data: [0, 0, 0, 0, 0, 0, 0, 0],
backgroundColor: [
'#4285f4', '#ea4335', '#34a853', '#fbbc05',
'#ff6d01', '#46bdc6', '#7baaf7', '#b366f6',
'#9c27b0', '#673ab7', '#3f51b5', '#2196f3',
'#03a9f4', '#00bcd4', '#009688', '#4caf50'
],
borderWidth: 1
}]
},
options: {
responsive: true,
maintainAspectRatio: false,
plugins: {
legend: {
position: 'right',
}
}
}
});
}
// Update stats periodically
function updateStats() {
// Fetch real-time metrics data from the server
fetch('resource:metrics://json')
.then(response => response.json())
.then(metricsData => {
// Update the stats
document.getElementById('uptime').textContent = metricsData.uptime;
document.getElementById('server-uptime').textContent = metricsData.uptime;
document.getElementById('tools-count').textContent = Object.keys(metricsData.tool_usage || {}).length;
document.getElementById('connections-count').textContent = metricsData.active_connections || 0;
document.getElementById('requests-count').textContent = (
Object.values(metricsData.tool_usage || {}).reduce((a, b) => a + b, 0) +
Object.values(metricsData.resource_usage || {}).reduce((a, b) => a + b, 0)
);
document.getElementById('resources-count').textContent = Object.keys(metricsData.resource_usage || {}).length;
// Update activity log
if (metricsData.recent_activity && metricsData.recent_activity.length > 0) {
const activityLog = document.getElementById('activity-log');
let logContent = '';
metricsData.recent_activity.forEach(event => {
const time = event.formatted_time;
if (event.type === 'tool') {
logContent += `[${time}] Tool call: ${event.name}\n`;
} else if (event.type === 'resource') {
logContent += `[${time}] Resource request: ${event.uri}\n`;
} else if (event.type === 'connection') {
const action = event.action === 'connect' ? 'connected' : 'disconnected';
logContent += `[${time}] Client ${event.client_id.substring(0, 8)} ${action}\n`;
} else if (event.type === 'error') {
logContent += `[${time}] Error (${event.error_type}): ${event.message}\n`;
}
});
activityLog.querySelector('pre').textContent = logContent;
}
// Update chart data if the charts are initialized
if (window.toolUsageChart && metricsData.tool_usage) {
const toolLabels = Object.keys(metricsData.tool_usage);
const toolData = Object.values(metricsData.tool_usage);
window.toolUsageChart.data.labels = toolLabels;
window.toolUsageChart.data.datasets[0].data = toolData;
window.toolUsageChart.update();
}
if (window.requestsChart && metricsData.time_series) {
// Update the time series data
if (metricsData.time_series.tool_calls) {
const labels = metricsData.time_series.tool_calls.map(d => d.formatted_time);
const toolCallData = metricsData.time_series.tool_calls.map(d => d.value);
const resourceData = metricsData.time_series.resource_calls.map(d => d.value);
window.requestsChart.data.labels = labels;
window.requestsChart.data.datasets[0].data = toolCallData;
window.requestsChart.data.datasets[1].data = resourceData;
window.requestsChart.update();
}
}
})
.catch(error => {
console.error('Error fetching metrics:', error);
});
}
// Update server uptime - no longer needed as it's part of updateStats
function updateUptime() {
// This is now handled by updateStats which fetches the actual uptime from the server
updateStats();
}
</script>
</body>
</html>
```
--------------------------------------------------------------------------------
/claude_code/lib/rl/tool_optimizer.py:
--------------------------------------------------------------------------------
```python
"""
Advanced tool selection optimization for Claude Code Python.
This module implements a specialized reinforcement learning system for optimizing
tool selection based on user queries and context. It uses advanced RL techniques
combined with neural models to learn which tools work best for different types of
queries over time, featuring transfer learning, meta-learning, and causal reasoning.
"""
import numpy as np
import os
import json
import time
import math
import random
from typing import Dict, List, Any, Optional, Tuple, Callable, Union, Set
from dataclasses import dataclass, field
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import deque, defaultdict
try:
from sentence_transformers import SentenceTransformer
HAVE_SENTENCE_TRANSFORMERS = True
except ImportError:
HAVE_SENTENCE_TRANSFORMERS = False
try:
import networkx as nx
HAVE_NETWORKX = True
except ImportError:
HAVE_NETWORKX = False
try:
import faiss
HAVE_FAISS = True
except ImportError:
HAVE_FAISS = False
from .grpo import ToolSelectionGRPO
# Advanced streaming and reflection capabilities
class StreamingReflectionEngine:
"""Engine for real-time streaming of thoughts, self-correction, and reflection."""
def __init__(self, embedding_dim: int = 768, reflection_buffer_size: int = 1000):
"""Initialize the streaming reflection engine.
Args:
embedding_dim: Dimension of embeddings
reflection_buffer_size: Size of reflection buffer
"""
self.embedding_dim = embedding_dim
self.reflection_buffer_size = reflection_buffer_size
# Reflection memory buffer
self.reflection_buffer = deque(maxlen=reflection_buffer_size)
# Working memory for current thought stream
self.working_memory = []
# Long-term memory for learned reflections
self.reflection_patterns = {}
# Reflection critic neural network
self.reflection_critic = nn.Sequential(
nn.Linear(embedding_dim, embedding_dim),
nn.LayerNorm(embedding_dim),
nn.ReLU(),
nn.Linear(embedding_dim, embedding_dim // 2),
nn.LayerNorm(embedding_dim // 2),
nn.ReLU(),
nn.Linear(embedding_dim // 2, 3) # 3 outputs: continue, revise, complete
)
# Thought revision network
self.thought_reviser = nn.Transformer(
d_model=embedding_dim,
nhead=8,
num_encoder_layers=3,
num_decoder_layers=3,
dim_feedforward=embedding_dim * 4,
dropout=0.1
)
# Self-correction performance metrics
self.correction_metrics = {
"total_corrections": 0,
"helpful_corrections": 0,
"correction_depth": [],
"avg_correction_time": 0.0,
"total_correction_time": 0.0
}
# Learning rate for reflection updates
self.reflection_lr = 0.001
# Optimizer for reflection models
self.optimizer = torch.optim.Adam(
list(self.reflection_critic.parameters()) +
list(self.thought_reviser.parameters()),
lr=self.reflection_lr
)
def start_reflection_stream(self, query_embedding: np.ndarray, context: Dict[str, Any]) -> str:
"""Start a new reflection stream for a query.
Args:
query_embedding: Embedding of the query
context: Additional context
Returns:
Stream ID for this reflection session
"""
stream_id = f"reflection_{int(time.time())}_{random.randint(0, 10000)}"
# Initialize working memory for this stream
self.working_memory = [
{
"type": "query",
"embedding": torch.FloatTensor(query_embedding),
"timestamp": time.time(),
"context": context,
"stream_id": stream_id
}
]
return stream_id
def add_thought(self,
stream_id: str,
thought_embedding: np.ndarray,
thought_text: str,
thought_type: str = "reasoning") -> Dict[str, Any]:
"""Add a thought to the reflection stream and get feedback.
Args:
stream_id: ID of the reflection stream
thought_embedding: Embedding of the thought
thought_text: Text of the thought
thought_type: Type of thought (reasoning, plan, action, etc.)
Returns:
Feedback on the thought
"""
# Convert to tensor
thought_tensor = torch.FloatTensor(thought_embedding)
# Create thought record
thought = {
"type": thought_type,
"embedding": thought_tensor,
"text": thought_text,
"timestamp": time.time(),
"stream_id": stream_id,
"depth": len(self.working_memory)
}
# Add to working memory
self.working_memory.append(thought)
# Get reflection feedback
feedback = self._reflect_on_thought(thought)
# Store in reflection buffer
self.reflection_buffer.append({
"thought": thought,
"feedback": feedback
})
return feedback
def _reflect_on_thought(self, thought: Dict[str, Any]) -> Dict[str, Any]:
"""Generate reflection on a thought.
Args:
thought: The thought to reflect on
Returns:
Reflection feedback
"""
# Get thought embedding
thought_embedding = thought["embedding"]
# Get critic prediction
with torch.no_grad():
critic_output = self.reflection_critic(thought_embedding.unsqueeze(0))
action_probs = F.softmax(critic_output, dim=1).squeeze(0)
# Actions: [continue, revise, complete]
action_idx = torch.argmax(action_probs).item()
action_confidence = action_probs[action_idx].item()
actions = ["continue", "revise", "complete"]
action = actions[action_idx]
# Check if similar to patterns we've seen before
pattern_matches = []
if len(self.working_memory) >= 3:
# Get sequence of last 3 thoughts
sequence = [t["embedding"] for t in self.working_memory[-3:]]
sequence_tensor = torch.stack(sequence)
# Compare to known patterns
for pattern_name, pattern_data in self.reflection_patterns.items():
if len(pattern_data["sequence"]) == 3:
# Compute similarity
pattern_tensor = torch.stack(pattern_data["sequence"])
similarity = F.cosine_similarity(
sequence_tensor.mean(dim=0).unsqueeze(0),
pattern_tensor.mean(dim=0).unsqueeze(0)
).item()
if similarity > 0.7: # High similarity threshold
pattern_matches.append({
"pattern": pattern_name,
"similarity": similarity,
"outcome": pattern_data["outcome"]
})
# Check for circular reasoning
is_circular = False
if len(self.working_memory) >= 5:
recent_thoughts = [t["embedding"] for t in self.working_memory[-5:]]
# Check if latest thought is very similar to any of the previous 4
latest = recent_thoughts[-1]
for prev in recent_thoughts[:-1]:
similarity = F.cosine_similarity(latest.unsqueeze(0), prev.unsqueeze(0)).item()
if similarity > 0.85: # Very high similarity threshold
is_circular = True
break
# Generate revision suggestion if needed
revision_suggestion = None
if action == "revise" or is_circular:
revision_suggestion = self._generate_revision(thought)
# Create feedback
feedback = {
"action": action,
"confidence": action_confidence,
"is_circular": is_circular,
"pattern_matches": pattern_matches,
"revision_suggestion": revision_suggestion,
"timestamp": time.time()
}
return feedback
def _generate_revision(self, thought: Dict[str, Any]) -> Dict[str, Any]:
"""Generate a revision for a thought.
Args:
thought: The thought to revise
Returns:
Revision suggestion
"""
# If we have fewer than 2 thoughts, can't generate meaningful revision
if len(self.working_memory) < 2:
return {
"type": "general",
"embedding": thought["embedding"].detach().numpy(),
"message": "Consider providing more specific reasoning"
}
# Get context from previous thoughts
context_embeddings = torch.stack([t["embedding"] for t in self.working_memory[:-1]])
# Create source and target sequences for transformer
src = context_embeddings.unsqueeze(1) # [seq_len, batch_size, embedding_dim]
tgt = thought["embedding"].unsqueeze(0).unsqueeze(1) # [1, batch_size, embedding_dim]
# Generate revision using transformer
with torch.no_grad():
# Create attention mask
src_mask = torch.zeros(src.shape[0], src.shape[0]).bool()
# Revised thought
revised_embedding = self.thought_reviser(
src,
tgt,
src_mask=src_mask,
tgt_mask=torch.zeros(1, 1).bool()
)
# Extract the output embedding
revised_embedding = revised_embedding[0, 0]
# Look for insights from reflection buffer
insights = []
# Find similar thoughts from reflection buffer
for entry in self.reflection_buffer:
past_thought = entry["thought"]
# Skip if from current stream
if past_thought.get("stream_id") == thought.get("stream_id"):
continue
# Compute similarity
similarity = F.cosine_similarity(
thought["embedding"].unsqueeze(0),
past_thought["embedding"].unsqueeze(0)
).item()
if similarity > 0.6: # Significant similarity
insights.append({
"type": "similar_thought",
"similarity": similarity,
"feedback": entry["feedback"]
})
# Create revision suggestion
revision = {
"type": "specific",
"embedding": revised_embedding.detach().numpy(),
"insights": insights[:3], # Top 3 insights
"message": "Consider revising this thought for more clarity and precision"
}
return revision
def complete_reflection(self, stream_id: str,
outcome: Dict[str, Any],
success: bool) -> Dict[str, Any]:
"""Complete a reflection stream and learn from it.
Args:
stream_id: ID of the reflection stream
outcome: Outcome of the actions taken based on reflections
success: Whether the outcome was successful
Returns:
Reflection summary and metrics
"""
# Filter working memory for this stream
stream_thoughts = [t for t in self.working_memory if t.get("stream_id") == stream_id]
if not stream_thoughts:
return {"status": "error", "message": "Stream not found"}
# Count corrections
corrections = sum(1 for t in stream_thoughts if t.get("type") == "correction")
# Update metrics
self.correction_metrics["total_corrections"] += corrections
if success:
self.correction_metrics["helpful_corrections"] += corrections
if corrections > 0:
self.correction_metrics["correction_depth"].append(len(stream_thoughts))
# Learn from this reflection session
self._learn_from_reflection(stream_thoughts, outcome, success)
# Extract and store useful thought patterns
if success and len(stream_thoughts) >= 3:
self._extract_thought_patterns(stream_thoughts, outcome)
# Compute summary stats
duration = time.time() - stream_thoughts[0]["timestamp"]
avg_thought_time = duration / len(stream_thoughts)
# Generate summary
summary = {
"stream_id": stream_id,
"num_thoughts": len(stream_thoughts),
"num_corrections": corrections,
"duration": duration,
"avg_thought_time": avg_thought_time,
"success": success,
"outcome_summary": outcome.get("summary", "No summary provided")
}
return summary
def _learn_from_reflection(self, thoughts: List[Dict[str, Any]],
outcome: Dict[str, Any],
success: bool) -> None:
"""Learn from a completed reflection stream.
Args:
thoughts: List of thoughts in the stream
outcome: Outcome of the actions
success: Whether the outcome was successful
"""
if not thoughts:
return
# Skip if too few thoughts to learn from
if len(thoughts) < 3:
return
# Create training examples for reflection critic
examples = []
for i in range(1, len(thoughts) - 1):
# Current thought
thought_embedding = thoughts[i]["embedding"]
# Determine correct action label based on what happened
# 0: continue, 1: revise, 2: complete
if i == len(thoughts) - 2:
# Second-to-last thought should have led to completion
label = 2
elif thoughts[i+1].get("type") == "correction":
# This thought was followed by a correction, should have been revised
label = 1
else:
# This thought was good to continue from
label = 0
# Create example
examples.append((thought_embedding, label))
# Skip training if too few examples
if not examples:
return
# Update reflection critic with these examples
self.optimizer.zero_grad()
critic_loss = 0.0
for embedding, label in examples:
# Forward pass
logits = self.reflection_critic(embedding.unsqueeze(0))
# Compute loss
target = torch.tensor([label], device=embedding.device)
loss = F.cross_entropy(logits, target)
critic_loss += loss
# Scale loss by number of examples
critic_loss /= len(examples)
# Backpropagation
critic_loss.backward()
# Update parameters
self.optimizer.step()
def _extract_thought_patterns(self, thoughts: List[Dict[str, Any]],
outcome: Dict[str, Any]) -> None:
"""Extract useful thought patterns from successful reflection streams.
Args:
thoughts: List of thoughts in the stream
outcome: Outcome information
"""
# Need at least 3 thoughts to form a meaningful pattern
if len(thoughts) < 3:
return
# Generate a name for this pattern
pattern_id = f"pattern_{len(self.reflection_patterns) + 1}"
# Extract sequences of 3 consecutive thoughts
for i in range(len(thoughts) - 2):
sequence = thoughts[i:i+3]
# Skip if any thought is a correction - we want clean sequences
if any(t.get("type") == "correction" for t in sequence):
continue
# Get embeddings for the sequence
sequence_embeddings = [t["embedding"] for t in sequence]
# Store the pattern
self.reflection_patterns[f"{pattern_id}_{i}"] = {
"sequence": sequence_embeddings,
"outcome": {
"success": outcome.get("success", True),
"context": outcome.get("context", {}),
"summary": outcome.get("summary", "")
},
"timestamp": time.time()
}
# Limit number of patterns to prevent memory issues
if len(self.reflection_patterns) > 100:
# Remove oldest pattern
oldest_key = min(self.reflection_patterns.keys(),
key=lambda k: self.reflection_patterns[k]["timestamp"])
del self.reflection_patterns[oldest_key]
# Active Learning and Self-Improvement
class ActiveLearningSystem:
"""Active learning system that identifies knowledge gaps and seeks targeted improvement."""
def __init__(self, embedding_dim: int = 768, exploration_rate: float = 0.2):
"""Initialize the active learning system.
Args:
embedding_dim: Dimension of embeddings
exploration_rate: Rate of exploration vs. exploitation
"""
self.embedding_dim = embedding_dim
self.exploration_rate = exploration_rate
# Knowledge graph for tracking what's known/unknown
self.knowledge_graph = nx.DiGraph() if HAVE_NETWORKX else None
# Uncertainty estimation model
self.uncertainty_estimator = nn.Sequential(
nn.Linear(embedding_dim, embedding_dim),
nn.LayerNorm(embedding_dim),
nn.ReLU(),
nn.Linear(embedding_dim, embedding_dim // 2),
nn.LayerNorm(embedding_dim // 2),
nn.ReLU(),
nn.Linear(embedding_dim // 2, 2) # [confidence, uncertainty]
)
# Knowledge boundaries
self.knowledge_centroids = []
self.knowledge_radius = {}
# Learning curriculum
self.learning_targets = []
self.learning_progress = {}
# Exploration history
self.exploration_history = []
# Coreset for diversity
self.coreset = []
self.coreset_embeddings = []
# Faiss index for fast nearest neighbor search
self.index = None
if HAVE_FAISS:
self.index = faiss.IndexFlatL2(embedding_dim)
# Optimizer for uncertainty estimator
self.optimizer = torch.optim.Adam(self.uncertainty_estimator.parameters(), lr=0.001)
def estimate_uncertainty(self, query_embedding: np.ndarray) -> Dict[str, float]:
"""Estimate uncertainty for a query or state.
Args:
query_embedding: Embedding to evaluate
Returns:
Dictionary with confidence and uncertainty scores
"""
# Convert to tensor
query_tensor = torch.FloatTensor(query_embedding)
# Get uncertainty estimate
with torch.no_grad():
estimate = self.uncertainty_estimator(query_tensor.unsqueeze(0))
confidence, uncertainty = F.softmax(estimate, dim=1).squeeze(0).tolist()
# Compute distance-based uncertainty if we have knowledge centroids
distance_uncertainty = 0.0
if self.knowledge_centroids:
# Convert to numpy for distance calculation
centroid_array = np.vstack(self.knowledge_centroids)
query_array = query_embedding.reshape(1, -1)
# Compute distances to all centroids
distances = np.linalg.norm(centroid_array - query_array, axis=1)
# Get distance to nearest centroid
min_dist = np.min(distances)
min_idx = np.argmin(distances)
nearest_centroid = self.knowledge_centroids[min_idx]
# Radius of knowledge around this centroid
radius = self.knowledge_radius.get(tuple(nearest_centroid), 1.0)
# Normalize distance by radius to get uncertainty
distance_uncertainty = min(1.0, min_dist / radius)
# Combine model and distance uncertainty
combined_uncertainty = 0.7 * uncertainty + 0.3 * distance_uncertainty
combined_confidence = 1.0 - combined_uncertainty
return {
"confidence": combined_confidence,
"uncertainty": combined_uncertainty,
"model_confidence": confidence,
"model_uncertainty": uncertainty,
"distance_uncertainty": distance_uncertainty
}
def should_explore(self, query_embedding: np.ndarray, context: Dict[str, Any]) -> bool:
"""Determine if we should explore to gather new knowledge for this query.
Args:
query_embedding: Query embedding
context: Additional context
Returns:
Whether to explore
"""
# Estimate uncertainty
uncertainty_info = self.estimate_uncertainty(query_embedding)
# Always explore if uncertainty is very high
if uncertainty_info["uncertainty"] > 0.8:
return True
# Use epsilon-greedy strategy with adaptive exploration
# Higher uncertainty means more likely to explore
adaptive_rate = self.exploration_rate * (0.5 + uncertainty_info["uncertainty"])
# Apply epsilon-greedy
return random.random() < adaptive_rate
def add_knowledge(self, query_embedding: np.ndarray,
related_info: Dict[str, Any],
confidence: float) -> None:
"""Add knowledge to the system.
Args:
query_embedding: Query embedding
related_info: Related information (e.g., tool used, outcome)
confidence: Confidence in this knowledge
"""
# Add to knowledge graph
if self.knowledge_graph is not None:
# Create node for this query
query_key = f"query_{len(self.knowledge_graph.nodes)}"
self.knowledge_graph.add_node(query_key,
embedding=query_embedding,
confidence=confidence,
timestamp=time.time())
# Add related information as connected nodes
for key, value in related_info.items():
info_key = f"{key}_{len(self.knowledge_graph.nodes)}"
self.knowledge_graph.add_node(info_key, value=value)
self.knowledge_graph.add_edge(query_key, info_key, relation=key)
# Update knowledge centroids
self._update_knowledge_boundaries(query_embedding, confidence)
# Update coreset for diversity
self._update_coreset(query_embedding, related_info)
def _update_knowledge_boundaries(self, embedding: np.ndarray, confidence: float) -> None:
"""Update knowledge boundaries with new information.
Args:
embedding: Embedding of new knowledge
confidence: Confidence in this knowledge
"""
# If no centroids yet, add this as the first one
if not self.knowledge_centroids:
self.knowledge_centroids.append(embedding)
self.knowledge_radius[tuple(embedding)] = 1.0
return
# Find closest centroid
centroid_array = np.vstack(self.knowledge_centroids)
query_array = embedding.reshape(1, -1)
distances = np.linalg.norm(centroid_array - query_array, axis=1)
min_dist = np.min(distances)
min_idx = np.argmin(distances)
nearest_centroid = self.knowledge_centroids[min_idx]
nearest_centroid_tuple = tuple(nearest_centroid)
# Get current radius
current_radius = self.knowledge_radius.get(nearest_centroid_tuple, 1.0)
# If within current radius, update radius based on confidence
if min_dist < current_radius:
# Higher confidence shrinks radius (more precise knowledge)
# Lower confidence expands radius (more uncertainty)
new_radius = current_radius * (1.0 - 0.1 * confidence)
self.knowledge_radius[nearest_centroid_tuple] = new_radius
else:
# Outside known areas, add as new centroid
if len(self.knowledge_centroids) < 100: # Limit number of centroids
self.knowledge_centroids.append(embedding)
self.knowledge_radius[tuple(embedding)] = 1.0
# Otherwise, merge with nearest
else:
# Update nearest centroid with weighted average
updated_centroid = 0.8 * nearest_centroid + 0.2 * embedding
# Update centroid list
self.knowledge_centroids[min_idx] = updated_centroid
# Update radius dict
self.knowledge_radius[tuple(updated_centroid)] = current_radius
del self.knowledge_radius[nearest_centroid_tuple]
def _update_coreset(self, embedding: np.ndarray, info: Dict[str, Any]) -> None:
"""Update coreset of diverse examples.
Args:
embedding: New example embedding
info: Related information
"""
# Skip if no Faiss
if self.index is None:
return
# If coreset is empty, add first example
if not self.coreset_embeddings:
self.coreset.append(info)
self.coreset_embeddings.append(embedding)
self.index.add(np.vstack([embedding]))
return
# Check if this example is sufficiently different from existing examples
# Convert to correct shape for Faiss
query = embedding.reshape(1, -1).astype('float32')
# Search for nearest neighbors
distances, indices = self.index.search(query, 1)
# If sufficiently different, add to coreset
if distances[0][0] > 0.5: # Distance threshold
if len(self.coreset) < 100: # Limit coreset size
self.coreset.append(info)
self.coreset_embeddings.append(embedding)
self.index.add(query)
else:
# Replace most similar item
_, indices = self.index.search(query, len(self.coreset))
most_similar_idx = indices[0][-1]
# Remove from index (need to rebuild index)
self.coreset[most_similar_idx] = info
self.coreset_embeddings[most_similar_idx] = embedding
# Rebuild index
self.index = faiss.IndexFlatL2(self.embedding_dim)
self.index.add(np.vstack(self.coreset_embeddings).astype('float32'))
def identify_knowledge_gaps(self) -> List[Dict[str, Any]]:
"""Identify knowledge gaps for active learning.
Returns:
List of knowledge gap areas to explore
"""
gaps = []
# Skip if no knowledge graph
if self.knowledge_graph is None:
return gaps
# Find areas with low confidence
low_confidence_nodes = [
(node, data) for node, data in self.knowledge_graph.nodes(data=True)
if "confidence" in data and data["confidence"] < 0.5
]
# Group by embedding similarity
clusters = {}
for node, data in low_confidence_nodes:
if "embedding" not in data:
continue
# Find or create cluster
assigned = False
for cluster_id, cluster_data in clusters.items():
centroid = cluster_data["centroid"]
# Compute similarity
similarity = np.dot(data["embedding"], centroid) / (
np.linalg.norm(data["embedding"]) * np.linalg.norm(centroid)
)
if similarity > 0.7: # High similarity threshold
# Add to cluster
cluster_data["nodes"].append((node, data))
# Update centroid
new_centroid = (centroid * len(cluster_data["nodes"]) + data["embedding"]) / (
len(cluster_data["nodes"]) + 1
)
cluster_data["centroid"] = new_centroid
assigned = True
break
if not assigned:
# Create new cluster
cluster_id = f"cluster_{len(clusters)}"
clusters[cluster_id] = {
"centroid": data["embedding"],
"nodes": [(node, data)]
}
# Convert clusters to knowledge gaps
for cluster_id, cluster_data in clusters.items():
if len(cluster_data["nodes"]) >= 2: # Only consider significant clusters
related_info = {}
# Collect information about this cluster
for node, data in cluster_data["nodes"]:
# Get connected nodes
if self.knowledge_graph.has_node(node):
for _, neighbor, edge_data in self.knowledge_graph.out_edges(node, data=True):
neighbor_data = self.knowledge_graph.nodes[neighbor]
if "value" in neighbor_data:
relation = edge_data.get("relation", "related")
related_info[relation] = neighbor_data["value"]
# Create gap description
gap = {
"id": cluster_id,
"centroid": cluster_data["centroid"],
"num_instances": len(cluster_data["nodes"]),
"related_info": related_info,
"confidence": np.mean([d["confidence"] for _, d in cluster_data["nodes"] if "confidence" in d])
}
gaps.append(gap)
# Sort gaps by confidence (ascending) and size (descending)
gaps.sort(key=lambda x: (x["confidence"], -x["num_instances"]))
return gaps
def generate_exploration_query(self, gap: Dict[str, Any]) -> Dict[str, Any]:
"""Generate an exploration query for a knowledge gap.
Args:
gap: Knowledge gap information
Returns:
Exploration query
"""
# Create query from gap centroid
centroid = gap["centroid"]
# Find nearest examples in coreset for additional context
similar_examples = []
if self.coreset_embeddings and len(self.coreset) > 0:
# Convert centroid to correct shape
query = centroid.reshape(1, -1).astype('float32')
# Find nearest neighbors
if self.index is not None:
distances, indices = self.index.search(query, min(3, len(self.coreset)))
# Add nearest examples
for i, idx in enumerate(indices[0]):
if idx < len(self.coreset):
similar_examples.append({
"example": self.coreset[idx],
"distance": distances[0][i]
})
# Generate exploration query
exploration = {
"embedding": centroid,
"gap_id": gap["id"],
"related_info": gap["related_info"],
"confidence": gap["confidence"],
"similar_examples": similar_examples,
"timestamp": time.time()
}
return exploration
def update_from_exploration(self,
gap_id: str,
query_embedding: np.ndarray,
result: Dict[str, Any],
success: bool) -> None:
"""Update knowledge from exploration results.
Args:
gap_id: ID of the knowledge gap
query_embedding: Embedding of the exploration query
result: Result of the exploration
success: Whether the exploration was successful
"""
# Add to exploration history
self.exploration_history.append({
"gap_id": gap_id,
"embedding": query_embedding,
"result": result,
"success": success,
"timestamp": time.time()
})
# Update knowledge with exploration results
self.add_knowledge(
query_embedding=query_embedding,
related_info=result,
confidence=0.8 if success else 0.3
)
# Update uncertainty model from this exploration
self._update_uncertainty_model(query_embedding, result, success)
def _update_uncertainty_model(self,
query_embedding: np.ndarray,
result: Dict[str, Any],
success: bool) -> None:
"""Update uncertainty estimation model.
Args:
query_embedding: Query embedding
result: Exploration result
success: Whether exploration was successful
"""
# Convert to tensor
query_tensor = torch.FloatTensor(query_embedding)
# Target values for training
# If success, low uncertainty (high confidence)
# If failure, high uncertainty (low confidence)
if success:
target = torch.tensor([[0.9, 0.1]]) # [confidence, uncertainty]
else:
target = torch.tensor([[0.2, 0.8]]) # [confidence, uncertainty]
# Update model
self.optimizer.zero_grad()
# Forward pass
prediction = self.uncertainty_estimator(query_tensor.unsqueeze(0))
prediction = F.softmax(prediction, dim=1)
# Compute loss
loss = F.mse_loss(prediction, target)
# Backpropagation
loss.backward()
# Update parameters
self.optimizer.step()
# Multi-task learning system
class MultiTaskLearningSystem:
"""System for learning across multiple task types with shared knowledge and specialized adapters."""
def __init__(self, embedding_dim: int = 768, num_tasks: int = 5):
"""Initialize the multi-task learning system.
Args:
embedding_dim: Dimension of embeddings
num_tasks: Number of task types to support
"""
self.embedding_dim = embedding_dim
self.num_tasks = num_tasks
# Task type registry
self.task_types = {}
# Shared embedding model (backbone)
self.shared_model = nn.Sequential(
nn.Linear(embedding_dim, embedding_dim),
nn.LayerNorm(embedding_dim),
nn.ReLU(),
nn.Linear(embedding_dim, embedding_dim),
nn.LayerNorm(embedding_dim)
)
# Task-specific adapter modules
self.task_adapters = nn.ModuleDict()
# Task projectors (for returning to original space)
self.task_projectors = nn.ModuleDict()
# Task-specific optimizers
self.task_optimizers = {}
# Multi-task performance metrics
self.task_metrics = {}
# Shared optimizer
self.shared_optimizer = torch.optim.Adam(self.shared_model.parameters(), lr=0.001)
def register_task_type(self, task_name: str,
initial_examples: List[Tuple[np.ndarray, np.ndarray]] = None) -> None:
"""Register a new task type.
Args:
task_name: Name of the task
initial_examples: Optional initial examples (input, output embeddings)
"""
if task_name in self.task_types:
return
# Register task
self.task_types[task_name] = {
"examples": [],
"difficulty": 0.5, # Initial difficulty estimate
"performance": 0.0, # Initial performance estimate
"timestamp": time.time()
}
# Create task adapter
self.task_adapters[task_name] = nn.Sequential(
nn.Linear(self.embedding_dim, self.embedding_dim // 2),
nn.LayerNorm(self.embedding_dim // 2),
nn.ReLU(),
nn.Linear(self.embedding_dim // 2, self.embedding_dim)
)
# Create projector back to original space
self.task_projectors[task_name] = nn.Linear(self.embedding_dim, self.embedding_dim)
# Create optimizer
self.task_optimizers[task_name] = torch.optim.Adam(
list(self.task_adapters[task_name].parameters()) +
list(self.task_projectors[task_name].parameters()),
lr=0.001
)
# Initialize metrics
self.task_metrics[task_name] = {
"examples_seen": 0,
"loss_history": [],
"accuracy_history": [],
"last_improvement": time.time()
}
# Add initial examples if provided
if initial_examples:
for input_emb, output_emb in initial_examples:
self.add_task_example(task_name, input_emb, output_emb)
def add_task_example(self, task_name: str,
input_embedding: np.ndarray,
output_embedding: np.ndarray) -> None:
"""Add an example for a specific task.
Args:
task_name: Name of the task
input_embedding: Input embedding
output_embedding: Target output embedding
"""
if task_name not in self.task_types:
self.register_task_type(task_name)
# Convert to tensors
input_tensor = torch.FloatTensor(input_embedding)
output_tensor = torch.FloatTensor(output_embedding)
# Add to examples
self.task_types[task_name]["examples"].append((input_tensor, output_tensor))
# Update metrics
self.task_metrics[task_name]["examples_seen"] += 1
# Limit number of examples stored
if len(self.task_types[task_name]["examples"]) > 100:
self.task_types[task_name]["examples"].pop(0)
# Update model with this example
self._update_model_with_example(task_name, input_tensor, output_tensor)
def _update_model_with_example(self, task_name: str,
input_tensor: torch.Tensor,
output_tensor: torch.Tensor) -> None:
"""Update models with a new example.
Args:
task_name: Name of the task
input_tensor: Input embedding tensor
output_tensor: Target output embedding tensor
"""
# Zero gradients
self.shared_optimizer.zero_grad()
self.task_optimizers[task_name].zero_grad()
# Forward pass through shared model
shared_features = self.shared_model(input_tensor.unsqueeze(0))
# Forward pass through task-specific adapter
task_features = self.task_adapters[task_name](shared_features)
# Project back to original space
predicted_output = self.task_projectors[task_name](task_features)
# Compute loss
loss = F.mse_loss(predicted_output.squeeze(0), output_tensor)
# Backpropagation
loss.backward()
# Update parameters
self.shared_optimizer.step()
self.task_optimizers[task_name].step()
# Update metrics
self.task_metrics[task_name]["loss_history"].append(loss.item())
# Calculate cosine similarity as a proxy for accuracy
with torch.no_grad():
cos_sim = F.cosine_similarity(predicted_output.squeeze(0), output_tensor.unsqueeze(0)).item()
self.task_metrics[task_name]["accuracy_history"].append(cos_sim)
# Check if this is an improvement
if len(self.task_metrics[task_name]["accuracy_history"]) > 1:
prev_best = max(self.task_metrics[task_name]["accuracy_history"][:-1])
if cos_sim > prev_best:
self.task_metrics[task_name]["last_improvement"] = time.time()
# Update overall performance metric
recent_accuracy = self.task_metrics[task_name]["accuracy_history"][-10:]
self.task_types[task_name]["performance"] = sum(recent_accuracy) / len(recent_accuracy)
def process_task(self, task_name: str, input_embedding: np.ndarray) -> np.ndarray:
"""Process an input through a specific task pipeline.
Args:
task_name: Name of the task
input_embedding: Input embedding
Returns:
Predicted output embedding
"""
if task_name not in self.task_types:
# Unknown task type, create new adapter
self.register_task_type(task_name)
# Convert to tensor
input_tensor = torch.FloatTensor(input_embedding)
# Process through model
with torch.no_grad():
# Shared features
shared_features = self.shared_model(input_tensor.unsqueeze(0))
# Task-specific processing
task_features = self.task_adapters[task_name](shared_features)
# Project to output space
output_embedding = self.task_projectors[task_name](task_features)
# Convert back to numpy
output = output_embedding.squeeze(0).numpy()
return output
def get_task_similarity(self, task_name1: str, task_name2: str) -> float:
"""Calculate similarity between two tasks based on adapter weights.
Args:
task_name1: First task name
task_name2: Second task name
Returns:
Similarity score (0-1)
"""
if task_name1 not in self.task_adapters or task_name2 not in self.task_adapters:
return 0.0
# Get adapter parameters as vectors
params1 = []
params2 = []
# Extract parameters
for p1, p2 in zip(self.task_adapters[task_name1].parameters(),
self.task_adapters[task_name2].parameters()):
params1.append(p1.view(-1))
params2.append(p2.view(-1))
# Concatenate all parameters
params1 = torch.cat(params1)
params2 = torch.cat(params2)
# Compute cosine similarity
similarity = F.cosine_similarity(params1.unsqueeze(0), params2.unsqueeze(0)).item()
return similarity
def find_most_similar_task(self, input_embedding: np.ndarray) -> str:
"""Find the most similar task for a new input.
Args:
input_embedding: Input embedding
Returns:
Most similar task name
"""
if not self.task_types:
return None
# Convert to tensor
input_tensor = torch.FloatTensor(input_embedding)
# Get shared features
with torch.no_grad():
shared_features = self.shared_model(input_tensor.unsqueeze(0))
# Try each task adapter and measure error on this input
task_errors = {}
for task_name in self.task_types:
# Get examples for this task
examples = self.task_types[task_name]["examples"]
if not examples:
continue
# Compute error for each example
errors = []
for ex_input, ex_output in examples:
# Process input with shared model
ex_shared = self.shared_model(ex_input.unsqueeze(0))
# Compute feature similarity
similarity = F.cosine_similarity(shared_features, ex_shared).item()
errors.append(1.0 - similarity) # Convert to error
# Average error for this task
if errors:
task_errors[task_name] = sum(errors) / len(errors)
if not task_errors:
return list(self.task_types.keys())[0] # Return first task if no errors computed
# Return task with lowest error
return min(task_errors.items(), key=lambda x: x[1])[0]
def transfer_knowledge(self, source_task: str, target_task: str, strength: float = 0.3) -> None:
"""Transfer knowledge from source task to target task.
Args:
source_task: Source task name
target_task: Target task name
strength: Strength of knowledge transfer (0-1)
"""
if source_task not in self.task_adapters or target_task not in self.task_adapters:
return
# Skip if tasks are identical
if source_task == target_task:
return
# Get source and target adapters
source_adapter = self.task_adapters[source_task]
target_adapter = self.task_adapters[target_task]
# Transfer knowledge through parameter interpolation
with torch.no_grad():
for source_param, target_param in zip(source_adapter.parameters(),
target_adapter.parameters()):
# Interpolate parameters
new_param = (1 - strength) * target_param + strength * source_param
target_param.copy_(new_param)
# Do the same for projectors
source_projector = self.task_projectors[source_task]
target_projector = self.task_projectors[target_task]
with torch.no_grad():
for source_param, target_param in zip(source_projector.parameters(),
target_projector.parameters()):
# Interpolate parameters
new_param = (1 - strength) * target_param + strength * source_param
target_param.copy_(new_param)
def get_task_metrics(self) -> Dict[str, Dict[str, Any]]:
"""Get performance metrics for all tasks.
Returns:
Dictionary of task metrics
"""
metrics = {}
for task_name, task_data in self.task_types.items():
task_metrics = self.task_metrics[task_name]
# Calculate recent performance
recent_acc = task_metrics["accuracy_history"][-10:] if task_metrics["accuracy_history"] else []
recent_perf = sum(recent_acc) / len(recent_acc) if recent_acc else 0.0
# Determine if task is improving
improving = False
if len(task_metrics["accuracy_history"]) >= 10:
first_half = task_metrics["accuracy_history"][-10:-5]
second_half = task_metrics["accuracy_history"][-5:]
if sum(second_half) / 5 > sum(first_half) / 5:
improving = True
# Collect metrics
metrics[task_name] = {
"examples_seen": task_metrics["examples_seen"],
"current_performance": recent_perf,
"registered_time": task_data["timestamp"],
"last_improvement": task_metrics["last_improvement"],
"improving": improving,
"difficulty": task_data["difficulty"]
}
# Compute task similarities
similarities = {}
for other_task in self.task_types:
if other_task != task_name:
similarity = self.get_task_similarity(task_name, other_task)
similarities[other_task] = similarity
metrics[task_name]["task_similarities"] = similarities
return metrics
# Causal inference system for tool selection
class CausalToolSelectionModel:
"""Causal inference system for understanding tool cause-effect relationships."""
def __init__(self, embedding_dim: int = 768):
"""Initialize the causal inference system.
Args:
embedding_dim: Dimension of embeddings
"""
self.embedding_dim = embedding_dim
# Causal graph
self.graph = nx.DiGraph() if HAVE_NETWORKX else None
# Tool variables (nodes in the graph)
self.tool_nodes = set()
# Context variables
self.context_nodes = set()
# Structural equation models
self.models = {}
# Intervention effects
self.interventions = {}
# Counterfactual cache
self.counterfactuals = {}
# Neural estimator for complex relationships
self.neural_estimator = nn.Sequential(
nn.Linear(embedding_dim * 2, embedding_dim),
nn.LayerNorm(embedding_dim),
nn.ReLU(),
nn.Linear(embedding_dim, embedding_dim // 2),
nn.LayerNorm(embedding_dim // 2),
nn.ReLU(),
nn.Linear(embedding_dim // 2, 1),
nn.Sigmoid()
)
# Optimizer
self.optimizer = torch.optim.Adam(self.neural_estimator.parameters(), lr=0.001)
def add_tool_node(self, tool_name: str):
"""Add a tool as a node in the causal graph.
Args:
tool_name: Name of the tool
"""
self.tool_nodes.add(tool_name)
if self.graph is not None:
self.graph.add_node(tool_name, type="tool")
def add_context_node(self, context_name: str):
"""Add a context variable as a node.
Args:
context_name: Name of the context variable
"""
self.context_nodes.add(context_name)
if self.graph is not None:
self.graph.add_node(context_name, type="context")
def add_causal_link(self, cause: str, effect: str, strength: float = 0.5):
"""Add a causal link between nodes.
Args:
cause: Name of the cause node
effect: Name of the effect node
strength: Strength of the causal relationship (0-1)
"""
if self.graph is not None:
self.graph.add_edge(cause, effect, weight=strength)
def observe(self, query_embedding: np.ndarray, context: Dict[str, Any],
tool_sequence: List[str], outcomes: List[Dict[str, Any]]):
"""Record an observation of tool usage and outcomes.
Args:
query_embedding: Embedding of the query
context: Context variables
tool_sequence: Sequence of tools used
outcomes: Outcomes of each tool (success, result, etc.)
"""
# Convert embeddings to tensors
query_tensor = torch.FloatTensor(query_embedding)
# Process each tool in the sequence
for i, (tool, outcome) in enumerate(zip(tool_sequence, outcomes)):
# Add tool if not already in graph
if tool not in self.tool_nodes:
self.add_tool_node(tool)
# Add context variables
for ctx_name, ctx_value in context.items():
ctx_key = f"{ctx_name}:{ctx_value}" if isinstance(ctx_value, (str, int, bool)) else ctx_name
if ctx_key not in self.context_nodes:
self.add_context_node(ctx_key)
# Add causal link from context to tool
self.add_causal_link(ctx_key, tool, 0.3) # Initial strength estimate
# Add causal links between tools in sequence
if i > 0:
prev_tool = tool_sequence[i-1]
prev_outcome = outcomes[i-1]
# Link strength based on previous success
strength = 0.7 if prev_outcome.get("success", False) else 0.2
self.add_causal_link(prev_tool, tool, strength)
# Update neural estimator
if i > 0 and hasattr(prev_outcome, "embedding") and hasattr(outcome, "embedding"):
# Training example for neural estimator
prev_embed = torch.FloatTensor(prev_outcome["embedding"])
curr_embed = torch.FloatTensor(outcome["embedding"])
combined = torch.cat([prev_embed, curr_embed])
target = torch.FloatTensor([strength])
# Update neural estimator
self.optimizer.zero_grad()
pred = self.neural_estimator(combined.unsqueeze(0))
loss = F.mse_loss(pred, target)
loss.backward()
self.optimizer.step()
def infer_effects(self, intervention_tool: str) -> Dict[str, float]:
"""Infer the effects of using a specific tool.
Args:
intervention_tool: The tool to intervene with
Returns:
Dictionary of effects on other tools/outcomes
"""
if self.graph is None:
return {}
# Use do-calculus to determine causal effects
effects = {}
# Create a modified graph for the intervention
intervention_graph = self.graph.copy()
# Remove incoming edges to the intervention tool (do-operator)
for pred in list(self.graph.predecessors(intervention_tool)):
intervention_graph.remove_edge(pred, intervention_tool)
# Compute effect on each tool
for tool in self.tool_nodes:
if tool == intervention_tool:
continue
# Check if there's a path from intervention to this tool
if nx.has_path(intervention_graph, intervention_tool, tool):
# Compute causal effect strength using path weights
paths = list(nx.all_simple_paths(intervention_graph, intervention_tool, tool))
effect = 0.0
for path in paths:
# Calculate path strength as product of edge weights
path_strength = 1.0
for i in range(len(path) - 1):
path_strength *= intervention_graph[path[i]][path[i+1]]["weight"]
effect += path_strength
# Normalize for multiple paths
if len(paths) > 0:
effect /= len(paths)
effects[tool] = effect
# Cache result
self.interventions[intervention_tool] = effects
return effects
def estimate_counterfactual(self, observed_tools: List[str],
alternative_tool: str) -> float:
"""Estimate the outcome difference if an alternative tool had been used.
Args:
observed_tools: The tools that were actually used
alternative_tool: The alternative tool to consider
Returns:
Estimated improvement (positive) or decline (negative) in outcome
"""
# Use nested counterfactual estimation
key = (tuple(observed_tools), alternative_tool)
if key in self.counterfactuals:
return self.counterfactuals[key]
if self.graph is None or not observed_tools:
return 0.0
# Find the position to replace
best_pos = 0
best_effect = -float('inf')
for i in range(len(observed_tools)):
# Consider replacing the tool at position i
tools_copy = observed_tools.copy()
original_tool = tools_copy[i]
tools_copy[i] = alternative_tool
# Estimate effect of this change
effect = 0.0
# Effect from replacing the original tool
if original_tool in self.interventions:
effect -= sum(self.interventions[original_tool].values())
# Effect from using the alternative tool
if alternative_tool in self.interventions:
effect += sum(self.interventions[alternative_tool].values())
# Check if this is the best position
if effect > best_effect:
best_effect = effect
best_pos = i
# Estimate the counterfactual difference
counterfactual = best_effect / len(observed_tools)
# Cache the result
self.counterfactuals[key] = counterfactual
return counterfactual
# Advanced Graph Neural Network for modeling tool relationships
class ToolRelationshipGNN(nn.Module):
"""Graph Neural Network for modeling relationships between tools."""
def __init__(self, embedding_dim: int, hidden_dim: int, num_tools: int):
"""Initialize the GNN with appropriate dimensions.
Args:
embedding_dim: Dimension of input embeddings
hidden_dim: Dimension of hidden layers
num_tools: Number of tools in the system
"""
super(ToolRelationshipGNN, self).__init__()
self.embedding_dim = embedding_dim
self.hidden_dim = hidden_dim
self.num_tools = num_tools
# Node embedding layers
self.node_encoder = nn.Sequential(
nn.Linear(embedding_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU()
)
# Edge embedding layers
self.edge_encoder = nn.Sequential(
nn.Linear(embedding_dim * 2, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU()
)
# Message passing layers
self.message_mlp = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
# Node update layers
self.node_update = nn.GRUCell(hidden_dim, hidden_dim)
# Output projection
self.output_projection = nn.Linear(hidden_dim, embedding_dim)
# Attention mechanism for node aggregation
self.attention = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, 1)
)
def forward(self, node_embeddings: torch.Tensor, adjacency_matrix: torch.Tensor) -> torch.Tensor:
"""Forward pass through the GNN.
Args:
node_embeddings: Tool embeddings tensor [num_tools, embedding_dim]
adjacency_matrix: Tool relationship adjacency matrix [num_tools, num_tools]
Returns:
Updated node embeddings
"""
batch_size = node_embeddings.shape[0]
# Initial node encoding
node_hidden = self.node_encoder(node_embeddings) # [batch, num_tools, hidden_dim]
# Message passing (3 rounds)
for _ in range(3):
# Compute messages for each edge
messages = []
attention_weights = []
for i in range(self.num_tools):
for j in range(self.num_tools):
# Only consider edges that exist in the adjacency matrix
if adjacency_matrix[i, j] > 0:
# Combine source and destination node features
edge_features = torch.cat([node_hidden[:, i], node_hidden[:, j]], dim=1)
message = self.message_mlp(edge_features)
messages.append((j, message)) # Message to node j
# Compute attention weight
attn_input = torch.cat([node_hidden[:, j], message], dim=1)
weight = self.attention(attn_input)
attention_weights.append((j, weight))
# Aggregate messages for each node using attention
aggregated_messages = torch.zeros(batch_size, self.num_tools, self.hidden_dim,
device=node_embeddings.device)
# Group messages by destination node
node_messages = defaultdict(list)
node_weights = defaultdict(list)
for j, message in messages:
node_messages[j].append(message)
for j, weight in attention_weights:
node_weights[j].append(weight)
# Apply attention for each node
for j in range(self.num_tools):
if j in node_messages:
stacked_messages = torch.stack(node_messages[j], dim=1) # [batch, num_msgs, hidden]
stacked_weights = torch.stack(node_weights[j], dim=1) # [batch, num_msgs, 1]
# Apply softmax to get attention distribution
normalized_weights = F.softmax(stacked_weights, dim=1)
# Weighted sum of messages
node_message = torch.sum(stacked_messages * normalized_weights, dim=1)
aggregated_messages[:, j] = node_message
# Update node states using GRU
node_hidden_reshaped = node_hidden.view(batch_size * self.num_tools, self.hidden_dim)
aggregated_messages_reshaped = aggregated_messages.view(batch_size * self.num_tools, self.hidden_dim)
updated_hidden = self.node_update(aggregated_messages_reshaped, node_hidden_reshaped)
node_hidden = updated_hidden.view(batch_size, self.num_tools, self.hidden_dim)
# Project back to embedding space
output_embeddings = self.output_projection(node_hidden)
return output_embeddings
# Enhanced Meta-Learning System
class MetaLearningOptimizer:
"""Meta-learning system that learns to generalize across different types of tasks."""
def __init__(self, embedding_dim: int, num_tools: int, learning_rate: float = 0.001):
"""Initialize the meta-learning optimizer.
Args:
embedding_dim: Dimension of embeddings
num_tools: Number of tools in the system
learning_rate: Learning rate for meta-updates
"""
self.embedding_dim = embedding_dim
self.num_tools = num_tools
self.learning_rate = learning_rate
# Task type embeddings
self.task_embeddings = {}
# Tool parameter embeddings
self.tool_parameters = nn.ParameterDict()
# Meta-network for adaptation
self.meta_network = nn.Sequential(
nn.Linear(embedding_dim * 2, embedding_dim),
nn.LayerNorm(embedding_dim),
nn.ReLU(),
nn.Linear(embedding_dim, embedding_dim)
)
# Optimizer for meta-parameters
self.optimizer = torch.optim.Adam(self.meta_network.parameters(), lr=learning_rate)
# Task buffers for meta-learning
self.task_buffers = defaultdict(list)
self.max_buffer_size = 100
def register_tool(self, tool_name: str, initial_embedding: np.ndarray):
"""Register a tool with the meta-learning system.
Args:
tool_name: Name of the tool
initial_embedding: Initial embedding for the tool
"""
self.tool_parameters[tool_name] = nn.Parameter(
torch.FloatTensor(initial_embedding),
requires_grad=True
)
def add_task_example(self, task_type: str, query_embedding: np.ndarray,
selected_tool: str, success: bool, reward: float):
"""Add an example to a task buffer for meta-learning.
Args:
task_type: Type of task (e.g., "search", "explanation")
query_embedding: Embedding of the query
selected_tool: Name of the selected tool
success: Whether the tool was successful
reward: The reward received
"""
# Convert to tensor
query_tensor = torch.FloatTensor(query_embedding)
# Add to task buffer
self.task_buffers[task_type].append({
"query": query_tensor,
"tool": selected_tool,
"success": success,
"reward": reward
})
# Limit buffer size
if len(self.task_buffers[task_type]) > self.max_buffer_size:
self.task_buffers[task_type].pop(0)
def meta_update(self):
"""Perform a meta-update step to improve adaptation capability."""
if not self.task_buffers:
return
# Sample a batch of tasks
sampled_tasks = random.sample(list(self.task_buffers.keys()),
min(5, len(self.task_buffers)))
meta_loss = 0.0
for task_type in sampled_tasks:
# Skip tasks with too few examples
if len(self.task_buffers[task_type]) < 5:
continue
# Sample examples from this task
examples = random.sample(self.task_buffers[task_type],
min(10, len(self.task_buffers[task_type])))
# Compute task embedding if not already computed
if task_type not in self.task_embeddings:
# Average query embeddings as task embedding
query_tensors = [ex["query"] for ex in examples]
task_embedding = torch.stack(query_tensors).mean(dim=0)
self.task_embeddings[task_type] = task_embedding
# Create adapted tool parameters for this task
adapted_params = {}
for tool_name, param in self.tool_parameters.items():
# Concatenate task embedding with tool parameter
adaptation_input = torch.cat([self.task_embeddings[task_type], param])
# Generate adaptation
adaptation = self.meta_network(adaptation_input.unsqueeze(0)).squeeze(0)
# Apply adaptation
adapted_params[tool_name] = param + adaptation
# Compute loss for this task
task_loss = 0.0
for example in examples:
query = example["query"]
selected_tool = example["tool"]
reward = example["reward"]
# Compute scores for all tools
scores = {}
for tool_name, param in adapted_params.items():
score = torch.dot(query, param) / (query.norm() * param.norm())
scores[tool_name] = score
# Convert to probability distribution
logits = torch.stack(list(scores.values()))
probs = F.softmax(logits, dim=0)
# Get index of selected tool
tool_idx = list(scores.keys()).index(selected_tool)
# Negative log likelihood weighted by reward
nll = -torch.log(probs[tool_idx])
task_loss += nll * (1.0 - reward) # Lower loss for high rewards
# Add to meta loss
meta_loss += task_loss / len(examples)
# Normalize by number of tasks
meta_loss /= len(sampled_tasks)
# Update meta-parameters
self.optimizer.zero_grad()
meta_loss.backward()
self.optimizer.step()
return meta_loss.item()
def get_adapted_embeddings(self, task_type: str) -> Dict[str, np.ndarray]:
"""Get task-adapted embeddings for tools.
Args:
task_type: Type of task
Returns:
Dictionary of adapted tool embeddings
"""
# Return original embeddings if task type is unknown
if task_type not in self.task_embeddings:
return {name: param.detach().numpy() for name, param in self.tool_parameters.items()}
# Create adapted embeddings
adapted_embeddings = {}
for tool_name, param in self.tool_parameters.items():
# Concatenate task embedding with tool parameter
adaptation_input = torch.cat([self.task_embeddings[task_type], param])
# Generate adaptation
adaptation = self.meta_network(adaptation_input.unsqueeze(0)).squeeze(0)
# Apply adaptation
adapted_embeddings[tool_name] = (param + adaptation).detach().numpy()
return adapted_embeddings
@dataclass
class ToolUsageRecord:
"""Record of a tool usage for optimization."""
query: str
tool_name: str
execution_time: float
token_usage: Dict[str, int]
success: bool
timestamp: float
class ToolUsageTracker:
"""Tracks tool usage for optimization."""
def __init__(self, max_records: int = 10000):
"""
Initialize the tool usage tracker.
Args:
max_records: Maximum number of records to store
"""
self.records = deque(maxlen=max_records)
def add_record(self, record: ToolUsageRecord) -> None:
"""Add a record to the tracker."""
self.records.append(record)
def get_tool_stats(self) -> Dict[str, Dict[str, Any]]:
"""
Get statistics about tool usage.
Returns:
Dictionary of tool statistics
"""
stats = {}
# Group by tool
for record in self.records:
if record.tool_name not in stats:
stats[record.tool_name] = {
"count": 0,
"success_count": 0,
"total_time": 0,
"token_usage": {"prompt": 0, "completion": 0, "total": 0},
}
stats[record.tool_name]["count"] += 1
if record.success:
stats[record.tool_name]["success_count"] += 1
stats[record.tool_name]["total_time"] += record.execution_time
# Update token usage
for key, value in record.token_usage.items():
stats[record.tool_name]["token_usage"][key] += value
# Compute derived metrics
for tool_name, tool_stats in stats.items():
tool_stats["success_rate"] = tool_stats["success_count"] / tool_stats["count"] if tool_stats["count"] > 0 else 0
tool_stats["avg_time"] = tool_stats["total_time"] / tool_stats["count"] if tool_stats["count"] > 0 else 0
for key in tool_stats["token_usage"]:
tool_stats[f"avg_{key}_tokens"] = tool_stats["token_usage"][key] / tool_stats["count"] if tool_stats["count"] > 0 else 0
return stats
class ToolSelectionOptimizer:
"""
Optimizes tool selection based on user queries and context.
Uses reinforcement learning to improve tool selection over time.
"""
def __init__(
self,
tool_registry: Any,
data_dir: str = "./data/rl",
enable_rl: bool = True,
model_update_interval: int = 100,
embedding_model_name: str = "all-MiniLM-L6-v2",
embedding_cache_size: int = 1000,
):
"""
Initialize the tool selection optimizer.
Args:
tool_registry: Registry containing available tools
data_dir: Directory to store data and models
enable_rl: Whether to enable reinforcement learning
model_update_interval: How often to update models (in observations)
embedding_model_name: Name of the sentence embedding model
embedding_cache_size: Size of the embedding cache
"""
self.tool_registry = tool_registry
self.data_dir = data_dir
self.enable_rl = enable_rl
self.model_update_interval = model_update_interval
# Create data directory
os.makedirs(data_dir, exist_ok=True)
# Initialize tool usage tracker
self.tracker = ToolUsageTracker()
# Initialize embedding model if available
self.embedding_model = None
self.embedding_cache = {}
self.embedding_cache_keys = deque(maxlen=embedding_cache_size)
if HAVE_SENTENCE_TRANSFORMERS and enable_rl:
try:
self.embedding_model = SentenceTransformer(embedding_model_name)
except Exception as e:
print(f"Warning: Failed to load embedding model: {e}")
# Initialize RL system if enabled
self.rl_system = None
if enable_rl:
# Define a simple context evaluator
def context_evaluator(context):
# This is a placeholder - in a real system, we'd evaluate the quality
# based on metrics like response coherence, success rate, etc.
return 0.5
# Create RL system
self.rl_system = ToolSelectionGRPO(
tool_registry=tool_registry,
context_evaluator=context_evaluator,
update_interval=model_update_interval,
)
# Load existing models and data if available
self._load_data()
def select_tool(self, query: str, context: Dict[str, Any], visualizer=None) -> str:
"""
Select the best tool to use for a given query.
Args:
query: User query
context: Conversation context
visualizer: Optional visualizer to display the selection process
Returns:
Name of the selected tool
"""
# If RL is not enabled, use default selection logic
if not self.enable_rl or self.rl_system is None:
return self._default_tool_selection(query, context)
# Use RL system to select tool
try:
return self.rl_system.select_tool(query, context, visualizer=visualizer)
except Exception as e:
print(f"Error in RL tool selection: {e}")
return self._default_tool_selection(query, context)
def record_tool_usage(
self,
query: str,
tool_name: str,
execution_time: float,
token_usage: Dict[str, int],
success: bool,
context: Optional[Dict[str, Any]] = None,
result: Optional[Any] = None,
) -> None:
"""
Record tool usage for optimization.
Args:
query: User query
tool_name: Name of the tool used
execution_time: Time taken to execute the tool
token_usage: Token usage information
success: Whether the tool usage was successful
context: Conversation context (for RL)
result: Result of the tool usage (for RL)
"""
# Create and add record
record = ToolUsageRecord(
query=query,
tool_name=tool_name,
execution_time=execution_time,
token_usage=token_usage,
success=success,
timestamp=time.time(),
)
self.tracker.add_record(record)
# Update RL system if enabled
if self.enable_rl and self.rl_system is not None and context is not None:
try:
# Find the agent that made this selection
for agent_id in self.rl_system.current_episode:
if agent_id in self.rl_system.current_episode and self.rl_system.current_episode[agent_id]:
# Observe the result
self.rl_system.observe_result(
agent_id=agent_id,
result=result,
context=context,
done=True,
)
except Exception as e:
print(f"Error updating RL system: {e}")
# Save data periodically
if len(self.tracker.records) % 50 == 0:
self._save_data()
def get_tool_recommendations(self, query: str) -> List[Tuple[str, float]]:
"""
Get tool recommendations for a query with confidence scores.
Args:
query: User query
Returns:
List of (tool_name, confidence) tuples
"""
# Get query embedding
if self.embedding_model is not None:
try:
query_embedding = self._get_embedding(query)
# Get all tools and their embeddings
tools = self.tool_registry.get_all_tools()
tool_scores = []
for tool in tools:
# Get tool description embedding
tool_desc = tool.description
tool_embedding = self._get_embedding(tool_desc)
# Compute similarity score
similarity = self._cosine_similarity(query_embedding, tool_embedding)
tool_scores.append((tool.name, similarity))
# Sort by score
tool_scores.sort(key=lambda x: x[1], reverse=True)
return tool_scores
except Exception as e:
print(f"Error computing tool recommendations: {e}")
# Fallback to default ordering
return [(tool, 0.5) for tool in self.tool_registry.get_all_tool_names()]
def update_model(self) -> Dict[str, Any]:
"""
Manually trigger a model update.
Returns:
Dictionary of update metrics
"""
if not self.enable_rl or self.rl_system is None:
return {"status": "RL not enabled"}
try:
metrics = self.rl_system.update()
# Save updated model
self.rl_system.save()
return {"status": "success", "metrics": metrics}
except Exception as e:
return {"status": "error", "message": str(e)}
def _default_tool_selection(self, query: str, context: Dict[str, Any]) -> str:
"""
Default tool selection logic when RL is not available.
Args:
query: User query
context: Conversation context
Returns:
Name of the selected tool
"""
# Use a simple rule-based approach as fallback
tools = self.tool_registry.get_all_tool_names()
# Look for keywords in the query
query_lower = query.lower()
if "file" in query_lower and "read" in query_lower:
for tool in tools:
if tool.lower() == "view":
return tool
if "search" in query_lower or "find" in query_lower:
for tool in tools:
if "grep" in tool.lower():
return tool
if "execute" in query_lower or "run" in query_lower:
for tool in tools:
if tool.lower() == "bash":
return tool
if "edit" in query_lower or "change" in query_lower:
for tool in tools:
if tool.lower() == "edit":
return tool
# Default to the first tool
return tools[0] if tools else "View"
def _get_embedding(self, text: str) -> np.ndarray:
"""
Get embedding for text with caching.
Args:
text: Text to embed
Returns:
Embedding vector
"""
if text in self.embedding_cache:
return self.embedding_cache[text]
if self.embedding_model is None:
raise ValueError("Embedding model not available")
# Generate embedding
embedding = self.embedding_model.encode(text, show_progress_bar=False)
# Cache embedding
if len(self.embedding_cache_keys) >= self.embedding_cache_keys.maxlen:
# Remove oldest key if cache is full
oldest_key = self.embedding_cache_keys.popleft()
self.embedding_cache.pop(oldest_key, None)
self.embedding_cache[text] = embedding
self.embedding_cache_keys.append(text)
return embedding
def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
"""
Compute cosine similarity between two vectors.
Args:
a: First vector
b: Second vector
Returns:
Cosine similarity score
"""
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
def _save_data(self) -> None:
"""Save data and models."""
try:
# Create data file
data_path = os.path.join(self.data_dir, "tool_usage_data.json")
# Convert records to serializable format
records_data = []
for record in self.tracker.records:
records_data.append({
"query": record.query,
"tool_name": record.tool_name,
"execution_time": record.execution_time,
"token_usage": record.token_usage,
"success": record.success,
"timestamp": record.timestamp,
})
# Write data
with open(data_path, "w") as f:
json.dump(records_data, f)
# Save RL model if available
if self.enable_rl and self.rl_system is not None:
self.rl_system.save()
except Exception as e:
print(f"Error saving optimizer data: {e}")
def _load_data(self) -> None:
"""Load data and models."""
try:
# Load data file
data_path = os.path.join(self.data_dir, "tool_usage_data.json")
if os.path.exists(data_path):
with open(data_path, "r") as f:
records_data = json.load(f)
# Convert to records
for record_data in records_data:
record = ToolUsageRecord(
query=record_data["query"],
tool_name=record_data["tool_name"],
execution_time=record_data["execution_time"],
token_usage=record_data["token_usage"],
success=record_data["success"],
timestamp=record_data["timestamp"],
)
self.tracker.add_record(record)
# Load RL model if available
if self.enable_rl and self.rl_system is not None:
try:
self.rl_system.load()
except Exception as e:
print(f"Error loading RL model: {e}")
except Exception as e:
print(f"Error loading optimizer data: {e}")
class ToolSelectionManager:
"""
Manages tool selection for Claude Code Python.
Provides an interface for selecting tools and recording usage.
"""
def __init__(
self,
tool_registry: Any,
enable_optimization: bool = True,
data_dir: str = "./data/rl",
):
"""
Initialize the tool selection manager.
Args:
tool_registry: Registry containing available tools
enable_optimization: Whether to enable optimization
data_dir: Directory to store data and models
"""
self.tool_registry = tool_registry
self.enable_optimization = enable_optimization
# Initialize optimizer if enabled
self.optimizer = None
if enable_optimization:
self.optimizer = ToolSelectionOptimizer(
tool_registry=tool_registry,
data_dir=data_dir,
enable_rl=True,
)
def select_tool(self, query: str, context: Dict[str, Any]) -> str:
"""
Select the best tool to use for a given query.
Args:
query: User query
context: Conversation context
Returns:
Name of the selected tool
"""
if self.optimizer is not None:
return self.optimizer.select_tool(query, context)
# Use default selection if optimizer is not available
return self._default_selection(query)
def record_tool_usage(
self,
query: str,
tool_name: str,
execution_time: float,
token_usage: Dict[str, int],
success: bool,
context: Optional[Dict[str, Any]] = None,
result: Optional[Any] = None,
) -> None:
"""
Record tool usage for optimization.
Args:
query: User query
tool_name: Name of the tool used
execution_time: Time taken to execute the tool
token_usage: Token usage information
success: Whether the tool usage was successful
context: Conversation context (for RL)
result: Result of the tool usage (for RL)
"""
if self.optimizer is not None:
self.optimizer.record_tool_usage(
query=query,
tool_name=tool_name,
execution_time=execution_time,
token_usage=token_usage,
success=success,
context=context,
result=result,
)
def get_tool_recommendations(self, query: str) -> List[Tuple[str, float]]:
"""
Get tool recommendations for a query with confidence scores.
Args:
query: User query
Returns:
List of (tool_name, confidence) tuples
"""
if self.optimizer is not None:
return self.optimizer.get_tool_recommendations(query)
# Return default recommendations if optimizer is not available
return [(tool, 0.5) for tool in self.tool_registry.get_all_tool_names()]
def _default_selection(self, query: str) -> str:
"""
Default tool selection logic when optimization is not available.
Args:
query: User query
Returns:
Name of the selected tool
"""
# Use a simple rule-based approach as fallback
tools = self.tool_registry.get_all_tool_names()
# Default to the first tool
return tools[0] if tools else "View"
```