This is page 4 of 4. Use http://codebase.md/arthurcolle/openai-mcp?page={x} to view the full context.
# Directory Structure
```
├── .gitignore
├── claude_code
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-312.pyc
│ │ └── mcp_server.cpython-312.pyc
│ ├── claude.py
│ ├── commands
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-312.pyc
│ │ │ └── serve.cpython-312.pyc
│ │ ├── client.py
│ │ ├── multi_agent_client.py
│ │ └── serve.py
│ ├── config
│ │ └── __init__.py
│ ├── examples
│ │ ├── agents_config.json
│ │ ├── claude_mcp_config.html
│ │ ├── claude_mcp_config.json
│ │ ├── echo_server.py
│ │ └── README.md
│ ├── lib
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ └── __init__.cpython-312.pyc
│ │ ├── context
│ │ │ └── __init__.py
│ │ ├── monitoring
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── __init__.cpython-312.pyc
│ │ │ │ └── server_metrics.cpython-312.pyc
│ │ │ ├── cost_tracker.py
│ │ │ └── server_metrics.py
│ │ ├── providers
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ └── openai.py
│ │ ├── rl
│ │ │ ├── __init__.py
│ │ │ ├── grpo.py
│ │ │ ├── mcts.py
│ │ │ └── tool_optimizer.py
│ │ ├── tools
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── __init__.cpython-312.pyc
│ │ │ │ ├── base.cpython-312.pyc
│ │ │ │ ├── file_tools.cpython-312.pyc
│ │ │ │ └── manager.cpython-312.pyc
│ │ │ ├── ai_tools.py
│ │ │ ├── base.py
│ │ │ ├── code_tools.py
│ │ │ ├── file_tools.py
│ │ │ ├── manager.py
│ │ │ └── search_tools.py
│ │ └── ui
│ │ ├── __init__.py
│ │ └── tool_visualizer.py
│ ├── mcp_server.py
│ ├── README_MCP_CLIENT.md
│ ├── README_MULTI_AGENT.md
│ └── util
│ └── __init__.py
├── claude.py
├── cli.py
├── data
│ └── prompt_templates.json
├── deploy_modal_mcp.py
├── deploy.sh
├── examples
│ ├── agents_config.json
│ └── echo_server.py
├── install.sh
├── mcp_modal_adapter.py
├── mcp_server.py
├── modal_mcp_server.py
├── README_modal_mcp.md
├── README.md
├── requirements.txt
├── setup.py
├── static
│ └── style.css
├── templates
│ └── index.html
└── web-client.html
```
# Files
--------------------------------------------------------------------------------
/modal_mcp_server.py:
--------------------------------------------------------------------------------
```python
import modal
import logging
import time
import uuid
import json
import asyncio
import hashlib
import threading
import concurrent.futures
from pathlib import Path
from typing import Dict, List, Optional, Any, Tuple, Union, AsyncIterator
from datetime import datetime, timedelta
from collections import deque
from fastapi import FastAPI, Request, Depends, HTTPException, status, BackgroundTasks
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# Create FastAPI app
api_app = FastAPI(
title="Advanced LLM Inference API",
description="Enterprise-grade OpenAI-compatible LLM serving API with multiple model support, streaming, and advanced caching",
version="1.1.0"
)
# Add CORS middleware
api_app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # For production, specify specific origins instead of wildcard
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Security setup
security = HTTPBearer()
# Token bucket rate limiter
class TokenBucket:
"""
Token bucket algorithm for rate limiting.
Each user gets a bucket that fills at a constant rate.
"""
def __init__(self):
self.buckets = {}
self.lock = threading.Lock()
def _get_bucket(self, user_id, rate_limit):
"""Get or create a bucket for a user"""
now = time.time()
if user_id not in self.buckets:
# Initialize with full bucket
self.buckets[user_id] = {
"tokens": rate_limit,
"last_refill": now,
"rate": rate_limit / 60.0 # tokens per second
}
return self.buckets[user_id]
bucket = self.buckets[user_id]
# Update rate if it changed
bucket["rate"] = rate_limit / 60.0
# Refill tokens based on time elapsed
elapsed = now - bucket["last_refill"]
new_tokens = elapsed * bucket["rate"]
bucket["tokens"] = min(rate_limit, bucket["tokens"] + new_tokens)
bucket["last_refill"] = now
return bucket
def consume(self, user_id, tokens=1, rate_limit=60):
"""
Consume tokens from a user's bucket.
Returns True if tokens were consumed, False otherwise.
"""
with self.lock:
bucket = self._get_bucket(user_id, rate_limit)
if bucket["tokens"] >= tokens:
bucket["tokens"] -= tokens
return True
return False
# Create rate limiter
rate_limiter = TokenBucket()
# Define the container image with necessary dependencies
vllm_image = (
modal.Image.debian_slim(python_version="3.10")
.pip_install(
"vllm==0.7.3", # Updated version
"huggingface_hub[hf_transfer]==0.26.2",
"flashinfer-python==0.2.0.post2",
"fastapi>=0.95.0",
"uvicorn>=0.15.0",
"pydantic>=2.0.0",
"tiktoken>=0.5.1",
extra_index_url="https://flashinfer.ai/whl/cu124/torch2.5",
)
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) # faster model transfers
.env({"VLLM_USE_V1": "1"}) # Enable V1 engine for better performance
)
# Define llama.cpp image for alternative models
llama_cpp_image = (
modal.Image.debian_slim(python_version="3.10")
.apt_install("git", "build-essential", "cmake", "curl", "libcurl4-openssl-dev")
.pip_install(
"huggingface_hub==0.26.2",
"hf_transfer>=0.1.4",
"fastapi>=0.95.0",
"uvicorn>=0.15.0",
"pydantic>=2.0.0"
)
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
.run_commands(
"git clone https://github.com/ggerganov/llama.cpp",
"cmake llama.cpp -B llama.cpp/build -DBUILD_SHARED_LIBS=OFF -DLLAMA_CURL=ON",
"cmake --build llama.cpp/build --config Release -j --target llama-cli",
"cp llama.cpp/build/bin/llama-* /usr/local/bin/"
)
)
# Set up model configurations
MODELS_DIR = "/models"
VLLM_MODELS = {
"llama3-8b": {
"id": "llama3-8b",
"name": "neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w4a16",
"config": "config.json", # Ensure this file is present in the model directory
"revision": "a7c09948d9a632c2c840722f519672cd94af885d",
"max_tokens": 4096,
"loaded": False
},
"mistral-7b": {
"id": "mistral-7b",
"name": "mistralai/Mistral-7B-Instruct-v0.2",
"revision": "main",
"max_tokens": 4096,
"loaded": False
},
# Small model for quick loading
"tiny-llama-1.1b": {
"id": "tiny-llama-1.1b",
"name": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"revision": "main",
"max_tokens": 2048,
"loaded": False
}
}
LLAMA_CPP_MODELS = {
"deepseek-r1": {
"id": "deepseek-r1",
"name": "unsloth/DeepSeek-R1-GGUF",
"quant": "UD-IQ1_S",
"pattern": "*UD-IQ1_S*",
"revision": "02656f62d2aa9da4d3f0cdb34c341d30dd87c3b6",
"gpu": "L40S:4",
"max_tokens": 4096,
"loaded": False
},
"phi-4": {
"id": "phi-4",
"name": "unsloth/phi-4-GGUF",
"quant": "Q2_K",
"pattern": "*Q2_K*",
"revision": None,
"gpu": "L40S:4", # Use GPU for better performance
"max_tokens": 4096,
"loaded": False
},
# Small model for quick loading
"phi-2": {
"id": "phi-2",
"name": "TheBloke/phi-2-GGUF",
"quant": "Q4_K_M",
"pattern": "*Q4_K_M.gguf",
"revision": "main",
"gpu": None, # Can run on CPU
"max_tokens": 2048,
"loaded": False
}
}
DEFAULT_MODEL = "phi-4"
# Create volumes for caching
hf_cache_vol = modal.Volume.from_name("huggingface-cache", create_if_missing=True)
vllm_cache_vol = modal.Volume.from_name("vllm-cache", create_if_missing=True)
llama_cpp_cache_vol = modal.Volume.from_name("llama-cpp-cache", create_if_missing=True)
results_vol = modal.Volume.from_name("model-results", create_if_missing=True)
# Create the Modal app
app = modal.App("openai-compatible-llm-server")
# Create shared data structures
model_stats_dict = modal.Dict.from_name("model-stats", create_if_missing=True)
user_usage_dict = modal.Dict.from_name("user-usage", create_if_missing=True)
request_queue = modal.Queue.from_name("request-queue", create_if_missing=True)
response_dict = modal.Dict.from_name("response-cache", create_if_missing=True)
api_keys_dict = modal.Dict.from_name("api-keys", create_if_missing=True)
stream_queues = modal.Dict.from_name("stream-queues", create_if_missing=True)
# Advanced caching system
class AdvancedCache:
"""
Advanced caching system with TTL and LRU eviction.
"""
def __init__(self, max_size=1000, default_ttl=3600):
self.cache = {}
self.ttl_map = {}
self.access_times = {}
self.max_size = max_size
self.default_ttl = default_ttl
self.lock = threading.Lock()
def get(self, key):
"""Get a value from the cache"""
with self.lock:
now = time.time()
# Check if key exists and is not expired
if key in self.cache:
# Check TTL
if key in self.ttl_map and self.ttl_map[key] < now:
# Expired
self._remove(key)
return None
# Update access time
self.access_times[key] = now
return self.cache[key]
return None
def set(self, key, value, ttl=None):
"""Set a value in the cache with optional TTL"""
with self.lock:
now = time.time()
# Evict if needed
if len(self.cache) >= self.max_size and key not in self.cache:
self._evict_lru()
# Set value
self.cache[key] = value
self.access_times[key] = now
# Set TTL
if ttl is not None:
self.ttl_map[key] = now + ttl
elif self.default_ttl > 0:
self.ttl_map[key] = now + self.default_ttl
def _remove(self, key):
"""Remove a key from the cache"""
if key in self.cache:
del self.cache[key]
if key in self.ttl_map:
del self.ttl_map[key]
if key in self.access_times:
del self.access_times[key]
def _evict_lru(self):
"""Evict least recently used item"""
if not self.access_times:
return
# Find oldest access time
oldest_key = min(self.access_times.items(), key=lambda x: x[1])[0]
self._remove(oldest_key)
def clear_expired(self):
"""Clear all expired entries"""
with self.lock:
now = time.time()
expired_keys = [k for k, v in self.ttl_map.items() if v < now]
for key in expired_keys:
self._remove(key)
# Constants
MAX_CACHE_AGE = 3600 # 1 hour in seconds
# Create memory cache
memory_cache = AdvancedCache(max_size=10000, default_ttl=MAX_CACHE_AGE)
# Initialize with default key if empty
if "default" not in api_keys_dict:
api_keys_dict["default"] = {
"key": "sk-modal-llm-api-key",
"rate_limit": 60, # requests per minute
"quota": 1000000, # tokens per day
"created_at": datetime.now().isoformat(),
"owner": "default"
}
# Add a default ADMIN API key
if "admin" not in api_keys_dict:
api_keys_dict["admin"] = {
"key": "sk-modal-admin-api-key",
"rate_limit": 1000, # Higher rate limit for admin
"quota": 10000000, # Higher quota for admin
"created_at": datetime.now().isoformat(),
"owner": "admin"
}
# Constants
DEFAULT_API_KEY = api_keys_dict["default"]["key"]
MINUTES = 60 # seconds
SERVER_PORT = 8000
CACHE_DIR = "/root/.cache"
RESULTS_DIR = "/root/results"
# Request/response models
class GenerationRequest(BaseModel):
request_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
model_id: str
messages: List[Dict[str, str]]
temperature: float = 0.7
max_tokens: int = 1024
top_p: float = 1.0
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
user: Optional[str] = None
stream: bool = False
timestamp: float = Field(default_factory=time.time)
api_key: str = DEFAULT_API_KEY
class StreamChunk(BaseModel):
"""Model for streaming response chunks"""
id: str
object: str = "chat.completion.chunk"
created: int
model: str
choices: List[Dict[str, Any]]
class StreamManager:
"""Manages streaming responses for clients"""
def __init__(self):
self.streams = {}
self.lock = threading.Lock()
def create_stream(self, request_id):
"""Create a new stream for a request"""
with self.lock:
self.streams[request_id] = {
"queue": asyncio.Queue(),
"finished": False,
"created_at": time.time()
}
def add_chunk(self, request_id, chunk):
"""Add a chunk to a stream"""
with self.lock:
if request_id in self.streams:
stream = self.streams[request_id]
if not stream["finished"]:
stream["queue"].put_nowait(chunk)
def finish_stream(self, request_id):
"""Mark a stream as finished"""
with self.lock:
if request_id in self.streams:
self.streams[request_id]["finished"] = True
# Add None to signal end of stream
self.streams[request_id]["queue"].put_nowait(None)
async def get_chunks(self, request_id):
"""Get chunks from a stream as an async generator"""
if request_id not in self.streams:
return
stream = self.streams[request_id]
queue = stream["queue"]
while True:
chunk = await queue.get()
if chunk is None: # End of stream
break
yield chunk
queue.task_done()
# Clean up after streaming is done
with self.lock:
if request_id in self.streams:
del self.streams[request_id]
def clean_old_streams(self, max_age=3600):
"""Clean up old streams"""
with self.lock:
now = time.time()
to_remove = []
for request_id, stream in self.streams.items():
if now - stream["created_at"] > max_age:
to_remove.append(request_id)
for request_id in to_remove:
if request_id in self.streams:
# Mark as finished to stop any ongoing processing
self.streams[request_id]["finished"] = True
# Add None to unblock any waiting consumers
self.streams[request_id]["queue"].put_nowait(None)
# Remove from streams
del self.streams[request_id]
# Create stream manager
stream_manager = StreamManager()
# API Authentication dependency
def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""Verify that the API key in the authorization header is valid and check rate limits"""
if credentials.scheme != "Bearer":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication scheme. Use Bearer",
)
api_key = credentials.credentials
valid_key = False
key_info = None
# Check if this is a known API key
for user_id, user_data in api_keys_dict.items():
if user_data.get("key") == api_key:
valid_key = True
key_info = user_data
break
if not valid_key:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)
# Check rate limits
user_id = key_info.get("owner", "unknown")
rate_limit = key_info.get("rate_limit", 60) # Default: 60 requests per minute
# Get or initialize user usage tracking
if user_id not in user_usage_dict:
user_usage_dict[user_id] = {
"requests": [],
"tokens": {
"input": 0,
"output": 0,
"last_reset": datetime.now().isoformat()
}
}
usage = user_usage_dict[user_id]
# Check if user exceeded rate limit using token bucket algorithm
if not rate_limiter.consume(user_id, tokens=1, rate_limit=rate_limit):
# Calculate retry-after based on rate
retry_after = int(60 / rate_limit) # seconds until at least one token is available
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=f"Rate limit exceeded. Maximum {rate_limit} requests per minute.",
headers={"Retry-After": str(retry_after)}
)
# Add current request timestamp for analytics
now = datetime.now()
usage["requests"].append(now.timestamp())
# Clean up old requests (older than 1 day) to prevent unbounded growth
day_ago = (now - timedelta(days=1)).timestamp()
usage["requests"] = [req for req in usage["requests"] if req > day_ago]
# Update usage dict
user_usage_dict[user_id] = usage
# Return the API key and user ID
return {"key": api_key, "user_id": user_id}
# API Endpoints
@api_app.get("/", response_class=HTMLResponse)
async def index():
"""Root endpoint that returns HTML with API information"""
return """
<html>
<head>
<title>Modal LLM Inference API</title>
<style>
body { font-family: system-ui, sans-serif; max-width: 800px; margin: 0 auto; padding: 2rem; }
h1 { color: #4a56e2; }
code { background: #f4f4f8; padding: 0.2rem 0.4rem; border-radius: 3px; }
</style>
</head>
<body>
<h1>Modal LLM Inference API</h1>
<p>This is an OpenAI-compatible API for LLM inference powered by Modal.</p>
<p>Use the following endpoints:</p>
<ul>
<li><a href="/docs">/docs</a> - API documentation</li>
<li><a href="/v1/models">/v1/models</a> - List available models</li>
<li><code>/v1/chat/completions</code> - Chat completions endpoint</li>
</ul>
</body>
</html>
"""
@api_app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "healthy"}
@api_app.get("/v1/models", dependencies=[Depends(verify_api_key)])
async def list_models():
"""List all available models in OpenAI-compatible format"""
# Combine vLLM and llama.cpp models
all_models = []
for model_id, model_info in VLLM_MODELS.items():
all_models.append({
"id": model_info["id"],
"object": "model",
"created": 1677610602,
"owned_by": "modal",
"engine": "vllm",
"loaded": model_info.get("loaded", False)
})
for model_id, model_info in LLAMA_CPP_MODELS.items():
all_models.append({
"id": model_info["id"],
"object": "model",
"created": 1677610602,
"owned_by": "modal",
"engine": "llama.cpp",
"loaded": model_info.get("loaded", False)
})
return {"data": all_models, "object": "list"}
# Model management endpoints
class ModelLoadRequest(BaseModel):
"""Request model to load a specific model"""
model_id: str
force_reload: bool = False
class HFModelLoadRequest(BaseModel):
"""Request to load a model directly from Hugging Face"""
repo_id: str
model_type: str = "vllm" # "vllm" or "llama.cpp"
revision: Optional[str] = None
quant: Optional[str] = None # For llama.cpp models
max_tokens: int = 4096
gpu: Optional[str] = None # For llama.cpp models
@api_app.post("/admin/models/load", dependencies=[Depends(verify_api_key)])
async def load_model(request: ModelLoadRequest, background_tasks: BackgroundTasks):
"""Load a specific model into memory"""
model_id = request.model_id
force_reload = request.force_reload
# Check if model exists
if model_id in VLLM_MODELS:
model_type = "vllm"
model_info = VLLM_MODELS[model_id]
elif model_id in LLAMA_CPP_MODELS:
model_type = "llama.cpp"
model_info = LLAMA_CPP_MODELS[model_id]
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Model {model_id} not found"
)
# Check if model is already loaded
if model_info.get("loaded", False) and not force_reload:
return {
"status": "success",
"message": f"Model {model_id} is already loaded",
"model_id": model_id,
"model_type": model_type
}
# Start loading the model in the background
if model_type == "vllm":
# Start vLLM server for this model
background_tasks.add_task(serve_vllm_model.remote, model_id=model_id)
# Update model status
VLLM_MODELS[model_id]["loaded"] = True
else: # llama.cpp
# For llama.cpp models, we'll preload the model
background_tasks.add_task(preload_llama_cpp_model, model_id)
# Update model status
LLAMA_CPP_MODELS[model_id]["loaded"] = True
return {
"status": "success",
"message": f"Started loading model {model_id}",
"model_id": model_id,
"model_type": model_type
}
@api_app.post("/admin/models/load-from-hf", dependencies=[Depends(verify_api_key)])
async def load_model_from_hf(request: HFModelLoadRequest, background_tasks: BackgroundTasks):
"""Load a model directly from Hugging Face"""
repo_id = request.repo_id
model_type = request.model_type
revision = request.revision
# Generate a unique model_id based on the repo name
repo_name = repo_id.split("/")[-1] if "/" in repo_id else repo_id
model_id = f"hf-{repo_name}-{uuid.uuid4().hex[:6]}"
# Create model info based on type
if model_type.lower() == "vllm":
# Add to VLLM_MODELS
VLLM_MODELS[model_id] = {
"id": model_id,
"name": repo_id,
"revision": revision or "main",
"max_tokens": request.max_tokens,
"loaded": False,
"hf_direct": True # Mark as directly loaded from HF
}
# Start vLLM server for this model
background_tasks.add_task(serve_vllm_model.remote, model_id=model_id)
# Update model status
VLLM_MODELS[model_id]["loaded"] = True
elif model_type.lower() == "llama.cpp":
# For llama.cpp we need quant info
quant = request.quant or "Q4_K_M" # Default quantization
pattern = f"*{quant}*"
# Add to LLAMA_CPP_MODELS
LLAMA_CPP_MODELS[model_id] = {
"id": model_id,
"name": repo_id,
"quant": quant,
"pattern": pattern,
"revision": revision,
"gpu": request.gpu, # Can be None for CPU
"max_tokens": request.max_tokens,
"loaded": False,
"hf_direct": True # Mark as directly loaded from HF
}
# Preload the model
background_tasks.add_task(preload_llama_cpp_model, model_id)
# Update model status
LLAMA_CPP_MODELS[model_id]["loaded"] = True
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid model type: {model_type}. Must be 'vllm' or 'llama.cpp'"
)
return {
"status": "success",
"message": f"Started loading model {repo_id} as {model_id}",
"model_id": model_id,
"model_type": model_type,
"repo_id": repo_id
}
@api_app.post("/admin/models/unload", dependencies=[Depends(verify_api_key)])
async def unload_model(request: ModelLoadRequest):
"""Unload a specific model from memory"""
model_id = request.model_id
# Check if model exists
if model_id in VLLM_MODELS:
model_type = "vllm"
model_info = VLLM_MODELS[model_id]
elif model_id in LLAMA_CPP_MODELS:
model_type = "llama.cpp"
model_info = LLAMA_CPP_MODELS[model_id]
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Model {model_id} not found"
)
# Check if model is loaded
if not model_info.get("loaded", False):
return {
"status": "success",
"message": f"Model {model_id} is not loaded",
"model_id": model_id,
"model_type": model_type
}
# Update model status
if model_type == "vllm":
VLLM_MODELS[model_id]["loaded"] = False
else: # llama.cpp
LLAMA_CPP_MODELS[model_id]["loaded"] = False
return {
"status": "success",
"message": f"Unloaded model {model_id}",
"model_id": model_id,
"model_type": model_type
}
@api_app.get("/admin/models/status/{model_id}", dependencies=[Depends(verify_api_key)])
async def get_model_status(model_id: str):
"""Get the status of a specific model"""
# Check if model exists
if model_id in VLLM_MODELS:
model_type = "vllm"
model_info = VLLM_MODELS[model_id]
elif model_id in LLAMA_CPP_MODELS:
model_type = "llama.cpp"
model_info = LLAMA_CPP_MODELS[model_id]
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Model {model_id} not found"
)
# Get model stats if available
model_stats = model_stats_dict.get(model_id, {})
# Include HF info if available
hf_info = {}
if model_info.get("hf_direct"):
hf_info = {
"repo_id": model_info.get("name"),
"revision": model_info.get("revision"),
}
if model_type == "llama.cpp":
hf_info["quant"] = model_info.get("quant")
return {
"model_id": model_id,
"model_type": model_type,
"loaded": model_info.get("loaded", False),
"stats": model_stats,
"hf_info": hf_info if hf_info else None
}
# Admin API endpoints
class APIKeyRequest(BaseModel):
user_id: str
rate_limit: int = 60
quota: int = 1000000
class APIKey(BaseModel):
key: str
user_id: str
rate_limit: int
quota: int
created_at: str
@api_app.post("/admin/api-keys", response_model=APIKey)
async def create_api_key(request: APIKeyRequest, auth_info: dict = Depends(verify_api_key)):
"""Create a new API key for a user (admin only)"""
# Check if this is an admin request
if auth_info["user_id"] != "default":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only admin users can create API keys"
)
# Generate a new API key
new_key = f"sk-modal-{uuid.uuid4()}"
user_id = request.user_id
# Store the key
api_keys_dict[user_id] = {
"key": new_key,
"rate_limit": request.rate_limit,
"quota": request.quota,
"created_at": datetime.now().isoformat(),
"owner": user_id
}
# Initialize user usage
if not user_usage_dict.contains(user_id):
user_usage_dict[user_id] = {
"requests": [],
"tokens": {
"input": 0,
"output": 0,
"last_reset": datetime.now().isoformat()
}
}
return APIKey(
key=new_key,
user_id=user_id,
rate_limit=request.rate_limit,
quota=request.quota,
created_at=datetime.now().isoformat()
)
@api_app.get("/admin/api-keys")
async def list_api_keys(auth_info: dict = Depends(verify_api_key)):
"""List all API keys (admin only)"""
# Check if this is an admin request
if auth_info["user_id"] != "default":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only admin users can list API keys"
)
# Return all keys (except the actual key values for security)
keys = []
for user_id, key_info in api_keys_dict.items():
keys.append({
"user_id": user_id,
"rate_limit": key_info.get("rate_limit", 60),
"quota": key_info.get("quota", 1000000),
"created_at": key_info.get("created_at", datetime.now().isoformat()),
# Mask the actual key
"key": key_info.get("key", "")[:8] + "..." if key_info.get("key") else "None"
})
return {"keys": keys}
@api_app.get("/admin/stats")
async def get_stats(auth_info: dict = Depends(verify_api_key)):
"""Get usage statistics (admin only)"""
# Check if this is an admin request
if auth_info["user_id"] != "default":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only admin users can view stats"
)
# Get model stats
model_stats = {}
for model_id in list(VLLM_MODELS.keys()) + list(LLAMA_CPP_MODELS.keys()):
if model_id in model_stats_dict:
model_stats[model_id] = model_stats_dict[model_id]
# Get user stats
user_stats = {}
for user_id in user_usage_dict.keys():
usage = user_usage_dict[user_id]
# Don't include request timestamps for brevity
if "requests" in usage:
usage = usage.copy()
usage["request_count"] = len(usage["requests"])
del usage["requests"]
user_stats[user_id] = usage
# Get queue info
queue_info = {
"pending_requests": request_queue.len(),
"active_workers": model_stats_dict.get("workers_running", 0)
}
return {
"models": model_stats,
"users": user_stats,
"queue": queue_info,
"timestamp": datetime.now().isoformat()
}
@api_app.delete("/admin/api-keys/{user_id}")
async def delete_api_key(user_id: str, auth_info: dict = Depends(verify_api_key)):
"""Delete an API key (admin only)"""
# Check if this is an admin request
if auth_info["user_id"] != "default":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only admin users can delete API keys"
)
# Check if the key exists
if not api_keys_dict.contains(user_id):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No API key found for user {user_id}"
)
# Can't delete the default key
if user_id == "default":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot delete the default API key"
)
# Delete the key
api_keys_dict.pop(user_id)
return {"status": "success", "message": f"API key deleted for user {user_id}"}
@api_app.post("/v1/chat/completions")
async def chat_completions(request: Request, background_tasks: BackgroundTasks, auth_info: dict = Depends(verify_api_key)):
"""OpenAI-compatible chat completions endpoint with request queueing, streaming and response caching"""
try:
json_data = await request.json()
# Extract model or use default
model_id = json_data.get("model", DEFAULT_MODEL)
messages = json_data.get("messages", [])
temperature = json_data.get("temperature", 0.7)
max_tokens = json_data.get("max_tokens", 1024)
stream = json_data.get("stream", False)
user = json_data.get("user", auth_info["user_id"])
# Calculate a cache key based on the request parameters
cache_key = calculate_cache_key(model_id, messages, temperature, max_tokens)
# Check if we have a cached response in memory cache first (faster)
cached_response = memory_cache.get(cache_key)
if cached_response and not stream: # Don't use cache for streaming requests
# Update stats
update_stats(model_id, "cache_hit")
return cached_response
# Check if we have a cached response in Modal's persistent cache
if not cached_response and cache_key in response_dict and not stream:
cached_response = response_dict[cache_key]
cache_age = time.time() - cached_response.get("timestamp", 0)
# Use cached response if it's fresh enough
if cache_age < MAX_CACHE_AGE:
# Update stats
update_stats(model_id, "cache_hit")
response_data = cached_response["response"]
# Also cache in memory for faster access next time
memory_cache.set(cache_key, response_data)
return response_data
# Select best model if "auto" is specified
if model_id == "auto" and len(messages) > 0:
# Get the last user message
last_message = None
for msg in reversed(messages):
if msg.get("role") == "user":
last_message = msg.get("content", "")
break
if last_message:
prompt = last_message
# Select best model based on prompt and parameters
model_id = select_best_model(prompt, max_tokens, temperature)
logging.info(f"Auto-selected model: {model_id} for prompt")
# Check if model exists
if model_id not in VLLM_MODELS and model_id not in LLAMA_CPP_MODELS:
# Default to the default model if specified model not found
logging.warning(f"Model {model_id} not found, using default: {DEFAULT_MODEL}")
model_id = DEFAULT_MODEL
# Create a unique request ID
request_id = str(uuid.uuid4())
# Create request object
gen_request = GenerationRequest(
request_id=request_id,
model_id=model_id,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=json_data.get("top_p", 1.0),
frequency_penalty=json_data.get("frequency_penalty", 0.0),
presence_penalty=json_data.get("presence_penalty", 0.0),
user=user,
stream=stream,
api_key=auth_info["key"]
)
# For streaming requests, set up streaming response
if stream:
# Create a new stream
stream_manager.create_stream(request_id)
# Put the request in the queue
await request_queue.put.aio(gen_request.model_dump())
# Update stats
update_stats(model_id, "request_count")
update_stats(model_id, "stream_count")
# Start a background worker to process the request if needed
background_tasks.add_task(ensure_worker_running)
# Return a streaming response using FastAPI's StreamingResponse
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
return FastAPIStreamingResponse(
content=stream_response(request_id, model_id, auth_info["user_id"]),
media_type="text/event-stream"
)
# For non-streaming, enqueue the request and wait for result
# Put the request in the queue
await request_queue.put.aio(gen_request.model_dump())
# Update stats
update_stats(model_id, "request_count")
# Start a background worker to process the request if needed
background_tasks.add_task(ensure_worker_running)
# Wait for the response with timeout
start_time = time.time()
timeout = 120 # 2-minute timeout for non-streaming requests
while time.time() - start_time < timeout:
# Check memory cache first (faster)
response_data = memory_cache.get(request_id)
if response_data:
# Update stats
update_stats(model_id, "success_count")
estimate_tokens(messages, response_data, auth_info["user_id"], model_id)
# Save to persistent cache
response_dict[cache_key] = {
"response": response_data,
"timestamp": time.time()
}
# Clean up request-specific cache
memory_cache.set(request_id, None)
return response_data
# Check persistent cache
if response_dict.contains(request_id):
response_data = response_dict[request_id]
# Remove from response dict to save memory
try:
response_dict.pop(request_id)
except Exception:
pass
# Save to cache
response_dict[cache_key] = {
"response": response_data,
"timestamp": time.time()
}
# Also cache in memory
memory_cache.set(cache_key, response_data)
# Update stats
update_stats(model_id, "success_count")
estimate_tokens(messages, response_data, auth_info["user_id"], model_id)
return response_data
# Wait a bit before checking again
await asyncio.sleep(0.1)
# If we get here, we timed out
update_stats(model_id, "timeout_count")
raise HTTPException(
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
detail="Request timed out. The model may be busy. Please try again later."
)
except Exception as e:
logging.error(f"Error in chat completions: {str(e)}")
# Update error stats
if "model_id" in locals():
update_stats(model_id, "error_count")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error generating response: {str(e)}"
)
async def stream_response(request_id: str, model_id: str, user_id: str) -> AsyncIterator[str]:
"""Stream response chunks to the client"""
try:
# Stream header
yield "data: " + json.dumps({"object": "chat.completion.chunk"}) + "\n\n"
# Stream chunks
async for chunk in stream_manager.get_chunks(request_id):
if chunk:
yield f"data: {json.dumps(chunk)}\n\n"
# Stream done
yield "data: [DONE]\n\n"
except Exception as e:
logging.error(f"Error streaming response: {str(e)}")
# Update error stats
update_stats(model_id, "stream_error_count")
# Send error as SSE
error_json = json.dumps({"error": str(e)})
yield f"data: {error_json}\n\n"
yield "data: [DONE]\n\n"
async def ensure_worker_running():
"""Ensure that a worker is running to process the queue"""
# Check if workers are already running via a sentinel in shared dict
workers_running_key = "workers_running"
if not model_stats_dict.contains(workers_running_key):
model_stats_dict[workers_running_key] = 0
current_workers = model_stats_dict[workers_running_key]
# If no workers or too few workers, start more
if current_workers < 3: # Keep up to 3 workers running
# Increment worker count
model_stats_dict[workers_running_key] = current_workers + 1
# Start a worker
await process_queue_worker.spawn.aio()
def calculate_cache_key(model_id: str, messages: List[dict], temperature: float, max_tokens: int) -> str:
"""Calculate a deterministic cache key for a request using SHA-256"""
# Create a simplified version of the request for cache key
cache_dict = {
"model": model_id,
"messages": messages,
"temperature": round(temperature, 2), # Round to reduce variations
"max_tokens": max_tokens
}
# Convert to a stable string representation and hash it with SHA-256
cache_str = json.dumps(cache_dict, sort_keys=True)
hash_obj = hashlib.sha256(cache_str.encode())
return f"cache:{hash_obj.hexdigest()[:16]}"
def update_stats(model_id: str, stat_type: str):
"""Update usage statistics for a model"""
if not model_stats_dict.contains(model_id):
model_stats_dict[model_id] = {
"request_count": 0,
"success_count": 0,
"error_count": 0,
"timeout_count": 0,
"cache_hit": 0,
"token_count": 0,
"avg_latency": 0
}
stats = model_stats_dict[model_id]
stats[stat_type] = stats.get(stat_type, 0) + 1
model_stats_dict[model_id] = stats
def estimate_tokens(messages: List[dict], response: dict, user_id: str, model_id: str):
"""Estimate token usage and update user quotas"""
# Very simple token estimation based on whitespace-split words * 1.3
input_tokens = 0
for msg in messages:
input_tokens += len(msg.get("content", "").split()) * 1.3
output_tokens = 0
if response and "choices" in response:
for choice in response["choices"]:
if "message" in choice and "content" in choice["message"]:
output_tokens += len(choice["message"]["content"].split()) * 1.3
# Update model stats
if model_stats_dict.contains(model_id):
stats = model_stats_dict[model_id]
stats["token_count"] = stats.get("token_count", 0) + input_tokens + output_tokens
model_stats_dict[model_id] = stats
# Update user usage
if user_id in user_usage_dict:
usage = user_usage_dict[user_id]
# Check if we need to reset daily counters
last_reset = datetime.fromisoformat(usage["tokens"]["last_reset"])
now = datetime.now()
if now.date() > last_reset.date():
# Reset daily counters
usage["tokens"]["input"] = 0
usage["tokens"]["output"] = 0
usage["tokens"]["last_reset"] = now.isoformat()
# Update token counts
usage["tokens"]["input"] += int(input_tokens)
usage["tokens"]["output"] += int(output_tokens)
user_usage_dict[user_id] = usage
def select_best_model(prompt: str, n_predict: int, temperature: float) -> str:
"""
Intelligently selects the best model based on input parameters.
Args:
prompt (str): The input prompt for the model.
n_predict (int): The number of tokens to predict.
temperature (float): The sampling temperature.
Returns:
str: The identifier of the best model to use.
"""
# Check for code generation patterns
code_indicators = ["```", "def ", "class ", "function", "import ", "from ", "<script", "<style",
"SELECT ", "CREATE TABLE", "const ", "let ", "var ", "function(", "=>"]
is_likely_code = any(indicator in prompt for indicator in code_indicators)
# Check for creative writing patterns
creative_indicators = ["story", "poem", "creative", "imagine", "fiction", "narrative",
"write a", "compose", "create a"]
is_creative_task = any(indicator in prompt.lower() for indicator in creative_indicators)
# Check for analytical/reasoning tasks
analytical_indicators = ["explain", "analyze", "compare", "contrast", "reason",
"evaluate", "assess", "why", "how does"]
is_analytical_task = any(indicator in prompt.lower() for indicator in analytical_indicators)
# Decision logic
if is_likely_code:
# For code generation, prefer phi-4 for all code tasks
return "phi-4" # Excellent for code generation
elif is_creative_task:
# For creative tasks, use models with higher creativity
if temperature > 0.8:
return "deepseek-r1" # More creative at high temperatures
else:
return "phi-4" # Good balance of creativity and coherence
elif is_analytical_task:
# For analytical tasks, use models with strong reasoning
return "phi-4" # Strong reasoning capabilities
# Length-based decisions
if len(prompt) > 2000:
# For very long prompts, use models with good context handling
return "llama3-8b"
elif len(prompt) < 1000:
# For shorter prompts, prefer phi-4
return "phi-4"
# Temperature-based decisions
if temperature < 0.5:
# For deterministic outputs
return "phi-4"
elif temperature > 0.9:
# For very creative outputs
return "deepseek-r1"
# Default to phi-4 instead of the standard model
return "phi-4"
# vLLM serving function
@app.function(
image=vllm_image,
gpu="H100:1",
allow_concurrent_inputs=100,
volumes={
f"{CACHE_DIR}/huggingface": hf_cache_vol,
f"{CACHE_DIR}/vllm": vllm_cache_vol,
},
timeout=30 * MINUTES,
)
@modal.web_server(port=SERVER_PORT)
def serve_vllm_model(model_id: str = DEFAULT_MODEL):
"""
Serves a model using vLLM with an OpenAI-compatible API.
Args:
model_id (str): The identifier of the model to serve. Defaults to DEFAULT_MODEL.
Raises:
ValueError: If the specified model_id is not found in VLLM_MODELS.
"""
import subprocess
if model_id not in VLLM_MODELS:
available_models = list(VLLM_MODELS.keys())
logging.error(f"Error: Unknown model: {model_id}. Available models: {available_models}")
raise ValueError(f"Unknown model: {model_id}. Available models: {available_models}")
model_info = VLLM_MODELS[model_id]
model_name = model_info["name"]
revision = model_info["revision"]
logging.basicConfig(level=logging.INFO)
logging.info(f"Starting vLLM server with model: {model_name}")
cmd = [
"vllm",
"serve",
"--uvicorn-log-level=info",
model_name,
"--revision",
revision,
"--host",
"0.0.0.0",
"--port",
str(SERVER_PORT),
"--api-key",
DEFAULT_API_KEY,
]
# Use subprocess.run instead of Popen to ensure the server is fully started
# before returning, and don't use shell=True for better process management
process = subprocess.Popen(cmd)
# Log that we've started the server
logging.info(f"Started vLLM server with PID {process.pid}")
# Define the worker that will process the queue
@app.function(
image=vllm_image,
gpu=None, # Worker will spawn GPU functions as needed
allow_concurrent_inputs=10,
volumes={
f"{CACHE_DIR}/huggingface": hf_cache_vol,
},
timeout=30 * MINUTES,
)
async def process_queue_worker():
"""Worker function that processes requests from the queue"""
import asyncio
import time
try:
# Signal that we're starting a worker
worker_id = str(uuid.uuid4())[:8]
logging.info(f"Starting queue processing worker {worker_id}")
# Process requests until timeout or empty queue
empty_count = 0
max_empty_count = 10 # Stop after 10 consecutive empty polls
while empty_count < max_empty_count:
# Try to get a request from the queue
try:
request_dict = await request_queue.get.aio(timeout_ms=5000)
empty_count = 0 # Reset empty counter
# Process the request
try:
# Create request object
request_id = request_dict.get("request_id")
model_id = request_dict.get("model_id")
messages = request_dict.get("messages", [])
temperature = request_dict.get("temperature", 0.7)
max_tokens = request_dict.get("max_tokens", 1024)
api_key = request_dict.get("api_key", DEFAULT_API_KEY)
stream_mode = request_dict.get("stream", False)
logging.info(f"Worker {worker_id} processing request {request_id} for model {model_id}")
# Start time for latency calculation
start_time = time.time()
if stream_mode:
# Generate streaming response
await generate_streaming_response(
request_id=request_id,
model_id=model_id,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
api_key=api_key
)
else:
# Generate non-streaming response
response = await generate_response(
model_id=model_id,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
api_key=api_key
)
# Calculate latency
latency = time.time() - start_time
# Update latency stats
if model_stats_dict.contains(model_id):
stats = model_stats_dict[model_id]
old_avg = stats.get("avg_latency", 0)
old_count = stats.get("success_count", 0)
# Calculate new average (moving average)
if old_count > 0:
new_avg = (old_avg * old_count + latency) / (old_count + 1)
else:
new_avg = latency
stats["avg_latency"] = new_avg
model_stats_dict[model_id] = stats
# Store the response in both caches
memory_cache.set(request_id, response)
response_dict[request_id] = response
logging.info(f"Worker {worker_id} completed request {request_id} in {latency:.2f}s")
except Exception as e:
# Log error and move on
logging.error(f"Worker {worker_id} error processing request {request_id}: {str(e)}")
# Create error response
error_response = {
"error": {
"message": str(e),
"type": "internal_error",
"code": 500
}
}
# Store the error as a response
memory_cache.set(request_id, error_response)
response_dict[request_id] = error_response
# If streaming, send error and finish stream
if "stream_mode" in locals() and stream_mode:
stream_manager.add_chunk(request_id, {
"id": f"chatcmpl-{int(time.time())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model_id,
"choices": [{
"index": 0,
"delta": {"content": f"Error: {str(e)}"},
"finish_reason": "error"
}]
})
stream_manager.finish_stream(request_id)
except asyncio.TimeoutError:
# No requests in queue
empty_count += 1
logging.info(f"Worker {worker_id}: No requests in queue. Empty count: {empty_count}")
# Clean up expired cache entries and old streams
if empty_count % 5 == 0: # Every 5 empty polls
memory_cache.clear_expired()
stream_manager.clean_old_streams()
await asyncio.sleep(1) # Wait a bit before checking again
# If we get here, we've had too many consecutive empty polls
logging.info(f"Worker {worker_id} shutting down due to empty queue")
finally:
# Signal that this worker is done
workers_running_key = "workers_running"
if model_stats_dict.contains(workers_running_key):
current_workers = model_stats_dict[workers_running_key]
model_stats_dict[workers_running_key] = max(0, current_workers - 1)
logging.info(f"Worker {worker_id} shutdown. Workers remaining: {max(0, current_workers - 1)}")
async def generate_streaming_response(
request_id: str,
model_id: str,
messages: List[dict],
temperature: float,
max_tokens: int,
api_key: str
):
"""
Generate a streaming response and send chunks to the stream manager.
Args:
request_id: The unique ID for this request
model_id: The ID of the model to use
messages: The chat messages
temperature: The sampling temperature
max_tokens: The maximum tokens to generate
api_key: The API key for authentication
"""
import httpx
import time
import json
import asyncio
try:
# Create response ID
response_id = f"chatcmpl-{int(time.time())}"
if model_id in VLLM_MODELS:
# Start vLLM server for this model
server_url = await serve_vllm_model.remote(model_id=model_id)
# Need to wait for server startup
await wait_for_server(serve_vllm_model.web_url, timeout=120)
# Forward request to vLLM with streaming enabled
async with httpx.AsyncClient(timeout=120.0) as client:
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
"Accept": "text/event-stream"
}
# Format request for vLLM OpenAI-compatible endpoint
vllm_request = {
"model": VLLM_MODELS[model_id]["name"],
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
"stream": True
}
# Make streaming request
async with client.stream(
"POST",
f"{serve_vllm_model.web_url}/v1/chat/completions",
json=vllm_request,
headers=headers
) as response:
# Process streaming response
buffer = ""
async for chunk in response.aiter_text():
buffer += chunk
# Process complete SSE messages
while "\n\n" in buffer:
message, buffer = buffer.split("\n\n", 1)
if message.startswith("data: "):
data = message[6:] # Remove "data: " prefix
if data == "[DONE]":
# End of stream
stream_manager.finish_stream(request_id)
return
try:
# Parse JSON data
chunk_data = json.loads(data)
# Forward to client
stream_manager.add_chunk(request_id, chunk_data)
except json.JSONDecodeError:
logging.error(f"Invalid JSON in stream: {data}")
# Ensure stream is finished
stream_manager.finish_stream(request_id)
elif model_id in LLAMA_CPP_MODELS:
# For llama.cpp models, we need to simulate streaming
# First convert the chat format to a prompt
prompt = format_messages_to_prompt(messages)
# Run llama.cpp with the prompt
output = await run_llama_cpp_stream.remote(
model_id=model_id,
prompt=prompt,
n_predict=max_tokens,
temperature=temperature,
request_id=request_id
)
# Streaming is handled by the run_llama_cpp_stream function
# which directly adds chunks to the stream manager
# Wait for completion signal
while True:
if request_id in stream_queues and stream_queues[request_id] == "DONE":
# Clean up
stream_queues.pop(request_id)
break
await asyncio.sleep(0.1)
else:
raise ValueError(f"Unknown model: {model_id}")
except Exception as e:
logging.error(f"Error in streaming generation: {str(e)}")
# Send error chunk
stream_manager.add_chunk(request_id, {
"id": response_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model_id,
"choices": [{
"index": 0,
"delta": {"content": f"Error: {str(e)}"},
"finish_reason": "error"
}]
})
# Finish stream
stream_manager.finish_stream(request_id)
async def generate_response(model_id: str, messages: List[dict], temperature: float, max_tokens: int, api_key: str):
"""
Generate a response using the appropriate model based on model_id.
Args:
model_id: The ID of the model to use
messages: The chat messages
temperature: The sampling temperature
max_tokens: The maximum tokens to generate
api_key: The API key for authentication
Returns:
A response in OpenAI-compatible format
"""
import httpx
import time
import json
import asyncio
if model_id in VLLM_MODELS:
# Start vLLM server for this model
server_url = await serve_vllm_model.remote(model_id=model_id)
# Need to wait for server startup
await wait_for_server(serve_vllm_model.web_url, timeout=120)
# Forward request to vLLM
async with httpx.AsyncClient(timeout=60.0) as client:
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
# Format request for vLLM OpenAI-compatible endpoint
vllm_request = {
"model": VLLM_MODELS[model_id]["name"],
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens
}
response = await client.post(
f"{serve_vllm_model.web_url}/v1/chat/completions",
json=vllm_request,
headers=headers
)
return response.json()
elif model_id in LLAMA_CPP_MODELS:
# For llama.cpp models, use the run_llama_cpp function
# First convert the chat format to a prompt
prompt = format_messages_to_prompt(messages)
# Run llama.cpp with the prompt
output = await run_llama_cpp.remote(
model_id=model_id,
prompt=prompt,
n_predict=max_tokens,
temperature=temperature
)
# Format the response in the OpenAI format
completion_text = output.strip()
finish_reason = "stop" if len(completion_text) < max_tokens else "length"
return {
"id": f"chatcmpl-{int(time.time())}",
"object": "chat.completion",
"created": int(time.time()),
"model": model_id,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": completion_text
},
"finish_reason": finish_reason
}
],
"usage": {
"prompt_tokens": len(prompt) // 4, # Rough estimation
"completion_tokens": len(completion_text) // 4, # Rough estimation
"total_tokens": (len(prompt) + len(completion_text)) // 4 # Rough estimation
}
}
else:
raise ValueError(f"Unknown model: {model_id}")
def format_messages_to_prompt(messages: List[Dict[str, str]]) -> str:
"""
Convert chat messages to a text prompt format for llama.cpp.
Args:
messages: List of message dictionaries with role and content
Returns:
Formatted prompt string
"""
formatted_prompt = ""
for message in messages:
role = message.get("role", "").lower()
content = message.get("content", "")
if role == "system":
formatted_prompt += f"<|system|>\n{content}\n"
elif role == "user":
formatted_prompt += f"<|user|>\n{content}\n"
elif role == "assistant":
formatted_prompt += f"<|assistant|>\n{content}\n"
else:
# For unknown roles, treat as user
formatted_prompt += f"<|user|>\n{content}\n"
# Add final assistant marker to prompt the model to respond
formatted_prompt += "<|assistant|>\n"
return formatted_prompt
async def wait_for_server(url: str, timeout: int = 120, check_interval: int = 2):
"""
Wait for a server to be ready by checking its health endpoint.
Args:
url: The base URL of the server
timeout: Maximum time to wait in seconds
check_interval: Interval between checks in seconds
Returns:
True if server is ready, False otherwise
"""
import httpx
import asyncio
import time
start_time = time.time()
health_url = f"{url}/health"
logging.info(f"Waiting for server at {url} to be ready...")
while time.time() - start_time < timeout:
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(health_url)
if response.status_code == 200:
logging.info(f"Server at {url} is ready!")
return True
except Exception as e:
elapsed = time.time() - start_time
logging.info(f"Server not ready yet after {elapsed:.1f}s: {str(e)}")
await asyncio.sleep(check_interval)
logging.error(f"Timed out waiting for server at {url} after {timeout} seconds")
return False
@app.function(
image=llama_cpp_image,
gpu=None, # Will be set dynamically based on model
volumes={
f"{CACHE_DIR}/huggingface": hf_cache_vol,
f"{CACHE_DIR}/llama_cpp": llama_cpp_cache_vol,
RESULTS_DIR: results_vol,
},
timeout=30 * MINUTES,
)
async def run_llama_cpp_stream(
model_id: str,
prompt: str,
n_predict: int = 1024,
temperature: float = 0.7,
request_id: str = None,
):
"""
Run streaming inference with llama.cpp for models like DeepSeek-R1 and Phi-4
"""
import subprocess
import os
import json
import time
import threading
from uuid import uuid4
from pathlib import Path
from huggingface_hub import snapshot_download
if model_id not in LLAMA_CPP_MODELS:
available_models = list(LLAMA_CPP_MODELS.keys())
error_msg = f"Unknown model: {model_id}. Available models: {available_models}"
logging.error(error_msg)
if request_id:
# Send error to stream
stream_manager.add_chunk(request_id, {
"id": f"chatcmpl-{int(time.time())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model_id,
"choices": [{
"index": 0,
"delta": {"content": f"Error: {error_msg}"},
"finish_reason": "error"
}]
})
stream_manager.finish_stream(request_id)
# Signal completion
stream_queues[request_id] = "DONE"
raise ValueError(error_msg)
model_info = LLAMA_CPP_MODELS[model_id]
repo_id = model_info["name"]
pattern = model_info["pattern"]
revision = model_info["revision"]
quant = model_info["quant"]
# Download model if not already cached
logging.info(f"Downloading model {repo_id} if not present")
try:
model_path = snapshot_download(
repo_id=repo_id,
revision=revision,
local_dir=f"{CACHE_DIR}/llama_cpp",
allow_patterns=[pattern],
)
except ValueError as e:
if "hf_transfer" in str(e):
# Fallback to standard download if hf_transfer fails
logging.warning("hf_transfer failed, falling back to standard download")
# Temporarily disable hf_transfer
import os
old_env = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "1")
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
try:
model_path = snapshot_download(
repo_id=repo_id,
revision=revision,
local_dir=f"{CACHE_DIR}/llama_cpp",
allow_patterns=[pattern],
)
finally:
# Restore original setting
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_env
else:
raise
# Find the model file
model_files = list(Path(model_path).glob(pattern))
if not model_files:
error_msg = f"No model files found matching pattern {pattern}"
logging.error(error_msg)
if request_id:
# Send error to stream
stream_manager.add_chunk(request_id, {
"id": f"chatcmpl-{int(time.time())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model_id,
"choices": [{
"index": 0,
"delta": {"content": f"Error: {error_msg}"},
"finish_reason": "error"
}]
})
stream_manager.finish_stream(request_id)
# Signal completion
stream_queues[request_id] = "DONE"
raise FileNotFoundError(error_msg)
model_file = str(model_files[0])
logging.info(f"Using model file: {model_file}")
# Set up command
cmd = [
"llama-cli",
"--model", model_file,
"--prompt", prompt,
"--n-predict", str(n_predict),
"--temp", str(temperature),
"--ctx-size", "8192",
]
# Add GPU layers if needed
if model_info["gpu"] is not None:
cmd.extend(["--n-gpu-layers", "9999"]) # Use all layers on GPU
# Run inference
result_id = str(uuid4())
logging.info(f"Running streaming inference with ID: {result_id}")
# Create response ID for streaming
response_id = f"chatcmpl-{int(time.time())}"
# Function to process output in real-time and send to stream
def process_output(process, request_id):
content_buffer = ""
last_send_time = time.time()
# Send initial chunk with role
if request_id:
stream_manager.add_chunk(request_id, {
"id": response_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model_id,
"choices": [{
"index": 0,
"delta": {"role": "assistant"},
}]
})
for line in iter(process.stdout.readline, b''):
try:
line_str = line.decode('utf-8', errors='replace')
# Skip llama.cpp info lines
if line_str.startswith("llama_"):
continue
# Add to buffer
content_buffer += line_str
# Send chunks at reasonable intervals or when buffer gets large
now = time.time()
if (now - last_send_time > 0.1 or len(content_buffer) > 20) and request_id:
# Send chunk
stream_manager.add_chunk(request_id, {
"id": response_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model_id,
"choices": [{
"index": 0,
"delta": {"content": content_buffer},
}]
})
# Reset buffer and time
content_buffer = ""
last_send_time = now
except Exception as e:
logging.error(f"Error processing output: {str(e)}")
# Send any remaining content
if content_buffer and request_id:
stream_manager.add_chunk(request_id, {
"id": response_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model_id,
"choices": [{
"index": 0,
"delta": {"content": content_buffer},
}]
})
# Send final chunk with finish reason
if request_id:
stream_manager.add_chunk(request_id, {
"id": response_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model_id,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "stop"
}]
})
# Finish stream
stream_manager.finish_stream(request_id)
# Signal completion
stream_queues[request_id] = "DONE"
# Start process
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=False,
bufsize=1 # Line buffered
)
# Start output processing thread if streaming
if request_id:
thread = threading.Thread(target=process_output, args=(process, request_id))
thread.daemon = True
thread.start()
# Return immediately for streaming
return "Streaming in progress"
else:
# For non-streaming, collect all output
stdout, stderr = collect_output(process)
# Save results
result_dir = Path(RESULTS_DIR) / result_id
result_dir.mkdir(parents=True, exist_ok=True)
(result_dir / "output.txt").write_text(stdout)
(result_dir / "stderr.txt").write_text(stderr)
(result_dir / "prompt.txt").write_text(prompt)
logging.info(f"Results saved to {result_dir}")
return stdout
@app.function(
image=llama_cpp_image,
gpu=None, # Will be set dynamically based on model
volumes={
f"{CACHE_DIR}/huggingface": hf_cache_vol,
f"{CACHE_DIR}/llama_cpp": llama_cpp_cache_vol,
RESULTS_DIR: results_vol,
},
timeout=30 * MINUTES,
)
async def run_llama_cpp(
model_id: str,
prompt: str = "Tell me about Modal and how it helps with ML deployments.",
n_predict: int = 1024,
temperature: float = 0.7,
):
"""
Run inference with llama.cpp for models like DeepSeek-R1 and Phi-4
"""
import subprocess
import os
from uuid import uuid4
from pathlib import Path
from huggingface_hub import snapshot_download
if model_id not in LLAMA_CPP_MODELS:
available_models = list(LLAMA_CPP_MODELS.keys())
print(f"Error: Unknown model: {model_id}. Available models: {available_models}")
raise ValueError(f"Unknown model: {model_id}. Available models: {available_models}")
model_info = LLAMA_CPP_MODELS[model_id]
repo_id = model_info["name"]
pattern = model_info["pattern"]
revision = model_info["revision"]
quant = model_info["quant"]
# Download model if not already cached
logging.info(f"Downloading model {repo_id} if not present")
try:
model_path = snapshot_download(
repo_id=repo_id,
revision=revision,
local_dir=f"{CACHE_DIR}/llama_cpp",
allow_patterns=[pattern],
)
except ValueError as e:
if "hf_transfer" in str(e):
# Fallback to standard download if hf_transfer fails
logging.warning("hf_transfer failed, falling back to standard download")
# Temporarily disable hf_transfer
import os
old_env = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "1")
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
try:
model_path = snapshot_download(
repo_id=repo_id,
revision=revision,
local_dir=f"{CACHE_DIR}/llama_cpp",
allow_patterns=[pattern],
)
finally:
# Restore original setting
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_env
else:
raise
# Find the model file
model_files = list(Path(model_path).glob(pattern))
if not model_files:
logging.error(f"No model files found matching pattern {pattern}")
raise FileNotFoundError(f"No model files found matching pattern {pattern}")
model_file = str(model_files[0])
print(f"Using model file: {model_file}")
# Set up command
cmd = [
"llama-cli",
"--model", model_file,
"--prompt", prompt,
"--n-predict", str(n_predict),
"--temp", str(temperature),
"--ctx-size", "8192",
]
# Add GPU layers if needed
if model_info["gpu"] is not None:
cmd.extend(["--n-gpu-layers", "9999"]) # Use all layers on GPU
# Run inference
result_id = str(uuid4())
print(f"Running inference with ID: {result_id}")
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=False
)
stdout, stderr = collect_output(process)
# Save results
result_dir = Path(RESULTS_DIR) / result_id
result_dir.mkdir(parents=True, exist_ok=True)
(result_dir / "output.txt").write_text(stdout)
(result_dir / "stderr.txt").write_text(stderr)
(result_dir / "prompt.txt").write_text(prompt)
print(f"Results saved to {result_dir}")
return stdout
@app.function(
image=vllm_image,
volumes={
f"{CACHE_DIR}/huggingface": hf_cache_vol,
},
)
def list_available_models():
"""
Lists available models that can be used with this server.
Returns:
dict: A dictionary containing lists of available vLLM and llama.cpp models.
"""
print("Available vLLM models:")
for model_id, model_info in VLLM_MODELS.items():
print(f"- {model_id}: {model_info['name']}")
print("\nAvailable llama.cpp models:")
for model_id, model_info in LLAMA_CPP_MODELS.items():
gpu_info = f"(GPU: {model_info['gpu']})" if model_info['gpu'] else "(CPU)"
print(f"- {model_id}: {model_info['name']} {gpu_info}")
return {
"vllm": list(VLLM_MODELS.keys()),
"llama_cpp": list(LLAMA_CPP_MODELS.keys())
}
def collect_output(process):
"""
Collect output from a process while streaming it.
Args:
process: The process from which to collect output.
Returns:
tuple: A tuple containing the collected stdout and stderr as strings.
"""
import sys
from queue import Queue
from threading import Thread
def stream_output(stream, queue, write_stream):
for line in iter(stream.readline, b""):
line_str = line.decode("utf-8", errors="replace")
write_stream.write(line_str)
write_stream.flush()
queue.put(line_str)
stream.close()
stdout_queue = Queue()
stderr_queue = Queue()
stdout_thread = Thread(target=stream_output, args=(process.stdout, stdout_queue, sys.stdout))
stderr_thread = Thread(target=stream_output, args=(process.stderr, stderr_queue, sys.stderr))
stdout_thread.start()
stderr_thread.start()
stdout_thread.join()
stderr_thread.join()
process.wait()
stdout_collected = "".join(list(stdout_queue.queue))
stderr_collected = "".join(list(stderr_queue.queue))
return stdout_collected, stderr_collected
# Main ASGI app for Modal
@app.function(
image=vllm_image,
gpu=None, # No GPU for the API frontend
allow_concurrent_inputs=100,
volumes={
f"{CACHE_DIR}/huggingface": hf_cache_vol,
},
)
@modal.asgi_app()
def inference_api():
"""The main ASGI app that serves the FastAPI application"""
return api_app
@app.local_entrypoint()
def main(
prompt: str = "What can you tell me about Modal?",
n_predict: int = 1024,
temperature: float = 0.7,
create_admin_key: bool = False,
stream: bool = False,
model: str = "auto",
load_model: str = None,
load_hf_model: str = None,
hf_model_type: str = "vllm",
):
"""
Main entrypoint for testing the API
"""
import json
import time
import urllib.request
# Initialize the API
print(f"Starting API at {inference_api.web_url}")
# Wait for API to be ready
print("Checking if API is ready...")
up, start, delay = False, time.time(), 10
while not up:
try:
with urllib.request.urlopen(inference_api.web_url + "/health") as response:
if response.getcode() == 200:
up = True
except Exception:
if time.time() - start > 5 * MINUTES:
break
time.sleep(delay)
assert up, f"Failed health check for API at {inference_api.web_url}"
print(f"API is up and running at {inference_api.web_url}")
# Create a test API key if requested
if create_admin_key:
print("Creating a test API key...")
key_request = {
"user_id": "test_user",
"rate_limit": 120,
"quota": 2000000
}
headers = {
"Authorization": f"Bearer {DEFAULT_API_KEY}", # Admin key
"Content-Type": "application/json",
}
req = urllib.request.Request(
inference_api.web_url + "/admin/api-keys",
data=json.dumps(key_request).encode("utf-8"),
headers=headers,
method="POST",
)
try:
with urllib.request.urlopen(req) as response:
result = json.loads(response.read().decode())
print("Created API key:")
print(json.dumps(result, indent=2))
# Use this key for the test message
test_key = result["key"]
except Exception as e:
print(f"Error creating API key: {str(e)}")
test_key = DEFAULT_API_KEY
else:
test_key = DEFAULT_API_KEY
# List available models
print("\nAvailable models:")
try:
headers = {
"Authorization": f"Bearer {test_key}",
"Content-Type": "application/json",
}
req = urllib.request.Request(
inference_api.web_url + "/v1/models",
headers=headers,
method="GET",
)
with urllib.request.urlopen(req) as response:
models = json.loads(response.read().decode())
print(json.dumps(models, indent=2))
except Exception as e:
print(f"Error listing models: {str(e)}")
# Select best model for the prompt
model = select_best_model(prompt, n_predict, temperature)
# Send a test message
print(f"\nSending a sample message to {inference_api.web_url}")
messages = [{"role": "user", "content": prompt}]
headers = {
"Authorization": f"Bearer {test_key}",
"Content-Type": "application/json",
}
payload = json.dumps({
"messages": messages,
"model": model,
"temperature": temperature,
"max_tokens": n_predict,
"stream": stream
})
req = urllib.request.Request(
inference_api.web_url + "/v1/chat/completions",
data=payload.encode("utf-8"),
headers=headers,
method="POST",
)
try:
if stream:
print("Streaming response:")
with urllib.request.urlopen(req) as response:
for line in response:
line = line.decode('utf-8')
if line.startswith('data: '):
data = line[6:].strip()
if data == '[DONE]':
print("\n[DONE]")
else:
try:
chunk = json.loads(data)
if 'choices' in chunk and len(chunk['choices']) > 0:
if 'delta' in chunk['choices'][0] and 'content' in chunk['choices'][0]['delta']:
content = chunk['choices'][0]['delta']['content']
print(content, end='', flush=True)
except json.JSONDecodeError:
print(f"Error parsing: {data}")
else:
with urllib.request.urlopen(req) as response:
result = json.loads(response.read().decode())
print("Response:")
print(json.dumps(result, indent=2))
except Exception as e:
print(f"Error: {str(e)}")
# Check API stats
print("\nChecking API stats...")
headers = {
"Authorization": f"Bearer {DEFAULT_API_KEY}", # Admin key
"Content-Type": "application/json",
}
req = urllib.request.Request(
inference_api.web_url + "/admin/stats",
headers=headers,
method="GET",
)
try:
with urllib.request.urlopen(req) as response:
stats = json.loads(response.read().decode())
print("API Stats:")
print(json.dumps(stats, indent=2))
except Exception as e:
print(f"Error getting stats: {str(e)}")
# Start a worker if none running
try:
current_workers = stats.get("queue", {}).get("active_workers", 0)
if current_workers < 1:
print("\nStarting a queue worker...")
process_queue_worker.spawn()
except Exception as e:
print(f"Error starting worker: {str(e)}")
print(f"\nAPI is available at {inference_api.web_url}")
print(f"Documentation is at {inference_api.web_url}/docs")
print(f"Default Bearer token: {DEFAULT_API_KEY}")
if create_admin_key:
print(f"Test Bearer token: {test_key}")
# If a model was specified to load, load it
if load_model:
print(f"\nLoading model: {load_model}")
load_url = f"{inference_api.web_url}/admin/models/load"
headers = {
"Authorization": f"Bearer {test_key}",
"Content-Type": "application/json",
}
payload = json.dumps({
"model_id": load_model,
"force_reload": True
})
req = urllib.request.Request(
load_url,
data=payload.encode("utf-8"),
headers=headers,
method="POST",
)
try:
with urllib.request.urlopen(req) as response:
result = json.loads(response.read().decode())
print("Load response:")
print(json.dumps(result, indent=2))
# If it's a small model, wait a bit for it to load
if load_model in ["tiny-llama-1.1b", "phi-2"]:
print(f"Waiting for {load_model} to load...")
time.sleep(10)
# Check status
status_url = f"{inference_api.web_url}/admin/models/status/{load_model}"
status_req = urllib.request.Request(
status_url,
headers={"Authorization": f"Bearer {test_key}"},
method="GET",
)
with urllib.request.urlopen(status_req) as status_response:
status_result = json.loads(status_response.read().decode())
print("Model status:")
print(json.dumps(status_result, indent=2))
# Use this model for the test
model = load_model
except Exception as e:
print(f"Error loading model: {str(e)}")
# If a HF model was specified to load directly
if load_hf_model:
print(f"\nLoading HF model: {load_hf_model} with type {hf_model_type}")
load_url = f"{inference_api.web_url}/admin/models/load-from-hf"
headers = {
"Authorization": f"Bearer {test_key}",
"Content-Type": "application/json",
}
payload = json.dumps({
"repo_id": load_hf_model,
"model_type": hf_model_type,
"max_tokens": n_predict
})
req = urllib.request.Request(
load_url,
data=payload.encode("utf-8"),
headers=headers,
method="POST",
)
try:
with urllib.request.urlopen(req) as response:
result = json.loads(response.read().decode())
print("HF Load response:")
print(json.dumps(result, indent=2))
# Get the model_id from the response
hf_model_id = result.get("model_id")
# Wait a bit for it to start loading
print(f"Waiting for {load_hf_model} to start loading...")
time.sleep(5)
# Check status
if hf_model_id:
status_url = f"{inference_api.web_url}/admin/models/status/{hf_model_id}"
status_req = urllib.request.Request(
status_url,
headers={"Authorization": f"Bearer {test_key}"},
method="GET",
)
with urllib.request.urlopen(status_req) as status_response:
status_result = json.loads(status_response.read().decode())
print("Model status:")
print(json.dumps(status_result, indent=2))
# Use this model for the test
if hf_model_id:
model = hf_model_id
except Exception as e:
print(f"Error loading HF model: {str(e)}")
# Show curl examples
print("\nExample curl commands:")
# Regular completion
print(f"""# Regular completion:
curl -X POST {inference_api.web_url}/v1/chat/completions \\
-H "Content-Type: application/json" \\
-H "Authorization: Bearer {test_key}" \\
-d '{{
"model": "{model}",
"messages": [
{{
"role": "user",
"content": "Hello, how can you help me today?"
}}
],
"temperature": 0.7,
"max_tokens": 500
}}'""")
# Streaming completion
print(f"""\n# Streaming completion:
curl -X POST {inference_api.web_url}/v1/chat/completions \\
-H "Content-Type: application/json" \\
-H "Authorization: Bearer {test_key}" \\
-d '{{
"model": "{model}",
"messages": [
{{
"role": "user",
"content": "Write a short story about AI"
}}
],
"temperature": 0.8,
"max_tokens": 1000,
"stream": true
}}' --no-buffer""")
# List models
print(f"""\n# List available models:
curl -X GET {inference_api.web_url}/v1/models \\
-H "Authorization: Bearer {test_key}" """)
# Model management commands
print(f"""\n# Load a model:
curl -X POST {inference_api.web_url}/admin/models/load \\
-H "Content-Type: application/json" \\
-H "Authorization: Bearer {test_key}" \\
-d '{{
"model_id": "phi-2",
"force_reload": false
}}'""")
print(f"""\n# Load a model directly from Hugging Face:
curl -X POST {inference_api.web_url}/admin/models/load-from-hf \\
-H "Content-Type: application/json" \\
-H "Authorization: Bearer {test_key}" \\
-d '{{
"repo_id": "microsoft/phi-2",
"model_type": "vllm",
"max_tokens": 4096
}}'""")
print(f"""\n# Get model status:
curl -X GET {inference_api.web_url}/admin/models/status/phi-2 \\
-H "Authorization: Bearer {test_key}" """)
print(f"""\n# Unload a model:
curl -X POST {inference_api.web_url}/admin/models/unload \\
-H "Content-Type: application/json" \\
-H "Authorization: Bearer {test_key}" \\
-d '{{
"model_id": "phi-2"
}}'""")
async def preload_llama_cpp_model(model_id: str):
"""Preload a llama.cpp model to make inference faster on first request"""
if model_id not in LLAMA_CPP_MODELS:
logging.error(f"Unknown model: {model_id}")
return
try:
# Run a simple inference to load the model
logging.info(f"Preloading llama.cpp model: {model_id}")
await run_llama_cpp.remote(
model_id=model_id,
prompt="Hello, this is a test to preload the model.",
n_predict=10,
temperature=0.7
)
logging.info(f"Successfully preloaded llama.cpp model: {model_id}")
except Exception as e:
logging.error(f"Error preloading llama.cpp model {model_id}: {str(e)}")
# Mark as not loaded
LLAMA_CPP_MODELS[model_id]["loaded"] = False
```
--------------------------------------------------------------------------------
/cli.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
# TODO: Refactor into modular structure similar to Claude Code (lib/, commands/, tools/ directories)
# TODO: Add support for multiple LLM providers (Azure OpenAI, Anthropic, etc.)
# TODO: Implement telemetry and usage tracking (optional, with consent)
import os
import sys
import json
import typer
from rich.console import Console
from rich.markdown import Markdown
from rich.prompt import Prompt
from rich.panel import Panel
from rich.progress import Progress
from rich.syntax import Syntax
from rich.live import Live
from rich.layout import Layout
from rich.table import Table
from openai import OpenAI
from pydantic import BaseModel, Field
from typing import List, Dict, Any, Optional, Union, Callable
import asyncio
import concurrent.futures
from dotenv import load_dotenv
import time
import re
import traceback
import requests
import urllib.parse
from uuid import uuid4
import socket
import threading
import multiprocessing
import pickle
import hashlib
import logging
import fastapi
import uvicorn
from fastapi import FastAPI, HTTPException, Depends, Request, BackgroundTasks
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
# Jina.ai client for search, fact-checking, and web reading
class JinaClient:
"""Client for interacting with Jina.ai endpoints"""
def __init__(self, token: Optional[str] = None):
"""Initialize with your Jina token"""
self.token = token or os.getenv("JINA_API_KEY", "")
self.headers = {
"Authorization": f"Bearer {self.token}",
"Content-Type": "application/json"
}
def search(self, query: str) -> dict:
"""
Search using s.jina.ai endpoint
Args:
query: Search term
Returns:
API response as dict
"""
encoded_query = urllib.parse.quote(query)
url = f"https://s.jina.ai/{encoded_query}"
response = requests.get(url, headers=self.headers)
return response.json()
def fact_check(self, query: str) -> dict:
"""
Get grounding info using g.jina.ai endpoint
Args:
query: Query to ground
Returns:
API response as dict
"""
encoded_query = urllib.parse.quote(query)
url = f"https://g.jina.ai/{encoded_query}"
response = requests.get(url, headers=self.headers)
return response.json()
def reader(self, url: str) -> dict:
"""
Get ranking using r.jina.ai endpoint
Args:
url: URL to rank
Returns:
API response as dict
"""
encoded_url = urllib.parse.quote(url)
url = f"https://r.jina.ai/{encoded_url}"
response = requests.get(url, headers=self.headers)
return response.json()
# Check if RL tools are available
HAVE_RL_TOOLS = False
try:
# This is a placeholder for the actual import that would be used
from tool_optimizer import ToolSelectionManager
# If the import succeeds, set HAVE_RL_TOOLS to True
HAVE_RL_TOOLS = True
except ImportError:
# RL tools not available
# Define a dummy ToolSelectionManager to avoid NameError
class ToolSelectionManager:
def __init__(self, **kwargs):
self.optimizer = None
self.data_dir = kwargs.get('data_dir', '')
def record_tool_usage(self, **kwargs):
pass
# Load environment variables
load_dotenv()
# TODO: Add update checking similar to Claude Code's auto-update functionality
# TODO: Add configuration file support to store settings beyond environment variables
app = typer.Typer(help="OpenAI Code Assistant CLI")
console = Console()
# Global Constants
# TODO: Move these to a config file
DEFAULT_MODEL = "gpt-4o"
DEFAULT_TEMPERATURE = 0
MAX_TOKENS = 4096
TOKEN_LIMIT_WARNING = 0.8 # Warn when 80% of token limit is reached
# Models
# TODO: Implement more sophisticated schema validation similar to Zod in the original
# TODO: Add permission system for tools that require user approval
class ToolParameter(BaseModel):
name: str
description: str
type: str
required: bool = False
class Tool(BaseModel):
name: str
description: str
parameters: Dict[str, Any]
function: Callable
# TODO: Add needs_permission flag for sensitive operations
# TODO: Add category for organizing tools (file, search, etc.)
class Message(BaseModel):
role: str
content: Optional[str] = None
tool_calls: Optional[List[Dict[str, Any]]] = None
tool_call_id: Optional[str] = None
name: Optional[str] = None
# TODO: Add timestamp for message tracking
# TODO: Add token count for better context management
class Conversation:
def __init__(self):
self.messages = []
# TODO: Implement retry logic with exponential backoff for API calls
# TODO: Add support for multiple LLM providers
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
self.model = os.getenv("OPENAI_MODEL", DEFAULT_MODEL)
self.temperature = float(os.getenv("OPENAI_TEMPERATURE", DEFAULT_TEMPERATURE))
self.tools = self._register_tools()
self.tool_map = {tool.name: tool.function for tool in self.tools}
self.conversation_id = str(uuid4())
self.session_start_time = time.time()
self.token_usage = {"prompt": 0, "completion": 0, "total": 0}
self.verbose = False
self.max_tool_iterations = int(os.getenv("MAX_TOOL_ITERATIONS", "10"))
# Initialize tool selection optimizer if available
self.tool_optimizer = None
if HAVE_RL_TOOLS:
try:
# Create a simple tool registry adapter for the optimizer
class ToolRegistryAdapter:
def __init__(self, tools):
self.tools = tools
def get_all_tools(self):
return self.tools
def get_all_tool_names(self):
return [tool.name for tool in self.tools]
# Initialize the tool selection manager
self.tool_optimizer = ToolSelectionManager(
tool_registry=ToolRegistryAdapter(self.tools),
enable_optimization=os.getenv("ENABLE_TOOL_OPTIMIZATION", "1") == "1",
data_dir=os.path.join(os.path.dirname(os.path.abspath(__file__)), "data/rl")
)
if self.verbose:
print("Tool selection optimization enabled")
except Exception as e:
print(f"Warning: Failed to initialize tool optimizer: {e}")
# TODO: Implement context window management
# Jina.ai client for search, fact-checking, and web reading
def _init_jina_client(self):
"""Initialize the Jina.ai client"""
token = os.getenv("JINA_API_KEY", "")
return JinaClient(token)
def _jina_search(self, query: str) -> str:
"""Search the web using Jina.ai"""
try:
client = self._init_jina_client()
results = client.search(query)
if not results or not isinstance(results, dict):
return f"No search results found for '{query}'"
# Format the results
formatted_results = "Search Results:\n\n"
if "results" in results and isinstance(results["results"], list):
for i, result in enumerate(results["results"], 1):
title = result.get("title", "No title")
url = result.get("url", "No URL")
snippet = result.get("snippet", "No snippet")
formatted_results += f"{i}. {title}\n"
formatted_results += f" URL: {url}\n"
formatted_results += f" {snippet}\n\n"
else:
formatted_results += "Unexpected response format. Raw data:\n"
formatted_results += json.dumps(results, indent=2)[:1000]
return formatted_results
except Exception as e:
return f"Error performing search: {str(e)}"
def _jina_fact_check(self, statement: str) -> str:
"""Fact check a statement using Jina.ai"""
try:
client = self._init_jina_client()
results = client.fact_check(statement)
if not results or not isinstance(results, dict):
return f"No fact-checking results for '{statement}'"
# Format the results
formatted_results = "Fact Check Results:\n\n"
formatted_results += f"Statement: {statement}\n\n"
if "grounding" in results:
grounding = results["grounding"]
verdict = grounding.get("verdict", "Unknown")
confidence = grounding.get("confidence", 0)
formatted_results += f"Verdict: {verdict}\n"
formatted_results += f"Confidence: {confidence:.2f}\n\n"
if "sources" in grounding and isinstance(grounding["sources"], list):
formatted_results += "Sources:\n"
for i, source in enumerate(grounding["sources"], 1):
title = source.get("title", "No title")
url = source.get("url", "No URL")
formatted_results += f"{i}. {title}\n {url}\n\n"
else:
formatted_results += "Unexpected response format. Raw data:\n"
formatted_results += json.dumps(results, indent=2)[:1000]
return formatted_results
except Exception as e:
return f"Error performing fact check: {str(e)}"
def _jina_read_url(self, url: str) -> str:
"""Read and summarize a webpage using Jina.ai"""
try:
client = self._init_jina_client()
results = client.reader(url)
if not results or not isinstance(results, dict):
return f"No reading results for URL '{url}'"
# Format the results
formatted_results = f"Web Page Summary: {url}\n\n"
if "content" in results:
content = results["content"]
title = content.get("title", "No title")
summary = content.get("summary", "No summary available")
formatted_results += f"Title: {title}\n\n"
formatted_results += f"Summary:\n{summary}\n\n"
if "keyPoints" in content and isinstance(content["keyPoints"], list):
formatted_results += "Key Points:\n"
for i, point in enumerate(content["keyPoints"], 1):
formatted_results += f"{i}. {point}\n"
else:
formatted_results += "Unexpected response format. Raw data:\n"
formatted_results += json.dumps(results, indent=2)[:1000]
return formatted_results
except Exception as e:
return f"Error reading URL: {str(e)}"
def _register_tools(self) -> List[Tool]:
# TODO: Modularize tools into separate files
# TODO: Implement Tool decorators for easier registration
# TODO: Add more tools similar to Claude Code (ReadNotebook, NotebookEditCell, etc.)
# Define and register all tools
tools = [
Tool(
name="Weather",
description="Gets the current weather for a location",
parameters={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and optional state/country (e.g., 'San Francisco, CA' or 'London, UK')"
}
},
"required": ["location"]
},
function=self._get_weather
),
Tool(
name="View",
description="Reads a file from the local filesystem. The file_path parameter must be an absolute path, not a relative path.",
parameters={
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "The absolute path to the file to read"
},
"limit": {
"type": "number",
"description": "The number of lines to read. Only provide if the file is too large to read at once."
},
"offset": {
"type": "number",
"description": "The line number to start reading from. Only provide if the file is too large to read at once"
}
},
"required": ["file_path"]
},
function=self._view_file
),
Tool(
name="Edit",
description="This is a tool for editing files.",
parameters={
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "The absolute path to the file to modify"
},
"old_string": {
"type": "string",
"description": "The text to replace"
},
"new_string": {
"type": "string",
"description": "The text to replace it with"
}
},
"required": ["file_path", "old_string", "new_string"]
},
function=self._edit_file
),
Tool(
name="Replace",
description="Write a file to the local filesystem. Overwrites the existing file if there is one.",
parameters={
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "The absolute path to the file to write"
},
"content": {
"type": "string",
"description": "The content to write to the file"
}
},
"required": ["file_path", "content"]
},
function=self._replace_file
),
Tool(
name="Bash",
description="Executes a given bash command in a persistent shell session.",
parameters={
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "The command to execute"
},
"timeout": {
"type": "number",
"description": "Optional timeout in milliseconds (max 600000)"
}
},
"required": ["command"]
},
function=self._execute_bash
),
Tool(
name="GlobTool",
description="Fast file pattern matching tool that works with any codebase size.",
parameters={
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "The directory to search in. Defaults to the current working directory."
},
"pattern": {
"type": "string",
"description": "The glob pattern to match files against"
}
},
"required": ["pattern"]
},
function=self._glob_tool
),
Tool(
name="GrepTool",
description="Fast content search tool that works with any codebase size.",
parameters={
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "The directory to search in. Defaults to the current working directory."
},
"pattern": {
"type": "string",
"description": "The regular expression pattern to search for in file contents"
},
"include": {
"type": "string",
"description": "File pattern to include in the search (e.g. \"*.js\", \"*.{ts,tsx}\")"
}
},
"required": ["pattern"]
},
function=self._grep_tool
),
Tool(
name="LS",
description="Lists files and directories in a given path.",
parameters={
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "The absolute path to the directory to list"
},
"ignore": {
"type": "array",
"items": {
"type": "string"
},
"description": "List of glob patterns to ignore"
}
},
"required": ["path"]
},
function=self._list_directory
),
Tool(
name="JinaSearch",
description="Search the web for information using Jina.ai",
parameters={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query"
}
},
"required": ["query"]
},
function=self._jina_search
),
Tool(
name="JinaFactCheck",
description="Fact check a statement using Jina.ai",
parameters={
"type": "object",
"properties": {
"statement": {
"type": "string",
"description": "The statement to fact check"
}
},
"required": ["statement"]
},
function=self._jina_fact_check
),
Tool(
name="JinaReadURL",
description="Read and summarize a webpage using Jina.ai",
parameters={
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The URL of the webpage to read"
}
},
"required": ["url"]
},
function=self._jina_read_url
)
]
return tools
# Tool implementations
# TODO: Add better error handling and user feedback
# TODO: Implement tool usage tracking and metrics
def _get_weather(self, location: str) -> str:
"""Get current weather for a location using OpenWeatherMap API"""
try:
# Get API key from environment or use a default for testing
api_key = os.getenv("OPENWEATHER_API_KEY", "")
if not api_key:
return "Error: OpenWeatherMap API key not found. Please set the OPENWEATHER_API_KEY environment variable."
# Prepare the API request
base_url = "https://api.openweathermap.org/data/2.5/weather"
params = {
"q": location,
"appid": api_key,
"units": "metric" # Use metric units (Celsius)
}
# Make the API request
response = requests.get(base_url, params=params)
# Check if the request was successful
if response.status_code == 200:
data = response.text
# Try to parse as JSON
try:
data = json.loads(data)
except json.JSONDecodeError:
return f"Error: Unable to parse weather data. Raw response: {data[:200]}..."
# Extract relevant weather information
weather_desc = data["weather"][0]["description"]
temp = data["main"]["temp"]
feels_like = data["main"]["feels_like"]
humidity = data["main"]["humidity"]
wind_speed = data["wind"]["speed"]
# Format the response
weather_info = (
f"Current weather in {location}:\n"
f"• Condition: {weather_desc.capitalize()}\n"
f"• Temperature: {temp}°C ({(temp * 9/5) + 32:.1f}°F)\n"
f"• Feels like: {feels_like}°C ({(feels_like * 9/5) + 32:.1f}°F)\n"
f"• Humidity: {humidity}%\n"
f"• Wind speed: {wind_speed} m/s ({wind_speed * 2.237:.1f} mph)"
)
return weather_info
else:
# Handle API errors
if response.status_code == 404:
return f"Error: Location '{location}' not found. Please check the spelling or try a different location."
elif response.status_code == 401:
return "Error: Invalid API key. Please check your OpenWeatherMap API key."
else:
return f"Error: Unable to fetch weather data. Status code: {response.status_code}"
except requests.exceptions.RequestException as e:
return f"Error: Network error when fetching weather data: {str(e)}"
except Exception as e:
return f"Error: Failed to get weather information: {str(e)}"
def _view_file(self, file_path: str, limit: Optional[int] = None, offset: Optional[int] = 0) -> str:
# TODO: Add special handling for binary files and images
# TODO: Add syntax highlighting for code files
try:
if not os.path.exists(file_path):
return f"Error: File not found: {file_path}"
# TODO: Handle file size limits better
with open(file_path, 'r') as f:
if limit is not None and offset is not None:
# Skip to offset
for _ in range(offset):
next(f, None)
# Read limited lines
lines = []
for _ in range(limit):
line = next(f, None)
if line is None:
break
lines.append(line)
content = ''.join(lines)
else:
content = f.read()
# TODO: Add file metadata like size, permissions, etc.
return content
except Exception as e:
return f"Error reading file: {str(e)}"
def _edit_file(self, file_path: str, old_string: str, new_string: str) -> str:
try:
# Create directory if creating new file
if not os.path.exists(os.path.dirname(file_path)) and old_string == "":
os.makedirs(os.path.dirname(file_path), exist_ok=True)
if old_string == "" and not os.path.exists(file_path):
# Creating new file
with open(file_path, 'w') as f:
f.write(new_string)
return f"Created new file: {file_path}"
# Reading existing file
if not os.path.exists(file_path):
return f"Error: File not found: {file_path}"
with open(file_path, 'r') as f:
content = f.read()
# Replace string
if old_string not in content:
return f"Error: Could not find the specified text in {file_path}"
# Count occurrences to ensure uniqueness
occurrences = content.count(old_string)
if occurrences > 1:
return f"Error: Found {occurrences} occurrences of the specified text in {file_path}. Please provide more context to uniquely identify the text to replace."
new_content = content.replace(old_string, new_string)
# Write back to file
with open(file_path, 'w') as f:
f.write(new_content)
return f"Successfully edited {file_path}"
except Exception as e:
return f"Error editing file: {str(e)}"
def _replace_file(self, file_path: str, content: str) -> str:
try:
# Create directory if it doesn't exist
directory = os.path.dirname(file_path)
if directory and not os.path.exists(directory):
os.makedirs(directory, exist_ok=True)
# Write content to file
with open(file_path, 'w') as f:
f.write(content)
return f"Successfully wrote to {file_path}"
except Exception as e:
return f"Error writing file: {str(e)}"
def _execute_bash(self, command: str, timeout: Optional[int] = None) -> str:
try:
import subprocess
import shlex
# Security check for banned commands
banned_commands = [
'alias', 'curl', 'curlie', 'wget', 'axel', 'aria2c', 'nc',
'telnet', 'lynx', 'w3m', 'links', 'httpie', 'xh', 'http-prompt',
'chrome', 'firefox', 'safari'
]
for banned in banned_commands:
if banned in command.split():
return f"Error: The command '{banned}' is not allowed for security reasons."
# Execute command
if timeout:
timeout_seconds = timeout / 1000 # Convert to seconds
else:
timeout_seconds = 1800 # 30 minutes default
result = subprocess.run(
command,
shell=True,
capture_output=True,
text=True,
timeout=timeout_seconds
)
output = result.stdout
if result.stderr:
output += f"\nErrors:\n{result.stderr}"
# Truncate if too long
if len(output) > 30000:
output = output[:30000] + "\n... (output truncated)"
return output
except subprocess.TimeoutExpired:
return f"Error: Command timed out after {timeout_seconds} seconds"
except Exception as e:
return f"Error executing command: {str(e)}"
def _glob_tool(self, pattern: str, path: Optional[str] = None) -> str:
try:
import glob
import os
if path is None:
path = os.getcwd()
# Build the full pattern path
if not os.path.isabs(path):
path = os.path.abspath(path)
full_pattern = os.path.join(path, pattern)
# Get matching files
matches = glob.glob(full_pattern, recursive=True)
# Sort by modification time (newest first)
matches.sort(key=os.path.getmtime, reverse=True)
if not matches:
return f"No files matching pattern '{pattern}' in {path}"
return "\n".join(matches)
except Exception as e:
return f"Error in glob search: {str(e)}"
def _grep_tool(self, pattern: str, path: Optional[str] = None, include: Optional[str] = None) -> str:
try:
import re
import os
import fnmatch
from concurrent.futures import ThreadPoolExecutor
if path is None:
path = os.getcwd()
if not os.path.isabs(path):
path = os.path.abspath(path)
# Compile regex pattern
regex = re.compile(pattern)
# Get all files
all_files = []
for root, _, files in os.walk(path):
for file in files:
file_path = os.path.join(root, file)
# Apply include filter if provided
if include:
if not fnmatch.fnmatch(file, include):
continue
all_files.append(file_path)
# Sort by modification time (newest first)
all_files.sort(key=os.path.getmtime, reverse=True)
matches = []
def search_file(file_path):
try:
with open(file_path, 'r', errors='ignore') as f:
content = f.read()
if regex.search(content):
return file_path
except:
# Skip files that can't be read
pass
return None
# Search files in parallel
with ThreadPoolExecutor(max_workers=10) as executor:
results = executor.map(search_file, all_files)
for result in results:
if result:
matches.append(result)
if not matches:
return f"No matches found for pattern '{pattern}' in {path}"
return "\n".join(matches)
except Exception as e:
return f"Error in grep search: {str(e)}"
def _list_directory(self, path: str, ignore: Optional[List[str]] = None) -> str:
try:
import os
import fnmatch
# If path is not absolute, make it absolute from current directory
if not os.path.isabs(path):
path = os.path.abspath(os.path.join(os.getcwd(), path))
if not os.path.exists(path):
return f"Error: Directory not found: {path}"
if not os.path.isdir(path):
return f"Error: Path is not a directory: {path}"
# List directory contents
items = os.listdir(path)
# Apply ignore patterns
if ignore:
for pattern in ignore:
items = [item for item in items if not fnmatch.fnmatch(item, pattern)]
# Sort items
items.sort()
# Format output
result = []
for item in items:
item_path = os.path.join(path, item)
if os.path.isdir(item_path):
result.append(f"{item}/")
else:
result.append(item)
if not result:
return f"Directory {path} is empty"
return "\n".join(result)
except Exception as e:
return f"Error listing directory: {str(e)}"
def add_message(self, role: str, content: str):
"""Legacy method to add messages - use direct append now"""
self.messages.append({"role": role, "content": content})
def process_tool_calls(self, tool_calls, query=None):
# TODO: Add tool call validation
# TODO: Add permission system for sensitive tools
# TODO: Add progress visualization for long-running tools
responses = []
# Process tool calls in parallel
from concurrent.futures import ThreadPoolExecutor
def process_single_tool(tool_call):
# Handle both object-style and dict-style tool calls
if isinstance(tool_call, dict):
function_name = tool_call["function"]["name"]
function_args = json.loads(tool_call["function"]["arguments"])
tool_call_id = tool_call["id"]
else:
function_name = tool_call.function.name
function_args = json.loads(tool_call.function.arguments)
tool_call_id = tool_call.id
# Get the tool function
if function_name in self.tool_map:
# TODO: Add pre-execution validation
# TODO: Add permission check here
# Track start time for metrics
start_time = time.time()
try:
function = self.tool_map[function_name]
result = function(**function_args)
success = True
except Exception as e:
result = f"Error executing tool {function_name}: {str(e)}\n{traceback.format_exc()}"
success = False
# Calculate execution time
execution_time = time.time() - start_time
# Record tool usage for optimization if optimizer is available
if self.tool_optimizer is not None and query is not None:
try:
# Create current context snapshot
context = {
"messages": self.messages.copy(),
"conversation_id": self.conversation_id,
}
# Record tool usage
self.tool_optimizer.record_tool_usage(
query=query,
tool_name=function_name,
execution_time=execution_time,
token_usage=self.token_usage.copy(),
success=success,
context=context,
result=result
)
except Exception as e:
if self.verbose:
print(f"Warning: Failed to record tool usage: {e}")
return {
"tool_call_id": tool_call_id,
"function_name": function_name,
"result": result,
"name": function_name,
"execution_time": execution_time, # For metrics
"success": success
}
return None
# Process all tool calls in parallel
with ThreadPoolExecutor(max_workers=min(10, len(tool_calls))) as executor:
futures = [executor.submit(process_single_tool, tool_call) for tool_call in tool_calls]
for future in futures:
result = future.result()
if result:
# Add tool response to messages
self.messages.append({
"tool_call_id": result["tool_call_id"],
"role": "tool",
"name": result["name"],
"content": result["result"]
})
responses.append({
"tool_call_id": result["tool_call_id"],
"function_name": result["function_name"],
"result": result["result"]
})
# Log tool execution metrics if verbose
if self.verbose:
print(f"Tool {result['function_name']} executed in {result['execution_time']:.2f}s (success: {result['success']})")
# Return tool responses
return responses
def compact(self):
# TODO: Add more sophisticated compaction with token counting
# TODO: Implement selective retention of critical information
# TODO: Add option to save conversation history before compacting
system_prompt = next((m for m in self.messages if m["role"] == "system"), None)
user_messages = [m for m in self.messages if m["role"] == "user"]
if not user_messages:
return "No user messages to compact."
last_user_message = user_messages[-1]
# Create a compaction prompt
# TODO: Improve the compaction prompt with more guidance on what to retain
compact_prompt = (
"Summarize the conversation so far, focusing on the key points, decisions, and context. "
"Keep important details about the code and tasks. Retain critical file paths, commands, "
"and code snippets. The summary should be concise but complete enough to continue the "
"conversation effectively."
)
# Add compaction message
self.messages.append({"role": "user", "content": compact_prompt})
# Get compaction summary
# TODO: Add error handling for compaction API call
response = self.client.chat.completions.create(
model=self.model,
messages=self.messages,
stream=False
)
summary = response.choices[0].message.content
# Reset conversation with summary
if system_prompt:
self.messages = [system_prompt]
else:
self.messages = []
self.messages.append({"role": "system", "content": f"This is a compacted conversation. Previous context: {summary}"})
self.messages.append({"role": "user", "content": last_user_message["content"]})
# TODO: Add metrics for compaction (tokens before/after)
return "Conversation compacted successfully."
def get_response(self, user_input: str, stream: bool = True):
# TODO: Add more special commands similar to Claude Code (e.g., /version, /status)
# TODO: Implement binary feedback mechanism for comparing responses
# Special commands
if user_input.strip() == "/compact":
return self.compact()
# Add a debug command to help diagnose issues
if user_input.strip() == "/debug":
debug_info = {
"model": self.model,
"temperature": self.temperature,
"message_count": len(self.messages),
"token_usage": self.token_usage,
"conversation_id": self.conversation_id,
"session_duration": time.time() - self.session_start_time,
"tools_count": len(self.tools),
"python_version": sys.version,
"openai_version": OpenAI.__version__ if hasattr(OpenAI, "__version__") else "Unknown"
}
return "Debug Information:\n" + json.dumps(debug_info, indent=2)
if user_input.strip() == "/help":
# Standard commands
commands = [
"/help - Show this help message",
"/compact - Compact the conversation to reduce token usage",
"/status - Show token usage and session information",
"/config - Show current configuration settings",
]
# RL-specific commands if available
if self.tool_optimizer is not None:
commands.extend([
"/rl-status - Show RL tool optimizer status",
"/rl-update - Update the RL model manually",
"/rl-stats - Show tool usage statistics",
])
return "Available commands:\n" + "\n".join(commands)
# Token usage and session stats
if user_input.strip() == "/status":
# Calculate session duration
session_duration = time.time() - self.session_start_time
hours, remainder = divmod(session_duration, 3600)
minutes, seconds = divmod(remainder, 60)
# Format message
status = (
f"Session ID: {self.conversation_id}\n"
f"Model: {self.model} (Temperature: {self.temperature})\n"
f"Session duration: {int(hours)}h {int(minutes)}m {int(seconds)}s\n\n"
f"Token usage:\n"
f" Prompt tokens: {self.token_usage['prompt']}\n"
f" Completion tokens: {self.token_usage['completion']}\n"
f" Total tokens: {self.token_usage['total']}\n"
)
return status
# Configuration settings
if user_input.strip() == "/config":
config_info = (
f"Current Configuration:\n"
f" Model: {self.model}\n"
f" Temperature: {self.temperature}\n"
f" Max tool iterations: {self.max_tool_iterations}\n"
f" Verbose mode: {self.verbose}\n"
f" RL optimization: {self.tool_optimizer is not None}\n"
)
# Provide instructions for changing settings
config_info += "\nTo change settings, use:\n"
config_info += " /config set <setting> <value>\n"
config_info += "Example: /config set max_tool_iterations 15"
return config_info
# Handle configuration changes
if user_input.strip().startswith("/config set "):
parts = user_input.strip().split(" ", 3)
if len(parts) != 4:
return "Invalid format. Use: /config set <setting> <value>"
setting = parts[2]
value = parts[3]
if setting == "max_tool_iterations":
try:
self.max_tool_iterations = int(value)
return f"Max tool iterations set to {self.max_tool_iterations}"
except ValueError:
return "Invalid value. Please provide a number."
elif setting == "temperature":
try:
self.temperature = float(value)
return f"Temperature set to {self.temperature}"
except ValueError:
return "Invalid value. Please provide a number."
elif setting == "verbose":
if value.lower() in ("true", "yes", "1", "on"):
self.verbose = True
return "Verbose mode enabled"
elif value.lower() in ("false", "no", "0", "off"):
self.verbose = False
return "Verbose mode disabled"
else:
return "Invalid value. Use 'true' or 'false'."
elif setting == "model":
self.model = value
return f"Model set to {self.model}"
else:
return f"Unknown setting: {setting}"
# RL-specific commands
if self.tool_optimizer is not None:
# RL status command
if user_input.strip() == "/rl-status":
return (
f"RL tool optimization is active\n"
f"Optimizer type: {type(self.tool_optimizer).__name__}\n"
f"Number of tools: {len(self.tools)}\n"
f"Data directory: {self.tool_optimizer.optimizer.data_dir if hasattr(self.tool_optimizer, 'optimizer') else 'N/A'}\n"
)
# RL update command
if user_input.strip() == "/rl-update":
try:
result = self.tool_optimizer.optimizer.update_model()
status = f"RL model update status: {result['status']}\n"
if 'metrics' in result:
status += "Metrics:\n" + "\n".join([f" {k}: {v}" for k, v in result['metrics'].items()])
return status
except Exception as e:
return f"Error updating RL model: {str(e)}"
# RL stats command
if user_input.strip() == "/rl-stats":
try:
if hasattr(self.tool_optimizer, 'optimizer') and hasattr(self.tool_optimizer.optimizer, 'tracker'):
stats = self.tool_optimizer.optimizer.tracker.get_tool_stats()
if not stats:
return "No tool usage data available yet."
result = "Tool Usage Statistics:\n\n"
for tool_name, tool_stats in stats.items():
result += f"{tool_name}:\n"
result += f" Count: {tool_stats['count']}\n"
result += f" Success rate: {tool_stats['success_rate']:.2f}\n"
result += f" Avg time: {tool_stats['avg_time']:.2f}s\n"
result += f" Avg tokens: {tool_stats['avg_total_tokens']:.1f}\n"
result += "\n"
return result
return "Tool usage statistics not available."
except Exception as e:
return f"Error getting tool statistics: {str(e)}"
# TODO: Add /version command to show version information
# Add user message
self.messages.append({"role": "user", "content": user_input})
# Initialize empty response
response_text = ""
# Create tools list for API
# TODO: Add dynamic tool availability based on context
api_tools = []
for tool in self.tools:
api_tools.append({
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters
}
})
if stream:
# TODO: Add retry mechanism for API failures
# TODO: Add token tracking for response
# TODO: Implement cancellation support
# Stream response
try:
# Add retry logic for API calls
max_retries = 3
retry_count = 0
while retry_count < max_retries:
try:
stream = self.client.chat.completions.create(
model=self.model,
messages=self.messages,
tools=api_tools,
temperature=self.temperature,
stream=True
)
break # Success, exit retry loop
except Exception as e:
retry_count += 1
if retry_count >= max_retries:
raise # Re-raise if we've exhausted retries
# Exponential backoff
wait_time = 2 ** retry_count
if self.verbose:
console.print(f"[yellow]API call failed, retrying in {wait_time}s... ({retry_count}/{max_retries})[/yellow]")
time.sleep(wait_time)
current_tool_calls = []
tool_call_chunks = {}
# Process streaming response outside of the status context
with Live("", refresh_per_second=10) as live:
for chunk in stream:
# If there's content, print it
if chunk.choices[0].delta.content:
content_piece = chunk.choices[0].delta.content
response_text += content_piece
# Update the live display with the accumulated response
live.update(response_text)
# Process tool calls
delta = chunk.choices[0].delta
if delta.tool_calls:
for tool_call_delta in delta.tool_calls:
# Initialize tool call in chunks dictionary if new
if tool_call_delta.index not in tool_call_chunks:
tool_call_chunks[tool_call_delta.index] = {
"id": "",
"function": {"name": "", "arguments": ""}
}
# Update tool call data
if tool_call_delta.id:
tool_call_chunks[tool_call_delta.index]["id"] = tool_call_delta.id
if tool_call_delta.function:
if tool_call_delta.function.name:
tool_call_chunks[tool_call_delta.index]["function"]["name"] = tool_call_delta.function.name
if tool_call_delta.function.arguments:
tool_call_chunks[tool_call_delta.index]["function"]["arguments"] += tool_call_delta.function.arguments
# No need to print the response again as it was already streamed in the Live context
except Exception as e:
# TODO: Add better error handling and user feedback
console.print(f"[bold red]Error during API call:[/bold red] {str(e)}")
return f"Error during API call: {str(e)}"
# Convert tool call chunks to actual tool calls
for index, tool_call_data in tool_call_chunks.items():
current_tool_calls.append({
"id": tool_call_data["id"],
"function": {
"name": tool_call_data["function"]["name"],
"arguments": tool_call_data["function"]["arguments"]
}
})
# Process tool calls if any
if current_tool_calls:
try:
# Add assistant message with tool_calls to messages first
# Ensure each tool call has a "type" field set to "function"
processed_tool_calls = []
for tool_call in current_tool_calls:
processed_tool_call = tool_call.copy()
processed_tool_call["type"] = "function"
processed_tool_calls.append(processed_tool_call)
# Make sure we add the assistant message with tool calls before processing them
self.messages.append({
"role": "assistant",
"content": response_text,
"tool_calls": processed_tool_calls
})
# Now process the tool calls
with console.status("[bold green]Running tools..."):
tool_responses = self.process_tool_calls(current_tool_calls, query=user_input)
except Exception as e:
console.print(f"[bold red]Error:[/bold red] {str(e)}")
console.print(traceback.format_exc())
return f"Error processing tool calls: {str(e)}"
# Continue the conversation with tool responses
# Implement looping function calls to allow for recursive tool usage
max_loop_iterations = self.max_tool_iterations # Use configurable setting
current_iteration = 0
while current_iteration < max_loop_iterations:
# Add retry logic for follow-up API calls
max_retries = 3
retry_count = 0
follow_up = None
while retry_count < max_retries:
try:
follow_up = self.client.chat.completions.create(
model=self.model,
messages=self.messages,
tools=api_tools, # Pass tools to enable recursive function calling
stream=False
)
break # Success, exit retry loop
except Exception as e:
retry_count += 1
if retry_count >= max_retries:
raise # Re-raise if we've exhausted retries
# Exponential backoff
wait_time = 2 ** retry_count
if self.verbose:
console.print(f"[yellow]Follow-up API call failed, retrying in {wait_time}s... ({retry_count}/{max_retries})[/yellow]")
time.sleep(wait_time)
# Check if the follow-up response contains more tool calls
assistant_message = follow_up.choices[0].message
follow_up_text = assistant_message.content or ""
# If there are no more tool calls, we're done with the loop
if not hasattr(assistant_message, 'tool_calls') or not assistant_message.tool_calls:
if follow_up_text:
console.print(Markdown(follow_up_text))
response_text += "\n" + follow_up_text
# Add the final assistant message to the conversation
self.messages.append({"role": "assistant", "content": follow_up_text})
break
# Process the new tool calls
current_tool_calls = []
for tool_call in assistant_message.tool_calls:
# Handle both object-style and dict-style tool calls
if isinstance(tool_call, dict):
processed_tool_call = tool_call.copy()
else:
# Convert object to dict
processed_tool_call = {
"id": tool_call.id,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments
}
}
# Ensure type field is present
processed_tool_call["type"] = "function"
current_tool_calls.append(processed_tool_call)
# Add the assistant message with tool calls
self.messages.append({
"role": "assistant",
"content": follow_up_text,
"tool_calls": current_tool_calls
})
# Process the new tool calls
with console.status(f"[bold green]Running tools (iteration {current_iteration + 1})...[/bold green]"):
tool_responses = self.process_tool_calls(assistant_message.tool_calls, query=user_input)
# Increment the iteration counter
current_iteration += 1
# If we've reached the maximum number of iterations, add a warning
if current_iteration >= max_loop_iterations:
warning_message = f"[yellow]Warning: Reached maximum number of tool call iterations ({max_loop_iterations}). Some operations may be incomplete.[/yellow]"
console.print(warning_message)
response_text += f"\n\n{warning_message}"
# Add assistant response to messages if there were no tool calls
# (we already added it above if there were tool calls)
if not current_tool_calls:
self.messages.append({"role": "assistant", "content": response_text})
return response_text
else:
# Non-streaming response
response = self.client.chat.completions.create(
model=self.model,
messages=self.messages,
tools=api_tools,
temperature=self.temperature,
stream=False
)
# Track token usage
if hasattr(response, 'usage'):
self.token_usage["prompt"] += response.usage.prompt_tokens
self.token_usage["completion"] += response.usage.completion_tokens
self.token_usage["total"] += response.usage.total_tokens
assistant_message = response.choices[0].message
response_text = assistant_message.content or ""
# Process tool calls if any
if assistant_message.tool_calls:
# Add assistant message with tool_calls to messages
# Convert tool_calls to a list of dictionaries with "type" field
processed_tool_calls = []
for tool_call in assistant_message.tool_calls:
# Handle both object-style and dict-style tool calls
if isinstance(tool_call, dict):
processed_tool_call = tool_call.copy()
else:
# Convert object to dict
processed_tool_call = {
"id": tool_call.id,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments
}
}
# Ensure type field is present
processed_tool_call["type"] = "function"
processed_tool_calls.append(processed_tool_call)
self.messages.append({
"role": "assistant",
"content": response_text,
"tool_calls": processed_tool_calls
})
with console.status("[bold green]Running tools..."):
tool_responses = self.process_tool_calls(assistant_message.tool_calls, query=user_input)
# Continue the conversation with tool responses
# Implement looping function calls to allow for recursive tool usage
max_loop_iterations = self.max_tool_iterations # Use configurable setting
current_iteration = 0
while current_iteration < max_loop_iterations:
# Add retry logic for follow-up API calls
max_retries = 3
retry_count = 0
follow_up = None
while retry_count < max_retries:
try:
follow_up = self.client.chat.completions.create(
model=self.model,
messages=self.messages,
tools=api_tools, # Pass tools to enable recursive function calling
stream=False
)
break # Success, exit retry loop
except Exception as e:
retry_count += 1
if retry_count >= max_retries:
raise # Re-raise if we've exhausted retries
# Exponential backoff
wait_time = 2 ** retry_count
if self.verbose:
console.print(f"[yellow]Follow-up API call failed, retrying in {wait_time}s... ({retry_count}/{max_retries})[/yellow]")
time.sleep(wait_time)
# Check if the follow-up response contains more tool calls
assistant_message = follow_up.choices[0].message
follow_up_text = assistant_message.content or ""
# If there are no more tool calls, we're done with the loop
if not hasattr(assistant_message, 'tool_calls') or not assistant_message.tool_calls:
if follow_up_text:
console.print(Markdown(follow_up_text))
response_text += "\n" + follow_up_text
# Add the final assistant message to the conversation
self.messages.append({"role": "assistant", "content": follow_up_text})
break
# Process the new tool calls
current_tool_calls = []
for tool_call in assistant_message.tool_calls:
# Handle both object-style and dict-style tool calls
if isinstance(tool_call, dict):
processed_tool_call = tool_call.copy()
else:
# Convert object to dict
processed_tool_call = {
"id": tool_call.id,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments
}
}
# Ensure type field is present
processed_tool_call["type"] = "function"
current_tool_calls.append(processed_tool_call)
# Add the assistant message with tool calls
self.messages.append({
"role": "assistant",
"content": follow_up_text,
"tool_calls": current_tool_calls
})
# Process the new tool calls
with console.status(f"[bold green]Running tools (iteration {current_iteration + 1})...[/bold green]"):
tool_responses = self.process_tool_calls(assistant_message.tool_calls, query=user_input)
# Increment the iteration counter
current_iteration += 1
# If we've reached the maximum number of iterations, add a warning
if current_iteration >= max_loop_iterations:
warning_message = f"[yellow]Warning: Reached maximum number of tool call iterations ({max_loop_iterations}). Some operations may be incomplete.[/yellow]"
console.print(warning_message)
response_text += f"\n\n{warning_message}"
else:
console.print(Markdown(response_text))
# Add assistant response to messages if not already added
# (we already added it above if there were tool calls)
if not assistant_message.tool_calls:
self.messages.append({"role": "assistant", "content": response_text})
return response_text
# TODO: Create a more flexible system prompt mechanism with customizable templates
def get_system_prompt():
return """You are OpenAI Code Assistant, a CLI tool that helps users with software engineering tasks and general information.
Use the available tools to assist the user with their requests.
# Tone and style
You should be concise, direct, and to the point. When you run a non-trivial bash command,
you should explain what the command does and why you are running it.
Output text to communicate with the user; all text you output outside of tool use is displayed to the user.
Remember that your output will be displayed on a command line interface.
# Tool usage policy
- When doing file search, remember to search effectively with the available tools.
- Always use the appropriate tool for the task.
- Use parallel tool calls when appropriate to improve performance.
- NEVER commit changes unless the user explicitly asks you to.
- For weather queries, use the Weather tool to provide real-time information.
# Tasks
The user will primarily request you perform software engineering tasks:
1. Solving bugs
2. Adding new functionality
3. Refactoring code
4. Explaining code
5. Writing tests
For these tasks:
1. Use search tools to understand the codebase
2. Implement solutions using the available tools
3. Verify solutions with tests if possible
4. Run lint and typecheck commands when appropriate
The user may also ask for general information:
1. Weather conditions
2. Simple calculations
3. General knowledge questions
# Code style
- Follow the existing code style of the project
- Maintain consistent naming conventions
- Use appropriate libraries that are already in the project
- Add comments when code is complex or non-obvious
IMPORTANT: You should minimize output tokens as much as possible while maintaining helpfulness,
quality, and accuracy. Answer concisely with short lines of text unless the user asks for detail.
"""
# TODO: Add version information and CLI arguments
# TODO: Add logging configuration
# TODO: Create a proper CLI command structure with subcommands
# Hosting and replication capabilities
class HostingManager:
"""Manages hosting and replication of the assistant"""
def __init__(self, host="127.0.0.1", port=8000):
self.host = host
self.port = port
self.app = FastAPI(title="OpenAI Code Assistant API")
self.conversation_pool = {}
self.setup_api()
def setup_api(self):
"""Configure the FastAPI application"""
# Add CORS middleware
self.app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, restrict this to specific domains
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Define API routes
@self.app.get("/")
async def root():
return {"message": "OpenAI Code Assistant API", "status": "running"}
@self.app.post("/conversation")
async def create_conversation(
request: Request,
background_tasks: BackgroundTasks,
model: str = DEFAULT_MODEL,
temperature: float = DEFAULT_TEMPERATURE
):
"""Create a new conversation instance"""
conversation_id = str(uuid4())
# Initialize conversation in background
background_tasks.add_task(self._init_conversation, conversation_id, model, temperature)
return {
"conversation_id": conversation_id,
"status": "initializing",
"model": model
}
@self.app.post("/conversation/{conversation_id}/message")
async def send_message(
conversation_id: str,
request: Request
):
"""Send a message to a conversation"""
if conversation_id not in self.conversation_pool:
raise HTTPException(status_code=404, detail="Conversation not found")
data = await request.json()
user_input = data.get("message", "")
# Get conversation instance
conversation = self.conversation_pool[conversation_id]
# Process message
try:
response = conversation.get_response(user_input, stream=False)
return {
"conversation_id": conversation_id,
"response": response
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing message: {str(e)}")
@self.app.post("/conversation/{conversation_id}/message/stream")
async def stream_message(
conversation_id: str,
request: Request
):
"""Stream a message response from a conversation"""
if conversation_id not in self.conversation_pool:
raise HTTPException(status_code=404, detail="Conversation not found")
data = await request.json()
user_input = data.get("message", "")
# Get conversation instance
conversation = self.conversation_pool[conversation_id]
# Create async generator for streaming
async def response_generator():
# Add user message
conversation.messages.append({"role": "user", "content": user_input})
# Create tools list for API
api_tools = []
for tool in conversation.tools:
api_tools.append({
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters
}
})
# Stream response
try:
stream = conversation.client.chat.completions.create(
model=conversation.model,
messages=conversation.messages,
tools=api_tools,
temperature=conversation.temperature,
stream=True
)
current_tool_calls = []
tool_call_chunks = {}
response_text = ""
for chunk in stream:
# If there's content, yield it
if chunk.choices[0].delta.content:
content_piece = chunk.choices[0].delta.content
response_text += content_piece
yield json.dumps({"type": "content", "content": content_piece}) + "\n"
# Process tool calls
delta = chunk.choices[0].delta
if delta.tool_calls:
for tool_call_delta in delta.tool_calls:
# Initialize tool call in chunks dictionary if new
if tool_call_delta.index not in tool_call_chunks:
tool_call_chunks[tool_call_delta.index] = {
"id": "",
"function": {"name": "", "arguments": ""}
}
# Update tool call data
if tool_call_delta.id:
tool_call_chunks[tool_call_delta.index]["id"] = tool_call_delta.id
if tool_call_delta.function:
if tool_call_delta.function.name:
tool_call_chunks[tool_call_delta.index]["function"]["name"] = tool_call_delta.function.name
if tool_call_delta.function.arguments:
tool_call_chunks[tool_call_delta.index]["function"]["arguments"] += tool_call_delta.function.arguments
# Convert tool call chunks to actual tool calls
for index, tool_call_data in tool_call_chunks.items():
current_tool_calls.append({
"id": tool_call_data["id"],
"function": {
"name": tool_call_data["function"]["name"],
"arguments": tool_call_data["function"]["arguments"]
}
})
# Process tool calls if any
if current_tool_calls:
# Add assistant message with tool_calls to messages
processed_tool_calls = []
for tool_call in current_tool_calls:
processed_tool_call = tool_call.copy()
processed_tool_call["type"] = "function"
processed_tool_calls.append(processed_tool_call)
conversation.messages.append({
"role": "assistant",
"content": response_text,
"tool_calls": processed_tool_calls
})
# Notify client that tools are running
yield json.dumps({"type": "status", "status": "running_tools"}) + "\n"
# Process tool calls
tool_responses = conversation.process_tool_calls(current_tool_calls, query=user_input)
# Notify client of tool results
for response in tool_responses:
yield json.dumps({
"type": "tool_result",
"tool": response["function_name"],
"result": response["result"]
}) + "\n"
# Continue the conversation with tool responses
max_loop_iterations = conversation.max_tool_iterations
current_iteration = 0
while current_iteration < max_loop_iterations:
follow_up = conversation.client.chat.completions.create(
model=conversation.model,
messages=conversation.messages,
tools=api_tools,
stream=False
)
# Check if the follow-up response contains more tool calls
assistant_message = follow_up.choices[0].message
follow_up_text = assistant_message.content or ""
# If there are no more tool calls, we're done with the loop
if not hasattr(assistant_message, 'tool_calls') or not assistant_message.tool_calls:
if follow_up_text:
yield json.dumps({"type": "content", "content": follow_up_text}) + "\n"
# Add the final assistant message to the conversation
conversation.messages.append({"role": "assistant", "content": follow_up_text})
break
# Process the new tool calls
current_tool_calls = []
for tool_call in assistant_message.tool_calls:
if isinstance(tool_call, dict):
processed_tool_call = tool_call.copy()
else:
processed_tool_call = {
"id": tool_call.id,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments
}
}
processed_tool_call["type"] = "function"
current_tool_calls.append(processed_tool_call)
# Add the assistant message with tool calls
conversation.messages.append({
"role": "assistant",
"content": follow_up_text,
"tool_calls": current_tool_calls
})
# Notify client that tools are running
yield json.dumps({
"type": "status",
"status": f"running_tools_iteration_{current_iteration + 1}"
}) + "\n"
# Process the new tool calls
tool_responses = conversation.process_tool_calls(assistant_message.tool_calls, query=user_input)
# Notify client of tool results
for response in tool_responses:
yield json.dumps({
"type": "tool_result",
"tool": response["function_name"],
"result": response["result"]
}) + "\n"
# Increment the iteration counter
current_iteration += 1
# If we've reached the maximum number of iterations, add a warning
if current_iteration >= max_loop_iterations:
warning_message = f"Warning: Reached maximum number of tool call iterations ({max_loop_iterations}). Some operations may be incomplete."
yield json.dumps({"type": "warning", "warning": warning_message}) + "\n"
else:
# Add assistant response to messages
conversation.messages.append({"role": "assistant", "content": response_text})
# Signal completion
yield json.dumps({"type": "status", "status": "complete"}) + "\n"
except Exception as e:
yield json.dumps({"type": "error", "error": str(e)}) + "\n"
return StreamingResponse(response_generator(), media_type="text/event-stream")
@self.app.get("/conversation/{conversation_id}")
async def get_conversation(conversation_id: str):
"""Get conversation details"""
if conversation_id not in self.conversation_pool:
raise HTTPException(status_code=404, detail="Conversation not found")
conversation = self.conversation_pool[conversation_id]
return {
"conversation_id": conversation_id,
"model": conversation.model,
"temperature": conversation.temperature,
"message_count": len(conversation.messages),
"token_usage": conversation.token_usage
}
@self.app.delete("/conversation/{conversation_id}")
async def delete_conversation(conversation_id: str):
"""Delete a conversation"""
if conversation_id not in self.conversation_pool:
raise HTTPException(status_code=404, detail="Conversation not found")
del self.conversation_pool[conversation_id]
return {"status": "deleted", "conversation_id": conversation_id}
@self.app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"active_conversations": len(self.conversation_pool),
"uptime": time.time() - self.start_time
}
async def _init_conversation(self, conversation_id, model, temperature):
"""Initialize a conversation instance"""
conversation = Conversation()
conversation.model = model
conversation.temperature = temperature
conversation.messages.append({"role": "system", "content": get_system_prompt()})
self.conversation_pool[conversation_id] = conversation
def start(self):
"""Start the API server"""
self.start_time = time.time()
uvicorn.run(self.app, host=self.host, port=self.port)
def start_background(self):
"""Start the API server in a background thread"""
self.start_time = time.time()
thread = threading.Thread(target=uvicorn.run, args=(self.app,),
kwargs={"host": self.host, "port": self.port})
thread.daemon = True
thread.start()
return thread
class ReplicationManager:
"""Manages replication across multiple instances"""
def __init__(self, primary=True, sync_interval=60):
self.primary = primary
self.sync_interval = sync_interval
self.peers = []
self.conversation_cache = {}
self.last_sync = time.time()
self.sync_lock = threading.Lock()
def add_peer(self, host, port):
"""Add a peer instance to replicate with"""
peer = {"host": host, "port": port}
if peer not in self.peers:
self.peers.append(peer)
return True
return False
def remove_peer(self, host, port):
"""Remove a peer instance"""
peer = {"host": host, "port": port}
if peer in self.peers:
self.peers.remove(peer)
return True
return False
def sync_conversation(self, conversation_id, conversation):
"""Sync a conversation to all peers"""
if not self.peers:
return
# Serialize conversation
try:
serialized = pickle.dumps(conversation)
# Calculate hash for change detection
conversation_hash = hashlib.md5(serialized).hexdigest()
# Check if conversation has changed
if conversation_id in self.conversation_cache:
if self.conversation_cache[conversation_id] == conversation_hash:
return # No changes, skip sync
# Update cache
self.conversation_cache[conversation_id] = conversation_hash
# Sync to peers
for peer in self.peers:
try:
url = f"http://{peer['host']}:{peer['port']}/sync/conversation/{conversation_id}"
requests.post(url, data=serialized,
headers={"Content-Type": "application/octet-stream"})
except Exception as e:
logging.error(f"Failed to sync with peer {peer['host']}:{peer['port']}: {e}")
except Exception as e:
logging.error(f"Error serializing conversation: {e}")
def start_sync_thread(self, conversation_pool):
"""Start background thread for periodic syncing"""
def sync_worker():
while True:
time.sleep(self.sync_interval)
with self.sync_lock:
for conversation_id, conversation in conversation_pool.items():
self.sync_conversation(conversation_id, conversation)
thread = threading.Thread(target=sync_worker)
thread.daemon = True
thread.start()
return thread
@app.command()
def serve(
host: str = typer.Option("127.0.0.1", "--host", help="Host address to bind to"),
port: int = typer.Option(8000, "--port", "-p", help="Port to listen on"),
workers: int = typer.Option(1, "--workers", "-w", help="Number of worker processes"),
enable_replication: bool = typer.Option(False, "--enable-replication", help="Enable replication across instances"),
primary: bool = typer.Option(True, "--primary/--secondary", help="Whether this is a primary or secondary instance"),
peers: List[str] = typer.Option([], "--peer", help="Peer instances to replicate with (host:port)")
):
"""
Start the OpenAI Code Assistant as a web service
"""
console.print(Panel.fit(
f"[bold green]OpenAI Code Assistant API Server[/bold green]\n"
f"Host: {host}\n"
f"Port: {port}\n"
f"Workers: {workers}\n"
f"Replication: {'Enabled' if enable_replication else 'Disabled'}\n"
f"Role: {'Primary' if primary else 'Secondary'}\n"
f"Peers: {', '.join(peers) if peers else 'None'}",
title="Server Starting",
border_style="green"
))
# Check API key
if not os.getenv("OPENAI_API_KEY"):
console.print("[bold red]Error:[/bold red] No OpenAI API key found. Please set the OPENAI_API_KEY environment variable.")
return
# Start server
if workers > 1:
# Use multiprocessing for multiple workers
console.print(f"Starting server with {workers} workers...")
uvicorn.run(
"cli:create_app",
host=host,
port=port,
workers=workers,
factory=True
)
else:
# Single process mode
hosting_manager = HostingManager(host=host, port=port)
# Setup replication if enabled
if enable_replication:
replication_manager = ReplicationManager(primary=primary)
# Add peers
for peer in peers:
try:
peer_host, peer_port = peer.split(":")
replication_manager.add_peer(peer_host, int(peer_port))
except ValueError:
console.print(f"[yellow]Warning: Invalid peer format: {peer}. Use host:port format.[/yellow]")
# Start sync thread
replication_manager.start_sync_thread(hosting_manager.conversation_pool)
console.print(f"Replication enabled with {len(replication_manager.peers)} peers")
# Start server
hosting_manager.start()
def create_app():
"""Factory function for creating the FastAPI app (used with multiple workers)"""
hosting_manager = HostingManager()
return hosting_manager.app
@app.command()
def mcp_serve(
host: str = typer.Option("127.0.0.1", "--host", help="Host address to bind to"),
port: int = typer.Option(8000, "--port", "-p", help="Port to listen on"),
dev_mode: bool = typer.Option(False, "--dev", help="Enable development mode with additional logging"),
dependencies: List[str] = typer.Option([], "--dependencies", help="Additional Python dependencies to install"),
env_file: str = typer.Option(None, "--env-file", help="Path to .env file with environment variables"),
cache_type: str = typer.Option("memory", "--cache", help="Cache type: 'memory' or 'redis'"),
redis_url: str = typer.Option(None, "--redis-url", help="Redis URL for cache (if cache_type is 'redis')"),
reload: bool = typer.Option(False, "--reload", help="Enable auto-reload on code changes")
):
"""
Start the OpenAI Code Assistant as an MCP (Model Context Protocol) server
This allows the assistant to be used as a context provider for MCP clients
like Claude Desktop or other MCP-compatible applications.
"""
# Load environment variables from file if specified
if env_file:
if os.path.exists(env_file):
load_dotenv(env_file)
console.print(f"[green]Loaded environment variables from {env_file}[/green]")
else:
console.print(f"[yellow]Warning: Environment file {env_file} not found[/yellow]")
# Install additional dependencies if specified
required_deps = ["prometheus-client", "tiktoken"]
if cache_type == "redis":
required_deps.append("redis")
all_deps = required_deps + list(dependencies)
if all_deps:
console.print(f"[bold]Installing dependencies: {', '.join(all_deps)}[/bold]")
try:
import subprocess
subprocess.check_call([sys.executable, "-m", "pip", "install", *all_deps])
console.print("[green]Dependencies installed successfully[/green]")
except Exception as e:
console.print(f"[red]Error installing dependencies: {str(e)}[/red]")
return
# Configure logging for development mode
if dev_mode:
import logging
logging.basicConfig(level=logging.DEBUG)
console.print("[yellow]Development mode enabled with debug logging[/yellow]")
# Print server information
cache_info = f"Cache: {cache_type}"
if cache_type == "redis" and redis_url:
cache_info += f" ({redis_url})"
console.print(Panel.fit(
f"[bold green]OpenAI Code Assistant MCP Server[/bold green]\n"
f"Host: {host}\n"
f"Port: {port}\n"
f"Development Mode: {'Enabled' if dev_mode else 'Disabled'}\n"
f"Auto-reload: {'Enabled' if reload else 'Disabled'}\n"
f"{cache_info}\n"
f"API Key: {'Configured' if os.getenv('OPENAI_API_KEY') else 'Not Configured'}",
title="MCP Server Starting",
border_style="green"
))
# Check API key
if not os.getenv("OPENAI_API_KEY"):
console.print("[bold red]Error:[/bold red] No OpenAI API key found. Please set the OPENAI_API_KEY environment variable.")
return
# Create required directories
base_dir = os.path.dirname(os.path.abspath(__file__))
os.makedirs(os.path.join(base_dir, "data"), exist_ok=True)
os.makedirs(os.path.join(base_dir, "templates"), exist_ok=True)
os.makedirs(os.path.join(base_dir, "static"), exist_ok=True)
try:
# Import the MCP server module
from mcp_server import MCPServer
# Start the MCP server
server = MCPServer(cache_type=cache_type, redis_url=redis_url)
server.start(host=host, port=port, reload=reload)
except ImportError:
console.print("[bold red]Error:[/bold red] MCP server module not found. Make sure mcp_server.py is in the same directory.")
except Exception as e:
console.print(f"[bold red]Error starting MCP server:[/bold red] {str(e)}")
if dev_mode:
import traceback
console.print(traceback.format_exc())
@app.command()
def mcp_client(
server_path: str = typer.Argument(..., help="Path to the MCP server script or module"),
model: str = typer.Option("gpt-4o", "--model", "-m", help="Model to use for reasoning"),
host: str = typer.Option("127.0.0.1", "--host", help="Host address for the MCP server"),
port: int = typer.Option(8000, "--port", "-p", help="Port for the MCP server")
):
"""
Connect to an MCP server using OpenAI Code Assistant as the reasoning engine
This allows using the assistant to interact with any MCP-compatible server.
"""
console.print(Panel.fit(
f"[bold green]OpenAI Code Assistant MCP Client[/bold green]\n"
f"Server: {server_path}\n"
f"Model: {model}\n"
f"Host: {host}\n"
f"Port: {port}",
title="MCP Client Starting",
border_style="green"
))
# Check if server path exists
if not os.path.exists(server_path):
console.print(f"[bold red]Error:[/bold red] Server script not found at {server_path}")
return
# Check API key
if not os.getenv("OPENAI_API_KEY"):
console.print("[bold red]Error:[/bold red] No OpenAI API key found. Please set the OPENAI_API_KEY environment variable.")
return
try:
# Start the server in a subprocess
import subprocess
import signal
# Start server process
console.print(f"[bold]Starting MCP server from {server_path}...[/bold]")
server_process = subprocess.Popen(
[sys.executable, server_path],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True
)
# Wait for server to start
time.sleep(2)
# Check if server started successfully
if server_process.poll() is not None:
console.print("[bold red]Error:[/bold red] Failed to start MCP server")
stdout, stderr = server_process.communicate()
console.print(f"[red]Server output:[/red]\n{stdout}\n{stderr}")
return
console.print("[green]MCP server started successfully[/green]")
# Initialize conversation
conversation = Conversation()
conversation.model = model
# Add system prompt
conversation.messages.append({
"role": "system",
"content": "You are an MCP client connecting to a Model Context Protocol server. "
"Use the available tools to interact with the server and help the user."
})
# Register MCP-specific tools
mcp_tools = [
Tool(
name="MCPGetContext",
description="Get context from the MCP server using a prompt template",
parameters={
"type": "object",
"properties": {
"prompt_id": {
"type": "string",
"description": "ID of the prompt template to use"
},
"parameters": {
"type": "object",
"description": "Parameters for the prompt template"
}
},
"required": ["prompt_id"]
},
function=lambda prompt_id, parameters=None: _mcp_get_context(host, port, prompt_id, parameters or {})
),
Tool(
name="MCPListPrompts",
description="List available prompt templates from the MCP server",
parameters={
"type": "object",
"properties": {}
},
function=lambda: _mcp_list_prompts(host, port)
),
Tool(
name="MCPGetPrompt",
description="Get details of a specific prompt template from the MCP server",
parameters={
"type": "object",
"properties": {
"prompt_id": {
"type": "string",
"description": "ID of the prompt template to get"
}
},
"required": ["prompt_id"]
},
function=lambda prompt_id: _mcp_get_prompt(host, port, prompt_id)
)
]
# Add MCP tools to conversation
conversation.tools.extend(mcp_tools)
for tool in mcp_tools:
conversation.tool_map[tool.name] = tool.function
# Main interaction loop
console.print("[bold]MCP Client ready. Type your questions or commands.[/bold]")
console.print("[bold]Type 'exit' to quit.[/bold]")
while True:
try:
user_input = Prompt.ask("\n[bold blue]>>[/bold blue]")
# Handle exit
if user_input.lower() in ("exit", "quit", "/exit", "/quit"):
console.print("[bold yellow]Shutting down MCP client...[/bold yellow]")
break
# Get response
conversation.get_response(user_input)
except KeyboardInterrupt:
console.print("\n[bold yellow]Operation cancelled by user.[/bold yellow]")
if Prompt.ask("[bold]Exit?[/bold]", choices=["y", "n"], default="n") == "y":
break
continue
except Exception as e:
console.print(f"[bold red]Error:[/bold red] {str(e)}")
# Clean up
console.print("[bold]Stopping MCP server...[/bold]")
server_process.terminate()
server_process.wait(timeout=5)
except Exception as e:
console.print(f"[bold red]Error:[/bold red] {str(e)}")
# MCP client helper functions
def _mcp_get_context(host, port, prompt_id, parameters):
"""Get context from MCP server"""
try:
url = f"http://{host}:{port}/context"
response = requests.post(
url,
json={
"prompt_id": prompt_id,
"parameters": parameters
}
)
if response.status_code != 200:
return f"Error: {response.status_code} - {response.text}"
data = response.json()
return f"Context (ID: {data['context_id']}):\n\n{data['context']}"
except Exception as e:
return f"Error connecting to MCP server: {str(e)}"
def _mcp_list_prompts(host, port):
"""List available prompt templates from MCP server"""
try:
url = f"http://{host}:{port}/prompts"
response = requests.get(url)
if response.status_code != 200:
return f"Error: {response.status_code} - {response.text}"
data = response.json()
prompts = data.get("prompts", [])
if not prompts:
return "No prompt templates available"
result = "Available prompt templates:\n\n"
for prompt in prompts:
result += f"ID: {prompt['id']}\n"
result += f"Description: {prompt['description']}\n"
result += f"Parameters: {', '.join(prompt.get('parameters', {}).keys())}\n\n"
return result
except Exception as e:
return f"Error connecting to MCP server: {str(e)}"
def _mcp_get_prompt(host, port, prompt_id):
"""Get details of a specific prompt template"""
try:
url = f"http://{host}:{port}/prompts/{prompt_id}"
response = requests.get(url)
if response.status_code != 200:
return f"Error: {response.status_code} - {response.text}"
prompt = response.json()
result = f"Prompt Template: {prompt['id']}\n\n"
result += f"Description: {prompt['description']}\n\n"
result += "Parameters:\n"
for param_name, param_info in prompt.get("parameters", {}).items():
result += f"- {param_name}: {param_info.get('description', '')}\n"
result += f"\nTemplate:\n{prompt['template']}\n"
return result
except Exception as e:
return f"Error connecting to MCP server: {str(e)}"
@app.command()
def mcp_multi_agent(
server_path: str = typer.Argument(..., help="Path to the MCP server script or module"),
config: str = typer.Option(None, "--config", "-c", help="Path to agent configuration JSON file"),
host: str = typer.Option("127.0.0.1", "--host", help="Host address for the MCP server"),
port: int = typer.Option(8000, "--port", "-p", help="Port for the MCP server")
):
"""
Start a multi-agent MCP client with multiple specialized agents
This allows using multiple agents with different roles to collaborate
on complex tasks by connecting to an MCP server.
"""
# Load configuration
if config:
if not os.path.exists(config):
console.print(f"[bold red]Error:[/bold red] Configuration file not found at {config}")
return
try:
with open(config, 'r') as f:
config_data = json.load(f)
except Exception as e:
console.print(f"[bold red]Error loading configuration:[/bold red] {str(e)}")
return
else:
# Default configuration
config_data = {
"agents": [
{
"name": "Primary",
"role": "primary",
"system_prompt": "You are a helpful assistant that uses an MCP server to provide information.",
"model": "gpt-4o",
"temperature": 0.0
}
],
"coordination": {
"strategy": "single",
"primary_agent": "Primary"
},
"settings": {
"max_turns_per_agent": 1,
"enable_agent_reflection": False,
"enable_cross_agent_communication": False,
"enable_user_selection": False
}
}
# Display configuration
agent_names = [agent["name"] for agent in config_data["agents"]]
console.print(Panel.fit(
f"[bold green]OpenAI Code Assistant Multi-Agent MCP Client[/bold green]\n"
f"Server: {server_path}\n"
f"Host: {host}:{port}\n"
f"Agents: {', '.join(agent_names)}\n"
f"Coordination: {config_data['coordination']['strategy']}\n"
f"Primary Agent: {config_data['coordination']['primary_agent']}",
title="Multi-Agent MCP Client Starting",
border_style="green"
))
# Check if server path exists
if not os.path.exists(server_path):
console.print(f"[bold red]Error:[/bold red] Server script not found at {server_path}")
return
# Check API key
if not os.getenv("OPENAI_API_KEY"):
console.print("[bold red]Error:[/bold red] No OpenAI API key found. Please set the OPENAI_API_KEY environment variable.")
return
try:
# Start the server in a subprocess
import subprocess
import signal
# Start server process
console.print(f"[bold]Starting MCP server from {server_path}...[/bold]")
server_process = subprocess.Popen(
[sys.executable, server_path],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True
)
# Wait for server to start
time.sleep(2)
# Check if server started successfully
if server_process.poll() is not None:
console.print("[bold red]Error:[/bold red] Failed to start MCP server")
stdout, stderr = server_process.communicate()
console.print(f"[red]Server output:[/red]\n{stdout}\n{stderr}")
return
console.print("[green]MCP server started successfully[/green]")
# Initialize agents
agents = {}
for agent_config in config_data["agents"]:
# Create conversation for agent
agent = Conversation()
agent.model = agent_config.get("model", "gpt-4o")
agent.temperature = agent_config.get("temperature", 0.0)
# Add system prompt
agent.messages.append({
"role": "system",
"content": agent_config.get("system_prompt", "You are a helpful assistant.")
})
# Register MCP-specific tools
mcp_tools = [
Tool(
name="MCPGetContext",
description="Get context from the MCP server using a prompt template",
parameters={
"type": "object",
"properties": {
"prompt_id": {
"type": "string",
"description": "ID of the prompt template to use"
},
"parameters": {
"type": "object",
"description": "Parameters for the prompt template"
}
},
"required": ["prompt_id"]
},
function=lambda prompt_id, parameters=None: _mcp_get_context(host, port, prompt_id, parameters or {})
),
Tool(
name="MCPListPrompts",
description="List available prompt templates from the MCP server",
parameters={
"type": "object",
"properties": {}
},
function=lambda: _mcp_list_prompts(host, port)
),
Tool(
name="MCPGetPrompt",
description="Get details of a specific prompt template from the MCP server",
parameters={
"type": "object",
"properties": {
"prompt_id": {
"type": "string",
"description": "ID of the prompt template to get"
}
},
"required": ["prompt_id"]
},
function=lambda prompt_id: _mcp_get_prompt(host, port, prompt_id)
)
]
# Add MCP tools to agent
agent.tools.extend(mcp_tools)
for tool in mcp_tools:
agent.tool_map[tool.name] = tool.function
# Add agent to agents dictionary
agents[agent_config["name"]] = {
"config": agent_config,
"conversation": agent,
"history": []
}
# Get primary agent
primary_agent_name = config_data["coordination"]["primary_agent"]
if primary_agent_name not in agents:
console.print(f"[bold red]Error:[/bold red] Primary agent '{primary_agent_name}' not found in configuration")
return
# Main interaction loop
console.print("[bold]Multi-Agent MCP Client ready. Type your questions or commands.[/bold]")
console.print("[bold]Special commands:[/bold]")
console.print(" [blue]/agents[/blue] - List available agents")
console.print(" [blue]/talk <agent_name> <message>[/blue] - Send message to specific agent")
console.print(" [blue]/history[/blue] - Show conversation history")
console.print(" [blue]/exit[/blue] - Exit the client")
conversation_history = []
while True:
try:
user_input = Prompt.ask("\n[bold blue]>>[/bold blue]")
# Handle exit
if user_input.lower() in ("exit", "quit", "/exit", "/quit"):
console.print("[bold yellow]Shutting down multi-agent MCP client...[/bold yellow]")
break
# Handle special commands
if user_input.startswith("/agents"):
console.print("[bold]Available Agents:[/bold]")
for name, agent_data in agents.items():
role = agent_data["config"]["role"]
model = agent_data["config"]["model"]
console.print(f" [green]{name}[/green] ({role}, {model})")
continue
if user_input.startswith("/history"):
console.print("[bold]Conversation History:[/bold]")
for i, entry in enumerate(conversation_history, 1):
if entry["role"] == "user":
console.print(f"[blue]{i}. User:[/blue] {entry['content']}")
else:
console.print(f"[green]{i}. {entry['agent']}:[/green] {entry['content']}")
continue
if user_input.startswith("/talk "):
parts = user_input.split(" ", 2)
if len(parts) < 3:
console.print("[yellow]Usage: /talk <agent_name> <message>[/yellow]")
continue
agent_name = parts[1]
message = parts[2]
if agent_name not in agents:
console.print(f"[yellow]Agent '{agent_name}' not found. Use /agents to see available agents.[/yellow]")
continue
# Add message to history
conversation_history.append({
"role": "user",
"content": message,
"target_agent": agent_name
})
# Get response from specific agent
console.print(f"[bold]Asking {agent_name}...[/bold]")
agent = agents[agent_name]["conversation"]
response = agent.get_response(message)
# Add response to history
conversation_history.append({
"role": "assistant",
"agent": agent_name,
"content": response
})
# Add to agent's history
agents[agent_name]["history"].append({
"role": "user",
"content": message
})
agents[agent_name]["history"].append({
"role": "assistant",
"content": response
})
continue
# Regular message - use coordination strategy
strategy = config_data["coordination"]["strategy"]
# Add message to history
conversation_history.append({
"role": "user",
"content": user_input
})
if strategy == "single" or strategy == "primary":
# Just use the primary agent
agent = agents[primary_agent_name]["conversation"]
response = agent.get_response(user_input)
# Add response to history
conversation_history.append({
"role": "assistant",
"agent": primary_agent_name,
"content": response
})
# Add to agent's history
agents[primary_agent_name]["history"].append({
"role": "user",
"content": user_input
})
agents[primary_agent_name]["history"].append({
"role": "assistant",
"content": response
})
elif strategy == "round_robin":
# Ask each agent in turn
console.print("[bold]Consulting all agents...[/bold]")
for agent_name, agent_data in agents.items():
console.print(f"[bold]Response from {agent_name}:[/bold]")
agent = agent_data["conversation"]
response = agent.get_response(user_input)
# Add response to history
conversation_history.append({
"role": "assistant",
"agent": agent_name,
"content": response
})
# Add to agent's history
agent_data["history"].append({
"role": "user",
"content": user_input
})
agent_data["history"].append({
"role": "assistant",
"content": response
})
elif strategy == "voting":
# Ask all agents and show all responses
console.print("[bold]Collecting responses from all agents...[/bold]")
responses = {}
for agent_name, agent_data in agents.items():
agent = agent_data["conversation"]
response = agent.get_response(user_input)
responses[agent_name] = response
# Add to agent's history
agent_data["history"].append({
"role": "user",
"content": user_input
})
agent_data["history"].append({
"role": "assistant",
"content": response
})
# Display all responses
for agent_name, response in responses.items():
console.print(f"[bold]Response from {agent_name}:[/bold]")
console.print(response)
# Add response to history
conversation_history.append({
"role": "assistant",
"agent": agent_name,
"content": response
})
else:
console.print(f"[yellow]Unknown coordination strategy: {strategy}[/yellow]")
# Default to primary agent
agent = agents[primary_agent_name]["conversation"]
response = agent.get_response(user_input)
# Add response to history
conversation_history.append({
"role": "assistant",
"agent": primary_agent_name,
"content": response
})
# Add to agent's history
agents[primary_agent_name]["history"].append({
"role": "user",
"content": user_input
})
agents[primary_agent_name]["history"].append({
"role": "assistant",
"content": response
})
except KeyboardInterrupt:
console.print("\n[bold yellow]Operation cancelled by user.[/bold yellow]")
if Prompt.ask("[bold]Exit?[/bold]", choices=["y", "n"], default="n") == "y":
break
continue
except Exception as e:
console.print(f"[bold red]Error:[/bold red] {str(e)}")
# Clean up
console.print("[bold]Stopping MCP server...[/bold]")
server_process.terminate()
server_process.wait(timeout=5)
except Exception as e:
console.print(f"[bold red]Error:[/bold red] {str(e)}")
@app.command()
def main(
model: str = typer.Option(DEFAULT_MODEL, "--model", "-m", help="Specify the model to use"),
temperature: float = typer.Option(DEFAULT_TEMPERATURE, "--temperature", "-t", help="Set temperature for response generation"),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable verbose output with additional information"),
enable_rl: bool = typer.Option(True, "--enable-rl/--disable-rl", help="Enable/disable reinforcement learning for tool optimization"),
rl_update: bool = typer.Option(False, "--rl-update", help="Manually trigger an update of the RL model"),
):
"""
OpenAI Code Assistant - A command-line coding assistant
that uses OpenAI APIs with function calling and streaming
"""
# TODO: Check for updates on startup
# TODO: Add environment setup verification
# Create welcome panel with more details
rl_status = "enabled" if enable_rl else "disabled"
console.print(Panel.fit(
f"[bold green]OpenAI Code Assistant[/bold green]\n"
f"Model: {model} (Temperature: {temperature})\n"
f"Reinforcement Learning: {rl_status}\n"
"Type your questions or commands. Use /help for available commands.",
title="Welcome",
border_style="green"
))
# Check API key
if not os.getenv("OPENAI_API_KEY"):
console.print("[bold red]Error:[/bold red] No OpenAI API key found. Please set the OPENAI_API_KEY environment variable.")
console.print("You can create a .env file with your API key or set it in your environment.")
return
# Initialize conversation
conversation = Conversation()
# Override model and temperature if specified
if model != DEFAULT_MODEL:
conversation.model = model
conversation.temperature = temperature
# Configure verbose mode
conversation.verbose = verbose
# Configure RL mode
if not enable_rl and hasattr(conversation, 'tool_optimizer') and conversation.tool_optimizer is not None:
os.environ["ENABLE_TOOL_OPTIMIZATION"] = "0"
conversation.tool_optimizer = None
console.print("[yellow]Reinforcement learning disabled[/yellow]")
# Handle manual RL update if requested
if rl_update and hasattr(conversation, 'tool_optimizer') and conversation.tool_optimizer is not None:
try:
with console.status("[bold blue]Updating RL model...[/bold blue]"):
result = conversation.tool_optimizer.optimizer.update_model()
console.print(f"[green]RL model update result:[/green] {result['status']}")
if 'metrics' in result:
console.print(Panel.fit(
"\n".join([f"{k}: {v}" for k, v in result['metrics'].items()]),
title="RL Metrics",
border_style="blue"
))
except Exception as e:
console.print(f"[red]Error updating RL model:[/red] {e}")
# Add system prompt
conversation.messages.append({"role": "system", "content": get_system_prompt()})
# TODO: Add context collection for file system and git information
# TODO: Add session persistence to allow resuming conversations
# Main interaction loop
while True:
try:
user_input = Prompt.ask("\n[bold blue]>>[/bold blue]")
# Handle exit
if user_input.lower() in ("exit", "quit", "/exit", "/quit"):
console.print("[bold yellow]Goodbye![/bold yellow]")
break
# Get response without wrapping it in a status indicator
# This allows the streaming to work properly
try:
conversation.get_response(user_input)
except Exception as e:
console.print(f"[bold red]Error during response generation:[/bold red] {str(e)}")
# Provide more helpful error messages for common issues
if "api_key" in str(e).lower():
console.print("[yellow]Hint: Check your OpenAI API key.[/yellow]")
elif "rate limit" in str(e).lower():
console.print("[yellow]Hint: You've hit a rate limit. Try again in a moment.[/yellow]")
elif "context_length_exceeded" in str(e).lower() or "maximum context length" in str(e).lower():
console.print("[yellow]Hint: The conversation is too long. Try using /compact to reduce its size.[/yellow]")
elif "Missing required parameter" in str(e):
console.print("[yellow]Hint: There's an API format issue. Try restarting the conversation.[/yellow]")
# Offer recovery options
recovery_choice = Prompt.ask(
"[bold]Would you like to:[/bold]",
choices=["continue", "debug", "compact", "restart", "exit"],
default="continue"
)
if recovery_choice == "debug":
# Show debug information
debug_info = {
"model": conversation.model,
"temperature": conversation.temperature,
"message_count": len(conversation.messages),
"token_usage": conversation.token_usage,
"conversation_id": conversation.conversation_id,
"session_duration": time.time() - conversation.session_start_time,
"tools_count": len(conversation.tools),
"python_version": sys.version,
"openai_version": OpenAI.__version__ if hasattr(OpenAI, "__version__") else "Unknown"
}
console.print(Panel(json.dumps(debug_info, indent=2), title="Debug Information", border_style="yellow"))
elif recovery_choice == "compact":
# Compact the conversation
result = conversation.compact()
console.print(f"[green]{result}[/green]")
elif recovery_choice == "restart":
# Restart the conversation
conversation = Conversation()
conversation.model = model
conversation.temperature = temperature
conversation.verbose = verbose
conversation.messages.append({"role": "system", "content": get_system_prompt()})
console.print("[green]Conversation restarted.[/green]")
elif recovery_choice == "exit":
console.print("[bold yellow]Goodbye![/bold yellow]")
break
except KeyboardInterrupt:
console.print("\n[bold yellow]Operation cancelled by user.[/bold yellow]")
# Offer options after cancellation
cancel_choice = Prompt.ask(
"[bold]Would you like to:[/bold]",
choices=["continue", "exit"],
default="continue"
)
if cancel_choice == "exit":
console.print("[bold yellow]Goodbye![/bold yellow]")
break
continue
except Exception as e:
console.print(f"[bold red]Unexpected error:[/bold red] {str(e)}")
import traceback
console.print(traceback.format_exc())
# Ask if user wants to continue despite the error
if Prompt.ask("[bold]Continue?[/bold]", choices=["y", "n"], default="y") == "n":
break
if __name__ == "__main__":
app()
```