#
tokens: 28668/50000 1/56 files (page 5/6)
lines: on (toggle) GitHub
raw markdown copy reset
This is page 5 of 6. Use http://codebase.md/arthurcolle/openai-mcp?lines=true&page={x} to view the full context.

# Directory Structure

```
├── .gitignore
├── claude_code
│   ├── __init__.py
│   ├── __pycache__
│   │   ├── __init__.cpython-312.pyc
│   │   └── mcp_server.cpython-312.pyc
│   ├── claude.py
│   ├── commands
│   │   ├── __init__.py
│   │   ├── __pycache__
│   │   │   ├── __init__.cpython-312.pyc
│   │   │   └── serve.cpython-312.pyc
│   │   ├── client.py
│   │   ├── multi_agent_client.py
│   │   └── serve.py
│   ├── config
│   │   └── __init__.py
│   ├── examples
│   │   ├── agents_config.json
│   │   ├── claude_mcp_config.html
│   │   ├── claude_mcp_config.json
│   │   ├── echo_server.py
│   │   └── README.md
│   ├── lib
│   │   ├── __init__.py
│   │   ├── __pycache__
│   │   │   └── __init__.cpython-312.pyc
│   │   ├── context
│   │   │   └── __init__.py
│   │   ├── monitoring
│   │   │   ├── __init__.py
│   │   │   ├── __pycache__
│   │   │   │   ├── __init__.cpython-312.pyc
│   │   │   │   └── server_metrics.cpython-312.pyc
│   │   │   ├── cost_tracker.py
│   │   │   └── server_metrics.py
│   │   ├── providers
│   │   │   ├── __init__.py
│   │   │   ├── base.py
│   │   │   └── openai.py
│   │   ├── rl
│   │   │   ├── __init__.py
│   │   │   ├── grpo.py
│   │   │   ├── mcts.py
│   │   │   └── tool_optimizer.py
│   │   ├── tools
│   │   │   ├── __init__.py
│   │   │   ├── __pycache__
│   │   │   │   ├── __init__.cpython-312.pyc
│   │   │   │   ├── base.cpython-312.pyc
│   │   │   │   ├── file_tools.cpython-312.pyc
│   │   │   │   └── manager.cpython-312.pyc
│   │   │   ├── ai_tools.py
│   │   │   ├── base.py
│   │   │   ├── code_tools.py
│   │   │   ├── file_tools.py
│   │   │   ├── manager.py
│   │   │   └── search_tools.py
│   │   └── ui
│   │       ├── __init__.py
│   │       └── tool_visualizer.py
│   ├── mcp_server.py
│   ├── README_MCP_CLIENT.md
│   ├── README_MULTI_AGENT.md
│   └── util
│       └── __init__.py
├── claude.py
├── cli.py
├── data
│   └── prompt_templates.json
├── deploy_modal_mcp.py
├── deploy.sh
├── examples
│   ├── agents_config.json
│   └── echo_server.py
├── install.sh
├── mcp_modal_adapter.py
├── mcp_server.py
├── modal_mcp_server.py
├── README_modal_mcp.md
├── README.md
├── requirements.txt
├── setup.py
├── static
│   └── style.css
├── templates
│   └── index.html
└── web-client.html
```

# Files

--------------------------------------------------------------------------------
/modal_mcp_server.py:
--------------------------------------------------------------------------------

```python
   1 | import modal
   2 | import logging
   3 | import time
   4 | import uuid
   5 | import json
   6 | import asyncio
   7 | import hashlib
   8 | import threading
   9 | import concurrent.futures
  10 | from pathlib import Path
  11 | from typing import Dict, List, Optional, Any, Tuple, Union, AsyncIterator
  12 | from datetime import datetime, timedelta
  13 | from collections import deque
  14 | 
  15 | from fastapi import FastAPI, Request, Depends, HTTPException, status, BackgroundTasks
  16 | from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
  17 | from fastapi.responses import JSONResponse, HTMLResponse
  18 | from fastapi.middleware.cors import CORSMiddleware
  19 | from pydantic import BaseModel, Field
  20 | 
  21 | # Configure logging
  22 | logging.basicConfig(
  23 |     level=logging.INFO,
  24 |     format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
  25 | )
  26 | 
  27 | # Create FastAPI app
  28 | api_app = FastAPI(
  29 |     title="Advanced LLM Inference API", 
  30 |     description="Enterprise-grade OpenAI-compatible LLM serving API with multiple model support, streaming, and advanced caching",
  31 |     version="1.1.0"
  32 | )
  33 | 
  34 | # Add CORS middleware
  35 | api_app.add_middleware(
  36 |     CORSMiddleware,
  37 |     allow_origins=["*"],  # For production, specify specific origins instead of wildcard
  38 |     allow_credentials=True,
  39 |     allow_methods=["*"],
  40 |     allow_headers=["*"],
  41 | )
  42 | 
  43 | # Security setup
  44 | security = HTTPBearer()
  45 | 
  46 | # Token bucket rate limiter
  47 | class TokenBucket:
  48 |     """
  49 |     Token bucket algorithm for rate limiting.
  50 |     Each user gets a bucket that fills at a constant rate.
  51 |     """
  52 |     def __init__(self):
  53 |         self.buckets = {}
  54 |         self.lock = threading.Lock()
  55 |     
  56 |     def _get_bucket(self, user_id, rate_limit):
  57 |         """Get or create a bucket for a user"""
  58 |         now = time.time()
  59 |         
  60 |         if user_id not in self.buckets:
  61 |             # Initialize with full bucket
  62 |             self.buckets[user_id] = {
  63 |                 "tokens": rate_limit,
  64 |                 "last_refill": now,
  65 |                 "rate": rate_limit / 60.0  # tokens per second
  66 |             }
  67 |             return self.buckets[user_id]
  68 |         
  69 |         bucket = self.buckets[user_id]
  70 |         
  71 |         # Update rate if it changed
  72 |         bucket["rate"] = rate_limit / 60.0
  73 |         
  74 |         # Refill tokens based on time elapsed
  75 |         elapsed = now - bucket["last_refill"]
  76 |         new_tokens = elapsed * bucket["rate"]
  77 |         
  78 |         bucket["tokens"] = min(rate_limit, bucket["tokens"] + new_tokens)
  79 |         bucket["last_refill"] = now
  80 |         
  81 |         return bucket
  82 |     
  83 |     def consume(self, user_id, tokens=1, rate_limit=60):
  84 |         """
  85 |         Consume tokens from a user's bucket.
  86 |         Returns True if tokens were consumed, False otherwise.
  87 |         """
  88 |         with self.lock:
  89 |             bucket = self._get_bucket(user_id, rate_limit)
  90 |             
  91 |             if bucket["tokens"] >= tokens:
  92 |                 bucket["tokens"] -= tokens
  93 |                 return True
  94 |             return False
  95 | 
  96 | # Create rate limiter
  97 | rate_limiter = TokenBucket()
  98 | 
  99 | # Define the container image with necessary dependencies
 100 | vllm_image = (
 101 |     modal.Image.debian_slim(python_version="3.10")
 102 |     .pip_install(
 103 |         "vllm==0.7.3",  # Updated version
 104 |         "huggingface_hub[hf_transfer]==0.26.2",
 105 |         "flashinfer-python==0.2.0.post2",
 106 |         "fastapi>=0.95.0",
 107 |         "uvicorn>=0.15.0",
 108 |         "pydantic>=2.0.0",
 109 |         "tiktoken>=0.5.1",
 110 |         extra_index_url="https://flashinfer.ai/whl/cu124/torch2.5",
 111 |     )
 112 |     .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})  # faster model transfers
 113 |     .env({"VLLM_USE_V1": "1"})  # Enable V1 engine for better performance
 114 | )
 115 | 
 116 | # Define llama.cpp image for alternative models
 117 | llama_cpp_image = (
 118 |     modal.Image.debian_slim(python_version="3.10")
 119 |     .apt_install("git", "build-essential", "cmake", "curl", "libcurl4-openssl-dev")
 120 |     .pip_install(
 121 |         "huggingface_hub==0.26.2",
 122 |         "hf_transfer>=0.1.4",
 123 |         "fastapi>=0.95.0",
 124 |         "uvicorn>=0.15.0",
 125 |         "pydantic>=2.0.0"
 126 |     )
 127 |     .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
 128 |     .run_commands(
 129 |         "git clone https://github.com/ggerganov/llama.cpp",
 130 |         "cmake llama.cpp -B llama.cpp/build -DBUILD_SHARED_LIBS=OFF -DLLAMA_CURL=ON",
 131 |         "cmake --build llama.cpp/build --config Release -j --target llama-cli",
 132 |         "cp llama.cpp/build/bin/llama-* /usr/local/bin/"
 133 |     )
 134 | )
 135 | 
 136 | # Set up model configurations
 137 | MODELS_DIR = "/models"
 138 | VLLM_MODELS = {
 139 |     "llama3-8b": {
 140 |         "id": "llama3-8b",
 141 |         "name": "neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w4a16",
 142 |         "config": "config.json",  # Ensure this file is present in the model directory
 143 |         "revision": "a7c09948d9a632c2c840722f519672cd94af885d",
 144 |         "max_tokens": 4096,
 145 |         "loaded": False
 146 |     },
 147 |     "mistral-7b": {
 148 |         "id": "mistral-7b",
 149 |         "name": "mistralai/Mistral-7B-Instruct-v0.2",
 150 |         "revision": "main",
 151 |         "max_tokens": 4096,
 152 |         "loaded": False
 153 |     },
 154 |     # Small model for quick loading
 155 |     "tiny-llama-1.1b": {
 156 |         "id": "tiny-llama-1.1b",
 157 |         "name": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
 158 |         "revision": "main",
 159 |         "max_tokens": 2048,
 160 |         "loaded": False
 161 |     }
 162 | }
 163 | 
 164 | LLAMA_CPP_MODELS = {
 165 |     "deepseek-r1": {
 166 |         "id": "deepseek-r1",
 167 |         "name": "unsloth/DeepSeek-R1-GGUF",
 168 |         "quant": "UD-IQ1_S",
 169 |         "pattern": "*UD-IQ1_S*",
 170 |         "revision": "02656f62d2aa9da4d3f0cdb34c341d30dd87c3b6",
 171 |         "gpu": "L40S:4",
 172 |         "max_tokens": 4096,
 173 |         "loaded": False
 174 |     },
 175 |     "phi-4": {
 176 |         "id": "phi-4",
 177 |         "name": "unsloth/phi-4-GGUF",
 178 |         "quant": "Q2_K",
 179 |         "pattern": "*Q2_K*",
 180 |         "revision": None,
 181 |         "gpu": "L40S:4",  # Use GPU for better performance
 182 |         "max_tokens": 4096,
 183 |         "loaded": False
 184 |     },
 185 |     # Small model for quick loading
 186 |     "phi-2": {
 187 |         "id": "phi-2",
 188 |         "name": "TheBloke/phi-2-GGUF",
 189 |         "quant": "Q4_K_M",
 190 |         "pattern": "*Q4_K_M.gguf",
 191 |         "revision": "main",
 192 |         "gpu": None,  # Can run on CPU
 193 |         "max_tokens": 2048,
 194 |         "loaded": False
 195 |     }
 196 | }
 197 | 
 198 | DEFAULT_MODEL = "phi-4"
 199 | 
 200 | # Create volumes for caching
 201 | hf_cache_vol = modal.Volume.from_name("huggingface-cache", create_if_missing=True)
 202 | vllm_cache_vol = modal.Volume.from_name("vllm-cache", create_if_missing=True)
 203 | llama_cpp_cache_vol = modal.Volume.from_name("llama-cpp-cache", create_if_missing=True)
 204 | results_vol = modal.Volume.from_name("model-results", create_if_missing=True)
 205 | 
 206 | # Create the Modal app
 207 | app = modal.App("openai-compatible-llm-server")
 208 | 
 209 | # Create shared data structures
 210 | model_stats_dict = modal.Dict.from_name("model-stats", create_if_missing=True)
 211 | user_usage_dict = modal.Dict.from_name("user-usage", create_if_missing=True)
 212 | request_queue = modal.Queue.from_name("request-queue", create_if_missing=True)
 213 | response_dict = modal.Dict.from_name("response-cache", create_if_missing=True)
 214 | api_keys_dict = modal.Dict.from_name("api-keys", create_if_missing=True)
 215 | stream_queues = modal.Dict.from_name("stream-queues", create_if_missing=True)
 216 | 
 217 | # Advanced caching system
 218 | class AdvancedCache:
 219 |     """
 220 |     Advanced caching system with TTL and LRU eviction.
 221 |     """
 222 |     def __init__(self, max_size=1000, default_ttl=3600):
 223 |         self.cache = {}
 224 |         self.ttl_map = {}
 225 |         self.access_times = {}
 226 |         self.max_size = max_size
 227 |         self.default_ttl = default_ttl
 228 |         self.lock = threading.Lock()
 229 |     
 230 |     def get(self, key):
 231 |         """Get a value from the cache"""
 232 |         with self.lock:
 233 |             now = time.time()
 234 |             
 235 |             # Check if key exists and is not expired
 236 |             if key in self.cache:
 237 |                 # Check TTL
 238 |                 if key in self.ttl_map and self.ttl_map[key] < now:
 239 |                     # Expired
 240 |                     self._remove(key)
 241 |                     return None
 242 |                 
 243 |                 # Update access time
 244 |                 self.access_times[key] = now
 245 |                 return self.cache[key]
 246 |             
 247 |             return None
 248 |     
 249 |     def set(self, key, value, ttl=None):
 250 |         """Set a value in the cache with optional TTL"""
 251 |         with self.lock:
 252 |             now = time.time()
 253 |             
 254 |             # Evict if needed
 255 |             if len(self.cache) >= self.max_size and key not in self.cache:
 256 |                 self._evict_lru()
 257 |             
 258 |             # Set value
 259 |             self.cache[key] = value
 260 |             self.access_times[key] = now
 261 |             
 262 |             # Set TTL
 263 |             if ttl is not None:
 264 |                 self.ttl_map[key] = now + ttl
 265 |             elif self.default_ttl > 0:
 266 |                 self.ttl_map[key] = now + self.default_ttl
 267 |     
 268 |     def _remove(self, key):
 269 |         """Remove a key from the cache"""
 270 |         if key in self.cache:
 271 |             del self.cache[key]
 272 |         if key in self.ttl_map:
 273 |             del self.ttl_map[key]
 274 |         if key in self.access_times:
 275 |             del self.access_times[key]
 276 |     
 277 |     def _evict_lru(self):
 278 |         """Evict least recently used item"""
 279 |         if not self.access_times:
 280 |             return
 281 |         
 282 |         # Find oldest access time
 283 |         oldest_key = min(self.access_times.items(), key=lambda x: x[1])[0]
 284 |         self._remove(oldest_key)
 285 |     
 286 |     def clear_expired(self):
 287 |         """Clear all expired entries"""
 288 |         with self.lock:
 289 |             now = time.time()
 290 |             expired_keys = [k for k, v in self.ttl_map.items() if v < now]
 291 |             for key in expired_keys:
 292 |                 self._remove(key)
 293 | 
 294 | # Constants
 295 | MAX_CACHE_AGE = 3600  # 1 hour in seconds
 296 | 
 297 | # Create memory cache
 298 | memory_cache = AdvancedCache(max_size=10000, default_ttl=MAX_CACHE_AGE)
 299 | 
 300 | # Initialize with default key if empty
 301 | if "default" not in api_keys_dict:
 302 |     api_keys_dict["default"] = {
 303 |         "key": "sk-modal-llm-api-key",
 304 |         "rate_limit": 60,  # requests per minute
 305 |         "quota": 1000000,  # tokens per day
 306 |         "created_at": datetime.now().isoformat(),
 307 |         "owner": "default"
 308 |     }
 309 | 
 310 | # Add a default ADMIN API key
 311 | if "admin" not in api_keys_dict:
 312 |     api_keys_dict["admin"] = {
 313 |         "key": "sk-modal-admin-api-key",
 314 |         "rate_limit": 1000,  # Higher rate limit for admin
 315 |         "quota": 10000000,  # Higher quota for admin
 316 |         "created_at": datetime.now().isoformat(),
 317 |         "owner": "admin"
 318 |     }
 319 | 
 320 | # Constants
 321 | DEFAULT_API_KEY = api_keys_dict["default"]["key"]
 322 | MINUTES = 60  # seconds
 323 | SERVER_PORT = 8000
 324 | CACHE_DIR = "/root/.cache"
 325 | RESULTS_DIR = "/root/results"
 326 | 
 327 | # Request/response models
 328 | class GenerationRequest(BaseModel):
 329 |     request_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
 330 |     model_id: str
 331 |     messages: List[Dict[str, str]]
 332 |     temperature: float = 0.7
 333 |     max_tokens: int = 1024
 334 |     top_p: float = 1.0
 335 |     frequency_penalty: float = 0.0
 336 |     presence_penalty: float = 0.0
 337 |     user: Optional[str] = None
 338 |     stream: bool = False
 339 |     timestamp: float = Field(default_factory=time.time)
 340 |     api_key: str = DEFAULT_API_KEY
 341 |     
 342 | class StreamChunk(BaseModel):
 343 |     """Model for streaming response chunks"""
 344 |     id: str
 345 |     object: str = "chat.completion.chunk"
 346 |     created: int
 347 |     model: str
 348 |     choices: List[Dict[str, Any]]
 349 |     
 350 | class StreamManager:
 351 |     """Manages streaming responses for clients"""
 352 |     def __init__(self):
 353 |         self.streams = {}
 354 |         self.lock = threading.Lock()
 355 |     
 356 |     def create_stream(self, request_id):
 357 |         """Create a new stream for a request"""
 358 |         with self.lock:
 359 |             self.streams[request_id] = {
 360 |                 "queue": asyncio.Queue(),
 361 |                 "finished": False,
 362 |                 "created_at": time.time()
 363 |             }
 364 |     
 365 |     def add_chunk(self, request_id, chunk):
 366 |         """Add a chunk to a stream"""
 367 |         with self.lock:
 368 |             if request_id in self.streams:
 369 |                 stream = self.streams[request_id]
 370 |                 if not stream["finished"]:
 371 |                     stream["queue"].put_nowait(chunk)
 372 |     
 373 |     def finish_stream(self, request_id):
 374 |         """Mark a stream as finished"""
 375 |         with self.lock:
 376 |             if request_id in self.streams:
 377 |                 self.streams[request_id]["finished"] = True
 378 |                 # Add None to signal end of stream
 379 |                 self.streams[request_id]["queue"].put_nowait(None)
 380 |     
 381 |     async def get_chunks(self, request_id):
 382 |         """Get chunks from a stream as an async generator"""
 383 |         if request_id not in self.streams:
 384 |             return
 385 |         
 386 |         stream = self.streams[request_id]
 387 |         queue = stream["queue"]
 388 |         
 389 |         while True:
 390 |             chunk = await queue.get()
 391 |             if chunk is None:  # End of stream
 392 |                 break
 393 |             yield chunk
 394 |             queue.task_done()
 395 |         
 396 |         # Clean up after streaming is done
 397 |         with self.lock:
 398 |             if request_id in self.streams:
 399 |                 del self.streams[request_id]
 400 |     
 401 |     def clean_old_streams(self, max_age=3600):
 402 |         """Clean up old streams"""
 403 |         with self.lock:
 404 |             now = time.time()
 405 |             to_remove = []
 406 |             
 407 |             for request_id, stream in self.streams.items():
 408 |                 if now - stream["created_at"] > max_age:
 409 |                     to_remove.append(request_id)
 410 |             
 411 |             for request_id in to_remove:
 412 |                 if request_id in self.streams:
 413 |                     # Mark as finished to stop any ongoing processing
 414 |                     self.streams[request_id]["finished"] = True
 415 |                     # Add None to unblock any waiting consumers
 416 |                     self.streams[request_id]["queue"].put_nowait(None)
 417 |                     # Remove from streams
 418 |                     del self.streams[request_id]
 419 | 
 420 | # Create stream manager
 421 | stream_manager = StreamManager()
 422 | 
 423 | # API Authentication dependency
 424 | def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
 425 |     """Verify that the API key in the authorization header is valid and check rate limits"""
 426 |     if credentials.scheme != "Bearer":
 427 |         raise HTTPException(
 428 |             status_code=status.HTTP_401_UNAUTHORIZED,
 429 |             detail="Invalid authentication scheme. Use Bearer",
 430 |         )
 431 |     
 432 |     api_key = credentials.credentials
 433 |     valid_key = False
 434 |     key_info = None
 435 |     
 436 |     # Check if this is a known API key
 437 |     for user_id, user_data in api_keys_dict.items():
 438 |         if user_data.get("key") == api_key:
 439 |             valid_key = True
 440 |             key_info = user_data
 441 |             break
 442 |     
 443 |     if not valid_key:
 444 |         raise HTTPException(
 445 |             status_code=status.HTTP_401_UNAUTHORIZED,
 446 |             detail="Invalid API key",
 447 |         )
 448 |     
 449 |     # Check rate limits
 450 |     user_id = key_info.get("owner", "unknown")
 451 |     rate_limit = key_info.get("rate_limit", 60)  # Default: 60 requests per minute
 452 |     
 453 |     # Get or initialize user usage tracking
 454 |     if user_id not in user_usage_dict:
 455 |         user_usage_dict[user_id] = {
 456 |             "requests": [],
 457 |             "tokens": {
 458 |                 "input": 0,
 459 |                 "output": 0,
 460 |                 "last_reset": datetime.now().isoformat()
 461 |             }
 462 |         }
 463 |     
 464 |     usage = user_usage_dict[user_id]
 465 |     
 466 |     # Check if user exceeded rate limit using token bucket algorithm
 467 |     if not rate_limiter.consume(user_id, tokens=1, rate_limit=rate_limit):
 468 |         # Calculate retry-after based on rate
 469 |         retry_after = int(60 / rate_limit)  # seconds until at least one token is available
 470 |         
 471 |         raise HTTPException(
 472 |             status_code=status.HTTP_429_TOO_MANY_REQUESTS,
 473 |             detail=f"Rate limit exceeded. Maximum {rate_limit} requests per minute.",
 474 |             headers={"Retry-After": str(retry_after)}
 475 |         )
 476 |     
 477 |     # Add current request timestamp for analytics
 478 |     now = datetime.now()
 479 |     usage["requests"].append(now.timestamp())
 480 |     
 481 |     # Clean up old requests (older than 1 day) to prevent unbounded growth
 482 |     day_ago = (now - timedelta(days=1)).timestamp()
 483 |     usage["requests"] = [req for req in usage["requests"] if req > day_ago]
 484 |     
 485 |     # Update usage dict
 486 |     user_usage_dict[user_id] = usage
 487 |     
 488 |     # Return the API key and user ID
 489 |     return {"key": api_key, "user_id": user_id}
 490 | 
 491 | # API Endpoints
 492 | @api_app.get("/", response_class=HTMLResponse)
 493 | async def index():
 494 |     """Root endpoint that returns HTML with API information"""
 495 |     return """
 496 |     <html>
 497 |         <head>
 498 |             <title>Modal LLM Inference API</title>
 499 |             <style>
 500 |                 body { font-family: system-ui, sans-serif; max-width: 800px; margin: 0 auto; padding: 2rem; }
 501 |                 h1 { color: #4a56e2; }
 502 |                 code { background: #f4f4f8; padding: 0.2rem 0.4rem; border-radius: 3px; }
 503 |             </style>
 504 |         </head>
 505 |         <body>
 506 |             <h1>Modal LLM Inference API</h1>
 507 |             <p>This is an OpenAI-compatible API for LLM inference powered by Modal.</p>
 508 |             <p>Use the following endpoints:</p>
 509 |             <ul>
 510 |                 <li><a href="/docs">/docs</a> - API documentation</li>
 511 |                 <li><a href="/v1/models">/v1/models</a> - List available models</li>
 512 |                 <li><code>/v1/chat/completions</code> - Chat completions endpoint</li>
 513 |             </ul>
 514 |         </body>
 515 |     </html>
 516 |     """
 517 | 
 518 | @api_app.get("/health")
 519 | async def health_check():
 520 |     """Health check endpoint"""
 521 |     return {"status": "healthy"}
 522 | 
 523 | @api_app.get("/v1/models", dependencies=[Depends(verify_api_key)])
 524 | async def list_models():
 525 |     """List all available models in OpenAI-compatible format"""
 526 |     # Combine vLLM and llama.cpp models
 527 |     all_models = []
 528 |     
 529 |     for model_id, model_info in VLLM_MODELS.items():
 530 |         all_models.append({
 531 |             "id": model_info["id"],
 532 |             "object": "model",
 533 |             "created": 1677610602,
 534 |             "owned_by": "modal",
 535 |             "engine": "vllm",
 536 |             "loaded": model_info.get("loaded", False)
 537 |         })
 538 |         
 539 |     for model_id, model_info in LLAMA_CPP_MODELS.items():
 540 |         all_models.append({
 541 |             "id": model_info["id"],
 542 |             "object": "model",
 543 |             "created": 1677610602,
 544 |             "owned_by": "modal",
 545 |             "engine": "llama.cpp",
 546 |             "loaded": model_info.get("loaded", False)
 547 |         })
 548 |         
 549 |     return {"data": all_models, "object": "list"}
 550 | 
 551 | # Model management endpoints
 552 | class ModelLoadRequest(BaseModel):
 553 |     """Request model to load a specific model"""
 554 |     model_id: str
 555 |     force_reload: bool = False
 556 |     
 557 | class HFModelLoadRequest(BaseModel):
 558 |     """Request to load a model directly from Hugging Face"""
 559 |     repo_id: str
 560 |     model_type: str = "vllm"  # "vllm" or "llama.cpp"
 561 |     revision: Optional[str] = None
 562 |     quant: Optional[str] = None  # For llama.cpp models
 563 |     max_tokens: int = 4096
 564 |     gpu: Optional[str] = None  # For llama.cpp models
 565 | 
 566 | @api_app.post("/admin/models/load", dependencies=[Depends(verify_api_key)])
 567 | async def load_model(request: ModelLoadRequest, background_tasks: BackgroundTasks):
 568 |     """Load a specific model into memory"""
 569 |     model_id = request.model_id
 570 |     force_reload = request.force_reload
 571 |     
 572 |     # Check if model exists
 573 |     if model_id in VLLM_MODELS:
 574 |         model_type = "vllm"
 575 |         model_info = VLLM_MODELS[model_id]
 576 |     elif model_id in LLAMA_CPP_MODELS:
 577 |         model_type = "llama.cpp"
 578 |         model_info = LLAMA_CPP_MODELS[model_id]
 579 |     else:
 580 |         raise HTTPException(
 581 |             status_code=status.HTTP_404_NOT_FOUND,
 582 |             detail=f"Model {model_id} not found"
 583 |         )
 584 |     
 585 |     # Check if model is already loaded
 586 |     if model_info.get("loaded", False) and not force_reload:
 587 |         return {
 588 |             "status": "success",
 589 |             "message": f"Model {model_id} is already loaded",
 590 |             "model_id": model_id,
 591 |             "model_type": model_type
 592 |         }
 593 |     
 594 |     # Start loading the model in the background
 595 |     if model_type == "vllm":
 596 |         # Start vLLM server for this model
 597 |         background_tasks.add_task(serve_vllm_model.remote, model_id=model_id)
 598 |         # Update model status
 599 |         VLLM_MODELS[model_id]["loaded"] = True
 600 |     else:  # llama.cpp
 601 |         # For llama.cpp models, we'll preload the model
 602 |         background_tasks.add_task(preload_llama_cpp_model, model_id)
 603 |         # Update model status
 604 |         LLAMA_CPP_MODELS[model_id]["loaded"] = True
 605 |     
 606 |     return {
 607 |         "status": "success",
 608 |         "message": f"Started loading model {model_id}",
 609 |         "model_id": model_id,
 610 |         "model_type": model_type
 611 |     }
 612 | 
 613 | @api_app.post("/admin/models/load-from-hf", dependencies=[Depends(verify_api_key)])
 614 | async def load_model_from_hf(request: HFModelLoadRequest, background_tasks: BackgroundTasks):
 615 |     """Load a model directly from Hugging Face"""
 616 |     repo_id = request.repo_id
 617 |     model_type = request.model_type
 618 |     revision = request.revision
 619 |     
 620 |     # Generate a unique model_id based on the repo name
 621 |     repo_name = repo_id.split("/")[-1] if "/" in repo_id else repo_id
 622 |     model_id = f"hf-{repo_name}-{uuid.uuid4().hex[:6]}"
 623 |     
 624 |     # Create model info based on type
 625 |     if model_type.lower() == "vllm":
 626 |         # Add to VLLM_MODELS
 627 |         VLLM_MODELS[model_id] = {
 628 |             "id": model_id,
 629 |             "name": repo_id,
 630 |             "revision": revision or "main",
 631 |             "max_tokens": request.max_tokens,
 632 |             "loaded": False,
 633 |             "hf_direct": True  # Mark as directly loaded from HF
 634 |         }
 635 |         
 636 |         # Start vLLM server for this model
 637 |         background_tasks.add_task(serve_vllm_model.remote, model_id=model_id)
 638 |         # Update model status
 639 |         VLLM_MODELS[model_id]["loaded"] = True
 640 |         
 641 |     elif model_type.lower() == "llama.cpp":
 642 |         # For llama.cpp we need quant info
 643 |         quant = request.quant or "Q4_K_M"  # Default quantization
 644 |         pattern = f"*{quant}*"
 645 |         
 646 |         # Add to LLAMA_CPP_MODELS
 647 |         LLAMA_CPP_MODELS[model_id] = {
 648 |             "id": model_id,
 649 |             "name": repo_id,
 650 |             "quant": quant,
 651 |             "pattern": pattern,
 652 |             "revision": revision,
 653 |             "gpu": request.gpu,  # Can be None for CPU
 654 |             "max_tokens": request.max_tokens,
 655 |             "loaded": False,
 656 |             "hf_direct": True  # Mark as directly loaded from HF
 657 |         }
 658 |         
 659 |         # Preload the model
 660 |         background_tasks.add_task(preload_llama_cpp_model, model_id)
 661 |         # Update model status
 662 |         LLAMA_CPP_MODELS[model_id]["loaded"] = True
 663 |         
 664 |     else:
 665 |         raise HTTPException(
 666 |             status_code=status.HTTP_400_BAD_REQUEST,
 667 |             detail=f"Invalid model type: {model_type}. Must be 'vllm' or 'llama.cpp'"
 668 |         )
 669 |     
 670 |     return {
 671 |         "status": "success",
 672 |         "message": f"Started loading model {repo_id} as {model_id}",
 673 |         "model_id": model_id,
 674 |         "model_type": model_type,
 675 |         "repo_id": repo_id
 676 |     }
 677 | 
 678 | @api_app.post("/admin/models/unload", dependencies=[Depends(verify_api_key)])
 679 | async def unload_model(request: ModelLoadRequest):
 680 |     """Unload a specific model from memory"""
 681 |     model_id = request.model_id
 682 |     
 683 |     # Check if model exists
 684 |     if model_id in VLLM_MODELS:
 685 |         model_type = "vllm"
 686 |         model_info = VLLM_MODELS[model_id]
 687 |     elif model_id in LLAMA_CPP_MODELS:
 688 |         model_type = "llama.cpp"
 689 |         model_info = LLAMA_CPP_MODELS[model_id]
 690 |     else:
 691 |         raise HTTPException(
 692 |             status_code=status.HTTP_404_NOT_FOUND,
 693 |             detail=f"Model {model_id} not found"
 694 |         )
 695 |     
 696 |     # Check if model is loaded
 697 |     if not model_info.get("loaded", False):
 698 |         return {
 699 |             "status": "success",
 700 |             "message": f"Model {model_id} is not loaded",
 701 |             "model_id": model_id,
 702 |             "model_type": model_type
 703 |         }
 704 |     
 705 |     # Update model status
 706 |     if model_type == "vllm":
 707 |         VLLM_MODELS[model_id]["loaded"] = False
 708 |     else:  # llama.cpp
 709 |         LLAMA_CPP_MODELS[model_id]["loaded"] = False
 710 |     
 711 |     return {
 712 |         "status": "success",
 713 |         "message": f"Unloaded model {model_id}",
 714 |         "model_id": model_id,
 715 |         "model_type": model_type
 716 |     }
 717 | 
 718 | @api_app.get("/admin/models/status/{model_id}", dependencies=[Depends(verify_api_key)])
 719 | async def get_model_status(model_id: str):
 720 |     """Get the status of a specific model"""
 721 |     # Check if model exists
 722 |     if model_id in VLLM_MODELS:
 723 |         model_type = "vllm"
 724 |         model_info = VLLM_MODELS[model_id]
 725 |     elif model_id in LLAMA_CPP_MODELS:
 726 |         model_type = "llama.cpp"
 727 |         model_info = LLAMA_CPP_MODELS[model_id]
 728 |     else:
 729 |         raise HTTPException(
 730 |             status_code=status.HTTP_404_NOT_FOUND,
 731 |             detail=f"Model {model_id} not found"
 732 |         )
 733 |     
 734 |     # Get model stats if available
 735 |     model_stats = model_stats_dict.get(model_id, {})
 736 |     
 737 |     # Include HF info if available
 738 |     hf_info = {}
 739 |     if model_info.get("hf_direct"):
 740 |         hf_info = {
 741 |             "repo_id": model_info.get("name"),
 742 |             "revision": model_info.get("revision"),
 743 |         }
 744 |         if model_type == "llama.cpp":
 745 |             hf_info["quant"] = model_info.get("quant")
 746 |     
 747 |     return {
 748 |         "model_id": model_id,
 749 |         "model_type": model_type,
 750 |         "loaded": model_info.get("loaded", False),
 751 |         "stats": model_stats,
 752 |         "hf_info": hf_info if hf_info else None
 753 |     }
 754 | 
 755 | # Admin API endpoints
 756 | class APIKeyRequest(BaseModel):
 757 |     user_id: str
 758 |     rate_limit: int = 60
 759 |     quota: int = 1000000
 760 |     
 761 | class APIKey(BaseModel):
 762 |     key: str
 763 |     user_id: str
 764 |     rate_limit: int
 765 |     quota: int
 766 |     created_at: str
 767 | 
 768 | @api_app.post("/admin/api-keys", response_model=APIKey)
 769 | async def create_api_key(request: APIKeyRequest, auth_info: dict = Depends(verify_api_key)):
 770 |     """Create a new API key for a user (admin only)"""
 771 |     # Check if this is an admin request
 772 |     if auth_info["user_id"] != "default":
 773 |         raise HTTPException(
 774 |             status_code=status.HTTP_403_FORBIDDEN,
 775 |             detail="Only admin users can create API keys"
 776 |         )
 777 |     
 778 |     # Generate a new API key
 779 |     new_key = f"sk-modal-{uuid.uuid4()}"
 780 |     user_id = request.user_id
 781 |     
 782 |     # Store the key
 783 |     api_keys_dict[user_id] = {
 784 |         "key": new_key,
 785 |         "rate_limit": request.rate_limit,
 786 |         "quota": request.quota,
 787 |         "created_at": datetime.now().isoformat(),
 788 |         "owner": user_id
 789 |     }
 790 |     
 791 |     # Initialize user usage
 792 |     if not user_usage_dict.contains(user_id):
 793 |         user_usage_dict[user_id] = {
 794 |             "requests": [],
 795 |             "tokens": {
 796 |                 "input": 0,
 797 |                 "output": 0,
 798 |                 "last_reset": datetime.now().isoformat()
 799 |             }
 800 |         }
 801 |     
 802 |     return APIKey(
 803 |         key=new_key,
 804 |         user_id=user_id,
 805 |         rate_limit=request.rate_limit,
 806 |         quota=request.quota,
 807 |         created_at=datetime.now().isoformat()
 808 |     )
 809 | 
 810 | @api_app.get("/admin/api-keys")
 811 | async def list_api_keys(auth_info: dict = Depends(verify_api_key)):
 812 |     """List all API keys (admin only)"""
 813 |     # Check if this is an admin request
 814 |     if auth_info["user_id"] != "default":
 815 |         raise HTTPException(
 816 |             status_code=status.HTTP_403_FORBIDDEN,
 817 |             detail="Only admin users can list API keys"
 818 |         )
 819 |     
 820 |     # Return all keys (except the actual key values for security)
 821 |     keys = []
 822 |     for user_id, key_info in api_keys_dict.items():
 823 |         keys.append({
 824 |             "user_id": user_id,
 825 |             "rate_limit": key_info.get("rate_limit", 60),
 826 |             "quota": key_info.get("quota", 1000000),
 827 |             "created_at": key_info.get("created_at", datetime.now().isoformat()),
 828 |             # Mask the actual key
 829 |             "key": key_info.get("key", "")[:8] + "..." if key_info.get("key") else "None"
 830 |         })
 831 |     
 832 |     return {"keys": keys}
 833 | 
 834 | @api_app.get("/admin/stats")
 835 | async def get_stats(auth_info: dict = Depends(verify_api_key)):
 836 |     """Get usage statistics (admin only)"""
 837 |     # Check if this is an admin request
 838 |     if auth_info["user_id"] != "default":
 839 |         raise HTTPException(
 840 |             status_code=status.HTTP_403_FORBIDDEN,
 841 |             detail="Only admin users can view stats"
 842 |         )
 843 |     
 844 |     # Get model stats
 845 |     model_stats = {}
 846 |     for model_id in list(VLLM_MODELS.keys()) + list(LLAMA_CPP_MODELS.keys()):
 847 |         if model_id in model_stats_dict:
 848 |             model_stats[model_id] = model_stats_dict[model_id]
 849 |     
 850 |     # Get user stats
 851 |     user_stats = {}
 852 |     for user_id in user_usage_dict.keys():
 853 |         usage = user_usage_dict[user_id]
 854 |         # Don't include request timestamps for brevity
 855 |         if "requests" in usage:
 856 |             usage = usage.copy()
 857 |             usage["request_count"] = len(usage["requests"])
 858 |             del usage["requests"]
 859 |         user_stats[user_id] = usage
 860 |     
 861 |     # Get queue info
 862 |     queue_info = {
 863 |         "pending_requests": request_queue.len(),
 864 |         "active_workers": model_stats_dict.get("workers_running", 0)
 865 |     }
 866 |     
 867 |     return {
 868 |         "models": model_stats,
 869 |         "users": user_stats,
 870 |         "queue": queue_info,
 871 |         "timestamp": datetime.now().isoformat()
 872 |     }
 873 | 
 874 | @api_app.delete("/admin/api-keys/{user_id}")
 875 | async def delete_api_key(user_id: str, auth_info: dict = Depends(verify_api_key)):
 876 |     """Delete an API key (admin only)"""
 877 |     # Check if this is an admin request
 878 |     if auth_info["user_id"] != "default":
 879 |         raise HTTPException(
 880 |             status_code=status.HTTP_403_FORBIDDEN,
 881 |             detail="Only admin users can delete API keys"
 882 |         )
 883 |     
 884 |     # Check if the key exists
 885 |     if not api_keys_dict.contains(user_id):
 886 |         raise HTTPException(
 887 |             status_code=status.HTTP_404_NOT_FOUND,
 888 |             detail=f"No API key found for user {user_id}"
 889 |         )
 890 |     
 891 |     # Can't delete the default key
 892 |     if user_id == "default":
 893 |         raise HTTPException(
 894 |             status_code=status.HTTP_400_BAD_REQUEST,
 895 |             detail="Cannot delete the default API key"
 896 |         )
 897 |     
 898 |     # Delete the key
 899 |     api_keys_dict.pop(user_id)
 900 |     
 901 |     return {"status": "success", "message": f"API key deleted for user {user_id}"}
 902 | 
 903 | @api_app.post("/v1/chat/completions")
 904 | async def chat_completions(request: Request, background_tasks: BackgroundTasks, auth_info: dict = Depends(verify_api_key)):
 905 |     """OpenAI-compatible chat completions endpoint with request queueing, streaming and response caching"""
 906 |     try:
 907 |         json_data = await request.json()
 908 |         
 909 |         # Extract model or use default
 910 |         model_id = json_data.get("model", DEFAULT_MODEL)
 911 |         messages = json_data.get("messages", [])
 912 |         temperature = json_data.get("temperature", 0.7)
 913 |         max_tokens = json_data.get("max_tokens", 1024)
 914 |         stream = json_data.get("stream", False)
 915 |         user = json_data.get("user", auth_info["user_id"])
 916 |         
 917 |         # Calculate a cache key based on the request parameters
 918 |         cache_key = calculate_cache_key(model_id, messages, temperature, max_tokens)
 919 |         
 920 |         # Check if we have a cached response in memory cache first (faster)
 921 |         cached_response = memory_cache.get(cache_key)
 922 |         if cached_response and not stream:  # Don't use cache for streaming requests
 923 |             # Update stats
 924 |             update_stats(model_id, "cache_hit")
 925 |             return cached_response
 926 |         
 927 |         # Check if we have a cached response in Modal's persistent cache
 928 |         if not cached_response and cache_key in response_dict and not stream:
 929 |             cached_response = response_dict[cache_key]
 930 |             cache_age = time.time() - cached_response.get("timestamp", 0)
 931 |             
 932 |             # Use cached response if it's fresh enough
 933 |             if cache_age < MAX_CACHE_AGE:
 934 |                 # Update stats
 935 |                 update_stats(model_id, "cache_hit")
 936 |                 response_data = cached_response["response"]
 937 |                 
 938 |                 # Also cache in memory for faster access next time
 939 |                 memory_cache.set(cache_key, response_data)
 940 |                 
 941 |                 return response_data
 942 |         
 943 |         # Select best model if "auto" is specified
 944 |         if model_id == "auto" and len(messages) > 0:
 945 |             # Get the last user message
 946 |             last_message = None
 947 |             for msg in reversed(messages):
 948 |                 if msg.get("role") == "user":
 949 |                     last_message = msg.get("content", "")
 950 |                     break
 951 |             
 952 |             if last_message:
 953 |                 prompt = last_message
 954 |                 # Select best model based on prompt and parameters
 955 |                 model_id = select_best_model(prompt, max_tokens, temperature)
 956 |                 logging.info(f"Auto-selected model: {model_id} for prompt")
 957 |         
 958 |         # Check if model exists
 959 |         if model_id not in VLLM_MODELS and model_id not in LLAMA_CPP_MODELS:
 960 |             # Default to the default model if specified model not found
 961 |             logging.warning(f"Model {model_id} not found, using default: {DEFAULT_MODEL}")
 962 |             model_id = DEFAULT_MODEL
 963 |         
 964 |         # Create a unique request ID
 965 |         request_id = str(uuid.uuid4())
 966 |         
 967 |         # Create request object
 968 |         gen_request = GenerationRequest(
 969 |             request_id=request_id,
 970 |             model_id=model_id,
 971 |             messages=messages,
 972 |             temperature=temperature,
 973 |             max_tokens=max_tokens,
 974 |             top_p=json_data.get("top_p", 1.0),
 975 |             frequency_penalty=json_data.get("frequency_penalty", 0.0),
 976 |             presence_penalty=json_data.get("presence_penalty", 0.0),
 977 |             user=user,
 978 |             stream=stream,
 979 |             api_key=auth_info["key"]
 980 |         )
 981 |         
 982 |         # For streaming requests, set up streaming response
 983 |         if stream:
 984 |             # Create a new stream
 985 |             stream_manager.create_stream(request_id)
 986 |             
 987 |             # Put the request in the queue
 988 |             await request_queue.put.aio(gen_request.model_dump())
 989 |             
 990 |             # Update stats
 991 |             update_stats(model_id, "request_count")
 992 |             update_stats(model_id, "stream_count")
 993 |             
 994 |             # Start a background worker to process the request if needed
 995 |             background_tasks.add_task(ensure_worker_running)
 996 |             
 997 |             # Return a streaming response using FastAPI's StreamingResponse
 998 |             from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
 999 |             return FastAPIStreamingResponse(
1000 |                 content=stream_response(request_id, model_id, auth_info["user_id"]),
1001 |                 media_type="text/event-stream"
1002 |             )
1003 |             
1004 |         # For non-streaming, enqueue the request and wait for result
1005 |         # Put the request in the queue
1006 |         await request_queue.put.aio(gen_request.model_dump())
1007 |         
1008 |         # Update stats
1009 |         update_stats(model_id, "request_count")
1010 |         
1011 |         # Start a background worker to process the request if needed
1012 |         background_tasks.add_task(ensure_worker_running)
1013 |         
1014 |         # Wait for the response with timeout
1015 |         start_time = time.time()
1016 |         timeout = 120  # 2-minute timeout for non-streaming requests
1017 |         
1018 |         while time.time() - start_time < timeout:
1019 |             # Check memory cache first (faster)
1020 |             response_data = memory_cache.get(request_id)
1021 |             if response_data:
1022 |                 # Update stats
1023 |                 update_stats(model_id, "success_count")
1024 |                 estimate_tokens(messages, response_data, auth_info["user_id"], model_id)
1025 |                 
1026 |                 # Save to persistent cache
1027 |                 response_dict[cache_key] = {
1028 |                     "response": response_data,
1029 |                     "timestamp": time.time()
1030 |                 }
1031 |                 
1032 |                 # Clean up request-specific cache
1033 |                 memory_cache.set(request_id, None)
1034 |                 
1035 |                 return response_data
1036 |                 
1037 |             # Check persistent cache
1038 |             if response_dict.contains(request_id):
1039 |                 response_data = response_dict[request_id]
1040 |                 
1041 |                 # Remove from response dict to save memory
1042 |                 try:
1043 |                     response_dict.pop(request_id)
1044 |                 except Exception:
1045 |                     pass
1046 |                 
1047 |                 # Save to cache
1048 |                 response_dict[cache_key] = {
1049 |                     "response": response_data,
1050 |                     "timestamp": time.time()
1051 |                 }
1052 |                 
1053 |                 # Also cache in memory
1054 |                 memory_cache.set(cache_key, response_data)
1055 |                 
1056 |                 # Update stats
1057 |                 update_stats(model_id, "success_count")
1058 |                 estimate_tokens(messages, response_data, auth_info["user_id"], model_id)
1059 |                 
1060 |                 return response_data
1061 |             
1062 |             # Wait a bit before checking again
1063 |             await asyncio.sleep(0.1)
1064 |         
1065 |         # If we get here, we timed out
1066 |         update_stats(model_id, "timeout_count")
1067 |         raise HTTPException(
1068 |             status_code=status.HTTP_504_GATEWAY_TIMEOUT,
1069 |             detail="Request timed out. The model may be busy. Please try again later."
1070 |         )
1071 |             
1072 |     except Exception as e:
1073 |         logging.error(f"Error in chat completions: {str(e)}")
1074 |         # Update error stats
1075 |         if "model_id" in locals():
1076 |             update_stats(model_id, "error_count")
1077 |             
1078 |         raise HTTPException(
1079 |             status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
1080 |             detail=f"Error generating response: {str(e)}"
1081 |         )
1082 | 
1083 | async def stream_response(request_id: str, model_id: str, user_id: str) -> AsyncIterator[str]:
1084 |     """Stream response chunks to the client"""
1085 |     try:
1086 |         # Stream header
1087 |         yield "data: " + json.dumps({"object": "chat.completion.chunk"}) + "\n\n"
1088 |         
1089 |         # Stream chunks
1090 |         async for chunk in stream_manager.get_chunks(request_id):
1091 |             if chunk:
1092 |                 yield f"data: {json.dumps(chunk)}\n\n"
1093 |         
1094 |         # Stream done
1095 |         yield "data: [DONE]\n\n"
1096 |         
1097 |     except Exception as e:
1098 |         logging.error(f"Error streaming response: {str(e)}")
1099 |         # Update error stats
1100 |         update_stats(model_id, "stream_error_count")
1101 |         
1102 |         # Send error as SSE
1103 |         error_json = json.dumps({"error": str(e)})
1104 |         yield f"data: {error_json}\n\n"
1105 |         yield "data: [DONE]\n\n"
1106 |         
1107 | async def ensure_worker_running():
1108 |     """Ensure that a worker is running to process the queue"""
1109 |     # Check if workers are already running via a sentinel in shared dict
1110 |     workers_running_key = "workers_running"
1111 |     
1112 |     if not model_stats_dict.contains(workers_running_key):
1113 |         model_stats_dict[workers_running_key] = 0
1114 |     
1115 |     current_workers = model_stats_dict[workers_running_key]
1116 |     
1117 |     # If no workers or too few workers, start more
1118 |     if current_workers < 3:  # Keep up to 3 workers running
1119 |         # Increment worker count
1120 |         model_stats_dict[workers_running_key] = current_workers + 1
1121 |         
1122 |         # Start a worker
1123 |         await process_queue_worker.spawn.aio()
1124 | 
1125 | def calculate_cache_key(model_id: str, messages: List[dict], temperature: float, max_tokens: int) -> str:
1126 |     """Calculate a deterministic cache key for a request using SHA-256"""
1127 |     # Create a simplified version of the request for cache key
1128 |     cache_dict = {
1129 |         "model": model_id,
1130 |         "messages": messages,
1131 |         "temperature": round(temperature, 2),  # Round to reduce variations
1132 |         "max_tokens": max_tokens
1133 |     }
1134 |     # Convert to a stable string representation and hash it with SHA-256
1135 |     cache_str = json.dumps(cache_dict, sort_keys=True)
1136 |     hash_obj = hashlib.sha256(cache_str.encode())
1137 |     return f"cache:{hash_obj.hexdigest()[:16]}"
1138 | 
1139 | def update_stats(model_id: str, stat_type: str):
1140 |     """Update usage statistics for a model"""
1141 |     if not model_stats_dict.contains(model_id):
1142 |         model_stats_dict[model_id] = {
1143 |             "request_count": 0,
1144 |             "success_count": 0,
1145 |             "error_count": 0,
1146 |             "timeout_count": 0,
1147 |             "cache_hit": 0,
1148 |             "token_count": 0,
1149 |             "avg_latency": 0
1150 |         }
1151 |     
1152 |     stats = model_stats_dict[model_id]
1153 |     stats[stat_type] = stats.get(stat_type, 0) + 1
1154 |     model_stats_dict[model_id] = stats
1155 |     
1156 | def estimate_tokens(messages: List[dict], response: dict, user_id: str, model_id: str):
1157 |     """Estimate token usage and update user quotas"""
1158 |     # Very simple token estimation based on whitespace-split words * 1.3
1159 |     input_tokens = 0
1160 |     for msg in messages:
1161 |         input_tokens += len(msg.get("content", "").split()) * 1.3
1162 |     
1163 |     output_tokens = 0
1164 |     if response and "choices" in response:
1165 |         for choice in response["choices"]:
1166 |             if "message" in choice and "content" in choice["message"]:
1167 |                 output_tokens += len(choice["message"]["content"].split()) * 1.3
1168 |     
1169 |     # Update model stats
1170 |     if model_stats_dict.contains(model_id):
1171 |         stats = model_stats_dict[model_id]
1172 |         stats["token_count"] = stats.get("token_count", 0) + input_tokens + output_tokens
1173 |         model_stats_dict[model_id] = stats
1174 |     
1175 |     # Update user usage
1176 |     if user_id in user_usage_dict:
1177 |         usage = user_usage_dict[user_id]
1178 |         
1179 |         # Check if we need to reset daily counters
1180 |         last_reset = datetime.fromisoformat(usage["tokens"]["last_reset"])
1181 |         now = datetime.now()
1182 |         
1183 |         if now.date() > last_reset.date():
1184 |             # Reset daily counters
1185 |             usage["tokens"]["input"] = 0
1186 |             usage["tokens"]["output"] = 0
1187 |             usage["tokens"]["last_reset"] = now.isoformat()
1188 |         
1189 |         # Update token counts
1190 |         usage["tokens"]["input"] += int(input_tokens)
1191 |         usage["tokens"]["output"] += int(output_tokens)
1192 |         user_usage_dict[user_id] = usage
1193 | 
1194 | def select_best_model(prompt: str, n_predict: int, temperature: float) -> str:
1195 |     """
1196 |     Intelligently selects the best model based on input parameters.
1197 | 
1198 |     Args:
1199 |         prompt (str): The input prompt for the model.
1200 |         n_predict (int): The number of tokens to predict.
1201 |         temperature (float): The sampling temperature.
1202 | 
1203 |     Returns:
1204 |         str: The identifier of the best model to use.
1205 |     """
1206 |     # Check for code generation patterns
1207 |     code_indicators = ["```", "def ", "class ", "function", "import ", "from ", "<script", "<style", 
1208 |                       "SELECT ", "CREATE TABLE", "const ", "let ", "var ", "function(", "=>"]
1209 |     
1210 |     is_likely_code = any(indicator in prompt for indicator in code_indicators)
1211 |     
1212 |     # Check for creative writing patterns
1213 |     creative_indicators = ["story", "poem", "creative", "imagine", "fiction", "narrative", 
1214 |                           "write a", "compose", "create a"]
1215 |     
1216 |     is_creative_task = any(indicator in prompt.lower() for indicator in creative_indicators)
1217 |     
1218 |     # Check for analytical/reasoning tasks
1219 |     analytical_indicators = ["explain", "analyze", "compare", "contrast", "reason", 
1220 |                             "evaluate", "assess", "why", "how does"]
1221 |     
1222 |     is_analytical_task = any(indicator in prompt.lower() for indicator in analytical_indicators)
1223 |     
1224 |     # Decision logic
1225 |     if is_likely_code:
1226 |         # For code generation, prefer phi-4 for all code tasks
1227 |         return "phi-4"  # Excellent for code generation
1228 |             
1229 |     elif is_creative_task:
1230 |         # For creative tasks, use models with higher creativity
1231 |         if temperature > 0.8:
1232 |             return "deepseek-r1"  # More creative at high temperatures
1233 |         else:
1234 |             return "phi-4"  # Good balance of creativity and coherence
1235 |             
1236 |     elif is_analytical_task:
1237 |         # For analytical tasks, use models with strong reasoning
1238 |         return "phi-4"  # Strong reasoning capabilities
1239 |         
1240 |     # Length-based decisions
1241 |     if len(prompt) > 2000:
1242 |         # For very long prompts, use models with good context handling
1243 |         return "llama3-8b"
1244 |     elif len(prompt) < 1000:
1245 |         # For shorter prompts, prefer phi-4
1246 |         return "phi-4"
1247 |         
1248 |     # Temperature-based decisions
1249 |     if temperature < 0.5:
1250 |         # For deterministic outputs
1251 |         return "phi-4"
1252 |     elif temperature > 0.9:
1253 |         # For very creative outputs
1254 |         return "deepseek-r1"
1255 |         
1256 |     # Default to phi-4 instead of the standard model
1257 |     return "phi-4"
1258 | 
1259 | # vLLM serving function
1260 | @app.function(
1261 |     image=vllm_image,
1262 |     gpu="H100:1",
1263 |     allow_concurrent_inputs=100,
1264 |     volumes={
1265 |         f"{CACHE_DIR}/huggingface": hf_cache_vol,
1266 |         f"{CACHE_DIR}/vllm": vllm_cache_vol,
1267 |     },
1268 |     timeout=30 * MINUTES,
1269 | )
1270 | @modal.web_server(port=SERVER_PORT)
1271 | def serve_vllm_model(model_id: str = DEFAULT_MODEL):
1272 |     """
1273 |     Serves a model using vLLM with an OpenAI-compatible API.
1274 | 
1275 |     Args:
1276 |         model_id (str): The identifier of the model to serve. Defaults to DEFAULT_MODEL.
1277 | 
1278 |     Raises:
1279 |         ValueError: If the specified model_id is not found in VLLM_MODELS.
1280 |     """
1281 |     import subprocess
1282 |     
1283 |     if model_id not in VLLM_MODELS:
1284 |         available_models = list(VLLM_MODELS.keys())
1285 |         logging.error(f"Error: Unknown model: {model_id}. Available models: {available_models}")
1286 |         raise ValueError(f"Unknown model: {model_id}. Available models: {available_models}")
1287 |     
1288 |     model_info = VLLM_MODELS[model_id]
1289 |     model_name = model_info["name"]
1290 |     revision = model_info["revision"]
1291 |     
1292 |     logging.basicConfig(level=logging.INFO)
1293 |     logging.info(f"Starting vLLM server with model: {model_name}")
1294 |     
1295 |     cmd = [
1296 |         "vllm",
1297 |         "serve",
1298 |         "--uvicorn-log-level=info",
1299 |         model_name,
1300 |         "--revision",
1301 |         revision,
1302 |         "--host",
1303 |         "0.0.0.0",
1304 |         "--port",
1305 |         str(SERVER_PORT),
1306 |         "--api-key",
1307 |         DEFAULT_API_KEY,
1308 |     ]
1309 | 
1310 |     # Use subprocess.run instead of Popen to ensure the server is fully started
1311 |     # before returning, and don't use shell=True for better process management
1312 |     process = subprocess.Popen(cmd)
1313 |     
1314 |     # Log that we've started the server
1315 |     logging.info(f"Started vLLM server with PID {process.pid}")
1316 | 
1317 | # Define the worker that will process the queue
1318 | @app.function(
1319 |     image=vllm_image,
1320 |     gpu=None,  # Worker will spawn GPU functions as needed
1321 |     allow_concurrent_inputs=10,
1322 |     volumes={
1323 |         f"{CACHE_DIR}/huggingface": hf_cache_vol,
1324 |     },
1325 |     timeout=30 * MINUTES,
1326 | )
1327 | async def process_queue_worker():
1328 |     """Worker function that processes requests from the queue"""
1329 |     import asyncio
1330 |     import time
1331 |     
1332 |     try:
1333 |         # Signal that we're starting a worker
1334 |         worker_id = str(uuid.uuid4())[:8]
1335 |         logging.info(f"Starting queue processing worker {worker_id}")
1336 |         
1337 |         # Process requests until timeout or empty queue
1338 |         empty_count = 0
1339 |         max_empty_count = 10  # Stop after 10 consecutive empty polls
1340 |         
1341 |         while empty_count < max_empty_count:
1342 |             # Try to get a request from the queue
1343 |             try:
1344 |                 request_dict = await request_queue.get.aio(timeout_ms=5000)
1345 |                 empty_count = 0  # Reset empty counter
1346 |                 
1347 |                 # Process the request
1348 |                 try:
1349 |                     # Create request object
1350 |                     request_id = request_dict.get("request_id")
1351 |                     model_id = request_dict.get("model_id")
1352 |                     messages = request_dict.get("messages", [])
1353 |                     temperature = request_dict.get("temperature", 0.7)
1354 |                     max_tokens = request_dict.get("max_tokens", 1024)
1355 |                     api_key = request_dict.get("api_key", DEFAULT_API_KEY)
1356 |                     stream_mode = request_dict.get("stream", False)
1357 |                     
1358 |                     logging.info(f"Worker {worker_id} processing request {request_id} for model {model_id}")
1359 |                     
1360 |                     # Start time for latency calculation
1361 |                     start_time = time.time()
1362 |                     
1363 |                     if stream_mode:
1364 |                         # Generate streaming response
1365 |                         await generate_streaming_response(
1366 |                             request_id=request_id,
1367 |                             model_id=model_id,
1368 |                             messages=messages,
1369 |                             temperature=temperature,
1370 |                             max_tokens=max_tokens,
1371 |                             api_key=api_key
1372 |                         )
1373 |                     else:
1374 |                         # Generate non-streaming response
1375 |                         response = await generate_response(
1376 |                             model_id=model_id,
1377 |                             messages=messages,
1378 |                             temperature=temperature,
1379 |                             max_tokens=max_tokens,
1380 |                             api_key=api_key
1381 |                         )
1382 |                         
1383 |                         # Calculate latency
1384 |                         latency = time.time() - start_time
1385 |                         
1386 |                         # Update latency stats
1387 |                         if model_stats_dict.contains(model_id):
1388 |                             stats = model_stats_dict[model_id]
1389 |                             old_avg = stats.get("avg_latency", 0)
1390 |                             old_count = stats.get("success_count", 0) 
1391 |                             
1392 |                             # Calculate new average (moving average)
1393 |                             if old_count > 0:
1394 |                                 new_avg = (old_avg * old_count + latency) / (old_count + 1)
1395 |                             else:
1396 |                                 new_avg = latency
1397 |                                 
1398 |                             stats["avg_latency"] = new_avg
1399 |                             model_stats_dict[model_id] = stats
1400 |                         
1401 |                         # Store the response in both caches
1402 |                         memory_cache.set(request_id, response)
1403 |                         response_dict[request_id] = response
1404 |                         
1405 |                         logging.info(f"Worker {worker_id} completed request {request_id} in {latency:.2f}s")
1406 |                     
1407 |                 except Exception as e:
1408 |                     # Log error and move on
1409 |                     logging.error(f"Worker {worker_id} error processing request {request_id}: {str(e)}")
1410 |                     
1411 |                     # Create error response
1412 |                     error_response = {
1413 |                         "error": {
1414 |                             "message": str(e),
1415 |                             "type": "internal_error",
1416 |                             "code": 500
1417 |                         }
1418 |                     }
1419 |                     
1420 |                     # Store the error as a response
1421 |                     memory_cache.set(request_id, error_response)
1422 |                     response_dict[request_id] = error_response
1423 |                     
1424 |                     # If streaming, send error and finish stream
1425 |                     if "stream_mode" in locals() and stream_mode:
1426 |                         stream_manager.add_chunk(request_id, {
1427 |                             "id": f"chatcmpl-{int(time.time())}",
1428 |                             "object": "chat.completion.chunk",
1429 |                             "created": int(time.time()),
1430 |                             "model": model_id,
1431 |                             "choices": [{
1432 |                                 "index": 0,
1433 |                                 "delta": {"content": f"Error: {str(e)}"},
1434 |                                 "finish_reason": "error"
1435 |                             }]
1436 |                         })
1437 |                         stream_manager.finish_stream(request_id)
1438 |             
1439 |             except asyncio.TimeoutError:
1440 |                 # No requests in queue
1441 |                 empty_count += 1
1442 |                 logging.info(f"Worker {worker_id}: No requests in queue. Empty count: {empty_count}")
1443 |                 
1444 |                 # Clean up expired cache entries and old streams
1445 |                 if empty_count % 5 == 0:  # Every 5 empty polls
1446 |                     memory_cache.clear_expired()
1447 |                     stream_manager.clean_old_streams()
1448 |                 
1449 |                 await asyncio.sleep(1)  # Wait a bit before checking again
1450 |         
1451 |         # If we get here, we've had too many consecutive empty polls
1452 |         logging.info(f"Worker {worker_id} shutting down due to empty queue")
1453 |         
1454 |     finally:
1455 |         # Signal that this worker is done
1456 |         workers_running_key = "workers_running"
1457 |         if model_stats_dict.contains(workers_running_key):
1458 |             current_workers = model_stats_dict[workers_running_key]
1459 |             model_stats_dict[workers_running_key] = max(0, current_workers - 1)
1460 |             logging.info(f"Worker {worker_id} shutdown. Workers remaining: {max(0, current_workers - 1)}")
1461 | 
1462 | async def generate_streaming_response(
1463 |     request_id: str,
1464 |     model_id: str,
1465 |     messages: List[dict],
1466 |     temperature: float,
1467 |     max_tokens: int,
1468 |     api_key: str
1469 | ):
1470 |     """
1471 |     Generate a streaming response and send chunks to the stream manager.
1472 |     
1473 |     Args:
1474 |         request_id: The unique ID for this request
1475 |         model_id: The ID of the model to use
1476 |         messages: The chat messages
1477 |         temperature: The sampling temperature
1478 |         max_tokens: The maximum tokens to generate
1479 |         api_key: The API key for authentication
1480 |     """
1481 |     import httpx
1482 |     import time
1483 |     import json
1484 |     import asyncio
1485 |     
1486 |     try:
1487 |         # Create response ID
1488 |         response_id = f"chatcmpl-{int(time.time())}"
1489 |         
1490 |         if model_id in VLLM_MODELS:
1491 |             # Start vLLM server for this model
1492 |             server_url = await serve_vllm_model.remote(model_id=model_id)
1493 |             
1494 |             # Need to wait for server startup
1495 |             await wait_for_server(serve_vllm_model.web_url, timeout=120)
1496 |             
1497 |             # Forward request to vLLM with streaming enabled
1498 |             async with httpx.AsyncClient(timeout=120.0) as client:
1499 |                 headers = {
1500 |                     "Authorization": f"Bearer {api_key}",
1501 |                     "Content-Type": "application/json",
1502 |                     "Accept": "text/event-stream"
1503 |                 }
1504 |                 
1505 |                 # Format request for vLLM OpenAI-compatible endpoint
1506 |                 vllm_request = {
1507 |                     "model": VLLM_MODELS[model_id]["name"],
1508 |                     "messages": messages,
1509 |                     "temperature": temperature,
1510 |                     "max_tokens": max_tokens,
1511 |                     "stream": True
1512 |                 }
1513 |                 
1514 |                 # Make streaming request
1515 |                 async with client.stream(
1516 |                     "POST",
1517 |                     f"{serve_vllm_model.web_url}/v1/chat/completions",
1518 |                     json=vllm_request,
1519 |                     headers=headers
1520 |                 ) as response:
1521 |                     # Process streaming response
1522 |                     buffer = ""
1523 |                     async for chunk in response.aiter_text():
1524 |                         buffer += chunk
1525 |                         
1526 |                         # Process complete SSE messages
1527 |                         while "\n\n" in buffer:
1528 |                             message, buffer = buffer.split("\n\n", 1)
1529 |                             
1530 |                             if message.startswith("data: "):
1531 |                                 data = message[6:]  # Remove "data: " prefix
1532 |                                 
1533 |                                 if data == "[DONE]":
1534 |                                     # End of stream
1535 |                                     stream_manager.finish_stream(request_id)
1536 |                                     return
1537 |                                 
1538 |                                 try:
1539 |                                     # Parse JSON data
1540 |                                     chunk_data = json.loads(data)
1541 |                                     # Forward to client
1542 |                                     stream_manager.add_chunk(request_id, chunk_data)
1543 |                                 except json.JSONDecodeError:
1544 |                                     logging.error(f"Invalid JSON in stream: {data}")
1545 |                     
1546 |                     # Ensure stream is finished
1547 |                     stream_manager.finish_stream(request_id)
1548 |                     
1549 |         elif model_id in LLAMA_CPP_MODELS:
1550 |             # For llama.cpp models, we need to simulate streaming
1551 |             # First convert the chat format to a prompt
1552 |             prompt = format_messages_to_prompt(messages)
1553 |             
1554 |             # Run llama.cpp with the prompt
1555 |             output = await run_llama_cpp_stream.remote(
1556 |                 model_id=model_id,
1557 |                 prompt=prompt,
1558 |                 n_predict=max_tokens,
1559 |                 temperature=temperature,
1560 |                 request_id=request_id
1561 |             )
1562 |             
1563 |             # Streaming is handled by the run_llama_cpp_stream function
1564 |             # which directly adds chunks to the stream manager
1565 |             
1566 |             # Wait for completion signal
1567 |             while True:
1568 |                 if request_id in stream_queues and stream_queues[request_id] == "DONE":
1569 |                     # Clean up
1570 |                     stream_queues.pop(request_id)
1571 |                     break
1572 |                 await asyncio.sleep(0.1)
1573 |                 
1574 |         else:
1575 |             raise ValueError(f"Unknown model: {model_id}")
1576 |             
1577 |     except Exception as e:
1578 |         logging.error(f"Error in streaming generation: {str(e)}")
1579 |         # Send error chunk
1580 |         stream_manager.add_chunk(request_id, {
1581 |             "id": response_id,
1582 |             "object": "chat.completion.chunk",
1583 |             "created": int(time.time()),
1584 |             "model": model_id,
1585 |             "choices": [{
1586 |                 "index": 0,
1587 |                 "delta": {"content": f"Error: {str(e)}"},
1588 |                 "finish_reason": "error"
1589 |             }]
1590 |         })
1591 |         # Finish stream
1592 |         stream_manager.finish_stream(request_id)
1593 | 
1594 | async def generate_response(model_id: str, messages: List[dict], temperature: float, max_tokens: int, api_key: str):
1595 |     """
1596 |     Generate a response using the appropriate model based on model_id.
1597 |     
1598 |     Args:
1599 |         model_id: The ID of the model to use
1600 |         messages: The chat messages
1601 |         temperature: The sampling temperature
1602 |         max_tokens: The maximum tokens to generate
1603 |         api_key: The API key for authentication
1604 |         
1605 |     Returns:
1606 |         A response in OpenAI-compatible format
1607 |     """
1608 |     import httpx
1609 |     import time
1610 |     import json
1611 |     import asyncio
1612 |     
1613 |     if model_id in VLLM_MODELS:
1614 |         # Start vLLM server for this model
1615 |         server_url = await serve_vllm_model.remote(model_id=model_id)
1616 |         
1617 |         # Need to wait for server startup
1618 |         await wait_for_server(serve_vllm_model.web_url, timeout=120)
1619 |         
1620 |         # Forward request to vLLM
1621 |         async with httpx.AsyncClient(timeout=60.0) as client:
1622 |             headers = {
1623 |                 "Authorization": f"Bearer {api_key}",
1624 |                 "Content-Type": "application/json"
1625 |             }
1626 |             
1627 |             # Format request for vLLM OpenAI-compatible endpoint
1628 |             vllm_request = {
1629 |                 "model": VLLM_MODELS[model_id]["name"],
1630 |                 "messages": messages,
1631 |                 "temperature": temperature,
1632 |                 "max_tokens": max_tokens
1633 |             }
1634 |             
1635 |             response = await client.post(
1636 |                 f"{serve_vllm_model.web_url}/v1/chat/completions",
1637 |                 json=vllm_request,
1638 |                 headers=headers
1639 |             )
1640 |             
1641 |             return response.json()
1642 |     elif model_id in LLAMA_CPP_MODELS:
1643 |         # For llama.cpp models, use the run_llama_cpp function
1644 |         # First convert the chat format to a prompt
1645 |         prompt = format_messages_to_prompt(messages)
1646 |         
1647 |         # Run llama.cpp with the prompt
1648 |         output = await run_llama_cpp.remote(
1649 |             model_id=model_id,
1650 |             prompt=prompt,
1651 |             n_predict=max_tokens,
1652 |             temperature=temperature
1653 |         )
1654 |         
1655 |         # Format the response in the OpenAI format
1656 |         completion_text = output.strip()
1657 |         finish_reason = "stop" if len(completion_text) < max_tokens else "length"
1658 |         
1659 |         return {
1660 |             "id": f"chatcmpl-{int(time.time())}",
1661 |             "object": "chat.completion",
1662 |             "created": int(time.time()),
1663 |             "model": model_id,
1664 |             "choices": [
1665 |                 {
1666 |                     "index": 0,
1667 |                     "message": {
1668 |                         "role": "assistant",
1669 |                         "content": completion_text
1670 |                     },
1671 |                     "finish_reason": finish_reason
1672 |                 }
1673 |             ],
1674 |             "usage": {
1675 |                 "prompt_tokens": len(prompt) // 4,  # Rough estimation
1676 |                 "completion_tokens": len(completion_text) // 4,  # Rough estimation
1677 |                 "total_tokens": (len(prompt) + len(completion_text)) // 4  # Rough estimation
1678 |             }
1679 |         }
1680 |     else:
1681 |         raise ValueError(f"Unknown model: {model_id}")
1682 | 
1683 | def format_messages_to_prompt(messages: List[Dict[str, str]]) -> str:
1684 |     """
1685 |     Convert chat messages to a text prompt format for llama.cpp.
1686 |     
1687 |     Args:
1688 |         messages: List of message dictionaries with role and content
1689 |     
1690 |     Returns:
1691 |         Formatted prompt string
1692 |     """
1693 |     formatted_prompt = ""
1694 |     
1695 |     for message in messages:
1696 |         role = message.get("role", "").lower()
1697 |         content = message.get("content", "")
1698 |         
1699 |         if role == "system":
1700 |             formatted_prompt += f"<|system|>\n{content}\n"
1701 |         elif role == "user":
1702 |             formatted_prompt += f"<|user|>\n{content}\n"
1703 |         elif role == "assistant":
1704 |             formatted_prompt += f"<|assistant|>\n{content}\n"
1705 |         else:
1706 |             # For unknown roles, treat as user
1707 |             formatted_prompt += f"<|user|>\n{content}\n"
1708 |     
1709 |     # Add final assistant marker to prompt the model to respond
1710 |     formatted_prompt += "<|assistant|>\n"
1711 |     
1712 |     return formatted_prompt
1713 | 
1714 | async def wait_for_server(url: str, timeout: int = 120, check_interval: int = 2):
1715 |     """
1716 |     Wait for a server to be ready by checking its health endpoint.
1717 |     
1718 |     Args:
1719 |         url: The base URL of the server
1720 |         timeout: Maximum time to wait in seconds
1721 |         check_interval: Interval between checks in seconds
1722 |     
1723 |     Returns:
1724 |         True if server is ready, False otherwise
1725 |     """
1726 |     import httpx
1727 |     import asyncio
1728 |     import time
1729 |     
1730 |     start_time = time.time()
1731 |     health_url = f"{url}/health"
1732 |     
1733 |     logging.info(f"Waiting for server at {url} to be ready...")
1734 |     
1735 |     while time.time() - start_time < timeout:
1736 |         try:
1737 |             async with httpx.AsyncClient(timeout=5.0) as client:
1738 |                 response = await client.get(health_url)
1739 |                 if response.status_code == 200:
1740 |                     logging.info(f"Server at {url} is ready!")
1741 |                     return True
1742 |         except Exception as e:
1743 |             elapsed = time.time() - start_time
1744 |             logging.info(f"Server not ready yet after {elapsed:.1f}s: {str(e)}")
1745 |             
1746 |         await asyncio.sleep(check_interval)
1747 |     
1748 |     logging.error(f"Timed out waiting for server at {url} after {timeout} seconds")
1749 |     return False
1750 | 
1751 | @app.function(
1752 |     image=llama_cpp_image,
1753 |     gpu=None,  # Will be set dynamically based on model
1754 |     volumes={
1755 |         f"{CACHE_DIR}/huggingface": hf_cache_vol,
1756 |         f"{CACHE_DIR}/llama_cpp": llama_cpp_cache_vol,
1757 |         RESULTS_DIR: results_vol,
1758 |     },
1759 |     timeout=30 * MINUTES,
1760 | )
1761 | async def run_llama_cpp_stream(
1762 |     model_id: str,
1763 |     prompt: str,
1764 |     n_predict: int = 1024,
1765 |     temperature: float = 0.7,
1766 |     request_id: str = None,
1767 | ):
1768 |     """
1769 |     Run streaming inference with llama.cpp for models like DeepSeek-R1 and Phi-4
1770 |     """
1771 |     import subprocess
1772 |     import os
1773 |     import json
1774 |     import time
1775 |     import threading
1776 |     from uuid import uuid4
1777 |     from pathlib import Path
1778 |     from huggingface_hub import snapshot_download
1779 |     
1780 |     if model_id not in LLAMA_CPP_MODELS:
1781 |         available_models = list(LLAMA_CPP_MODELS.keys())
1782 |         error_msg = f"Unknown model: {model_id}. Available models: {available_models}"
1783 |         logging.error(error_msg)
1784 |         
1785 |         if request_id:
1786 |             # Send error to stream
1787 |             stream_manager.add_chunk(request_id, {
1788 |                 "id": f"chatcmpl-{int(time.time())}",
1789 |                 "object": "chat.completion.chunk",
1790 |                 "created": int(time.time()),
1791 |                 "model": model_id,
1792 |                 "choices": [{
1793 |                     "index": 0,
1794 |                     "delta": {"content": f"Error: {error_msg}"},
1795 |                     "finish_reason": "error"
1796 |                 }]
1797 |             })
1798 |             stream_manager.finish_stream(request_id)
1799 |             # Signal completion
1800 |             stream_queues[request_id] = "DONE"
1801 |             
1802 |         raise ValueError(error_msg)
1803 |     
1804 |     model_info = LLAMA_CPP_MODELS[model_id]
1805 |     repo_id = model_info["name"]
1806 |     pattern = model_info["pattern"]
1807 |     revision = model_info["revision"]
1808 |     quant = model_info["quant"]
1809 |     
1810 |     # Download model if not already cached
1811 |     logging.info(f"Downloading model {repo_id} if not present")
1812 |     try:
1813 |         model_path = snapshot_download(
1814 |             repo_id=repo_id,
1815 |             revision=revision,
1816 |             local_dir=f"{CACHE_DIR}/llama_cpp",
1817 |             allow_patterns=[pattern],
1818 |         )
1819 |     except ValueError as e:
1820 |         if "hf_transfer" in str(e):
1821 |             # Fallback to standard download if hf_transfer fails
1822 |             logging.warning("hf_transfer failed, falling back to standard download")
1823 |             # Temporarily disable hf_transfer
1824 |             import os
1825 |             old_env = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "1")
1826 |             os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
1827 |             try:
1828 |                 model_path = snapshot_download(
1829 |                     repo_id=repo_id,
1830 |                     revision=revision,
1831 |                     local_dir=f"{CACHE_DIR}/llama_cpp",
1832 |                     allow_patterns=[pattern],
1833 |                 )
1834 |             finally:
1835 |                 # Restore original setting
1836 |                 os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_env
1837 |         else:
1838 |             raise
1839 |     
1840 |     # Find the model file
1841 |     model_files = list(Path(model_path).glob(pattern))
1842 |     if not model_files:
1843 |         error_msg = f"No model files found matching pattern {pattern}"
1844 |         logging.error(error_msg)
1845 |         
1846 |         if request_id:
1847 |             # Send error to stream
1848 |             stream_manager.add_chunk(request_id, {
1849 |                 "id": f"chatcmpl-{int(time.time())}",
1850 |                 "object": "chat.completion.chunk",
1851 |                 "created": int(time.time()),
1852 |                 "model": model_id,
1853 |                 "choices": [{
1854 |                     "index": 0,
1855 |                     "delta": {"content": f"Error: {error_msg}"},
1856 |                     "finish_reason": "error"
1857 |                 }]
1858 |             })
1859 |             stream_manager.finish_stream(request_id)
1860 |             # Signal completion
1861 |             stream_queues[request_id] = "DONE"
1862 |             
1863 |         raise FileNotFoundError(error_msg)
1864 |     
1865 |     model_file = str(model_files[0])
1866 |     logging.info(f"Using model file: {model_file}")
1867 |     
1868 |     # Set up command
1869 |     cmd = [
1870 |         "llama-cli",
1871 |         "--model", model_file,
1872 |         "--prompt", prompt,
1873 |         "--n-predict", str(n_predict),
1874 |         "--temp", str(temperature),
1875 |         "--ctx-size", "8192",
1876 |     ]
1877 |     
1878 |     # Add GPU layers if needed
1879 |     if model_info["gpu"] is not None:
1880 |         cmd.extend(["--n-gpu-layers", "9999"])  # Use all layers on GPU
1881 |     
1882 |     # Run inference
1883 |     result_id = str(uuid4())
1884 |     logging.info(f"Running streaming inference with ID: {result_id}")
1885 |     
1886 |     # Create response ID for streaming
1887 |     response_id = f"chatcmpl-{int(time.time())}"
1888 |     
1889 |     # Function to process output in real-time and send to stream
1890 |     def process_output(process, request_id):
1891 |         content_buffer = ""
1892 |         last_send_time = time.time()
1893 |         
1894 |         # Send initial chunk with role
1895 |         if request_id:
1896 |             stream_manager.add_chunk(request_id, {
1897 |                 "id": response_id,
1898 |                 "object": "chat.completion.chunk",
1899 |                 "created": int(time.time()),
1900 |                 "model": model_id,
1901 |                 "choices": [{
1902 |                     "index": 0,
1903 |                     "delta": {"role": "assistant"},
1904 |                 }]
1905 |             })
1906 |         
1907 |         for line in iter(process.stdout.readline, b''):
1908 |             try:
1909 |                 line_str = line.decode('utf-8', errors='replace')
1910 |                 
1911 |                 # Skip llama.cpp info lines
1912 |                 if line_str.startswith("llama_"):
1913 |                     continue
1914 |                 
1915 |                 # Add to buffer
1916 |                 content_buffer += line_str
1917 |                 
1918 |                 # Send chunks at reasonable intervals or when buffer gets large
1919 |                 now = time.time()
1920 |                 if (now - last_send_time > 0.1 or len(content_buffer) > 20) and request_id:
1921 |                     # Send chunk
1922 |                     stream_manager.add_chunk(request_id, {
1923 |                         "id": response_id,
1924 |                         "object": "chat.completion.chunk",
1925 |                         "created": int(time.time()),
1926 |                         "model": model_id,
1927 |                         "choices": [{
1928 |                             "index": 0,
1929 |                             "delta": {"content": content_buffer},
1930 |                         }]
1931 |                     })
1932 |                     
1933 |                     # Reset buffer and time
1934 |                     content_buffer = ""
1935 |                     last_send_time = now
1936 |                     
1937 |             except Exception as e:
1938 |                 logging.error(f"Error processing output: {str(e)}")
1939 |         
1940 |         # Send any remaining content
1941 |         if content_buffer and request_id:
1942 |             stream_manager.add_chunk(request_id, {
1943 |                 "id": response_id,
1944 |                 "object": "chat.completion.chunk",
1945 |                 "created": int(time.time()),
1946 |                 "model": model_id,
1947 |                 "choices": [{
1948 |                     "index": 0,
1949 |                     "delta": {"content": content_buffer},
1950 |                 }]
1951 |             })
1952 |         
1953 |         # Send final chunk with finish reason
1954 |         if request_id:
1955 |             stream_manager.add_chunk(request_id, {
1956 |                 "id": response_id,
1957 |                 "object": "chat.completion.chunk",
1958 |                 "created": int(time.time()),
1959 |                 "model": model_id,
1960 |                 "choices": [{
1961 |                     "index": 0,
1962 |                     "delta": {},
1963 |                     "finish_reason": "stop"
1964 |                 }]
1965 |             })
1966 |             
1967 |             # Finish stream
1968 |             stream_manager.finish_stream(request_id)
1969 |             
1970 |             # Signal completion
1971 |             stream_queues[request_id] = "DONE"
1972 |     
1973 |     # Start process
1974 |     process = subprocess.Popen(
1975 |         cmd, 
1976 |         stdout=subprocess.PIPE,
1977 |         stderr=subprocess.PIPE,
1978 |         text=False,
1979 |         bufsize=1  # Line buffered
1980 |     )
1981 |     
1982 |     # Start output processing thread if streaming
1983 |     if request_id:
1984 |         thread = threading.Thread(target=process_output, args=(process, request_id))
1985 |         thread.daemon = True
1986 |         thread.start()
1987 |         
1988 |         # Return immediately for streaming
1989 |         return "Streaming in progress"
1990 |     else:
1991 |         # For non-streaming, collect all output
1992 |         stdout, stderr = collect_output(process)
1993 |         
1994 |         # Save results
1995 |         result_dir = Path(RESULTS_DIR) / result_id
1996 |         result_dir.mkdir(parents=True, exist_ok=True)
1997 |         
1998 |         (result_dir / "output.txt").write_text(stdout)
1999 |         (result_dir / "stderr.txt").write_text(stderr)
2000 |         (result_dir / "prompt.txt").write_text(prompt)
2001 |         
2002 |         logging.info(f"Results saved to {result_dir}")
2003 |         return stdout
2004 | 
2005 | @app.function(
2006 |     image=llama_cpp_image,
2007 |     gpu=None,  # Will be set dynamically based on model
2008 |     volumes={
2009 |         f"{CACHE_DIR}/huggingface": hf_cache_vol,
2010 |         f"{CACHE_DIR}/llama_cpp": llama_cpp_cache_vol,
2011 |         RESULTS_DIR: results_vol,
2012 |     },
2013 |     timeout=30 * MINUTES,
2014 | )
2015 | async def run_llama_cpp(
2016 |     model_id: str,
2017 |     prompt: str = "Tell me about Modal and how it helps with ML deployments.",
2018 |     n_predict: int = 1024,
2019 |     temperature: float = 0.7,
2020 | ):
2021 |     """
2022 |     Run inference with llama.cpp for models like DeepSeek-R1 and Phi-4
2023 |     """
2024 |     import subprocess
2025 |     import os
2026 |     from uuid import uuid4
2027 |     from pathlib import Path
2028 |     from huggingface_hub import snapshot_download
2029 |     
2030 |     if model_id not in LLAMA_CPP_MODELS:
2031 |         available_models = list(LLAMA_CPP_MODELS.keys())
2032 |         print(f"Error: Unknown model: {model_id}. Available models: {available_models}")
2033 |         raise ValueError(f"Unknown model: {model_id}. Available models: {available_models}")
2034 |     
2035 |     model_info = LLAMA_CPP_MODELS[model_id]
2036 |     repo_id = model_info["name"]
2037 |     pattern = model_info["pattern"]
2038 |     revision = model_info["revision"]
2039 |     quant = model_info["quant"]
2040 |     
2041 |     # Download model if not already cached
2042 |     logging.info(f"Downloading model {repo_id} if not present")
2043 |     try:
2044 |         model_path = snapshot_download(
2045 |             repo_id=repo_id,
2046 |             revision=revision,
2047 |             local_dir=f"{CACHE_DIR}/llama_cpp",
2048 |             allow_patterns=[pattern],
2049 |         )
2050 |     except ValueError as e:
2051 |         if "hf_transfer" in str(e):
2052 |             # Fallback to standard download if hf_transfer fails
2053 |             logging.warning("hf_transfer failed, falling back to standard download")
2054 |             # Temporarily disable hf_transfer
2055 |             import os
2056 |             old_env = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "1")
2057 |             os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
2058 |             try:
2059 |                 model_path = snapshot_download(
2060 |                     repo_id=repo_id,
2061 |                     revision=revision,
2062 |                     local_dir=f"{CACHE_DIR}/llama_cpp",
2063 |                     allow_patterns=[pattern],
2064 |                 )
2065 |             finally:
2066 |                 # Restore original setting
2067 |                 os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_env
2068 |         else:
2069 |             raise
2070 |     
2071 |     # Find the model file
2072 |     model_files = list(Path(model_path).glob(pattern))
2073 |     if not model_files:
2074 |         logging.error(f"No model files found matching pattern {pattern}")
2075 |         raise FileNotFoundError(f"No model files found matching pattern {pattern}")
2076 |     
2077 |     model_file = str(model_files[0])
2078 |     print(f"Using model file: {model_file}")
2079 |     
2080 |     # Set up command
2081 |     cmd = [
2082 |         "llama-cli",
2083 |         "--model", model_file,
2084 |         "--prompt", prompt,
2085 |         "--n-predict", str(n_predict),
2086 |         "--temp", str(temperature),
2087 |         "--ctx-size", "8192",
2088 |     ]
2089 |     
2090 |     # Add GPU layers if needed
2091 |     if model_info["gpu"] is not None:
2092 |         cmd.extend(["--n-gpu-layers", "9999"])  # Use all layers on GPU
2093 |     
2094 |     # Run inference
2095 |     result_id = str(uuid4())
2096 |     print(f"Running inference with ID: {result_id}")
2097 |     
2098 |     process = subprocess.Popen(
2099 |         cmd, 
2100 |         stdout=subprocess.PIPE,
2101 |         stderr=subprocess.PIPE,
2102 |         text=False
2103 |     )
2104 |     
2105 |     stdout, stderr = collect_output(process)
2106 |     
2107 |     # Save results
2108 |     result_dir = Path(RESULTS_DIR) / result_id
2109 |     result_dir.mkdir(parents=True, exist_ok=True)
2110 |     
2111 |     (result_dir / "output.txt").write_text(stdout)
2112 |     (result_dir / "stderr.txt").write_text(stderr)
2113 |     (result_dir / "prompt.txt").write_text(prompt)
2114 |     
2115 |     print(f"Results saved to {result_dir}")
2116 |     return stdout
2117 | 
2118 | @app.function(
2119 |     image=vllm_image,
2120 |     volumes={
2121 |         f"{CACHE_DIR}/huggingface": hf_cache_vol,
2122 |     },
2123 | )
2124 | def list_available_models():
2125 |     """
2126 |     Lists available models that can be used with this server.
2127 | 
2128 |     Returns:
2129 |         dict: A dictionary containing lists of available vLLM and llama.cpp models.
2130 |     """
2131 |     print("Available vLLM models:")
2132 |     for model_id, model_info in VLLM_MODELS.items():
2133 |         print(f"- {model_id}: {model_info['name']}")
2134 |     
2135 |     print("\nAvailable llama.cpp models:")
2136 |     for model_id, model_info in LLAMA_CPP_MODELS.items():
2137 |         gpu_info = f"(GPU: {model_info['gpu']})" if model_info['gpu'] else "(CPU)"
2138 |         print(f"- {model_id}: {model_info['name']} {gpu_info}")
2139 |     
2140 |     return {
2141 |         "vllm": list(VLLM_MODELS.keys()),
2142 |         "llama_cpp": list(LLAMA_CPP_MODELS.keys())
2143 |     }
2144 | 
2145 | def collect_output(process):
2146 |     """
2147 |     Collect output from a process while streaming it.
2148 | 
2149 |     Args:
2150 |         process: The process from which to collect output.
2151 | 
2152 |     Returns:
2153 |         tuple: A tuple containing the collected stdout and stderr as strings.
2154 |     """
2155 |     import sys
2156 |     from queue import Queue
2157 |     from threading import Thread
2158 |     
2159 |     def stream_output(stream, queue, write_stream):
2160 |         for line in iter(stream.readline, b""):
2161 |             line_str = line.decode("utf-8", errors="replace")
2162 |             write_stream.write(line_str)
2163 |             write_stream.flush()
2164 |             queue.put(line_str)
2165 |         stream.close()
2166 |     
2167 |     stdout_queue = Queue()
2168 |     stderr_queue = Queue()
2169 |     
2170 |     stdout_thread = Thread(target=stream_output, args=(process.stdout, stdout_queue, sys.stdout))
2171 |     stderr_thread = Thread(target=stream_output, args=(process.stderr, stderr_queue, sys.stderr))
2172 |     
2173 |     stdout_thread.start()
2174 |     stderr_thread.start()
2175 |     
2176 |     stdout_thread.join()
2177 |     stderr_thread.join()
2178 |     process.wait()
2179 |     
2180 |     stdout_collected = "".join(list(stdout_queue.queue))
2181 |     stderr_collected = "".join(list(stderr_queue.queue))
2182 |     
2183 |     return stdout_collected, stderr_collected
2184 | 
2185 | # Main ASGI app for Modal
2186 | @app.function(
2187 |     image=vllm_image,
2188 |     gpu=None,  # No GPU for the API frontend
2189 |     allow_concurrent_inputs=100,
2190 |     volumes={
2191 |         f"{CACHE_DIR}/huggingface": hf_cache_vol,
2192 |     },
2193 | )
2194 | @modal.asgi_app()
2195 | def inference_api():
2196 |     """The main ASGI app that serves the FastAPI application"""
2197 |     return api_app
2198 | 
2199 | @app.local_entrypoint()
2200 | def main(
2201 |     prompt: str = "What can you tell me about Modal?",
2202 |     n_predict: int = 1024,
2203 |     temperature: float = 0.7,
2204 |     create_admin_key: bool = False,
2205 |     stream: bool = False,
2206 |     model: str = "auto",
2207 |     load_model: str = None,
2208 |     load_hf_model: str = None,
2209 |     hf_model_type: str = "vllm",
2210 | ):
2211 |     """
2212 |     Main entrypoint for testing the API
2213 |     """
2214 |     import json
2215 |     import time
2216 |     import urllib.request
2217 |     
2218 |     # Initialize the API
2219 |     print(f"Starting API at {inference_api.web_url}")
2220 |     
2221 |     # Wait for API to be ready
2222 |     print("Checking if API is ready...")
2223 |     up, start, delay = False, time.time(), 10
2224 |     while not up:
2225 |         try:
2226 |             with urllib.request.urlopen(inference_api.web_url + "/health") as response:
2227 |                 if response.getcode() == 200:
2228 |                     up = True
2229 |         except Exception:
2230 |             if time.time() - start > 5 * MINUTES:
2231 |                 break
2232 |             time.sleep(delay)
2233 | 
2234 |     assert up, f"Failed health check for API at {inference_api.web_url}"
2235 |     print(f"API is up and running at {inference_api.web_url}")
2236 |     
2237 |     # Create a test API key if requested
2238 |     if create_admin_key:
2239 |         print("Creating a test API key...")
2240 |         key_request = {
2241 |             "user_id": "test_user",
2242 |             "rate_limit": 120,
2243 |             "quota": 2000000
2244 |         }
2245 |         headers = {
2246 |             "Authorization": f"Bearer {DEFAULT_API_KEY}",  # Admin key
2247 |             "Content-Type": "application/json",
2248 |         }
2249 |         req = urllib.request.Request(
2250 |             inference_api.web_url + "/admin/api-keys",
2251 |             data=json.dumps(key_request).encode("utf-8"),
2252 |             headers=headers,
2253 |             method="POST",
2254 |         )
2255 |         try:
2256 |             with urllib.request.urlopen(req) as response:
2257 |                 result = json.loads(response.read().decode())
2258 |                 print("Created API key:")
2259 |                 print(json.dumps(result, indent=2))
2260 |                 # Use this key for the test message
2261 |                 test_key = result["key"]
2262 |         except Exception as e:
2263 |             print(f"Error creating API key: {str(e)}")
2264 |             test_key = DEFAULT_API_KEY
2265 |     else:
2266 |         test_key = DEFAULT_API_KEY
2267 |             
2268 |     # List available models
2269 |     print("\nAvailable models:")
2270 |     try:
2271 |         headers = {
2272 |             "Authorization": f"Bearer {test_key}",
2273 |             "Content-Type": "application/json",
2274 |         }
2275 |         req = urllib.request.Request(
2276 |             inference_api.web_url + "/v1/models",
2277 |             headers=headers,
2278 |             method="GET",
2279 |         )
2280 |         with urllib.request.urlopen(req) as response:
2281 |             models = json.loads(response.read().decode())
2282 |             print(json.dumps(models, indent=2))
2283 |     except Exception as e:
2284 |         print(f"Error listing models: {str(e)}")
2285 |         
2286 |     # Select best model for the prompt
2287 |     model = select_best_model(prompt, n_predict, temperature)
2288 |     
2289 |     # Send a test message
2290 |     print(f"\nSending a sample message to {inference_api.web_url}")
2291 |     messages = [{"role": "user", "content": prompt}]
2292 | 
2293 |     headers = {
2294 |         "Authorization": f"Bearer {test_key}",
2295 |         "Content-Type": "application/json",
2296 |     }
2297 |     payload = json.dumps({
2298 |         "messages": messages, 
2299 |         "model": model,
2300 |         "temperature": temperature,
2301 |         "max_tokens": n_predict,
2302 |         "stream": stream
2303 |     })
2304 |     req = urllib.request.Request(
2305 |         inference_api.web_url + "/v1/chat/completions",
2306 |         data=payload.encode("utf-8"),
2307 |         headers=headers,
2308 |         method="POST",
2309 |     )
2310 |     
2311 |     try:
2312 |         if stream:
2313 |             print("Streaming response:")
2314 |             with urllib.request.urlopen(req) as response:
2315 |                 for line in response:
2316 |                     line = line.decode('utf-8')
2317 |                     if line.startswith('data: '):
2318 |                         data = line[6:].strip()
2319 |                         if data == '[DONE]':
2320 |                             print("\n[DONE]")
2321 |                         else:
2322 |                             try:
2323 |                                 chunk = json.loads(data)
2324 |                                 if 'choices' in chunk and len(chunk['choices']) > 0:
2325 |                                     if 'delta' in chunk['choices'][0] and 'content' in chunk['choices'][0]['delta']:
2326 |                                         content = chunk['choices'][0]['delta']['content']
2327 |                                         print(content, end='', flush=True)
2328 |                             except json.JSONDecodeError:
2329 |                                 print(f"Error parsing: {data}")
2330 |         else:
2331 |             with urllib.request.urlopen(req) as response:
2332 |                 result = json.loads(response.read().decode())
2333 |                 print("Response:")
2334 |                 print(json.dumps(result, indent=2))
2335 |     except Exception as e:
2336 |         print(f"Error: {str(e)}")
2337 |     
2338 |     # Check API stats
2339 |     print("\nChecking API stats...")
2340 |     headers = {
2341 |         "Authorization": f"Bearer {DEFAULT_API_KEY}",  # Admin key
2342 |         "Content-Type": "application/json",
2343 |     }
2344 |     req = urllib.request.Request(
2345 |         inference_api.web_url + "/admin/stats",
2346 |         headers=headers,
2347 |         method="GET",
2348 |     )
2349 |     try:
2350 |         with urllib.request.urlopen(req) as response:
2351 |             stats = json.loads(response.read().decode())
2352 |             print("API Stats:")
2353 |             print(json.dumps(stats, indent=2))
2354 |     except Exception as e:
2355 |         print(f"Error getting stats: {str(e)}")
2356 |         
2357 |     # Start a worker if none running
2358 |     try:
2359 |         current_workers = stats.get("queue", {}).get("active_workers", 0)
2360 |         if current_workers < 1:
2361 |             print("\nStarting a queue worker...")
2362 |             process_queue_worker.spawn()
2363 |     except Exception as e:
2364 |         print(f"Error starting worker: {str(e)}")
2365 |         
2366 |     print(f"\nAPI is available at {inference_api.web_url}")
2367 |     print(f"Documentation is at {inference_api.web_url}/docs")
2368 |     print(f"Default Bearer token: {DEFAULT_API_KEY}")
2369 |     
2370 |     if create_admin_key:
2371 |         print(f"Test Bearer token: {test_key}")
2372 |         
2373 |     # If a model was specified to load, load it
2374 |     if load_model:
2375 |         print(f"\nLoading model: {load_model}")
2376 |         load_url = f"{inference_api.web_url}/admin/models/load"
2377 |         headers = {
2378 |             "Authorization": f"Bearer {test_key}",
2379 |             "Content-Type": "application/json",
2380 |         }
2381 |         payload = json.dumps({
2382 |             "model_id": load_model,
2383 |             "force_reload": True
2384 |         })
2385 |         req = urllib.request.Request(
2386 |             load_url,
2387 |             data=payload.encode("utf-8"),
2388 |             headers=headers,
2389 |             method="POST",
2390 |         )
2391 |         try:
2392 |             with urllib.request.urlopen(req) as response:
2393 |                 result = json.loads(response.read().decode())
2394 |                 print("Load response:")
2395 |                 print(json.dumps(result, indent=2))
2396 |                 
2397 |                 # If it's a small model, wait a bit for it to load
2398 |                 if load_model in ["tiny-llama-1.1b", "phi-2"]:
2399 |                     print(f"Waiting for {load_model} to load...")
2400 |                     time.sleep(10)
2401 |                     
2402 |                     # Check status
2403 |                     status_url = f"{inference_api.web_url}/admin/models/status/{load_model}"
2404 |                     status_req = urllib.request.Request(
2405 |                         status_url,
2406 |                         headers={"Authorization": f"Bearer {test_key}"},
2407 |                         method="GET",
2408 |                     )
2409 |                     with urllib.request.urlopen(status_req) as status_response:
2410 |                         status_result = json.loads(status_response.read().decode())
2411 |                         print("Model status:")
2412 |                         print(json.dumps(status_result, indent=2))
2413 |                 
2414 |                 # Use this model for the test
2415 |                 model = load_model
2416 |         except Exception as e:
2417 |             print(f"Error loading model: {str(e)}")
2418 |             
2419 |     # If a HF model was specified to load directly
2420 |     if load_hf_model:
2421 |         print(f"\nLoading HF model: {load_hf_model} with type {hf_model_type}")
2422 |         load_url = f"{inference_api.web_url}/admin/models/load-from-hf"
2423 |         headers = {
2424 |             "Authorization": f"Bearer {test_key}",
2425 |             "Content-Type": "application/json",
2426 |         }
2427 |         payload = json.dumps({
2428 |             "repo_id": load_hf_model,
2429 |             "model_type": hf_model_type,
2430 |             "max_tokens": n_predict
2431 |         })
2432 |         req = urllib.request.Request(
2433 |             load_url,
2434 |             data=payload.encode("utf-8"),
2435 |             headers=headers,
2436 |             method="POST",
2437 |         )
2438 |         try:
2439 |             with urllib.request.urlopen(req) as response:
2440 |                 result = json.loads(response.read().decode())
2441 |                 print("HF Load response:")
2442 |                 print(json.dumps(result, indent=2))
2443 |                 
2444 |                 # Get the model_id from the response
2445 |                 hf_model_id = result.get("model_id")
2446 |                 
2447 |                 # Wait a bit for it to start loading
2448 |                 print(f"Waiting for {load_hf_model} to start loading...")
2449 |                 time.sleep(5)
2450 |                 
2451 |                 # Check status
2452 |                 if hf_model_id:
2453 |                     status_url = f"{inference_api.web_url}/admin/models/status/{hf_model_id}"
2454 |                     status_req = urllib.request.Request(
2455 |                         status_url,
2456 |                         headers={"Authorization": f"Bearer {test_key}"},
2457 |                         method="GET",
2458 |                     )
2459 |                     with urllib.request.urlopen(status_req) as status_response:
2460 |                         status_result = json.loads(status_response.read().decode())
2461 |                         print("Model status:")
2462 |                         print(json.dumps(status_result, indent=2))
2463 |                 
2464 |                 # Use this model for the test
2465 |                 if hf_model_id:
2466 |                     model = hf_model_id
2467 |         except Exception as e:
2468 |             print(f"Error loading HF model: {str(e)}")
2469 | 
2470 |     # Show curl examples
2471 |     print("\nExample curl commands:")
2472 |     
2473 |     # Regular completion
2474 |     print(f"""# Regular completion:
2475 | curl -X POST {inference_api.web_url}/v1/chat/completions \\
2476 |   -H "Content-Type: application/json" \\
2477 |   -H "Authorization: Bearer {test_key}" \\
2478 |   -d '{{
2479 |     "model": "{model}",
2480 |     "messages": [
2481 |       {{
2482 |         "role": "user",
2483 |         "content": "Hello, how can you help me today?"
2484 |       }}
2485 |     ],
2486 |     "temperature": 0.7,
2487 |     "max_tokens": 500
2488 |   }}'""")
2489 |     
2490 |     # Streaming completion
2491 |     print(f"""\n# Streaming completion:
2492 | curl -X POST {inference_api.web_url}/v1/chat/completions \\
2493 |   -H "Content-Type: application/json" \\
2494 |   -H "Authorization: Bearer {test_key}" \\
2495 |   -d '{{
2496 |     "model": "{model}",
2497 |     "messages": [
2498 |       {{
2499 |         "role": "user",
2500 |         "content": "Write a short story about AI"
2501 |       }}
2502 |     ],
2503 |     "temperature": 0.8,
2504 |     "max_tokens": 1000,
2505 |     "stream": true
2506 |   }}' --no-buffer""")
2507 |     
2508 |     # List models
2509 |     print(f"""\n# List available models:
2510 | curl -X GET {inference_api.web_url}/v1/models \\
2511 |   -H "Authorization: Bearer {test_key}" """)
2512 |     
2513 |     # Model management commands
2514 |     print(f"""\n# Load a model:
2515 | curl -X POST {inference_api.web_url}/admin/models/load \\
2516 |   -H "Content-Type: application/json" \\
2517 |   -H "Authorization: Bearer {test_key}" \\
2518 |   -d '{{
2519 |     "model_id": "phi-2",
2520 |     "force_reload": false
2521 |   }}'""")
2522 |     
2523 |     print(f"""\n# Load a model directly from Hugging Face:
2524 | curl -X POST {inference_api.web_url}/admin/models/load-from-hf \\
2525 |   -H "Content-Type: application/json" \\
2526 |   -H "Authorization: Bearer {test_key}" \\
2527 |   -d '{{
2528 |     "repo_id": "microsoft/phi-2",
2529 |     "model_type": "vllm",
2530 |     "max_tokens": 4096
2531 |   }}'""")
2532 |     
2533 |     print(f"""\n# Get model status:
2534 | curl -X GET {inference_api.web_url}/admin/models/status/phi-2 \\
2535 |   -H "Authorization: Bearer {test_key}" """)
2536 |     
2537 |     print(f"""\n# Unload a model:
2538 | curl -X POST {inference_api.web_url}/admin/models/unload \\
2539 |   -H "Content-Type: application/json" \\
2540 |   -H "Authorization: Bearer {test_key}" \\
2541 |   -d '{{
2542 |     "model_id": "phi-2"
2543 |   }}'""")
2544 | async def preload_llama_cpp_model(model_id: str):
2545 |     """Preload a llama.cpp model to make inference faster on first request"""
2546 |     if model_id not in LLAMA_CPP_MODELS:
2547 |         logging.error(f"Unknown model: {model_id}")
2548 |         return
2549 |     
2550 |     try:
2551 |         # Run a simple inference to load the model
2552 |         logging.info(f"Preloading llama.cpp model: {model_id}")
2553 |         await run_llama_cpp.remote(
2554 |             model_id=model_id,
2555 |             prompt="Hello, this is a test to preload the model.",
2556 |             n_predict=10,
2557 |             temperature=0.7
2558 |         )
2559 |         logging.info(f"Successfully preloaded llama.cpp model: {model_id}")
2560 |     except Exception as e:
2561 |         logging.error(f"Error preloading llama.cpp model {model_id}: {str(e)}")
2562 |         # Mark as not loaded
2563 |         LLAMA_CPP_MODELS[model_id]["loaded"] = False
2564 | 
```
Page 5/6FirstPrevNextLast