# Directory Structure
```
├── .gitignore
├── benchmark
│ ├── backend_request_func.py
│ ├── benchmark_dataset.py
│ ├── benchmark_serving.py
│ ├── benchmark_utils.py
│ └── readme.md
├── benchmark_tool.py
├── pyproject.toml
├── README.md
├── server.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
```
1 | __pycache__
2 | ShareGPT_V3_unfiltered_cleaned_split.json
3 | **/__pycache__
4 |
```
--------------------------------------------------------------------------------
/benchmark/readme.md:
--------------------------------------------------------------------------------
```markdown
1 | These files come straight from the vLLM.
2 |
3 | https://github.com/vllm-project/vllm/tree/main/benchmarks
```
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
```markdown
1 | # MCP vLLM Benchmarking Tool
2 |
3 | This is proof of concept on how to use MCP to interactively benchmark vLLM.
4 |
5 | We are not new to benchmarking, read our blog:
6 |
7 | [Benchmarking vLLM](https://eliovp.com/introducing-our-benchmarking-tool-powered-by-dstack/)
8 |
9 | This is just an exploration of possibilities with MCP.
10 |
11 | ## Usage
12 |
13 | 1. Clone the repository
14 | 2. Add it to your MCP servers:
15 | ```
16 | {
17 | "mcpServers": {
18 | "mcp-vllm": {
19 | "command": "uv",
20 | "args": [
21 | "run",
22 | "/Path/TO/mcp-vllm-benchmarking-tool/server.py"
23 | ]
24 | }
25 | }
26 | }
27 | ```
28 |
29 | Then you can prompt for example like this:
30 |
31 | ```
32 | Do a vllm benchmark for this endpoint: http://10.0.101.39:8888
33 | benchmark the following model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
34 | run the benchmark 3 times with each 32 num prompts, then compare the results, but ignore the first iteration as that is just a warmup.
35 | ```
36 |
37 |
38 | ## Todo:
39 |
40 | - Due to some random outputs by vllm it may show that it found some invalid json. I have not really looked into it yet.
```
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
```toml
1 | [project]
2 | name = "mcp-bencher"
3 | version = "0.1.0"
4 | description = "A model benchmarking tool using vLLM"
5 | readme = "README.md"
6 | requires-python = "~=3.10"
7 | dependencies = [
8 | "mcp[cli]>=1.5.0",
9 | "vllm",
10 | "requests",
11 | "pathlib",
12 | "pytest",
13 | "tqdm",
14 | "aiohttp",
15 | "numpy",
16 | "huggingface_hub",
17 | "transformers",
18 | "pandas",
19 | "datasets",
20 | ]
21 |
```
--------------------------------------------------------------------------------
/benchmark/benchmark_utils.py:
--------------------------------------------------------------------------------
```python
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | import argparse
4 | import json
5 | import math
6 | import os
7 | from typing import Any
8 |
9 |
10 | def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
11 | metrics: dict[str, list],
12 | extra_info: dict[str, Any]) -> list:
13 | """
14 | Save the benchmark results in the format used by PyTorch OSS benchmark with
15 | on metric per record
16 | https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database
17 | """
18 | records = []
19 | if not os.environ.get("SAVE_TO_PYTORCH_BENCHMARK_FORMAT", False):
20 | return records
21 |
22 | for name, benchmark_values in metrics.items():
23 | record = {
24 | "benchmark": {
25 | "name": "vLLM benchmark",
26 | "extra_info": {
27 | "args": vars(args),
28 | },
29 | },
30 | "model": {
31 | "name": args.model,
32 | },
33 | "metric": {
34 | "name": name,
35 | "benchmark_values": benchmark_values,
36 | "extra_info": extra_info,
37 | },
38 | }
39 |
40 | tp = record["benchmark"]["extra_info"]["args"].get(
41 | "tensor_parallel_size")
42 | # Save tensor_parallel_size parameter if it's part of the metadata
43 | if not tp and "tensor_parallel_size" in extra_info:
44 | record["benchmark"]["extra_info"]["args"][
45 | "tensor_parallel_size"] = extra_info["tensor_parallel_size"]
46 |
47 | records.append(record)
48 |
49 | return records
50 |
51 |
52 | class InfEncoder(json.JSONEncoder):
53 |
54 | def clear_inf(self, o: Any):
55 | if isinstance(o, dict):
56 | return {k: self.clear_inf(v) for k, v in o.items()}
57 | elif isinstance(o, list):
58 | return [self.clear_inf(v) for v in o]
59 | elif isinstance(o, float) and math.isinf(o):
60 | return "inf"
61 | return o
62 |
63 | def iterencode(self, o: Any, *args, **kwargs) -> Any:
64 | return super().iterencode(self.clear_inf(o), *args, **kwargs)
65 |
66 |
67 | def write_to_json(filename: str, records: list) -> None:
68 | with open(filename, "w") as f:
69 | json.dump(records, f, cls=InfEncoder)
```
--------------------------------------------------------------------------------
/server.py:
--------------------------------------------------------------------------------
```python
1 | # server.py
2 | from mcp.server.fastmcp import FastMCP
3 | from typing import Dict
4 | import concurrent.futures
5 | import os
6 | import requests
7 | import pathlib
8 | import tqdm
9 |
10 | from benchmark_tool import run_benchmark
11 |
12 | # Create an MCP server
13 | mcp = FastMCP("vLLM Bencher")
14 |
15 | @mcp.tool()
16 | def benchmark_vllm(
17 | model: str,
18 | base_url: str,
19 | num_prompts: int = 10,
20 | ) -> Dict:
21 | """
22 | Run vLLM benchmarking tool to measure model performance
23 |
24 | Args:
25 | model: The model to benchmark (e.g., 'meta-llama/Llama-2-7b-hf')
26 | backend: Backend server to use (vllm, tgi, openai, etc.)
27 | dataset: Dataset to use for benchmarking (sharegpt, random, etc.)
28 | dataset_path: Path to the dataset file
29 | num_prompts: Number of prompts to benchmark with
30 | request_rate: Requests per second
31 | concurrent_requests: Number of concurrent requests
32 | max_tokens: Maximum number of tokens to generate
33 | vllm_dir: Directory where vLLM is installed
34 | api_url: URL of the API to benchmark
35 | save_result: Whether to save benchmark results
36 | result_filename: Filename to save benchmark results
37 | api_key: API key for the backend
38 | trust_remote_code: Whether to trust remote code
39 | extra_args: Additional arguments to pass to benchmark_serving.py
40 |
41 | Returns:
42 | Dictionary containing benchmark results including throughput, latency, and other metrics
43 | """
44 |
45 | # Define the dataset path
46 | dataset_filename = "ShareGPT_V3_unfiltered_cleaned_split.json"
47 | current_dir = pathlib.Path(__file__).parent.absolute()
48 | dataset_path = current_dir / dataset_filename
49 |
50 | # Check if dataset exists, if not, download it
51 | if not dataset_path.exists():
52 | dataset_url = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
53 | try:
54 | response = requests.get(dataset_url, stream=True)
55 | response.raise_for_status()
56 |
57 | # Get file size if available
58 | total_size_in_bytes = int(response.headers.get('content-length', 0))
59 | progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True, desc="Downloading dataset")
60 |
61 | with open(dataset_path, 'wb') as f:
62 | for chunk in response.iter_content(chunk_size=8192):
63 | progress_bar.update(len(chunk))
64 | f.write(chunk)
65 |
66 | progress_bar.close()
67 | except Exception as e:
68 | # If download failed and partial file exists, remove it
69 | if dataset_path.exists():
70 | os.remove(dataset_path)
71 | raise
72 |
73 | # Run the benchmark in a separate thread to avoid asyncio event loop issues
74 | with concurrent.futures.ThreadPoolExecutor() as executor:
75 | future = executor.submit(
76 | run_benchmark,
77 | model=model,
78 | backend="vllm",
79 | dataset="sharegpt",
80 | dataset_path=str(dataset_path),
81 | num_prompts=num_prompts,
82 | base_url=base_url,
83 | )
84 | return future.result()
85 |
86 | if __name__ == "__main__":
87 | mcp.run()
88 |
```
--------------------------------------------------------------------------------
/benchmark_tool.py:
--------------------------------------------------------------------------------
```python
1 | import argparse
2 | from typing import Dict, List, Optional
3 | from benchmark.benchmark_serving import main as main_benchmark_serving
4 |
5 | def run_benchmark(
6 | model: str,
7 | base_url: str,
8 | backend: str = "vllm",
9 | dataset: str = "sharegpt",
10 | dataset_path: Optional[str] = None,
11 | num_prompts: int = 100,
12 | request_rate: float = 10.0,
13 | concurrent_requests: int = 10,
14 | max_tokens: int = 128,
15 | vllm_dir: Optional[str] = None,
16 | save_result: bool = True,
17 | result_filename: Optional[str] = None,
18 | api_key: Optional[str] = None,
19 | trust_remote_code: bool = False,
20 | extra_args: Optional[List[str]] = None,
21 | ) -> Dict:
22 | """
23 | Run vLLM benchmarking tool
24 |
25 | Args:
26 | model: The model to benchmark
27 | backend: Backend server to use (vllm, tgi, openai, etc.)
28 | dataset: Dataset to use for benchmarking
29 | dataset_path: Path to the dataset file
30 | num_prompts: Number of prompts to benchmark with
31 | request_rate: Requests per second
32 | concurrent_requests: Number of concurrent requests
33 | max_tokens: Maximum number of tokens to generate
34 | vllm_dir: Directory where vLLM is installed
35 | api_url: URL of the API to benchmark
36 | save_result: Whether to save benchmark results
37 | result_filename: Filename to save benchmark results
38 | api_key: API key for the backend
39 | trust_remote_code: Whether to trust remote code
40 | extra_args: Additional arguments to pass to benchmark_serving.py
41 |
42 | Returns:
43 | Dictionary with benchmark results
44 | """
45 | # Create argparse.Namespace object to pass to main_benchmark_serving
46 | args = argparse.Namespace()
47 |
48 | # Required parameters
49 | args.model = model
50 | args.backend = backend
51 | args.dataset_name = dataset
52 | args.dataset_path = dataset_path if dataset_path else "./ShareGPT_V3_unfiltered_cleaned_split.json"
53 | args.num_prompts = num_prompts
54 | args.request_rate = request_rate
55 | args.max_concurrency = concurrent_requests
56 |
57 | # Optional parameters with defaults
58 | args.host = "127.0.0.1"
59 | args.port = 8000
60 | args.endpoint = "/v1/completions"
61 | args.base_url = base_url
62 | args.seed = 0
63 | args.disable_tqdm = False
64 | args.profile = False
65 | args.use_beam_search = False
66 | args.tokenizer = None
67 | args.logprobs = None
68 | args.burstiness = 1.0
69 | args.ignore_eos = False
70 | args.percentile_metrics = "ttft,tpot,itl"
71 | args.metric_percentiles = "99"
72 | args.save_result = save_result
73 | args.save_detailed = False
74 | args.metadata = None
75 | args.result_dir = None
76 | args.result_filename = result_filename
77 | args.trust_remote_code = trust_remote_code
78 | args.tokenizer_mode = "auto"
79 | args.served_model_name = None
80 | args.lora_modules = None
81 |
82 | # Dataset-specific parameters
83 | args.sonnet_input_len = 550
84 | args.sonnet_output_len = 150
85 | args.sonnet_prefix_len = 200
86 | args.sharegpt_output_len = max_tokens
87 | args.random_input_len = 1024
88 | args.random_output_len = max_tokens
89 | args.random_range_ratio = 1.0
90 | args.random_prefix_len = 0
91 | args.hf_subset = None
92 | args.hf_split = None
93 | args.hf_output_len = max_tokens
94 | args.goodput = None
95 |
96 | # Handle extra args if provided
97 | if extra_args:
98 | for arg in extra_args:
99 | if '=' in arg:
100 | key, value = arg.split('=', 1)
101 | key = key.lstrip('-').replace('-', '_')
102 | try:
103 | # Try to convert to appropriate type
104 | if value.lower() in ('true', 'yes'):
105 | value = True
106 | elif value.lower() in ('false', 'no'):
107 | value = False
108 | elif value.isdigit():
109 | value = int(value)
110 | elif value.replace('.', '', 1).isdigit():
111 | value = float(value)
112 | except (ValueError, AttributeError):
113 | pass
114 | setattr(args, key, value)
115 |
116 | benchmark_result = main_benchmark_serving(args)
117 | return benchmark_result
118 |
```
--------------------------------------------------------------------------------
/benchmark/backend_request_func.py:
--------------------------------------------------------------------------------
```python
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | import json
4 | import os
5 | import sys
6 | import time
7 | import traceback
8 | from dataclasses import dataclass, field
9 | from typing import Optional, Union
10 |
11 | import aiohttp
12 | import huggingface_hub.constants
13 | from tqdm.asyncio import tqdm
14 | from transformers import (AutoTokenizer, PreTrainedTokenizer,
15 | PreTrainedTokenizerFast)
16 |
17 | # NOTE(simon): do not import vLLM here so the benchmark script
18 | # can run without vLLM installed.
19 |
20 | AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
21 |
22 |
23 | @dataclass
24 | class RequestFuncInput:
25 | prompt: str
26 | api_url: str
27 | prompt_len: int
28 | output_len: int
29 | model: str
30 | model_name: Optional[str] = None
31 | logprobs: Optional[int] = None
32 | extra_body: Optional[dict] = None
33 | multi_modal_content: Optional[dict] = None
34 | ignore_eos: bool = False
35 |
36 |
37 | @dataclass
38 | class RequestFuncOutput:
39 | generated_text: str = ""
40 | success: bool = False
41 | latency: float = 0.0
42 | output_tokens: int = 0
43 | ttft: float = 0.0 # Time to first token
44 | itl: list[float] = field(
45 | default_factory=list) # list of inter-token latencies
46 | tpot: float = 0.0 # avg next-token latencies
47 | prompt_len: int = 0
48 | error: str = ""
49 |
50 |
51 | async def async_request_tgi(
52 | request_func_input: RequestFuncInput,
53 | pbar: Optional[tqdm] = None,
54 | ) -> RequestFuncOutput:
55 | api_url = request_func_input.api_url
56 | assert api_url.endswith("generate_stream")
57 |
58 | async with aiohttp.ClientSession(trust_env=True,
59 | timeout=AIOHTTP_TIMEOUT) as session:
60 | params = {
61 | "max_new_tokens": request_func_input.output_len,
62 | "do_sample": True,
63 | "temperature": 0.01, # TGI does not accept 0.0 temperature.
64 | "top_p": 0.99, # TGI does not accept 1.0 top_p.
65 | "truncate": request_func_input.prompt_len,
66 | "ignore_eos_token": request_func_input.ignore_eos,
67 | }
68 | payload = {
69 | "inputs": request_func_input.prompt,
70 | "parameters": params,
71 | }
72 | output = RequestFuncOutput()
73 | output.prompt_len = request_func_input.prompt_len
74 | if request_func_input.ignore_eos:
75 | output.output_tokens = request_func_input.output_len
76 | else:
77 | output.output_tokens = None
78 |
79 | ttft = 0.0
80 | st = time.perf_counter()
81 | most_recent_timestamp = st
82 | try:
83 | async with session.post(url=api_url, json=payload) as response:
84 | if response.status == 200:
85 | async for chunk_bytes in response.content:
86 | chunk_bytes = chunk_bytes.strip()
87 | if not chunk_bytes:
88 | continue
89 | chunk_bytes = chunk_bytes.decode("utf-8")
90 |
91 | # NOTE: Sometimes TGI returns a ping response without
92 | # any data, we should skip it.
93 | if chunk_bytes.startswith(":"):
94 | continue
95 | chunk = chunk_bytes.removeprefix("data:")
96 |
97 | data = json.loads(chunk)
98 | timestamp = time.perf_counter()
99 | # First token
100 | if ttft == 0.0:
101 | ttft = time.perf_counter() - st
102 | output.ttft = ttft
103 |
104 | # Decoding phase
105 | else:
106 | output.itl.append(timestamp -
107 | most_recent_timestamp)
108 |
109 | most_recent_timestamp = timestamp
110 |
111 | output.latency = most_recent_timestamp - st
112 | output.success = True
113 | output.generated_text = data["generated_text"]
114 | else:
115 | output.error = response.reason or ""
116 | output.success = False
117 | except Exception:
118 | output.success = False
119 | exc_info = sys.exc_info()
120 | output.error = "".join(traceback.format_exception(*exc_info))
121 |
122 | if pbar:
123 | pbar.update(1)
124 | return output
125 |
126 |
127 | async def async_request_trt_llm(
128 | request_func_input: RequestFuncInput,
129 | pbar: Optional[tqdm] = None,
130 | ) -> RequestFuncOutput:
131 | api_url = request_func_input.api_url
132 | assert api_url.endswith("generate_stream")
133 |
134 | async with aiohttp.ClientSession(trust_env=True,
135 | timeout=AIOHTTP_TIMEOUT) as session:
136 | payload = {
137 | "accumulate_tokens": True,
138 | "text_input": request_func_input.prompt,
139 | "temperature": 0.0,
140 | "top_p": 1.0,
141 | "max_tokens": request_func_input.output_len,
142 | "stream": True,
143 | }
144 | if request_func_input.ignore_eos:
145 | payload["min_length"] = request_func_input.output_len
146 | output = RequestFuncOutput()
147 | output.prompt_len = request_func_input.prompt_len
148 |
149 | ttft = 0.0
150 | st = time.perf_counter()
151 | most_recent_timestamp = st
152 | try:
153 | async with session.post(url=api_url, json=payload) as response:
154 | if response.status == 200:
155 | async for chunk_bytes in response.content:
156 | chunk_bytes = chunk_bytes.strip()
157 | if not chunk_bytes:
158 | continue
159 |
160 | chunk = chunk_bytes.decode("utf-8").removeprefix(
161 | "data:")
162 |
163 | data = json.loads(chunk)
164 | output.generated_text += data["text_output"]
165 | timestamp = time.perf_counter()
166 | # First token
167 | if ttft == 0.0:
168 | ttft = timestamp - st
169 | output.ttft = ttft
170 |
171 | # Decoding phase
172 | else:
173 | output.itl.append(timestamp -
174 | most_recent_timestamp)
175 |
176 | most_recent_timestamp = timestamp
177 |
178 | output.latency = most_recent_timestamp - st
179 | output.success = True
180 |
181 | else:
182 | output.error = response.reason or ""
183 | output.success = False
184 | except Exception:
185 | output.success = False
186 | exc_info = sys.exc_info()
187 | output.error = "".join(traceback.format_exception(*exc_info))
188 |
189 | if pbar:
190 | pbar.update(1)
191 | return output
192 |
193 |
194 | async def async_request_deepspeed_mii(
195 | request_func_input: RequestFuncInput,
196 | pbar: Optional[tqdm] = None,
197 | ) -> RequestFuncOutput:
198 | async with aiohttp.ClientSession(trust_env=True,
199 | timeout=AIOHTTP_TIMEOUT) as session:
200 |
201 | payload = {
202 | "prompt": request_func_input.prompt,
203 | "max_tokens": request_func_input.output_len,
204 | "temperature": 0.01, # deepspeed-mii does not accept 0.0 temp.
205 | "top_p": 1.0,
206 | }
207 | output = RequestFuncOutput()
208 | output.prompt_len = request_func_input.prompt_len
209 |
210 | # NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
211 | # will use 0 as placeholder.
212 | # See https://github.com/microsoft/DeepSpeed-MII/pull/311
213 | output.ttft = 0
214 |
215 | st = time.perf_counter()
216 | try:
217 | async with session.post(url=request_func_input.api_url,
218 | json=payload) as response:
219 | if response.status == 200:
220 | parsed_resp = await response.json()
221 | output.latency = time.perf_counter() - st
222 | output.generated_text = parsed_resp["text"][0]
223 | output.success = True
224 | else:
225 | output.error = response.reason or ""
226 | output.success = False
227 | except Exception:
228 | output.success = False
229 | exc_info = sys.exc_info()
230 | output.error = "".join(traceback.format_exception(*exc_info))
231 |
232 | if pbar:
233 | pbar.update(1)
234 | return output
235 |
236 |
237 | async def async_request_openai_completions(
238 | request_func_input: RequestFuncInput,
239 | pbar: Optional[tqdm] = None,
240 | ) -> RequestFuncOutput:
241 | api_url = request_func_input.api_url
242 | assert api_url.endswith(
243 | ("completions", "profile")
244 | ), "OpenAI Completions API URL must end with 'completions' or 'profile'."
245 |
246 | async with aiohttp.ClientSession(trust_env=True,
247 | timeout=AIOHTTP_TIMEOUT) as session:
248 | payload = {
249 | "model": request_func_input.model_name \
250 | if request_func_input.model_name else request_func_input.model,
251 | "prompt": request_func_input.prompt,
252 | "temperature": 0.0,
253 | "max_tokens": request_func_input.output_len,
254 | "logprobs": request_func_input.logprobs,
255 | "stream": True,
256 | "stream_options": {
257 | "include_usage": True,
258 | },
259 | }
260 | if request_func_input.ignore_eos:
261 | payload["ignore_eos"] = request_func_input.ignore_eos
262 | if request_func_input.extra_body:
263 | payload.update(request_func_input.extra_body)
264 | headers = {
265 | "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
266 | }
267 |
268 | output = RequestFuncOutput()
269 | output.prompt_len = request_func_input.prompt_len
270 |
271 | generated_text = ""
272 | st = time.perf_counter()
273 | most_recent_timestamp = st
274 | try:
275 | async with session.post(url=api_url, json=payload,
276 | headers=headers) as response:
277 | if response.status == 200:
278 | first_chunk_received = False
279 | async for chunk_bytes in response.content:
280 | chunk_bytes = chunk_bytes.strip()
281 | if not chunk_bytes:
282 | continue
283 |
284 | chunk = chunk_bytes.decode("utf-8").removeprefix(
285 | "data: ")
286 | if chunk != "[DONE]":
287 | data = json.loads(chunk)
288 |
289 | # NOTE: Some completion API might have a last
290 | # usage summary response without a token so we
291 | # want to check a token was generated
292 | if choices := data.get("choices"):
293 | # Note that text could be empty here
294 | # e.g. for special tokens
295 | text = choices[0].get("text")
296 | timestamp = time.perf_counter()
297 | # First token
298 | if not first_chunk_received:
299 | first_chunk_received = True
300 | ttft = time.perf_counter() - st
301 | output.ttft = ttft
302 |
303 | # Decoding phase
304 | else:
305 | output.itl.append(timestamp -
306 | most_recent_timestamp)
307 |
308 | most_recent_timestamp = timestamp
309 | generated_text += text or ""
310 | elif usage := data.get("usage"):
311 | output.output_tokens = usage.get(
312 | "completion_tokens")
313 | if first_chunk_received:
314 | output.success = True
315 | else:
316 | output.success = False
317 | output.error = (
318 | "Never received a valid chunk to calculate TTFT."
319 | "This response will be marked as failed!")
320 | output.generated_text = generated_text
321 | output.latency = most_recent_timestamp - st
322 | else:
323 | output.error = response.reason or ""
324 | output.success = False
325 | except Exception:
326 | output.success = False
327 | exc_info = sys.exc_info()
328 | output.error = "".join(traceback.format_exception(*exc_info))
329 |
330 | if pbar:
331 | pbar.update(1)
332 | return output
333 |
334 |
335 | async def async_request_openai_chat_completions(
336 | request_func_input: RequestFuncInput,
337 | pbar: Optional[tqdm] = None,
338 | ) -> RequestFuncOutput:
339 | api_url = request_func_input.api_url
340 | assert api_url.endswith(
341 | ("chat/completions", "profile")
342 | ), "OpenAI Chat Completions API URL must end with 'chat/completions'."
343 |
344 | async with aiohttp.ClientSession(trust_env=True,
345 | timeout=AIOHTTP_TIMEOUT) as session:
346 | content = [{"type": "text", "text": request_func_input.prompt}]
347 | if request_func_input.multi_modal_content:
348 | content.append(request_func_input.multi_modal_content)
349 | payload = {
350 | "model": request_func_input.model_name \
351 | if request_func_input.model_name else request_func_input.model,
352 | "messages": [
353 | {
354 | "role": "user",
355 | "content": content
356 | },
357 | ],
358 | "temperature": 0.0,
359 | "max_completion_tokens": request_func_input.output_len,
360 | "stream": True,
361 | "stream_options": {
362 | "include_usage": True,
363 | },
364 | }
365 | if request_func_input.ignore_eos:
366 | payload["ignore_eos"] = request_func_input.ignore_eos
367 | if request_func_input.extra_body:
368 | payload.update(request_func_input.extra_body)
369 | headers = {
370 | "Content-Type": "application/json",
371 | "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
372 | }
373 |
374 | output = RequestFuncOutput()
375 | output.prompt_len = request_func_input.prompt_len
376 |
377 | generated_text = ""
378 | ttft = 0.0
379 | st = time.perf_counter()
380 | most_recent_timestamp = st
381 | try:
382 | async with session.post(url=api_url, json=payload,
383 | headers=headers) as response:
384 | if response.status == 200:
385 | async for chunk_bytes in response.content:
386 | chunk_bytes = chunk_bytes.strip()
387 | if not chunk_bytes:
388 | continue
389 |
390 | chunk = chunk_bytes.decode("utf-8").removeprefix(
391 | "data: ")
392 | if chunk != "[DONE]":
393 | timestamp = time.perf_counter()
394 | data = json.loads(chunk)
395 |
396 | if choices := data.get("choices"):
397 | content = choices[0]["delta"].get("content")
398 | # First token
399 | if ttft == 0.0:
400 | ttft = timestamp - st
401 | output.ttft = ttft
402 |
403 | # Decoding phase
404 | else:
405 | output.itl.append(timestamp -
406 | most_recent_timestamp)
407 |
408 | generated_text += content or ""
409 | elif usage := data.get("usage"):
410 | output.output_tokens = usage.get(
411 | "completion_tokens")
412 |
413 | most_recent_timestamp = timestamp
414 |
415 | output.generated_text = generated_text
416 | output.success = True
417 | output.latency = most_recent_timestamp - st
418 | else:
419 | output.error = response.reason or ""
420 | output.success = False
421 | except Exception:
422 | output.success = False
423 | exc_info = sys.exc_info()
424 | output.error = "".join(traceback.format_exception(*exc_info))
425 |
426 | if pbar:
427 | pbar.update(1)
428 | return output
429 |
430 |
431 | def get_model(pretrained_model_name_or_path: str) -> str:
432 | if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
433 | from modelscope import snapshot_download
434 |
435 | from vllm.model_executor.model_loader.weight_utils import get_lock
436 |
437 | # Use file lock to prevent multiple processes from
438 | # downloading the same model weights at the same time.
439 | with get_lock(pretrained_model_name_or_path):
440 | model_path = snapshot_download(
441 | model_id=pretrained_model_name_or_path,
442 | local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
443 | ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
444 |
445 | return model_path
446 | return pretrained_model_name_or_path
447 |
448 |
449 | def get_tokenizer(
450 | pretrained_model_name_or_path: str,
451 | tokenizer_mode: str = "auto",
452 | trust_remote_code: bool = False,
453 | **kwargs,
454 | ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
455 | if pretrained_model_name_or_path is not None and not os.path.exists(
456 | pretrained_model_name_or_path):
457 | pretrained_model_name_or_path = get_model(
458 | pretrained_model_name_or_path)
459 | if tokenizer_mode == "slow":
460 | if kwargs.get("use_fast", False):
461 | raise ValueError(
462 | "Cannot use the fast tokenizer in slow tokenizer mode.")
463 | kwargs["use_fast"] = False
464 | if tokenizer_mode == "mistral":
465 | try:
466 | from vllm.transformers_utils.tokenizer import MistralTokenizer
467 | except ImportError as e:
468 | raise ImportError("MistralTokenizer requires vllm package.\n"
469 | "Please install it with `pip install vllm` "
470 | "to use mistral tokenizer mode.") from e
471 | return MistralTokenizer.from_pretrained(
472 | str(pretrained_model_name_or_path))
473 | else:
474 | return AutoTokenizer.from_pretrained(
475 | pretrained_model_name_or_path,
476 | trust_remote_code=trust_remote_code,
477 | **kwargs,
478 | )
479 |
480 |
481 | ASYNC_REQUEST_FUNCS = {
482 | "tgi": async_request_tgi,
483 | "vllm": async_request_openai_completions,
484 | "lmdeploy": async_request_openai_completions,
485 | "deepspeed-mii": async_request_deepspeed_mii,
486 | "openai": async_request_openai_completions,
487 | "openai-chat": async_request_openai_chat_completions,
488 | "tensorrt-llm": async_request_trt_llm,
489 | "scalellm": async_request_openai_completions,
490 | "sglang": async_request_openai_completions,
491 | }
```
--------------------------------------------------------------------------------
/benchmark/benchmark_dataset.py:
--------------------------------------------------------------------------------
```python
1 | # SPDX-License-Identifier: Apache-2.0
2 | """
3 | This module defines a framework for sampling benchmark requests from various
4 | datasets. Each dataset subclass of BenchmarkDataset must implement sample
5 | generation. Supported dataset types include:
6 | - ShareGPT
7 | - Random (synthetic)
8 | - Sonnet
9 | - BurstGPT
10 | - HuggingFace
11 | - VisionArena
12 |
13 | TODO: Implement CustomDataset to parse a JSON file and convert its contents into
14 | SampleRequest instances, similar to the approach used in ShareGPT.
15 | """
16 |
17 | import base64
18 | import io
19 | import json
20 | import logging
21 | import random
22 | from abc import ABC, abstractmethod
23 | from collections.abc import Mapping
24 | from dataclasses import dataclass
25 | from functools import cache
26 | from typing import Any, Optional, Union
27 |
28 | import numpy as np
29 | import pandas as pd
30 | from datasets import load_dataset
31 | from PIL import Image
32 | from transformers import PreTrainedTokenizerBase
33 |
34 | from vllm.lora.request import LoRARequest
35 | from vllm.lora.utils import get_adapter_absolute_path
36 | from vllm.multimodal import MultiModalDataDict
37 | from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
38 |
39 | logger = logging.getLogger(__name__)
40 |
41 | # -----------------------------------------------------------------------------
42 | # Data Classes
43 | # -----------------------------------------------------------------------------
44 |
45 |
46 | @dataclass
47 | class SampleRequest:
48 | """
49 | Represents a single inference request for benchmarking.
50 | """
51 |
52 | prompt: Union[str, Any]
53 | prompt_len: int
54 | expected_output_len: int
55 | multi_modal_data: Optional[Union[MultiModalDataDict, dict]] = None
56 | lora_request: Optional[LoRARequest] = None
57 |
58 |
59 | # -----------------------------------------------------------------------------
60 | # Benchmark Dataset Base Class
61 | # -----------------------------------------------------------------------------
62 |
63 |
64 | class BenchmarkDataset(ABC):
65 | DEFAULT_SEED = 0
66 |
67 | def __init__(
68 | self,
69 | dataset_path: Optional[str] = None,
70 | random_seed: int = DEFAULT_SEED,
71 | ) -> None:
72 | """
73 | Initialize the BenchmarkDataset with an optional dataset path and random
74 | seed. Args:
75 | dataset_path (Optional[str]): Path to the dataset. If None, it
76 | indicates that a default or random dataset might be used.
77 | random_seed (int): Seed value for reproducible shuffling or
78 | sampling. Defaults to DEFAULT_SEED.
79 | """
80 | self.dataset_path = dataset_path
81 | # Set the random seed, ensuring that a None value is replaced with the
82 | # default seed.
83 | self.random_seed = (random_seed
84 | if random_seed is not None else self.DEFAULT_SEED)
85 | self.data = None
86 |
87 | def apply_multimodal_chat_transformation(
88 | self,
89 | prompt: str,
90 | mm_content: Optional[MultiModalDataDict] = None) -> list[dict]:
91 | """
92 | Transform a prompt and optional multimodal content into a chat format.
93 | This method is used for chat models that expect a specific conversation
94 | format.
95 | """
96 | content = [{"text": prompt, "type": "text"}]
97 | if mm_content is not None:
98 | content.append(mm_content)
99 | return [{"role": "user", "content": content}]
100 |
101 | def load_data(self) -> None:
102 | """
103 | Load data from the dataset path into self.data.
104 |
105 | This method must be overridden by subclasses since the method to load
106 | data will vary depending on the dataset format and source.
107 |
108 | Raises:
109 | NotImplementedError: If a subclass does not implement this method.
110 | """
111 | # TODO (jenniferzhao): add support for downloading data
112 | raise NotImplementedError(
113 | "load_data must be implemented in subclasses.")
114 |
115 | def get_random_lora_request(
116 | self,
117 | tokenizer: PreTrainedTokenizerBase,
118 | max_loras: Optional[int] = None,
119 | lora_path: Optional[str] = None,
120 | ) -> tuple[Optional[LoRARequest], AnyTokenizer]:
121 | """
122 | Optionally select a random LoRA request and return its associated
123 | tokenizer.
124 |
125 | This method is used when LoRA parameters are provided. It randomly
126 | selects a LoRA based on max_loras and retrieves a cached tokenizer for
127 | that LoRA if available. Otherwise, it returns the base tokenizer.
128 |
129 | Args:
130 | tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no
131 | LoRA is selected. max_loras (Optional[int]): The maximum number of
132 | LoRAs available. If None, LoRA is not used. lora_path
133 | (Optional[str]): Path to the LoRA parameters on disk. If None, LoRA
134 | is not used.
135 |
136 | Returns:
137 | tuple[Optional[LoRARequest], AnyTokenizer]: A tuple where the first
138 | element is a LoRARequest (or None if not applicable) and the second
139 | element is the tokenizer associated with the LoRA request (or the
140 | base tokenizer).
141 | """
142 | if max_loras is None or lora_path is None:
143 | return None, tokenizer
144 |
145 | # Generate a random LoRA ID in the range [1, max_loras].
146 | lora_id = random.randint(1, max_loras)
147 | lora_request = LoRARequest(
148 | lora_name=str(lora_id),
149 | lora_int_id=lora_id,
150 | lora_path=lora_path_on_disk(lora_path),
151 | )
152 | if lora_id not in lora_tokenizer_cache:
153 | lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request)
154 | # Return lora_request and the cached tokenizer if available; otherwise,
155 | # return the base tokenizer
156 | return lora_request, lora_tokenizer_cache[lora_id] or tokenizer
157 |
158 | @abstractmethod
159 | def sample(self, tokenizer: PreTrainedTokenizerBase,
160 | num_requests: int) -> list[SampleRequest]:
161 | """
162 | Abstract method to generate sample requests from the dataset.
163 |
164 | Subclasses must override this method to implement dataset-specific logic
165 | for generating a list of SampleRequest objects.
166 |
167 | Args:
168 | tokenizer (PreTrainedTokenizerBase): The tokenizer to be used
169 | for processing the dataset's text.
170 | num_requests (int): The number of sample requests to generate.
171 |
172 | Returns:
173 | list[SampleRequest]: A list of sample requests generated from the
174 | dataset.
175 | """
176 | raise NotImplementedError("sample must be implemented in subclasses.")
177 |
178 | def maybe_oversample_requests(self, requests: list[SampleRequest],
179 | num_requests: int) -> None:
180 | """
181 | Oversamples the list of requests if its size is less than the desired
182 | number.
183 |
184 | Args:
185 | requests (List[SampleRequest]): The current list of sampled
186 | requests. num_requests (int): The target number of requests.
187 | """
188 | if len(requests) < num_requests:
189 | random.seed(self.random_seed)
190 | additional = random.choices(requests,
191 | k=num_requests - len(requests))
192 | requests.extend(additional)
193 | logger.info("Oversampled requests to reach %d total samples.",
194 | num_requests)
195 |
196 |
197 | # -----------------------------------------------------------------------------
198 | # Utility Functions and Global Caches
199 | # -----------------------------------------------------------------------------
200 |
201 |
202 | def is_valid_sequence(
203 | prompt_len: int,
204 | output_len: int,
205 | min_len: int = 4,
206 | max_prompt_len: int = 1024,
207 | max_total_len: int = 2048,
208 | skip_min_output_len_check: bool = False,
209 | ) -> bool:
210 | """
211 | Validate a sequence based on prompt and output lengths.
212 |
213 | Default pruning criteria are copied from the original `sample_hf_requests`
214 | and `sample_sharegpt_requests` functions in benchmark_serving.py, as well as
215 | from `sample_requests` in benchmark_throughput.py.
216 | """
217 | # Check for invalid conditions
218 | prompt_too_short = prompt_len < min_len
219 | output_too_short = (not skip_min_output_len_check) and (output_len
220 | < min_len)
221 | prompt_too_long = prompt_len > max_prompt_len
222 | combined_too_long = (prompt_len + output_len) > max_total_len
223 |
224 | # Return True if none of the invalid conditions are met
225 | return not (prompt_too_short or output_too_short or prompt_too_long
226 | or combined_too_long)
227 |
228 |
229 | @cache
230 | def lora_path_on_disk(lora_path: str) -> str:
231 | return get_adapter_absolute_path(lora_path)
232 |
233 |
234 | # Global cache for LoRA tokenizers.
235 | lora_tokenizer_cache: dict[int, AnyTokenizer] = {}
236 |
237 |
238 | def process_image(image: Any) -> Mapping[str, Any]:
239 | """
240 | Process a single image input and return a multimedia content dictionary.
241 |
242 | For a PIL.Image.Image input:
243 | - Converts the image to RGB.
244 | - Saves the image as a JPEG in-memory.
245 | - Encodes the JPEG data as a base64 string.
246 | - Returns a dictionary with the image as a base64 data URL.
247 |
248 | For a string input:
249 | - Treats the string as a URL or file path.
250 | - Prepends "file://" if the string doesn't start with "http://" or
251 | "file://".
252 | - Returns a dictionary with the image URL.
253 |
254 | Raises:
255 | ValueError: If the input is neither a PIL.Image.Image nor a string.
256 | """
257 | if isinstance(image, Image.Image):
258 | image = image.convert("RGB")
259 | with io.BytesIO() as image_data:
260 | image.save(image_data, format="JPEG")
261 | image_base64 = base64.b64encode(
262 | image_data.getvalue()).decode("utf-8")
263 | return {
264 | "type": "image_url",
265 | "image_url": {
266 | "url": f"data:image/jpeg;base64,{image_base64}"
267 | },
268 | }
269 |
270 | if isinstance(image, str):
271 | image_url = (image if image.startswith(
272 | ("http://", "file://")) else f"file://{image}")
273 | return {"type": "image_url", "image_url": {"url": image_url}}
274 |
275 | raise ValueError(
276 | f"Invalid image input {image}. Must be a PIL.Image.Image or str.")
277 |
278 |
279 | # -----------------------------------------------------------------------------
280 | # Random Dataset Implementation (Synthetic Data)
281 | # -----------------------------------------------------------------------------
282 |
283 |
284 | class RandomDataset(BenchmarkDataset):
285 | # Default values copied from benchmark_serving.py for the random dataset.
286 | DEFAULT_PREFIX_LEN = 0
287 | DEFAULT_RANGE_RATIO = 1.0
288 | DEFAULT_INPUT_LEN = 1024
289 | DEFAULT_OUTPUT_LEN = 128
290 |
291 | def __init__(
292 | self,
293 | **kwargs,
294 | ) -> None:
295 | super().__init__(**kwargs)
296 |
297 | def sample(
298 | self,
299 | tokenizer: PreTrainedTokenizerBase,
300 | num_requests: int,
301 | prefix_len: int = DEFAULT_PREFIX_LEN,
302 | range_ratio: float = DEFAULT_RANGE_RATIO,
303 | input_len: int = DEFAULT_INPUT_LEN,
304 | output_len: int = DEFAULT_OUTPUT_LEN,
305 | **kwargs,
306 | ) -> list[SampleRequest]:
307 | vocab_size = tokenizer.vocab_size
308 |
309 | prefix_token_ids = (np.random.randint(
310 | 0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else [])
311 |
312 | input_low = int(input_len * range_ratio)
313 | output_low = int(output_len * range_ratio)
314 |
315 | input_lens = np.random.randint(input_low,
316 | input_len + 1,
317 | size=num_requests)
318 | output_lens = np.random.randint(output_low,
319 | output_len + 1,
320 | size=num_requests)
321 | offsets = np.random.randint(0, vocab_size, size=num_requests)
322 |
323 | requests = []
324 | for i in range(num_requests):
325 | inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) %
326 | vocab_size).tolist()
327 | token_sequence = prefix_token_ids + inner_seq
328 | prompt = tokenizer.decode(token_sequence)
329 | total_input_len = prefix_len + int(input_lens[i])
330 | requests.append(
331 | SampleRequest(
332 | prompt=prompt,
333 | prompt_len=total_input_len,
334 | expected_output_len=int(output_lens[i]),
335 | ))
336 | return requests
337 |
338 |
339 | # -----------------------------------------------------------------------------
340 | # ShareGPT Dataset Implementation
341 | # -----------------------------------------------------------------------------
342 |
343 |
344 | class ShareGPTDataset(BenchmarkDataset):
345 | """
346 | Implements the ShareGPT dataset. Loads data from a JSON file and generates
347 | sample requests based on conversation turns.
348 | """
349 |
350 | def __init__(self, **kwargs) -> None:
351 | super().__init__(**kwargs)
352 | self.load_data()
353 |
354 | def load_data(self) -> None:
355 | if self.dataset_path is None:
356 | raise ValueError("dataset_path must be provided for loading data.")
357 |
358 | with open(self.dataset_path, encoding="utf-8") as f:
359 | self.data = json.load(f)
360 | # Filter entries with at least two conversation turns.
361 | self.data = [
362 | entry for entry in self.data
363 | if "conversations" in entry and len(entry["conversations"]) >= 2
364 | ]
365 | random.seed(self.random_seed)
366 | random.shuffle(self.data)
367 |
368 | def sample(
369 | self,
370 | tokenizer: PreTrainedTokenizerBase,
371 | num_requests: int,
372 | lora_path: Optional[str] = None,
373 | max_loras: Optional[int] = None,
374 | output_len: Optional[int] = None,
375 | enable_multimodal_chat: bool = False,
376 | **kwargs,
377 | ) -> list:
378 | samples: list = []
379 | for entry in self.data:
380 | if len(samples) >= num_requests:
381 | break
382 | prompt, completion = (
383 | entry["conversations"][0]["value"],
384 | entry["conversations"][1]["value"],
385 | )
386 |
387 | lora_request, tokenizer = self.get_random_lora_request(
388 | tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
389 | prompt_ids = tokenizer(prompt).input_ids
390 | completion_ids = tokenizer(completion).input_ids
391 | prompt_len = len(prompt_ids)
392 | new_output_len = (len(completion_ids)
393 | if output_len is None else output_len)
394 | if not is_valid_sequence(prompt_len,
395 | new_output_len,
396 | skip_min_output_len_check=output_len
397 | is not None):
398 | continue
399 | if enable_multimodal_chat:
400 | prompt = self.apply_multimodal_chat_transformation(
401 | prompt, None)
402 | samples.append(
403 | SampleRequest(
404 | prompt=prompt,
405 | prompt_len=prompt_len,
406 | expected_output_len=new_output_len,
407 | lora_request=lora_request,
408 | ))
409 | self.maybe_oversample_requests(samples, num_requests)
410 | return samples
411 |
412 |
413 | # -----------------------------------------------------------------------------
414 | # Sonnet Dataset Implementation
415 | # -----------------------------------------------------------------------------
416 |
417 |
418 | class SonnetDataset(BenchmarkDataset):
419 | """
420 | Simplified implementation of the Sonnet dataset. Loads poem lines from a
421 | text file and generates sample requests. Default values here copied from
422 | `benchmark_serving.py` for the sonnet dataset.
423 | """
424 |
425 | DEFAULT_PREFIX_LEN = 200
426 | DEFAULT_INPUT_LEN = 550
427 | DEFAULT_OUTPUT_LEN = 150
428 |
429 | def __init__(
430 | self,
431 | **kwargs,
432 | ) -> None:
433 | super().__init__(**kwargs)
434 | self.load_data()
435 |
436 | def load_data(self) -> None:
437 | if not self.dataset_path:
438 | raise ValueError("dataset_path must be provided.")
439 | with open(self.dataset_path, encoding="utf-8") as f:
440 | self.data = f.readlines()
441 |
442 | def sample(
443 | self,
444 | tokenizer,
445 | num_requests: int,
446 | prefix_len: int = DEFAULT_PREFIX_LEN,
447 | input_len: int = DEFAULT_INPUT_LEN,
448 | output_len: int = DEFAULT_OUTPUT_LEN,
449 | return_prompt_formatted: bool = False,
450 | **kwargs,
451 | ) -> list:
452 | # Calculate average token length for a poem line.
453 | tokenized_lines = [tokenizer(line).input_ids for line in self.data]
454 | avg_len = sum(len(tokens)
455 | for tokens in tokenized_lines) / len(tokenized_lines)
456 |
457 | # Build the base prompt.
458 | base_prompt = "Pick as many lines as you can from these poem lines:\n"
459 | base_msg = [{"role": "user", "content": base_prompt}]
460 | base_fmt = tokenizer.apply_chat_template(base_msg,
461 | add_generation_prompt=True,
462 | tokenize=False)
463 | base_offset = len(tokenizer(base_fmt).input_ids)
464 | if input_len <= base_offset:
465 | raise ValueError(
466 | f"'input_len' must be higher than the base prompt length "
467 | f"({base_offset}).")
468 |
469 | # Determine how many poem lines to use.
470 | num_input_lines = round((input_len - base_offset) / avg_len)
471 | num_prefix_lines = round((prefix_len - base_offset) / avg_len)
472 | prefix_lines = self.data[:num_prefix_lines]
473 |
474 | samples = []
475 | for _ in range(num_requests):
476 | extra_lines = random.choices(self.data,
477 | k=num_input_lines - num_prefix_lines)
478 | prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}"
479 | msg = [{"role": "user", "content": prompt}]
480 | prompt_formatted = tokenizer.apply_chat_template(
481 | msg, add_generation_prompt=True, tokenize=False)
482 | prompt_len = len(tokenizer(prompt_formatted).input_ids)
483 | samples.append(
484 | SampleRequest(
485 | prompt=prompt_formatted
486 | if return_prompt_formatted else prompt,
487 | prompt_len=prompt_len,
488 | expected_output_len=output_len,
489 | ))
490 | return samples
491 |
492 |
493 | # -----------------------------------------------------------------------------
494 | # BurstGPT Dataset Implementation
495 | # -----------------------------------------------------------------------------
496 |
497 |
498 | class BurstGPTDataset(BenchmarkDataset):
499 | """
500 | Implements the BurstGPT dataset. Loads data from a CSV file and generates
501 | sample requests based on synthetic prompt generation. Only rows with Model
502 | "GPT-4" and positive response tokens are used.
503 | """
504 |
505 | def __init__(self, **kwargs) -> None:
506 | super().__init__(**kwargs)
507 | self.load_data()
508 |
509 | def load_data(self, ):
510 | if self.dataset_path is None:
511 | raise ValueError("dataset_path must be provided for loading data.")
512 |
513 | df = pd.read_csv(self.dataset_path)
514 | # Filter to keep only GPT-4 rows.
515 | gpt4_df = df[df["Model"] == "GPT-4"]
516 | # Remove failed requests (where Response tokens is 0 or less).
517 | gpt4_df = gpt4_df[gpt4_df["Response tokens"] > 0]
518 | # Sample the desired number of rows.
519 | self.data = gpt4_df
520 |
521 | def _sample_loaded_data(self, num_requests: int) -> list:
522 | if num_requests <= len(self.data):
523 | data = self.data.sample(n=num_requests,
524 | random_state=self.random_seed)
525 | else:
526 | data = self.data.sample(
527 | n=num_requests,
528 | random_state=self.random_seed,
529 | replace=True,
530 | )
531 | # Convert the dataframe to a list of lists.
532 | return data.values.tolist()
533 |
534 | def sample(
535 | self,
536 | tokenizer: PreTrainedTokenizerBase,
537 | num_requests: int,
538 | max_loras: Optional[int] = None,
539 | lora_path: Optional[str] = None,
540 | **kwargs,
541 | ) -> list[SampleRequest]:
542 | samples = []
543 | data = self._sample_loaded_data(num_requests=num_requests)
544 | for i in range(num_requests):
545 | input_len = int(data[i][2])
546 | output_len = int(data[i][3])
547 | lora_req, tokenizer = self.get_random_lora_request(
548 | tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
549 | vocab_size = tokenizer.vocab_size
550 | # Generate a synthetic prompt: a list of token IDs computed as (i +
551 | # j) modulo vocab_size.
552 | token_ids = [(i + j) % vocab_size for j in range(input_len)]
553 | prompt = tokenizer.decode(token_ids)
554 | samples.append(
555 | SampleRequest(
556 | prompt=prompt,
557 | prompt_len=input_len,
558 | expected_output_len=output_len,
559 | lora_request=lora_req,
560 | ))
561 | return samples
562 |
563 |
564 | # -----------------------------------------------------------------------------
565 | # HuggingFace Dataset Implementation
566 | # -----------------------------------------------------------------------------
567 |
568 |
569 | class HuggingFaceDataset(BenchmarkDataset):
570 | """
571 | Dataset class for processing a HuggingFace dataset with conversation data
572 | and optional images.
573 | """
574 |
575 | def __init__(
576 | self,
577 | dataset_split: str,
578 | dataset_subset: Optional[str] = None,
579 | **kwargs,
580 | ) -> None:
581 | super().__init__(**kwargs)
582 | self.dataset_split = dataset_split
583 | self.dataset_subset = dataset_subset
584 |
585 | self.load_data()
586 |
587 | def load_data(self) -> None:
588 | if not self.dataset_path:
589 | raise ValueError("dataset_path must be provided for loading data.")
590 |
591 | self.data = load_dataset(
592 | self.dataset_path,
593 | name=self.dataset_subset,
594 | split=self.dataset_split,
595 | streaming=True,
596 | )
597 | if self.data.features is None or "conversations" \
598 | not in self.data.features:
599 | raise ValueError(
600 | "HuggingFaceDataset currently only supports datasets with "
601 | "a 'conversations' column like lmms-lab/LLaVA-OneVision-Data. "
602 | "Please consider contributing if you would like to add "
603 | "support for additional dataset formats.")
604 | # Shuffle and filter examples with at least 2 conversations.
605 | self.data = self.data.shuffle(seed=self.random_seed).filter(
606 | lambda x: len(x["conversations"]) >= 2)
607 |
608 | def sample(self,
609 | tokenizer: PreTrainedTokenizerBase,
610 | num_requests: int,
611 | output_len: Optional[int] = None,
612 | enable_multimodal_chat: bool = False,
613 | **kwargs) -> list:
614 | sampled_requests = []
615 | dynamic_output = output_len is None
616 |
617 | for item in self.data:
618 | if len(sampled_requests) >= num_requests:
619 | break
620 | conv = item["conversations"]
621 | prompt, completion = conv[0]["value"], conv[1]["value"]
622 |
623 | prompt_ids = tokenizer(prompt).input_ids
624 | completion_ids = tokenizer(completion).input_ids
625 | prompt_len = len(prompt_ids)
626 | completion_len = len(completion_ids)
627 | output_len = completion_len if dynamic_output else output_len
628 | assert isinstance(output_len, int) and output_len > 0
629 | if dynamic_output and not is_valid_sequence(
630 | prompt_len, completion_len):
631 | continue
632 | mm_content = process_image(
633 | item["image"]) if "image" in item else None
634 | if enable_multimodal_chat:
635 | # Note: when chat is enabled the request prompt_len is no longer
636 | # accurate and we will be using request output to count the
637 | # actual prompt len and output len
638 | prompt = self.apply_multimodal_chat_transformation(
639 | prompt, mm_content)
640 | sampled_requests.append(
641 | SampleRequest(
642 | prompt=prompt,
643 | prompt_len=prompt_len,
644 | expected_output_len=output_len,
645 | multi_modal_data=mm_content,
646 | ))
647 | self.maybe_oversample_requests(sampled_requests, num_requests)
648 | return sampled_requests
649 |
650 |
651 | # -----------------------------------------------------------------------------
652 | # Vision Arena Dataset Implementation
653 | # -----------------------------------------------------------------------------
654 |
655 |
656 | class VisionArenaDataset(HuggingFaceDataset):
657 | """
658 | Vision Arena Dataset.
659 | """
660 |
661 | DEFAULT_OUTPUT_LEN = 128
662 | VISION_ARENA_DATASET_PATH = "lmarena-ai/vision-arena-bench-v0.1"
663 |
664 | def __init__(
665 | self,
666 | **kwargs,
667 | ) -> None:
668 | super().__init__(**kwargs)
669 | if self.dataset_path != self.VISION_ARENA_DATASET_PATH:
670 | raise ValueError(f"Only support Vision Arena dataset.\
671 | This data path {self.dataset_path} is not valid.")
672 | if self.dataset_subset is None and self.dataset_split != "train":
673 | raise ValueError("Dataset split must be 'train'.")
674 |
675 | self.load_data()
676 |
677 | def load_data(self) -> None:
678 | dataset = load_dataset(
679 | self.dataset_path,
680 | name=self.dataset_subset,
681 | split=self.dataset_split,
682 | streaming=True,
683 | )
684 | self.data = dataset.shuffle(seed=self.random_seed)
685 |
686 | def sample(
687 | self,
688 | tokenizer: PreTrainedTokenizerBase,
689 | num_requests: int,
690 | output_len: Optional[int] = None,
691 | enable_multimodal_chat: bool = False,
692 | **kwargs,
693 | ) -> list:
694 | output_len = (output_len
695 | if output_len is not None else self.DEFAULT_OUTPUT_LEN)
696 | sampled_requests = []
697 | for item in self.data:
698 | if len(sampled_requests) >= num_requests:
699 | break
700 | prompt = item["turns"][0][0]["content"]
701 | mm_content = process_image(item["images"][0])
702 | prompt_len = len(tokenizer(prompt).input_ids)
703 | if enable_multimodal_chat:
704 | # Note: when chat is enabled the request prompt_len is no longer
705 | # accurate and we will be using request output to count the
706 | # actual prompt len
707 | prompt = self.apply_multimodal_chat_transformation(
708 | prompt, mm_content)
709 | sampled_requests.append(
710 | SampleRequest(
711 | prompt=prompt,
712 | prompt_len=prompt_len,
713 | expected_output_len=output_len,
714 | multi_modal_data=mm_content,
715 | ))
716 | self.maybe_oversample_requests(sampled_requests, num_requests)
717 | return sampled_requests
```
--------------------------------------------------------------------------------
/benchmark/benchmark_serving.py:
--------------------------------------------------------------------------------
```python
1 | # SPDX-License-Identifier: Apache-2.0
2 | r"""Benchmark online serving throughput.
3 |
4 | On the server side, run one of the following commands:
5 | vLLM OpenAI API server
6 | vllm serve <your_model> \
7 | --swap-space 16 \
8 | --disable-log-requests
9 |
10 | (TGI backend)
11 | ./launch_tgi_server.sh <your_model> <max_batch_total_tokens>
12 |
13 | On the client side, run:
14 | python benchmarks/benchmark_serving.py \
15 | --backend <backend> \
16 | --model <your_model> \
17 | --dataset-name sharegpt \
18 | --dataset-path <path to dataset> \
19 | --request-rate <request_rate> \ # By default <request_rate> is inf
20 | --num-prompts <num_prompts> # By default <num_prompts> is 1000
21 |
22 | when using tgi backend, add
23 | --endpoint /generate_stream
24 | to the end of the command above.
25 | """
26 | import argparse
27 | import asyncio
28 | import gc
29 | import json
30 | import os
31 | import random
32 | import time
33 | import warnings
34 | from collections.abc import AsyncGenerator, Iterable
35 | from dataclasses import dataclass
36 | from datetime import datetime
37 | from typing import Any, Optional
38 |
39 | import numpy as np
40 | from benchmark.backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
41 | RequestFuncOutput)
42 | from tqdm.asyncio import tqdm
43 | from transformers import PreTrainedTokenizerBase
44 |
45 | try:
46 | from vllm.transformers_utils.tokenizer import get_tokenizer
47 | except ImportError:
48 | from backend_request_func import get_tokenizer
49 |
50 | try:
51 | from vllm.utils import FlexibleArgumentParser
52 | except ImportError:
53 | from argparse import ArgumentParser as FlexibleArgumentParser
54 |
55 | from benchmark.benchmark_dataset import (BurstGPTDataset, HuggingFaceDataset,
56 | RandomDataset, SampleRequest, ShareGPTDataset,
57 | SonnetDataset, VisionArenaDataset)
58 | from benchmark.benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
59 |
60 | MILLISECONDS_TO_SECONDS_CONVERSION = 1000
61 |
62 |
63 | @dataclass
64 | class BenchmarkMetrics:
65 | completed: int
66 | total_input: int
67 | total_output: int
68 | request_throughput: float
69 | request_goodput: float
70 | output_throughput: float
71 | total_token_throughput: float
72 | mean_ttft_ms: float
73 | median_ttft_ms: float
74 | std_ttft_ms: float
75 | percentiles_ttft_ms: list[tuple[float, float]]
76 | mean_tpot_ms: float
77 | median_tpot_ms: float
78 | std_tpot_ms: float
79 | percentiles_tpot_ms: list[tuple[float, float]]
80 | mean_itl_ms: float
81 | median_itl_ms: float
82 | std_itl_ms: float
83 | percentiles_itl_ms: list[tuple[float, float]]
84 | # E2EL stands for end-to-end latency per request.
85 | # It is the time taken on the client side from sending
86 | # a request to receiving a complete response.
87 | mean_e2el_ms: float
88 | median_e2el_ms: float
89 | std_e2el_ms: float
90 | percentiles_e2el_ms: list[tuple[float, float]]
91 |
92 |
93 | async def get_request(
94 | input_requests: list[SampleRequest],
95 | request_rate: float,
96 | burstiness: float = 1.0,
97 | ) -> AsyncGenerator[SampleRequest, None]:
98 | """
99 | Asynchronously generates requests at a specified rate
100 | with OPTIONAL burstiness.
101 |
102 | Args:
103 | input_requests:
104 | A list of input requests, each represented as a SampleRequest.
105 | request_rate:
106 | The rate at which requests are generated (requests/s).
107 | burstiness (optional):
108 | The burstiness factor of the request generation.
109 | Only takes effect when request_rate is not inf.
110 | Default value is 1, which follows a Poisson process.
111 | Otherwise, the request intervals follow a gamma distribution.
112 | A lower burstiness value (0 < burstiness < 1) results
113 | in more bursty requests, while a higher burstiness value
114 | (burstiness > 1) results in a more uniform arrival of requests.
115 | """
116 | input_requests: Iterable[SampleRequest] = iter(input_requests)
117 |
118 | # Calculate scale parameter theta to maintain the desired request_rate.
119 | assert burstiness > 0, (
120 | f"A positive burstiness factor is expected, but given {burstiness}.")
121 | theta = 1.0 / (request_rate * burstiness)
122 |
123 | for request in input_requests:
124 | yield request
125 |
126 | if request_rate == float("inf"):
127 | # If the request rate is infinity, then we don't need to wait.
128 | continue
129 |
130 | # Sample the request interval from the gamma distribution.
131 | # If burstiness is 1, it follows exponential distribution.
132 | interval = np.random.gamma(shape=burstiness, scale=theta)
133 | # The next request will be sent after the interval.
134 | await asyncio.sleep(interval)
135 |
136 |
137 | def calculate_metrics(
138 | input_requests: list[SampleRequest],
139 | outputs: list[RequestFuncOutput],
140 | dur_s: float,
141 | tokenizer: PreTrainedTokenizerBase,
142 | selected_percentile_metrics: list[str],
143 | selected_percentiles: list[float],
144 | goodput_config_dict: dict[str, float],
145 | ) -> tuple[BenchmarkMetrics, list[int]]:
146 | actual_output_lens: list[int] = []
147 | total_input = 0
148 | completed = 0
149 | good_completed = 0
150 | itls: list[float] = []
151 | tpots: list[float] = []
152 | all_tpots: list[float] = []
153 | ttfts: list[float] = []
154 | e2els: list[float] = []
155 | for i in range(len(outputs)):
156 | if outputs[i].success:
157 | output_len = outputs[i].output_tokens
158 |
159 | if output_len is None:
160 | # We use the tokenizer to count the number of output tokens
161 | # for some serving backends instead of looking at
162 | # len(outputs[i].itl) since multiple output tokens may be
163 | # bundled together
164 | # Note : this may inflate the output token count slightly
165 | output_len = len(
166 | tokenizer(outputs[i].generated_text,
167 | add_special_tokens=False).input_ids)
168 | actual_output_lens.append(output_len)
169 | total_input += input_requests[i].prompt_len
170 | tpot = 0
171 | if output_len > 1:
172 | latency_minus_ttft = outputs[i].latency - outputs[i].ttft
173 | tpot = latency_minus_ttft / (output_len - 1)
174 | tpots.append(tpot)
175 | # Note: if output_len <= 1, we regard tpot as 0 for goodput
176 | all_tpots.append(tpot)
177 | itls += outputs[i].itl
178 | ttfts.append(outputs[i].ttft)
179 | e2els.append(outputs[i].latency)
180 | completed += 1
181 | else:
182 | actual_output_lens.append(0)
183 |
184 | if goodput_config_dict:
185 | valid_metrics = []
186 | slo_values = []
187 |
188 | if "ttft" in goodput_config_dict:
189 | valid_metrics.append(ttfts)
190 | slo_values.append(goodput_config_dict["ttft"] /
191 | MILLISECONDS_TO_SECONDS_CONVERSION)
192 | if "tpot" in goodput_config_dict:
193 | valid_metrics.append(all_tpots)
194 | slo_values.append(goodput_config_dict["tpot"] /
195 | MILLISECONDS_TO_SECONDS_CONVERSION)
196 | if "e2el" in goodput_config_dict:
197 | valid_metrics.append(e2els)
198 | slo_values.append(goodput_config_dict["e2el"] /
199 | MILLISECONDS_TO_SECONDS_CONVERSION)
200 |
201 | for req_metric in zip(*valid_metrics):
202 | is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
203 | if is_good_req:
204 | good_completed += 1
205 |
206 | if completed == 0:
207 | warnings.warn(
208 | "All requests failed. This is likely due to a misconfiguration "
209 | "on the benchmark arguments.",
210 | stacklevel=2)
211 | metrics = BenchmarkMetrics(
212 | completed=completed,
213 | total_input=total_input,
214 | total_output=sum(actual_output_lens),
215 | request_throughput=completed / dur_s,
216 | request_goodput=good_completed / dur_s,
217 | output_throughput=sum(actual_output_lens) / dur_s,
218 | total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
219 | mean_ttft_ms=np.mean(ttfts or 0) *
220 | 1000, # ttfts is empty if streaming is not supported by backend
221 | std_ttft_ms=np.std(ttfts or 0) * 1000,
222 | median_ttft_ms=np.median(ttfts or 0) * 1000,
223 | percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000)
224 | for p in selected_percentiles],
225 | mean_tpot_ms=np.mean(tpots or 0) * 1000,
226 | std_tpot_ms=np.std(tpots or 0) * 1000,
227 | median_tpot_ms=np.median(tpots or 0) * 1000,
228 | percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000)
229 | for p in selected_percentiles],
230 | mean_itl_ms=np.mean(itls or 0) * 1000,
231 | std_itl_ms=np.std(itls or 0) * 1000,
232 | median_itl_ms=np.median(itls or 0) * 1000,
233 | percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000)
234 | for p in selected_percentiles],
235 | mean_e2el_ms=np.mean(e2els or 0) * 1000,
236 | std_e2el_ms=np.std(e2els or 0) * 1000,
237 | median_e2el_ms=np.median(e2els or 0) * 1000,
238 | percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
239 | for p in selected_percentiles],
240 | )
241 |
242 | return metrics, actual_output_lens
243 |
244 |
245 | async def benchmark(
246 | backend: str,
247 | api_url: str,
248 | base_url: str,
249 | model_id: str,
250 | model_name: str,
251 | tokenizer: PreTrainedTokenizerBase,
252 | input_requests: list[SampleRequest],
253 | logprobs: Optional[int],
254 | request_rate: float,
255 | burstiness: float,
256 | disable_tqdm: bool,
257 | profile: bool,
258 | selected_percentile_metrics: list[str],
259 | selected_percentiles: list[float],
260 | ignore_eos: bool,
261 | goodput_config_dict: dict[str, float],
262 | max_concurrency: Optional[int],
263 | lora_modules: Optional[Iterable[str]],
264 | ):
265 | if backend in ASYNC_REQUEST_FUNCS:
266 | request_func = ASYNC_REQUEST_FUNCS[backend]
267 | else:
268 | raise ValueError(f"Unknown backend: {backend}")
269 |
270 | print("Starting initial single prompt test run...")
271 | test_prompt, test_prompt_len, test_output_len, test_mm_content = \
272 | input_requests[0].prompt, input_requests[0].prompt_len, \
273 | input_requests[0].expected_output_len, \
274 | input_requests[0].multi_modal_data
275 |
276 | if backend != "openai-chat" and test_mm_content is not None:
277 | # multi-modal benchmark is only available on OpenAI Chat backend.
278 | raise ValueError(
279 | "Multi-modal content is only supported on 'openai-chat' backend.")
280 | assert test_mm_content is None or isinstance(test_mm_content, dict)
281 | test_input = RequestFuncInput(
282 | model=model_id,
283 | model_name=model_name,
284 | prompt=test_prompt,
285 | api_url=api_url,
286 | prompt_len=test_prompt_len,
287 | output_len=test_output_len,
288 | logprobs=logprobs,
289 | multi_modal_content=test_mm_content,
290 | ignore_eos=ignore_eos,
291 | )
292 |
293 | test_output = await request_func(request_func_input=test_input)
294 | if not test_output.success:
295 | raise ValueError(
296 | "Initial test run failed - Please make sure benchmark arguments "
297 | f"are correctly specified. Error: {test_output.error}")
298 | else:
299 | print("Initial test run completed. Starting main benchmark run...")
300 |
301 | if lora_modules:
302 | # For each input request, choose a LoRA module at random.
303 | lora_modules = iter(
304 | [random.choice(lora_modules) \
305 | for _ in range(len(input_requests))])
306 |
307 | if profile:
308 | print("Starting profiler...")
309 | profile_input = RequestFuncInput(model=model_id,
310 | model_name=model_name,
311 | prompt=test_prompt,
312 | api_url=base_url + "/start_profile",
313 | prompt_len=test_prompt_len,
314 | output_len=test_output_len,
315 | logprobs=logprobs,
316 | multi_modal_content=test_mm_content,
317 | ignore_eos=ignore_eos)
318 | profile_output = await request_func(request_func_input=profile_input)
319 | if profile_output.success:
320 | print("Profiler started")
321 |
322 | if burstiness == 1.0:
323 | distribution = "Poisson process"
324 | else:
325 | distribution = "Gamma distribution"
326 |
327 | print(f"Traffic request rate: {request_rate}")
328 | print(f"Burstiness factor: {burstiness} ({distribution})")
329 | print(f"Maximum request concurrency: {max_concurrency}")
330 |
331 | pbar = None if disable_tqdm else tqdm(total=len(input_requests))
332 |
333 | # This can be used once the minimum Python version is 3.10 or higher,
334 | # and it will simplify the code in limited_request_func.
335 | # semaphore = (asyncio.Semaphore(max_concurrency)
336 | # if max_concurrency else contextlib.nullcontext())
337 | semaphore = (asyncio.Semaphore(max_concurrency)
338 | if max_concurrency else None)
339 |
340 | async def limited_request_func(request_func_input, pbar):
341 | if semaphore is None:
342 | return await request_func(request_func_input=request_func_input,
343 | pbar=pbar)
344 | async with semaphore:
345 | return await request_func(request_func_input=request_func_input,
346 | pbar=pbar)
347 |
348 | benchmark_start_time = time.perf_counter()
349 | tasks: list[asyncio.Task] = []
350 | async for request in get_request(input_requests, request_rate, burstiness):
351 | prompt, prompt_len, output_len, mm_content = request.prompt, \
352 | request.prompt_len, request.expected_output_len, \
353 | request.multi_modal_data
354 | req_model_id, req_model_name = model_id, model_name
355 | if lora_modules:
356 | req_lora_module = next(lora_modules)
357 | req_model_id, req_model_name = req_lora_module, req_lora_module
358 |
359 | request_func_input = RequestFuncInput(model=req_model_id,
360 | model_name=req_model_name,
361 | prompt=prompt,
362 | api_url=api_url,
363 | prompt_len=prompt_len,
364 | output_len=output_len,
365 | logprobs=logprobs,
366 | multi_modal_content=mm_content,
367 | ignore_eos=ignore_eos)
368 | tasks.append(
369 | asyncio.create_task(
370 | limited_request_func(request_func_input=request_func_input,
371 | pbar=pbar)))
372 | outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
373 |
374 | if profile:
375 | print("Stopping profiler...")
376 | profile_input = RequestFuncInput(
377 | model=model_id,
378 | prompt=test_prompt,
379 | api_url=base_url + "/stop_profile",
380 | prompt_len=test_prompt_len,
381 | output_len=test_output_len,
382 | logprobs=logprobs,
383 | )
384 | profile_output = await request_func(request_func_input=profile_input)
385 | if profile_output.success:
386 | print("Profiler stopped")
387 |
388 | if pbar is not None:
389 | pbar.close()
390 |
391 | benchmark_duration = time.perf_counter() - benchmark_start_time
392 |
393 | metrics, actual_output_lens = calculate_metrics(
394 | input_requests=input_requests,
395 | outputs=outputs,
396 | dur_s=benchmark_duration,
397 | tokenizer=tokenizer,
398 | selected_percentile_metrics=selected_percentile_metrics,
399 | selected_percentiles=selected_percentiles,
400 | goodput_config_dict=goodput_config_dict,
401 | )
402 |
403 | print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
404 | print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
405 | print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
406 | benchmark_duration))
407 | print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
408 | print("{:<40} {:<10}".format("Total generated tokens:",
409 | metrics.total_output))
410 | print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
411 | metrics.request_throughput))
412 | if goodput_config_dict:
413 | print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
414 | metrics.request_goodput))
415 | print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
416 | metrics.output_throughput))
417 | print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
418 | metrics.total_token_throughput))
419 |
420 | result = {
421 | "duration": benchmark_duration,
422 | "completed": metrics.completed,
423 | "total_input_tokens": metrics.total_input,
424 | "total_output_tokens": metrics.total_output,
425 | "request_throughput": metrics.request_throughput,
426 | "request_goodput:":
427 | metrics.request_goodput if goodput_config_dict else None,
428 | "output_throughput": metrics.output_throughput,
429 | "total_token_throughput": metrics.total_token_throughput,
430 | "input_lens": [output.prompt_len for output in outputs],
431 | "output_lens": actual_output_lens,
432 | "ttfts": [output.ttft for output in outputs],
433 | "itls": [output.itl for output in outputs],
434 | "generated_texts": [output.generated_text for output in outputs],
435 | "errors": [output.error for output in outputs],
436 | }
437 |
438 | def process_one_metric(
439 | # E.g., "ttft"
440 | metric_attribute_name: str,
441 | # E.g., "TTFT"
442 | metric_name: str,
443 | # E.g., "Time to First Token"
444 | metric_header: str,
445 | ):
446 | # This function prints and adds statistics of the specified
447 | # metric.
448 | if metric_attribute_name not in selected_percentile_metrics:
449 | return
450 | print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
451 | print("{:<40} {:<10.2f}".format(
452 | f"Mean {metric_name} (ms):",
453 | getattr(metrics, f"mean_{metric_attribute_name}_ms")))
454 | print("{:<40} {:<10.2f}".format(
455 | f"Median {metric_name} (ms):",
456 | getattr(metrics, f"median_{metric_attribute_name}_ms")))
457 | result[f"mean_{metric_attribute_name}_ms"] = getattr(
458 | metrics, f"mean_{metric_attribute_name}_ms")
459 | result[f"median_{metric_attribute_name}_ms"] = getattr(
460 | metrics, f"median_{metric_attribute_name}_ms")
461 | result[f"std_{metric_attribute_name}_ms"] = getattr(
462 | metrics, f"std_{metric_attribute_name}_ms")
463 | for p, value in getattr(metrics,
464 | f"percentiles_{metric_attribute_name}_ms"):
465 | p_word = str(int(p)) if int(p) == p else str(p)
466 | print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):",
467 | value))
468 | result[f"p{p_word}_{metric_attribute_name}_ms"] = value
469 |
470 | process_one_metric("ttft", "TTFT", "Time to First Token")
471 | process_one_metric("tpot", "TPOT",
472 | "Time per Output Token (excl. 1st token)")
473 | process_one_metric("itl", "ITL", "Inter-token Latency")
474 | process_one_metric("e2el", "E2EL", "End-to-end Latency")
475 |
476 | print("=" * 50)
477 |
478 | return result
479 |
480 |
481 | def check_goodput_args(args):
482 | # Check and parse goodput arguments
483 | goodput_config_dict = {}
484 | VALID_NAMES = ["ttft", "tpot", "e2el"]
485 | if args.goodput:
486 | goodput_config_dict = parse_goodput(args.goodput)
487 | for slo_name, slo_val in goodput_config_dict.items():
488 | if slo_name not in VALID_NAMES:
489 | raise ValueError(
490 | f"Invalid metric name found, {slo_name}: {slo_val}. "
491 | "The service level objective name should be one of "
492 | f"{str(VALID_NAMES)}. ")
493 | if slo_val < 0:
494 | raise ValueError(
495 | f"Invalid value found, {slo_name}: {slo_val}. "
496 | "The service level objective value should be "
497 | "non-negative.")
498 | return goodput_config_dict
499 |
500 |
501 | def parse_goodput(slo_pairs):
502 | goodput_config_dict = {}
503 | try:
504 | for slo_pair in slo_pairs:
505 | slo_name, slo_val = slo_pair.split(":")
506 | goodput_config_dict[slo_name] = float(slo_val)
507 | except ValueError as err:
508 | raise argparse.ArgumentTypeError(
509 | "Invalid format found for service level objectives. "
510 | "Specify service level objectives for goodput as \"KEY:VALUE\" "
511 | "pairs, where the key is a metric name, and the value is a "
512 | "number in milliseconds.") from err
513 | return goodput_config_dict
514 |
515 |
516 | def save_to_pytorch_benchmark_format(args: argparse.Namespace,
517 | results: dict[str, Any],
518 | file_name: str) -> None:
519 | metrics = [
520 | "median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms",
521 | "mean_tpot_ms", "median_tpot_ms", "std_tpot_ms", "p99_tpot_ms",
522 | "median_itl_ms", "mean_itl_ms", "std_itl_ms", "p99_itl_ms"
523 | ]
524 | # These raw data might be useful, but they are rather big. They can be added
525 | # later if needed
526 | ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"]
527 | pt_records = convert_to_pytorch_benchmark_format(
528 | args=args,
529 | metrics={k: [results[k]]
530 | for k in metrics},
531 | extra_info={
532 | k: results[k]
533 | for k in results if k not in metrics and k not in ignored_metrics
534 | })
535 | if pt_records:
536 | # Don't use json suffix here as we don't want CI to pick it up
537 | pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json"
538 | write_to_json(pt_file, pt_records)
539 |
540 |
541 | def main(args: argparse.Namespace):
542 | print(args)
543 | random.seed(args.seed)
544 | np.random.seed(args.seed)
545 |
546 | backend = args.backend
547 | model_id = args.model
548 | model_name = args.served_model_name
549 | tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
550 | tokenizer_mode = args.tokenizer_mode
551 |
552 | if args.base_url is not None:
553 | api_url = f"{args.base_url}{args.endpoint}"
554 | base_url = f"{args.base_url}"
555 | else:
556 | api_url = f"http://{args.host}:{args.port}{args.endpoint}"
557 | base_url = f"http://{args.host}:{args.port}"
558 |
559 | tokenizer = get_tokenizer(tokenizer_id,
560 | tokenizer_mode=tokenizer_mode,
561 | trust_remote_code=args.trust_remote_code)
562 |
563 | if args.dataset_name is None:
564 | raise ValueError(
565 | "Please specify '--dataset-name' and the corresponding "
566 | "'--dataset-path' if required.")
567 |
568 | if args.dataset_name == "sonnet":
569 | dataset = SonnetDataset(dataset_path=args.dataset_path)
570 | # For the "sonnet" dataset, formatting depends on the backend.
571 | if args.backend == "openai-chat":
572 | input_requests = dataset.sample(num_requests=args.num_prompts,
573 | input_len=args.sonnet_input_len,
574 | output_len=args.sonnet_output_len,
575 | prefix_len=args.sonnet_prefix_len,
576 | tokenizer=tokenizer,
577 | return_prompt_formatted=False)
578 | else:
579 | assert tokenizer.chat_template or tokenizer.default_chat_template, (
580 | "Tokenizer/model must have chat template for sonnet dataset.")
581 | input_requests = dataset.sample(num_requests=args.num_prompts,
582 | input_len=args.sonnet_input_len,
583 | output_len=args.sonnet_output_len,
584 | prefix_len=args.sonnet_prefix_len,
585 | tokenizer=tokenizer,
586 | return_prompt_formatted=True)
587 |
588 | elif args.dataset_name == "hf":
589 | # Choose between VisionArenaDataset
590 | # and HuggingFaceDataset based on provided parameters.
591 | dataset_class = (VisionArenaDataset if args.dataset_path
592 | == VisionArenaDataset.VISION_ARENA_DATASET_PATH
593 | and args.hf_subset is None else HuggingFaceDataset)
594 | input_requests = dataset_class(
595 | dataset_path=args.dataset_path,
596 | dataset_subset=args.hf_subset,
597 | dataset_split=args.hf_split,
598 | ).sample(
599 | num_requests=args.num_prompts,
600 | tokenizer=tokenizer,
601 | random_seed=args.seed,
602 | output_len=args.hf_output_len,
603 | )
604 |
605 | else:
606 | # For datasets that follow a similar structure, use a mapping.
607 | dataset_mapping = {
608 | "sharegpt":
609 | lambda: ShareGPTDataset(random_seed=args.seed,
610 | dataset_path=args.dataset_path).sample(
611 | tokenizer=tokenizer,
612 | num_requests=args.num_prompts,
613 | output_len=args.sharegpt_output_len,
614 | ),
615 | "burstgpt":
616 | lambda: BurstGPTDataset(random_seed=args.seed,
617 | dataset_path=args.dataset_path).
618 | sample(tokenizer=tokenizer, num_requests=args.num_prompts),
619 | "random":
620 | lambda: RandomDataset(dataset_path=args.dataset_path).sample(
621 | tokenizer=tokenizer,
622 | num_requests=args.num_prompts,
623 | prefix_len=args.random_prefix_len,
624 | input_len=args.random_input_len,
625 | output_len=args.random_output_len,
626 | range_ratio=args.random_range_ratio,
627 | )
628 | }
629 |
630 | try:
631 | input_requests = dataset_mapping[args.dataset_name]()
632 | except KeyError as err:
633 | raise ValueError(f"Unknown dataset: {args.dataset_name}") from err
634 | goodput_config_dict = check_goodput_args(args)
635 |
636 | # Avoid GC processing "static" data - reduce pause times.
637 | gc.collect()
638 | gc.freeze()
639 |
640 | benchmark_result = asyncio.run(
641 | benchmark(
642 | backend=backend,
643 | api_url=api_url,
644 | base_url=base_url,
645 | model_id=model_id,
646 | model_name=model_name,
647 | tokenizer=tokenizer,
648 | input_requests=input_requests,
649 | logprobs=args.logprobs,
650 | request_rate=args.request_rate,
651 | burstiness=args.burstiness,
652 | disable_tqdm=args.disable_tqdm,
653 | profile=args.profile,
654 | selected_percentile_metrics=args.percentile_metrics.split(","),
655 | selected_percentiles=[
656 | float(p) for p in args.metric_percentiles.split(",")
657 | ],
658 | ignore_eos=args.ignore_eos,
659 | goodput_config_dict=goodput_config_dict,
660 | max_concurrency=args.max_concurrency,
661 | lora_modules=args.lora_modules,
662 | ))
663 |
664 | # CUSTOM START
665 | # I made slight modification here to just return the benchmark results.
666 | # The rest was left as is, in case we need to update in the future.
667 | return benchmark_result
668 | # CUSTOM END
669 |
670 | # Save config and results to json
671 | if args.save_result:
672 | result_json: dict[str, Any] = {}
673 |
674 | # Setup
675 | current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
676 | result_json["date"] = current_dt
677 | result_json["backend"] = backend
678 | result_json["model_id"] = model_id
679 | result_json["tokenizer_id"] = tokenizer_id
680 | result_json["num_prompts"] = args.num_prompts
681 |
682 | # Metadata
683 | if args.metadata:
684 | for item in args.metadata:
685 | if "=" in item:
686 | kvstring = item.split("=")
687 | result_json[kvstring[0].strip()] = kvstring[1].strip()
688 | else:
689 | raise ValueError(
690 | "Invalid metadata format. Please use KEY=VALUE format."
691 | )
692 |
693 | if not args.save_detailed:
694 | # Remove fields with too many data points
695 | for field in [
696 | "input_lens", "output_lens", "ttfts", "itls",
697 | "generated_texts", "errors"
698 | ]:
699 | if field in result_json:
700 | del result_json[field]
701 |
702 | # Traffic
703 | result_json["request_rate"] = (args.request_rate if args.request_rate
704 | < float("inf") else "inf")
705 | result_json["burstiness"] = args.burstiness
706 | result_json["max_concurrency"] = args.max_concurrency
707 |
708 | # Merge with benchmark result
709 | result_json = {**result_json, **benchmark_result}
710 |
711 | # Save to file
712 | base_model_id = model_id.split("/")[-1]
713 | max_concurrency_str = (f"-concurrency{args.max_concurrency}"
714 | if args.max_concurrency is not None else "")
715 | file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" #noqa
716 | if args.result_filename:
717 | file_name = args.result_filename
718 | if args.result_dir:
719 | file_name = os.path.join(args.result_dir, file_name)
720 | with open(file_name, "w", encoding='utf-8') as outfile:
721 | json.dump(result_json, outfile)
722 | save_to_pytorch_benchmark_format(args, result_json, file_name)
723 |
724 |
725 | if __name__ == "__main__":
726 | parser = FlexibleArgumentParser(
727 | description="Benchmark the online serving throughput.")
728 | parser.add_argument(
729 | "--backend",
730 | type=str,
731 | default="vllm",
732 | choices=list(ASYNC_REQUEST_FUNCS.keys()),
733 | )
734 | parser.add_argument(
735 | "--base-url",
736 | type=str,
737 | default=None,
738 | help="Server or API base url if not using http host and port.",
739 | )
740 | # Use 127.0.0.1 here instead of localhost to force the use of ipv4
741 | parser.add_argument("--host", type=str, default="127.0.0.1")
742 | parser.add_argument("--port", type=int, default=8000)
743 | parser.add_argument(
744 | "--endpoint",
745 | type=str,
746 | default="/v1/completions",
747 | help="API endpoint.",
748 | )
749 | parser.add_argument(
750 | "--dataset-name",
751 | type=str,
752 | default="sharegpt",
753 | choices=["sharegpt", "burstgpt", "sonnet", "random", "hf"],
754 | help="Name of the dataset to benchmark on.",
755 | )
756 | parser.add_argument("--dataset-path",
757 | type=str,
758 | default=None,
759 | help="Path to the sharegpt/sonnet dataset. "
760 | "Or the huggingface dataset ID if using HF dataset.")
761 | parser.add_argument(
762 | "--max-concurrency",
763 | type=int,
764 | default=None,
765 | help="Maximum number of concurrent requests. This can be used "
766 | "to help simulate an environment where a higher level component "
767 | "is enforcing a maximum number of concurrent requests. While the "
768 | "--request-rate argument controls the rate at which requests are "
769 | "initiated, this argument will control how many are actually allowed "
770 | "to execute at a time. This means that when used in combination, the "
771 | "actual request rate may be lower than specified with --request-rate, "
772 | "if the server is not processing requests fast enough to keep up.")
773 |
774 | parser.add_argument(
775 | "--model",
776 | type=str,
777 | required=True,
778 | help="Name of the model.",
779 | )
780 | parser.add_argument(
781 | "--tokenizer",
782 | type=str,
783 | help=
784 | "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
785 | )
786 | parser.add_argument("--use-beam-search", action="store_true")
787 | parser.add_argument(
788 | "--num-prompts",
789 | type=int,
790 | default=1000,
791 | help="Number of prompts to process.",
792 | )
793 | parser.add_argument(
794 | "--logprobs",
795 | type=int,
796 | default=None,
797 | help=("Number of logprobs-per-token to compute & return as part of "
798 | "the request. If unspecified, then either (1) if beam search "
799 | "is disabled, no logprobs are computed & a single dummy "
800 | "logprob is returned for each token; or (2) if beam search "
801 | "is enabled 1 logprob per token is computed"),
802 | )
803 | parser.add_argument(
804 | "--request-rate",
805 | type=float,
806 | default=float("inf"),
807 | help="Number of requests per second. If this is inf, "
808 | "then all the requests are sent at time 0. "
809 | "Otherwise, we use Poisson process or gamma distribution "
810 | "to synthesize the request arrival times.",
811 | )
812 | parser.add_argument(
813 | "--burstiness",
814 | type=float,
815 | default=1.0,
816 | help="Burstiness factor of the request generation. "
817 | "Only take effect when request_rate is not inf. "
818 | "Default value is 1, which follows Poisson process. "
819 | "Otherwise, the request intervals follow a gamma distribution. "
820 | "A lower burstiness value (0 < burstiness < 1) results in more "
821 | "bursty requests. A higher burstiness value (burstiness > 1) "
822 | "results in a more uniform arrival of requests.",
823 | )
824 | parser.add_argument("--seed", type=int, default=0)
825 | parser.add_argument(
826 | "--trust-remote-code",
827 | action="store_true",
828 | help="Trust remote code from huggingface",
829 | )
830 | parser.add_argument(
831 | "--disable-tqdm",
832 | action="store_true",
833 | help="Specify to disable tqdm progress bar.",
834 | )
835 | parser.add_argument(
836 | "--profile",
837 | action="store_true",
838 | help="Use Torch Profiler. The endpoint must be launched with "
839 | "VLLM_TORCH_PROFILER_DIR to enable profiler.",
840 | )
841 | parser.add_argument(
842 | "--save-result",
843 | action="store_true",
844 | help="Specify to save benchmark results to a json file",
845 | )
846 | parser.add_argument(
847 | "--save-detailed",
848 | action="store_true",
849 | help="When saving the results, whether to include per request "
850 | "information such as response, error, ttfs, tpots, etc.",
851 | )
852 | parser.add_argument(
853 | "--metadata",
854 | metavar="KEY=VALUE",
855 | nargs="*",
856 | help="Key-value pairs (e.g, --metadata version=0.3.3 tp=1) "
857 | "for metadata of this run to be saved in the result JSON file "
858 | "for record keeping purposes.",
859 | )
860 | parser.add_argument(
861 | "--result-dir",
862 | type=str,
863 | default=None,
864 | help="Specify directory to save benchmark json results."
865 | "If not specified, results are saved in the current directory.",
866 | )
867 | parser.add_argument(
868 | "--result-filename",
869 | type=str,
870 | default=None,
871 | help="Specify the filename to save benchmark json results."
872 | "If not specified, results will be saved in "
873 | "{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
874 | " format.",
875 | )
876 | parser.add_argument(
877 | "--ignore-eos",
878 | action="store_true",
879 | help="Set ignore_eos flag when sending the benchmark request."
880 | "Warning: ignore_eos is not supported in deepspeed_mii and tgi.")
881 | parser.add_argument(
882 | "--percentile-metrics",
883 | type=str,
884 | default="ttft,tpot,itl",
885 | help="Comma-seperated list of selected metrics to report percentils. "
886 | "This argument specifies the metrics to report percentiles. "
887 | "Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". "
888 | "Default value is \"ttft,tpot,itl\".")
889 | parser.add_argument(
890 | "--metric-percentiles",
891 | type=str,
892 | default="99",
893 | help="Comma-seperated list of percentiles for selected metrics. "
894 | "To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
895 | "Default value is \"99\". "
896 | "Use \"--percentile-metrics\" to select metrics.",
897 | )
898 | parser.add_argument(
899 | "--goodput",
900 | nargs="+",
901 | required=False,
902 | help="Specify service level objectives for goodput as \"KEY:VALUE\" "
903 | "pairs, where the key is a metric name, and the value is in "
904 | "milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, "
905 | "separated by spaces. Allowed request level metric names are "
906 | "\"ttft\", \"tpot\", \"e2el\". For more context on the definition of "
907 | "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
908 | "and the blog: https://hao-ai-lab.github.io/blogs/distserve")
909 |
910 | # group for dataset specific arguments
911 | sonnet_group = parser.add_argument_group("sonnet dataset options")
912 | sonnet_group.add_argument(
913 | "--sonnet-input-len",
914 | type=int,
915 | default=550,
916 | help=
917 | "Number of input tokens per request, used only for sonnet dataset.",
918 | )
919 | sonnet_group.add_argument(
920 | "--sonnet-output-len",
921 | type=int,
922 | default=150,
923 | help=
924 | "Number of output tokens per request, used only for sonnet dataset.",
925 | )
926 | sonnet_group.add_argument(
927 | "--sonnet-prefix-len",
928 | type=int,
929 | default=200,
930 | help=
931 | "Number of prefix tokens per request, used only for sonnet dataset.",
932 | )
933 |
934 | sharegpt_group = parser.add_argument_group("sharegpt dataset options")
935 | sharegpt_group.add_argument(
936 | "--sharegpt-output-len",
937 | type=int,
938 | default=None,
939 | help="Output length for each request. Overrides the output length "
940 | "from the ShareGPT dataset.")
941 |
942 | random_group = parser.add_argument_group("random dataset options")
943 | random_group.add_argument(
944 | "--random-input-len",
945 | type=int,
946 | default=1024,
947 | help=
948 | "Number of input tokens per request, used only for random sampling.",
949 | )
950 | random_group.add_argument(
951 | "--random-output-len",
952 | type=int,
953 | default=128,
954 | help=
955 | "Number of output tokens per request, used only for random sampling.",
956 | )
957 | random_group.add_argument(
958 | "--random-range-ratio",
959 | type=float,
960 | default=1.0,
961 | help="Range of sampled ratio of input/output length, "
962 | "used only for random sampling.",
963 | )
964 | random_group.add_argument(
965 | "--random-prefix-len",
966 | type=int,
967 | default=0,
968 | help="Number of fixed prefix tokens before random "
969 | " context. The length range of context in a random "
970 | " request is [random-prefix-len, "
971 | " random-prefix-len + random-prefix-len * random-range-ratio).")
972 |
973 | hf_group = parser.add_argument_group("hf dataset options")
974 | hf_group.add_argument("--hf-subset",
975 | type=str,
976 | default=None,
977 | help="Subset of the HF dataset.")
978 | hf_group.add_argument("--hf-split",
979 | type=str,
980 | default=None,
981 | help="Split of the HF dataset.")
982 | hf_group.add_argument(
983 | "--hf-output-len",
984 | type=int,
985 | default=None,
986 | help="Output length for each request. Overrides the output lengths "
987 | "from the sampled HF dataset.",
988 | )
989 |
990 | parser.add_argument(
991 | '--tokenizer-mode',
992 | type=str,
993 | default="auto",
994 | choices=['auto', 'slow', 'mistral', 'custom'],
995 | help='The tokenizer mode.\n\n* "auto" will use the '
996 | 'fast tokenizer if available.\n* "slow" will '
997 | 'always use the slow tokenizer. \n* '
998 | '"mistral" will always use the `mistral_common` tokenizer. \n*'
999 | '"custom" will use --tokenizer to select the preregistered tokenizer.')
1000 |
1001 | parser.add_argument("--served-model-name",
1002 | type=str,
1003 | default=None,
1004 | help="The model name used in the API. "
1005 | "If not specified, the model name will be the "
1006 | "same as the ``--model`` argument. ")
1007 |
1008 | parser.add_argument("--lora-modules",
1009 | nargs='+',
1010 | default=None,
1011 | help="A subset of LoRA module names passed in when "
1012 | "launching the server. For each request, the "
1013 | "script chooses a LoRA module at random.")
1014 |
1015 | args = parser.parse_args()
1016 |
1017 | main(args)
```