This is page 5 of 6. Use http://codebase.md/arthurcolle/openai-mcp?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .gitignore
├── claude_code
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-312.pyc
│ │ └── mcp_server.cpython-312.pyc
│ ├── claude.py
│ ├── commands
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-312.pyc
│ │ │ └── serve.cpython-312.pyc
│ │ ├── client.py
│ │ ├── multi_agent_client.py
│ │ └── serve.py
│ ├── config
│ │ └── __init__.py
│ ├── examples
│ │ ├── agents_config.json
│ │ ├── claude_mcp_config.html
│ │ ├── claude_mcp_config.json
│ │ ├── echo_server.py
│ │ └── README.md
│ ├── lib
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ └── __init__.cpython-312.pyc
│ │ ├── context
│ │ │ └── __init__.py
│ │ ├── monitoring
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── __init__.cpython-312.pyc
│ │ │ │ └── server_metrics.cpython-312.pyc
│ │ │ ├── cost_tracker.py
│ │ │ └── server_metrics.py
│ │ ├── providers
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ └── openai.py
│ │ ├── rl
│ │ │ ├── __init__.py
│ │ │ ├── grpo.py
│ │ │ ├── mcts.py
│ │ │ └── tool_optimizer.py
│ │ ├── tools
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── __init__.cpython-312.pyc
│ │ │ │ ├── base.cpython-312.pyc
│ │ │ │ ├── file_tools.cpython-312.pyc
│ │ │ │ └── manager.cpython-312.pyc
│ │ │ ├── ai_tools.py
│ │ │ ├── base.py
│ │ │ ├── code_tools.py
│ │ │ ├── file_tools.py
│ │ │ ├── manager.py
│ │ │ └── search_tools.py
│ │ └── ui
│ │ ├── __init__.py
│ │ └── tool_visualizer.py
│ ├── mcp_server.py
│ ├── README_MCP_CLIENT.md
│ ├── README_MULTI_AGENT.md
│ └── util
│ └── __init__.py
├── claude.py
├── cli.py
├── data
│ └── prompt_templates.json
├── deploy_modal_mcp.py
├── deploy.sh
├── examples
│ ├── agents_config.json
│ └── echo_server.py
├── install.sh
├── mcp_modal_adapter.py
├── mcp_server.py
├── modal_mcp_server.py
├── README_modal_mcp.md
├── README.md
├── requirements.txt
├── setup.py
├── static
│ └── style.css
├── templates
│ └── index.html
└── web-client.html
```
# Files
--------------------------------------------------------------------------------
/modal_mcp_server.py:
--------------------------------------------------------------------------------
```python
1 | import modal
2 | import logging
3 | import time
4 | import uuid
5 | import json
6 | import asyncio
7 | import hashlib
8 | import threading
9 | import concurrent.futures
10 | from pathlib import Path
11 | from typing import Dict, List, Optional, Any, Tuple, Union, AsyncIterator
12 | from datetime import datetime, timedelta
13 | from collections import deque
14 |
15 | from fastapi import FastAPI, Request, Depends, HTTPException, status, BackgroundTasks
16 | from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
17 | from fastapi.responses import JSONResponse, HTMLResponse
18 | from fastapi.middleware.cors import CORSMiddleware
19 | from pydantic import BaseModel, Field
20 |
21 | # Configure logging
22 | logging.basicConfig(
23 | level=logging.INFO,
24 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
25 | )
26 |
27 | # Create FastAPI app
28 | api_app = FastAPI(
29 | title="Advanced LLM Inference API",
30 | description="Enterprise-grade OpenAI-compatible LLM serving API with multiple model support, streaming, and advanced caching",
31 | version="1.1.0"
32 | )
33 |
34 | # Add CORS middleware
35 | api_app.add_middleware(
36 | CORSMiddleware,
37 | allow_origins=["*"], # For production, specify specific origins instead of wildcard
38 | allow_credentials=True,
39 | allow_methods=["*"],
40 | allow_headers=["*"],
41 | )
42 |
43 | # Security setup
44 | security = HTTPBearer()
45 |
46 | # Token bucket rate limiter
47 | class TokenBucket:
48 | """
49 | Token bucket algorithm for rate limiting.
50 | Each user gets a bucket that fills at a constant rate.
51 | """
52 | def __init__(self):
53 | self.buckets = {}
54 | self.lock = threading.Lock()
55 |
56 | def _get_bucket(self, user_id, rate_limit):
57 | """Get or create a bucket for a user"""
58 | now = time.time()
59 |
60 | if user_id not in self.buckets:
61 | # Initialize with full bucket
62 | self.buckets[user_id] = {
63 | "tokens": rate_limit,
64 | "last_refill": now,
65 | "rate": rate_limit / 60.0 # tokens per second
66 | }
67 | return self.buckets[user_id]
68 |
69 | bucket = self.buckets[user_id]
70 |
71 | # Update rate if it changed
72 | bucket["rate"] = rate_limit / 60.0
73 |
74 | # Refill tokens based on time elapsed
75 | elapsed = now - bucket["last_refill"]
76 | new_tokens = elapsed * bucket["rate"]
77 |
78 | bucket["tokens"] = min(rate_limit, bucket["tokens"] + new_tokens)
79 | bucket["last_refill"] = now
80 |
81 | return bucket
82 |
83 | def consume(self, user_id, tokens=1, rate_limit=60):
84 | """
85 | Consume tokens from a user's bucket.
86 | Returns True if tokens were consumed, False otherwise.
87 | """
88 | with self.lock:
89 | bucket = self._get_bucket(user_id, rate_limit)
90 |
91 | if bucket["tokens"] >= tokens:
92 | bucket["tokens"] -= tokens
93 | return True
94 | return False
95 |
96 | # Create rate limiter
97 | rate_limiter = TokenBucket()
98 |
99 | # Define the container image with necessary dependencies
100 | vllm_image = (
101 | modal.Image.debian_slim(python_version="3.10")
102 | .pip_install(
103 | "vllm==0.7.3", # Updated version
104 | "huggingface_hub[hf_transfer]==0.26.2",
105 | "flashinfer-python==0.2.0.post2",
106 | "fastapi>=0.95.0",
107 | "uvicorn>=0.15.0",
108 | "pydantic>=2.0.0",
109 | "tiktoken>=0.5.1",
110 | extra_index_url="https://flashinfer.ai/whl/cu124/torch2.5",
111 | )
112 | .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) # faster model transfers
113 | .env({"VLLM_USE_V1": "1"}) # Enable V1 engine for better performance
114 | )
115 |
116 | # Define llama.cpp image for alternative models
117 | llama_cpp_image = (
118 | modal.Image.debian_slim(python_version="3.10")
119 | .apt_install("git", "build-essential", "cmake", "curl", "libcurl4-openssl-dev")
120 | .pip_install(
121 | "huggingface_hub==0.26.2",
122 | "hf_transfer>=0.1.4",
123 | "fastapi>=0.95.0",
124 | "uvicorn>=0.15.0",
125 | "pydantic>=2.0.0"
126 | )
127 | .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
128 | .run_commands(
129 | "git clone https://github.com/ggerganov/llama.cpp",
130 | "cmake llama.cpp -B llama.cpp/build -DBUILD_SHARED_LIBS=OFF -DLLAMA_CURL=ON",
131 | "cmake --build llama.cpp/build --config Release -j --target llama-cli",
132 | "cp llama.cpp/build/bin/llama-* /usr/local/bin/"
133 | )
134 | )
135 |
136 | # Set up model configurations
137 | MODELS_DIR = "/models"
138 | VLLM_MODELS = {
139 | "llama3-8b": {
140 | "id": "llama3-8b",
141 | "name": "neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w4a16",
142 | "config": "config.json", # Ensure this file is present in the model directory
143 | "revision": "a7c09948d9a632c2c840722f519672cd94af885d",
144 | "max_tokens": 4096,
145 | "loaded": False
146 | },
147 | "mistral-7b": {
148 | "id": "mistral-7b",
149 | "name": "mistralai/Mistral-7B-Instruct-v0.2",
150 | "revision": "main",
151 | "max_tokens": 4096,
152 | "loaded": False
153 | },
154 | # Small model for quick loading
155 | "tiny-llama-1.1b": {
156 | "id": "tiny-llama-1.1b",
157 | "name": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
158 | "revision": "main",
159 | "max_tokens": 2048,
160 | "loaded": False
161 | }
162 | }
163 |
164 | LLAMA_CPP_MODELS = {
165 | "deepseek-r1": {
166 | "id": "deepseek-r1",
167 | "name": "unsloth/DeepSeek-R1-GGUF",
168 | "quant": "UD-IQ1_S",
169 | "pattern": "*UD-IQ1_S*",
170 | "revision": "02656f62d2aa9da4d3f0cdb34c341d30dd87c3b6",
171 | "gpu": "L40S:4",
172 | "max_tokens": 4096,
173 | "loaded": False
174 | },
175 | "phi-4": {
176 | "id": "phi-4",
177 | "name": "unsloth/phi-4-GGUF",
178 | "quant": "Q2_K",
179 | "pattern": "*Q2_K*",
180 | "revision": None,
181 | "gpu": "L40S:4", # Use GPU for better performance
182 | "max_tokens": 4096,
183 | "loaded": False
184 | },
185 | # Small model for quick loading
186 | "phi-2": {
187 | "id": "phi-2",
188 | "name": "TheBloke/phi-2-GGUF",
189 | "quant": "Q4_K_M",
190 | "pattern": "*Q4_K_M.gguf",
191 | "revision": "main",
192 | "gpu": None, # Can run on CPU
193 | "max_tokens": 2048,
194 | "loaded": False
195 | }
196 | }
197 |
198 | DEFAULT_MODEL = "phi-4"
199 |
200 | # Create volumes for caching
201 | hf_cache_vol = modal.Volume.from_name("huggingface-cache", create_if_missing=True)
202 | vllm_cache_vol = modal.Volume.from_name("vllm-cache", create_if_missing=True)
203 | llama_cpp_cache_vol = modal.Volume.from_name("llama-cpp-cache", create_if_missing=True)
204 | results_vol = modal.Volume.from_name("model-results", create_if_missing=True)
205 |
206 | # Create the Modal app
207 | app = modal.App("openai-compatible-llm-server")
208 |
209 | # Create shared data structures
210 | model_stats_dict = modal.Dict.from_name("model-stats", create_if_missing=True)
211 | user_usage_dict = modal.Dict.from_name("user-usage", create_if_missing=True)
212 | request_queue = modal.Queue.from_name("request-queue", create_if_missing=True)
213 | response_dict = modal.Dict.from_name("response-cache", create_if_missing=True)
214 | api_keys_dict = modal.Dict.from_name("api-keys", create_if_missing=True)
215 | stream_queues = modal.Dict.from_name("stream-queues", create_if_missing=True)
216 |
217 | # Advanced caching system
218 | class AdvancedCache:
219 | """
220 | Advanced caching system with TTL and LRU eviction.
221 | """
222 | def __init__(self, max_size=1000, default_ttl=3600):
223 | self.cache = {}
224 | self.ttl_map = {}
225 | self.access_times = {}
226 | self.max_size = max_size
227 | self.default_ttl = default_ttl
228 | self.lock = threading.Lock()
229 |
230 | def get(self, key):
231 | """Get a value from the cache"""
232 | with self.lock:
233 | now = time.time()
234 |
235 | # Check if key exists and is not expired
236 | if key in self.cache:
237 | # Check TTL
238 | if key in self.ttl_map and self.ttl_map[key] < now:
239 | # Expired
240 | self._remove(key)
241 | return None
242 |
243 | # Update access time
244 | self.access_times[key] = now
245 | return self.cache[key]
246 |
247 | return None
248 |
249 | def set(self, key, value, ttl=None):
250 | """Set a value in the cache with optional TTL"""
251 | with self.lock:
252 | now = time.time()
253 |
254 | # Evict if needed
255 | if len(self.cache) >= self.max_size and key not in self.cache:
256 | self._evict_lru()
257 |
258 | # Set value
259 | self.cache[key] = value
260 | self.access_times[key] = now
261 |
262 | # Set TTL
263 | if ttl is not None:
264 | self.ttl_map[key] = now + ttl
265 | elif self.default_ttl > 0:
266 | self.ttl_map[key] = now + self.default_ttl
267 |
268 | def _remove(self, key):
269 | """Remove a key from the cache"""
270 | if key in self.cache:
271 | del self.cache[key]
272 | if key in self.ttl_map:
273 | del self.ttl_map[key]
274 | if key in self.access_times:
275 | del self.access_times[key]
276 |
277 | def _evict_lru(self):
278 | """Evict least recently used item"""
279 | if not self.access_times:
280 | return
281 |
282 | # Find oldest access time
283 | oldest_key = min(self.access_times.items(), key=lambda x: x[1])[0]
284 | self._remove(oldest_key)
285 |
286 | def clear_expired(self):
287 | """Clear all expired entries"""
288 | with self.lock:
289 | now = time.time()
290 | expired_keys = [k for k, v in self.ttl_map.items() if v < now]
291 | for key in expired_keys:
292 | self._remove(key)
293 |
294 | # Constants
295 | MAX_CACHE_AGE = 3600 # 1 hour in seconds
296 |
297 | # Create memory cache
298 | memory_cache = AdvancedCache(max_size=10000, default_ttl=MAX_CACHE_AGE)
299 |
300 | # Initialize with default key if empty
301 | if "default" not in api_keys_dict:
302 | api_keys_dict["default"] = {
303 | "key": "sk-modal-llm-api-key",
304 | "rate_limit": 60, # requests per minute
305 | "quota": 1000000, # tokens per day
306 | "created_at": datetime.now().isoformat(),
307 | "owner": "default"
308 | }
309 |
310 | # Add a default ADMIN API key
311 | if "admin" not in api_keys_dict:
312 | api_keys_dict["admin"] = {
313 | "key": "sk-modal-admin-api-key",
314 | "rate_limit": 1000, # Higher rate limit for admin
315 | "quota": 10000000, # Higher quota for admin
316 | "created_at": datetime.now().isoformat(),
317 | "owner": "admin"
318 | }
319 |
320 | # Constants
321 | DEFAULT_API_KEY = api_keys_dict["default"]["key"]
322 | MINUTES = 60 # seconds
323 | SERVER_PORT = 8000
324 | CACHE_DIR = "/root/.cache"
325 | RESULTS_DIR = "/root/results"
326 |
327 | # Request/response models
328 | class GenerationRequest(BaseModel):
329 | request_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
330 | model_id: str
331 | messages: List[Dict[str, str]]
332 | temperature: float = 0.7
333 | max_tokens: int = 1024
334 | top_p: float = 1.0
335 | frequency_penalty: float = 0.0
336 | presence_penalty: float = 0.0
337 | user: Optional[str] = None
338 | stream: bool = False
339 | timestamp: float = Field(default_factory=time.time)
340 | api_key: str = DEFAULT_API_KEY
341 |
342 | class StreamChunk(BaseModel):
343 | """Model for streaming response chunks"""
344 | id: str
345 | object: str = "chat.completion.chunk"
346 | created: int
347 | model: str
348 | choices: List[Dict[str, Any]]
349 |
350 | class StreamManager:
351 | """Manages streaming responses for clients"""
352 | def __init__(self):
353 | self.streams = {}
354 | self.lock = threading.Lock()
355 |
356 | def create_stream(self, request_id):
357 | """Create a new stream for a request"""
358 | with self.lock:
359 | self.streams[request_id] = {
360 | "queue": asyncio.Queue(),
361 | "finished": False,
362 | "created_at": time.time()
363 | }
364 |
365 | def add_chunk(self, request_id, chunk):
366 | """Add a chunk to a stream"""
367 | with self.lock:
368 | if request_id in self.streams:
369 | stream = self.streams[request_id]
370 | if not stream["finished"]:
371 | stream["queue"].put_nowait(chunk)
372 |
373 | def finish_stream(self, request_id):
374 | """Mark a stream as finished"""
375 | with self.lock:
376 | if request_id in self.streams:
377 | self.streams[request_id]["finished"] = True
378 | # Add None to signal end of stream
379 | self.streams[request_id]["queue"].put_nowait(None)
380 |
381 | async def get_chunks(self, request_id):
382 | """Get chunks from a stream as an async generator"""
383 | if request_id not in self.streams:
384 | return
385 |
386 | stream = self.streams[request_id]
387 | queue = stream["queue"]
388 |
389 | while True:
390 | chunk = await queue.get()
391 | if chunk is None: # End of stream
392 | break
393 | yield chunk
394 | queue.task_done()
395 |
396 | # Clean up after streaming is done
397 | with self.lock:
398 | if request_id in self.streams:
399 | del self.streams[request_id]
400 |
401 | def clean_old_streams(self, max_age=3600):
402 | """Clean up old streams"""
403 | with self.lock:
404 | now = time.time()
405 | to_remove = []
406 |
407 | for request_id, stream in self.streams.items():
408 | if now - stream["created_at"] > max_age:
409 | to_remove.append(request_id)
410 |
411 | for request_id in to_remove:
412 | if request_id in self.streams:
413 | # Mark as finished to stop any ongoing processing
414 | self.streams[request_id]["finished"] = True
415 | # Add None to unblock any waiting consumers
416 | self.streams[request_id]["queue"].put_nowait(None)
417 | # Remove from streams
418 | del self.streams[request_id]
419 |
420 | # Create stream manager
421 | stream_manager = StreamManager()
422 |
423 | # API Authentication dependency
424 | def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
425 | """Verify that the API key in the authorization header is valid and check rate limits"""
426 | if credentials.scheme != "Bearer":
427 | raise HTTPException(
428 | status_code=status.HTTP_401_UNAUTHORIZED,
429 | detail="Invalid authentication scheme. Use Bearer",
430 | )
431 |
432 | api_key = credentials.credentials
433 | valid_key = False
434 | key_info = None
435 |
436 | # Check if this is a known API key
437 | for user_id, user_data in api_keys_dict.items():
438 | if user_data.get("key") == api_key:
439 | valid_key = True
440 | key_info = user_data
441 | break
442 |
443 | if not valid_key:
444 | raise HTTPException(
445 | status_code=status.HTTP_401_UNAUTHORIZED,
446 | detail="Invalid API key",
447 | )
448 |
449 | # Check rate limits
450 | user_id = key_info.get("owner", "unknown")
451 | rate_limit = key_info.get("rate_limit", 60) # Default: 60 requests per minute
452 |
453 | # Get or initialize user usage tracking
454 | if user_id not in user_usage_dict:
455 | user_usage_dict[user_id] = {
456 | "requests": [],
457 | "tokens": {
458 | "input": 0,
459 | "output": 0,
460 | "last_reset": datetime.now().isoformat()
461 | }
462 | }
463 |
464 | usage = user_usage_dict[user_id]
465 |
466 | # Check if user exceeded rate limit using token bucket algorithm
467 | if not rate_limiter.consume(user_id, tokens=1, rate_limit=rate_limit):
468 | # Calculate retry-after based on rate
469 | retry_after = int(60 / rate_limit) # seconds until at least one token is available
470 |
471 | raise HTTPException(
472 | status_code=status.HTTP_429_TOO_MANY_REQUESTS,
473 | detail=f"Rate limit exceeded. Maximum {rate_limit} requests per minute.",
474 | headers={"Retry-After": str(retry_after)}
475 | )
476 |
477 | # Add current request timestamp for analytics
478 | now = datetime.now()
479 | usage["requests"].append(now.timestamp())
480 |
481 | # Clean up old requests (older than 1 day) to prevent unbounded growth
482 | day_ago = (now - timedelta(days=1)).timestamp()
483 | usage["requests"] = [req for req in usage["requests"] if req > day_ago]
484 |
485 | # Update usage dict
486 | user_usage_dict[user_id] = usage
487 |
488 | # Return the API key and user ID
489 | return {"key": api_key, "user_id": user_id}
490 |
491 | # API Endpoints
492 | @api_app.get("/", response_class=HTMLResponse)
493 | async def index():
494 | """Root endpoint that returns HTML with API information"""
495 | return """
496 | <html>
497 | <head>
498 | <title>Modal LLM Inference API</title>
499 | <style>
500 | body { font-family: system-ui, sans-serif; max-width: 800px; margin: 0 auto; padding: 2rem; }
501 | h1 { color: #4a56e2; }
502 | code { background: #f4f4f8; padding: 0.2rem 0.4rem; border-radius: 3px; }
503 | </style>
504 | </head>
505 | <body>
506 | <h1>Modal LLM Inference API</h1>
507 | <p>This is an OpenAI-compatible API for LLM inference powered by Modal.</p>
508 | <p>Use the following endpoints:</p>
509 | <ul>
510 | <li><a href="/docs">/docs</a> - API documentation</li>
511 | <li><a href="/v1/models">/v1/models</a> - List available models</li>
512 | <li><code>/v1/chat/completions</code> - Chat completions endpoint</li>
513 | </ul>
514 | </body>
515 | </html>
516 | """
517 |
518 | @api_app.get("/health")
519 | async def health_check():
520 | """Health check endpoint"""
521 | return {"status": "healthy"}
522 |
523 | @api_app.get("/v1/models", dependencies=[Depends(verify_api_key)])
524 | async def list_models():
525 | """List all available models in OpenAI-compatible format"""
526 | # Combine vLLM and llama.cpp models
527 | all_models = []
528 |
529 | for model_id, model_info in VLLM_MODELS.items():
530 | all_models.append({
531 | "id": model_info["id"],
532 | "object": "model",
533 | "created": 1677610602,
534 | "owned_by": "modal",
535 | "engine": "vllm",
536 | "loaded": model_info.get("loaded", False)
537 | })
538 |
539 | for model_id, model_info in LLAMA_CPP_MODELS.items():
540 | all_models.append({
541 | "id": model_info["id"],
542 | "object": "model",
543 | "created": 1677610602,
544 | "owned_by": "modal",
545 | "engine": "llama.cpp",
546 | "loaded": model_info.get("loaded", False)
547 | })
548 |
549 | return {"data": all_models, "object": "list"}
550 |
551 | # Model management endpoints
552 | class ModelLoadRequest(BaseModel):
553 | """Request model to load a specific model"""
554 | model_id: str
555 | force_reload: bool = False
556 |
557 | class HFModelLoadRequest(BaseModel):
558 | """Request to load a model directly from Hugging Face"""
559 | repo_id: str
560 | model_type: str = "vllm" # "vllm" or "llama.cpp"
561 | revision: Optional[str] = None
562 | quant: Optional[str] = None # For llama.cpp models
563 | max_tokens: int = 4096
564 | gpu: Optional[str] = None # For llama.cpp models
565 |
566 | @api_app.post("/admin/models/load", dependencies=[Depends(verify_api_key)])
567 | async def load_model(request: ModelLoadRequest, background_tasks: BackgroundTasks):
568 | """Load a specific model into memory"""
569 | model_id = request.model_id
570 | force_reload = request.force_reload
571 |
572 | # Check if model exists
573 | if model_id in VLLM_MODELS:
574 | model_type = "vllm"
575 | model_info = VLLM_MODELS[model_id]
576 | elif model_id in LLAMA_CPP_MODELS:
577 | model_type = "llama.cpp"
578 | model_info = LLAMA_CPP_MODELS[model_id]
579 | else:
580 | raise HTTPException(
581 | status_code=status.HTTP_404_NOT_FOUND,
582 | detail=f"Model {model_id} not found"
583 | )
584 |
585 | # Check if model is already loaded
586 | if model_info.get("loaded", False) and not force_reload:
587 | return {
588 | "status": "success",
589 | "message": f"Model {model_id} is already loaded",
590 | "model_id": model_id,
591 | "model_type": model_type
592 | }
593 |
594 | # Start loading the model in the background
595 | if model_type == "vllm":
596 | # Start vLLM server for this model
597 | background_tasks.add_task(serve_vllm_model.remote, model_id=model_id)
598 | # Update model status
599 | VLLM_MODELS[model_id]["loaded"] = True
600 | else: # llama.cpp
601 | # For llama.cpp models, we'll preload the model
602 | background_tasks.add_task(preload_llama_cpp_model, model_id)
603 | # Update model status
604 | LLAMA_CPP_MODELS[model_id]["loaded"] = True
605 |
606 | return {
607 | "status": "success",
608 | "message": f"Started loading model {model_id}",
609 | "model_id": model_id,
610 | "model_type": model_type
611 | }
612 |
613 | @api_app.post("/admin/models/load-from-hf", dependencies=[Depends(verify_api_key)])
614 | async def load_model_from_hf(request: HFModelLoadRequest, background_tasks: BackgroundTasks):
615 | """Load a model directly from Hugging Face"""
616 | repo_id = request.repo_id
617 | model_type = request.model_type
618 | revision = request.revision
619 |
620 | # Generate a unique model_id based on the repo name
621 | repo_name = repo_id.split("/")[-1] if "/" in repo_id else repo_id
622 | model_id = f"hf-{repo_name}-{uuid.uuid4().hex[:6]}"
623 |
624 | # Create model info based on type
625 | if model_type.lower() == "vllm":
626 | # Add to VLLM_MODELS
627 | VLLM_MODELS[model_id] = {
628 | "id": model_id,
629 | "name": repo_id,
630 | "revision": revision or "main",
631 | "max_tokens": request.max_tokens,
632 | "loaded": False,
633 | "hf_direct": True # Mark as directly loaded from HF
634 | }
635 |
636 | # Start vLLM server for this model
637 | background_tasks.add_task(serve_vllm_model.remote, model_id=model_id)
638 | # Update model status
639 | VLLM_MODELS[model_id]["loaded"] = True
640 |
641 | elif model_type.lower() == "llama.cpp":
642 | # For llama.cpp we need quant info
643 | quant = request.quant or "Q4_K_M" # Default quantization
644 | pattern = f"*{quant}*"
645 |
646 | # Add to LLAMA_CPP_MODELS
647 | LLAMA_CPP_MODELS[model_id] = {
648 | "id": model_id,
649 | "name": repo_id,
650 | "quant": quant,
651 | "pattern": pattern,
652 | "revision": revision,
653 | "gpu": request.gpu, # Can be None for CPU
654 | "max_tokens": request.max_tokens,
655 | "loaded": False,
656 | "hf_direct": True # Mark as directly loaded from HF
657 | }
658 |
659 | # Preload the model
660 | background_tasks.add_task(preload_llama_cpp_model, model_id)
661 | # Update model status
662 | LLAMA_CPP_MODELS[model_id]["loaded"] = True
663 |
664 | else:
665 | raise HTTPException(
666 | status_code=status.HTTP_400_BAD_REQUEST,
667 | detail=f"Invalid model type: {model_type}. Must be 'vllm' or 'llama.cpp'"
668 | )
669 |
670 | return {
671 | "status": "success",
672 | "message": f"Started loading model {repo_id} as {model_id}",
673 | "model_id": model_id,
674 | "model_type": model_type,
675 | "repo_id": repo_id
676 | }
677 |
678 | @api_app.post("/admin/models/unload", dependencies=[Depends(verify_api_key)])
679 | async def unload_model(request: ModelLoadRequest):
680 | """Unload a specific model from memory"""
681 | model_id = request.model_id
682 |
683 | # Check if model exists
684 | if model_id in VLLM_MODELS:
685 | model_type = "vllm"
686 | model_info = VLLM_MODELS[model_id]
687 | elif model_id in LLAMA_CPP_MODELS:
688 | model_type = "llama.cpp"
689 | model_info = LLAMA_CPP_MODELS[model_id]
690 | else:
691 | raise HTTPException(
692 | status_code=status.HTTP_404_NOT_FOUND,
693 | detail=f"Model {model_id} not found"
694 | )
695 |
696 | # Check if model is loaded
697 | if not model_info.get("loaded", False):
698 | return {
699 | "status": "success",
700 | "message": f"Model {model_id} is not loaded",
701 | "model_id": model_id,
702 | "model_type": model_type
703 | }
704 |
705 | # Update model status
706 | if model_type == "vllm":
707 | VLLM_MODELS[model_id]["loaded"] = False
708 | else: # llama.cpp
709 | LLAMA_CPP_MODELS[model_id]["loaded"] = False
710 |
711 | return {
712 | "status": "success",
713 | "message": f"Unloaded model {model_id}",
714 | "model_id": model_id,
715 | "model_type": model_type
716 | }
717 |
718 | @api_app.get("/admin/models/status/{model_id}", dependencies=[Depends(verify_api_key)])
719 | async def get_model_status(model_id: str):
720 | """Get the status of a specific model"""
721 | # Check if model exists
722 | if model_id in VLLM_MODELS:
723 | model_type = "vllm"
724 | model_info = VLLM_MODELS[model_id]
725 | elif model_id in LLAMA_CPP_MODELS:
726 | model_type = "llama.cpp"
727 | model_info = LLAMA_CPP_MODELS[model_id]
728 | else:
729 | raise HTTPException(
730 | status_code=status.HTTP_404_NOT_FOUND,
731 | detail=f"Model {model_id} not found"
732 | )
733 |
734 | # Get model stats if available
735 | model_stats = model_stats_dict.get(model_id, {})
736 |
737 | # Include HF info if available
738 | hf_info = {}
739 | if model_info.get("hf_direct"):
740 | hf_info = {
741 | "repo_id": model_info.get("name"),
742 | "revision": model_info.get("revision"),
743 | }
744 | if model_type == "llama.cpp":
745 | hf_info["quant"] = model_info.get("quant")
746 |
747 | return {
748 | "model_id": model_id,
749 | "model_type": model_type,
750 | "loaded": model_info.get("loaded", False),
751 | "stats": model_stats,
752 | "hf_info": hf_info if hf_info else None
753 | }
754 |
755 | # Admin API endpoints
756 | class APIKeyRequest(BaseModel):
757 | user_id: str
758 | rate_limit: int = 60
759 | quota: int = 1000000
760 |
761 | class APIKey(BaseModel):
762 | key: str
763 | user_id: str
764 | rate_limit: int
765 | quota: int
766 | created_at: str
767 |
768 | @api_app.post("/admin/api-keys", response_model=APIKey)
769 | async def create_api_key(request: APIKeyRequest, auth_info: dict = Depends(verify_api_key)):
770 | """Create a new API key for a user (admin only)"""
771 | # Check if this is an admin request
772 | if auth_info["user_id"] != "default":
773 | raise HTTPException(
774 | status_code=status.HTTP_403_FORBIDDEN,
775 | detail="Only admin users can create API keys"
776 | )
777 |
778 | # Generate a new API key
779 | new_key = f"sk-modal-{uuid.uuid4()}"
780 | user_id = request.user_id
781 |
782 | # Store the key
783 | api_keys_dict[user_id] = {
784 | "key": new_key,
785 | "rate_limit": request.rate_limit,
786 | "quota": request.quota,
787 | "created_at": datetime.now().isoformat(),
788 | "owner": user_id
789 | }
790 |
791 | # Initialize user usage
792 | if not user_usage_dict.contains(user_id):
793 | user_usage_dict[user_id] = {
794 | "requests": [],
795 | "tokens": {
796 | "input": 0,
797 | "output": 0,
798 | "last_reset": datetime.now().isoformat()
799 | }
800 | }
801 |
802 | return APIKey(
803 | key=new_key,
804 | user_id=user_id,
805 | rate_limit=request.rate_limit,
806 | quota=request.quota,
807 | created_at=datetime.now().isoformat()
808 | )
809 |
810 | @api_app.get("/admin/api-keys")
811 | async def list_api_keys(auth_info: dict = Depends(verify_api_key)):
812 | """List all API keys (admin only)"""
813 | # Check if this is an admin request
814 | if auth_info["user_id"] != "default":
815 | raise HTTPException(
816 | status_code=status.HTTP_403_FORBIDDEN,
817 | detail="Only admin users can list API keys"
818 | )
819 |
820 | # Return all keys (except the actual key values for security)
821 | keys = []
822 | for user_id, key_info in api_keys_dict.items():
823 | keys.append({
824 | "user_id": user_id,
825 | "rate_limit": key_info.get("rate_limit", 60),
826 | "quota": key_info.get("quota", 1000000),
827 | "created_at": key_info.get("created_at", datetime.now().isoformat()),
828 | # Mask the actual key
829 | "key": key_info.get("key", "")[:8] + "..." if key_info.get("key") else "None"
830 | })
831 |
832 | return {"keys": keys}
833 |
834 | @api_app.get("/admin/stats")
835 | async def get_stats(auth_info: dict = Depends(verify_api_key)):
836 | """Get usage statistics (admin only)"""
837 | # Check if this is an admin request
838 | if auth_info["user_id"] != "default":
839 | raise HTTPException(
840 | status_code=status.HTTP_403_FORBIDDEN,
841 | detail="Only admin users can view stats"
842 | )
843 |
844 | # Get model stats
845 | model_stats = {}
846 | for model_id in list(VLLM_MODELS.keys()) + list(LLAMA_CPP_MODELS.keys()):
847 | if model_id in model_stats_dict:
848 | model_stats[model_id] = model_stats_dict[model_id]
849 |
850 | # Get user stats
851 | user_stats = {}
852 | for user_id in user_usage_dict.keys():
853 | usage = user_usage_dict[user_id]
854 | # Don't include request timestamps for brevity
855 | if "requests" in usage:
856 | usage = usage.copy()
857 | usage["request_count"] = len(usage["requests"])
858 | del usage["requests"]
859 | user_stats[user_id] = usage
860 |
861 | # Get queue info
862 | queue_info = {
863 | "pending_requests": request_queue.len(),
864 | "active_workers": model_stats_dict.get("workers_running", 0)
865 | }
866 |
867 | return {
868 | "models": model_stats,
869 | "users": user_stats,
870 | "queue": queue_info,
871 | "timestamp": datetime.now().isoformat()
872 | }
873 |
874 | @api_app.delete("/admin/api-keys/{user_id}")
875 | async def delete_api_key(user_id: str, auth_info: dict = Depends(verify_api_key)):
876 | """Delete an API key (admin only)"""
877 | # Check if this is an admin request
878 | if auth_info["user_id"] != "default":
879 | raise HTTPException(
880 | status_code=status.HTTP_403_FORBIDDEN,
881 | detail="Only admin users can delete API keys"
882 | )
883 |
884 | # Check if the key exists
885 | if not api_keys_dict.contains(user_id):
886 | raise HTTPException(
887 | status_code=status.HTTP_404_NOT_FOUND,
888 | detail=f"No API key found for user {user_id}"
889 | )
890 |
891 | # Can't delete the default key
892 | if user_id == "default":
893 | raise HTTPException(
894 | status_code=status.HTTP_400_BAD_REQUEST,
895 | detail="Cannot delete the default API key"
896 | )
897 |
898 | # Delete the key
899 | api_keys_dict.pop(user_id)
900 |
901 | return {"status": "success", "message": f"API key deleted for user {user_id}"}
902 |
903 | @api_app.post("/v1/chat/completions")
904 | async def chat_completions(request: Request, background_tasks: BackgroundTasks, auth_info: dict = Depends(verify_api_key)):
905 | """OpenAI-compatible chat completions endpoint with request queueing, streaming and response caching"""
906 | try:
907 | json_data = await request.json()
908 |
909 | # Extract model or use default
910 | model_id = json_data.get("model", DEFAULT_MODEL)
911 | messages = json_data.get("messages", [])
912 | temperature = json_data.get("temperature", 0.7)
913 | max_tokens = json_data.get("max_tokens", 1024)
914 | stream = json_data.get("stream", False)
915 | user = json_data.get("user", auth_info["user_id"])
916 |
917 | # Calculate a cache key based on the request parameters
918 | cache_key = calculate_cache_key(model_id, messages, temperature, max_tokens)
919 |
920 | # Check if we have a cached response in memory cache first (faster)
921 | cached_response = memory_cache.get(cache_key)
922 | if cached_response and not stream: # Don't use cache for streaming requests
923 | # Update stats
924 | update_stats(model_id, "cache_hit")
925 | return cached_response
926 |
927 | # Check if we have a cached response in Modal's persistent cache
928 | if not cached_response and cache_key in response_dict and not stream:
929 | cached_response = response_dict[cache_key]
930 | cache_age = time.time() - cached_response.get("timestamp", 0)
931 |
932 | # Use cached response if it's fresh enough
933 | if cache_age < MAX_CACHE_AGE:
934 | # Update stats
935 | update_stats(model_id, "cache_hit")
936 | response_data = cached_response["response"]
937 |
938 | # Also cache in memory for faster access next time
939 | memory_cache.set(cache_key, response_data)
940 |
941 | return response_data
942 |
943 | # Select best model if "auto" is specified
944 | if model_id == "auto" and len(messages) > 0:
945 | # Get the last user message
946 | last_message = None
947 | for msg in reversed(messages):
948 | if msg.get("role") == "user":
949 | last_message = msg.get("content", "")
950 | break
951 |
952 | if last_message:
953 | prompt = last_message
954 | # Select best model based on prompt and parameters
955 | model_id = select_best_model(prompt, max_tokens, temperature)
956 | logging.info(f"Auto-selected model: {model_id} for prompt")
957 |
958 | # Check if model exists
959 | if model_id not in VLLM_MODELS and model_id not in LLAMA_CPP_MODELS:
960 | # Default to the default model if specified model not found
961 | logging.warning(f"Model {model_id} not found, using default: {DEFAULT_MODEL}")
962 | model_id = DEFAULT_MODEL
963 |
964 | # Create a unique request ID
965 | request_id = str(uuid.uuid4())
966 |
967 | # Create request object
968 | gen_request = GenerationRequest(
969 | request_id=request_id,
970 | model_id=model_id,
971 | messages=messages,
972 | temperature=temperature,
973 | max_tokens=max_tokens,
974 | top_p=json_data.get("top_p", 1.0),
975 | frequency_penalty=json_data.get("frequency_penalty", 0.0),
976 | presence_penalty=json_data.get("presence_penalty", 0.0),
977 | user=user,
978 | stream=stream,
979 | api_key=auth_info["key"]
980 | )
981 |
982 | # For streaming requests, set up streaming response
983 | if stream:
984 | # Create a new stream
985 | stream_manager.create_stream(request_id)
986 |
987 | # Put the request in the queue
988 | await request_queue.put.aio(gen_request.model_dump())
989 |
990 | # Update stats
991 | update_stats(model_id, "request_count")
992 | update_stats(model_id, "stream_count")
993 |
994 | # Start a background worker to process the request if needed
995 | background_tasks.add_task(ensure_worker_running)
996 |
997 | # Return a streaming response using FastAPI's StreamingResponse
998 | from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
999 | return FastAPIStreamingResponse(
1000 | content=stream_response(request_id, model_id, auth_info["user_id"]),
1001 | media_type="text/event-stream"
1002 | )
1003 |
1004 | # For non-streaming, enqueue the request and wait for result
1005 | # Put the request in the queue
1006 | await request_queue.put.aio(gen_request.model_dump())
1007 |
1008 | # Update stats
1009 | update_stats(model_id, "request_count")
1010 |
1011 | # Start a background worker to process the request if needed
1012 | background_tasks.add_task(ensure_worker_running)
1013 |
1014 | # Wait for the response with timeout
1015 | start_time = time.time()
1016 | timeout = 120 # 2-minute timeout for non-streaming requests
1017 |
1018 | while time.time() - start_time < timeout:
1019 | # Check memory cache first (faster)
1020 | response_data = memory_cache.get(request_id)
1021 | if response_data:
1022 | # Update stats
1023 | update_stats(model_id, "success_count")
1024 | estimate_tokens(messages, response_data, auth_info["user_id"], model_id)
1025 |
1026 | # Save to persistent cache
1027 | response_dict[cache_key] = {
1028 | "response": response_data,
1029 | "timestamp": time.time()
1030 | }
1031 |
1032 | # Clean up request-specific cache
1033 | memory_cache.set(request_id, None)
1034 |
1035 | return response_data
1036 |
1037 | # Check persistent cache
1038 | if response_dict.contains(request_id):
1039 | response_data = response_dict[request_id]
1040 |
1041 | # Remove from response dict to save memory
1042 | try:
1043 | response_dict.pop(request_id)
1044 | except Exception:
1045 | pass
1046 |
1047 | # Save to cache
1048 | response_dict[cache_key] = {
1049 | "response": response_data,
1050 | "timestamp": time.time()
1051 | }
1052 |
1053 | # Also cache in memory
1054 | memory_cache.set(cache_key, response_data)
1055 |
1056 | # Update stats
1057 | update_stats(model_id, "success_count")
1058 | estimate_tokens(messages, response_data, auth_info["user_id"], model_id)
1059 |
1060 | return response_data
1061 |
1062 | # Wait a bit before checking again
1063 | await asyncio.sleep(0.1)
1064 |
1065 | # If we get here, we timed out
1066 | update_stats(model_id, "timeout_count")
1067 | raise HTTPException(
1068 | status_code=status.HTTP_504_GATEWAY_TIMEOUT,
1069 | detail="Request timed out. The model may be busy. Please try again later."
1070 | )
1071 |
1072 | except Exception as e:
1073 | logging.error(f"Error in chat completions: {str(e)}")
1074 | # Update error stats
1075 | if "model_id" in locals():
1076 | update_stats(model_id, "error_count")
1077 |
1078 | raise HTTPException(
1079 | status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
1080 | detail=f"Error generating response: {str(e)}"
1081 | )
1082 |
1083 | async def stream_response(request_id: str, model_id: str, user_id: str) -> AsyncIterator[str]:
1084 | """Stream response chunks to the client"""
1085 | try:
1086 | # Stream header
1087 | yield "data: " + json.dumps({"object": "chat.completion.chunk"}) + "\n\n"
1088 |
1089 | # Stream chunks
1090 | async for chunk in stream_manager.get_chunks(request_id):
1091 | if chunk:
1092 | yield f"data: {json.dumps(chunk)}\n\n"
1093 |
1094 | # Stream done
1095 | yield "data: [DONE]\n\n"
1096 |
1097 | except Exception as e:
1098 | logging.error(f"Error streaming response: {str(e)}")
1099 | # Update error stats
1100 | update_stats(model_id, "stream_error_count")
1101 |
1102 | # Send error as SSE
1103 | error_json = json.dumps({"error": str(e)})
1104 | yield f"data: {error_json}\n\n"
1105 | yield "data: [DONE]\n\n"
1106 |
1107 | async def ensure_worker_running():
1108 | """Ensure that a worker is running to process the queue"""
1109 | # Check if workers are already running via a sentinel in shared dict
1110 | workers_running_key = "workers_running"
1111 |
1112 | if not model_stats_dict.contains(workers_running_key):
1113 | model_stats_dict[workers_running_key] = 0
1114 |
1115 | current_workers = model_stats_dict[workers_running_key]
1116 |
1117 | # If no workers or too few workers, start more
1118 | if current_workers < 3: # Keep up to 3 workers running
1119 | # Increment worker count
1120 | model_stats_dict[workers_running_key] = current_workers + 1
1121 |
1122 | # Start a worker
1123 | await process_queue_worker.spawn.aio()
1124 |
1125 | def calculate_cache_key(model_id: str, messages: List[dict], temperature: float, max_tokens: int) -> str:
1126 | """Calculate a deterministic cache key for a request using SHA-256"""
1127 | # Create a simplified version of the request for cache key
1128 | cache_dict = {
1129 | "model": model_id,
1130 | "messages": messages,
1131 | "temperature": round(temperature, 2), # Round to reduce variations
1132 | "max_tokens": max_tokens
1133 | }
1134 | # Convert to a stable string representation and hash it with SHA-256
1135 | cache_str = json.dumps(cache_dict, sort_keys=True)
1136 | hash_obj = hashlib.sha256(cache_str.encode())
1137 | return f"cache:{hash_obj.hexdigest()[:16]}"
1138 |
1139 | def update_stats(model_id: str, stat_type: str):
1140 | """Update usage statistics for a model"""
1141 | if not model_stats_dict.contains(model_id):
1142 | model_stats_dict[model_id] = {
1143 | "request_count": 0,
1144 | "success_count": 0,
1145 | "error_count": 0,
1146 | "timeout_count": 0,
1147 | "cache_hit": 0,
1148 | "token_count": 0,
1149 | "avg_latency": 0
1150 | }
1151 |
1152 | stats = model_stats_dict[model_id]
1153 | stats[stat_type] = stats.get(stat_type, 0) + 1
1154 | model_stats_dict[model_id] = stats
1155 |
1156 | def estimate_tokens(messages: List[dict], response: dict, user_id: str, model_id: str):
1157 | """Estimate token usage and update user quotas"""
1158 | # Very simple token estimation based on whitespace-split words * 1.3
1159 | input_tokens = 0
1160 | for msg in messages:
1161 | input_tokens += len(msg.get("content", "").split()) * 1.3
1162 |
1163 | output_tokens = 0
1164 | if response and "choices" in response:
1165 | for choice in response["choices"]:
1166 | if "message" in choice and "content" in choice["message"]:
1167 | output_tokens += len(choice["message"]["content"].split()) * 1.3
1168 |
1169 | # Update model stats
1170 | if model_stats_dict.contains(model_id):
1171 | stats = model_stats_dict[model_id]
1172 | stats["token_count"] = stats.get("token_count", 0) + input_tokens + output_tokens
1173 | model_stats_dict[model_id] = stats
1174 |
1175 | # Update user usage
1176 | if user_id in user_usage_dict:
1177 | usage = user_usage_dict[user_id]
1178 |
1179 | # Check if we need to reset daily counters
1180 | last_reset = datetime.fromisoformat(usage["tokens"]["last_reset"])
1181 | now = datetime.now()
1182 |
1183 | if now.date() > last_reset.date():
1184 | # Reset daily counters
1185 | usage["tokens"]["input"] = 0
1186 | usage["tokens"]["output"] = 0
1187 | usage["tokens"]["last_reset"] = now.isoformat()
1188 |
1189 | # Update token counts
1190 | usage["tokens"]["input"] += int(input_tokens)
1191 | usage["tokens"]["output"] += int(output_tokens)
1192 | user_usage_dict[user_id] = usage
1193 |
1194 | def select_best_model(prompt: str, n_predict: int, temperature: float) -> str:
1195 | """
1196 | Intelligently selects the best model based on input parameters.
1197 |
1198 | Args:
1199 | prompt (str): The input prompt for the model.
1200 | n_predict (int): The number of tokens to predict.
1201 | temperature (float): The sampling temperature.
1202 |
1203 | Returns:
1204 | str: The identifier of the best model to use.
1205 | """
1206 | # Check for code generation patterns
1207 | code_indicators = ["```", "def ", "class ", "function", "import ", "from ", "<script", "<style",
1208 | "SELECT ", "CREATE TABLE", "const ", "let ", "var ", "function(", "=>"]
1209 |
1210 | is_likely_code = any(indicator in prompt for indicator in code_indicators)
1211 |
1212 | # Check for creative writing patterns
1213 | creative_indicators = ["story", "poem", "creative", "imagine", "fiction", "narrative",
1214 | "write a", "compose", "create a"]
1215 |
1216 | is_creative_task = any(indicator in prompt.lower() for indicator in creative_indicators)
1217 |
1218 | # Check for analytical/reasoning tasks
1219 | analytical_indicators = ["explain", "analyze", "compare", "contrast", "reason",
1220 | "evaluate", "assess", "why", "how does"]
1221 |
1222 | is_analytical_task = any(indicator in prompt.lower() for indicator in analytical_indicators)
1223 |
1224 | # Decision logic
1225 | if is_likely_code:
1226 | # For code generation, prefer phi-4 for all code tasks
1227 | return "phi-4" # Excellent for code generation
1228 |
1229 | elif is_creative_task:
1230 | # For creative tasks, use models with higher creativity
1231 | if temperature > 0.8:
1232 | return "deepseek-r1" # More creative at high temperatures
1233 | else:
1234 | return "phi-4" # Good balance of creativity and coherence
1235 |
1236 | elif is_analytical_task:
1237 | # For analytical tasks, use models with strong reasoning
1238 | return "phi-4" # Strong reasoning capabilities
1239 |
1240 | # Length-based decisions
1241 | if len(prompt) > 2000:
1242 | # For very long prompts, use models with good context handling
1243 | return "llama3-8b"
1244 | elif len(prompt) < 1000:
1245 | # For shorter prompts, prefer phi-4
1246 | return "phi-4"
1247 |
1248 | # Temperature-based decisions
1249 | if temperature < 0.5:
1250 | # For deterministic outputs
1251 | return "phi-4"
1252 | elif temperature > 0.9:
1253 | # For very creative outputs
1254 | return "deepseek-r1"
1255 |
1256 | # Default to phi-4 instead of the standard model
1257 | return "phi-4"
1258 |
1259 | # vLLM serving function
1260 | @app.function(
1261 | image=vllm_image,
1262 | gpu="H100:1",
1263 | allow_concurrent_inputs=100,
1264 | volumes={
1265 | f"{CACHE_DIR}/huggingface": hf_cache_vol,
1266 | f"{CACHE_DIR}/vllm": vllm_cache_vol,
1267 | },
1268 | timeout=30 * MINUTES,
1269 | )
1270 | @modal.web_server(port=SERVER_PORT)
1271 | def serve_vllm_model(model_id: str = DEFAULT_MODEL):
1272 | """
1273 | Serves a model using vLLM with an OpenAI-compatible API.
1274 |
1275 | Args:
1276 | model_id (str): The identifier of the model to serve. Defaults to DEFAULT_MODEL.
1277 |
1278 | Raises:
1279 | ValueError: If the specified model_id is not found in VLLM_MODELS.
1280 | """
1281 | import subprocess
1282 |
1283 | if model_id not in VLLM_MODELS:
1284 | available_models = list(VLLM_MODELS.keys())
1285 | logging.error(f"Error: Unknown model: {model_id}. Available models: {available_models}")
1286 | raise ValueError(f"Unknown model: {model_id}. Available models: {available_models}")
1287 |
1288 | model_info = VLLM_MODELS[model_id]
1289 | model_name = model_info["name"]
1290 | revision = model_info["revision"]
1291 |
1292 | logging.basicConfig(level=logging.INFO)
1293 | logging.info(f"Starting vLLM server with model: {model_name}")
1294 |
1295 | cmd = [
1296 | "vllm",
1297 | "serve",
1298 | "--uvicorn-log-level=info",
1299 | model_name,
1300 | "--revision",
1301 | revision,
1302 | "--host",
1303 | "0.0.0.0",
1304 | "--port",
1305 | str(SERVER_PORT),
1306 | "--api-key",
1307 | DEFAULT_API_KEY,
1308 | ]
1309 |
1310 | # Use subprocess.run instead of Popen to ensure the server is fully started
1311 | # before returning, and don't use shell=True for better process management
1312 | process = subprocess.Popen(cmd)
1313 |
1314 | # Log that we've started the server
1315 | logging.info(f"Started vLLM server with PID {process.pid}")
1316 |
1317 | # Define the worker that will process the queue
1318 | @app.function(
1319 | image=vllm_image,
1320 | gpu=None, # Worker will spawn GPU functions as needed
1321 | allow_concurrent_inputs=10,
1322 | volumes={
1323 | f"{CACHE_DIR}/huggingface": hf_cache_vol,
1324 | },
1325 | timeout=30 * MINUTES,
1326 | )
1327 | async def process_queue_worker():
1328 | """Worker function that processes requests from the queue"""
1329 | import asyncio
1330 | import time
1331 |
1332 | try:
1333 | # Signal that we're starting a worker
1334 | worker_id = str(uuid.uuid4())[:8]
1335 | logging.info(f"Starting queue processing worker {worker_id}")
1336 |
1337 | # Process requests until timeout or empty queue
1338 | empty_count = 0
1339 | max_empty_count = 10 # Stop after 10 consecutive empty polls
1340 |
1341 | while empty_count < max_empty_count:
1342 | # Try to get a request from the queue
1343 | try:
1344 | request_dict = await request_queue.get.aio(timeout_ms=5000)
1345 | empty_count = 0 # Reset empty counter
1346 |
1347 | # Process the request
1348 | try:
1349 | # Create request object
1350 | request_id = request_dict.get("request_id")
1351 | model_id = request_dict.get("model_id")
1352 | messages = request_dict.get("messages", [])
1353 | temperature = request_dict.get("temperature", 0.7)
1354 | max_tokens = request_dict.get("max_tokens", 1024)
1355 | api_key = request_dict.get("api_key", DEFAULT_API_KEY)
1356 | stream_mode = request_dict.get("stream", False)
1357 |
1358 | logging.info(f"Worker {worker_id} processing request {request_id} for model {model_id}")
1359 |
1360 | # Start time for latency calculation
1361 | start_time = time.time()
1362 |
1363 | if stream_mode:
1364 | # Generate streaming response
1365 | await generate_streaming_response(
1366 | request_id=request_id,
1367 | model_id=model_id,
1368 | messages=messages,
1369 | temperature=temperature,
1370 | max_tokens=max_tokens,
1371 | api_key=api_key
1372 | )
1373 | else:
1374 | # Generate non-streaming response
1375 | response = await generate_response(
1376 | model_id=model_id,
1377 | messages=messages,
1378 | temperature=temperature,
1379 | max_tokens=max_tokens,
1380 | api_key=api_key
1381 | )
1382 |
1383 | # Calculate latency
1384 | latency = time.time() - start_time
1385 |
1386 | # Update latency stats
1387 | if model_stats_dict.contains(model_id):
1388 | stats = model_stats_dict[model_id]
1389 | old_avg = stats.get("avg_latency", 0)
1390 | old_count = stats.get("success_count", 0)
1391 |
1392 | # Calculate new average (moving average)
1393 | if old_count > 0:
1394 | new_avg = (old_avg * old_count + latency) / (old_count + 1)
1395 | else:
1396 | new_avg = latency
1397 |
1398 | stats["avg_latency"] = new_avg
1399 | model_stats_dict[model_id] = stats
1400 |
1401 | # Store the response in both caches
1402 | memory_cache.set(request_id, response)
1403 | response_dict[request_id] = response
1404 |
1405 | logging.info(f"Worker {worker_id} completed request {request_id} in {latency:.2f}s")
1406 |
1407 | except Exception as e:
1408 | # Log error and move on
1409 | logging.error(f"Worker {worker_id} error processing request {request_id}: {str(e)}")
1410 |
1411 | # Create error response
1412 | error_response = {
1413 | "error": {
1414 | "message": str(e),
1415 | "type": "internal_error",
1416 | "code": 500
1417 | }
1418 | }
1419 |
1420 | # Store the error as a response
1421 | memory_cache.set(request_id, error_response)
1422 | response_dict[request_id] = error_response
1423 |
1424 | # If streaming, send error and finish stream
1425 | if "stream_mode" in locals() and stream_mode:
1426 | stream_manager.add_chunk(request_id, {
1427 | "id": f"chatcmpl-{int(time.time())}",
1428 | "object": "chat.completion.chunk",
1429 | "created": int(time.time()),
1430 | "model": model_id,
1431 | "choices": [{
1432 | "index": 0,
1433 | "delta": {"content": f"Error: {str(e)}"},
1434 | "finish_reason": "error"
1435 | }]
1436 | })
1437 | stream_manager.finish_stream(request_id)
1438 |
1439 | except asyncio.TimeoutError:
1440 | # No requests in queue
1441 | empty_count += 1
1442 | logging.info(f"Worker {worker_id}: No requests in queue. Empty count: {empty_count}")
1443 |
1444 | # Clean up expired cache entries and old streams
1445 | if empty_count % 5 == 0: # Every 5 empty polls
1446 | memory_cache.clear_expired()
1447 | stream_manager.clean_old_streams()
1448 |
1449 | await asyncio.sleep(1) # Wait a bit before checking again
1450 |
1451 | # If we get here, we've had too many consecutive empty polls
1452 | logging.info(f"Worker {worker_id} shutting down due to empty queue")
1453 |
1454 | finally:
1455 | # Signal that this worker is done
1456 | workers_running_key = "workers_running"
1457 | if model_stats_dict.contains(workers_running_key):
1458 | current_workers = model_stats_dict[workers_running_key]
1459 | model_stats_dict[workers_running_key] = max(0, current_workers - 1)
1460 | logging.info(f"Worker {worker_id} shutdown. Workers remaining: {max(0, current_workers - 1)}")
1461 |
1462 | async def generate_streaming_response(
1463 | request_id: str,
1464 | model_id: str,
1465 | messages: List[dict],
1466 | temperature: float,
1467 | max_tokens: int,
1468 | api_key: str
1469 | ):
1470 | """
1471 | Generate a streaming response and send chunks to the stream manager.
1472 |
1473 | Args:
1474 | request_id: The unique ID for this request
1475 | model_id: The ID of the model to use
1476 | messages: The chat messages
1477 | temperature: The sampling temperature
1478 | max_tokens: The maximum tokens to generate
1479 | api_key: The API key for authentication
1480 | """
1481 | import httpx
1482 | import time
1483 | import json
1484 | import asyncio
1485 |
1486 | try:
1487 | # Create response ID
1488 | response_id = f"chatcmpl-{int(time.time())}"
1489 |
1490 | if model_id in VLLM_MODELS:
1491 | # Start vLLM server for this model
1492 | server_url = await serve_vllm_model.remote(model_id=model_id)
1493 |
1494 | # Need to wait for server startup
1495 | await wait_for_server(serve_vllm_model.web_url, timeout=120)
1496 |
1497 | # Forward request to vLLM with streaming enabled
1498 | async with httpx.AsyncClient(timeout=120.0) as client:
1499 | headers = {
1500 | "Authorization": f"Bearer {api_key}",
1501 | "Content-Type": "application/json",
1502 | "Accept": "text/event-stream"
1503 | }
1504 |
1505 | # Format request for vLLM OpenAI-compatible endpoint
1506 | vllm_request = {
1507 | "model": VLLM_MODELS[model_id]["name"],
1508 | "messages": messages,
1509 | "temperature": temperature,
1510 | "max_tokens": max_tokens,
1511 | "stream": True
1512 | }
1513 |
1514 | # Make streaming request
1515 | async with client.stream(
1516 | "POST",
1517 | f"{serve_vllm_model.web_url}/v1/chat/completions",
1518 | json=vllm_request,
1519 | headers=headers
1520 | ) as response:
1521 | # Process streaming response
1522 | buffer = ""
1523 | async for chunk in response.aiter_text():
1524 | buffer += chunk
1525 |
1526 | # Process complete SSE messages
1527 | while "\n\n" in buffer:
1528 | message, buffer = buffer.split("\n\n", 1)
1529 |
1530 | if message.startswith("data: "):
1531 | data = message[6:] # Remove "data: " prefix
1532 |
1533 | if data == "[DONE]":
1534 | # End of stream
1535 | stream_manager.finish_stream(request_id)
1536 | return
1537 |
1538 | try:
1539 | # Parse JSON data
1540 | chunk_data = json.loads(data)
1541 | # Forward to client
1542 | stream_manager.add_chunk(request_id, chunk_data)
1543 | except json.JSONDecodeError:
1544 | logging.error(f"Invalid JSON in stream: {data}")
1545 |
1546 | # Ensure stream is finished
1547 | stream_manager.finish_stream(request_id)
1548 |
1549 | elif model_id in LLAMA_CPP_MODELS:
1550 | # For llama.cpp models, we need to simulate streaming
1551 | # First convert the chat format to a prompt
1552 | prompt = format_messages_to_prompt(messages)
1553 |
1554 | # Run llama.cpp with the prompt
1555 | output = await run_llama_cpp_stream.remote(
1556 | model_id=model_id,
1557 | prompt=prompt,
1558 | n_predict=max_tokens,
1559 | temperature=temperature,
1560 | request_id=request_id
1561 | )
1562 |
1563 | # Streaming is handled by the run_llama_cpp_stream function
1564 | # which directly adds chunks to the stream manager
1565 |
1566 | # Wait for completion signal
1567 | while True:
1568 | if request_id in stream_queues and stream_queues[request_id] == "DONE":
1569 | # Clean up
1570 | stream_queues.pop(request_id)
1571 | break
1572 | await asyncio.sleep(0.1)
1573 |
1574 | else:
1575 | raise ValueError(f"Unknown model: {model_id}")
1576 |
1577 | except Exception as e:
1578 | logging.error(f"Error in streaming generation: {str(e)}")
1579 | # Send error chunk
1580 | stream_manager.add_chunk(request_id, {
1581 | "id": response_id,
1582 | "object": "chat.completion.chunk",
1583 | "created": int(time.time()),
1584 | "model": model_id,
1585 | "choices": [{
1586 | "index": 0,
1587 | "delta": {"content": f"Error: {str(e)}"},
1588 | "finish_reason": "error"
1589 | }]
1590 | })
1591 | # Finish stream
1592 | stream_manager.finish_stream(request_id)
1593 |
1594 | async def generate_response(model_id: str, messages: List[dict], temperature: float, max_tokens: int, api_key: str):
1595 | """
1596 | Generate a response using the appropriate model based on model_id.
1597 |
1598 | Args:
1599 | model_id: The ID of the model to use
1600 | messages: The chat messages
1601 | temperature: The sampling temperature
1602 | max_tokens: The maximum tokens to generate
1603 | api_key: The API key for authentication
1604 |
1605 | Returns:
1606 | A response in OpenAI-compatible format
1607 | """
1608 | import httpx
1609 | import time
1610 | import json
1611 | import asyncio
1612 |
1613 | if model_id in VLLM_MODELS:
1614 | # Start vLLM server for this model
1615 | server_url = await serve_vllm_model.remote(model_id=model_id)
1616 |
1617 | # Need to wait for server startup
1618 | await wait_for_server(serve_vllm_model.web_url, timeout=120)
1619 |
1620 | # Forward request to vLLM
1621 | async with httpx.AsyncClient(timeout=60.0) as client:
1622 | headers = {
1623 | "Authorization": f"Bearer {api_key}",
1624 | "Content-Type": "application/json"
1625 | }
1626 |
1627 | # Format request for vLLM OpenAI-compatible endpoint
1628 | vllm_request = {
1629 | "model": VLLM_MODELS[model_id]["name"],
1630 | "messages": messages,
1631 | "temperature": temperature,
1632 | "max_tokens": max_tokens
1633 | }
1634 |
1635 | response = await client.post(
1636 | f"{serve_vllm_model.web_url}/v1/chat/completions",
1637 | json=vllm_request,
1638 | headers=headers
1639 | )
1640 |
1641 | return response.json()
1642 | elif model_id in LLAMA_CPP_MODELS:
1643 | # For llama.cpp models, use the run_llama_cpp function
1644 | # First convert the chat format to a prompt
1645 | prompt = format_messages_to_prompt(messages)
1646 |
1647 | # Run llama.cpp with the prompt
1648 | output = await run_llama_cpp.remote(
1649 | model_id=model_id,
1650 | prompt=prompt,
1651 | n_predict=max_tokens,
1652 | temperature=temperature
1653 | )
1654 |
1655 | # Format the response in the OpenAI format
1656 | completion_text = output.strip()
1657 | finish_reason = "stop" if len(completion_text) < max_tokens else "length"
1658 |
1659 | return {
1660 | "id": f"chatcmpl-{int(time.time())}",
1661 | "object": "chat.completion",
1662 | "created": int(time.time()),
1663 | "model": model_id,
1664 | "choices": [
1665 | {
1666 | "index": 0,
1667 | "message": {
1668 | "role": "assistant",
1669 | "content": completion_text
1670 | },
1671 | "finish_reason": finish_reason
1672 | }
1673 | ],
1674 | "usage": {
1675 | "prompt_tokens": len(prompt) // 4, # Rough estimation
1676 | "completion_tokens": len(completion_text) // 4, # Rough estimation
1677 | "total_tokens": (len(prompt) + len(completion_text)) // 4 # Rough estimation
1678 | }
1679 | }
1680 | else:
1681 | raise ValueError(f"Unknown model: {model_id}")
1682 |
1683 | def format_messages_to_prompt(messages: List[Dict[str, str]]) -> str:
1684 | """
1685 | Convert chat messages to a text prompt format for llama.cpp.
1686 |
1687 | Args:
1688 | messages: List of message dictionaries with role and content
1689 |
1690 | Returns:
1691 | Formatted prompt string
1692 | """
1693 | formatted_prompt = ""
1694 |
1695 | for message in messages:
1696 | role = message.get("role", "").lower()
1697 | content = message.get("content", "")
1698 |
1699 | if role == "system":
1700 | formatted_prompt += f"<|system|>\n{content}\n"
1701 | elif role == "user":
1702 | formatted_prompt += f"<|user|>\n{content}\n"
1703 | elif role == "assistant":
1704 | formatted_prompt += f"<|assistant|>\n{content}\n"
1705 | else:
1706 | # For unknown roles, treat as user
1707 | formatted_prompt += f"<|user|>\n{content}\n"
1708 |
1709 | # Add final assistant marker to prompt the model to respond
1710 | formatted_prompt += "<|assistant|>\n"
1711 |
1712 | return formatted_prompt
1713 |
1714 | async def wait_for_server(url: str, timeout: int = 120, check_interval: int = 2):
1715 | """
1716 | Wait for a server to be ready by checking its health endpoint.
1717 |
1718 | Args:
1719 | url: The base URL of the server
1720 | timeout: Maximum time to wait in seconds
1721 | check_interval: Interval between checks in seconds
1722 |
1723 | Returns:
1724 | True if server is ready, False otherwise
1725 | """
1726 | import httpx
1727 | import asyncio
1728 | import time
1729 |
1730 | start_time = time.time()
1731 | health_url = f"{url}/health"
1732 |
1733 | logging.info(f"Waiting for server at {url} to be ready...")
1734 |
1735 | while time.time() - start_time < timeout:
1736 | try:
1737 | async with httpx.AsyncClient(timeout=5.0) as client:
1738 | response = await client.get(health_url)
1739 | if response.status_code == 200:
1740 | logging.info(f"Server at {url} is ready!")
1741 | return True
1742 | except Exception as e:
1743 | elapsed = time.time() - start_time
1744 | logging.info(f"Server not ready yet after {elapsed:.1f}s: {str(e)}")
1745 |
1746 | await asyncio.sleep(check_interval)
1747 |
1748 | logging.error(f"Timed out waiting for server at {url} after {timeout} seconds")
1749 | return False
1750 |
1751 | @app.function(
1752 | image=llama_cpp_image,
1753 | gpu=None, # Will be set dynamically based on model
1754 | volumes={
1755 | f"{CACHE_DIR}/huggingface": hf_cache_vol,
1756 | f"{CACHE_DIR}/llama_cpp": llama_cpp_cache_vol,
1757 | RESULTS_DIR: results_vol,
1758 | },
1759 | timeout=30 * MINUTES,
1760 | )
1761 | async def run_llama_cpp_stream(
1762 | model_id: str,
1763 | prompt: str,
1764 | n_predict: int = 1024,
1765 | temperature: float = 0.7,
1766 | request_id: str = None,
1767 | ):
1768 | """
1769 | Run streaming inference with llama.cpp for models like DeepSeek-R1 and Phi-4
1770 | """
1771 | import subprocess
1772 | import os
1773 | import json
1774 | import time
1775 | import threading
1776 | from uuid import uuid4
1777 | from pathlib import Path
1778 | from huggingface_hub import snapshot_download
1779 |
1780 | if model_id not in LLAMA_CPP_MODELS:
1781 | available_models = list(LLAMA_CPP_MODELS.keys())
1782 | error_msg = f"Unknown model: {model_id}. Available models: {available_models}"
1783 | logging.error(error_msg)
1784 |
1785 | if request_id:
1786 | # Send error to stream
1787 | stream_manager.add_chunk(request_id, {
1788 | "id": f"chatcmpl-{int(time.time())}",
1789 | "object": "chat.completion.chunk",
1790 | "created": int(time.time()),
1791 | "model": model_id,
1792 | "choices": [{
1793 | "index": 0,
1794 | "delta": {"content": f"Error: {error_msg}"},
1795 | "finish_reason": "error"
1796 | }]
1797 | })
1798 | stream_manager.finish_stream(request_id)
1799 | # Signal completion
1800 | stream_queues[request_id] = "DONE"
1801 |
1802 | raise ValueError(error_msg)
1803 |
1804 | model_info = LLAMA_CPP_MODELS[model_id]
1805 | repo_id = model_info["name"]
1806 | pattern = model_info["pattern"]
1807 | revision = model_info["revision"]
1808 | quant = model_info["quant"]
1809 |
1810 | # Download model if not already cached
1811 | logging.info(f"Downloading model {repo_id} if not present")
1812 | try:
1813 | model_path = snapshot_download(
1814 | repo_id=repo_id,
1815 | revision=revision,
1816 | local_dir=f"{CACHE_DIR}/llama_cpp",
1817 | allow_patterns=[pattern],
1818 | )
1819 | except ValueError as e:
1820 | if "hf_transfer" in str(e):
1821 | # Fallback to standard download if hf_transfer fails
1822 | logging.warning("hf_transfer failed, falling back to standard download")
1823 | # Temporarily disable hf_transfer
1824 | import os
1825 | old_env = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "1")
1826 | os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
1827 | try:
1828 | model_path = snapshot_download(
1829 | repo_id=repo_id,
1830 | revision=revision,
1831 | local_dir=f"{CACHE_DIR}/llama_cpp",
1832 | allow_patterns=[pattern],
1833 | )
1834 | finally:
1835 | # Restore original setting
1836 | os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_env
1837 | else:
1838 | raise
1839 |
1840 | # Find the model file
1841 | model_files = list(Path(model_path).glob(pattern))
1842 | if not model_files:
1843 | error_msg = f"No model files found matching pattern {pattern}"
1844 | logging.error(error_msg)
1845 |
1846 | if request_id:
1847 | # Send error to stream
1848 | stream_manager.add_chunk(request_id, {
1849 | "id": f"chatcmpl-{int(time.time())}",
1850 | "object": "chat.completion.chunk",
1851 | "created": int(time.time()),
1852 | "model": model_id,
1853 | "choices": [{
1854 | "index": 0,
1855 | "delta": {"content": f"Error: {error_msg}"},
1856 | "finish_reason": "error"
1857 | }]
1858 | })
1859 | stream_manager.finish_stream(request_id)
1860 | # Signal completion
1861 | stream_queues[request_id] = "DONE"
1862 |
1863 | raise FileNotFoundError(error_msg)
1864 |
1865 | model_file = str(model_files[0])
1866 | logging.info(f"Using model file: {model_file}")
1867 |
1868 | # Set up command
1869 | cmd = [
1870 | "llama-cli",
1871 | "--model", model_file,
1872 | "--prompt", prompt,
1873 | "--n-predict", str(n_predict),
1874 | "--temp", str(temperature),
1875 | "--ctx-size", "8192",
1876 | ]
1877 |
1878 | # Add GPU layers if needed
1879 | if model_info["gpu"] is not None:
1880 | cmd.extend(["--n-gpu-layers", "9999"]) # Use all layers on GPU
1881 |
1882 | # Run inference
1883 | result_id = str(uuid4())
1884 | logging.info(f"Running streaming inference with ID: {result_id}")
1885 |
1886 | # Create response ID for streaming
1887 | response_id = f"chatcmpl-{int(time.time())}"
1888 |
1889 | # Function to process output in real-time and send to stream
1890 | def process_output(process, request_id):
1891 | content_buffer = ""
1892 | last_send_time = time.time()
1893 |
1894 | # Send initial chunk with role
1895 | if request_id:
1896 | stream_manager.add_chunk(request_id, {
1897 | "id": response_id,
1898 | "object": "chat.completion.chunk",
1899 | "created": int(time.time()),
1900 | "model": model_id,
1901 | "choices": [{
1902 | "index": 0,
1903 | "delta": {"role": "assistant"},
1904 | }]
1905 | })
1906 |
1907 | for line in iter(process.stdout.readline, b''):
1908 | try:
1909 | line_str = line.decode('utf-8', errors='replace')
1910 |
1911 | # Skip llama.cpp info lines
1912 | if line_str.startswith("llama_"):
1913 | continue
1914 |
1915 | # Add to buffer
1916 | content_buffer += line_str
1917 |
1918 | # Send chunks at reasonable intervals or when buffer gets large
1919 | now = time.time()
1920 | if (now - last_send_time > 0.1 or len(content_buffer) > 20) and request_id:
1921 | # Send chunk
1922 | stream_manager.add_chunk(request_id, {
1923 | "id": response_id,
1924 | "object": "chat.completion.chunk",
1925 | "created": int(time.time()),
1926 | "model": model_id,
1927 | "choices": [{
1928 | "index": 0,
1929 | "delta": {"content": content_buffer},
1930 | }]
1931 | })
1932 |
1933 | # Reset buffer and time
1934 | content_buffer = ""
1935 | last_send_time = now
1936 |
1937 | except Exception as e:
1938 | logging.error(f"Error processing output: {str(e)}")
1939 |
1940 | # Send any remaining content
1941 | if content_buffer and request_id:
1942 | stream_manager.add_chunk(request_id, {
1943 | "id": response_id,
1944 | "object": "chat.completion.chunk",
1945 | "created": int(time.time()),
1946 | "model": model_id,
1947 | "choices": [{
1948 | "index": 0,
1949 | "delta": {"content": content_buffer},
1950 | }]
1951 | })
1952 |
1953 | # Send final chunk with finish reason
1954 | if request_id:
1955 | stream_manager.add_chunk(request_id, {
1956 | "id": response_id,
1957 | "object": "chat.completion.chunk",
1958 | "created": int(time.time()),
1959 | "model": model_id,
1960 | "choices": [{
1961 | "index": 0,
1962 | "delta": {},
1963 | "finish_reason": "stop"
1964 | }]
1965 | })
1966 |
1967 | # Finish stream
1968 | stream_manager.finish_stream(request_id)
1969 |
1970 | # Signal completion
1971 | stream_queues[request_id] = "DONE"
1972 |
1973 | # Start process
1974 | process = subprocess.Popen(
1975 | cmd,
1976 | stdout=subprocess.PIPE,
1977 | stderr=subprocess.PIPE,
1978 | text=False,
1979 | bufsize=1 # Line buffered
1980 | )
1981 |
1982 | # Start output processing thread if streaming
1983 | if request_id:
1984 | thread = threading.Thread(target=process_output, args=(process, request_id))
1985 | thread.daemon = True
1986 | thread.start()
1987 |
1988 | # Return immediately for streaming
1989 | return "Streaming in progress"
1990 | else:
1991 | # For non-streaming, collect all output
1992 | stdout, stderr = collect_output(process)
1993 |
1994 | # Save results
1995 | result_dir = Path(RESULTS_DIR) / result_id
1996 | result_dir.mkdir(parents=True, exist_ok=True)
1997 |
1998 | (result_dir / "output.txt").write_text(stdout)
1999 | (result_dir / "stderr.txt").write_text(stderr)
2000 | (result_dir / "prompt.txt").write_text(prompt)
2001 |
2002 | logging.info(f"Results saved to {result_dir}")
2003 | return stdout
2004 |
2005 | @app.function(
2006 | image=llama_cpp_image,
2007 | gpu=None, # Will be set dynamically based on model
2008 | volumes={
2009 | f"{CACHE_DIR}/huggingface": hf_cache_vol,
2010 | f"{CACHE_DIR}/llama_cpp": llama_cpp_cache_vol,
2011 | RESULTS_DIR: results_vol,
2012 | },
2013 | timeout=30 * MINUTES,
2014 | )
2015 | async def run_llama_cpp(
2016 | model_id: str,
2017 | prompt: str = "Tell me about Modal and how it helps with ML deployments.",
2018 | n_predict: int = 1024,
2019 | temperature: float = 0.7,
2020 | ):
2021 | """
2022 | Run inference with llama.cpp for models like DeepSeek-R1 and Phi-4
2023 | """
2024 | import subprocess
2025 | import os
2026 | from uuid import uuid4
2027 | from pathlib import Path
2028 | from huggingface_hub import snapshot_download
2029 |
2030 | if model_id not in LLAMA_CPP_MODELS:
2031 | available_models = list(LLAMA_CPP_MODELS.keys())
2032 | print(f"Error: Unknown model: {model_id}. Available models: {available_models}")
2033 | raise ValueError(f"Unknown model: {model_id}. Available models: {available_models}")
2034 |
2035 | model_info = LLAMA_CPP_MODELS[model_id]
2036 | repo_id = model_info["name"]
2037 | pattern = model_info["pattern"]
2038 | revision = model_info["revision"]
2039 | quant = model_info["quant"]
2040 |
2041 | # Download model if not already cached
2042 | logging.info(f"Downloading model {repo_id} if not present")
2043 | try:
2044 | model_path = snapshot_download(
2045 | repo_id=repo_id,
2046 | revision=revision,
2047 | local_dir=f"{CACHE_DIR}/llama_cpp",
2048 | allow_patterns=[pattern],
2049 | )
2050 | except ValueError as e:
2051 | if "hf_transfer" in str(e):
2052 | # Fallback to standard download if hf_transfer fails
2053 | logging.warning("hf_transfer failed, falling back to standard download")
2054 | # Temporarily disable hf_transfer
2055 | import os
2056 | old_env = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "1")
2057 | os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
2058 | try:
2059 | model_path = snapshot_download(
2060 | repo_id=repo_id,
2061 | revision=revision,
2062 | local_dir=f"{CACHE_DIR}/llama_cpp",
2063 | allow_patterns=[pattern],
2064 | )
2065 | finally:
2066 | # Restore original setting
2067 | os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_env
2068 | else:
2069 | raise
2070 |
2071 | # Find the model file
2072 | model_files = list(Path(model_path).glob(pattern))
2073 | if not model_files:
2074 | logging.error(f"No model files found matching pattern {pattern}")
2075 | raise FileNotFoundError(f"No model files found matching pattern {pattern}")
2076 |
2077 | model_file = str(model_files[0])
2078 | print(f"Using model file: {model_file}")
2079 |
2080 | # Set up command
2081 | cmd = [
2082 | "llama-cli",
2083 | "--model", model_file,
2084 | "--prompt", prompt,
2085 | "--n-predict", str(n_predict),
2086 | "--temp", str(temperature),
2087 | "--ctx-size", "8192",
2088 | ]
2089 |
2090 | # Add GPU layers if needed
2091 | if model_info["gpu"] is not None:
2092 | cmd.extend(["--n-gpu-layers", "9999"]) # Use all layers on GPU
2093 |
2094 | # Run inference
2095 | result_id = str(uuid4())
2096 | print(f"Running inference with ID: {result_id}")
2097 |
2098 | process = subprocess.Popen(
2099 | cmd,
2100 | stdout=subprocess.PIPE,
2101 | stderr=subprocess.PIPE,
2102 | text=False
2103 | )
2104 |
2105 | stdout, stderr = collect_output(process)
2106 |
2107 | # Save results
2108 | result_dir = Path(RESULTS_DIR) / result_id
2109 | result_dir.mkdir(parents=True, exist_ok=True)
2110 |
2111 | (result_dir / "output.txt").write_text(stdout)
2112 | (result_dir / "stderr.txt").write_text(stderr)
2113 | (result_dir / "prompt.txt").write_text(prompt)
2114 |
2115 | print(f"Results saved to {result_dir}")
2116 | return stdout
2117 |
2118 | @app.function(
2119 | image=vllm_image,
2120 | volumes={
2121 | f"{CACHE_DIR}/huggingface": hf_cache_vol,
2122 | },
2123 | )
2124 | def list_available_models():
2125 | """
2126 | Lists available models that can be used with this server.
2127 |
2128 | Returns:
2129 | dict: A dictionary containing lists of available vLLM and llama.cpp models.
2130 | """
2131 | print("Available vLLM models:")
2132 | for model_id, model_info in VLLM_MODELS.items():
2133 | print(f"- {model_id}: {model_info['name']}")
2134 |
2135 | print("\nAvailable llama.cpp models:")
2136 | for model_id, model_info in LLAMA_CPP_MODELS.items():
2137 | gpu_info = f"(GPU: {model_info['gpu']})" if model_info['gpu'] else "(CPU)"
2138 | print(f"- {model_id}: {model_info['name']} {gpu_info}")
2139 |
2140 | return {
2141 | "vllm": list(VLLM_MODELS.keys()),
2142 | "llama_cpp": list(LLAMA_CPP_MODELS.keys())
2143 | }
2144 |
2145 | def collect_output(process):
2146 | """
2147 | Collect output from a process while streaming it.
2148 |
2149 | Args:
2150 | process: The process from which to collect output.
2151 |
2152 | Returns:
2153 | tuple: A tuple containing the collected stdout and stderr as strings.
2154 | """
2155 | import sys
2156 | from queue import Queue
2157 | from threading import Thread
2158 |
2159 | def stream_output(stream, queue, write_stream):
2160 | for line in iter(stream.readline, b""):
2161 | line_str = line.decode("utf-8", errors="replace")
2162 | write_stream.write(line_str)
2163 | write_stream.flush()
2164 | queue.put(line_str)
2165 | stream.close()
2166 |
2167 | stdout_queue = Queue()
2168 | stderr_queue = Queue()
2169 |
2170 | stdout_thread = Thread(target=stream_output, args=(process.stdout, stdout_queue, sys.stdout))
2171 | stderr_thread = Thread(target=stream_output, args=(process.stderr, stderr_queue, sys.stderr))
2172 |
2173 | stdout_thread.start()
2174 | stderr_thread.start()
2175 |
2176 | stdout_thread.join()
2177 | stderr_thread.join()
2178 | process.wait()
2179 |
2180 | stdout_collected = "".join(list(stdout_queue.queue))
2181 | stderr_collected = "".join(list(stderr_queue.queue))
2182 |
2183 | return stdout_collected, stderr_collected
2184 |
2185 | # Main ASGI app for Modal
2186 | @app.function(
2187 | image=vllm_image,
2188 | gpu=None, # No GPU for the API frontend
2189 | allow_concurrent_inputs=100,
2190 | volumes={
2191 | f"{CACHE_DIR}/huggingface": hf_cache_vol,
2192 | },
2193 | )
2194 | @modal.asgi_app()
2195 | def inference_api():
2196 | """The main ASGI app that serves the FastAPI application"""
2197 | return api_app
2198 |
2199 | @app.local_entrypoint()
2200 | def main(
2201 | prompt: str = "What can you tell me about Modal?",
2202 | n_predict: int = 1024,
2203 | temperature: float = 0.7,
2204 | create_admin_key: bool = False,
2205 | stream: bool = False,
2206 | model: str = "auto",
2207 | load_model: str = None,
2208 | load_hf_model: str = None,
2209 | hf_model_type: str = "vllm",
2210 | ):
2211 | """
2212 | Main entrypoint for testing the API
2213 | """
2214 | import json
2215 | import time
2216 | import urllib.request
2217 |
2218 | # Initialize the API
2219 | print(f"Starting API at {inference_api.web_url}")
2220 |
2221 | # Wait for API to be ready
2222 | print("Checking if API is ready...")
2223 | up, start, delay = False, time.time(), 10
2224 | while not up:
2225 | try:
2226 | with urllib.request.urlopen(inference_api.web_url + "/health") as response:
2227 | if response.getcode() == 200:
2228 | up = True
2229 | except Exception:
2230 | if time.time() - start > 5 * MINUTES:
2231 | break
2232 | time.sleep(delay)
2233 |
2234 | assert up, f"Failed health check for API at {inference_api.web_url}"
2235 | print(f"API is up and running at {inference_api.web_url}")
2236 |
2237 | # Create a test API key if requested
2238 | if create_admin_key:
2239 | print("Creating a test API key...")
2240 | key_request = {
2241 | "user_id": "test_user",
2242 | "rate_limit": 120,
2243 | "quota": 2000000
2244 | }
2245 | headers = {
2246 | "Authorization": f"Bearer {DEFAULT_API_KEY}", # Admin key
2247 | "Content-Type": "application/json",
2248 | }
2249 | req = urllib.request.Request(
2250 | inference_api.web_url + "/admin/api-keys",
2251 | data=json.dumps(key_request).encode("utf-8"),
2252 | headers=headers,
2253 | method="POST",
2254 | )
2255 | try:
2256 | with urllib.request.urlopen(req) as response:
2257 | result = json.loads(response.read().decode())
2258 | print("Created API key:")
2259 | print(json.dumps(result, indent=2))
2260 | # Use this key for the test message
2261 | test_key = result["key"]
2262 | except Exception as e:
2263 | print(f"Error creating API key: {str(e)}")
2264 | test_key = DEFAULT_API_KEY
2265 | else:
2266 | test_key = DEFAULT_API_KEY
2267 |
2268 | # List available models
2269 | print("\nAvailable models:")
2270 | try:
2271 | headers = {
2272 | "Authorization": f"Bearer {test_key}",
2273 | "Content-Type": "application/json",
2274 | }
2275 | req = urllib.request.Request(
2276 | inference_api.web_url + "/v1/models",
2277 | headers=headers,
2278 | method="GET",
2279 | )
2280 | with urllib.request.urlopen(req) as response:
2281 | models = json.loads(response.read().decode())
2282 | print(json.dumps(models, indent=2))
2283 | except Exception as e:
2284 | print(f"Error listing models: {str(e)}")
2285 |
2286 | # Select best model for the prompt
2287 | model = select_best_model(prompt, n_predict, temperature)
2288 |
2289 | # Send a test message
2290 | print(f"\nSending a sample message to {inference_api.web_url}")
2291 | messages = [{"role": "user", "content": prompt}]
2292 |
2293 | headers = {
2294 | "Authorization": f"Bearer {test_key}",
2295 | "Content-Type": "application/json",
2296 | }
2297 | payload = json.dumps({
2298 | "messages": messages,
2299 | "model": model,
2300 | "temperature": temperature,
2301 | "max_tokens": n_predict,
2302 | "stream": stream
2303 | })
2304 | req = urllib.request.Request(
2305 | inference_api.web_url + "/v1/chat/completions",
2306 | data=payload.encode("utf-8"),
2307 | headers=headers,
2308 | method="POST",
2309 | )
2310 |
2311 | try:
2312 | if stream:
2313 | print("Streaming response:")
2314 | with urllib.request.urlopen(req) as response:
2315 | for line in response:
2316 | line = line.decode('utf-8')
2317 | if line.startswith('data: '):
2318 | data = line[6:].strip()
2319 | if data == '[DONE]':
2320 | print("\n[DONE]")
2321 | else:
2322 | try:
2323 | chunk = json.loads(data)
2324 | if 'choices' in chunk and len(chunk['choices']) > 0:
2325 | if 'delta' in chunk['choices'][0] and 'content' in chunk['choices'][0]['delta']:
2326 | content = chunk['choices'][0]['delta']['content']
2327 | print(content, end='', flush=True)
2328 | except json.JSONDecodeError:
2329 | print(f"Error parsing: {data}")
2330 | else:
2331 | with urllib.request.urlopen(req) as response:
2332 | result = json.loads(response.read().decode())
2333 | print("Response:")
2334 | print(json.dumps(result, indent=2))
2335 | except Exception as e:
2336 | print(f"Error: {str(e)}")
2337 |
2338 | # Check API stats
2339 | print("\nChecking API stats...")
2340 | headers = {
2341 | "Authorization": f"Bearer {DEFAULT_API_KEY}", # Admin key
2342 | "Content-Type": "application/json",
2343 | }
2344 | req = urllib.request.Request(
2345 | inference_api.web_url + "/admin/stats",
2346 | headers=headers,
2347 | method="GET",
2348 | )
2349 | try:
2350 | with urllib.request.urlopen(req) as response:
2351 | stats = json.loads(response.read().decode())
2352 | print("API Stats:")
2353 | print(json.dumps(stats, indent=2))
2354 | except Exception as e:
2355 | print(f"Error getting stats: {str(e)}")
2356 |
2357 | # Start a worker if none running
2358 | try:
2359 | current_workers = stats.get("queue", {}).get("active_workers", 0)
2360 | if current_workers < 1:
2361 | print("\nStarting a queue worker...")
2362 | process_queue_worker.spawn()
2363 | except Exception as e:
2364 | print(f"Error starting worker: {str(e)}")
2365 |
2366 | print(f"\nAPI is available at {inference_api.web_url}")
2367 | print(f"Documentation is at {inference_api.web_url}/docs")
2368 | print(f"Default Bearer token: {DEFAULT_API_KEY}")
2369 |
2370 | if create_admin_key:
2371 | print(f"Test Bearer token: {test_key}")
2372 |
2373 | # If a model was specified to load, load it
2374 | if load_model:
2375 | print(f"\nLoading model: {load_model}")
2376 | load_url = f"{inference_api.web_url}/admin/models/load"
2377 | headers = {
2378 | "Authorization": f"Bearer {test_key}",
2379 | "Content-Type": "application/json",
2380 | }
2381 | payload = json.dumps({
2382 | "model_id": load_model,
2383 | "force_reload": True
2384 | })
2385 | req = urllib.request.Request(
2386 | load_url,
2387 | data=payload.encode("utf-8"),
2388 | headers=headers,
2389 | method="POST",
2390 | )
2391 | try:
2392 | with urllib.request.urlopen(req) as response:
2393 | result = json.loads(response.read().decode())
2394 | print("Load response:")
2395 | print(json.dumps(result, indent=2))
2396 |
2397 | # If it's a small model, wait a bit for it to load
2398 | if load_model in ["tiny-llama-1.1b", "phi-2"]:
2399 | print(f"Waiting for {load_model} to load...")
2400 | time.sleep(10)
2401 |
2402 | # Check status
2403 | status_url = f"{inference_api.web_url}/admin/models/status/{load_model}"
2404 | status_req = urllib.request.Request(
2405 | status_url,
2406 | headers={"Authorization": f"Bearer {test_key}"},
2407 | method="GET",
2408 | )
2409 | with urllib.request.urlopen(status_req) as status_response:
2410 | status_result = json.loads(status_response.read().decode())
2411 | print("Model status:")
2412 | print(json.dumps(status_result, indent=2))
2413 |
2414 | # Use this model for the test
2415 | model = load_model
2416 | except Exception as e:
2417 | print(f"Error loading model: {str(e)}")
2418 |
2419 | # If a HF model was specified to load directly
2420 | if load_hf_model:
2421 | print(f"\nLoading HF model: {load_hf_model} with type {hf_model_type}")
2422 | load_url = f"{inference_api.web_url}/admin/models/load-from-hf"
2423 | headers = {
2424 | "Authorization": f"Bearer {test_key}",
2425 | "Content-Type": "application/json",
2426 | }
2427 | payload = json.dumps({
2428 | "repo_id": load_hf_model,
2429 | "model_type": hf_model_type,
2430 | "max_tokens": n_predict
2431 | })
2432 | req = urllib.request.Request(
2433 | load_url,
2434 | data=payload.encode("utf-8"),
2435 | headers=headers,
2436 | method="POST",
2437 | )
2438 | try:
2439 | with urllib.request.urlopen(req) as response:
2440 | result = json.loads(response.read().decode())
2441 | print("HF Load response:")
2442 | print(json.dumps(result, indent=2))
2443 |
2444 | # Get the model_id from the response
2445 | hf_model_id = result.get("model_id")
2446 |
2447 | # Wait a bit for it to start loading
2448 | print(f"Waiting for {load_hf_model} to start loading...")
2449 | time.sleep(5)
2450 |
2451 | # Check status
2452 | if hf_model_id:
2453 | status_url = f"{inference_api.web_url}/admin/models/status/{hf_model_id}"
2454 | status_req = urllib.request.Request(
2455 | status_url,
2456 | headers={"Authorization": f"Bearer {test_key}"},
2457 | method="GET",
2458 | )
2459 | with urllib.request.urlopen(status_req) as status_response:
2460 | status_result = json.loads(status_response.read().decode())
2461 | print("Model status:")
2462 | print(json.dumps(status_result, indent=2))
2463 |
2464 | # Use this model for the test
2465 | if hf_model_id:
2466 | model = hf_model_id
2467 | except Exception as e:
2468 | print(f"Error loading HF model: {str(e)}")
2469 |
2470 | # Show curl examples
2471 | print("\nExample curl commands:")
2472 |
2473 | # Regular completion
2474 | print(f"""# Regular completion:
2475 | curl -X POST {inference_api.web_url}/v1/chat/completions \\
2476 | -H "Content-Type: application/json" \\
2477 | -H "Authorization: Bearer {test_key}" \\
2478 | -d '{{
2479 | "model": "{model}",
2480 | "messages": [
2481 | {{
2482 | "role": "user",
2483 | "content": "Hello, how can you help me today?"
2484 | }}
2485 | ],
2486 | "temperature": 0.7,
2487 | "max_tokens": 500
2488 | }}'""")
2489 |
2490 | # Streaming completion
2491 | print(f"""\n# Streaming completion:
2492 | curl -X POST {inference_api.web_url}/v1/chat/completions \\
2493 | -H "Content-Type: application/json" \\
2494 | -H "Authorization: Bearer {test_key}" \\
2495 | -d '{{
2496 | "model": "{model}",
2497 | "messages": [
2498 | {{
2499 | "role": "user",
2500 | "content": "Write a short story about AI"
2501 | }}
2502 | ],
2503 | "temperature": 0.8,
2504 | "max_tokens": 1000,
2505 | "stream": true
2506 | }}' --no-buffer""")
2507 |
2508 | # List models
2509 | print(f"""\n# List available models:
2510 | curl -X GET {inference_api.web_url}/v1/models \\
2511 | -H "Authorization: Bearer {test_key}" """)
2512 |
2513 | # Model management commands
2514 | print(f"""\n# Load a model:
2515 | curl -X POST {inference_api.web_url}/admin/models/load \\
2516 | -H "Content-Type: application/json" \\
2517 | -H "Authorization: Bearer {test_key}" \\
2518 | -d '{{
2519 | "model_id": "phi-2",
2520 | "force_reload": false
2521 | }}'""")
2522 |
2523 | print(f"""\n# Load a model directly from Hugging Face:
2524 | curl -X POST {inference_api.web_url}/admin/models/load-from-hf \\
2525 | -H "Content-Type: application/json" \\
2526 | -H "Authorization: Bearer {test_key}" \\
2527 | -d '{{
2528 | "repo_id": "microsoft/phi-2",
2529 | "model_type": "vllm",
2530 | "max_tokens": 4096
2531 | }}'""")
2532 |
2533 | print(f"""\n# Get model status:
2534 | curl -X GET {inference_api.web_url}/admin/models/status/phi-2 \\
2535 | -H "Authorization: Bearer {test_key}" """)
2536 |
2537 | print(f"""\n# Unload a model:
2538 | curl -X POST {inference_api.web_url}/admin/models/unload \\
2539 | -H "Content-Type: application/json" \\
2540 | -H "Authorization: Bearer {test_key}" \\
2541 | -d '{{
2542 | "model_id": "phi-2"
2543 | }}'""")
2544 | async def preload_llama_cpp_model(model_id: str):
2545 | """Preload a llama.cpp model to make inference faster on first request"""
2546 | if model_id not in LLAMA_CPP_MODELS:
2547 | logging.error(f"Unknown model: {model_id}")
2548 | return
2549 |
2550 | try:
2551 | # Run a simple inference to load the model
2552 | logging.info(f"Preloading llama.cpp model: {model_id}")
2553 | await run_llama_cpp.remote(
2554 | model_id=model_id,
2555 | prompt="Hello, this is a test to preload the model.",
2556 | n_predict=10,
2557 | temperature=0.7
2558 | )
2559 | logging.info(f"Successfully preloaded llama.cpp model: {model_id}")
2560 | except Exception as e:
2561 | logging.error(f"Error preloading llama.cpp model {model_id}: {str(e)}")
2562 | # Mark as not loaded
2563 | LLAMA_CPP_MODELS[model_id]["loaded"] = False
2564 |
```