#
tokens: 43355/50000 2/56 files (page 4/4)
lines: off (toggle) GitHub
raw markdown copy
This is page 4 of 4. Use http://codebase.md/arthurcolle/openai-mcp?page={x} to view the full context.

# Directory Structure

```
├── .gitignore
├── claude_code
│   ├── __init__.py
│   ├── __pycache__
│   │   ├── __init__.cpython-312.pyc
│   │   └── mcp_server.cpython-312.pyc
│   ├── claude.py
│   ├── commands
│   │   ├── __init__.py
│   │   ├── __pycache__
│   │   │   ├── __init__.cpython-312.pyc
│   │   │   └── serve.cpython-312.pyc
│   │   ├── client.py
│   │   ├── multi_agent_client.py
│   │   └── serve.py
│   ├── config
│   │   └── __init__.py
│   ├── examples
│   │   ├── agents_config.json
│   │   ├── claude_mcp_config.html
│   │   ├── claude_mcp_config.json
│   │   ├── echo_server.py
│   │   └── README.md
│   ├── lib
│   │   ├── __init__.py
│   │   ├── __pycache__
│   │   │   └── __init__.cpython-312.pyc
│   │   ├── context
│   │   │   └── __init__.py
│   │   ├── monitoring
│   │   │   ├── __init__.py
│   │   │   ├── __pycache__
│   │   │   │   ├── __init__.cpython-312.pyc
│   │   │   │   └── server_metrics.cpython-312.pyc
│   │   │   ├── cost_tracker.py
│   │   │   └── server_metrics.py
│   │   ├── providers
│   │   │   ├── __init__.py
│   │   │   ├── base.py
│   │   │   └── openai.py
│   │   ├── rl
│   │   │   ├── __init__.py
│   │   │   ├── grpo.py
│   │   │   ├── mcts.py
│   │   │   └── tool_optimizer.py
│   │   ├── tools
│   │   │   ├── __init__.py
│   │   │   ├── __pycache__
│   │   │   │   ├── __init__.cpython-312.pyc
│   │   │   │   ├── base.cpython-312.pyc
│   │   │   │   ├── file_tools.cpython-312.pyc
│   │   │   │   └── manager.cpython-312.pyc
│   │   │   ├── ai_tools.py
│   │   │   ├── base.py
│   │   │   ├── code_tools.py
│   │   │   ├── file_tools.py
│   │   │   ├── manager.py
│   │   │   └── search_tools.py
│   │   └── ui
│   │       ├── __init__.py
│   │       └── tool_visualizer.py
│   ├── mcp_server.py
│   ├── README_MCP_CLIENT.md
│   ├── README_MULTI_AGENT.md
│   └── util
│       └── __init__.py
├── claude.py
├── cli.py
├── data
│   └── prompt_templates.json
├── deploy_modal_mcp.py
├── deploy.sh
├── examples
│   ├── agents_config.json
│   └── echo_server.py
├── install.sh
├── mcp_modal_adapter.py
├── mcp_server.py
├── modal_mcp_server.py
├── README_modal_mcp.md
├── README.md
├── requirements.txt
├── setup.py
├── static
│   └── style.css
├── templates
│   └── index.html
└── web-client.html
```

# Files

--------------------------------------------------------------------------------
/modal_mcp_server.py:
--------------------------------------------------------------------------------

```python
import modal
import logging
import time
import uuid
import json
import asyncio
import hashlib
import threading
import concurrent.futures
from pathlib import Path
from typing import Dict, List, Optional, Any, Tuple, Union, AsyncIterator
from datetime import datetime, timedelta
from collections import deque

from fastapi import FastAPI, Request, Depends, HTTPException, status, BackgroundTasks
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)

# Create FastAPI app
api_app = FastAPI(
    title="Advanced LLM Inference API", 
    description="Enterprise-grade OpenAI-compatible LLM serving API with multiple model support, streaming, and advanced caching",
    version="1.1.0"
)

# Add CORS middleware
api_app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # For production, specify specific origins instead of wildcard
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Security setup
security = HTTPBearer()

# Token bucket rate limiter
class TokenBucket:
    """
    Token bucket algorithm for rate limiting.
    Each user gets a bucket that fills at a constant rate.
    """
    def __init__(self):
        self.buckets = {}
        self.lock = threading.Lock()
    
    def _get_bucket(self, user_id, rate_limit):
        """Get or create a bucket for a user"""
        now = time.time()
        
        if user_id not in self.buckets:
            # Initialize with full bucket
            self.buckets[user_id] = {
                "tokens": rate_limit,
                "last_refill": now,
                "rate": rate_limit / 60.0  # tokens per second
            }
            return self.buckets[user_id]
        
        bucket = self.buckets[user_id]
        
        # Update rate if it changed
        bucket["rate"] = rate_limit / 60.0
        
        # Refill tokens based on time elapsed
        elapsed = now - bucket["last_refill"]
        new_tokens = elapsed * bucket["rate"]
        
        bucket["tokens"] = min(rate_limit, bucket["tokens"] + new_tokens)
        bucket["last_refill"] = now
        
        return bucket
    
    def consume(self, user_id, tokens=1, rate_limit=60):
        """
        Consume tokens from a user's bucket.
        Returns True if tokens were consumed, False otherwise.
        """
        with self.lock:
            bucket = self._get_bucket(user_id, rate_limit)
            
            if bucket["tokens"] >= tokens:
                bucket["tokens"] -= tokens
                return True
            return False

# Create rate limiter
rate_limiter = TokenBucket()

# Define the container image with necessary dependencies
vllm_image = (
    modal.Image.debian_slim(python_version="3.10")
    .pip_install(
        "vllm==0.7.3",  # Updated version
        "huggingface_hub[hf_transfer]==0.26.2",
        "flashinfer-python==0.2.0.post2",
        "fastapi>=0.95.0",
        "uvicorn>=0.15.0",
        "pydantic>=2.0.0",
        "tiktoken>=0.5.1",
        extra_index_url="https://flashinfer.ai/whl/cu124/torch2.5",
    )
    .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})  # faster model transfers
    .env({"VLLM_USE_V1": "1"})  # Enable V1 engine for better performance
)

# Define llama.cpp image for alternative models
llama_cpp_image = (
    modal.Image.debian_slim(python_version="3.10")
    .apt_install("git", "build-essential", "cmake", "curl", "libcurl4-openssl-dev")
    .pip_install(
        "huggingface_hub==0.26.2",
        "hf_transfer>=0.1.4",
        "fastapi>=0.95.0",
        "uvicorn>=0.15.0",
        "pydantic>=2.0.0"
    )
    .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
    .run_commands(
        "git clone https://github.com/ggerganov/llama.cpp",
        "cmake llama.cpp -B llama.cpp/build -DBUILD_SHARED_LIBS=OFF -DLLAMA_CURL=ON",
        "cmake --build llama.cpp/build --config Release -j --target llama-cli",
        "cp llama.cpp/build/bin/llama-* /usr/local/bin/"
    )
)

# Set up model configurations
MODELS_DIR = "/models"
VLLM_MODELS = {
    "llama3-8b": {
        "id": "llama3-8b",
        "name": "neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w4a16",
        "config": "config.json",  # Ensure this file is present in the model directory
        "revision": "a7c09948d9a632c2c840722f519672cd94af885d",
        "max_tokens": 4096,
        "loaded": False
    },
    "mistral-7b": {
        "id": "mistral-7b",
        "name": "mistralai/Mistral-7B-Instruct-v0.2",
        "revision": "main",
        "max_tokens": 4096,
        "loaded": False
    },
    # Small model for quick loading
    "tiny-llama-1.1b": {
        "id": "tiny-llama-1.1b",
        "name": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
        "revision": "main",
        "max_tokens": 2048,
        "loaded": False
    }
}

LLAMA_CPP_MODELS = {
    "deepseek-r1": {
        "id": "deepseek-r1",
        "name": "unsloth/DeepSeek-R1-GGUF",
        "quant": "UD-IQ1_S",
        "pattern": "*UD-IQ1_S*",
        "revision": "02656f62d2aa9da4d3f0cdb34c341d30dd87c3b6",
        "gpu": "L40S:4",
        "max_tokens": 4096,
        "loaded": False
    },
    "phi-4": {
        "id": "phi-4",
        "name": "unsloth/phi-4-GGUF",
        "quant": "Q2_K",
        "pattern": "*Q2_K*",
        "revision": None,
        "gpu": "L40S:4",  # Use GPU for better performance
        "max_tokens": 4096,
        "loaded": False
    },
    # Small model for quick loading
    "phi-2": {
        "id": "phi-2",
        "name": "TheBloke/phi-2-GGUF",
        "quant": "Q4_K_M",
        "pattern": "*Q4_K_M.gguf",
        "revision": "main",
        "gpu": None,  # Can run on CPU
        "max_tokens": 2048,
        "loaded": False
    }
}

DEFAULT_MODEL = "phi-4"

# Create volumes for caching
hf_cache_vol = modal.Volume.from_name("huggingface-cache", create_if_missing=True)
vllm_cache_vol = modal.Volume.from_name("vllm-cache", create_if_missing=True)
llama_cpp_cache_vol = modal.Volume.from_name("llama-cpp-cache", create_if_missing=True)
results_vol = modal.Volume.from_name("model-results", create_if_missing=True)

# Create the Modal app
app = modal.App("openai-compatible-llm-server")

# Create shared data structures
model_stats_dict = modal.Dict.from_name("model-stats", create_if_missing=True)
user_usage_dict = modal.Dict.from_name("user-usage", create_if_missing=True)
request_queue = modal.Queue.from_name("request-queue", create_if_missing=True)
response_dict = modal.Dict.from_name("response-cache", create_if_missing=True)
api_keys_dict = modal.Dict.from_name("api-keys", create_if_missing=True)
stream_queues = modal.Dict.from_name("stream-queues", create_if_missing=True)

# Advanced caching system
class AdvancedCache:
    """
    Advanced caching system with TTL and LRU eviction.
    """
    def __init__(self, max_size=1000, default_ttl=3600):
        self.cache = {}
        self.ttl_map = {}
        self.access_times = {}
        self.max_size = max_size
        self.default_ttl = default_ttl
        self.lock = threading.Lock()
    
    def get(self, key):
        """Get a value from the cache"""
        with self.lock:
            now = time.time()
            
            # Check if key exists and is not expired
            if key in self.cache:
                # Check TTL
                if key in self.ttl_map and self.ttl_map[key] < now:
                    # Expired
                    self._remove(key)
                    return None
                
                # Update access time
                self.access_times[key] = now
                return self.cache[key]
            
            return None
    
    def set(self, key, value, ttl=None):
        """Set a value in the cache with optional TTL"""
        with self.lock:
            now = time.time()
            
            # Evict if needed
            if len(self.cache) >= self.max_size and key not in self.cache:
                self._evict_lru()
            
            # Set value
            self.cache[key] = value
            self.access_times[key] = now
            
            # Set TTL
            if ttl is not None:
                self.ttl_map[key] = now + ttl
            elif self.default_ttl > 0:
                self.ttl_map[key] = now + self.default_ttl
    
    def _remove(self, key):
        """Remove a key from the cache"""
        if key in self.cache:
            del self.cache[key]
        if key in self.ttl_map:
            del self.ttl_map[key]
        if key in self.access_times:
            del self.access_times[key]
    
    def _evict_lru(self):
        """Evict least recently used item"""
        if not self.access_times:
            return
        
        # Find oldest access time
        oldest_key = min(self.access_times.items(), key=lambda x: x[1])[0]
        self._remove(oldest_key)
    
    def clear_expired(self):
        """Clear all expired entries"""
        with self.lock:
            now = time.time()
            expired_keys = [k for k, v in self.ttl_map.items() if v < now]
            for key in expired_keys:
                self._remove(key)

# Constants
MAX_CACHE_AGE = 3600  # 1 hour in seconds

# Create memory cache
memory_cache = AdvancedCache(max_size=10000, default_ttl=MAX_CACHE_AGE)

# Initialize with default key if empty
if "default" not in api_keys_dict:
    api_keys_dict["default"] = {
        "key": "sk-modal-llm-api-key",
        "rate_limit": 60,  # requests per minute
        "quota": 1000000,  # tokens per day
        "created_at": datetime.now().isoformat(),
        "owner": "default"
    }

# Add a default ADMIN API key
if "admin" not in api_keys_dict:
    api_keys_dict["admin"] = {
        "key": "sk-modal-admin-api-key",
        "rate_limit": 1000,  # Higher rate limit for admin
        "quota": 10000000,  # Higher quota for admin
        "created_at": datetime.now().isoformat(),
        "owner": "admin"
    }

# Constants
DEFAULT_API_KEY = api_keys_dict["default"]["key"]
MINUTES = 60  # seconds
SERVER_PORT = 8000
CACHE_DIR = "/root/.cache"
RESULTS_DIR = "/root/results"

# Request/response models
class GenerationRequest(BaseModel):
    request_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
    model_id: str
    messages: List[Dict[str, str]]
    temperature: float = 0.7
    max_tokens: int = 1024
    top_p: float = 1.0
    frequency_penalty: float = 0.0
    presence_penalty: float = 0.0
    user: Optional[str] = None
    stream: bool = False
    timestamp: float = Field(default_factory=time.time)
    api_key: str = DEFAULT_API_KEY
    
class StreamChunk(BaseModel):
    """Model for streaming response chunks"""
    id: str
    object: str = "chat.completion.chunk"
    created: int
    model: str
    choices: List[Dict[str, Any]]
    
class StreamManager:
    """Manages streaming responses for clients"""
    def __init__(self):
        self.streams = {}
        self.lock = threading.Lock()
    
    def create_stream(self, request_id):
        """Create a new stream for a request"""
        with self.lock:
            self.streams[request_id] = {
                "queue": asyncio.Queue(),
                "finished": False,
                "created_at": time.time()
            }
    
    def add_chunk(self, request_id, chunk):
        """Add a chunk to a stream"""
        with self.lock:
            if request_id in self.streams:
                stream = self.streams[request_id]
                if not stream["finished"]:
                    stream["queue"].put_nowait(chunk)
    
    def finish_stream(self, request_id):
        """Mark a stream as finished"""
        with self.lock:
            if request_id in self.streams:
                self.streams[request_id]["finished"] = True
                # Add None to signal end of stream
                self.streams[request_id]["queue"].put_nowait(None)
    
    async def get_chunks(self, request_id):
        """Get chunks from a stream as an async generator"""
        if request_id not in self.streams:
            return
        
        stream = self.streams[request_id]
        queue = stream["queue"]
        
        while True:
            chunk = await queue.get()
            if chunk is None:  # End of stream
                break
            yield chunk
            queue.task_done()
        
        # Clean up after streaming is done
        with self.lock:
            if request_id in self.streams:
                del self.streams[request_id]
    
    def clean_old_streams(self, max_age=3600):
        """Clean up old streams"""
        with self.lock:
            now = time.time()
            to_remove = []
            
            for request_id, stream in self.streams.items():
                if now - stream["created_at"] > max_age:
                    to_remove.append(request_id)
            
            for request_id in to_remove:
                if request_id in self.streams:
                    # Mark as finished to stop any ongoing processing
                    self.streams[request_id]["finished"] = True
                    # Add None to unblock any waiting consumers
                    self.streams[request_id]["queue"].put_nowait(None)
                    # Remove from streams
                    del self.streams[request_id]

# Create stream manager
stream_manager = StreamManager()

# API Authentication dependency
def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
    """Verify that the API key in the authorization header is valid and check rate limits"""
    if credentials.scheme != "Bearer":
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid authentication scheme. Use Bearer",
        )
    
    api_key = credentials.credentials
    valid_key = False
    key_info = None
    
    # Check if this is a known API key
    for user_id, user_data in api_keys_dict.items():
        if user_data.get("key") == api_key:
            valid_key = True
            key_info = user_data
            break
    
    if not valid_key:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid API key",
        )
    
    # Check rate limits
    user_id = key_info.get("owner", "unknown")
    rate_limit = key_info.get("rate_limit", 60)  # Default: 60 requests per minute
    
    # Get or initialize user usage tracking
    if user_id not in user_usage_dict:
        user_usage_dict[user_id] = {
            "requests": [],
            "tokens": {
                "input": 0,
                "output": 0,
                "last_reset": datetime.now().isoformat()
            }
        }
    
    usage = user_usage_dict[user_id]
    
    # Check if user exceeded rate limit using token bucket algorithm
    if not rate_limiter.consume(user_id, tokens=1, rate_limit=rate_limit):
        # Calculate retry-after based on rate
        retry_after = int(60 / rate_limit)  # seconds until at least one token is available
        
        raise HTTPException(
            status_code=status.HTTP_429_TOO_MANY_REQUESTS,
            detail=f"Rate limit exceeded. Maximum {rate_limit} requests per minute.",
            headers={"Retry-After": str(retry_after)}
        )
    
    # Add current request timestamp for analytics
    now = datetime.now()
    usage["requests"].append(now.timestamp())
    
    # Clean up old requests (older than 1 day) to prevent unbounded growth
    day_ago = (now - timedelta(days=1)).timestamp()
    usage["requests"] = [req for req in usage["requests"] if req > day_ago]
    
    # Update usage dict
    user_usage_dict[user_id] = usage
    
    # Return the API key and user ID
    return {"key": api_key, "user_id": user_id}

# API Endpoints
@api_app.get("/", response_class=HTMLResponse)
async def index():
    """Root endpoint that returns HTML with API information"""
    return """
    <html>
        <head>
            <title>Modal LLM Inference API</title>
            <style>
                body { font-family: system-ui, sans-serif; max-width: 800px; margin: 0 auto; padding: 2rem; }
                h1 { color: #4a56e2; }
                code { background: #f4f4f8; padding: 0.2rem 0.4rem; border-radius: 3px; }
            </style>
        </head>
        <body>
            <h1>Modal LLM Inference API</h1>
            <p>This is an OpenAI-compatible API for LLM inference powered by Modal.</p>
            <p>Use the following endpoints:</p>
            <ul>
                <li><a href="/docs">/docs</a> - API documentation</li>
                <li><a href="/v1/models">/v1/models</a> - List available models</li>
                <li><code>/v1/chat/completions</code> - Chat completions endpoint</li>
            </ul>
        </body>
    </html>
    """

@api_app.get("/health")
async def health_check():
    """Health check endpoint"""
    return {"status": "healthy"}

@api_app.get("/v1/models", dependencies=[Depends(verify_api_key)])
async def list_models():
    """List all available models in OpenAI-compatible format"""
    # Combine vLLM and llama.cpp models
    all_models = []
    
    for model_id, model_info in VLLM_MODELS.items():
        all_models.append({
            "id": model_info["id"],
            "object": "model",
            "created": 1677610602,
            "owned_by": "modal",
            "engine": "vllm",
            "loaded": model_info.get("loaded", False)
        })
        
    for model_id, model_info in LLAMA_CPP_MODELS.items():
        all_models.append({
            "id": model_info["id"],
            "object": "model",
            "created": 1677610602,
            "owned_by": "modal",
            "engine": "llama.cpp",
            "loaded": model_info.get("loaded", False)
        })
        
    return {"data": all_models, "object": "list"}

# Model management endpoints
class ModelLoadRequest(BaseModel):
    """Request model to load a specific model"""
    model_id: str
    force_reload: bool = False
    
class HFModelLoadRequest(BaseModel):
    """Request to load a model directly from Hugging Face"""
    repo_id: str
    model_type: str = "vllm"  # "vllm" or "llama.cpp"
    revision: Optional[str] = None
    quant: Optional[str] = None  # For llama.cpp models
    max_tokens: int = 4096
    gpu: Optional[str] = None  # For llama.cpp models

@api_app.post("/admin/models/load", dependencies=[Depends(verify_api_key)])
async def load_model(request: ModelLoadRequest, background_tasks: BackgroundTasks):
    """Load a specific model into memory"""
    model_id = request.model_id
    force_reload = request.force_reload
    
    # Check if model exists
    if model_id in VLLM_MODELS:
        model_type = "vllm"
        model_info = VLLM_MODELS[model_id]
    elif model_id in LLAMA_CPP_MODELS:
        model_type = "llama.cpp"
        model_info = LLAMA_CPP_MODELS[model_id]
    else:
        raise HTTPException(
            status_code=status.HTTP_404_NOT_FOUND,
            detail=f"Model {model_id} not found"
        )
    
    # Check if model is already loaded
    if model_info.get("loaded", False) and not force_reload:
        return {
            "status": "success",
            "message": f"Model {model_id} is already loaded",
            "model_id": model_id,
            "model_type": model_type
        }
    
    # Start loading the model in the background
    if model_type == "vllm":
        # Start vLLM server for this model
        background_tasks.add_task(serve_vllm_model.remote, model_id=model_id)
        # Update model status
        VLLM_MODELS[model_id]["loaded"] = True
    else:  # llama.cpp
        # For llama.cpp models, we'll preload the model
        background_tasks.add_task(preload_llama_cpp_model, model_id)
        # Update model status
        LLAMA_CPP_MODELS[model_id]["loaded"] = True
    
    return {
        "status": "success",
        "message": f"Started loading model {model_id}",
        "model_id": model_id,
        "model_type": model_type
    }

@api_app.post("/admin/models/load-from-hf", dependencies=[Depends(verify_api_key)])
async def load_model_from_hf(request: HFModelLoadRequest, background_tasks: BackgroundTasks):
    """Load a model directly from Hugging Face"""
    repo_id = request.repo_id
    model_type = request.model_type
    revision = request.revision
    
    # Generate a unique model_id based on the repo name
    repo_name = repo_id.split("/")[-1] if "/" in repo_id else repo_id
    model_id = f"hf-{repo_name}-{uuid.uuid4().hex[:6]}"
    
    # Create model info based on type
    if model_type.lower() == "vllm":
        # Add to VLLM_MODELS
        VLLM_MODELS[model_id] = {
            "id": model_id,
            "name": repo_id,
            "revision": revision or "main",
            "max_tokens": request.max_tokens,
            "loaded": False,
            "hf_direct": True  # Mark as directly loaded from HF
        }
        
        # Start vLLM server for this model
        background_tasks.add_task(serve_vllm_model.remote, model_id=model_id)
        # Update model status
        VLLM_MODELS[model_id]["loaded"] = True
        
    elif model_type.lower() == "llama.cpp":
        # For llama.cpp we need quant info
        quant = request.quant or "Q4_K_M"  # Default quantization
        pattern = f"*{quant}*"
        
        # Add to LLAMA_CPP_MODELS
        LLAMA_CPP_MODELS[model_id] = {
            "id": model_id,
            "name": repo_id,
            "quant": quant,
            "pattern": pattern,
            "revision": revision,
            "gpu": request.gpu,  # Can be None for CPU
            "max_tokens": request.max_tokens,
            "loaded": False,
            "hf_direct": True  # Mark as directly loaded from HF
        }
        
        # Preload the model
        background_tasks.add_task(preload_llama_cpp_model, model_id)
        # Update model status
        LLAMA_CPP_MODELS[model_id]["loaded"] = True
        
    else:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=f"Invalid model type: {model_type}. Must be 'vllm' or 'llama.cpp'"
        )
    
    return {
        "status": "success",
        "message": f"Started loading model {repo_id} as {model_id}",
        "model_id": model_id,
        "model_type": model_type,
        "repo_id": repo_id
    }

@api_app.post("/admin/models/unload", dependencies=[Depends(verify_api_key)])
async def unload_model(request: ModelLoadRequest):
    """Unload a specific model from memory"""
    model_id = request.model_id
    
    # Check if model exists
    if model_id in VLLM_MODELS:
        model_type = "vllm"
        model_info = VLLM_MODELS[model_id]
    elif model_id in LLAMA_CPP_MODELS:
        model_type = "llama.cpp"
        model_info = LLAMA_CPP_MODELS[model_id]
    else:
        raise HTTPException(
            status_code=status.HTTP_404_NOT_FOUND,
            detail=f"Model {model_id} not found"
        )
    
    # Check if model is loaded
    if not model_info.get("loaded", False):
        return {
            "status": "success",
            "message": f"Model {model_id} is not loaded",
            "model_id": model_id,
            "model_type": model_type
        }
    
    # Update model status
    if model_type == "vllm":
        VLLM_MODELS[model_id]["loaded"] = False
    else:  # llama.cpp
        LLAMA_CPP_MODELS[model_id]["loaded"] = False
    
    return {
        "status": "success",
        "message": f"Unloaded model {model_id}",
        "model_id": model_id,
        "model_type": model_type
    }

@api_app.get("/admin/models/status/{model_id}", dependencies=[Depends(verify_api_key)])
async def get_model_status(model_id: str):
    """Get the status of a specific model"""
    # Check if model exists
    if model_id in VLLM_MODELS:
        model_type = "vllm"
        model_info = VLLM_MODELS[model_id]
    elif model_id in LLAMA_CPP_MODELS:
        model_type = "llama.cpp"
        model_info = LLAMA_CPP_MODELS[model_id]
    else:
        raise HTTPException(
            status_code=status.HTTP_404_NOT_FOUND,
            detail=f"Model {model_id} not found"
        )
    
    # Get model stats if available
    model_stats = model_stats_dict.get(model_id, {})
    
    # Include HF info if available
    hf_info = {}
    if model_info.get("hf_direct"):
        hf_info = {
            "repo_id": model_info.get("name"),
            "revision": model_info.get("revision"),
        }
        if model_type == "llama.cpp":
            hf_info["quant"] = model_info.get("quant")
    
    return {
        "model_id": model_id,
        "model_type": model_type,
        "loaded": model_info.get("loaded", False),
        "stats": model_stats,
        "hf_info": hf_info if hf_info else None
    }

# Admin API endpoints
class APIKeyRequest(BaseModel):
    user_id: str
    rate_limit: int = 60
    quota: int = 1000000
    
class APIKey(BaseModel):
    key: str
    user_id: str
    rate_limit: int
    quota: int
    created_at: str

@api_app.post("/admin/api-keys", response_model=APIKey)
async def create_api_key(request: APIKeyRequest, auth_info: dict = Depends(verify_api_key)):
    """Create a new API key for a user (admin only)"""
    # Check if this is an admin request
    if auth_info["user_id"] != "default":
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN,
            detail="Only admin users can create API keys"
        )
    
    # Generate a new API key
    new_key = f"sk-modal-{uuid.uuid4()}"
    user_id = request.user_id
    
    # Store the key
    api_keys_dict[user_id] = {
        "key": new_key,
        "rate_limit": request.rate_limit,
        "quota": request.quota,
        "created_at": datetime.now().isoformat(),
        "owner": user_id
    }
    
    # Initialize user usage
    if not user_usage_dict.contains(user_id):
        user_usage_dict[user_id] = {
            "requests": [],
            "tokens": {
                "input": 0,
                "output": 0,
                "last_reset": datetime.now().isoformat()
            }
        }
    
    return APIKey(
        key=new_key,
        user_id=user_id,
        rate_limit=request.rate_limit,
        quota=request.quota,
        created_at=datetime.now().isoformat()
    )

@api_app.get("/admin/api-keys")
async def list_api_keys(auth_info: dict = Depends(verify_api_key)):
    """List all API keys (admin only)"""
    # Check if this is an admin request
    if auth_info["user_id"] != "default":
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN,
            detail="Only admin users can list API keys"
        )
    
    # Return all keys (except the actual key values for security)
    keys = []
    for user_id, key_info in api_keys_dict.items():
        keys.append({
            "user_id": user_id,
            "rate_limit": key_info.get("rate_limit", 60),
            "quota": key_info.get("quota", 1000000),
            "created_at": key_info.get("created_at", datetime.now().isoformat()),
            # Mask the actual key
            "key": key_info.get("key", "")[:8] + "..." if key_info.get("key") else "None"
        })
    
    return {"keys": keys}

@api_app.get("/admin/stats")
async def get_stats(auth_info: dict = Depends(verify_api_key)):
    """Get usage statistics (admin only)"""
    # Check if this is an admin request
    if auth_info["user_id"] != "default":
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN,
            detail="Only admin users can view stats"
        )
    
    # Get model stats
    model_stats = {}
    for model_id in list(VLLM_MODELS.keys()) + list(LLAMA_CPP_MODELS.keys()):
        if model_id in model_stats_dict:
            model_stats[model_id] = model_stats_dict[model_id]
    
    # Get user stats
    user_stats = {}
    for user_id in user_usage_dict.keys():
        usage = user_usage_dict[user_id]
        # Don't include request timestamps for brevity
        if "requests" in usage:
            usage = usage.copy()
            usage["request_count"] = len(usage["requests"])
            del usage["requests"]
        user_stats[user_id] = usage
    
    # Get queue info
    queue_info = {
        "pending_requests": request_queue.len(),
        "active_workers": model_stats_dict.get("workers_running", 0)
    }
    
    return {
        "models": model_stats,
        "users": user_stats,
        "queue": queue_info,
        "timestamp": datetime.now().isoformat()
    }

@api_app.delete("/admin/api-keys/{user_id}")
async def delete_api_key(user_id: str, auth_info: dict = Depends(verify_api_key)):
    """Delete an API key (admin only)"""
    # Check if this is an admin request
    if auth_info["user_id"] != "default":
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN,
            detail="Only admin users can delete API keys"
        )
    
    # Check if the key exists
    if not api_keys_dict.contains(user_id):
        raise HTTPException(
            status_code=status.HTTP_404_NOT_FOUND,
            detail=f"No API key found for user {user_id}"
        )
    
    # Can't delete the default key
    if user_id == "default":
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="Cannot delete the default API key"
        )
    
    # Delete the key
    api_keys_dict.pop(user_id)
    
    return {"status": "success", "message": f"API key deleted for user {user_id}"}

@api_app.post("/v1/chat/completions")
async def chat_completions(request: Request, background_tasks: BackgroundTasks, auth_info: dict = Depends(verify_api_key)):
    """OpenAI-compatible chat completions endpoint with request queueing, streaming and response caching"""
    try:
        json_data = await request.json()
        
        # Extract model or use default
        model_id = json_data.get("model", DEFAULT_MODEL)
        messages = json_data.get("messages", [])
        temperature = json_data.get("temperature", 0.7)
        max_tokens = json_data.get("max_tokens", 1024)
        stream = json_data.get("stream", False)
        user = json_data.get("user", auth_info["user_id"])
        
        # Calculate a cache key based on the request parameters
        cache_key = calculate_cache_key(model_id, messages, temperature, max_tokens)
        
        # Check if we have a cached response in memory cache first (faster)
        cached_response = memory_cache.get(cache_key)
        if cached_response and not stream:  # Don't use cache for streaming requests
            # Update stats
            update_stats(model_id, "cache_hit")
            return cached_response
        
        # Check if we have a cached response in Modal's persistent cache
        if not cached_response and cache_key in response_dict and not stream:
            cached_response = response_dict[cache_key]
            cache_age = time.time() - cached_response.get("timestamp", 0)
            
            # Use cached response if it's fresh enough
            if cache_age < MAX_CACHE_AGE:
                # Update stats
                update_stats(model_id, "cache_hit")
                response_data = cached_response["response"]
                
                # Also cache in memory for faster access next time
                memory_cache.set(cache_key, response_data)
                
                return response_data
        
        # Select best model if "auto" is specified
        if model_id == "auto" and len(messages) > 0:
            # Get the last user message
            last_message = None
            for msg in reversed(messages):
                if msg.get("role") == "user":
                    last_message = msg.get("content", "")
                    break
            
            if last_message:
                prompt = last_message
                # Select best model based on prompt and parameters
                model_id = select_best_model(prompt, max_tokens, temperature)
                logging.info(f"Auto-selected model: {model_id} for prompt")
        
        # Check if model exists
        if model_id not in VLLM_MODELS and model_id not in LLAMA_CPP_MODELS:
            # Default to the default model if specified model not found
            logging.warning(f"Model {model_id} not found, using default: {DEFAULT_MODEL}")
            model_id = DEFAULT_MODEL
        
        # Create a unique request ID
        request_id = str(uuid.uuid4())
        
        # Create request object
        gen_request = GenerationRequest(
            request_id=request_id,
            model_id=model_id,
            messages=messages,
            temperature=temperature,
            max_tokens=max_tokens,
            top_p=json_data.get("top_p", 1.0),
            frequency_penalty=json_data.get("frequency_penalty", 0.0),
            presence_penalty=json_data.get("presence_penalty", 0.0),
            user=user,
            stream=stream,
            api_key=auth_info["key"]
        )
        
        # For streaming requests, set up streaming response
        if stream:
            # Create a new stream
            stream_manager.create_stream(request_id)
            
            # Put the request in the queue
            await request_queue.put.aio(gen_request.model_dump())
            
            # Update stats
            update_stats(model_id, "request_count")
            update_stats(model_id, "stream_count")
            
            # Start a background worker to process the request if needed
            background_tasks.add_task(ensure_worker_running)
            
            # Return a streaming response using FastAPI's StreamingResponse
            from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
            return FastAPIStreamingResponse(
                content=stream_response(request_id, model_id, auth_info["user_id"]),
                media_type="text/event-stream"
            )
            
        # For non-streaming, enqueue the request and wait for result
        # Put the request in the queue
        await request_queue.put.aio(gen_request.model_dump())
        
        # Update stats
        update_stats(model_id, "request_count")
        
        # Start a background worker to process the request if needed
        background_tasks.add_task(ensure_worker_running)
        
        # Wait for the response with timeout
        start_time = time.time()
        timeout = 120  # 2-minute timeout for non-streaming requests
        
        while time.time() - start_time < timeout:
            # Check memory cache first (faster)
            response_data = memory_cache.get(request_id)
            if response_data:
                # Update stats
                update_stats(model_id, "success_count")
                estimate_tokens(messages, response_data, auth_info["user_id"], model_id)
                
                # Save to persistent cache
                response_dict[cache_key] = {
                    "response": response_data,
                    "timestamp": time.time()
                }
                
                # Clean up request-specific cache
                memory_cache.set(request_id, None)
                
                return response_data
                
            # Check persistent cache
            if response_dict.contains(request_id):
                response_data = response_dict[request_id]
                
                # Remove from response dict to save memory
                try:
                    response_dict.pop(request_id)
                except Exception:
                    pass
                
                # Save to cache
                response_dict[cache_key] = {
                    "response": response_data,
                    "timestamp": time.time()
                }
                
                # Also cache in memory
                memory_cache.set(cache_key, response_data)
                
                # Update stats
                update_stats(model_id, "success_count")
                estimate_tokens(messages, response_data, auth_info["user_id"], model_id)
                
                return response_data
            
            # Wait a bit before checking again
            await asyncio.sleep(0.1)
        
        # If we get here, we timed out
        update_stats(model_id, "timeout_count")
        raise HTTPException(
            status_code=status.HTTP_504_GATEWAY_TIMEOUT,
            detail="Request timed out. The model may be busy. Please try again later."
        )
            
    except Exception as e:
        logging.error(f"Error in chat completions: {str(e)}")
        # Update error stats
        if "model_id" in locals():
            update_stats(model_id, "error_count")
            
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"Error generating response: {str(e)}"
        )

async def stream_response(request_id: str, model_id: str, user_id: str) -> AsyncIterator[str]:
    """Stream response chunks to the client"""
    try:
        # Stream header
        yield "data: " + json.dumps({"object": "chat.completion.chunk"}) + "\n\n"
        
        # Stream chunks
        async for chunk in stream_manager.get_chunks(request_id):
            if chunk:
                yield f"data: {json.dumps(chunk)}\n\n"
        
        # Stream done
        yield "data: [DONE]\n\n"
        
    except Exception as e:
        logging.error(f"Error streaming response: {str(e)}")
        # Update error stats
        update_stats(model_id, "stream_error_count")
        
        # Send error as SSE
        error_json = json.dumps({"error": str(e)})
        yield f"data: {error_json}\n\n"
        yield "data: [DONE]\n\n"
        
async def ensure_worker_running():
    """Ensure that a worker is running to process the queue"""
    # Check if workers are already running via a sentinel in shared dict
    workers_running_key = "workers_running"
    
    if not model_stats_dict.contains(workers_running_key):
        model_stats_dict[workers_running_key] = 0
    
    current_workers = model_stats_dict[workers_running_key]
    
    # If no workers or too few workers, start more
    if current_workers < 3:  # Keep up to 3 workers running
        # Increment worker count
        model_stats_dict[workers_running_key] = current_workers + 1
        
        # Start a worker
        await process_queue_worker.spawn.aio()

def calculate_cache_key(model_id: str, messages: List[dict], temperature: float, max_tokens: int) -> str:
    """Calculate a deterministic cache key for a request using SHA-256"""
    # Create a simplified version of the request for cache key
    cache_dict = {
        "model": model_id,
        "messages": messages,
        "temperature": round(temperature, 2),  # Round to reduce variations
        "max_tokens": max_tokens
    }
    # Convert to a stable string representation and hash it with SHA-256
    cache_str = json.dumps(cache_dict, sort_keys=True)
    hash_obj = hashlib.sha256(cache_str.encode())
    return f"cache:{hash_obj.hexdigest()[:16]}"

def update_stats(model_id: str, stat_type: str):
    """Update usage statistics for a model"""
    if not model_stats_dict.contains(model_id):
        model_stats_dict[model_id] = {
            "request_count": 0,
            "success_count": 0,
            "error_count": 0,
            "timeout_count": 0,
            "cache_hit": 0,
            "token_count": 0,
            "avg_latency": 0
        }
    
    stats = model_stats_dict[model_id]
    stats[stat_type] = stats.get(stat_type, 0) + 1
    model_stats_dict[model_id] = stats
    
def estimate_tokens(messages: List[dict], response: dict, user_id: str, model_id: str):
    """Estimate token usage and update user quotas"""
    # Very simple token estimation based on whitespace-split words * 1.3
    input_tokens = 0
    for msg in messages:
        input_tokens += len(msg.get("content", "").split()) * 1.3
    
    output_tokens = 0
    if response and "choices" in response:
        for choice in response["choices"]:
            if "message" in choice and "content" in choice["message"]:
                output_tokens += len(choice["message"]["content"].split()) * 1.3
    
    # Update model stats
    if model_stats_dict.contains(model_id):
        stats = model_stats_dict[model_id]
        stats["token_count"] = stats.get("token_count", 0) + input_tokens + output_tokens
        model_stats_dict[model_id] = stats
    
    # Update user usage
    if user_id in user_usage_dict:
        usage = user_usage_dict[user_id]
        
        # Check if we need to reset daily counters
        last_reset = datetime.fromisoformat(usage["tokens"]["last_reset"])
        now = datetime.now()
        
        if now.date() > last_reset.date():
            # Reset daily counters
            usage["tokens"]["input"] = 0
            usage["tokens"]["output"] = 0
            usage["tokens"]["last_reset"] = now.isoformat()
        
        # Update token counts
        usage["tokens"]["input"] += int(input_tokens)
        usage["tokens"]["output"] += int(output_tokens)
        user_usage_dict[user_id] = usage

def select_best_model(prompt: str, n_predict: int, temperature: float) -> str:
    """
    Intelligently selects the best model based on input parameters.

    Args:
        prompt (str): The input prompt for the model.
        n_predict (int): The number of tokens to predict.
        temperature (float): The sampling temperature.

    Returns:
        str: The identifier of the best model to use.
    """
    # Check for code generation patterns
    code_indicators = ["```", "def ", "class ", "function", "import ", "from ", "<script", "<style", 
                      "SELECT ", "CREATE TABLE", "const ", "let ", "var ", "function(", "=>"]
    
    is_likely_code = any(indicator in prompt for indicator in code_indicators)
    
    # Check for creative writing patterns
    creative_indicators = ["story", "poem", "creative", "imagine", "fiction", "narrative", 
                          "write a", "compose", "create a"]
    
    is_creative_task = any(indicator in prompt.lower() for indicator in creative_indicators)
    
    # Check for analytical/reasoning tasks
    analytical_indicators = ["explain", "analyze", "compare", "contrast", "reason", 
                            "evaluate", "assess", "why", "how does"]
    
    is_analytical_task = any(indicator in prompt.lower() for indicator in analytical_indicators)
    
    # Decision logic
    if is_likely_code:
        # For code generation, prefer phi-4 for all code tasks
        return "phi-4"  # Excellent for code generation
            
    elif is_creative_task:
        # For creative tasks, use models with higher creativity
        if temperature > 0.8:
            return "deepseek-r1"  # More creative at high temperatures
        else:
            return "phi-4"  # Good balance of creativity and coherence
            
    elif is_analytical_task:
        # For analytical tasks, use models with strong reasoning
        return "phi-4"  # Strong reasoning capabilities
        
    # Length-based decisions
    if len(prompt) > 2000:
        # For very long prompts, use models with good context handling
        return "llama3-8b"
    elif len(prompt) < 1000:
        # For shorter prompts, prefer phi-4
        return "phi-4"
        
    # Temperature-based decisions
    if temperature < 0.5:
        # For deterministic outputs
        return "phi-4"
    elif temperature > 0.9:
        # For very creative outputs
        return "deepseek-r1"
        
    # Default to phi-4 instead of the standard model
    return "phi-4"

# vLLM serving function
@app.function(
    image=vllm_image,
    gpu="H100:1",
    allow_concurrent_inputs=100,
    volumes={
        f"{CACHE_DIR}/huggingface": hf_cache_vol,
        f"{CACHE_DIR}/vllm": vllm_cache_vol,
    },
    timeout=30 * MINUTES,
)
@modal.web_server(port=SERVER_PORT)
def serve_vllm_model(model_id: str = DEFAULT_MODEL):
    """
    Serves a model using vLLM with an OpenAI-compatible API.

    Args:
        model_id (str): The identifier of the model to serve. Defaults to DEFAULT_MODEL.

    Raises:
        ValueError: If the specified model_id is not found in VLLM_MODELS.
    """
    import subprocess
    
    if model_id not in VLLM_MODELS:
        available_models = list(VLLM_MODELS.keys())
        logging.error(f"Error: Unknown model: {model_id}. Available models: {available_models}")
        raise ValueError(f"Unknown model: {model_id}. Available models: {available_models}")
    
    model_info = VLLM_MODELS[model_id]
    model_name = model_info["name"]
    revision = model_info["revision"]
    
    logging.basicConfig(level=logging.INFO)
    logging.info(f"Starting vLLM server with model: {model_name}")
    
    cmd = [
        "vllm",
        "serve",
        "--uvicorn-log-level=info",
        model_name,
        "--revision",
        revision,
        "--host",
        "0.0.0.0",
        "--port",
        str(SERVER_PORT),
        "--api-key",
        DEFAULT_API_KEY,
    ]

    # Use subprocess.run instead of Popen to ensure the server is fully started
    # before returning, and don't use shell=True for better process management
    process = subprocess.Popen(cmd)
    
    # Log that we've started the server
    logging.info(f"Started vLLM server with PID {process.pid}")

# Define the worker that will process the queue
@app.function(
    image=vllm_image,
    gpu=None,  # Worker will spawn GPU functions as needed
    allow_concurrent_inputs=10,
    volumes={
        f"{CACHE_DIR}/huggingface": hf_cache_vol,
    },
    timeout=30 * MINUTES,
)
async def process_queue_worker():
    """Worker function that processes requests from the queue"""
    import asyncio
    import time
    
    try:
        # Signal that we're starting a worker
        worker_id = str(uuid.uuid4())[:8]
        logging.info(f"Starting queue processing worker {worker_id}")
        
        # Process requests until timeout or empty queue
        empty_count = 0
        max_empty_count = 10  # Stop after 10 consecutive empty polls
        
        while empty_count < max_empty_count:
            # Try to get a request from the queue
            try:
                request_dict = await request_queue.get.aio(timeout_ms=5000)
                empty_count = 0  # Reset empty counter
                
                # Process the request
                try:
                    # Create request object
                    request_id = request_dict.get("request_id")
                    model_id = request_dict.get("model_id")
                    messages = request_dict.get("messages", [])
                    temperature = request_dict.get("temperature", 0.7)
                    max_tokens = request_dict.get("max_tokens", 1024)
                    api_key = request_dict.get("api_key", DEFAULT_API_KEY)
                    stream_mode = request_dict.get("stream", False)
                    
                    logging.info(f"Worker {worker_id} processing request {request_id} for model {model_id}")
                    
                    # Start time for latency calculation
                    start_time = time.time()
                    
                    if stream_mode:
                        # Generate streaming response
                        await generate_streaming_response(
                            request_id=request_id,
                            model_id=model_id,
                            messages=messages,
                            temperature=temperature,
                            max_tokens=max_tokens,
                            api_key=api_key
                        )
                    else:
                        # Generate non-streaming response
                        response = await generate_response(
                            model_id=model_id,
                            messages=messages,
                            temperature=temperature,
                            max_tokens=max_tokens,
                            api_key=api_key
                        )
                        
                        # Calculate latency
                        latency = time.time() - start_time
                        
                        # Update latency stats
                        if model_stats_dict.contains(model_id):
                            stats = model_stats_dict[model_id]
                            old_avg = stats.get("avg_latency", 0)
                            old_count = stats.get("success_count", 0) 
                            
                            # Calculate new average (moving average)
                            if old_count > 0:
                                new_avg = (old_avg * old_count + latency) / (old_count + 1)
                            else:
                                new_avg = latency
                                
                            stats["avg_latency"] = new_avg
                            model_stats_dict[model_id] = stats
                        
                        # Store the response in both caches
                        memory_cache.set(request_id, response)
                        response_dict[request_id] = response
                        
                        logging.info(f"Worker {worker_id} completed request {request_id} in {latency:.2f}s")
                    
                except Exception as e:
                    # Log error and move on
                    logging.error(f"Worker {worker_id} error processing request {request_id}: {str(e)}")
                    
                    # Create error response
                    error_response = {
                        "error": {
                            "message": str(e),
                            "type": "internal_error",
                            "code": 500
                        }
                    }
                    
                    # Store the error as a response
                    memory_cache.set(request_id, error_response)
                    response_dict[request_id] = error_response
                    
                    # If streaming, send error and finish stream
                    if "stream_mode" in locals() and stream_mode:
                        stream_manager.add_chunk(request_id, {
                            "id": f"chatcmpl-{int(time.time())}",
                            "object": "chat.completion.chunk",
                            "created": int(time.time()),
                            "model": model_id,
                            "choices": [{
                                "index": 0,
                                "delta": {"content": f"Error: {str(e)}"},
                                "finish_reason": "error"
                            }]
                        })
                        stream_manager.finish_stream(request_id)
            
            except asyncio.TimeoutError:
                # No requests in queue
                empty_count += 1
                logging.info(f"Worker {worker_id}: No requests in queue. Empty count: {empty_count}")
                
                # Clean up expired cache entries and old streams
                if empty_count % 5 == 0:  # Every 5 empty polls
                    memory_cache.clear_expired()
                    stream_manager.clean_old_streams()
                
                await asyncio.sleep(1)  # Wait a bit before checking again
        
        # If we get here, we've had too many consecutive empty polls
        logging.info(f"Worker {worker_id} shutting down due to empty queue")
        
    finally:
        # Signal that this worker is done
        workers_running_key = "workers_running"
        if model_stats_dict.contains(workers_running_key):
            current_workers = model_stats_dict[workers_running_key]
            model_stats_dict[workers_running_key] = max(0, current_workers - 1)
            logging.info(f"Worker {worker_id} shutdown. Workers remaining: {max(0, current_workers - 1)}")

async def generate_streaming_response(
    request_id: str,
    model_id: str,
    messages: List[dict],
    temperature: float,
    max_tokens: int,
    api_key: str
):
    """
    Generate a streaming response and send chunks to the stream manager.
    
    Args:
        request_id: The unique ID for this request
        model_id: The ID of the model to use
        messages: The chat messages
        temperature: The sampling temperature
        max_tokens: The maximum tokens to generate
        api_key: The API key for authentication
    """
    import httpx
    import time
    import json
    import asyncio
    
    try:
        # Create response ID
        response_id = f"chatcmpl-{int(time.time())}"
        
        if model_id in VLLM_MODELS:
            # Start vLLM server for this model
            server_url = await serve_vllm_model.remote(model_id=model_id)
            
            # Need to wait for server startup
            await wait_for_server(serve_vllm_model.web_url, timeout=120)
            
            # Forward request to vLLM with streaming enabled
            async with httpx.AsyncClient(timeout=120.0) as client:
                headers = {
                    "Authorization": f"Bearer {api_key}",
                    "Content-Type": "application/json",
                    "Accept": "text/event-stream"
                }
                
                # Format request for vLLM OpenAI-compatible endpoint
                vllm_request = {
                    "model": VLLM_MODELS[model_id]["name"],
                    "messages": messages,
                    "temperature": temperature,
                    "max_tokens": max_tokens,
                    "stream": True
                }
                
                # Make streaming request
                async with client.stream(
                    "POST",
                    f"{serve_vllm_model.web_url}/v1/chat/completions",
                    json=vllm_request,
                    headers=headers
                ) as response:
                    # Process streaming response
                    buffer = ""
                    async for chunk in response.aiter_text():
                        buffer += chunk
                        
                        # Process complete SSE messages
                        while "\n\n" in buffer:
                            message, buffer = buffer.split("\n\n", 1)
                            
                            if message.startswith("data: "):
                                data = message[6:]  # Remove "data: " prefix
                                
                                if data == "[DONE]":
                                    # End of stream
                                    stream_manager.finish_stream(request_id)
                                    return
                                
                                try:
                                    # Parse JSON data
                                    chunk_data = json.loads(data)
                                    # Forward to client
                                    stream_manager.add_chunk(request_id, chunk_data)
                                except json.JSONDecodeError:
                                    logging.error(f"Invalid JSON in stream: {data}")
                    
                    # Ensure stream is finished
                    stream_manager.finish_stream(request_id)
                    
        elif model_id in LLAMA_CPP_MODELS:
            # For llama.cpp models, we need to simulate streaming
            # First convert the chat format to a prompt
            prompt = format_messages_to_prompt(messages)
            
            # Run llama.cpp with the prompt
            output = await run_llama_cpp_stream.remote(
                model_id=model_id,
                prompt=prompt,
                n_predict=max_tokens,
                temperature=temperature,
                request_id=request_id
            )
            
            # Streaming is handled by the run_llama_cpp_stream function
            # which directly adds chunks to the stream manager
            
            # Wait for completion signal
            while True:
                if request_id in stream_queues and stream_queues[request_id] == "DONE":
                    # Clean up
                    stream_queues.pop(request_id)
                    break
                await asyncio.sleep(0.1)
                
        else:
            raise ValueError(f"Unknown model: {model_id}")
            
    except Exception as e:
        logging.error(f"Error in streaming generation: {str(e)}")
        # Send error chunk
        stream_manager.add_chunk(request_id, {
            "id": response_id,
            "object": "chat.completion.chunk",
            "created": int(time.time()),
            "model": model_id,
            "choices": [{
                "index": 0,
                "delta": {"content": f"Error: {str(e)}"},
                "finish_reason": "error"
            }]
        })
        # Finish stream
        stream_manager.finish_stream(request_id)

async def generate_response(model_id: str, messages: List[dict], temperature: float, max_tokens: int, api_key: str):
    """
    Generate a response using the appropriate model based on model_id.
    
    Args:
        model_id: The ID of the model to use
        messages: The chat messages
        temperature: The sampling temperature
        max_tokens: The maximum tokens to generate
        api_key: The API key for authentication
        
    Returns:
        A response in OpenAI-compatible format
    """
    import httpx
    import time
    import json
    import asyncio
    
    if model_id in VLLM_MODELS:
        # Start vLLM server for this model
        server_url = await serve_vllm_model.remote(model_id=model_id)
        
        # Need to wait for server startup
        await wait_for_server(serve_vllm_model.web_url, timeout=120)
        
        # Forward request to vLLM
        async with httpx.AsyncClient(timeout=60.0) as client:
            headers = {
                "Authorization": f"Bearer {api_key}",
                "Content-Type": "application/json"
            }
            
            # Format request for vLLM OpenAI-compatible endpoint
            vllm_request = {
                "model": VLLM_MODELS[model_id]["name"],
                "messages": messages,
                "temperature": temperature,
                "max_tokens": max_tokens
            }
            
            response = await client.post(
                f"{serve_vllm_model.web_url}/v1/chat/completions",
                json=vllm_request,
                headers=headers
            )
            
            return response.json()
    elif model_id in LLAMA_CPP_MODELS:
        # For llama.cpp models, use the run_llama_cpp function
        # First convert the chat format to a prompt
        prompt = format_messages_to_prompt(messages)
        
        # Run llama.cpp with the prompt
        output = await run_llama_cpp.remote(
            model_id=model_id,
            prompt=prompt,
            n_predict=max_tokens,
            temperature=temperature
        )
        
        # Format the response in the OpenAI format
        completion_text = output.strip()
        finish_reason = "stop" if len(completion_text) < max_tokens else "length"
        
        return {
            "id": f"chatcmpl-{int(time.time())}",
            "object": "chat.completion",
            "created": int(time.time()),
            "model": model_id,
            "choices": [
                {
                    "index": 0,
                    "message": {
                        "role": "assistant",
                        "content": completion_text
                    },
                    "finish_reason": finish_reason
                }
            ],
            "usage": {
                "prompt_tokens": len(prompt) // 4,  # Rough estimation
                "completion_tokens": len(completion_text) // 4,  # Rough estimation
                "total_tokens": (len(prompt) + len(completion_text)) // 4  # Rough estimation
            }
        }
    else:
        raise ValueError(f"Unknown model: {model_id}")

def format_messages_to_prompt(messages: List[Dict[str, str]]) -> str:
    """
    Convert chat messages to a text prompt format for llama.cpp.
    
    Args:
        messages: List of message dictionaries with role and content
    
    Returns:
        Formatted prompt string
    """
    formatted_prompt = ""
    
    for message in messages:
        role = message.get("role", "").lower()
        content = message.get("content", "")
        
        if role == "system":
            formatted_prompt += f"<|system|>\n{content}\n"
        elif role == "user":
            formatted_prompt += f"<|user|>\n{content}\n"
        elif role == "assistant":
            formatted_prompt += f"<|assistant|>\n{content}\n"
        else:
            # For unknown roles, treat as user
            formatted_prompt += f"<|user|>\n{content}\n"
    
    # Add final assistant marker to prompt the model to respond
    formatted_prompt += "<|assistant|>\n"
    
    return formatted_prompt

async def wait_for_server(url: str, timeout: int = 120, check_interval: int = 2):
    """
    Wait for a server to be ready by checking its health endpoint.
    
    Args:
        url: The base URL of the server
        timeout: Maximum time to wait in seconds
        check_interval: Interval between checks in seconds
    
    Returns:
        True if server is ready, False otherwise
    """
    import httpx
    import asyncio
    import time
    
    start_time = time.time()
    health_url = f"{url}/health"
    
    logging.info(f"Waiting for server at {url} to be ready...")
    
    while time.time() - start_time < timeout:
        try:
            async with httpx.AsyncClient(timeout=5.0) as client:
                response = await client.get(health_url)
                if response.status_code == 200:
                    logging.info(f"Server at {url} is ready!")
                    return True
        except Exception as e:
            elapsed = time.time() - start_time
            logging.info(f"Server not ready yet after {elapsed:.1f}s: {str(e)}")
            
        await asyncio.sleep(check_interval)
    
    logging.error(f"Timed out waiting for server at {url} after {timeout} seconds")
    return False

@app.function(
    image=llama_cpp_image,
    gpu=None,  # Will be set dynamically based on model
    volumes={
        f"{CACHE_DIR}/huggingface": hf_cache_vol,
        f"{CACHE_DIR}/llama_cpp": llama_cpp_cache_vol,
        RESULTS_DIR: results_vol,
    },
    timeout=30 * MINUTES,
)
async def run_llama_cpp_stream(
    model_id: str,
    prompt: str,
    n_predict: int = 1024,
    temperature: float = 0.7,
    request_id: str = None,
):
    """
    Run streaming inference with llama.cpp for models like DeepSeek-R1 and Phi-4
    """
    import subprocess
    import os
    import json
    import time
    import threading
    from uuid import uuid4
    from pathlib import Path
    from huggingface_hub import snapshot_download
    
    if model_id not in LLAMA_CPP_MODELS:
        available_models = list(LLAMA_CPP_MODELS.keys())
        error_msg = f"Unknown model: {model_id}. Available models: {available_models}"
        logging.error(error_msg)
        
        if request_id:
            # Send error to stream
            stream_manager.add_chunk(request_id, {
                "id": f"chatcmpl-{int(time.time())}",
                "object": "chat.completion.chunk",
                "created": int(time.time()),
                "model": model_id,
                "choices": [{
                    "index": 0,
                    "delta": {"content": f"Error: {error_msg}"},
                    "finish_reason": "error"
                }]
            })
            stream_manager.finish_stream(request_id)
            # Signal completion
            stream_queues[request_id] = "DONE"
            
        raise ValueError(error_msg)
    
    model_info = LLAMA_CPP_MODELS[model_id]
    repo_id = model_info["name"]
    pattern = model_info["pattern"]
    revision = model_info["revision"]
    quant = model_info["quant"]
    
    # Download model if not already cached
    logging.info(f"Downloading model {repo_id} if not present")
    try:
        model_path = snapshot_download(
            repo_id=repo_id,
            revision=revision,
            local_dir=f"{CACHE_DIR}/llama_cpp",
            allow_patterns=[pattern],
        )
    except ValueError as e:
        if "hf_transfer" in str(e):
            # Fallback to standard download if hf_transfer fails
            logging.warning("hf_transfer failed, falling back to standard download")
            # Temporarily disable hf_transfer
            import os
            old_env = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "1")
            os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
            try:
                model_path = snapshot_download(
                    repo_id=repo_id,
                    revision=revision,
                    local_dir=f"{CACHE_DIR}/llama_cpp",
                    allow_patterns=[pattern],
                )
            finally:
                # Restore original setting
                os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_env
        else:
            raise
    
    # Find the model file
    model_files = list(Path(model_path).glob(pattern))
    if not model_files:
        error_msg = f"No model files found matching pattern {pattern}"
        logging.error(error_msg)
        
        if request_id:
            # Send error to stream
            stream_manager.add_chunk(request_id, {
                "id": f"chatcmpl-{int(time.time())}",
                "object": "chat.completion.chunk",
                "created": int(time.time()),
                "model": model_id,
                "choices": [{
                    "index": 0,
                    "delta": {"content": f"Error: {error_msg}"},
                    "finish_reason": "error"
                }]
            })
            stream_manager.finish_stream(request_id)
            # Signal completion
            stream_queues[request_id] = "DONE"
            
        raise FileNotFoundError(error_msg)
    
    model_file = str(model_files[0])
    logging.info(f"Using model file: {model_file}")
    
    # Set up command
    cmd = [
        "llama-cli",
        "--model", model_file,
        "--prompt", prompt,
        "--n-predict", str(n_predict),
        "--temp", str(temperature),
        "--ctx-size", "8192",
    ]
    
    # Add GPU layers if needed
    if model_info["gpu"] is not None:
        cmd.extend(["--n-gpu-layers", "9999"])  # Use all layers on GPU
    
    # Run inference
    result_id = str(uuid4())
    logging.info(f"Running streaming inference with ID: {result_id}")
    
    # Create response ID for streaming
    response_id = f"chatcmpl-{int(time.time())}"
    
    # Function to process output in real-time and send to stream
    def process_output(process, request_id):
        content_buffer = ""
        last_send_time = time.time()
        
        # Send initial chunk with role
        if request_id:
            stream_manager.add_chunk(request_id, {
                "id": response_id,
                "object": "chat.completion.chunk",
                "created": int(time.time()),
                "model": model_id,
                "choices": [{
                    "index": 0,
                    "delta": {"role": "assistant"},
                }]
            })
        
        for line in iter(process.stdout.readline, b''):
            try:
                line_str = line.decode('utf-8', errors='replace')
                
                # Skip llama.cpp info lines
                if line_str.startswith("llama_"):
                    continue
                
                # Add to buffer
                content_buffer += line_str
                
                # Send chunks at reasonable intervals or when buffer gets large
                now = time.time()
                if (now - last_send_time > 0.1 or len(content_buffer) > 20) and request_id:
                    # Send chunk
                    stream_manager.add_chunk(request_id, {
                        "id": response_id,
                        "object": "chat.completion.chunk",
                        "created": int(time.time()),
                        "model": model_id,
                        "choices": [{
                            "index": 0,
                            "delta": {"content": content_buffer},
                        }]
                    })
                    
                    # Reset buffer and time
                    content_buffer = ""
                    last_send_time = now
                    
            except Exception as e:
                logging.error(f"Error processing output: {str(e)}")
        
        # Send any remaining content
        if content_buffer and request_id:
            stream_manager.add_chunk(request_id, {
                "id": response_id,
                "object": "chat.completion.chunk",
                "created": int(time.time()),
                "model": model_id,
                "choices": [{
                    "index": 0,
                    "delta": {"content": content_buffer},
                }]
            })
        
        # Send final chunk with finish reason
        if request_id:
            stream_manager.add_chunk(request_id, {
                "id": response_id,
                "object": "chat.completion.chunk",
                "created": int(time.time()),
                "model": model_id,
                "choices": [{
                    "index": 0,
                    "delta": {},
                    "finish_reason": "stop"
                }]
            })
            
            # Finish stream
            stream_manager.finish_stream(request_id)
            
            # Signal completion
            stream_queues[request_id] = "DONE"
    
    # Start process
    process = subprocess.Popen(
        cmd, 
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=False,
        bufsize=1  # Line buffered
    )
    
    # Start output processing thread if streaming
    if request_id:
        thread = threading.Thread(target=process_output, args=(process, request_id))
        thread.daemon = True
        thread.start()
        
        # Return immediately for streaming
        return "Streaming in progress"
    else:
        # For non-streaming, collect all output
        stdout, stderr = collect_output(process)
        
        # Save results
        result_dir = Path(RESULTS_DIR) / result_id
        result_dir.mkdir(parents=True, exist_ok=True)
        
        (result_dir / "output.txt").write_text(stdout)
        (result_dir / "stderr.txt").write_text(stderr)
        (result_dir / "prompt.txt").write_text(prompt)
        
        logging.info(f"Results saved to {result_dir}")
        return stdout

@app.function(
    image=llama_cpp_image,
    gpu=None,  # Will be set dynamically based on model
    volumes={
        f"{CACHE_DIR}/huggingface": hf_cache_vol,
        f"{CACHE_DIR}/llama_cpp": llama_cpp_cache_vol,
        RESULTS_DIR: results_vol,
    },
    timeout=30 * MINUTES,
)
async def run_llama_cpp(
    model_id: str,
    prompt: str = "Tell me about Modal and how it helps with ML deployments.",
    n_predict: int = 1024,
    temperature: float = 0.7,
):
    """
    Run inference with llama.cpp for models like DeepSeek-R1 and Phi-4
    """
    import subprocess
    import os
    from uuid import uuid4
    from pathlib import Path
    from huggingface_hub import snapshot_download
    
    if model_id not in LLAMA_CPP_MODELS:
        available_models = list(LLAMA_CPP_MODELS.keys())
        print(f"Error: Unknown model: {model_id}. Available models: {available_models}")
        raise ValueError(f"Unknown model: {model_id}. Available models: {available_models}")
    
    model_info = LLAMA_CPP_MODELS[model_id]
    repo_id = model_info["name"]
    pattern = model_info["pattern"]
    revision = model_info["revision"]
    quant = model_info["quant"]
    
    # Download model if not already cached
    logging.info(f"Downloading model {repo_id} if not present")
    try:
        model_path = snapshot_download(
            repo_id=repo_id,
            revision=revision,
            local_dir=f"{CACHE_DIR}/llama_cpp",
            allow_patterns=[pattern],
        )
    except ValueError as e:
        if "hf_transfer" in str(e):
            # Fallback to standard download if hf_transfer fails
            logging.warning("hf_transfer failed, falling back to standard download")
            # Temporarily disable hf_transfer
            import os
            old_env = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "1")
            os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
            try:
                model_path = snapshot_download(
                    repo_id=repo_id,
                    revision=revision,
                    local_dir=f"{CACHE_DIR}/llama_cpp",
                    allow_patterns=[pattern],
                )
            finally:
                # Restore original setting
                os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_env
        else:
            raise
    
    # Find the model file
    model_files = list(Path(model_path).glob(pattern))
    if not model_files:
        logging.error(f"No model files found matching pattern {pattern}")
        raise FileNotFoundError(f"No model files found matching pattern {pattern}")
    
    model_file = str(model_files[0])
    print(f"Using model file: {model_file}")
    
    # Set up command
    cmd = [
        "llama-cli",
        "--model", model_file,
        "--prompt", prompt,
        "--n-predict", str(n_predict),
        "--temp", str(temperature),
        "--ctx-size", "8192",
    ]
    
    # Add GPU layers if needed
    if model_info["gpu"] is not None:
        cmd.extend(["--n-gpu-layers", "9999"])  # Use all layers on GPU
    
    # Run inference
    result_id = str(uuid4())
    print(f"Running inference with ID: {result_id}")
    
    process = subprocess.Popen(
        cmd, 
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=False
    )
    
    stdout, stderr = collect_output(process)
    
    # Save results
    result_dir = Path(RESULTS_DIR) / result_id
    result_dir.mkdir(parents=True, exist_ok=True)
    
    (result_dir / "output.txt").write_text(stdout)
    (result_dir / "stderr.txt").write_text(stderr)
    (result_dir / "prompt.txt").write_text(prompt)
    
    print(f"Results saved to {result_dir}")
    return stdout

@app.function(
    image=vllm_image,
    volumes={
        f"{CACHE_DIR}/huggingface": hf_cache_vol,
    },
)
def list_available_models():
    """
    Lists available models that can be used with this server.

    Returns:
        dict: A dictionary containing lists of available vLLM and llama.cpp models.
    """
    print("Available vLLM models:")
    for model_id, model_info in VLLM_MODELS.items():
        print(f"- {model_id}: {model_info['name']}")
    
    print("\nAvailable llama.cpp models:")
    for model_id, model_info in LLAMA_CPP_MODELS.items():
        gpu_info = f"(GPU: {model_info['gpu']})" if model_info['gpu'] else "(CPU)"
        print(f"- {model_id}: {model_info['name']} {gpu_info}")
    
    return {
        "vllm": list(VLLM_MODELS.keys()),
        "llama_cpp": list(LLAMA_CPP_MODELS.keys())
    }

def collect_output(process):
    """
    Collect output from a process while streaming it.

    Args:
        process: The process from which to collect output.

    Returns:
        tuple: A tuple containing the collected stdout and stderr as strings.
    """
    import sys
    from queue import Queue
    from threading import Thread
    
    def stream_output(stream, queue, write_stream):
        for line in iter(stream.readline, b""):
            line_str = line.decode("utf-8", errors="replace")
            write_stream.write(line_str)
            write_stream.flush()
            queue.put(line_str)
        stream.close()
    
    stdout_queue = Queue()
    stderr_queue = Queue()
    
    stdout_thread = Thread(target=stream_output, args=(process.stdout, stdout_queue, sys.stdout))
    stderr_thread = Thread(target=stream_output, args=(process.stderr, stderr_queue, sys.stderr))
    
    stdout_thread.start()
    stderr_thread.start()
    
    stdout_thread.join()
    stderr_thread.join()
    process.wait()
    
    stdout_collected = "".join(list(stdout_queue.queue))
    stderr_collected = "".join(list(stderr_queue.queue))
    
    return stdout_collected, stderr_collected

# Main ASGI app for Modal
@app.function(
    image=vllm_image,
    gpu=None,  # No GPU for the API frontend
    allow_concurrent_inputs=100,
    volumes={
        f"{CACHE_DIR}/huggingface": hf_cache_vol,
    },
)
@modal.asgi_app()
def inference_api():
    """The main ASGI app that serves the FastAPI application"""
    return api_app

@app.local_entrypoint()
def main(
    prompt: str = "What can you tell me about Modal?",
    n_predict: int = 1024,
    temperature: float = 0.7,
    create_admin_key: bool = False,
    stream: bool = False,
    model: str = "auto",
    load_model: str = None,
    load_hf_model: str = None,
    hf_model_type: str = "vllm",
):
    """
    Main entrypoint for testing the API
    """
    import json
    import time
    import urllib.request
    
    # Initialize the API
    print(f"Starting API at {inference_api.web_url}")
    
    # Wait for API to be ready
    print("Checking if API is ready...")
    up, start, delay = False, time.time(), 10
    while not up:
        try:
            with urllib.request.urlopen(inference_api.web_url + "/health") as response:
                if response.getcode() == 200:
                    up = True
        except Exception:
            if time.time() - start > 5 * MINUTES:
                break
            time.sleep(delay)

    assert up, f"Failed health check for API at {inference_api.web_url}"
    print(f"API is up and running at {inference_api.web_url}")
    
    # Create a test API key if requested
    if create_admin_key:
        print("Creating a test API key...")
        key_request = {
            "user_id": "test_user",
            "rate_limit": 120,
            "quota": 2000000
        }
        headers = {
            "Authorization": f"Bearer {DEFAULT_API_KEY}",  # Admin key
            "Content-Type": "application/json",
        }
        req = urllib.request.Request(
            inference_api.web_url + "/admin/api-keys",
            data=json.dumps(key_request).encode("utf-8"),
            headers=headers,
            method="POST",
        )
        try:
            with urllib.request.urlopen(req) as response:
                result = json.loads(response.read().decode())
                print("Created API key:")
                print(json.dumps(result, indent=2))
                # Use this key for the test message
                test_key = result["key"]
        except Exception as e:
            print(f"Error creating API key: {str(e)}")
            test_key = DEFAULT_API_KEY
    else:
        test_key = DEFAULT_API_KEY
            
    # List available models
    print("\nAvailable models:")
    try:
        headers = {
            "Authorization": f"Bearer {test_key}",
            "Content-Type": "application/json",
        }
        req = urllib.request.Request(
            inference_api.web_url + "/v1/models",
            headers=headers,
            method="GET",
        )
        with urllib.request.urlopen(req) as response:
            models = json.loads(response.read().decode())
            print(json.dumps(models, indent=2))
    except Exception as e:
        print(f"Error listing models: {str(e)}")
        
    # Select best model for the prompt
    model = select_best_model(prompt, n_predict, temperature)
    
    # Send a test message
    print(f"\nSending a sample message to {inference_api.web_url}")
    messages = [{"role": "user", "content": prompt}]

    headers = {
        "Authorization": f"Bearer {test_key}",
        "Content-Type": "application/json",
    }
    payload = json.dumps({
        "messages": messages, 
        "model": model,
        "temperature": temperature,
        "max_tokens": n_predict,
        "stream": stream
    })
    req = urllib.request.Request(
        inference_api.web_url + "/v1/chat/completions",
        data=payload.encode("utf-8"),
        headers=headers,
        method="POST",
    )
    
    try:
        if stream:
            print("Streaming response:")
            with urllib.request.urlopen(req) as response:
                for line in response:
                    line = line.decode('utf-8')
                    if line.startswith('data: '):
                        data = line[6:].strip()
                        if data == '[DONE]':
                            print("\n[DONE]")
                        else:
                            try:
                                chunk = json.loads(data)
                                if 'choices' in chunk and len(chunk['choices']) > 0:
                                    if 'delta' in chunk['choices'][0] and 'content' in chunk['choices'][0]['delta']:
                                        content = chunk['choices'][0]['delta']['content']
                                        print(content, end='', flush=True)
                            except json.JSONDecodeError:
                                print(f"Error parsing: {data}")
        else:
            with urllib.request.urlopen(req) as response:
                result = json.loads(response.read().decode())
                print("Response:")
                print(json.dumps(result, indent=2))
    except Exception as e:
        print(f"Error: {str(e)}")
    
    # Check API stats
    print("\nChecking API stats...")
    headers = {
        "Authorization": f"Bearer {DEFAULT_API_KEY}",  # Admin key
        "Content-Type": "application/json",
    }
    req = urllib.request.Request(
        inference_api.web_url + "/admin/stats",
        headers=headers,
        method="GET",
    )
    try:
        with urllib.request.urlopen(req) as response:
            stats = json.loads(response.read().decode())
            print("API Stats:")
            print(json.dumps(stats, indent=2))
    except Exception as e:
        print(f"Error getting stats: {str(e)}")
        
    # Start a worker if none running
    try:
        current_workers = stats.get("queue", {}).get("active_workers", 0)
        if current_workers < 1:
            print("\nStarting a queue worker...")
            process_queue_worker.spawn()
    except Exception as e:
        print(f"Error starting worker: {str(e)}")
        
    print(f"\nAPI is available at {inference_api.web_url}")
    print(f"Documentation is at {inference_api.web_url}/docs")
    print(f"Default Bearer token: {DEFAULT_API_KEY}")
    
    if create_admin_key:
        print(f"Test Bearer token: {test_key}")
        
    # If a model was specified to load, load it
    if load_model:
        print(f"\nLoading model: {load_model}")
        load_url = f"{inference_api.web_url}/admin/models/load"
        headers = {
            "Authorization": f"Bearer {test_key}",
            "Content-Type": "application/json",
        }
        payload = json.dumps({
            "model_id": load_model,
            "force_reload": True
        })
        req = urllib.request.Request(
            load_url,
            data=payload.encode("utf-8"),
            headers=headers,
            method="POST",
        )
        try:
            with urllib.request.urlopen(req) as response:
                result = json.loads(response.read().decode())
                print("Load response:")
                print(json.dumps(result, indent=2))
                
                # If it's a small model, wait a bit for it to load
                if load_model in ["tiny-llama-1.1b", "phi-2"]:
                    print(f"Waiting for {load_model} to load...")
                    time.sleep(10)
                    
                    # Check status
                    status_url = f"{inference_api.web_url}/admin/models/status/{load_model}"
                    status_req = urllib.request.Request(
                        status_url,
                        headers={"Authorization": f"Bearer {test_key}"},
                        method="GET",
                    )
                    with urllib.request.urlopen(status_req) as status_response:
                        status_result = json.loads(status_response.read().decode())
                        print("Model status:")
                        print(json.dumps(status_result, indent=2))
                
                # Use this model for the test
                model = load_model
        except Exception as e:
            print(f"Error loading model: {str(e)}")
            
    # If a HF model was specified to load directly
    if load_hf_model:
        print(f"\nLoading HF model: {load_hf_model} with type {hf_model_type}")
        load_url = f"{inference_api.web_url}/admin/models/load-from-hf"
        headers = {
            "Authorization": f"Bearer {test_key}",
            "Content-Type": "application/json",
        }
        payload = json.dumps({
            "repo_id": load_hf_model,
            "model_type": hf_model_type,
            "max_tokens": n_predict
        })
        req = urllib.request.Request(
            load_url,
            data=payload.encode("utf-8"),
            headers=headers,
            method="POST",
        )
        try:
            with urllib.request.urlopen(req) as response:
                result = json.loads(response.read().decode())
                print("HF Load response:")
                print(json.dumps(result, indent=2))
                
                # Get the model_id from the response
                hf_model_id = result.get("model_id")
                
                # Wait a bit for it to start loading
                print(f"Waiting for {load_hf_model} to start loading...")
                time.sleep(5)
                
                # Check status
                if hf_model_id:
                    status_url = f"{inference_api.web_url}/admin/models/status/{hf_model_id}"
                    status_req = urllib.request.Request(
                        status_url,
                        headers={"Authorization": f"Bearer {test_key}"},
                        method="GET",
                    )
                    with urllib.request.urlopen(status_req) as status_response:
                        status_result = json.loads(status_response.read().decode())
                        print("Model status:")
                        print(json.dumps(status_result, indent=2))
                
                # Use this model for the test
                if hf_model_id:
                    model = hf_model_id
        except Exception as e:
            print(f"Error loading HF model: {str(e)}")

    # Show curl examples
    print("\nExample curl commands:")
    
    # Regular completion
    print(f"""# Regular completion:
curl -X POST {inference_api.web_url}/v1/chat/completions \\
  -H "Content-Type: application/json" \\
  -H "Authorization: Bearer {test_key}" \\
  -d '{{
    "model": "{model}",
    "messages": [
      {{
        "role": "user",
        "content": "Hello, how can you help me today?"
      }}
    ],
    "temperature": 0.7,
    "max_tokens": 500
  }}'""")
    
    # Streaming completion
    print(f"""\n# Streaming completion:
curl -X POST {inference_api.web_url}/v1/chat/completions \\
  -H "Content-Type: application/json" \\
  -H "Authorization: Bearer {test_key}" \\
  -d '{{
    "model": "{model}",
    "messages": [
      {{
        "role": "user",
        "content": "Write a short story about AI"
      }}
    ],
    "temperature": 0.8,
    "max_tokens": 1000,
    "stream": true
  }}' --no-buffer""")
    
    # List models
    print(f"""\n# List available models:
curl -X GET {inference_api.web_url}/v1/models \\
  -H "Authorization: Bearer {test_key}" """)
    
    # Model management commands
    print(f"""\n# Load a model:
curl -X POST {inference_api.web_url}/admin/models/load \\
  -H "Content-Type: application/json" \\
  -H "Authorization: Bearer {test_key}" \\
  -d '{{
    "model_id": "phi-2",
    "force_reload": false
  }}'""")
    
    print(f"""\n# Load a model directly from Hugging Face:
curl -X POST {inference_api.web_url}/admin/models/load-from-hf \\
  -H "Content-Type: application/json" \\
  -H "Authorization: Bearer {test_key}" \\
  -d '{{
    "repo_id": "microsoft/phi-2",
    "model_type": "vllm",
    "max_tokens": 4096
  }}'""")
    
    print(f"""\n# Get model status:
curl -X GET {inference_api.web_url}/admin/models/status/phi-2 \\
  -H "Authorization: Bearer {test_key}" """)
    
    print(f"""\n# Unload a model:
curl -X POST {inference_api.web_url}/admin/models/unload \\
  -H "Content-Type: application/json" \\
  -H "Authorization: Bearer {test_key}" \\
  -d '{{
    "model_id": "phi-2"
  }}'""")
async def preload_llama_cpp_model(model_id: str):
    """Preload a llama.cpp model to make inference faster on first request"""
    if model_id not in LLAMA_CPP_MODELS:
        logging.error(f"Unknown model: {model_id}")
        return
    
    try:
        # Run a simple inference to load the model
        logging.info(f"Preloading llama.cpp model: {model_id}")
        await run_llama_cpp.remote(
            model_id=model_id,
            prompt="Hello, this is a test to preload the model.",
            n_predict=10,
            temperature=0.7
        )
        logging.info(f"Successfully preloaded llama.cpp model: {model_id}")
    except Exception as e:
        logging.error(f"Error preloading llama.cpp model {model_id}: {str(e)}")
        # Mark as not loaded
        LLAMA_CPP_MODELS[model_id]["loaded"] = False

```

--------------------------------------------------------------------------------
/cli.py:
--------------------------------------------------------------------------------

```python
#!/usr/bin/env python3
# TODO: Refactor into modular structure similar to Claude Code (lib/, commands/, tools/ directories)
# TODO: Add support for multiple LLM providers (Azure OpenAI, Anthropic, etc.)
# TODO: Implement telemetry and usage tracking (optional, with consent)
import os
import sys
import json
import typer
from rich.console import Console
from rich.markdown import Markdown
from rich.prompt import Prompt
from rich.panel import Panel
from rich.progress import Progress
from rich.syntax import Syntax
from rich.live import Live
from rich.layout import Layout
from rich.table import Table
from openai import OpenAI
from pydantic import BaseModel, Field
from typing import List, Dict, Any, Optional, Union, Callable
import asyncio
import concurrent.futures
from dotenv import load_dotenv
import time
import re
import traceback
import requests
import urllib.parse
from uuid import uuid4
import socket
import threading
import multiprocessing
import pickle
import hashlib
import logging
import fastapi
import uvicorn
from fastapi import FastAPI, HTTPException, Depends, Request, BackgroundTasks
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware

# Jina.ai client for search, fact-checking, and web reading
class JinaClient:
    """Client for interacting with Jina.ai endpoints"""
    
    def __init__(self, token: Optional[str] = None):
        """Initialize with your Jina token"""
        self.token = token or os.getenv("JINA_API_KEY", "")
        
        self.headers = {
            "Authorization": f"Bearer {self.token}",
            "Content-Type": "application/json"
        }
    
    def search(self, query: str) -> dict:
        """
        Search using s.jina.ai endpoint
        Args:
            query: Search term
        Returns:
            API response as dict
        """
        encoded_query = urllib.parse.quote(query)
        url = f"https://s.jina.ai/{encoded_query}"
        response = requests.get(url, headers=self.headers)
        return response.json()
    
    def fact_check(self, query: str) -> dict:
        """
        Get grounding info using g.jina.ai endpoint
        Args:
            query: Query to ground
        Returns:
            API response as dict
        """
        encoded_query = urllib.parse.quote(query)
        url = f"https://g.jina.ai/{encoded_query}"
        response = requests.get(url, headers=self.headers)
        return response.json()
        
    def reader(self, url: str) -> dict:
        """
        Get ranking using r.jina.ai endpoint
        Args:
            url: URL to rank
        Returns:
            API response as dict
        """
        encoded_url = urllib.parse.quote(url)
        url = f"https://r.jina.ai/{encoded_url}"
        response = requests.get(url, headers=self.headers)
        return response.json()

# Check if RL tools are available
HAVE_RL_TOOLS = False
try:
    # This is a placeholder for the actual import that would be used
    from tool_optimizer import ToolSelectionManager
    # If the import succeeds, set HAVE_RL_TOOLS to True
    HAVE_RL_TOOLS = True
except ImportError:
    # RL tools not available
    # Define a dummy ToolSelectionManager to avoid NameError
    class ToolSelectionManager:
        def __init__(self, **kwargs):
            self.optimizer = None
            self.data_dir = kwargs.get('data_dir', '')
            
        def record_tool_usage(self, **kwargs):
            pass

# Load environment variables
load_dotenv()

# TODO: Add update checking similar to Claude Code's auto-update functionality
# TODO: Add configuration file support to store settings beyond environment variables

app = typer.Typer(help="OpenAI Code Assistant CLI")
console = Console()

# Global Constants
# TODO: Move these to a config file
DEFAULT_MODEL = "gpt-4o"
DEFAULT_TEMPERATURE = 0
MAX_TOKENS = 4096
TOKEN_LIMIT_WARNING = 0.8  # Warn when 80% of token limit is reached

# Models
# TODO: Implement more sophisticated schema validation similar to Zod in the original
# TODO: Add permission system for tools that require user approval

class ToolParameter(BaseModel):
    name: str
    description: str
    type: str
    required: bool = False

class Tool(BaseModel):
    name: str
    description: str
    parameters: Dict[str, Any]
    function: Callable
    # TODO: Add needs_permission flag for sensitive operations
    # TODO: Add category for organizing tools (file, search, etc.)

class Message(BaseModel):
    role: str
    content: Optional[str] = None
    tool_calls: Optional[List[Dict[str, Any]]] = None
    tool_call_id: Optional[str] = None
    name: Optional[str] = None
    # TODO: Add timestamp for message tracking
    # TODO: Add token count for better context management

class Conversation:
    def __init__(self):
        self.messages = []
        # TODO: Implement retry logic with exponential backoff for API calls
        # TODO: Add support for multiple LLM providers
        self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
        self.model = os.getenv("OPENAI_MODEL", DEFAULT_MODEL)
        self.temperature = float(os.getenv("OPENAI_TEMPERATURE", DEFAULT_TEMPERATURE))
        self.tools = self._register_tools()
        self.tool_map = {tool.name: tool.function for tool in self.tools}
        self.conversation_id = str(uuid4())
        self.session_start_time = time.time()
        self.token_usage = {"prompt": 0, "completion": 0, "total": 0}
        self.verbose = False
        self.max_tool_iterations = int(os.getenv("MAX_TOOL_ITERATIONS", "10"))
        
        # Initialize tool selection optimizer if available
        self.tool_optimizer = None
        if HAVE_RL_TOOLS:
            try:
                # Create a simple tool registry adapter for the optimizer
                class ToolRegistryAdapter:
                    def __init__(self, tools):
                        self.tools = tools
                
                    def get_all_tools(self):
                        return self.tools
                
                    def get_all_tool_names(self):
                        return [tool.name for tool in self.tools]
            
                # Initialize the tool selection manager
                self.tool_optimizer = ToolSelectionManager(
                    tool_registry=ToolRegistryAdapter(self.tools),
                    enable_optimization=os.getenv("ENABLE_TOOL_OPTIMIZATION", "1") == "1",
                    data_dir=os.path.join(os.path.dirname(os.path.abspath(__file__)), "data/rl")
                )
                if self.verbose:
                    print("Tool selection optimization enabled")
            except Exception as e:
                print(f"Warning: Failed to initialize tool optimizer: {e}")
        # TODO: Implement context window management
        
    # Jina.ai client for search, fact-checking, and web reading
    def _init_jina_client(self):
        """Initialize the Jina.ai client"""
        token = os.getenv("JINA_API_KEY", "")
        return JinaClient(token)
    
    def _jina_search(self, query: str) -> str:
        """Search the web using Jina.ai"""
        try:
            client = self._init_jina_client()
            results = client.search(query)
            
            if not results or not isinstance(results, dict):
                return f"No search results found for '{query}'"
            
            # Format the results
            formatted_results = "Search Results:\n\n"
            
            if "results" in results and isinstance(results["results"], list):
                for i, result in enumerate(results["results"], 1):
                    title = result.get("title", "No title")
                    url = result.get("url", "No URL")
                    snippet = result.get("snippet", "No snippet")
                    
                    formatted_results += f"{i}. {title}\n"
                    formatted_results += f"   URL: {url}\n"
                    formatted_results += f"   {snippet}\n\n"
            else:
                formatted_results += "Unexpected response format. Raw data:\n"
                formatted_results += json.dumps(results, indent=2)[:1000]
                
            return formatted_results
        except Exception as e:
            return f"Error performing search: {str(e)}"
    
    def _jina_fact_check(self, statement: str) -> str:
        """Fact check a statement using Jina.ai"""
        try:
            client = self._init_jina_client()
            results = client.fact_check(statement)
            
            if not results or not isinstance(results, dict):
                return f"No fact-checking results for '{statement}'"
            
            # Format the results
            formatted_results = "Fact Check Results:\n\n"
            formatted_results += f"Statement: {statement}\n\n"
            
            if "grounding" in results:
                grounding = results["grounding"]
                verdict = grounding.get("verdict", "Unknown")
                confidence = grounding.get("confidence", 0)
                
                formatted_results += f"Verdict: {verdict}\n"
                formatted_results += f"Confidence: {confidence:.2f}\n\n"
                
                if "sources" in grounding and isinstance(grounding["sources"], list):
                    formatted_results += "Sources:\n"
                    for i, source in enumerate(grounding["sources"], 1):
                        title = source.get("title", "No title")
                        url = source.get("url", "No URL")
                        formatted_results += f"{i}. {title}\n   {url}\n\n"
            else:
                formatted_results += "Unexpected response format. Raw data:\n"
                formatted_results += json.dumps(results, indent=2)[:1000]
                
            return formatted_results
        except Exception as e:
            return f"Error performing fact check: {str(e)}"
    
    def _jina_read_url(self, url: str) -> str:
        """Read and summarize a webpage using Jina.ai"""
        try:
            client = self._init_jina_client()
            results = client.reader(url)
            
            if not results or not isinstance(results, dict):
                return f"No reading results for URL '{url}'"
            
            # Format the results
            formatted_results = f"Web Page Summary: {url}\n\n"
            
            if "content" in results:
                content = results["content"]
                title = content.get("title", "No title")
                summary = content.get("summary", "No summary available")
                
                formatted_results += f"Title: {title}\n\n"
                formatted_results += f"Summary:\n{summary}\n\n"
                
                if "keyPoints" in content and isinstance(content["keyPoints"], list):
                    formatted_results += "Key Points:\n"
                    for i, point in enumerate(content["keyPoints"], 1):
                        formatted_results += f"{i}. {point}\n"
            else:
                formatted_results += "Unexpected response format. Raw data:\n"
                formatted_results += json.dumps(results, indent=2)[:1000]
                
            return formatted_results
        except Exception as e:
            return f"Error reading URL: {str(e)}"
    
    def _register_tools(self) -> List[Tool]:
        # TODO: Modularize tools into separate files
        # TODO: Implement Tool decorators for easier registration
        # TODO: Add more tools similar to Claude Code (ReadNotebook, NotebookEditCell, etc.)
        
        # Define and register all tools
        tools = [
            Tool(
                name="Weather",
                description="Gets the current weather for a location",
                parameters={
                    "type": "object",
                    "properties": {
                        "location": {
                            "type": "string",
                            "description": "The city and optional state/country (e.g., 'San Francisco, CA' or 'London, UK')"
                        }
                    },
                    "required": ["location"]
                },
                function=self._get_weather
            ),
            Tool(
                name="View",
                description="Reads a file from the local filesystem. The file_path parameter must be an absolute path, not a relative path.",
                parameters={
                    "type": "object",
                    "properties": {
                        "file_path": {
                            "type": "string",
                            "description": "The absolute path to the file to read"
                        },
                        "limit": {
                            "type": "number",
                            "description": "The number of lines to read. Only provide if the file is too large to read at once."
                        },
                        "offset": {
                            "type": "number",
                            "description": "The line number to start reading from. Only provide if the file is too large to read at once"
                        }
                    },
                    "required": ["file_path"]
                },
                function=self._view_file
            ),
            Tool(
                name="Edit",
                description="This is a tool for editing files.",
                parameters={
                    "type": "object",
                    "properties": {
                        "file_path": {
                            "type": "string",
                            "description": "The absolute path to the file to modify"
                        },
                        "old_string": {
                            "type": "string",
                            "description": "The text to replace"
                        },
                        "new_string": {
                            "type": "string",
                            "description": "The text to replace it with"
                        }
                    },
                    "required": ["file_path", "old_string", "new_string"]
                },
                function=self._edit_file
            ),
            Tool(
                name="Replace",
                description="Write a file to the local filesystem. Overwrites the existing file if there is one.",
                parameters={
                    "type": "object",
                    "properties": {
                        "file_path": {
                            "type": "string",
                            "description": "The absolute path to the file to write"
                        },
                        "content": {
                            "type": "string",
                            "description": "The content to write to the file"
                        }
                    },
                    "required": ["file_path", "content"]
                },
                function=self._replace_file
            ),
            Tool(
                name="Bash",
                description="Executes a given bash command in a persistent shell session.",
                parameters={
                    "type": "object",
                    "properties": {
                        "command": {
                            "type": "string",
                            "description": "The command to execute"
                        },
                        "timeout": {
                            "type": "number",
                            "description": "Optional timeout in milliseconds (max 600000)"
                        }
                    },
                    "required": ["command"]
                },
                function=self._execute_bash
            ),
            Tool(
                name="GlobTool",
                description="Fast file pattern matching tool that works with any codebase size.",
                parameters={
                    "type": "object",
                    "properties": {
                        "path": {
                            "type": "string",
                            "description": "The directory to search in. Defaults to the current working directory."
                        },
                        "pattern": {
                            "type": "string",
                            "description": "The glob pattern to match files against"
                        }
                    },
                    "required": ["pattern"]
                },
                function=self._glob_tool
            ),
            Tool(
                name="GrepTool",
                description="Fast content search tool that works with any codebase size.",
                parameters={
                    "type": "object",
                    "properties": {
                        "path": {
                            "type": "string",
                            "description": "The directory to search in. Defaults to the current working directory."
                        },
                        "pattern": {
                            "type": "string",
                            "description": "The regular expression pattern to search for in file contents"
                        },
                        "include": {
                            "type": "string",
                            "description": "File pattern to include in the search (e.g. \"*.js\", \"*.{ts,tsx}\")"
                        }
                    },
                    "required": ["pattern"]
                },
                function=self._grep_tool
            ),
            Tool(
                name="LS",
                description="Lists files and directories in a given path.",
                parameters={
                    "type": "object",
                    "properties": {
                        "path": {
                            "type": "string",
                            "description": "The absolute path to the directory to list"
                        },
                        "ignore": {
                            "type": "array",
                            "items": {
                                "type": "string"
                            },
                            "description": "List of glob patterns to ignore"
                        }
                    },
                    "required": ["path"]
                },
                function=self._list_directory
            ),
            Tool(
                name="JinaSearch",
                description="Search the web for information using Jina.ai",
                parameters={
                    "type": "object",
                    "properties": {
                        "query": {
                            "type": "string",
                            "description": "The search query"
                        }
                    },
                    "required": ["query"]
                },
                function=self._jina_search
            ),
            Tool(
                name="JinaFactCheck",
                description="Fact check a statement using Jina.ai",
                parameters={
                    "type": "object",
                    "properties": {
                        "statement": {
                            "type": "string",
                            "description": "The statement to fact check"
                        }
                    },
                    "required": ["statement"]
                },
                function=self._jina_fact_check
            ),
            Tool(
                name="JinaReadURL",
                description="Read and summarize a webpage using Jina.ai",
                parameters={
                    "type": "object",
                    "properties": {
                        "url": {
                            "type": "string",
                            "description": "The URL of the webpage to read"
                        }
                    },
                    "required": ["url"]
                },
                function=self._jina_read_url
            )
        ]
        return tools
    
    # Tool implementations
    # TODO: Add better error handling and user feedback
    # TODO: Implement tool usage tracking and metrics
    
    def _get_weather(self, location: str) -> str:
        """Get current weather for a location using OpenWeatherMap API"""
        try:
            # Get API key from environment or use a default for testing
            api_key = os.getenv("OPENWEATHER_API_KEY", "")
            if not api_key:
                return "Error: OpenWeatherMap API key not found. Please set the OPENWEATHER_API_KEY environment variable."
            
            # Prepare the API request
            base_url = "https://api.openweathermap.org/data/2.5/weather"
            params = {
                "q": location,
                "appid": api_key,
                "units": "metric"  # Use metric units (Celsius)
            }
            
            # Make the API request
            response = requests.get(base_url, params=params)
            
            # Check if the request was successful
            if response.status_code == 200:
                data = response.text
                # Try to parse as JSON
                try:
                    data = json.loads(data)
                except json.JSONDecodeError:
                    return f"Error: Unable to parse weather data. Raw response: {data[:200]}..."
                
                # Extract relevant weather information
                weather_desc = data["weather"][0]["description"]
                temp = data["main"]["temp"]
                feels_like = data["main"]["feels_like"]
                humidity = data["main"]["humidity"]
                wind_speed = data["wind"]["speed"]
                
                # Format the response
                weather_info = (
                    f"Current weather in {location}:\n"
                    f"• Condition: {weather_desc.capitalize()}\n"
                    f"• Temperature: {temp}°C ({(temp * 9/5) + 32:.1f}°F)\n"
                    f"• Feels like: {feels_like}°C ({(feels_like * 9/5) + 32:.1f}°F)\n"
                    f"• Humidity: {humidity}%\n"
                    f"• Wind speed: {wind_speed} m/s ({wind_speed * 2.237:.1f} mph)"
                )
                return weather_info
            else:
                # Handle API errors
                if response.status_code == 404:
                    return f"Error: Location '{location}' not found. Please check the spelling or try a different location."
                elif response.status_code == 401:
                    return "Error: Invalid API key. Please check your OpenWeatherMap API key."
                else:
                    return f"Error: Unable to fetch weather data. Status code: {response.status_code}"
        
        except requests.exceptions.RequestException as e:
            return f"Error: Network error when fetching weather data: {str(e)}"
        except Exception as e:
            return f"Error: Failed to get weather information: {str(e)}"
    
    def _view_file(self, file_path: str, limit: Optional[int] = None, offset: Optional[int] = 0) -> str:
        # TODO: Add special handling for binary files and images
        # TODO: Add syntax highlighting for code files
        try:
            if not os.path.exists(file_path):
                return f"Error: File not found: {file_path}"
            
            # TODO: Handle file size limits better
            
            with open(file_path, 'r') as f:
                if limit is not None and offset is not None:
                    # Skip to offset
                    for _ in range(offset):
                        next(f, None)
                    
                    # Read limited lines
                    lines = []
                    for _ in range(limit):
                        line = next(f, None)
                        if line is None:
                            break
                        lines.append(line)
                    content = ''.join(lines)
                else:
                    content = f.read()
            
            # TODO: Add file metadata like size, permissions, etc.
            return content
        except Exception as e:
            return f"Error reading file: {str(e)}"
    
    def _edit_file(self, file_path: str, old_string: str, new_string: str) -> str:
        try:
            # Create directory if creating new file
            if not os.path.exists(os.path.dirname(file_path)) and old_string == "":
                os.makedirs(os.path.dirname(file_path), exist_ok=True)
                
            if old_string == "" and not os.path.exists(file_path):
                # Creating new file
                with open(file_path, 'w') as f:
                    f.write(new_string)
                return f"Created new file: {file_path}"
            
            # Reading existing file
            if not os.path.exists(file_path):
                return f"Error: File not found: {file_path}"
            
            with open(file_path, 'r') as f:
                content = f.read()
            
            # Replace string
            if old_string not in content:
                return f"Error: Could not find the specified text in {file_path}"
            
            # Count occurrences to ensure uniqueness
            occurrences = content.count(old_string)
            if occurrences > 1:
                return f"Error: Found {occurrences} occurrences of the specified text in {file_path}. Please provide more context to uniquely identify the text to replace."
            
            new_content = content.replace(old_string, new_string)
            
            # Write back to file
            with open(file_path, 'w') as f:
                f.write(new_content)
            
            return f"Successfully edited {file_path}"
        
        except Exception as e:
            return f"Error editing file: {str(e)}"
    
    def _replace_file(self, file_path: str, content: str) -> str:
        try:
            # Create directory if it doesn't exist
            directory = os.path.dirname(file_path)
            if directory and not os.path.exists(directory):
                os.makedirs(directory, exist_ok=True)
            
            # Write content to file
            with open(file_path, 'w') as f:
                f.write(content)
            
            return f"Successfully wrote to {file_path}"
        
        except Exception as e:
            return f"Error writing file: {str(e)}"
    
    def _execute_bash(self, command: str, timeout: Optional[int] = None) -> str:
        try:
            import subprocess
            import shlex
            
            # Security check for banned commands
            banned_commands = [
                'alias', 'curl', 'curlie', 'wget', 'axel', 'aria2c', 'nc', 
                'telnet', 'lynx', 'w3m', 'links', 'httpie', 'xh', 'http-prompt', 
                'chrome', 'firefox', 'safari'
            ]
            
            for banned in banned_commands:
                if banned in command.split():
                    return f"Error: The command '{banned}' is not allowed for security reasons."
            
            # Execute command
            if timeout:
                timeout_seconds = timeout / 1000  # Convert to seconds
            else:
                timeout_seconds = 1800  # 30 minutes default
            
            result = subprocess.run(
                command,
                shell=True,
                capture_output=True,
                text=True,
                timeout=timeout_seconds
            )
            
            output = result.stdout
            if result.stderr:
                output += f"\nErrors:\n{result.stderr}"
            
            # Truncate if too long
            if len(output) > 30000:
                output = output[:30000] + "\n... (output truncated)"
            
            return output
        
        except subprocess.TimeoutExpired:
            return f"Error: Command timed out after {timeout_seconds} seconds"
        except Exception as e:
            return f"Error executing command: {str(e)}"
    
    def _glob_tool(self, pattern: str, path: Optional[str] = None) -> str:
        try:
            import glob
            import os
            
            if path is None:
                path = os.getcwd()
            
            # Build the full pattern path
            if not os.path.isabs(path):
                path = os.path.abspath(path)
            
            full_pattern = os.path.join(path, pattern)
            
            # Get matching files
            matches = glob.glob(full_pattern, recursive=True)
            
            # Sort by modification time (newest first)
            matches.sort(key=os.path.getmtime, reverse=True)
            
            if not matches:
                return f"No files matching pattern '{pattern}' in {path}"
            
            return "\n".join(matches)
        
        except Exception as e:
            return f"Error in glob search: {str(e)}"
    
    def _grep_tool(self, pattern: str, path: Optional[str] = None, include: Optional[str] = None) -> str:
        try:
            import re
            import os
            import fnmatch
            from concurrent.futures import ThreadPoolExecutor
            
            if path is None:
                path = os.getcwd()
            
            if not os.path.isabs(path):
                path = os.path.abspath(path)
            
            # Compile regex pattern
            regex = re.compile(pattern)
            
            # Get all files
            all_files = []
            for root, _, files in os.walk(path):
                for file in files:
                    file_path = os.path.join(root, file)
                    
                    # Apply include filter if provided
                    if include:
                        if not fnmatch.fnmatch(file, include):
                            continue
                    
                    all_files.append(file_path)
            
            # Sort by modification time (newest first)
            all_files.sort(key=os.path.getmtime, reverse=True)
            
            matches = []
            
            def search_file(file_path):
                try:
                    with open(file_path, 'r', errors='ignore') as f:
                        content = f.read()
                        if regex.search(content):
                            return file_path
                except:
                    # Skip files that can't be read
                    pass
                return None
            
            # Search files in parallel
            with ThreadPoolExecutor(max_workers=10) as executor:
                results = executor.map(search_file, all_files)
                
                for result in results:
                    if result:
                        matches.append(result)
            
            if not matches:
                return f"No matches found for pattern '{pattern}' in {path}"
            
            return "\n".join(matches)
        
        except Exception as e:
            return f"Error in grep search: {str(e)}"
    
    def _list_directory(self, path: str, ignore: Optional[List[str]] = None) -> str:
        try:
            import os
            import fnmatch
            
            # If path is not absolute, make it absolute from current directory
            if not os.path.isabs(path):
                path = os.path.abspath(os.path.join(os.getcwd(), path))
            
            if not os.path.exists(path):
                return f"Error: Directory not found: {path}"
            
            if not os.path.isdir(path):
                return f"Error: Path is not a directory: {path}"
            
            # List directory contents
            items = os.listdir(path)
            
            # Apply ignore patterns
            if ignore:
                for pattern in ignore:
                    items = [item for item in items if not fnmatch.fnmatch(item, pattern)]
            
            # Sort items
            items.sort()
            
            # Format output
            result = []
            for item in items:
                item_path = os.path.join(path, item)
                if os.path.isdir(item_path):
                    result.append(f"{item}/")
                else:
                    result.append(item)
            
            if not result:
                return f"Directory {path} is empty"
            
            return "\n".join(result)
        
        except Exception as e:
            return f"Error listing directory: {str(e)}"
    
    def add_message(self, role: str, content: str):
        """Legacy method to add messages - use direct append now"""
        self.messages.append({"role": role, "content": content})
    
    def process_tool_calls(self, tool_calls, query=None):
        # TODO: Add tool call validation
        # TODO: Add permission system for sensitive tools
        # TODO: Add progress visualization for long-running tools
        responses = []
        
        # Process tool calls in parallel
        from concurrent.futures import ThreadPoolExecutor
        
        def process_single_tool(tool_call):
            # Handle both object-style and dict-style tool calls
            if isinstance(tool_call, dict):
                function_name = tool_call["function"]["name"]
                function_args = json.loads(tool_call["function"]["arguments"])
                tool_call_id = tool_call["id"]
            else:
                function_name = tool_call.function.name
                function_args = json.loads(tool_call.function.arguments)
                tool_call_id = tool_call.id
            
            # Get the tool function
            if function_name in self.tool_map:
                # TODO: Add pre-execution validation
                # TODO: Add permission check here
                
                # Track start time for metrics
                start_time = time.time()
                
                try:
                    function = self.tool_map[function_name]
                    result = function(**function_args)
                    success = True
                except Exception as e:
                    result = f"Error executing tool {function_name}: {str(e)}\n{traceback.format_exc()}"
                    success = False
                
                # Calculate execution time
                execution_time = time.time() - start_time
                
                # Record tool usage for optimization if optimizer is available
                if self.tool_optimizer is not None and query is not None:
                    try:
                        # Create current context snapshot
                        context = {
                            "messages": self.messages.copy(),
                            "conversation_id": self.conversation_id,
                        }
                        
                        # Record tool usage
                        self.tool_optimizer.record_tool_usage(
                            query=query,
                            tool_name=function_name,
                            execution_time=execution_time,
                            token_usage=self.token_usage.copy(),
                            success=success,
                            context=context,
                            result=result
                        )
                    except Exception as e:
                        if self.verbose:
                            print(f"Warning: Failed to record tool usage: {e}")
                
                return {
                    "tool_call_id": tool_call_id,
                    "function_name": function_name,
                    "result": result,
                    "name": function_name,
                    "execution_time": execution_time,  # For metrics
                    "success": success
                }
            return None
        
        # Process all tool calls in parallel
        with ThreadPoolExecutor(max_workers=min(10, len(tool_calls))) as executor:
            futures = [executor.submit(process_single_tool, tool_call) for tool_call in tool_calls]
            for future in futures:
                result = future.result()
                if result:
                    # Add tool response to messages
                    self.messages.append({
                        "tool_call_id": result["tool_call_id"],
                        "role": "tool",
                        "name": result["name"],
                        "content": result["result"]
                    })
                    
                    responses.append({
                        "tool_call_id": result["tool_call_id"],
                        "function_name": result["function_name"],
                        "result": result["result"]
                    })
                    
                    # Log tool execution metrics if verbose
                    if self.verbose:
                        print(f"Tool {result['function_name']} executed in {result['execution_time']:.2f}s (success: {result['success']})")
        
        # Return tool responses
        return responses
    
    def compact(self):
        # TODO: Add more sophisticated compaction with token counting
        # TODO: Implement selective retention of critical information
        # TODO: Add option to save conversation history before compacting
        
        system_prompt = next((m for m in self.messages if m["role"] == "system"), None)
        user_messages = [m for m in self.messages if m["role"] == "user"]
        
        if not user_messages:
            return "No user messages to compact."
        
        last_user_message = user_messages[-1]
        
        # Create a compaction prompt
        # TODO: Improve the compaction prompt with more guidance on what to retain
        compact_prompt = (
            "Summarize the conversation so far, focusing on the key points, decisions, and context. "
            "Keep important details about the code and tasks. Retain critical file paths, commands, "
            "and code snippets. The summary should be concise but complete enough to continue the "
            "conversation effectively."
        )
        
        # Add compaction message
        self.messages.append({"role": "user", "content": compact_prompt})
        
        # Get compaction summary
        # TODO: Add error handling for compaction API call
        response = self.client.chat.completions.create(
            model=self.model,
            messages=self.messages,
            stream=False
        )
        
        summary = response.choices[0].message.content
        
        # Reset conversation with summary
        if system_prompt:
            self.messages = [system_prompt]
        else:
            self.messages = []
        
        self.messages.append({"role": "system", "content": f"This is a compacted conversation. Previous context: {summary}"})
        self.messages.append({"role": "user", "content": last_user_message["content"]})
        
        # TODO: Add metrics for compaction (tokens before/after)
        
        return "Conversation compacted successfully."
    
    def get_response(self, user_input: str, stream: bool = True):
        # TODO: Add more special commands similar to Claude Code (e.g., /version, /status)
        # TODO: Implement binary feedback mechanism for comparing responses
        
        # Special commands
        if user_input.strip() == "/compact":
            return self.compact()
        
        # Add a debug command to help diagnose issues
        if user_input.strip() == "/debug":
            debug_info = {
                "model": self.model,
                "temperature": self.temperature,
                "message_count": len(self.messages),
                "token_usage": self.token_usage,
                "conversation_id": self.conversation_id,
                "session_duration": time.time() - self.session_start_time,
                "tools_count": len(self.tools),
                "python_version": sys.version,
                "openai_version": OpenAI.__version__ if hasattr(OpenAI, "__version__") else "Unknown"
            }
            return "Debug Information:\n" + json.dumps(debug_info, indent=2)
        
        if user_input.strip() == "/help":
            # Standard commands
            commands = [
                "/help - Show this help message",
                "/compact - Compact the conversation to reduce token usage",
                "/status - Show token usage and session information",
                "/config - Show current configuration settings",
            ]
            
            # RL-specific commands if available
            if self.tool_optimizer is not None:
                commands.extend([
                    "/rl-status - Show RL tool optimizer status",
                    "/rl-update - Update the RL model manually",
                    "/rl-stats - Show tool usage statistics",
                ])
            
            return "Available commands:\n" + "\n".join(commands)
        
        # Token usage and session stats
        if user_input.strip() == "/status":
            # Calculate session duration
            session_duration = time.time() - self.session_start_time
            hours, remainder = divmod(session_duration, 3600)
            minutes, seconds = divmod(remainder, 60)
            
            # Format message
            status = (
                f"Session ID: {self.conversation_id}\n"
                f"Model: {self.model} (Temperature: {self.temperature})\n"
                f"Session duration: {int(hours)}h {int(minutes)}m {int(seconds)}s\n\n"
                f"Token usage:\n"
                f"  Prompt tokens: {self.token_usage['prompt']}\n"
                f"  Completion tokens: {self.token_usage['completion']}\n"
                f"  Total tokens: {self.token_usage['total']}\n"
            )
            return status
            
        # Configuration settings
        if user_input.strip() == "/config":
            config_info = (
                f"Current Configuration:\n"
                f"  Model: {self.model}\n"
                f"  Temperature: {self.temperature}\n"
                f"  Max tool iterations: {self.max_tool_iterations}\n"
                f"  Verbose mode: {self.verbose}\n"
                f"  RL optimization: {self.tool_optimizer is not None}\n"
            )
            
            # Provide instructions for changing settings
            config_info += "\nTo change settings, use:\n"
            config_info += "  /config set <setting> <value>\n"
            config_info += "Example: /config set max_tool_iterations 15"
            
            return config_info
            
        # Handle configuration changes
        if user_input.strip().startswith("/config set "):
            parts = user_input.strip().split(" ", 3)
            if len(parts) != 4:
                return "Invalid format. Use: /config set <setting> <value>"
                
            setting = parts[2]
            value = parts[3]
            
            if setting == "max_tool_iterations":
                try:
                    self.max_tool_iterations = int(value)
                    return f"Max tool iterations set to {self.max_tool_iterations}"
                except ValueError:
                    return "Invalid value. Please provide a number."
            elif setting == "temperature":
                try:
                    self.temperature = float(value)
                    return f"Temperature set to {self.temperature}"
                except ValueError:
                    return "Invalid value. Please provide a number."
            elif setting == "verbose":
                if value.lower() in ("true", "yes", "1", "on"):
                    self.verbose = True
                    return "Verbose mode enabled"
                elif value.lower() in ("false", "no", "0", "off"):
                    self.verbose = False
                    return "Verbose mode disabled"
                else:
                    return "Invalid value. Use 'true' or 'false'."
            elif setting == "model":
                self.model = value
                return f"Model set to {self.model}"
            else:
                return f"Unknown setting: {setting}"
        
        # RL-specific commands
        if self.tool_optimizer is not None:
            # RL status command
            if user_input.strip() == "/rl-status":
                return (
                    f"RL tool optimization is active\n"
                    f"Optimizer type: {type(self.tool_optimizer).__name__}\n"
                    f"Number of tools: {len(self.tools)}\n"
                    f"Data directory: {self.tool_optimizer.optimizer.data_dir if hasattr(self.tool_optimizer, 'optimizer') else 'N/A'}\n"
                )
            
            # RL update command
            if user_input.strip() == "/rl-update":
                try:
                    result = self.tool_optimizer.optimizer.update_model()
                    status = f"RL model update status: {result['status']}\n"
                    if 'metrics' in result:
                        status += "Metrics:\n" + "\n".join([f"  {k}: {v}" for k, v in result['metrics'].items()])
                    return status
                except Exception as e:
                    return f"Error updating RL model: {str(e)}"
            
            # RL stats command
            if user_input.strip() == "/rl-stats":
                try:
                    if hasattr(self.tool_optimizer, 'optimizer') and hasattr(self.tool_optimizer.optimizer, 'tracker'):
                        stats = self.tool_optimizer.optimizer.tracker.get_tool_stats()
                        if not stats:
                            return "No tool usage data available yet."
                        
                        result = "Tool Usage Statistics:\n\n"
                        for tool_name, tool_stats in stats.items():
                            result += f"{tool_name}:\n"
                            result += f"  Count: {tool_stats['count']}\n"
                            result += f"  Success rate: {tool_stats['success_rate']:.2f}\n"
                            result += f"  Avg time: {tool_stats['avg_time']:.2f}s\n"
                            result += f"  Avg tokens: {tool_stats['avg_total_tokens']:.1f}\n"
                            result += "\n"
                        return result
                    return "Tool usage statistics not available."
                except Exception as e:
                    return f"Error getting tool statistics: {str(e)}"
        
        # TODO: Add /version command to show version information
        
        # Add user message
        self.messages.append({"role": "user", "content": user_input})
        
        # Initialize empty response
        response_text = ""
        
        # Create tools list for API
        # TODO: Add dynamic tool availability based on context
        api_tools = []
        for tool in self.tools:
            api_tools.append({
                "type": "function",
                "function": {
                    "name": tool.name,
                    "description": tool.description,
                    "parameters": tool.parameters
                }
            })
        
        if stream:
            # TODO: Add retry mechanism for API failures
            # TODO: Add token tracking for response
            # TODO: Implement cancellation support
            
            # Stream response
            try:
                # Add retry logic for API calls
                max_retries = 3
                retry_count = 0
                while retry_count < max_retries:
                    try:
                        stream = self.client.chat.completions.create(
                            model=self.model,
                            messages=self.messages,
                            tools=api_tools,
                            temperature=self.temperature,
                            stream=True
                        )
                        break  # Success, exit retry loop
                    except Exception as e:
                        retry_count += 1
                        if retry_count >= max_retries:
                            raise  # Re-raise if we've exhausted retries
                        
                        # Exponential backoff
                        wait_time = 2 ** retry_count
                        if self.verbose:
                            console.print(f"[yellow]API call failed, retrying in {wait_time}s... ({retry_count}/{max_retries})[/yellow]")
                        time.sleep(wait_time)
                
                current_tool_calls = []
                tool_call_chunks = {}
                
                # Process streaming response outside of the status context
                with Live("", refresh_per_second=10) as live:
                    for chunk in stream:
                        # If there's content, print it
                        if chunk.choices[0].delta.content:
                            content_piece = chunk.choices[0].delta.content
                            response_text += content_piece
                            # Update the live display with the accumulated response
                            live.update(response_text)
                        
                        # Process tool calls
                        delta = chunk.choices[0].delta
                        if delta.tool_calls:
                            for tool_call_delta in delta.tool_calls:
                                # Initialize tool call in chunks dictionary if new
                                if tool_call_delta.index not in tool_call_chunks:
                                    tool_call_chunks[tool_call_delta.index] = {
                                        "id": "",
                                        "function": {"name": "", "arguments": ""}
                                    }
                                
                                # Update tool call data
                                if tool_call_delta.id:
                                    tool_call_chunks[tool_call_delta.index]["id"] = tool_call_delta.id
                                
                                if tool_call_delta.function:
                                    if tool_call_delta.function.name:
                                        tool_call_chunks[tool_call_delta.index]["function"]["name"] = tool_call_delta.function.name
                                    
                                    if tool_call_delta.function.arguments:
                                        tool_call_chunks[tool_call_delta.index]["function"]["arguments"] += tool_call_delta.function.arguments
                
                # No need to print the response again as it was already streamed in the Live context
                
            except Exception as e:
                # TODO: Add better error handling and user feedback
                console.print(f"[bold red]Error during API call:[/bold red] {str(e)}")
                return f"Error during API call: {str(e)}"
            
            # Convert tool call chunks to actual tool calls
            for index, tool_call_data in tool_call_chunks.items():
                current_tool_calls.append({
                    "id": tool_call_data["id"],
                    "function": {
                        "name": tool_call_data["function"]["name"],
                        "arguments": tool_call_data["function"]["arguments"]
                    }
                })
            
            # Process tool calls if any
            if current_tool_calls:
                try:
                    # Add assistant message with tool_calls to messages first
                    # Ensure each tool call has a "type" field set to "function"
                    processed_tool_calls = []
                    for tool_call in current_tool_calls:
                        processed_tool_call = tool_call.copy()
                        processed_tool_call["type"] = "function"
                        processed_tool_calls.append(processed_tool_call)
                        
                    # Make sure we add the assistant message with tool calls before processing them
                    self.messages.append({
                        "role": "assistant", 
                        "content": response_text,
                        "tool_calls": processed_tool_calls
                    })
                        
                    # Now process the tool calls
                    with console.status("[bold green]Running tools..."):
                        tool_responses = self.process_tool_calls(current_tool_calls, query=user_input)
                except Exception as e:
                    console.print(f"[bold red]Error:[/bold red] {str(e)}")
                    console.print(traceback.format_exc())
                    return f"Error processing tool calls: {str(e)}"
                    
                # Continue the conversation with tool responses
                # Implement looping function calls to allow for recursive tool usage
                max_loop_iterations = self.max_tool_iterations  # Use configurable setting
                current_iteration = 0
                
                while current_iteration < max_loop_iterations:
                    # Add retry logic for follow-up API calls
                    max_retries = 3
                    retry_count = 0
                    follow_up = None
                    
                    while retry_count < max_retries:
                        try:
                            follow_up = self.client.chat.completions.create(
                                model=self.model,
                                messages=self.messages,
                                tools=api_tools,  # Pass tools to enable recursive function calling
                                stream=False
                            )
                            break  # Success, exit retry loop
                        except Exception as e:
                            retry_count += 1
                            if retry_count >= max_retries:
                                raise  # Re-raise if we've exhausted retries
                                
                            # Exponential backoff
                            wait_time = 2 ** retry_count
                            if self.verbose:
                                console.print(f"[yellow]Follow-up API call failed, retrying in {wait_time}s... ({retry_count}/{max_retries})[/yellow]")
                            time.sleep(wait_time)
                    
                    # Check if the follow-up response contains more tool calls
                    assistant_message = follow_up.choices[0].message
                    follow_up_text = assistant_message.content or ""
                    
                    # If there are no more tool calls, we're done with the loop
                    if not hasattr(assistant_message, 'tool_calls') or not assistant_message.tool_calls:
                        if follow_up_text:
                            console.print(Markdown(follow_up_text))
                            response_text += "\n" + follow_up_text
                        
                        # Add the final assistant message to the conversation
                        self.messages.append({"role": "assistant", "content": follow_up_text})
                        break
                    
                    # Process the new tool calls
                    current_tool_calls = []
                    for tool_call in assistant_message.tool_calls:
                        # Handle both object-style and dict-style tool calls
                        if isinstance(tool_call, dict):
                            processed_tool_call = tool_call.copy()
                        else:
                            # Convert object to dict
                            processed_tool_call = {
                                "id": tool_call.id,
                                "function": {
                                    "name": tool_call.function.name,
                                    "arguments": tool_call.function.arguments
                                }
                            }
                        
                        # Ensure type field is present
                        processed_tool_call["type"] = "function"
                        current_tool_calls.append(processed_tool_call)
                    
                    # Add the assistant message with tool calls
                    self.messages.append({
                        "role": "assistant",
                        "content": follow_up_text,
                        "tool_calls": current_tool_calls
                    })
                    
                    # Process the new tool calls
                    with console.status(f"[bold green]Running tools (iteration {current_iteration + 1})...[/bold green]"):
                        tool_responses = self.process_tool_calls(assistant_message.tool_calls, query=user_input)
                    
                    # Increment the iteration counter
                    current_iteration += 1
                
                # If we've reached the maximum number of iterations, add a warning
                if current_iteration >= max_loop_iterations:
                    warning_message = f"[yellow]Warning: Reached maximum number of tool call iterations ({max_loop_iterations}). Some operations may be incomplete.[/yellow]"
                    console.print(warning_message)
                    response_text += f"\n\n{warning_message}"
            
            # Add assistant response to messages if there were no tool calls
            # (we already added it above if there were tool calls)
            if not current_tool_calls:
                self.messages.append({"role": "assistant", "content": response_text})
            
            return response_text
        else:
            # Non-streaming response
            response = self.client.chat.completions.create(
                model=self.model,
                messages=self.messages,
                tools=api_tools,
                temperature=self.temperature,
                stream=False
            )
            
            # Track token usage
            if hasattr(response, 'usage'):
                self.token_usage["prompt"] += response.usage.prompt_tokens
                self.token_usage["completion"] += response.usage.completion_tokens
                self.token_usage["total"] += response.usage.total_tokens
            
            assistant_message = response.choices[0].message
            response_text = assistant_message.content or ""
            
            # Process tool calls if any
            if assistant_message.tool_calls:
                # Add assistant message with tool_calls to messages
                # Convert tool_calls to a list of dictionaries with "type" field
                processed_tool_calls = []
                for tool_call in assistant_message.tool_calls:
                    # Handle both object-style and dict-style tool calls
                    if isinstance(tool_call, dict):
                        processed_tool_call = tool_call.copy()
                    else:
                        # Convert object to dict
                        processed_tool_call = {
                            "id": tool_call.id,
                            "function": {
                                "name": tool_call.function.name,
                                "arguments": tool_call.function.arguments
                            }
                        }
                    
                    # Ensure type field is present
                    processed_tool_call["type"] = "function"
                    processed_tool_calls.append(processed_tool_call)
                
                self.messages.append({
                    "role": "assistant",
                    "content": response_text,
                    "tool_calls": processed_tool_calls
                })
                
                with console.status("[bold green]Running tools..."):
                    tool_responses = self.process_tool_calls(assistant_message.tool_calls, query=user_input)
                
                # Continue the conversation with tool responses
                # Implement looping function calls to allow for recursive tool usage
                max_loop_iterations = self.max_tool_iterations  # Use configurable setting
                current_iteration = 0
                
                while current_iteration < max_loop_iterations:
                    # Add retry logic for follow-up API calls
                    max_retries = 3
                    retry_count = 0
                    follow_up = None
                    
                    while retry_count < max_retries:
                        try:
                            follow_up = self.client.chat.completions.create(
                                model=self.model,
                                messages=self.messages,
                                tools=api_tools,  # Pass tools to enable recursive function calling
                                stream=False
                            )
                            break  # Success, exit retry loop
                        except Exception as e:
                            retry_count += 1
                            if retry_count >= max_retries:
                                raise  # Re-raise if we've exhausted retries
                                
                            # Exponential backoff
                            wait_time = 2 ** retry_count
                            if self.verbose:
                                console.print(f"[yellow]Follow-up API call failed, retrying in {wait_time}s... ({retry_count}/{max_retries})[/yellow]")
                            time.sleep(wait_time)
                    
                    # Check if the follow-up response contains more tool calls
                    assistant_message = follow_up.choices[0].message
                    follow_up_text = assistant_message.content or ""
                    
                    # If there are no more tool calls, we're done with the loop
                    if not hasattr(assistant_message, 'tool_calls') or not assistant_message.tool_calls:
                        if follow_up_text:
                            console.print(Markdown(follow_up_text))
                            response_text += "\n" + follow_up_text
                        
                        # Add the final assistant message to the conversation
                        self.messages.append({"role": "assistant", "content": follow_up_text})
                        break
                    
                    # Process the new tool calls
                    current_tool_calls = []
                    for tool_call in assistant_message.tool_calls:
                        # Handle both object-style and dict-style tool calls
                        if isinstance(tool_call, dict):
                            processed_tool_call = tool_call.copy()
                        else:
                            # Convert object to dict
                            processed_tool_call = {
                                "id": tool_call.id,
                                "function": {
                                    "name": tool_call.function.name,
                                    "arguments": tool_call.function.arguments
                                }
                            }
                        
                        # Ensure type field is present
                        processed_tool_call["type"] = "function"
                        current_tool_calls.append(processed_tool_call)
                    
                    # Add the assistant message with tool calls
                    self.messages.append({
                        "role": "assistant",
                        "content": follow_up_text,
                        "tool_calls": current_tool_calls
                    })
                    
                    # Process the new tool calls
                    with console.status(f"[bold green]Running tools (iteration {current_iteration + 1})...[/bold green]"):
                        tool_responses = self.process_tool_calls(assistant_message.tool_calls, query=user_input)
                    
                    # Increment the iteration counter
                    current_iteration += 1
                
                # If we've reached the maximum number of iterations, add a warning
                if current_iteration >= max_loop_iterations:
                    warning_message = f"[yellow]Warning: Reached maximum number of tool call iterations ({max_loop_iterations}). Some operations may be incomplete.[/yellow]"
                    console.print(warning_message)
                    response_text += f"\n\n{warning_message}"
            else:
                console.print(Markdown(response_text))
            
            # Add assistant response to messages if not already added
            # (we already added it above if there were tool calls)
            if not assistant_message.tool_calls:
                self.messages.append({"role": "assistant", "content": response_text})
            
            return response_text

# TODO: Create a more flexible system prompt mechanism with customizable templates
def get_system_prompt():
    return """You are OpenAI Code Assistant, a CLI tool that helps users with software engineering tasks and general information.
Use the available tools to assist the user with their requests.

# Tone and style
You should be concise, direct, and to the point. When you run a non-trivial bash command, 
you should explain what the command does and why you are running it.
Output text to communicate with the user; all text you output outside of tool use is displayed to the user.
Remember that your output will be displayed on a command line interface.

# Tool usage policy
- When doing file search, remember to search effectively with the available tools.
- Always use the appropriate tool for the task.
- Use parallel tool calls when appropriate to improve performance.
- NEVER commit changes unless the user explicitly asks you to.
- For weather queries, use the Weather tool to provide real-time information.

# Tasks
The user will primarily request you perform software engineering tasks:
1. Solving bugs
2. Adding new functionality 
3. Refactoring code
4. Explaining code
5. Writing tests

For these tasks:
1. Use search tools to understand the codebase
2. Implement solutions using the available tools
3. Verify solutions with tests if possible
4. Run lint and typecheck commands when appropriate

The user may also ask for general information:
1. Weather conditions
2. Simple calculations
3. General knowledge questions

# Code style
- Follow the existing code style of the project
- Maintain consistent naming conventions
- Use appropriate libraries that are already in the project
- Add comments when code is complex or non-obvious

IMPORTANT: You should minimize output tokens as much as possible while maintaining helpfulness, 
quality, and accuracy. Answer concisely with short lines of text unless the user asks for detail.
"""

# TODO: Add version information and CLI arguments
# TODO: Add logging configuration
# TODO: Create a proper CLI command structure with subcommands

# Hosting and replication capabilities
class HostingManager:
    """Manages hosting and replication of the assistant"""
    
    def __init__(self, host="127.0.0.1", port=8000):
        self.host = host
        self.port = port
        self.app = FastAPI(title="OpenAI Code Assistant API")
        self.conversation_pool = {}
        self.setup_api()
        
    def setup_api(self):
        """Configure the FastAPI application"""
        # Add CORS middleware
        self.app.add_middleware(
            CORSMiddleware,
            allow_origins=["*"],  # In production, restrict this to specific domains
            allow_credentials=True,
            allow_methods=["*"],
            allow_headers=["*"],
        )
        
        # Define API routes
        @self.app.get("/")
        async def root():
            return {"message": "OpenAI Code Assistant API", "status": "running"}
        
        @self.app.post("/conversation")
        async def create_conversation(
            request: Request,
            background_tasks: BackgroundTasks,
            model: str = DEFAULT_MODEL,
            temperature: float = DEFAULT_TEMPERATURE
        ):
            """Create a new conversation instance"""
            conversation_id = str(uuid4())
            
            # Initialize conversation in background
            background_tasks.add_task(self._init_conversation, conversation_id, model, temperature)
            
            return {
                "conversation_id": conversation_id,
                "status": "initializing",
                "model": model
            }
        
        @self.app.post("/conversation/{conversation_id}/message")
        async def send_message(
            conversation_id: str,
            request: Request
        ):
            """Send a message to a conversation"""
            if conversation_id not in self.conversation_pool:
                raise HTTPException(status_code=404, detail="Conversation not found")
                
            data = await request.json()
            user_input = data.get("message", "")
            
            # Get conversation instance
            conversation = self.conversation_pool[conversation_id]
            
            # Process message
            try:
                response = conversation.get_response(user_input, stream=False)
                return {
                    "conversation_id": conversation_id,
                    "response": response
                }
            except Exception as e:
                raise HTTPException(status_code=500, detail=f"Error processing message: {str(e)}")
        
        @self.app.post("/conversation/{conversation_id}/message/stream")
        async def stream_message(
            conversation_id: str,
            request: Request
        ):
            """Stream a message response from a conversation"""
            if conversation_id not in self.conversation_pool:
                raise HTTPException(status_code=404, detail="Conversation not found")
                
            data = await request.json()
            user_input = data.get("message", "")
            
            # Get conversation instance
            conversation = self.conversation_pool[conversation_id]
            
            # Create async generator for streaming
            async def response_generator():
                # Add user message
                conversation.messages.append({"role": "user", "content": user_input})
                
                # Create tools list for API
                api_tools = []
                for tool in conversation.tools:
                    api_tools.append({
                        "type": "function",
                        "function": {
                            "name": tool.name,
                            "description": tool.description,
                            "parameters": tool.parameters
                        }
                    })
                
                # Stream response
                try:
                    stream = conversation.client.chat.completions.create(
                        model=conversation.model,
                        messages=conversation.messages,
                        tools=api_tools,
                        temperature=conversation.temperature,
                        stream=True
                    )
                    
                    current_tool_calls = []
                    tool_call_chunks = {}
                    response_text = ""
                    
                    for chunk in stream:
                        # If there's content, yield it
                        if chunk.choices[0].delta.content:
                            content_piece = chunk.choices[0].delta.content
                            response_text += content_piece
                            yield json.dumps({"type": "content", "content": content_piece}) + "\n"
                        
                        # Process tool calls
                        delta = chunk.choices[0].delta
                        if delta.tool_calls:
                            for tool_call_delta in delta.tool_calls:
                                # Initialize tool call in chunks dictionary if new
                                if tool_call_delta.index not in tool_call_chunks:
                                    tool_call_chunks[tool_call_delta.index] = {
                                        "id": "",
                                        "function": {"name": "", "arguments": ""}
                                    }
                                
                                # Update tool call data
                                if tool_call_delta.id:
                                    tool_call_chunks[tool_call_delta.index]["id"] = tool_call_delta.id
                                
                                if tool_call_delta.function:
                                    if tool_call_delta.function.name:
                                        tool_call_chunks[tool_call_delta.index]["function"]["name"] = tool_call_delta.function.name
                                    
                                    if tool_call_delta.function.arguments:
                                        tool_call_chunks[tool_call_delta.index]["function"]["arguments"] += tool_call_delta.function.arguments
                    
                    # Convert tool call chunks to actual tool calls
                    for index, tool_call_data in tool_call_chunks.items():
                        current_tool_calls.append({
                            "id": tool_call_data["id"],
                            "function": {
                                "name": tool_call_data["function"]["name"],
                                "arguments": tool_call_data["function"]["arguments"]
                            }
                        })
                    
                    # Process tool calls if any
                    if current_tool_calls:
                        # Add assistant message with tool_calls to messages
                        processed_tool_calls = []
                        for tool_call in current_tool_calls:
                            processed_tool_call = tool_call.copy()
                            processed_tool_call["type"] = "function"
                            processed_tool_calls.append(processed_tool_call)
                        
                        conversation.messages.append({
                            "role": "assistant", 
                            "content": response_text,
                            "tool_calls": processed_tool_calls
                        })
                        
                        # Notify client that tools are running
                        yield json.dumps({"type": "status", "status": "running_tools"}) + "\n"
                        
                        # Process tool calls
                        tool_responses = conversation.process_tool_calls(current_tool_calls, query=user_input)
                        
                        # Notify client of tool results
                        for response in tool_responses:
                            yield json.dumps({
                                "type": "tool_result", 
                                "tool": response["function_name"],
                                "result": response["result"]
                            }) + "\n"
                        
                        # Continue the conversation with tool responses
                        max_loop_iterations = conversation.max_tool_iterations
                        current_iteration = 0
                        
                        while current_iteration < max_loop_iterations:
                            follow_up = conversation.client.chat.completions.create(
                                model=conversation.model,
                                messages=conversation.messages,
                                tools=api_tools,
                                stream=False
                            )
                            
                            # Check if the follow-up response contains more tool calls
                            assistant_message = follow_up.choices[0].message
                            follow_up_text = assistant_message.content or ""
                            
                            # If there are no more tool calls, we're done with the loop
                            if not hasattr(assistant_message, 'tool_calls') or not assistant_message.tool_calls:
                                if follow_up_text:
                                    yield json.dumps({"type": "content", "content": follow_up_text}) + "\n"
                                
                                # Add the final assistant message to the conversation
                                conversation.messages.append({"role": "assistant", "content": follow_up_text})
                                break
                            
                            # Process the new tool calls
                            current_tool_calls = []
                            for tool_call in assistant_message.tool_calls:
                                if isinstance(tool_call, dict):
                                    processed_tool_call = tool_call.copy()
                                else:
                                    processed_tool_call = {
                                        "id": tool_call.id,
                                        "function": {
                                            "name": tool_call.function.name,
                                            "arguments": tool_call.function.arguments
                                        }
                                    }
                                
                                processed_tool_call["type"] = "function"
                                current_tool_calls.append(processed_tool_call)
                            
                            # Add the assistant message with tool calls
                            conversation.messages.append({
                                "role": "assistant",
                                "content": follow_up_text,
                                "tool_calls": current_tool_calls
                            })
                            
                            # Notify client that tools are running
                            yield json.dumps({
                                "type": "status", 
                                "status": f"running_tools_iteration_{current_iteration + 1}"
                            }) + "\n"
                            
                            # Process the new tool calls
                            tool_responses = conversation.process_tool_calls(assistant_message.tool_calls, query=user_input)
                            
                            # Notify client of tool results
                            for response in tool_responses:
                                yield json.dumps({
                                    "type": "tool_result", 
                                    "tool": response["function_name"],
                                    "result": response["result"]
                                }) + "\n"
                            
                            # Increment the iteration counter
                            current_iteration += 1
                        
                        # If we've reached the maximum number of iterations, add a warning
                        if current_iteration >= max_loop_iterations:
                            warning_message = f"Warning: Reached maximum number of tool call iterations ({max_loop_iterations}). Some operations may be incomplete."
                            yield json.dumps({"type": "warning", "warning": warning_message}) + "\n"
                    else:
                        # Add assistant response to messages
                        conversation.messages.append({"role": "assistant", "content": response_text})
                    
                    # Signal completion
                    yield json.dumps({"type": "status", "status": "complete"}) + "\n"
                    
                except Exception as e:
                    yield json.dumps({"type": "error", "error": str(e)}) + "\n"
            
            return StreamingResponse(response_generator(), media_type="text/event-stream")
        
        @self.app.get("/conversation/{conversation_id}")
        async def get_conversation(conversation_id: str):
            """Get conversation details"""
            if conversation_id not in self.conversation_pool:
                raise HTTPException(status_code=404, detail="Conversation not found")
            
            conversation = self.conversation_pool[conversation_id]
            
            return {
                "conversation_id": conversation_id,
                "model": conversation.model,
                "temperature": conversation.temperature,
                "message_count": len(conversation.messages),
                "token_usage": conversation.token_usage
            }
        
        @self.app.delete("/conversation/{conversation_id}")
        async def delete_conversation(conversation_id: str):
            """Delete a conversation"""
            if conversation_id not in self.conversation_pool:
                raise HTTPException(status_code=404, detail="Conversation not found")
            
            del self.conversation_pool[conversation_id]
            
            return {"status": "deleted", "conversation_id": conversation_id}
        
        @self.app.get("/health")
        async def health_check():
            """Health check endpoint"""
            return {
                "status": "healthy",
                "active_conversations": len(self.conversation_pool),
                "uptime": time.time() - self.start_time
            }
    
    async def _init_conversation(self, conversation_id, model, temperature):
        """Initialize a conversation instance"""
        conversation = Conversation()
        conversation.model = model
        conversation.temperature = temperature
        conversation.messages.append({"role": "system", "content": get_system_prompt()})
        
        self.conversation_pool[conversation_id] = conversation
    
    def start(self):
        """Start the API server"""
        self.start_time = time.time()
        uvicorn.run(self.app, host=self.host, port=self.port)
    
    def start_background(self):
        """Start the API server in a background thread"""
        self.start_time = time.time()
        thread = threading.Thread(target=uvicorn.run, args=(self.app,), 
                                 kwargs={"host": self.host, "port": self.port})
        thread.daemon = True
        thread.start()
        return thread

class ReplicationManager:
    """Manages replication across multiple instances"""
    
    def __init__(self, primary=True, sync_interval=60):
        self.primary = primary
        self.sync_interval = sync_interval
        self.peers = []
        self.conversation_cache = {}
        self.last_sync = time.time()
        self.sync_lock = threading.Lock()
    
    def add_peer(self, host, port):
        """Add a peer instance to replicate with"""
        peer = {"host": host, "port": port}
        if peer not in self.peers:
            self.peers.append(peer)
            return True
        return False
    
    def remove_peer(self, host, port):
        """Remove a peer instance"""
        peer = {"host": host, "port": port}
        if peer in self.peers:
            self.peers.remove(peer)
            return True
        return False
    
    def sync_conversation(self, conversation_id, conversation):
        """Sync a conversation to all peers"""
        if not self.peers:
            return
        
        # Serialize conversation
        try:
            serialized = pickle.dumps(conversation)
            
            # Calculate hash for change detection
            conversation_hash = hashlib.md5(serialized).hexdigest()
            
            # Check if conversation has changed
            if conversation_id in self.conversation_cache:
                if self.conversation_cache[conversation_id] == conversation_hash:
                    return  # No changes, skip sync
            
            # Update cache
            self.conversation_cache[conversation_id] = conversation_hash
            
            # Sync to peers
            for peer in self.peers:
                try:
                    url = f"http://{peer['host']}:{peer['port']}/sync/conversation/{conversation_id}"
                    requests.post(url, data=serialized, 
                                 headers={"Content-Type": "application/octet-stream"})
                except Exception as e:
                    logging.error(f"Failed to sync with peer {peer['host']}:{peer['port']}: {e}")
        except Exception as e:
            logging.error(f"Error serializing conversation: {e}")
    
    def start_sync_thread(self, conversation_pool):
        """Start background thread for periodic syncing"""
        def sync_worker():
            while True:
                time.sleep(self.sync_interval)
                
                with self.sync_lock:
                    for conversation_id, conversation in conversation_pool.items():
                        self.sync_conversation(conversation_id, conversation)
        
        thread = threading.Thread(target=sync_worker)
        thread.daemon = True
        thread.start()
        return thread

@app.command()
def serve(
    host: str = typer.Option("127.0.0.1", "--host", help="Host address to bind to"),
    port: int = typer.Option(8000, "--port", "-p", help="Port to listen on"),
    workers: int = typer.Option(1, "--workers", "-w", help="Number of worker processes"),
    enable_replication: bool = typer.Option(False, "--enable-replication", help="Enable replication across instances"),
    primary: bool = typer.Option(True, "--primary/--secondary", help="Whether this is a primary or secondary instance"),
    peers: List[str] = typer.Option([], "--peer", help="Peer instances to replicate with (host:port)")
):
    """
    Start the OpenAI Code Assistant as a web service
    """
    console.print(Panel.fit(
        f"[bold green]OpenAI Code Assistant API Server[/bold green]\n"
        f"Host: {host}\n"
        f"Port: {port}\n"
        f"Workers: {workers}\n"
        f"Replication: {'Enabled' if enable_replication else 'Disabled'}\n"
        f"Role: {'Primary' if primary else 'Secondary'}\n"
        f"Peers: {', '.join(peers) if peers else 'None'}",
        title="Server Starting",
        border_style="green"
    ))
    
    # Check API key
    if not os.getenv("OPENAI_API_KEY"):
        console.print("[bold red]Error:[/bold red] No OpenAI API key found. Please set the OPENAI_API_KEY environment variable.")
        return
    
    # Start server
    if workers > 1:
        # Use multiprocessing for multiple workers
        console.print(f"Starting server with {workers} workers...")
        uvicorn.run(
            "cli:create_app",
            host=host,
            port=port,
            workers=workers,
            factory=True
        )
    else:
        # Single process mode
        hosting_manager = HostingManager(host=host, port=port)
        
        # Setup replication if enabled
        if enable_replication:
            replication_manager = ReplicationManager(primary=primary)
            
            # Add peers
            for peer in peers:
                try:
                    peer_host, peer_port = peer.split(":")
                    replication_manager.add_peer(peer_host, int(peer_port))
                except ValueError:
                    console.print(f"[yellow]Warning: Invalid peer format: {peer}. Use host:port format.[/yellow]")
            
            # Start sync thread
            replication_manager.start_sync_thread(hosting_manager.conversation_pool)
            
            console.print(f"Replication enabled with {len(replication_manager.peers)} peers")
        
        # Start server
        hosting_manager.start()

def create_app():
    """Factory function for creating the FastAPI app (used with multiple workers)"""
    hosting_manager = HostingManager()
    return hosting_manager.app

@app.command()
def mcp_serve(
    host: str = typer.Option("127.0.0.1", "--host", help="Host address to bind to"),
    port: int = typer.Option(8000, "--port", "-p", help="Port to listen on"),
    dev_mode: bool = typer.Option(False, "--dev", help="Enable development mode with additional logging"),
    dependencies: List[str] = typer.Option([], "--dependencies", help="Additional Python dependencies to install"),
    env_file: str = typer.Option(None, "--env-file", help="Path to .env file with environment variables"),
    cache_type: str = typer.Option("memory", "--cache", help="Cache type: 'memory' or 'redis'"),
    redis_url: str = typer.Option(None, "--redis-url", help="Redis URL for cache (if cache_type is 'redis')"),
    reload: bool = typer.Option(False, "--reload", help="Enable auto-reload on code changes")
):
    """
    Start the OpenAI Code Assistant as an MCP (Model Context Protocol) server
    
    This allows the assistant to be used as a context provider for MCP clients
    like Claude Desktop or other MCP-compatible applications.
    """
    # Load environment variables from file if specified
    if env_file:
        if os.path.exists(env_file):
            load_dotenv(env_file)
            console.print(f"[green]Loaded environment variables from {env_file}[/green]")
        else:
            console.print(f"[yellow]Warning: Environment file {env_file} not found[/yellow]")
    
    # Install additional dependencies if specified
    required_deps = ["prometheus-client", "tiktoken"]
    if cache_type == "redis":
        required_deps.append("redis")
    
    all_deps = required_deps + list(dependencies)
    
    if all_deps:
        console.print(f"[bold]Installing dependencies: {', '.join(all_deps)}[/bold]")
        try:
            import subprocess
            subprocess.check_call([sys.executable, "-m", "pip", "install", *all_deps])
            console.print("[green]Dependencies installed successfully[/green]")
        except Exception as e:
            console.print(f"[red]Error installing dependencies: {str(e)}[/red]")
            return
    
    # Configure logging for development mode
    if dev_mode:
        import logging
        logging.basicConfig(level=logging.DEBUG)
        console.print("[yellow]Development mode enabled with debug logging[/yellow]")
    
    # Print server information
    cache_info = f"Cache: {cache_type}"
    if cache_type == "redis" and redis_url:
        cache_info += f" ({redis_url})"
    
    console.print(Panel.fit(
        f"[bold green]OpenAI Code Assistant MCP Server[/bold green]\n"
        f"Host: {host}\n"
        f"Port: {port}\n"
        f"Development Mode: {'Enabled' if dev_mode else 'Disabled'}\n"
        f"Auto-reload: {'Enabled' if reload else 'Disabled'}\n"
        f"{cache_info}\n"
        f"API Key: {'Configured' if os.getenv('OPENAI_API_KEY') else 'Not Configured'}",
        title="MCP Server Starting",
        border_style="green"
    ))
    
    # Check API key
    if not os.getenv("OPENAI_API_KEY"):
        console.print("[bold red]Error:[/bold red] No OpenAI API key found. Please set the OPENAI_API_KEY environment variable.")
        return
    
    # Create required directories
    base_dir = os.path.dirname(os.path.abspath(__file__))
    os.makedirs(os.path.join(base_dir, "data"), exist_ok=True)
    os.makedirs(os.path.join(base_dir, "templates"), exist_ok=True)
    os.makedirs(os.path.join(base_dir, "static"), exist_ok=True)
    
    try:
        # Import the MCP server module
        from mcp_server import MCPServer
        
        # Start the MCP server
        server = MCPServer(cache_type=cache_type, redis_url=redis_url)
        server.start(host=host, port=port, reload=reload)
    except ImportError:
        console.print("[bold red]Error:[/bold red] MCP server module not found. Make sure mcp_server.py is in the same directory.")
    except Exception as e:
        console.print(f"[bold red]Error starting MCP server:[/bold red] {str(e)}")
        if dev_mode:
            import traceback
            console.print(traceback.format_exc())

@app.command()
def mcp_client(
    server_path: str = typer.Argument(..., help="Path to the MCP server script or module"),
    model: str = typer.Option("gpt-4o", "--model", "-m", help="Model to use for reasoning"),
    host: str = typer.Option("127.0.0.1", "--host", help="Host address for the MCP server"),
    port: int = typer.Option(8000, "--port", "-p", help="Port for the MCP server")
):
    """
    Connect to an MCP server using OpenAI Code Assistant as the reasoning engine
    
    This allows using the assistant to interact with any MCP-compatible server.
    """
    console.print(Panel.fit(
        f"[bold green]OpenAI Code Assistant MCP Client[/bold green]\n"
        f"Server: {server_path}\n"
        f"Model: {model}\n"
        f"Host: {host}\n"
        f"Port: {port}",
        title="MCP Client Starting",
        border_style="green"
    ))
    
    # Check if server path exists
    if not os.path.exists(server_path):
        console.print(f"[bold red]Error:[/bold red] Server script not found at {server_path}")
        return
    
    # Check API key
    if not os.getenv("OPENAI_API_KEY"):
        console.print("[bold red]Error:[/bold red] No OpenAI API key found. Please set the OPENAI_API_KEY environment variable.")
        return
    
    try:
        # Start the server in a subprocess
        import subprocess
        import signal
        
        # Start server process
        console.print(f"[bold]Starting MCP server from {server_path}...[/bold]")
        server_process = subprocess.Popen(
            [sys.executable, server_path],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True
        )
        
        # Wait for server to start
        time.sleep(2)
        
        # Check if server started successfully
        if server_process.poll() is not None:
            console.print("[bold red]Error:[/bold red] Failed to start MCP server")
            stdout, stderr = server_process.communicate()
            console.print(f"[red]Server output:[/red]\n{stdout}\n{stderr}")
            return
        
        console.print("[green]MCP server started successfully[/green]")
        
        # Initialize conversation
        conversation = Conversation()
        conversation.model = model
        
        # Add system prompt
        conversation.messages.append({
            "role": "system", 
            "content": "You are an MCP client connecting to a Model Context Protocol server. "
                      "Use the available tools to interact with the server and help the user."
        })
        
        # Register MCP-specific tools
        mcp_tools = [
            Tool(
                name="MCPGetContext",
                description="Get context from the MCP server using a prompt template",
                parameters={
                    "type": "object",
                    "properties": {
                        "prompt_id": {
                            "type": "string",
                            "description": "ID of the prompt template to use"
                        },
                        "parameters": {
                            "type": "object",
                            "description": "Parameters for the prompt template"
                        }
                    },
                    "required": ["prompt_id"]
                },
                function=lambda prompt_id, parameters=None: _mcp_get_context(host, port, prompt_id, parameters or {})
            ),
            Tool(
                name="MCPListPrompts",
                description="List available prompt templates from the MCP server",
                parameters={
                    "type": "object",
                    "properties": {}
                },
                function=lambda: _mcp_list_prompts(host, port)
            ),
            Tool(
                name="MCPGetPrompt",
                description="Get details of a specific prompt template from the MCP server",
                parameters={
                    "type": "object",
                    "properties": {
                        "prompt_id": {
                            "type": "string",
                            "description": "ID of the prompt template to get"
                        }
                    },
                    "required": ["prompt_id"]
                },
                function=lambda prompt_id: _mcp_get_prompt(host, port, prompt_id)
            )
        ]
        
        # Add MCP tools to conversation
        conversation.tools.extend(mcp_tools)
        for tool in mcp_tools:
            conversation.tool_map[tool.name] = tool.function
        
        # Main interaction loop
        console.print("[bold]MCP Client ready. Type your questions or commands.[/bold]")
        console.print("[bold]Type 'exit' to quit.[/bold]")
        
        while True:
            try:
                user_input = Prompt.ask("\n[bold blue]>>[/bold blue]")
                
                # Handle exit
                if user_input.lower() in ("exit", "quit", "/exit", "/quit"):
                    console.print("[bold yellow]Shutting down MCP client...[/bold yellow]")
                    break
                
                # Get response
                conversation.get_response(user_input)
                
            except KeyboardInterrupt:
                console.print("\n[bold yellow]Operation cancelled by user.[/bold yellow]")
                if Prompt.ask("[bold]Exit?[/bold]", choices=["y", "n"], default="n") == "y":
                    break
                continue
            except Exception as e:
                console.print(f"[bold red]Error:[/bold red] {str(e)}")
        
        # Clean up
        console.print("[bold]Stopping MCP server...[/bold]")
        server_process.terminate()
        server_process.wait(timeout=5)
        
    except Exception as e:
        console.print(f"[bold red]Error:[/bold red] {str(e)}")

# MCP client helper functions
def _mcp_get_context(host, port, prompt_id, parameters):
    """Get context from MCP server"""
    try:
        url = f"http://{host}:{port}/context"
        response = requests.post(
            url,
            json={
                "prompt_id": prompt_id,
                "parameters": parameters
            }
        )
        
        if response.status_code != 200:
            return f"Error: {response.status_code} - {response.text}"
        
        data = response.json()
        return f"Context (ID: {data['context_id']}):\n\n{data['context']}"
    except Exception as e:
        return f"Error connecting to MCP server: {str(e)}"

def _mcp_list_prompts(host, port):
    """List available prompt templates from MCP server"""
    try:
        url = f"http://{host}:{port}/prompts"
        response = requests.get(url)
        
        if response.status_code != 200:
            return f"Error: {response.status_code} - {response.text}"
        
        data = response.json()
        prompts = data.get("prompts", [])
        
        if not prompts:
            return "No prompt templates available"
        
        result = "Available prompt templates:\n\n"
        for prompt in prompts:
            result += f"ID: {prompt['id']}\n"
            result += f"Description: {prompt['description']}\n"
            result += f"Parameters: {', '.join(prompt.get('parameters', {}).keys())}\n\n"
        
        return result
    except Exception as e:
        return f"Error connecting to MCP server: {str(e)}"

def _mcp_get_prompt(host, port, prompt_id):
    """Get details of a specific prompt template"""
    try:
        url = f"http://{host}:{port}/prompts/{prompt_id}"
        response = requests.get(url)
        
        if response.status_code != 200:
            return f"Error: {response.status_code} - {response.text}"
        
        prompt = response.json()
        
        result = f"Prompt Template: {prompt['id']}\n\n"
        result += f"Description: {prompt['description']}\n\n"
        result += "Parameters:\n"
        
        for param_name, param_info in prompt.get("parameters", {}).items():
            result += f"- {param_name}: {param_info.get('description', '')}\n"
        
        result += f"\nTemplate:\n{prompt['template']}\n"
        
        return result
    except Exception as e:
        return f"Error connecting to MCP server: {str(e)}"

@app.command()
def mcp_multi_agent(
    server_path: str = typer.Argument(..., help="Path to the MCP server script or module"),
    config: str = typer.Option(None, "--config", "-c", help="Path to agent configuration JSON file"),
    host: str = typer.Option("127.0.0.1", "--host", help="Host address for the MCP server"),
    port: int = typer.Option(8000, "--port", "-p", help="Port for the MCP server")
):
    """
    Start a multi-agent MCP client with multiple specialized agents
    
    This allows using multiple agents with different roles to collaborate
    on complex tasks by connecting to an MCP server.
    """
    # Load configuration
    if config:
        if not os.path.exists(config):
            console.print(f"[bold red]Error:[/bold red] Configuration file not found at {config}")
            return
        
        try:
            with open(config, 'r') as f:
                config_data = json.load(f)
        except Exception as e:
            console.print(f"[bold red]Error loading configuration:[/bold red] {str(e)}")
            return
    else:
        # Default configuration
        config_data = {
            "agents": [
                {
                    "name": "Primary",
                    "role": "primary",
                    "system_prompt": "You are a helpful assistant that uses an MCP server to provide information.",
                    "model": "gpt-4o",
                    "temperature": 0.0
                }
            ],
            "coordination": {
                "strategy": "single",
                "primary_agent": "Primary"
            },
            "settings": {
                "max_turns_per_agent": 1,
                "enable_agent_reflection": False,
                "enable_cross_agent_communication": False,
                "enable_user_selection": False
            }
        }
    
    # Display configuration
    agent_names = [agent["name"] for agent in config_data["agents"]]
    console.print(Panel.fit(
        f"[bold green]OpenAI Code Assistant Multi-Agent MCP Client[/bold green]\n"
        f"Server: {server_path}\n"
        f"Host: {host}:{port}\n"
        f"Agents: {', '.join(agent_names)}\n"
        f"Coordination: {config_data['coordination']['strategy']}\n"
        f"Primary Agent: {config_data['coordination']['primary_agent']}",
        title="Multi-Agent MCP Client Starting",
        border_style="green"
    ))
    
    # Check if server path exists
    if not os.path.exists(server_path):
        console.print(f"[bold red]Error:[/bold red] Server script not found at {server_path}")
        return
    
    # Check API key
    if not os.getenv("OPENAI_API_KEY"):
        console.print("[bold red]Error:[/bold red] No OpenAI API key found. Please set the OPENAI_API_KEY environment variable.")
        return
    
    try:
        # Start the server in a subprocess
        import subprocess
        import signal
        
        # Start server process
        console.print(f"[bold]Starting MCP server from {server_path}...[/bold]")
        server_process = subprocess.Popen(
            [sys.executable, server_path],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True
        )
        
        # Wait for server to start
        time.sleep(2)
        
        # Check if server started successfully
        if server_process.poll() is not None:
            console.print("[bold red]Error:[/bold red] Failed to start MCP server")
            stdout, stderr = server_process.communicate()
            console.print(f"[red]Server output:[/red]\n{stdout}\n{stderr}")
            return
        
        console.print("[green]MCP server started successfully[/green]")
        
        # Initialize agents
        agents = {}
        for agent_config in config_data["agents"]:
            # Create conversation for agent
            agent = Conversation()
            agent.model = agent_config.get("model", "gpt-4o")
            agent.temperature = agent_config.get("temperature", 0.0)
            
            # Add system prompt
            agent.messages.append({
                "role": "system", 
                "content": agent_config.get("system_prompt", "You are a helpful assistant.")
            })
            
            # Register MCP-specific tools
            mcp_tools = [
                Tool(
                    name="MCPGetContext",
                    description="Get context from the MCP server using a prompt template",
                    parameters={
                        "type": "object",
                        "properties": {
                            "prompt_id": {
                                "type": "string",
                                "description": "ID of the prompt template to use"
                            },
                            "parameters": {
                                "type": "object",
                                "description": "Parameters for the prompt template"
                            }
                        },
                        "required": ["prompt_id"]
                    },
                    function=lambda prompt_id, parameters=None: _mcp_get_context(host, port, prompt_id, parameters or {})
                ),
                Tool(
                    name="MCPListPrompts",
                    description="List available prompt templates from the MCP server",
                    parameters={
                        "type": "object",
                        "properties": {}
                    },
                    function=lambda: _mcp_list_prompts(host, port)
                ),
                Tool(
                    name="MCPGetPrompt",
                    description="Get details of a specific prompt template from the MCP server",
                    parameters={
                        "type": "object",
                        "properties": {
                            "prompt_id": {
                                "type": "string",
                                "description": "ID of the prompt template to get"
                            }
                        },
                        "required": ["prompt_id"]
                    },
                    function=lambda prompt_id: _mcp_get_prompt(host, port, prompt_id)
                )
            ]
            
            # Add MCP tools to agent
            agent.tools.extend(mcp_tools)
            for tool in mcp_tools:
                agent.tool_map[tool.name] = tool.function
            
            # Add agent to agents dictionary
            agents[agent_config["name"]] = {
                "config": agent_config,
                "conversation": agent,
                "history": []
            }
        
        # Get primary agent
        primary_agent_name = config_data["coordination"]["primary_agent"]
        if primary_agent_name not in agents:
            console.print(f"[bold red]Error:[/bold red] Primary agent '{primary_agent_name}' not found in configuration")
            return
        
        # Main interaction loop
        console.print("[bold]Multi-Agent MCP Client ready. Type your questions or commands.[/bold]")
        console.print("[bold]Special commands:[/bold]")
        console.print("  [blue]/agents[/blue] - List available agents")
        console.print("  [blue]/talk <agent_name> <message>[/blue] - Send message to specific agent")
        console.print("  [blue]/history[/blue] - Show conversation history")
        console.print("  [blue]/exit[/blue] - Exit the client")
        
        conversation_history = []
        
        while True:
            try:
                user_input = Prompt.ask("\n[bold blue]>>[/bold blue]")
                
                # Handle exit
                if user_input.lower() in ("exit", "quit", "/exit", "/quit"):
                    console.print("[bold yellow]Shutting down multi-agent MCP client...[/bold yellow]")
                    break
                
                # Handle special commands
                if user_input.startswith("/agents"):
                    console.print("[bold]Available Agents:[/bold]")
                    for name, agent_data in agents.items():
                        role = agent_data["config"]["role"]
                        model = agent_data["config"]["model"]
                        console.print(f"  [green]{name}[/green] ({role}, {model})")
                    continue
                
                if user_input.startswith("/history"):
                    console.print("[bold]Conversation History:[/bold]")
                    for i, entry in enumerate(conversation_history, 1):
                        if entry["role"] == "user":
                            console.print(f"[blue]{i}. User:[/blue] {entry['content']}")
                        else:
                            console.print(f"[green]{i}. {entry['agent']}:[/green] {entry['content']}")
                    continue
                
                if user_input.startswith("/talk "):
                    parts = user_input.split(" ", 2)
                    if len(parts) < 3:
                        console.print("[yellow]Usage: /talk <agent_name> <message>[/yellow]")
                        continue
                    
                    agent_name = parts[1]
                    message = parts[2]
                    
                    if agent_name not in agents:
                        console.print(f"[yellow]Agent '{agent_name}' not found. Use /agents to see available agents.[/yellow]")
                        continue
                    
                    # Add message to history
                    conversation_history.append({
                        "role": "user",
                        "content": message,
                        "target_agent": agent_name
                    })
                    
                    # Get response from specific agent
                    console.print(f"[bold]Asking {agent_name}...[/bold]")
                    agent = agents[agent_name]["conversation"]
                    response = agent.get_response(message)
                    
                    # Add response to history
                    conversation_history.append({
                        "role": "assistant",
                        "agent": agent_name,
                        "content": response
                    })
                    
                    # Add to agent's history
                    agents[agent_name]["history"].append({
                        "role": "user",
                        "content": message
                    })
                    agents[agent_name]["history"].append({
                        "role": "assistant",
                        "content": response
                    })
                    
                    continue
                
                # Regular message - use coordination strategy
                strategy = config_data["coordination"]["strategy"]
                
                # Add message to history
                conversation_history.append({
                    "role": "user",
                    "content": user_input
                })
                
                if strategy == "single" or strategy == "primary":
                    # Just use the primary agent
                    agent = agents[primary_agent_name]["conversation"]
                    response = agent.get_response(user_input)
                    
                    # Add response to history
                    conversation_history.append({
                        "role": "assistant",
                        "agent": primary_agent_name,
                        "content": response
                    })
                    
                    # Add to agent's history
                    agents[primary_agent_name]["history"].append({
                        "role": "user",
                        "content": user_input
                    })
                    agents[primary_agent_name]["history"].append({
                        "role": "assistant",
                        "content": response
                    })
                    
                elif strategy == "round_robin":
                    # Ask each agent in turn
                    console.print("[bold]Consulting all agents...[/bold]")
                    
                    for agent_name, agent_data in agents.items():
                        console.print(f"[bold]Response from {agent_name}:[/bold]")
                        agent = agent_data["conversation"]
                        response = agent.get_response(user_input)
                        
                        # Add response to history
                        conversation_history.append({
                            "role": "assistant",
                            "agent": agent_name,
                            "content": response
                        })
                        
                        # Add to agent's history
                        agent_data["history"].append({
                            "role": "user",
                            "content": user_input
                        })
                        agent_data["history"].append({
                            "role": "assistant",
                            "content": response
                        })
                
                elif strategy == "voting":
                    # Ask all agents and show all responses
                    console.print("[bold]Collecting responses from all agents...[/bold]")
                    
                    responses = {}
                    for agent_name, agent_data in agents.items():
                        agent = agent_data["conversation"]
                        response = agent.get_response(user_input)
                        responses[agent_name] = response
                        
                        # Add to agent's history
                        agent_data["history"].append({
                            "role": "user",
                            "content": user_input
                        })
                        agent_data["history"].append({
                            "role": "assistant",
                            "content": response
                        })
                    
                    # Display all responses
                    for agent_name, response in responses.items():
                        console.print(f"[bold]Response from {agent_name}:[/bold]")
                        console.print(response)
                        
                        # Add response to history
                        conversation_history.append({
                            "role": "assistant",
                            "agent": agent_name,
                            "content": response
                        })
                
                else:
                    console.print(f"[yellow]Unknown coordination strategy: {strategy}[/yellow]")
                    # Default to primary agent
                    agent = agents[primary_agent_name]["conversation"]
                    response = agent.get_response(user_input)
                    
                    # Add response to history
                    conversation_history.append({
                        "role": "assistant",
                        "agent": primary_agent_name,
                        "content": response
                    })
                    
                    # Add to agent's history
                    agents[primary_agent_name]["history"].append({
                        "role": "user",
                        "content": user_input
                    })
                    agents[primary_agent_name]["history"].append({
                        "role": "assistant",
                        "content": response
                    })
                
            except KeyboardInterrupt:
                console.print("\n[bold yellow]Operation cancelled by user.[/bold yellow]")
                if Prompt.ask("[bold]Exit?[/bold]", choices=["y", "n"], default="n") == "y":
                    break
                continue
            except Exception as e:
                console.print(f"[bold red]Error:[/bold red] {str(e)}")
        
        # Clean up
        console.print("[bold]Stopping MCP server...[/bold]")
        server_process.terminate()
        server_process.wait(timeout=5)
        
    except Exception as e:
        console.print(f"[bold red]Error:[/bold red] {str(e)}")

@app.command()
def main(
    model: str = typer.Option(DEFAULT_MODEL, "--model", "-m", help="Specify the model to use"),
    temperature: float = typer.Option(DEFAULT_TEMPERATURE, "--temperature", "-t", help="Set temperature for response generation"),
    verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable verbose output with additional information"),
    enable_rl: bool = typer.Option(True, "--enable-rl/--disable-rl", help="Enable/disable reinforcement learning for tool optimization"),
    rl_update: bool = typer.Option(False, "--rl-update", help="Manually trigger an update of the RL model"),
):
    """
    OpenAI Code Assistant - A command-line coding assistant 
    that uses OpenAI APIs with function calling and streaming
    """
    # TODO: Check for updates on startup
    # TODO: Add environment setup verification
    
    # Create welcome panel with more details
    rl_status = "enabled" if enable_rl else "disabled"
    console.print(Panel.fit(
        f"[bold green]OpenAI Code Assistant[/bold green]\n"
        f"Model: {model} (Temperature: {temperature})\n"
        f"Reinforcement Learning: {rl_status}\n"
        "Type your questions or commands. Use /help for available commands.",
        title="Welcome",
        border_style="green"
    ))
    
    # Check API key
    if not os.getenv("OPENAI_API_KEY"):
        console.print("[bold red]Error:[/bold red] No OpenAI API key found. Please set the OPENAI_API_KEY environment variable.")
        console.print("You can create a .env file with your API key or set it in your environment.")
        return
    
    # Initialize conversation
    conversation = Conversation()
    
    # Override model and temperature if specified
    if model != DEFAULT_MODEL:
        conversation.model = model
    conversation.temperature = temperature
    
    # Configure verbose mode
    conversation.verbose = verbose
    
    # Configure RL mode
    if not enable_rl and hasattr(conversation, 'tool_optimizer') and conversation.tool_optimizer is not None:
        os.environ["ENABLE_TOOL_OPTIMIZATION"] = "0"
        conversation.tool_optimizer = None
        console.print("[yellow]Reinforcement learning disabled[/yellow]")
    
    # Handle manual RL update if requested
    if rl_update and hasattr(conversation, 'tool_optimizer') and conversation.tool_optimizer is not None:
        try:
            with console.status("[bold blue]Updating RL model...[/bold blue]"):
                result = conversation.tool_optimizer.optimizer.update_model()
            console.print(f"[green]RL model update result:[/green] {result['status']}")
            if 'metrics' in result:
                console.print(Panel.fit(
                    "\n".join([f"{k}: {v}" for k, v in result['metrics'].items()]),
                    title="RL Metrics",
                    border_style="blue"
                ))
        except Exception as e:
            console.print(f"[red]Error updating RL model:[/red] {e}")
    
    # Add system prompt
    conversation.messages.append({"role": "system", "content": get_system_prompt()})
    
    # TODO: Add context collection for file system and git information
    # TODO: Add session persistence to allow resuming conversations
    
    # Main interaction loop
    while True:
        try:
            user_input = Prompt.ask("\n[bold blue]>>[/bold blue]")
            
            # Handle exit
            if user_input.lower() in ("exit", "quit", "/exit", "/quit"):
                console.print("[bold yellow]Goodbye![/bold yellow]")
                break
            
            # Get response without wrapping it in a status indicator
            # This allows the streaming to work properly
            try:
                conversation.get_response(user_input)
            except Exception as e:
                console.print(f"[bold red]Error during response generation:[/bold red] {str(e)}")
                
                # Provide more helpful error messages for common issues
                if "api_key" in str(e).lower():
                    console.print("[yellow]Hint: Check your OpenAI API key.[/yellow]")
                elif "rate limit" in str(e).lower():
                    console.print("[yellow]Hint: You've hit a rate limit. Try again in a moment.[/yellow]")
                elif "context_length_exceeded" in str(e).lower() or "maximum context length" in str(e).lower():
                    console.print("[yellow]Hint: The conversation is too long. Try using /compact to reduce its size.[/yellow]")
                elif "Missing required parameter" in str(e):
                    console.print("[yellow]Hint: There's an API format issue. Try restarting the conversation.[/yellow]")
                
                # Offer recovery options
                recovery_choice = Prompt.ask(
                    "[bold]Would you like to:[/bold]",
                    choices=["continue", "debug", "compact", "restart", "exit"],
                    default="continue"
                )
                
                if recovery_choice == "debug":
                    # Show debug information
                    debug_info = {
                        "model": conversation.model,
                        "temperature": conversation.temperature,
                        "message_count": len(conversation.messages),
                        "token_usage": conversation.token_usage,
                        "conversation_id": conversation.conversation_id,
                        "session_duration": time.time() - conversation.session_start_time,
                        "tools_count": len(conversation.tools),
                        "python_version": sys.version,
                        "openai_version": OpenAI.__version__ if hasattr(OpenAI, "__version__") else "Unknown"
                    }
                    console.print(Panel(json.dumps(debug_info, indent=2), title="Debug Information", border_style="yellow"))
                elif recovery_choice == "compact":
                    # Compact the conversation
                    result = conversation.compact()
                    console.print(f"[green]{result}[/green]")
                elif recovery_choice == "restart":
                    # Restart the conversation
                    conversation = Conversation()
                    conversation.model = model
                    conversation.temperature = temperature
                    conversation.verbose = verbose
                    conversation.messages.append({"role": "system", "content": get_system_prompt()})
                    console.print("[green]Conversation restarted.[/green]")
                elif recovery_choice == "exit":
                    console.print("[bold yellow]Goodbye![/bold yellow]")
                    break
            
        except KeyboardInterrupt:
            console.print("\n[bold yellow]Operation cancelled by user.[/bold yellow]")
            # Offer options after cancellation
            cancel_choice = Prompt.ask(
                "[bold]Would you like to:[/bold]",
                choices=["continue", "exit"],
                default="continue"
            )
            if cancel_choice == "exit":
                console.print("[bold yellow]Goodbye![/bold yellow]")
                break
            continue
        except Exception as e:
            console.print(f"[bold red]Unexpected error:[/bold red] {str(e)}")
            import traceback
            console.print(traceback.format_exc())
            # Ask if user wants to continue despite the error
            if Prompt.ask("[bold]Continue?[/bold]", choices=["y", "n"], default="y") == "n":
                break

if __name__ == "__main__":
    app()

```
Page 4/4FirstPrevNextLast