#
tokens: 27144/50000 10/10 files
lines: on (toggle) GitHub
raw markdown copy reset
# Directory Structure

```
├── .gitignore
├── benchmark
│   ├── backend_request_func.py
│   ├── benchmark_dataset.py
│   ├── benchmark_serving.py
│   ├── benchmark_utils.py
│   └── readme.md
├── benchmark_tool.py
├── pyproject.toml
├── README.md
├── server.py
└── uv.lock
```

# Files

--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------

```
1 | __pycache__
2 | ShareGPT_V3_unfiltered_cleaned_split.json
3 | **/__pycache__
4 | 
```

--------------------------------------------------------------------------------
/benchmark/readme.md:
--------------------------------------------------------------------------------

```markdown
1 | These files come straight from the vLLM.
2 | 
3 | https://github.com/vllm-project/vllm/tree/main/benchmarks
```

--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------

```markdown
 1 | # MCP vLLM Benchmarking Tool
 2 | 
 3 | This is proof of concept on how to use MCP to interactively benchmark vLLM.
 4 | 
 5 | We are not new to benchmarking, read our blog:
 6 | 
 7 | [Benchmarking vLLM](https://eliovp.com/introducing-our-benchmarking-tool-powered-by-dstack/)
 8 | 
 9 | This is just an exploration of possibilities with MCP.
10 | 
11 | ## Usage
12 | 
13 | 1. Clone the repository
14 | 2. Add it to your MCP servers:
15 | ```
16 | {
17 |     "mcpServers": {
18 |         "mcp-vllm": {
19 |             "command": "uv",
20 |             "args": [
21 |                 "run",
22 |                 "/Path/TO/mcp-vllm-benchmarking-tool/server.py"
23 |             ]
24 |         }
25 |     }
26 | }
27 | ```
28 | 
29 | Then you can prompt for example like this:
30 | 
31 | ```
32 | Do a vllm benchmark for this endpoint: http://10.0.101.39:8888 
33 | benchmark the following model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B 
34 | run the benchmark 3 times with each 32 num prompts, then compare the results, but ignore the first iteration as that is just a warmup.
35 | ```
36 | 
37 | 
38 | ## Todo:
39 | 
40 | - Due to some random outputs by vllm it may show that it found some invalid json. I have not really looked into it yet.
```

--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------

```toml
 1 | [project]
 2 | name = "mcp-bencher"
 3 | version = "0.1.0"
 4 | description = "A model benchmarking tool using vLLM"
 5 | readme = "README.md"
 6 | requires-python = "~=3.10"
 7 | dependencies = [
 8 |     "mcp[cli]>=1.5.0",
 9 |     "vllm",
10 |     "requests",
11 |     "pathlib",
12 |     "pytest",
13 |     "tqdm",
14 |     "aiohttp",
15 |     "numpy",
16 |     "huggingface_hub",
17 |     "transformers",
18 |     "pandas",
19 |     "datasets",
20 | ]
21 | 
```

--------------------------------------------------------------------------------
/benchmark/benchmark_utils.py:
--------------------------------------------------------------------------------

```python
 1 | # SPDX-License-Identifier: Apache-2.0
 2 | 
 3 | import argparse
 4 | import json
 5 | import math
 6 | import os
 7 | from typing import Any
 8 | 
 9 | 
10 | def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
11 |                                         metrics: dict[str, list],
12 |                                         extra_info: dict[str, Any]) -> list:
13 |     """
14 |     Save the benchmark results in the format used by PyTorch OSS benchmark with
15 |     on metric per record
16 |     https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database
17 |     """
18 |     records = []
19 |     if not os.environ.get("SAVE_TO_PYTORCH_BENCHMARK_FORMAT", False):
20 |         return records
21 | 
22 |     for name, benchmark_values in metrics.items():
23 |         record = {
24 |             "benchmark": {
25 |                 "name": "vLLM benchmark",
26 |                 "extra_info": {
27 |                     "args": vars(args),
28 |                 },
29 |             },
30 |             "model": {
31 |                 "name": args.model,
32 |             },
33 |             "metric": {
34 |                 "name": name,
35 |                 "benchmark_values": benchmark_values,
36 |                 "extra_info": extra_info,
37 |             },
38 |         }
39 | 
40 |         tp = record["benchmark"]["extra_info"]["args"].get(
41 |             "tensor_parallel_size")
42 |         # Save tensor_parallel_size parameter if it's part of the metadata
43 |         if not tp and "tensor_parallel_size" in extra_info:
44 |             record["benchmark"]["extra_info"]["args"][
45 |                 "tensor_parallel_size"] = extra_info["tensor_parallel_size"]
46 | 
47 |         records.append(record)
48 | 
49 |     return records
50 | 
51 | 
52 | class InfEncoder(json.JSONEncoder):
53 | 
54 |     def clear_inf(self, o: Any):
55 |         if isinstance(o, dict):
56 |             return {k: self.clear_inf(v) for k, v in o.items()}
57 |         elif isinstance(o, list):
58 |             return [self.clear_inf(v) for v in o]
59 |         elif isinstance(o, float) and math.isinf(o):
60 |             return "inf"
61 |         return o
62 | 
63 |     def iterencode(self, o: Any, *args, **kwargs) -> Any:
64 |         return super().iterencode(self.clear_inf(o), *args, **kwargs)
65 | 
66 | 
67 | def write_to_json(filename: str, records: list) -> None:
68 |     with open(filename, "w") as f:
69 |         json.dump(records, f, cls=InfEncoder)
```

--------------------------------------------------------------------------------
/server.py:
--------------------------------------------------------------------------------

```python
 1 | # server.py
 2 | from mcp.server.fastmcp import FastMCP
 3 | from typing import Dict
 4 | import concurrent.futures
 5 | import os
 6 | import requests
 7 | import pathlib
 8 | import tqdm
 9 | 
10 | from benchmark_tool import run_benchmark
11 | 
12 | # Create an MCP server
13 | mcp = FastMCP("vLLM Bencher")
14 | 
15 | @mcp.tool()
16 | def benchmark_vllm(
17 |     model: str,
18 |     base_url: str,
19 |     num_prompts: int = 10,
20 | ) -> Dict:
21 |     """
22 |     Run vLLM benchmarking tool to measure model performance
23 |     
24 |     Args:
25 |         model: The model to benchmark (e.g., 'meta-llama/Llama-2-7b-hf')
26 |         backend: Backend server to use (vllm, tgi, openai, etc.)
27 |         dataset: Dataset to use for benchmarking (sharegpt, random, etc.)
28 |         dataset_path: Path to the dataset file
29 |         num_prompts: Number of prompts to benchmark with
30 |         request_rate: Requests per second
31 |         concurrent_requests: Number of concurrent requests
32 |         max_tokens: Maximum number of tokens to generate
33 |         vllm_dir: Directory where vLLM is installed
34 |         api_url: URL of the API to benchmark
35 |         save_result: Whether to save benchmark results
36 |         result_filename: Filename to save benchmark results
37 |         api_key: API key for the backend
38 |         trust_remote_code: Whether to trust remote code
39 |         extra_args: Additional arguments to pass to benchmark_serving.py
40 |     
41 |     Returns:
42 |         Dictionary containing benchmark results including throughput, latency, and other metrics
43 |     """
44 |     
45 |     # Define the dataset path
46 |     dataset_filename = "ShareGPT_V3_unfiltered_cleaned_split.json"
47 |     current_dir = pathlib.Path(__file__).parent.absolute()
48 |     dataset_path = current_dir / dataset_filename
49 |     
50 |     # Check if dataset exists, if not, download it
51 |     if not dataset_path.exists():
52 |         dataset_url = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
53 |         try:
54 |             response = requests.get(dataset_url, stream=True)
55 |             response.raise_for_status()
56 |             
57 |             # Get file size if available
58 |             total_size_in_bytes = int(response.headers.get('content-length', 0))
59 |             progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True, desc="Downloading dataset")
60 |             
61 |             with open(dataset_path, 'wb') as f:
62 |                 for chunk in response.iter_content(chunk_size=8192):
63 |                     progress_bar.update(len(chunk))
64 |                     f.write(chunk)
65 |             
66 |             progress_bar.close()
67 |         except Exception as e:
68 |             # If download failed and partial file exists, remove it
69 |             if dataset_path.exists():
70 |                 os.remove(dataset_path)
71 |             raise
72 | 
73 |     # Run the benchmark in a separate thread to avoid asyncio event loop issues
74 |     with concurrent.futures.ThreadPoolExecutor() as executor:
75 |         future = executor.submit(
76 |             run_benchmark,
77 |             model=model,
78 |             backend="vllm",
79 |             dataset="sharegpt",
80 |             dataset_path=str(dataset_path),
81 |             num_prompts=num_prompts,
82 |             base_url=base_url,
83 |         )
84 |         return future.result()
85 | 
86 | if __name__ == "__main__":
87 |     mcp.run()
88 | 
```

--------------------------------------------------------------------------------
/benchmark_tool.py:
--------------------------------------------------------------------------------

```python
  1 | import argparse
  2 | from typing import Dict, List, Optional
  3 | from benchmark.benchmark_serving import main as main_benchmark_serving
  4 | 
  5 | def run_benchmark(
  6 |     model: str,
  7 |     base_url: str,
  8 |     backend: str = "vllm",
  9 |     dataset: str = "sharegpt",
 10 |     dataset_path: Optional[str] = None,
 11 |     num_prompts: int = 100,
 12 |     request_rate: float = 10.0,
 13 |     concurrent_requests: int = 10,
 14 |     max_tokens: int = 128,
 15 |     vllm_dir: Optional[str] = None,
 16 |     save_result: bool = True,
 17 |     result_filename: Optional[str] = None,
 18 |     api_key: Optional[str] = None,
 19 |     trust_remote_code: bool = False,
 20 |     extra_args: Optional[List[str]] = None,
 21 | ) -> Dict:
 22 |     """
 23 |     Run vLLM benchmarking tool
 24 |     
 25 |     Args:
 26 |         model: The model to benchmark
 27 |         backend: Backend server to use (vllm, tgi, openai, etc.)
 28 |         dataset: Dataset to use for benchmarking
 29 |         dataset_path: Path to the dataset file
 30 |         num_prompts: Number of prompts to benchmark with
 31 |         request_rate: Requests per second
 32 |         concurrent_requests: Number of concurrent requests
 33 |         max_tokens: Maximum number of tokens to generate
 34 |         vllm_dir: Directory where vLLM is installed
 35 |         api_url: URL of the API to benchmark
 36 |         save_result: Whether to save benchmark results
 37 |         result_filename: Filename to save benchmark results
 38 |         api_key: API key for the backend
 39 |         trust_remote_code: Whether to trust remote code
 40 |         extra_args: Additional arguments to pass to benchmark_serving.py
 41 |     
 42 |     Returns:
 43 |         Dictionary with benchmark results
 44 |     """
 45 |     # Create argparse.Namespace object to pass to main_benchmark_serving
 46 |     args = argparse.Namespace()
 47 |     
 48 |     # Required parameters
 49 |     args.model = model
 50 |     args.backend = backend
 51 |     args.dataset_name = dataset
 52 |     args.dataset_path = dataset_path if dataset_path else "./ShareGPT_V3_unfiltered_cleaned_split.json"
 53 |     args.num_prompts = num_prompts
 54 |     args.request_rate = request_rate
 55 |     args.max_concurrency = concurrent_requests
 56 |     
 57 |     # Optional parameters with defaults
 58 |     args.host = "127.0.0.1"
 59 |     args.port = 8000
 60 |     args.endpoint = "/v1/completions"
 61 |     args.base_url = base_url
 62 |     args.seed = 0
 63 |     args.disable_tqdm = False
 64 |     args.profile = False
 65 |     args.use_beam_search = False
 66 |     args.tokenizer = None
 67 |     args.logprobs = None
 68 |     args.burstiness = 1.0
 69 |     args.ignore_eos = False
 70 |     args.percentile_metrics = "ttft,tpot,itl"
 71 |     args.metric_percentiles = "99"
 72 |     args.save_result = save_result
 73 |     args.save_detailed = False
 74 |     args.metadata = None
 75 |     args.result_dir = None
 76 |     args.result_filename = result_filename
 77 |     args.trust_remote_code = trust_remote_code
 78 |     args.tokenizer_mode = "auto"
 79 |     args.served_model_name = None
 80 |     args.lora_modules = None
 81 |     
 82 |     # Dataset-specific parameters
 83 |     args.sonnet_input_len = 550
 84 |     args.sonnet_output_len = 150
 85 |     args.sonnet_prefix_len = 200
 86 |     args.sharegpt_output_len = max_tokens
 87 |     args.random_input_len = 1024
 88 |     args.random_output_len = max_tokens
 89 |     args.random_range_ratio = 1.0
 90 |     args.random_prefix_len = 0
 91 |     args.hf_subset = None
 92 |     args.hf_split = None
 93 |     args.hf_output_len = max_tokens
 94 |     args.goodput = None
 95 |     
 96 |     # Handle extra args if provided
 97 |     if extra_args:
 98 |         for arg in extra_args:
 99 |             if '=' in arg:
100 |                 key, value = arg.split('=', 1)
101 |                 key = key.lstrip('-').replace('-', '_')
102 |                 try:
103 |                     # Try to convert to appropriate type
104 |                     if value.lower() in ('true', 'yes'):
105 |                         value = True
106 |                     elif value.lower() in ('false', 'no'):
107 |                         value = False
108 |                     elif value.isdigit():
109 |                         value = int(value)
110 |                     elif value.replace('.', '', 1).isdigit():
111 |                         value = float(value)
112 |                 except (ValueError, AttributeError):
113 |                     pass
114 |                 setattr(args, key, value)
115 |     
116 |     benchmark_result = main_benchmark_serving(args)
117 |     return benchmark_result
118 | 
```

--------------------------------------------------------------------------------
/benchmark/backend_request_func.py:
--------------------------------------------------------------------------------

```python
  1 | # SPDX-License-Identifier: Apache-2.0
  2 | 
  3 | import json
  4 | import os
  5 | import sys
  6 | import time
  7 | import traceback
  8 | from dataclasses import dataclass, field
  9 | from typing import Optional, Union
 10 | 
 11 | import aiohttp
 12 | import huggingface_hub.constants
 13 | from tqdm.asyncio import tqdm
 14 | from transformers import (AutoTokenizer, PreTrainedTokenizer,
 15 |                           PreTrainedTokenizerFast)
 16 | 
 17 | # NOTE(simon): do not import vLLM here so the benchmark script
 18 | # can run without vLLM installed.
 19 | 
 20 | AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
 21 | 
 22 | 
 23 | @dataclass
 24 | class RequestFuncInput:
 25 |     prompt: str
 26 |     api_url: str
 27 |     prompt_len: int
 28 |     output_len: int
 29 |     model: str
 30 |     model_name: Optional[str] = None
 31 |     logprobs: Optional[int] = None
 32 |     extra_body: Optional[dict] = None
 33 |     multi_modal_content: Optional[dict] = None
 34 |     ignore_eos: bool = False
 35 | 
 36 | 
 37 | @dataclass
 38 | class RequestFuncOutput:
 39 |     generated_text: str = ""
 40 |     success: bool = False
 41 |     latency: float = 0.0
 42 |     output_tokens: int = 0
 43 |     ttft: float = 0.0  # Time to first token
 44 |     itl: list[float] = field(
 45 |         default_factory=list)  # list of inter-token latencies
 46 |     tpot: float = 0.0  # avg next-token latencies
 47 |     prompt_len: int = 0
 48 |     error: str = ""
 49 | 
 50 | 
 51 | async def async_request_tgi(
 52 |     request_func_input: RequestFuncInput,
 53 |     pbar: Optional[tqdm] = None,
 54 | ) -> RequestFuncOutput:
 55 |     api_url = request_func_input.api_url
 56 |     assert api_url.endswith("generate_stream")
 57 | 
 58 |     async with aiohttp.ClientSession(trust_env=True,
 59 |                                      timeout=AIOHTTP_TIMEOUT) as session:
 60 |         params = {
 61 |             "max_new_tokens": request_func_input.output_len,
 62 |             "do_sample": True,
 63 |             "temperature": 0.01,  # TGI does not accept 0.0 temperature.
 64 |             "top_p": 0.99,  # TGI does not accept 1.0 top_p.
 65 |             "truncate": request_func_input.prompt_len,
 66 |             "ignore_eos_token": request_func_input.ignore_eos,
 67 |         }
 68 |         payload = {
 69 |             "inputs": request_func_input.prompt,
 70 |             "parameters": params,
 71 |         }
 72 |         output = RequestFuncOutput()
 73 |         output.prompt_len = request_func_input.prompt_len
 74 |         if request_func_input.ignore_eos:
 75 |             output.output_tokens = request_func_input.output_len
 76 |         else:
 77 |             output.output_tokens = None
 78 | 
 79 |         ttft = 0.0
 80 |         st = time.perf_counter()
 81 |         most_recent_timestamp = st
 82 |         try:
 83 |             async with session.post(url=api_url, json=payload) as response:
 84 |                 if response.status == 200:
 85 |                     async for chunk_bytes in response.content:
 86 |                         chunk_bytes = chunk_bytes.strip()
 87 |                         if not chunk_bytes:
 88 |                             continue
 89 |                         chunk_bytes = chunk_bytes.decode("utf-8")
 90 | 
 91 |                         # NOTE: Sometimes TGI returns a ping response without
 92 |                         # any data, we should skip it.
 93 |                         if chunk_bytes.startswith(":"):
 94 |                             continue
 95 |                         chunk = chunk_bytes.removeprefix("data:")
 96 | 
 97 |                         data = json.loads(chunk)
 98 |                         timestamp = time.perf_counter()
 99 |                         # First token
100 |                         if ttft == 0.0:
101 |                             ttft = time.perf_counter() - st
102 |                             output.ttft = ttft
103 | 
104 |                         # Decoding phase
105 |                         else:
106 |                             output.itl.append(timestamp -
107 |                                               most_recent_timestamp)
108 | 
109 |                         most_recent_timestamp = timestamp
110 | 
111 |                     output.latency = most_recent_timestamp - st
112 |                     output.success = True
113 |                     output.generated_text = data["generated_text"]
114 |                 else:
115 |                     output.error = response.reason or ""
116 |                     output.success = False
117 |         except Exception:
118 |             output.success = False
119 |             exc_info = sys.exc_info()
120 |             output.error = "".join(traceback.format_exception(*exc_info))
121 | 
122 |         if pbar:
123 |             pbar.update(1)
124 |         return output
125 | 
126 | 
127 | async def async_request_trt_llm(
128 |     request_func_input: RequestFuncInput,
129 |     pbar: Optional[tqdm] = None,
130 | ) -> RequestFuncOutput:
131 |     api_url = request_func_input.api_url
132 |     assert api_url.endswith("generate_stream")
133 | 
134 |     async with aiohttp.ClientSession(trust_env=True,
135 |                                      timeout=AIOHTTP_TIMEOUT) as session:
136 |         payload = {
137 |             "accumulate_tokens": True,
138 |             "text_input": request_func_input.prompt,
139 |             "temperature": 0.0,
140 |             "top_p": 1.0,
141 |             "max_tokens": request_func_input.output_len,
142 |             "stream": True,
143 |         }
144 |         if request_func_input.ignore_eos:
145 |             payload["min_length"] = request_func_input.output_len
146 |         output = RequestFuncOutput()
147 |         output.prompt_len = request_func_input.prompt_len
148 | 
149 |         ttft = 0.0
150 |         st = time.perf_counter()
151 |         most_recent_timestamp = st
152 |         try:
153 |             async with session.post(url=api_url, json=payload) as response:
154 |                 if response.status == 200:
155 |                     async for chunk_bytes in response.content:
156 |                         chunk_bytes = chunk_bytes.strip()
157 |                         if not chunk_bytes:
158 |                             continue
159 | 
160 |                         chunk = chunk_bytes.decode("utf-8").removeprefix(
161 |                             "data:")
162 | 
163 |                         data = json.loads(chunk)
164 |                         output.generated_text += data["text_output"]
165 |                         timestamp = time.perf_counter()
166 |                         # First token
167 |                         if ttft == 0.0:
168 |                             ttft = timestamp - st
169 |                             output.ttft = ttft
170 | 
171 |                         # Decoding phase
172 |                         else:
173 |                             output.itl.append(timestamp -
174 |                                               most_recent_timestamp)
175 | 
176 |                         most_recent_timestamp = timestamp
177 | 
178 |                     output.latency = most_recent_timestamp - st
179 |                     output.success = True
180 | 
181 |                 else:
182 |                     output.error = response.reason or ""
183 |                     output.success = False
184 |         except Exception:
185 |             output.success = False
186 |             exc_info = sys.exc_info()
187 |             output.error = "".join(traceback.format_exception(*exc_info))
188 | 
189 |         if pbar:
190 |             pbar.update(1)
191 |         return output
192 | 
193 | 
194 | async def async_request_deepspeed_mii(
195 |     request_func_input: RequestFuncInput,
196 |     pbar: Optional[tqdm] = None,
197 | ) -> RequestFuncOutput:
198 |     async with aiohttp.ClientSession(trust_env=True,
199 |                                      timeout=AIOHTTP_TIMEOUT) as session:
200 | 
201 |         payload = {
202 |             "prompt": request_func_input.prompt,
203 |             "max_tokens": request_func_input.output_len,
204 |             "temperature": 0.01,  # deepspeed-mii does not accept 0.0 temp.
205 |             "top_p": 1.0,
206 |         }
207 |         output = RequestFuncOutput()
208 |         output.prompt_len = request_func_input.prompt_len
209 | 
210 |         # NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
211 |         # will use 0 as placeholder.
212 |         # See https://github.com/microsoft/DeepSpeed-MII/pull/311
213 |         output.ttft = 0
214 | 
215 |         st = time.perf_counter()
216 |         try:
217 |             async with session.post(url=request_func_input.api_url,
218 |                                     json=payload) as response:
219 |                 if response.status == 200:
220 |                     parsed_resp = await response.json()
221 |                     output.latency = time.perf_counter() - st
222 |                     output.generated_text = parsed_resp["text"][0]
223 |                     output.success = True
224 |                 else:
225 |                     output.error = response.reason or ""
226 |                     output.success = False
227 |         except Exception:
228 |             output.success = False
229 |             exc_info = sys.exc_info()
230 |             output.error = "".join(traceback.format_exception(*exc_info))
231 | 
232 |         if pbar:
233 |             pbar.update(1)
234 |         return output
235 | 
236 | 
237 | async def async_request_openai_completions(
238 |     request_func_input: RequestFuncInput,
239 |     pbar: Optional[tqdm] = None,
240 | ) -> RequestFuncOutput:
241 |     api_url = request_func_input.api_url
242 |     assert api_url.endswith(
243 |         ("completions", "profile")
244 |     ), "OpenAI Completions API URL must end with 'completions' or 'profile'."
245 | 
246 |     async with aiohttp.ClientSession(trust_env=True,
247 |                                      timeout=AIOHTTP_TIMEOUT) as session:
248 |         payload = {
249 |             "model": request_func_input.model_name \
250 |                 if request_func_input.model_name else request_func_input.model,
251 |             "prompt": request_func_input.prompt,
252 |             "temperature": 0.0,
253 |             "max_tokens": request_func_input.output_len,
254 |             "logprobs": request_func_input.logprobs,
255 |             "stream": True,
256 |             "stream_options": {
257 |                 "include_usage": True,
258 |             },
259 |         }
260 |         if request_func_input.ignore_eos:
261 |             payload["ignore_eos"] = request_func_input.ignore_eos
262 |         if request_func_input.extra_body:
263 |             payload.update(request_func_input.extra_body)
264 |         headers = {
265 |             "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
266 |         }
267 | 
268 |         output = RequestFuncOutput()
269 |         output.prompt_len = request_func_input.prompt_len
270 | 
271 |         generated_text = ""
272 |         st = time.perf_counter()
273 |         most_recent_timestamp = st
274 |         try:
275 |             async with session.post(url=api_url, json=payload,
276 |                                     headers=headers) as response:
277 |                 if response.status == 200:
278 |                     first_chunk_received = False
279 |                     async for chunk_bytes in response.content:
280 |                         chunk_bytes = chunk_bytes.strip()
281 |                         if not chunk_bytes:
282 |                             continue
283 | 
284 |                         chunk = chunk_bytes.decode("utf-8").removeprefix(
285 |                             "data: ")
286 |                         if chunk != "[DONE]":
287 |                             data = json.loads(chunk)
288 | 
289 |                             # NOTE: Some completion API might have a last
290 |                             # usage summary response without a token so we
291 |                             # want to check a token was generated
292 |                             if choices := data.get("choices"):
293 |                                 # Note that text could be empty here
294 |                                 # e.g. for special tokens
295 |                                 text = choices[0].get("text")
296 |                                 timestamp = time.perf_counter()
297 |                                 # First token
298 |                                 if not first_chunk_received:
299 |                                     first_chunk_received = True
300 |                                     ttft = time.perf_counter() - st
301 |                                     output.ttft = ttft
302 | 
303 |                                 # Decoding phase
304 |                                 else:
305 |                                     output.itl.append(timestamp -
306 |                                                       most_recent_timestamp)
307 | 
308 |                                 most_recent_timestamp = timestamp
309 |                                 generated_text += text or ""
310 |                             elif usage := data.get("usage"):
311 |                                 output.output_tokens = usage.get(
312 |                                     "completion_tokens")
313 |                     if first_chunk_received:
314 |                         output.success = True
315 |                     else:
316 |                         output.success = False
317 |                         output.error = (
318 |                             "Never received a valid chunk to calculate TTFT."
319 |                             "This response will be marked as failed!")
320 |                     output.generated_text = generated_text
321 |                     output.latency = most_recent_timestamp - st
322 |                 else:
323 |                     output.error = response.reason or ""
324 |                     output.success = False
325 |         except Exception:
326 |             output.success = False
327 |             exc_info = sys.exc_info()
328 |             output.error = "".join(traceback.format_exception(*exc_info))
329 | 
330 |     if pbar:
331 |         pbar.update(1)
332 |     return output
333 | 
334 | 
335 | async def async_request_openai_chat_completions(
336 |     request_func_input: RequestFuncInput,
337 |     pbar: Optional[tqdm] = None,
338 | ) -> RequestFuncOutput:
339 |     api_url = request_func_input.api_url
340 |     assert api_url.endswith(
341 |         ("chat/completions", "profile")
342 |     ), "OpenAI Chat Completions API URL must end with 'chat/completions'."
343 | 
344 |     async with aiohttp.ClientSession(trust_env=True,
345 |                                      timeout=AIOHTTP_TIMEOUT) as session:
346 |         content = [{"type": "text", "text": request_func_input.prompt}]
347 |         if request_func_input.multi_modal_content:
348 |             content.append(request_func_input.multi_modal_content)
349 |         payload = {
350 |             "model": request_func_input.model_name \
351 |                 if request_func_input.model_name else request_func_input.model,
352 |             "messages": [
353 |                 {
354 |                     "role": "user",
355 |                     "content": content
356 |                 },
357 |             ],
358 |             "temperature": 0.0,
359 |             "max_completion_tokens": request_func_input.output_len,
360 |             "stream": True,
361 |             "stream_options": {
362 |                 "include_usage": True,
363 |             },
364 |         }
365 |         if request_func_input.ignore_eos:
366 |             payload["ignore_eos"] = request_func_input.ignore_eos
367 |         if request_func_input.extra_body:
368 |             payload.update(request_func_input.extra_body)
369 |         headers = {
370 |             "Content-Type": "application/json",
371 |             "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
372 |         }
373 | 
374 |         output = RequestFuncOutput()
375 |         output.prompt_len = request_func_input.prompt_len
376 | 
377 |         generated_text = ""
378 |         ttft = 0.0
379 |         st = time.perf_counter()
380 |         most_recent_timestamp = st
381 |         try:
382 |             async with session.post(url=api_url, json=payload,
383 |                                     headers=headers) as response:
384 |                 if response.status == 200:
385 |                     async for chunk_bytes in response.content:
386 |                         chunk_bytes = chunk_bytes.strip()
387 |                         if not chunk_bytes:
388 |                             continue
389 | 
390 |                         chunk = chunk_bytes.decode("utf-8").removeprefix(
391 |                             "data: ")
392 |                         if chunk != "[DONE]":
393 |                             timestamp = time.perf_counter()
394 |                             data = json.loads(chunk)
395 | 
396 |                             if choices := data.get("choices"):
397 |                                 content = choices[0]["delta"].get("content")
398 |                                 # First token
399 |                                 if ttft == 0.0:
400 |                                     ttft = timestamp - st
401 |                                     output.ttft = ttft
402 | 
403 |                                 # Decoding phase
404 |                                 else:
405 |                                     output.itl.append(timestamp -
406 |                                                       most_recent_timestamp)
407 | 
408 |                                 generated_text += content or ""
409 |                             elif usage := data.get("usage"):
410 |                                 output.output_tokens = usage.get(
411 |                                     "completion_tokens")
412 | 
413 |                             most_recent_timestamp = timestamp
414 | 
415 |                     output.generated_text = generated_text
416 |                     output.success = True
417 |                     output.latency = most_recent_timestamp - st
418 |                 else:
419 |                     output.error = response.reason or ""
420 |                     output.success = False
421 |         except Exception:
422 |             output.success = False
423 |             exc_info = sys.exc_info()
424 |             output.error = "".join(traceback.format_exception(*exc_info))
425 | 
426 |     if pbar:
427 |         pbar.update(1)
428 |     return output
429 | 
430 | 
431 | def get_model(pretrained_model_name_or_path: str) -> str:
432 |     if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
433 |         from modelscope import snapshot_download
434 | 
435 |         from vllm.model_executor.model_loader.weight_utils import get_lock
436 | 
437 |         # Use file lock to prevent multiple processes from
438 |         # downloading the same model weights at the same time.
439 |         with get_lock(pretrained_model_name_or_path):
440 |             model_path = snapshot_download(
441 |                 model_id=pretrained_model_name_or_path,
442 |                 local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
443 |                 ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
444 | 
445 |             return model_path
446 |     return pretrained_model_name_or_path
447 | 
448 | 
449 | def get_tokenizer(
450 |     pretrained_model_name_or_path: str,
451 |     tokenizer_mode: str = "auto",
452 |     trust_remote_code: bool = False,
453 |     **kwargs,
454 | ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
455 |     if pretrained_model_name_or_path is not None and not os.path.exists(
456 |             pretrained_model_name_or_path):
457 |         pretrained_model_name_or_path = get_model(
458 |             pretrained_model_name_or_path)
459 |     if tokenizer_mode == "slow":
460 |         if kwargs.get("use_fast", False):
461 |             raise ValueError(
462 |                 "Cannot use the fast tokenizer in slow tokenizer mode.")
463 |         kwargs["use_fast"] = False
464 |     if tokenizer_mode == "mistral":
465 |         try:
466 |             from vllm.transformers_utils.tokenizer import MistralTokenizer
467 |         except ImportError as e:
468 |             raise ImportError("MistralTokenizer requires vllm package.\n"
469 |                               "Please install it with `pip install vllm` "
470 |                               "to use mistral tokenizer mode.") from e
471 |         return MistralTokenizer.from_pretrained(
472 |             str(pretrained_model_name_or_path))
473 |     else:
474 |         return AutoTokenizer.from_pretrained(
475 |             pretrained_model_name_or_path,
476 |             trust_remote_code=trust_remote_code,
477 |             **kwargs,
478 |         )
479 | 
480 | 
481 | ASYNC_REQUEST_FUNCS = {
482 |     "tgi": async_request_tgi,
483 |     "vllm": async_request_openai_completions,
484 |     "lmdeploy": async_request_openai_completions,
485 |     "deepspeed-mii": async_request_deepspeed_mii,
486 |     "openai": async_request_openai_completions,
487 |     "openai-chat": async_request_openai_chat_completions,
488 |     "tensorrt-llm": async_request_trt_llm,
489 |     "scalellm": async_request_openai_completions,
490 |     "sglang": async_request_openai_completions,
491 | }
```

--------------------------------------------------------------------------------
/benchmark/benchmark_dataset.py:
--------------------------------------------------------------------------------

```python
  1 | # SPDX-License-Identifier: Apache-2.0
  2 | """
  3 | This module defines a framework for sampling benchmark requests from various
  4 | datasets. Each dataset subclass of BenchmarkDataset must implement sample
  5 | generation. Supported dataset types include:
  6 |   - ShareGPT
  7 |   - Random (synthetic)
  8 |   - Sonnet
  9 |   - BurstGPT
 10 |   - HuggingFace
 11 |   - VisionArena
 12 | 
 13 | TODO: Implement CustomDataset to parse a JSON file and convert its contents into
 14 | SampleRequest instances, similar to the approach used in ShareGPT.
 15 | """
 16 | 
 17 | import base64
 18 | import io
 19 | import json
 20 | import logging
 21 | import random
 22 | from abc import ABC, abstractmethod
 23 | from collections.abc import Mapping
 24 | from dataclasses import dataclass
 25 | from functools import cache
 26 | from typing import Any, Optional, Union
 27 | 
 28 | import numpy as np
 29 | import pandas as pd
 30 | from datasets import load_dataset
 31 | from PIL import Image
 32 | from transformers import PreTrainedTokenizerBase
 33 | 
 34 | from vllm.lora.request import LoRARequest
 35 | from vllm.lora.utils import get_adapter_absolute_path
 36 | from vllm.multimodal import MultiModalDataDict
 37 | from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
 38 | 
 39 | logger = logging.getLogger(__name__)
 40 | 
 41 | # -----------------------------------------------------------------------------
 42 | # Data Classes
 43 | # -----------------------------------------------------------------------------
 44 | 
 45 | 
 46 | @dataclass
 47 | class SampleRequest:
 48 |     """
 49 |     Represents a single inference request for benchmarking.
 50 |     """
 51 | 
 52 |     prompt: Union[str, Any]
 53 |     prompt_len: int
 54 |     expected_output_len: int
 55 |     multi_modal_data: Optional[Union[MultiModalDataDict, dict]] = None
 56 |     lora_request: Optional[LoRARequest] = None
 57 | 
 58 | 
 59 | # -----------------------------------------------------------------------------
 60 | # Benchmark Dataset Base Class
 61 | # -----------------------------------------------------------------------------
 62 | 
 63 | 
 64 | class BenchmarkDataset(ABC):
 65 |     DEFAULT_SEED = 0
 66 | 
 67 |     def __init__(
 68 |         self,
 69 |         dataset_path: Optional[str] = None,
 70 |         random_seed: int = DEFAULT_SEED,
 71 |     ) -> None:
 72 |         """
 73 |         Initialize the BenchmarkDataset with an optional dataset path and random
 74 |         seed.  Args:
 75 |             dataset_path (Optional[str]): Path to the dataset. If None, it
 76 |             indicates that a default or random dataset might be used.
 77 |             random_seed (int): Seed value for reproducible shuffling or
 78 |             sampling. Defaults to DEFAULT_SEED.
 79 |         """
 80 |         self.dataset_path = dataset_path
 81 |         # Set the random seed, ensuring that a None value is replaced with the
 82 |         # default seed.
 83 |         self.random_seed = (random_seed
 84 |                             if random_seed is not None else self.DEFAULT_SEED)
 85 |         self.data = None
 86 | 
 87 |     def apply_multimodal_chat_transformation(
 88 |             self,
 89 |             prompt: str,
 90 |             mm_content: Optional[MultiModalDataDict] = None) -> list[dict]:
 91 |         """
 92 |         Transform a prompt and optional multimodal content into a chat format.
 93 |         This method is used for chat models that expect a specific conversation
 94 |         format.
 95 |         """
 96 |         content = [{"text": prompt, "type": "text"}]
 97 |         if mm_content is not None:
 98 |             content.append(mm_content)
 99 |         return [{"role": "user", "content": content}]
100 | 
101 |     def load_data(self) -> None:
102 |         """
103 |         Load data from the dataset path into self.data.
104 | 
105 |         This method must be overridden by subclasses since the method to load
106 |         data will vary depending on the dataset format and source.
107 | 
108 |         Raises:
109 |             NotImplementedError: If a subclass does not implement this method.
110 |         """
111 |         # TODO (jenniferzhao): add support for downloading data
112 |         raise NotImplementedError(
113 |             "load_data must be implemented in subclasses.")
114 | 
115 |     def get_random_lora_request(
116 |         self,
117 |         tokenizer: PreTrainedTokenizerBase,
118 |         max_loras: Optional[int] = None,
119 |         lora_path: Optional[str] = None,
120 |     ) -> tuple[Optional[LoRARequest], AnyTokenizer]:
121 |         """
122 |         Optionally select a random LoRA request and return its associated
123 |         tokenizer.
124 | 
125 |         This method is used when LoRA parameters are provided.  It randomly
126 |         selects a LoRA based on max_loras and retrieves a cached tokenizer for
127 |         that LoRA if available. Otherwise, it returns the base tokenizer.
128 | 
129 |         Args:
130 |             tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no
131 |             LoRA is selected.  max_loras (Optional[int]): The maximum number of
132 |             LoRAs available. If None, LoRA is not used.  lora_path
133 |             (Optional[str]): Path to the LoRA parameters on disk. If None, LoRA
134 |             is not used.
135 | 
136 |         Returns:
137 |             tuple[Optional[LoRARequest], AnyTokenizer]: A tuple where the first
138 |             element is a LoRARequest (or None if not applicable) and the second
139 |             element is the tokenizer associated with the LoRA request (or the
140 |             base tokenizer).
141 |         """
142 |         if max_loras is None or lora_path is None:
143 |             return None, tokenizer
144 | 
145 |         # Generate a random LoRA ID in the range [1, max_loras].
146 |         lora_id = random.randint(1, max_loras)
147 |         lora_request = LoRARequest(
148 |             lora_name=str(lora_id),
149 |             lora_int_id=lora_id,
150 |             lora_path=lora_path_on_disk(lora_path),
151 |         )
152 |         if lora_id not in lora_tokenizer_cache:
153 |             lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request)
154 |         # Return lora_request and the cached tokenizer if available; otherwise,
155 |         # return the base tokenizer
156 |         return lora_request, lora_tokenizer_cache[lora_id] or tokenizer
157 | 
158 |     @abstractmethod
159 |     def sample(self, tokenizer: PreTrainedTokenizerBase,
160 |                num_requests: int) -> list[SampleRequest]:
161 |         """
162 |         Abstract method to generate sample requests from the dataset.
163 | 
164 |         Subclasses must override this method to implement dataset-specific logic
165 |         for generating a list of SampleRequest objects.
166 | 
167 |         Args:
168 |             tokenizer (PreTrainedTokenizerBase): The tokenizer to be used
169 |              for processing the dataset's text.
170 |             num_requests (int): The number of sample requests to generate.
171 | 
172 |         Returns:
173 |             list[SampleRequest]: A list of sample requests generated from the
174 |             dataset.
175 |         """
176 |         raise NotImplementedError("sample must be implemented in subclasses.")
177 | 
178 |     def maybe_oversample_requests(self, requests: list[SampleRequest],
179 |                                   num_requests: int) -> None:
180 |         """
181 |         Oversamples the list of requests if its size is less than the desired
182 |         number.
183 | 
184 |         Args:
185 |             requests (List[SampleRequest]): The current list of sampled
186 |             requests.  num_requests (int): The target number of requests.
187 |         """
188 |         if len(requests) < num_requests:
189 |             random.seed(self.random_seed)
190 |             additional = random.choices(requests,
191 |                                         k=num_requests - len(requests))
192 |             requests.extend(additional)
193 |             logger.info("Oversampled requests to reach %d total samples.",
194 |                         num_requests)
195 | 
196 | 
197 | # -----------------------------------------------------------------------------
198 | # Utility Functions and Global Caches
199 | # -----------------------------------------------------------------------------
200 | 
201 | 
202 | def is_valid_sequence(
203 |     prompt_len: int,
204 |     output_len: int,
205 |     min_len: int = 4,
206 |     max_prompt_len: int = 1024,
207 |     max_total_len: int = 2048,
208 |     skip_min_output_len_check: bool = False,
209 | ) -> bool:
210 |     """
211 |     Validate a sequence based on prompt and output lengths.
212 | 
213 |     Default pruning criteria are copied from the original `sample_hf_requests`
214 |     and `sample_sharegpt_requests` functions in benchmark_serving.py, as well as
215 |     from `sample_requests` in benchmark_throughput.py.
216 |     """
217 |     # Check for invalid conditions
218 |     prompt_too_short = prompt_len < min_len
219 |     output_too_short = (not skip_min_output_len_check) and (output_len
220 |                                                             < min_len)
221 |     prompt_too_long = prompt_len > max_prompt_len
222 |     combined_too_long = (prompt_len + output_len) > max_total_len
223 | 
224 |     # Return True if none of the invalid conditions are met
225 |     return not (prompt_too_short or output_too_short or prompt_too_long
226 |                 or combined_too_long)
227 | 
228 | 
229 | @cache
230 | def lora_path_on_disk(lora_path: str) -> str:
231 |     return get_adapter_absolute_path(lora_path)
232 | 
233 | 
234 | # Global cache for LoRA tokenizers.
235 | lora_tokenizer_cache: dict[int, AnyTokenizer] = {}
236 | 
237 | 
238 | def process_image(image: Any) -> Mapping[str, Any]:
239 |     """
240 |     Process a single image input and return a multimedia content dictionary.
241 | 
242 |     For a PIL.Image.Image input:
243 |       - Converts the image to RGB.
244 |       - Saves the image as a JPEG in-memory.
245 |       - Encodes the JPEG data as a base64 string.
246 |       - Returns a dictionary with the image as a base64 data URL.
247 | 
248 |     For a string input:
249 |       - Treats the string as a URL or file path.
250 |       - Prepends "file://" if the string doesn't start with "http://" or
251 |         "file://".
252 |       - Returns a dictionary with the image URL.
253 | 
254 |     Raises:
255 |       ValueError: If the input is neither a PIL.Image.Image nor a string.
256 |     """
257 |     if isinstance(image, Image.Image):
258 |         image = image.convert("RGB")
259 |         with io.BytesIO() as image_data:
260 |             image.save(image_data, format="JPEG")
261 |             image_base64 = base64.b64encode(
262 |                 image_data.getvalue()).decode("utf-8")
263 |         return {
264 |             "type": "image_url",
265 |             "image_url": {
266 |                 "url": f"data:image/jpeg;base64,{image_base64}"
267 |             },
268 |         }
269 | 
270 |     if isinstance(image, str):
271 |         image_url = (image if image.startswith(
272 |             ("http://", "file://")) else f"file://{image}")
273 |         return {"type": "image_url", "image_url": {"url": image_url}}
274 | 
275 |     raise ValueError(
276 |         f"Invalid image input {image}. Must be a PIL.Image.Image or str.")
277 | 
278 | 
279 | # -----------------------------------------------------------------------------
280 | # Random Dataset Implementation (Synthetic Data)
281 | # -----------------------------------------------------------------------------
282 | 
283 | 
284 | class RandomDataset(BenchmarkDataset):
285 |     # Default values copied from benchmark_serving.py for the random dataset.
286 |     DEFAULT_PREFIX_LEN = 0
287 |     DEFAULT_RANGE_RATIO = 1.0
288 |     DEFAULT_INPUT_LEN = 1024
289 |     DEFAULT_OUTPUT_LEN = 128
290 | 
291 |     def __init__(
292 |         self,
293 |         **kwargs,
294 |     ) -> None:
295 |         super().__init__(**kwargs)
296 | 
297 |     def sample(
298 |         self,
299 |         tokenizer: PreTrainedTokenizerBase,
300 |         num_requests: int,
301 |         prefix_len: int = DEFAULT_PREFIX_LEN,
302 |         range_ratio: float = DEFAULT_RANGE_RATIO,
303 |         input_len: int = DEFAULT_INPUT_LEN,
304 |         output_len: int = DEFAULT_OUTPUT_LEN,
305 |         **kwargs,
306 |     ) -> list[SampleRequest]:
307 |         vocab_size = tokenizer.vocab_size
308 | 
309 |         prefix_token_ids = (np.random.randint(
310 |             0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else [])
311 | 
312 |         input_low = int(input_len * range_ratio)
313 |         output_low = int(output_len * range_ratio)
314 | 
315 |         input_lens = np.random.randint(input_low,
316 |                                        input_len + 1,
317 |                                        size=num_requests)
318 |         output_lens = np.random.randint(output_low,
319 |                                         output_len + 1,
320 |                                         size=num_requests)
321 |         offsets = np.random.randint(0, vocab_size, size=num_requests)
322 | 
323 |         requests = []
324 |         for i in range(num_requests):
325 |             inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) %
326 |                          vocab_size).tolist()
327 |             token_sequence = prefix_token_ids + inner_seq
328 |             prompt = tokenizer.decode(token_sequence)
329 |             total_input_len = prefix_len + int(input_lens[i])
330 |             requests.append(
331 |                 SampleRequest(
332 |                     prompt=prompt,
333 |                     prompt_len=total_input_len,
334 |                     expected_output_len=int(output_lens[i]),
335 |                 ))
336 |         return requests
337 | 
338 | 
339 | # -----------------------------------------------------------------------------
340 | # ShareGPT Dataset Implementation
341 | # -----------------------------------------------------------------------------
342 | 
343 | 
344 | class ShareGPTDataset(BenchmarkDataset):
345 |     """
346 |     Implements the ShareGPT dataset.  Loads data from a JSON file and generates
347 |     sample requests based on conversation turns.
348 |     """
349 | 
350 |     def __init__(self, **kwargs) -> None:
351 |         super().__init__(**kwargs)
352 |         self.load_data()
353 | 
354 |     def load_data(self) -> None:
355 |         if self.dataset_path is None:
356 |             raise ValueError("dataset_path must be provided for loading data.")
357 | 
358 |         with open(self.dataset_path, encoding="utf-8") as f:
359 |             self.data = json.load(f)
360 |         # Filter entries with at least two conversation turns.
361 |         self.data = [
362 |             entry for entry in self.data
363 |             if "conversations" in entry and len(entry["conversations"]) >= 2
364 |         ]
365 |         random.seed(self.random_seed)
366 |         random.shuffle(self.data)
367 | 
368 |     def sample(
369 |         self,
370 |         tokenizer: PreTrainedTokenizerBase,
371 |         num_requests: int,
372 |         lora_path: Optional[str] = None,
373 |         max_loras: Optional[int] = None,
374 |         output_len: Optional[int] = None,
375 |         enable_multimodal_chat: bool = False,
376 |         **kwargs,
377 |     ) -> list:
378 |         samples: list = []
379 |         for entry in self.data:
380 |             if len(samples) >= num_requests:
381 |                 break
382 |             prompt, completion = (
383 |                 entry["conversations"][0]["value"],
384 |                 entry["conversations"][1]["value"],
385 |             )
386 | 
387 |             lora_request, tokenizer = self.get_random_lora_request(
388 |                 tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
389 |             prompt_ids = tokenizer(prompt).input_ids
390 |             completion_ids = tokenizer(completion).input_ids
391 |             prompt_len = len(prompt_ids)
392 |             new_output_len = (len(completion_ids)
393 |                               if output_len is None else output_len)
394 |             if not is_valid_sequence(prompt_len,
395 |                                      new_output_len,
396 |                                      skip_min_output_len_check=output_len
397 |                                      is not None):
398 |                 continue
399 |             if enable_multimodal_chat:
400 |                 prompt = self.apply_multimodal_chat_transformation(
401 |                     prompt, None)
402 |             samples.append(
403 |                 SampleRequest(
404 |                     prompt=prompt,
405 |                     prompt_len=prompt_len,
406 |                     expected_output_len=new_output_len,
407 |                     lora_request=lora_request,
408 |                 ))
409 |         self.maybe_oversample_requests(samples, num_requests)
410 |         return samples
411 | 
412 | 
413 | # -----------------------------------------------------------------------------
414 | # Sonnet Dataset Implementation
415 | # -----------------------------------------------------------------------------
416 | 
417 | 
418 | class SonnetDataset(BenchmarkDataset):
419 |     """
420 |     Simplified implementation of the Sonnet dataset.  Loads poem lines from a
421 |     text file and generates sample requests.  Default values here copied from
422 |     `benchmark_serving.py` for the sonnet dataset.
423 |     """
424 | 
425 |     DEFAULT_PREFIX_LEN = 200
426 |     DEFAULT_INPUT_LEN = 550
427 |     DEFAULT_OUTPUT_LEN = 150
428 | 
429 |     def __init__(
430 |         self,
431 |         **kwargs,
432 |     ) -> None:
433 |         super().__init__(**kwargs)
434 |         self.load_data()
435 | 
436 |     def load_data(self) -> None:
437 |         if not self.dataset_path:
438 |             raise ValueError("dataset_path must be provided.")
439 |         with open(self.dataset_path, encoding="utf-8") as f:
440 |             self.data = f.readlines()
441 | 
442 |     def sample(
443 |         self,
444 |         tokenizer,
445 |         num_requests: int,
446 |         prefix_len: int = DEFAULT_PREFIX_LEN,
447 |         input_len: int = DEFAULT_INPUT_LEN,
448 |         output_len: int = DEFAULT_OUTPUT_LEN,
449 |         return_prompt_formatted: bool = False,
450 |         **kwargs,
451 |     ) -> list:
452 |         # Calculate average token length for a poem line.
453 |         tokenized_lines = [tokenizer(line).input_ids for line in self.data]
454 |         avg_len = sum(len(tokens)
455 |                       for tokens in tokenized_lines) / len(tokenized_lines)
456 | 
457 |         # Build the base prompt.
458 |         base_prompt = "Pick as many lines as you can from these poem lines:\n"
459 |         base_msg = [{"role": "user", "content": base_prompt}]
460 |         base_fmt = tokenizer.apply_chat_template(base_msg,
461 |                                                  add_generation_prompt=True,
462 |                                                  tokenize=False)
463 |         base_offset = len(tokenizer(base_fmt).input_ids)
464 |         if input_len <= base_offset:
465 |             raise ValueError(
466 |                 f"'input_len' must be higher than the base prompt length "
467 |                 f"({base_offset}).")
468 | 
469 |         # Determine how many poem lines to use.
470 |         num_input_lines = round((input_len - base_offset) / avg_len)
471 |         num_prefix_lines = round((prefix_len - base_offset) / avg_len)
472 |         prefix_lines = self.data[:num_prefix_lines]
473 | 
474 |         samples = []
475 |         for _ in range(num_requests):
476 |             extra_lines = random.choices(self.data,
477 |                                          k=num_input_lines - num_prefix_lines)
478 |             prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}"
479 |             msg = [{"role": "user", "content": prompt}]
480 |             prompt_formatted = tokenizer.apply_chat_template(
481 |                 msg, add_generation_prompt=True, tokenize=False)
482 |             prompt_len = len(tokenizer(prompt_formatted).input_ids)
483 |             samples.append(
484 |                 SampleRequest(
485 |                     prompt=prompt_formatted
486 |                     if return_prompt_formatted else prompt,
487 |                     prompt_len=prompt_len,
488 |                     expected_output_len=output_len,
489 |                 ))
490 |         return samples
491 | 
492 | 
493 | # -----------------------------------------------------------------------------
494 | # BurstGPT Dataset Implementation
495 | # -----------------------------------------------------------------------------
496 | 
497 | 
498 | class BurstGPTDataset(BenchmarkDataset):
499 |     """
500 |     Implements the BurstGPT dataset.  Loads data from a CSV file and generates
501 |     sample requests based on synthetic prompt generation. Only rows with Model
502 |     "GPT-4" and positive response tokens are used.
503 |     """
504 | 
505 |     def __init__(self, **kwargs) -> None:
506 |         super().__init__(**kwargs)
507 |         self.load_data()
508 | 
509 |     def load_data(self, ):
510 |         if self.dataset_path is None:
511 |             raise ValueError("dataset_path must be provided for loading data.")
512 | 
513 |         df = pd.read_csv(self.dataset_path)
514 |         # Filter to keep only GPT-4 rows.
515 |         gpt4_df = df[df["Model"] == "GPT-4"]
516 |         # Remove failed requests (where Response tokens is 0 or less).
517 |         gpt4_df = gpt4_df[gpt4_df["Response tokens"] > 0]
518 |         # Sample the desired number of rows.
519 |         self.data = gpt4_df
520 | 
521 |     def _sample_loaded_data(self, num_requests: int) -> list:
522 |         if num_requests <= len(self.data):
523 |             data = self.data.sample(n=num_requests,
524 |                                     random_state=self.random_seed)
525 |         else:
526 |             data = self.data.sample(
527 |                 n=num_requests,
528 |                 random_state=self.random_seed,
529 |                 replace=True,
530 |             )
531 |         # Convert the dataframe to a list of lists.
532 |         return data.values.tolist()
533 | 
534 |     def sample(
535 |         self,
536 |         tokenizer: PreTrainedTokenizerBase,
537 |         num_requests: int,
538 |         max_loras: Optional[int] = None,
539 |         lora_path: Optional[str] = None,
540 |         **kwargs,
541 |     ) -> list[SampleRequest]:
542 |         samples = []
543 |         data = self._sample_loaded_data(num_requests=num_requests)
544 |         for i in range(num_requests):
545 |             input_len = int(data[i][2])
546 |             output_len = int(data[i][3])
547 |             lora_req, tokenizer = self.get_random_lora_request(
548 |                 tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
549 |             vocab_size = tokenizer.vocab_size
550 |             # Generate a synthetic prompt: a list of token IDs computed as (i +
551 |             # j) modulo vocab_size.
552 |             token_ids = [(i + j) % vocab_size for j in range(input_len)]
553 |             prompt = tokenizer.decode(token_ids)
554 |             samples.append(
555 |                 SampleRequest(
556 |                     prompt=prompt,
557 |                     prompt_len=input_len,
558 |                     expected_output_len=output_len,
559 |                     lora_request=lora_req,
560 |                 ))
561 |         return samples
562 | 
563 | 
564 | # -----------------------------------------------------------------------------
565 | # HuggingFace Dataset Implementation
566 | # -----------------------------------------------------------------------------
567 | 
568 | 
569 | class HuggingFaceDataset(BenchmarkDataset):
570 |     """
571 |     Dataset class for processing a HuggingFace dataset with conversation data
572 |     and optional images.
573 |     """
574 | 
575 |     def __init__(
576 |         self,
577 |         dataset_split: str,
578 |         dataset_subset: Optional[str] = None,
579 |         **kwargs,
580 |     ) -> None:
581 |         super().__init__(**kwargs)
582 |         self.dataset_split = dataset_split
583 |         self.dataset_subset = dataset_subset
584 | 
585 |         self.load_data()
586 | 
587 |     def load_data(self) -> None:
588 |         if not self.dataset_path:
589 |             raise ValueError("dataset_path must be provided for loading data.")
590 | 
591 |         self.data = load_dataset(
592 |             self.dataset_path,
593 |             name=self.dataset_subset,
594 |             split=self.dataset_split,
595 |             streaming=True,
596 |         )
597 |         if self.data.features is None or "conversations" \
598 |             not in self.data.features:
599 |             raise ValueError(
600 |                 "HuggingFaceDataset currently only supports datasets with "
601 |                 "a 'conversations' column like lmms-lab/LLaVA-OneVision-Data. "
602 |                 "Please consider contributing if you would like to add "
603 |                 "support for additional dataset formats.")
604 |         # Shuffle and filter examples with at least 2 conversations.
605 |         self.data = self.data.shuffle(seed=self.random_seed).filter(
606 |             lambda x: len(x["conversations"]) >= 2)
607 | 
608 |     def sample(self,
609 |                tokenizer: PreTrainedTokenizerBase,
610 |                num_requests: int,
611 |                output_len: Optional[int] = None,
612 |                enable_multimodal_chat: bool = False,
613 |                **kwargs) -> list:
614 |         sampled_requests = []
615 |         dynamic_output = output_len is None
616 | 
617 |         for item in self.data:
618 |             if len(sampled_requests) >= num_requests:
619 |                 break
620 |             conv = item["conversations"]
621 |             prompt, completion = conv[0]["value"], conv[1]["value"]
622 | 
623 |             prompt_ids = tokenizer(prompt).input_ids
624 |             completion_ids = tokenizer(completion).input_ids
625 |             prompt_len = len(prompt_ids)
626 |             completion_len = len(completion_ids)
627 |             output_len = completion_len if dynamic_output else output_len
628 |             assert isinstance(output_len, int) and output_len > 0
629 |             if dynamic_output and not is_valid_sequence(
630 |                     prompt_len, completion_len):
631 |                 continue
632 |             mm_content = process_image(
633 |                 item["image"]) if "image" in item else None
634 |             if enable_multimodal_chat:
635 |                 # Note: when chat is enabled the request prompt_len is no longer
636 |                 # accurate and we will be using request output to count the
637 |                 # actual prompt len and output len
638 |                 prompt = self.apply_multimodal_chat_transformation(
639 |                     prompt, mm_content)
640 |             sampled_requests.append(
641 |                 SampleRequest(
642 |                     prompt=prompt,
643 |                     prompt_len=prompt_len,
644 |                     expected_output_len=output_len,
645 |                     multi_modal_data=mm_content,
646 |                 ))
647 |         self.maybe_oversample_requests(sampled_requests, num_requests)
648 |         return sampled_requests
649 | 
650 | 
651 | # -----------------------------------------------------------------------------
652 | # Vision Arena Dataset Implementation
653 | # -----------------------------------------------------------------------------
654 | 
655 | 
656 | class VisionArenaDataset(HuggingFaceDataset):
657 |     """
658 |     Vision Arena Dataset.
659 |     """
660 | 
661 |     DEFAULT_OUTPUT_LEN = 128
662 |     VISION_ARENA_DATASET_PATH = "lmarena-ai/vision-arena-bench-v0.1"
663 | 
664 |     def __init__(
665 |         self,
666 |         **kwargs,
667 |     ) -> None:
668 |         super().__init__(**kwargs)
669 |         if self.dataset_path != self.VISION_ARENA_DATASET_PATH:
670 |             raise ValueError(f"Only support Vision Arena dataset.\
671 |                     This data path {self.dataset_path} is not valid.")
672 |         if self.dataset_subset is None and self.dataset_split != "train":
673 |             raise ValueError("Dataset split must be 'train'.")
674 | 
675 |         self.load_data()
676 | 
677 |     def load_data(self) -> None:
678 |         dataset = load_dataset(
679 |             self.dataset_path,
680 |             name=self.dataset_subset,
681 |             split=self.dataset_split,
682 |             streaming=True,
683 |         )
684 |         self.data = dataset.shuffle(seed=self.random_seed)
685 | 
686 |     def sample(
687 |         self,
688 |         tokenizer: PreTrainedTokenizerBase,
689 |         num_requests: int,
690 |         output_len: Optional[int] = None,
691 |         enable_multimodal_chat: bool = False,
692 |         **kwargs,
693 |     ) -> list:
694 |         output_len = (output_len
695 |                       if output_len is not None else self.DEFAULT_OUTPUT_LEN)
696 |         sampled_requests = []
697 |         for item in self.data:
698 |             if len(sampled_requests) >= num_requests:
699 |                 break
700 |             prompt = item["turns"][0][0]["content"]
701 |             mm_content = process_image(item["images"][0])
702 |             prompt_len = len(tokenizer(prompt).input_ids)
703 |             if enable_multimodal_chat:
704 |                 # Note: when chat is enabled the request prompt_len is no longer
705 |                 # accurate and we will be using request output to count the
706 |                 # actual prompt len
707 |                 prompt = self.apply_multimodal_chat_transformation(
708 |                     prompt, mm_content)
709 |             sampled_requests.append(
710 |                 SampleRequest(
711 |                     prompt=prompt,
712 |                     prompt_len=prompt_len,
713 |                     expected_output_len=output_len,
714 |                     multi_modal_data=mm_content,
715 |                 ))
716 |         self.maybe_oversample_requests(sampled_requests, num_requests)
717 |         return sampled_requests
```

--------------------------------------------------------------------------------
/benchmark/benchmark_serving.py:
--------------------------------------------------------------------------------

```python
   1 | # SPDX-License-Identifier: Apache-2.0
   2 | r"""Benchmark online serving throughput.
   3 | 
   4 | On the server side, run one of the following commands:
   5 |     vLLM OpenAI API server
   6 |     vllm serve <your_model> \
   7 |         --swap-space 16 \
   8 |         --disable-log-requests
   9 | 
  10 |     (TGI backend)
  11 |     ./launch_tgi_server.sh <your_model> <max_batch_total_tokens>
  12 | 
  13 | On the client side, run:
  14 |     python benchmarks/benchmark_serving.py \
  15 |         --backend <backend> \
  16 |         --model <your_model> \
  17 |         --dataset-name sharegpt \
  18 |         --dataset-path <path to dataset> \
  19 |         --request-rate <request_rate> \ # By default <request_rate> is inf
  20 |         --num-prompts <num_prompts> # By default <num_prompts> is 1000
  21 | 
  22 |     when using tgi backend, add
  23 |         --endpoint /generate_stream
  24 |     to the end of the command above.
  25 | """
  26 | import argparse
  27 | import asyncio
  28 | import gc
  29 | import json
  30 | import os
  31 | import random
  32 | import time
  33 | import warnings
  34 | from collections.abc import AsyncGenerator, Iterable
  35 | from dataclasses import dataclass
  36 | from datetime import datetime
  37 | from typing import Any, Optional
  38 | 
  39 | import numpy as np
  40 | from benchmark.backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
  41 |                                   RequestFuncOutput)
  42 | from tqdm.asyncio import tqdm
  43 | from transformers import PreTrainedTokenizerBase
  44 | 
  45 | try:
  46 |     from vllm.transformers_utils.tokenizer import get_tokenizer
  47 | except ImportError:
  48 |     from backend_request_func import get_tokenizer
  49 | 
  50 | try:
  51 |     from vllm.utils import FlexibleArgumentParser
  52 | except ImportError:
  53 |     from argparse import ArgumentParser as FlexibleArgumentParser
  54 | 
  55 | from benchmark.benchmark_dataset import (BurstGPTDataset, HuggingFaceDataset,
  56 |                                RandomDataset, SampleRequest, ShareGPTDataset,
  57 |                                SonnetDataset, VisionArenaDataset)
  58 | from benchmark.benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
  59 | 
  60 | MILLISECONDS_TO_SECONDS_CONVERSION = 1000
  61 | 
  62 | 
  63 | @dataclass
  64 | class BenchmarkMetrics:
  65 |     completed: int
  66 |     total_input: int
  67 |     total_output: int
  68 |     request_throughput: float
  69 |     request_goodput: float
  70 |     output_throughput: float
  71 |     total_token_throughput: float
  72 |     mean_ttft_ms: float
  73 |     median_ttft_ms: float
  74 |     std_ttft_ms: float
  75 |     percentiles_ttft_ms: list[tuple[float, float]]
  76 |     mean_tpot_ms: float
  77 |     median_tpot_ms: float
  78 |     std_tpot_ms: float
  79 |     percentiles_tpot_ms: list[tuple[float, float]]
  80 |     mean_itl_ms: float
  81 |     median_itl_ms: float
  82 |     std_itl_ms: float
  83 |     percentiles_itl_ms: list[tuple[float, float]]
  84 |     # E2EL stands for end-to-end latency per request.
  85 |     # It is the time taken on the client side from sending
  86 |     # a request to receiving a complete response.
  87 |     mean_e2el_ms: float
  88 |     median_e2el_ms: float
  89 |     std_e2el_ms: float
  90 |     percentiles_e2el_ms: list[tuple[float, float]]
  91 | 
  92 | 
  93 | async def get_request(
  94 |     input_requests: list[SampleRequest],
  95 |     request_rate: float,
  96 |     burstiness: float = 1.0,
  97 | ) -> AsyncGenerator[SampleRequest, None]:
  98 |     """
  99 |     Asynchronously generates requests at a specified rate
 100 |     with OPTIONAL burstiness.
 101 | 
 102 |     Args:
 103 |         input_requests:
 104 |             A list of input requests, each represented as a SampleRequest.
 105 |         request_rate:
 106 |             The rate at which requests are generated (requests/s).
 107 |         burstiness (optional):
 108 |             The burstiness factor of the request generation.
 109 |             Only takes effect when request_rate is not inf.
 110 |             Default value is 1, which follows a Poisson process.
 111 |             Otherwise, the request intervals follow a gamma distribution.
 112 |             A lower burstiness value (0 < burstiness < 1) results
 113 |             in more bursty requests, while a higher burstiness value
 114 |             (burstiness > 1) results in a more uniform arrival of requests.
 115 |     """
 116 |     input_requests: Iterable[SampleRequest] = iter(input_requests)
 117 | 
 118 |     # Calculate scale parameter theta to maintain the desired request_rate.
 119 |     assert burstiness > 0, (
 120 |         f"A positive burstiness factor is expected, but given {burstiness}.")
 121 |     theta = 1.0 / (request_rate * burstiness)
 122 | 
 123 |     for request in input_requests:
 124 |         yield request
 125 | 
 126 |         if request_rate == float("inf"):
 127 |             # If the request rate is infinity, then we don't need to wait.
 128 |             continue
 129 | 
 130 |         # Sample the request interval from the gamma distribution.
 131 |         # If burstiness is 1, it follows exponential distribution.
 132 |         interval = np.random.gamma(shape=burstiness, scale=theta)
 133 |         # The next request will be sent after the interval.
 134 |         await asyncio.sleep(interval)
 135 | 
 136 | 
 137 | def calculate_metrics(
 138 |     input_requests: list[SampleRequest],
 139 |     outputs: list[RequestFuncOutput],
 140 |     dur_s: float,
 141 |     tokenizer: PreTrainedTokenizerBase,
 142 |     selected_percentile_metrics: list[str],
 143 |     selected_percentiles: list[float],
 144 |     goodput_config_dict: dict[str, float],
 145 | ) -> tuple[BenchmarkMetrics, list[int]]:
 146 |     actual_output_lens: list[int] = []
 147 |     total_input = 0
 148 |     completed = 0
 149 |     good_completed = 0
 150 |     itls: list[float] = []
 151 |     tpots: list[float] = []
 152 |     all_tpots: list[float] = []
 153 |     ttfts: list[float] = []
 154 |     e2els: list[float] = []
 155 |     for i in range(len(outputs)):
 156 |         if outputs[i].success:
 157 |             output_len = outputs[i].output_tokens
 158 | 
 159 |             if output_len is None:
 160 |                 # We use the tokenizer to count the number of output tokens
 161 |                 # for some serving backends instead of looking at
 162 |                 # len(outputs[i].itl) since multiple output tokens may be
 163 |                 # bundled together
 164 |                 # Note : this may inflate the output token count slightly
 165 |                 output_len = len(
 166 |                     tokenizer(outputs[i].generated_text,
 167 |                               add_special_tokens=False).input_ids)
 168 |             actual_output_lens.append(output_len)
 169 |             total_input += input_requests[i].prompt_len
 170 |             tpot = 0
 171 |             if output_len > 1:
 172 |                 latency_minus_ttft = outputs[i].latency - outputs[i].ttft
 173 |                 tpot = latency_minus_ttft / (output_len - 1)
 174 |                 tpots.append(tpot)
 175 |             # Note: if output_len <= 1, we regard tpot as 0 for goodput
 176 |             all_tpots.append(tpot)
 177 |             itls += outputs[i].itl
 178 |             ttfts.append(outputs[i].ttft)
 179 |             e2els.append(outputs[i].latency)
 180 |             completed += 1
 181 |         else:
 182 |             actual_output_lens.append(0)
 183 | 
 184 |     if goodput_config_dict:
 185 |         valid_metrics = []
 186 |         slo_values = []
 187 | 
 188 |         if "ttft" in goodput_config_dict:
 189 |             valid_metrics.append(ttfts)
 190 |             slo_values.append(goodput_config_dict["ttft"] /
 191 |                               MILLISECONDS_TO_SECONDS_CONVERSION)
 192 |         if "tpot" in goodput_config_dict:
 193 |             valid_metrics.append(all_tpots)
 194 |             slo_values.append(goodput_config_dict["tpot"] /
 195 |                               MILLISECONDS_TO_SECONDS_CONVERSION)
 196 |         if "e2el" in goodput_config_dict:
 197 |             valid_metrics.append(e2els)
 198 |             slo_values.append(goodput_config_dict["e2el"] /
 199 |                               MILLISECONDS_TO_SECONDS_CONVERSION)
 200 | 
 201 |         for req_metric in zip(*valid_metrics):
 202 |             is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
 203 |             if is_good_req:
 204 |                 good_completed += 1
 205 | 
 206 |     if completed == 0:
 207 |         warnings.warn(
 208 |             "All requests failed. This is likely due to a misconfiguration "
 209 |             "on the benchmark arguments.",
 210 |             stacklevel=2)
 211 |     metrics = BenchmarkMetrics(
 212 |         completed=completed,
 213 |         total_input=total_input,
 214 |         total_output=sum(actual_output_lens),
 215 |         request_throughput=completed / dur_s,
 216 |         request_goodput=good_completed / dur_s,
 217 |         output_throughput=sum(actual_output_lens) / dur_s,
 218 |         total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
 219 |         mean_ttft_ms=np.mean(ttfts or 0) *
 220 |         1000,  # ttfts is empty if streaming is not supported by backend
 221 |         std_ttft_ms=np.std(ttfts or 0) * 1000,
 222 |         median_ttft_ms=np.median(ttfts or 0) * 1000,
 223 |         percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000)
 224 |                              for p in selected_percentiles],
 225 |         mean_tpot_ms=np.mean(tpots or 0) * 1000,
 226 |         std_tpot_ms=np.std(tpots or 0) * 1000,
 227 |         median_tpot_ms=np.median(tpots or 0) * 1000,
 228 |         percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000)
 229 |                              for p in selected_percentiles],
 230 |         mean_itl_ms=np.mean(itls or 0) * 1000,
 231 |         std_itl_ms=np.std(itls or 0) * 1000,
 232 |         median_itl_ms=np.median(itls or 0) * 1000,
 233 |         percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000)
 234 |                             for p in selected_percentiles],
 235 |         mean_e2el_ms=np.mean(e2els or 0) * 1000,
 236 |         std_e2el_ms=np.std(e2els or 0) * 1000,
 237 |         median_e2el_ms=np.median(e2els or 0) * 1000,
 238 |         percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
 239 |                              for p in selected_percentiles],
 240 |     )
 241 | 
 242 |     return metrics, actual_output_lens
 243 | 
 244 | 
 245 | async def benchmark(
 246 |     backend: str,
 247 |     api_url: str,
 248 |     base_url: str,
 249 |     model_id: str,
 250 |     model_name: str,
 251 |     tokenizer: PreTrainedTokenizerBase,
 252 |     input_requests: list[SampleRequest],
 253 |     logprobs: Optional[int],
 254 |     request_rate: float,
 255 |     burstiness: float,
 256 |     disable_tqdm: bool,
 257 |     profile: bool,
 258 |     selected_percentile_metrics: list[str],
 259 |     selected_percentiles: list[float],
 260 |     ignore_eos: bool,
 261 |     goodput_config_dict: dict[str, float],
 262 |     max_concurrency: Optional[int],
 263 |     lora_modules: Optional[Iterable[str]],
 264 | ):
 265 |     if backend in ASYNC_REQUEST_FUNCS:
 266 |         request_func = ASYNC_REQUEST_FUNCS[backend]
 267 |     else:
 268 |         raise ValueError(f"Unknown backend: {backend}")
 269 | 
 270 |     print("Starting initial single prompt test run...")
 271 |     test_prompt, test_prompt_len, test_output_len, test_mm_content = \
 272 |         input_requests[0].prompt, input_requests[0].prompt_len, \
 273 |         input_requests[0].expected_output_len, \
 274 |             input_requests[0].multi_modal_data
 275 | 
 276 |     if backend != "openai-chat" and test_mm_content is not None:
 277 |         # multi-modal benchmark is only available on OpenAI Chat backend.
 278 |         raise ValueError(
 279 |             "Multi-modal content is only supported on 'openai-chat' backend.")
 280 |     assert test_mm_content is None or isinstance(test_mm_content, dict)
 281 |     test_input = RequestFuncInput(
 282 |         model=model_id,
 283 |         model_name=model_name,
 284 |         prompt=test_prompt,
 285 |         api_url=api_url,
 286 |         prompt_len=test_prompt_len,
 287 |         output_len=test_output_len,
 288 |         logprobs=logprobs,
 289 |         multi_modal_content=test_mm_content,
 290 |         ignore_eos=ignore_eos,
 291 |     )
 292 | 
 293 |     test_output = await request_func(request_func_input=test_input)
 294 |     if not test_output.success:
 295 |         raise ValueError(
 296 |             "Initial test run failed - Please make sure benchmark arguments "
 297 |             f"are correctly specified. Error: {test_output.error}")
 298 |     else:
 299 |         print("Initial test run completed. Starting main benchmark run...")
 300 | 
 301 |     if lora_modules:
 302 |         # For each input request, choose a LoRA module at random.
 303 |         lora_modules = iter(
 304 |             [random.choice(lora_modules) \
 305 |                 for _ in range(len(input_requests))])
 306 | 
 307 |     if profile:
 308 |         print("Starting profiler...")
 309 |         profile_input = RequestFuncInput(model=model_id,
 310 |                                          model_name=model_name,
 311 |                                          prompt=test_prompt,
 312 |                                          api_url=base_url + "/start_profile",
 313 |                                          prompt_len=test_prompt_len,
 314 |                                          output_len=test_output_len,
 315 |                                          logprobs=logprobs,
 316 |                                          multi_modal_content=test_mm_content,
 317 |                                          ignore_eos=ignore_eos)
 318 |         profile_output = await request_func(request_func_input=profile_input)
 319 |         if profile_output.success:
 320 |             print("Profiler started")
 321 | 
 322 |     if burstiness == 1.0:
 323 |         distribution = "Poisson process"
 324 |     else:
 325 |         distribution = "Gamma distribution"
 326 | 
 327 |     print(f"Traffic request rate: {request_rate}")
 328 |     print(f"Burstiness factor: {burstiness} ({distribution})")
 329 |     print(f"Maximum request concurrency: {max_concurrency}")
 330 | 
 331 |     pbar = None if disable_tqdm else tqdm(total=len(input_requests))
 332 | 
 333 |     # This can be used once the minimum Python version is 3.10 or higher,
 334 |     # and it will simplify the code in limited_request_func.
 335 |     #    semaphore = (asyncio.Semaphore(max_concurrency)
 336 |     #                 if max_concurrency else contextlib.nullcontext())
 337 |     semaphore = (asyncio.Semaphore(max_concurrency)
 338 |                  if max_concurrency else None)
 339 | 
 340 |     async def limited_request_func(request_func_input, pbar):
 341 |         if semaphore is None:
 342 |             return await request_func(request_func_input=request_func_input,
 343 |                                       pbar=pbar)
 344 |         async with semaphore:
 345 |             return await request_func(request_func_input=request_func_input,
 346 |                                       pbar=pbar)
 347 | 
 348 |     benchmark_start_time = time.perf_counter()
 349 |     tasks: list[asyncio.Task] = []
 350 |     async for request in get_request(input_requests, request_rate, burstiness):
 351 |         prompt, prompt_len, output_len, mm_content = request.prompt, \
 352 |             request.prompt_len, request.expected_output_len, \
 353 |                 request.multi_modal_data
 354 |         req_model_id, req_model_name = model_id, model_name
 355 |         if lora_modules:
 356 |             req_lora_module = next(lora_modules)
 357 |             req_model_id, req_model_name = req_lora_module, req_lora_module
 358 | 
 359 |         request_func_input = RequestFuncInput(model=req_model_id,
 360 |                                               model_name=req_model_name,
 361 |                                               prompt=prompt,
 362 |                                               api_url=api_url,
 363 |                                               prompt_len=prompt_len,
 364 |                                               output_len=output_len,
 365 |                                               logprobs=logprobs,
 366 |                                               multi_modal_content=mm_content,
 367 |                                               ignore_eos=ignore_eos)
 368 |         tasks.append(
 369 |             asyncio.create_task(
 370 |                 limited_request_func(request_func_input=request_func_input,
 371 |                                      pbar=pbar)))
 372 |     outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
 373 | 
 374 |     if profile:
 375 |         print("Stopping profiler...")
 376 |         profile_input = RequestFuncInput(
 377 |             model=model_id,
 378 |             prompt=test_prompt,
 379 |             api_url=base_url + "/stop_profile",
 380 |             prompt_len=test_prompt_len,
 381 |             output_len=test_output_len,
 382 |             logprobs=logprobs,
 383 |         )
 384 |         profile_output = await request_func(request_func_input=profile_input)
 385 |         if profile_output.success:
 386 |             print("Profiler stopped")
 387 | 
 388 |     if pbar is not None:
 389 |         pbar.close()
 390 | 
 391 |     benchmark_duration = time.perf_counter() - benchmark_start_time
 392 | 
 393 |     metrics, actual_output_lens = calculate_metrics(
 394 |         input_requests=input_requests,
 395 |         outputs=outputs,
 396 |         dur_s=benchmark_duration,
 397 |         tokenizer=tokenizer,
 398 |         selected_percentile_metrics=selected_percentile_metrics,
 399 |         selected_percentiles=selected_percentiles,
 400 |         goodput_config_dict=goodput_config_dict,
 401 |     )
 402 | 
 403 |     print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
 404 |     print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
 405 |     print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
 406 |                                     benchmark_duration))
 407 |     print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
 408 |     print("{:<40} {:<10}".format("Total generated tokens:",
 409 |                                  metrics.total_output))
 410 |     print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
 411 |                                     metrics.request_throughput))
 412 |     if goodput_config_dict:
 413 |         print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
 414 |                                         metrics.request_goodput))
 415 |     print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
 416 |                                     metrics.output_throughput))
 417 |     print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
 418 |                                     metrics.total_token_throughput))
 419 | 
 420 |     result = {
 421 |         "duration": benchmark_duration,
 422 |         "completed": metrics.completed,
 423 |         "total_input_tokens": metrics.total_input,
 424 |         "total_output_tokens": metrics.total_output,
 425 |         "request_throughput": metrics.request_throughput,
 426 |         "request_goodput:":
 427 |         metrics.request_goodput if goodput_config_dict else None,
 428 |         "output_throughput": metrics.output_throughput,
 429 |         "total_token_throughput": metrics.total_token_throughput,
 430 |         "input_lens": [output.prompt_len for output in outputs],
 431 |         "output_lens": actual_output_lens,
 432 |         "ttfts": [output.ttft for output in outputs],
 433 |         "itls": [output.itl for output in outputs],
 434 |         "generated_texts": [output.generated_text for output in outputs],
 435 |         "errors": [output.error for output in outputs],
 436 |     }
 437 | 
 438 |     def process_one_metric(
 439 |         # E.g., "ttft"
 440 |         metric_attribute_name: str,
 441 |         # E.g., "TTFT"
 442 |         metric_name: str,
 443 |         # E.g., "Time to First Token"
 444 |         metric_header: str,
 445 |     ):
 446 |         # This function prints and adds statistics of the specified
 447 |         # metric.
 448 |         if metric_attribute_name not in selected_percentile_metrics:
 449 |             return
 450 |         print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
 451 |         print("{:<40} {:<10.2f}".format(
 452 |             f"Mean {metric_name} (ms):",
 453 |             getattr(metrics, f"mean_{metric_attribute_name}_ms")))
 454 |         print("{:<40} {:<10.2f}".format(
 455 |             f"Median {metric_name} (ms):",
 456 |             getattr(metrics, f"median_{metric_attribute_name}_ms")))
 457 |         result[f"mean_{metric_attribute_name}_ms"] = getattr(
 458 |             metrics, f"mean_{metric_attribute_name}_ms")
 459 |         result[f"median_{metric_attribute_name}_ms"] = getattr(
 460 |             metrics, f"median_{metric_attribute_name}_ms")
 461 |         result[f"std_{metric_attribute_name}_ms"] = getattr(
 462 |             metrics, f"std_{metric_attribute_name}_ms")
 463 |         for p, value in getattr(metrics,
 464 |                                 f"percentiles_{metric_attribute_name}_ms"):
 465 |             p_word = str(int(p)) if int(p) == p else str(p)
 466 |             print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):",
 467 |                                             value))
 468 |             result[f"p{p_word}_{metric_attribute_name}_ms"] = value
 469 | 
 470 |     process_one_metric("ttft", "TTFT", "Time to First Token")
 471 |     process_one_metric("tpot", "TPOT",
 472 |                        "Time per Output Token (excl. 1st token)")
 473 |     process_one_metric("itl", "ITL", "Inter-token Latency")
 474 |     process_one_metric("e2el", "E2EL", "End-to-end Latency")
 475 | 
 476 |     print("=" * 50)
 477 | 
 478 |     return result
 479 | 
 480 | 
 481 | def check_goodput_args(args):
 482 |     # Check and parse goodput arguments
 483 |     goodput_config_dict = {}
 484 |     VALID_NAMES = ["ttft", "tpot", "e2el"]
 485 |     if args.goodput:
 486 |         goodput_config_dict = parse_goodput(args.goodput)
 487 |         for slo_name, slo_val in goodput_config_dict.items():
 488 |             if slo_name not in VALID_NAMES:
 489 |                 raise ValueError(
 490 |                     f"Invalid metric name found, {slo_name}: {slo_val}. "
 491 |                     "The service level objective name should be one of "
 492 |                     f"{str(VALID_NAMES)}. ")
 493 |             if slo_val < 0:
 494 |                 raise ValueError(
 495 |                     f"Invalid value found, {slo_name}: {slo_val}. "
 496 |                     "The service level objective value should be "
 497 |                     "non-negative.")
 498 |     return goodput_config_dict
 499 | 
 500 | 
 501 | def parse_goodput(slo_pairs):
 502 |     goodput_config_dict = {}
 503 |     try:
 504 |         for slo_pair in slo_pairs:
 505 |             slo_name, slo_val = slo_pair.split(":")
 506 |             goodput_config_dict[slo_name] = float(slo_val)
 507 |     except ValueError as err:
 508 |         raise argparse.ArgumentTypeError(
 509 |             "Invalid format found for service level objectives. "
 510 |             "Specify service level objectives for goodput as \"KEY:VALUE\" "
 511 |             "pairs, where the key is a metric name, and the value is a "
 512 |             "number in milliseconds.") from err
 513 |     return goodput_config_dict
 514 | 
 515 | 
 516 | def save_to_pytorch_benchmark_format(args: argparse.Namespace,
 517 |                                      results: dict[str, Any],
 518 |                                      file_name: str) -> None:
 519 |     metrics = [
 520 |         "median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms",
 521 |         "mean_tpot_ms", "median_tpot_ms", "std_tpot_ms", "p99_tpot_ms",
 522 |         "median_itl_ms", "mean_itl_ms", "std_itl_ms", "p99_itl_ms"
 523 |     ]
 524 |     # These raw data might be useful, but they are rather big. They can be added
 525 |     # later if needed
 526 |     ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"]
 527 |     pt_records = convert_to_pytorch_benchmark_format(
 528 |         args=args,
 529 |         metrics={k: [results[k]]
 530 |                  for k in metrics},
 531 |         extra_info={
 532 |             k: results[k]
 533 |             for k in results if k not in metrics and k not in ignored_metrics
 534 |         })
 535 |     if pt_records:
 536 |         # Don't use json suffix here as we don't want CI to pick it up
 537 |         pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json"
 538 |         write_to_json(pt_file, pt_records)
 539 | 
 540 | 
 541 | def main(args: argparse.Namespace):
 542 |     print(args)
 543 |     random.seed(args.seed)
 544 |     np.random.seed(args.seed)
 545 | 
 546 |     backend = args.backend
 547 |     model_id = args.model
 548 |     model_name = args.served_model_name
 549 |     tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
 550 |     tokenizer_mode = args.tokenizer_mode
 551 | 
 552 |     if args.base_url is not None:
 553 |         api_url = f"{args.base_url}{args.endpoint}"
 554 |         base_url = f"{args.base_url}"
 555 |     else:
 556 |         api_url = f"http://{args.host}:{args.port}{args.endpoint}"
 557 |         base_url = f"http://{args.host}:{args.port}"
 558 | 
 559 |     tokenizer = get_tokenizer(tokenizer_id,
 560 |                               tokenizer_mode=tokenizer_mode,
 561 |                               trust_remote_code=args.trust_remote_code)
 562 | 
 563 |     if args.dataset_name is None:
 564 |         raise ValueError(
 565 |             "Please specify '--dataset-name' and the corresponding "
 566 |             "'--dataset-path' if required.")
 567 | 
 568 |     if args.dataset_name == "sonnet":
 569 |         dataset = SonnetDataset(dataset_path=args.dataset_path)
 570 |         # For the "sonnet" dataset, formatting depends on the backend.
 571 |         if args.backend == "openai-chat":
 572 |             input_requests = dataset.sample(num_requests=args.num_prompts,
 573 |                                             input_len=args.sonnet_input_len,
 574 |                                             output_len=args.sonnet_output_len,
 575 |                                             prefix_len=args.sonnet_prefix_len,
 576 |                                             tokenizer=tokenizer,
 577 |                                             return_prompt_formatted=False)
 578 |         else:
 579 |             assert tokenizer.chat_template or tokenizer.default_chat_template, (
 580 |                 "Tokenizer/model must have chat template for sonnet dataset.")
 581 |             input_requests = dataset.sample(num_requests=args.num_prompts,
 582 |                                             input_len=args.sonnet_input_len,
 583 |                                             output_len=args.sonnet_output_len,
 584 |                                             prefix_len=args.sonnet_prefix_len,
 585 |                                             tokenizer=tokenizer,
 586 |                                             return_prompt_formatted=True)
 587 | 
 588 |     elif args.dataset_name == "hf":
 589 |         # Choose between VisionArenaDataset
 590 |         # and HuggingFaceDataset based on provided parameters.
 591 |         dataset_class = (VisionArenaDataset if args.dataset_path
 592 |                          == VisionArenaDataset.VISION_ARENA_DATASET_PATH
 593 |                          and args.hf_subset is None else HuggingFaceDataset)
 594 |         input_requests = dataset_class(
 595 |             dataset_path=args.dataset_path,
 596 |             dataset_subset=args.hf_subset,
 597 |             dataset_split=args.hf_split,
 598 |         ).sample(
 599 |             num_requests=args.num_prompts,
 600 |             tokenizer=tokenizer,
 601 |             random_seed=args.seed,
 602 |             output_len=args.hf_output_len,
 603 |         )
 604 | 
 605 |     else:
 606 |         # For datasets that follow a similar structure, use a mapping.
 607 |         dataset_mapping = {
 608 |             "sharegpt":
 609 |             lambda: ShareGPTDataset(random_seed=args.seed,
 610 |                                     dataset_path=args.dataset_path).sample(
 611 |                                         tokenizer=tokenizer,
 612 |                                         num_requests=args.num_prompts,
 613 |                                         output_len=args.sharegpt_output_len,
 614 |                                     ),
 615 |             "burstgpt":
 616 |             lambda: BurstGPTDataset(random_seed=args.seed,
 617 |                                     dataset_path=args.dataset_path).
 618 |             sample(tokenizer=tokenizer, num_requests=args.num_prompts),
 619 |             "random":
 620 |             lambda: RandomDataset(dataset_path=args.dataset_path).sample(
 621 |                 tokenizer=tokenizer,
 622 |                 num_requests=args.num_prompts,
 623 |                 prefix_len=args.random_prefix_len,
 624 |                 input_len=args.random_input_len,
 625 |                 output_len=args.random_output_len,
 626 |                 range_ratio=args.random_range_ratio,
 627 |             )
 628 |         }
 629 | 
 630 |         try:
 631 |             input_requests = dataset_mapping[args.dataset_name]()
 632 |         except KeyError as err:
 633 |             raise ValueError(f"Unknown dataset: {args.dataset_name}") from err
 634 |     goodput_config_dict = check_goodput_args(args)
 635 | 
 636 |     # Avoid GC processing "static" data - reduce pause times.
 637 |     gc.collect()
 638 |     gc.freeze()
 639 | 
 640 |     benchmark_result = asyncio.run(
 641 |         benchmark(
 642 |             backend=backend,
 643 |             api_url=api_url,
 644 |             base_url=base_url,
 645 |             model_id=model_id,
 646 |             model_name=model_name,
 647 |             tokenizer=tokenizer,
 648 |             input_requests=input_requests,
 649 |             logprobs=args.logprobs,
 650 |             request_rate=args.request_rate,
 651 |             burstiness=args.burstiness,
 652 |             disable_tqdm=args.disable_tqdm,
 653 |             profile=args.profile,
 654 |             selected_percentile_metrics=args.percentile_metrics.split(","),
 655 |             selected_percentiles=[
 656 |                 float(p) for p in args.metric_percentiles.split(",")
 657 |             ],
 658 |             ignore_eos=args.ignore_eos,
 659 |             goodput_config_dict=goodput_config_dict,
 660 |             max_concurrency=args.max_concurrency,
 661 |             lora_modules=args.lora_modules,
 662 |         ))
 663 |         
 664 |     # CUSTOM START
 665 |     # I made slight modification here to just return the benchmark results.
 666 |     # The rest was left as is, in case we need to update in the future.
 667 |     return benchmark_result
 668 |     # CUSTOM END
 669 | 
 670 |     # Save config and results to json
 671 |     if args.save_result:
 672 |         result_json: dict[str, Any] = {}
 673 | 
 674 |         # Setup
 675 |         current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
 676 |         result_json["date"] = current_dt
 677 |         result_json["backend"] = backend
 678 |         result_json["model_id"] = model_id
 679 |         result_json["tokenizer_id"] = tokenizer_id
 680 |         result_json["num_prompts"] = args.num_prompts
 681 | 
 682 |         # Metadata
 683 |         if args.metadata:
 684 |             for item in args.metadata:
 685 |                 if "=" in item:
 686 |                     kvstring = item.split("=")
 687 |                     result_json[kvstring[0].strip()] = kvstring[1].strip()
 688 |                 else:
 689 |                     raise ValueError(
 690 |                         "Invalid metadata format. Please use KEY=VALUE format."
 691 |                     )
 692 | 
 693 |         if not args.save_detailed:
 694 |             # Remove fields with too many data points
 695 |             for field in [
 696 |                     "input_lens", "output_lens", "ttfts", "itls",
 697 |                     "generated_texts", "errors"
 698 |             ]:
 699 |                 if field in result_json:
 700 |                     del result_json[field]
 701 | 
 702 |         # Traffic
 703 |         result_json["request_rate"] = (args.request_rate if args.request_rate
 704 |                                        < float("inf") else "inf")
 705 |         result_json["burstiness"] = args.burstiness
 706 |         result_json["max_concurrency"] = args.max_concurrency
 707 | 
 708 |         # Merge with benchmark result
 709 |         result_json = {**result_json, **benchmark_result}
 710 | 
 711 |         # Save to file
 712 |         base_model_id = model_id.split("/")[-1]
 713 |         max_concurrency_str = (f"-concurrency{args.max_concurrency}"
 714 |                                if args.max_concurrency is not None else "")
 715 |         file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json"  #noqa
 716 |         if args.result_filename:
 717 |             file_name = args.result_filename
 718 |         if args.result_dir:
 719 |             file_name = os.path.join(args.result_dir, file_name)
 720 |         with open(file_name, "w", encoding='utf-8') as outfile:
 721 |             json.dump(result_json, outfile)
 722 |         save_to_pytorch_benchmark_format(args, result_json, file_name)
 723 | 
 724 | 
 725 | if __name__ == "__main__":
 726 |     parser = FlexibleArgumentParser(
 727 |         description="Benchmark the online serving throughput.")
 728 |     parser.add_argument(
 729 |         "--backend",
 730 |         type=str,
 731 |         default="vllm",
 732 |         choices=list(ASYNC_REQUEST_FUNCS.keys()),
 733 |     )
 734 |     parser.add_argument(
 735 |         "--base-url",
 736 |         type=str,
 737 |         default=None,
 738 |         help="Server or API base url if not using http host and port.",
 739 |     )
 740 |     # Use 127.0.0.1 here instead of localhost to force the use of ipv4
 741 |     parser.add_argument("--host", type=str, default="127.0.0.1")
 742 |     parser.add_argument("--port", type=int, default=8000)
 743 |     parser.add_argument(
 744 |         "--endpoint",
 745 |         type=str,
 746 |         default="/v1/completions",
 747 |         help="API endpoint.",
 748 |     )
 749 |     parser.add_argument(
 750 |         "--dataset-name",
 751 |         type=str,
 752 |         default="sharegpt",
 753 |         choices=["sharegpt", "burstgpt", "sonnet", "random", "hf"],
 754 |         help="Name of the dataset to benchmark on.",
 755 |     )
 756 |     parser.add_argument("--dataset-path",
 757 |                         type=str,
 758 |                         default=None,
 759 |                         help="Path to the sharegpt/sonnet dataset. "
 760 |                         "Or the huggingface dataset ID if using HF dataset.")
 761 |     parser.add_argument(
 762 |         "--max-concurrency",
 763 |         type=int,
 764 |         default=None,
 765 |         help="Maximum number of concurrent requests. This can be used "
 766 |         "to help simulate an environment where a higher level component "
 767 |         "is enforcing a maximum number of concurrent requests. While the "
 768 |         "--request-rate argument controls the rate at which requests are "
 769 |         "initiated, this argument will control how many are actually allowed "
 770 |         "to execute at a time. This means that when used in combination, the "
 771 |         "actual request rate may be lower than specified with --request-rate, "
 772 |         "if the server is not processing requests fast enough to keep up.")
 773 | 
 774 |     parser.add_argument(
 775 |         "--model",
 776 |         type=str,
 777 |         required=True,
 778 |         help="Name of the model.",
 779 |     )
 780 |     parser.add_argument(
 781 |         "--tokenizer",
 782 |         type=str,
 783 |         help=
 784 |         "Name or path of the tokenizer, if not using the default tokenizer.",  # noqa: E501
 785 |     )
 786 |     parser.add_argument("--use-beam-search", action="store_true")
 787 |     parser.add_argument(
 788 |         "--num-prompts",
 789 |         type=int,
 790 |         default=1000,
 791 |         help="Number of prompts to process.",
 792 |     )
 793 |     parser.add_argument(
 794 |         "--logprobs",
 795 |         type=int,
 796 |         default=None,
 797 |         help=("Number of logprobs-per-token to compute & return as part of "
 798 |               "the request. If unspecified, then either (1) if beam search "
 799 |               "is disabled, no logprobs are computed & a single dummy "
 800 |               "logprob is returned for each token; or (2) if beam search "
 801 |               "is enabled 1 logprob per token is computed"),
 802 |     )
 803 |     parser.add_argument(
 804 |         "--request-rate",
 805 |         type=float,
 806 |         default=float("inf"),
 807 |         help="Number of requests per second. If this is inf, "
 808 |         "then all the requests are sent at time 0. "
 809 |         "Otherwise, we use Poisson process or gamma distribution "
 810 |         "to synthesize the request arrival times.",
 811 |     )
 812 |     parser.add_argument(
 813 |         "--burstiness",
 814 |         type=float,
 815 |         default=1.0,
 816 |         help="Burstiness factor of the request generation. "
 817 |         "Only take effect when request_rate is not inf. "
 818 |         "Default value is 1, which follows Poisson process. "
 819 |         "Otherwise, the request intervals follow a gamma distribution. "
 820 |         "A lower burstiness value (0 < burstiness < 1) results in more "
 821 |         "bursty requests. A higher burstiness value (burstiness > 1) "
 822 |         "results in a more uniform arrival of requests.",
 823 |     )
 824 |     parser.add_argument("--seed", type=int, default=0)
 825 |     parser.add_argument(
 826 |         "--trust-remote-code",
 827 |         action="store_true",
 828 |         help="Trust remote code from huggingface",
 829 |     )
 830 |     parser.add_argument(
 831 |         "--disable-tqdm",
 832 |         action="store_true",
 833 |         help="Specify to disable tqdm progress bar.",
 834 |     )
 835 |     parser.add_argument(
 836 |         "--profile",
 837 |         action="store_true",
 838 |         help="Use Torch Profiler. The endpoint must be launched with "
 839 |         "VLLM_TORCH_PROFILER_DIR to enable profiler.",
 840 |     )
 841 |     parser.add_argument(
 842 |         "--save-result",
 843 |         action="store_true",
 844 |         help="Specify to save benchmark results to a json file",
 845 |     )
 846 |     parser.add_argument(
 847 |         "--save-detailed",
 848 |         action="store_true",
 849 |         help="When saving the results, whether to include per request "
 850 |         "information such as response, error, ttfs, tpots, etc.",
 851 |     )
 852 |     parser.add_argument(
 853 |         "--metadata",
 854 |         metavar="KEY=VALUE",
 855 |         nargs="*",
 856 |         help="Key-value pairs (e.g, --metadata version=0.3.3 tp=1) "
 857 |         "for metadata of this run to be saved in the result JSON file "
 858 |         "for record keeping purposes.",
 859 |     )
 860 |     parser.add_argument(
 861 |         "--result-dir",
 862 |         type=str,
 863 |         default=None,
 864 |         help="Specify directory to save benchmark json results."
 865 |         "If not specified, results are saved in the current directory.",
 866 |     )
 867 |     parser.add_argument(
 868 |         "--result-filename",
 869 |         type=str,
 870 |         default=None,
 871 |         help="Specify the filename to save benchmark json results."
 872 |         "If not specified, results will be saved in "
 873 |         "{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
 874 |         " format.",
 875 |     )
 876 |     parser.add_argument(
 877 |         "--ignore-eos",
 878 |         action="store_true",
 879 |         help="Set ignore_eos flag when sending the benchmark request."
 880 |         "Warning: ignore_eos is not supported in deepspeed_mii and tgi.")
 881 |     parser.add_argument(
 882 |         "--percentile-metrics",
 883 |         type=str,
 884 |         default="ttft,tpot,itl",
 885 |         help="Comma-seperated list of selected metrics to report percentils. "
 886 |         "This argument specifies the metrics to report percentiles. "
 887 |         "Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". "
 888 |         "Default value is \"ttft,tpot,itl\".")
 889 |     parser.add_argument(
 890 |         "--metric-percentiles",
 891 |         type=str,
 892 |         default="99",
 893 |         help="Comma-seperated list of percentiles for selected metrics. "
 894 |         "To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
 895 |         "Default value is \"99\". "
 896 |         "Use \"--percentile-metrics\" to select metrics.",
 897 |     )
 898 |     parser.add_argument(
 899 |         "--goodput",
 900 |         nargs="+",
 901 |         required=False,
 902 |         help="Specify service level objectives for goodput as \"KEY:VALUE\" "
 903 |         "pairs, where the key is a metric name, and the value is in "
 904 |         "milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, "
 905 |         "separated by spaces. Allowed request level metric names are "
 906 |         "\"ttft\", \"tpot\", \"e2el\". For more context on the definition of "
 907 |         "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
 908 |         "and the blog: https://hao-ai-lab.github.io/blogs/distserve")
 909 | 
 910 |     # group for dataset specific arguments
 911 |     sonnet_group = parser.add_argument_group("sonnet dataset options")
 912 |     sonnet_group.add_argument(
 913 |         "--sonnet-input-len",
 914 |         type=int,
 915 |         default=550,
 916 |         help=
 917 |         "Number of input tokens per request, used only for sonnet dataset.",
 918 |     )
 919 |     sonnet_group.add_argument(
 920 |         "--sonnet-output-len",
 921 |         type=int,
 922 |         default=150,
 923 |         help=
 924 |         "Number of output tokens per request, used only for sonnet dataset.",
 925 |     )
 926 |     sonnet_group.add_argument(
 927 |         "--sonnet-prefix-len",
 928 |         type=int,
 929 |         default=200,
 930 |         help=
 931 |         "Number of prefix tokens per request, used only for sonnet dataset.",
 932 |     )
 933 | 
 934 |     sharegpt_group = parser.add_argument_group("sharegpt dataset options")
 935 |     sharegpt_group.add_argument(
 936 |         "--sharegpt-output-len",
 937 |         type=int,
 938 |         default=None,
 939 |         help="Output length for each request. Overrides the output length "
 940 |         "from the ShareGPT dataset.")
 941 | 
 942 |     random_group = parser.add_argument_group("random dataset options")
 943 |     random_group.add_argument(
 944 |         "--random-input-len",
 945 |         type=int,
 946 |         default=1024,
 947 |         help=
 948 |         "Number of input tokens per request, used only for random sampling.",
 949 |     )
 950 |     random_group.add_argument(
 951 |         "--random-output-len",
 952 |         type=int,
 953 |         default=128,
 954 |         help=
 955 |         "Number of output tokens per request, used only for random sampling.",
 956 |     )
 957 |     random_group.add_argument(
 958 |         "--random-range-ratio",
 959 |         type=float,
 960 |         default=1.0,
 961 |         help="Range of sampled ratio of input/output length, "
 962 |         "used only for random sampling.",
 963 |     )
 964 |     random_group.add_argument(
 965 |         "--random-prefix-len",
 966 |         type=int,
 967 |         default=0,
 968 |         help="Number of fixed prefix tokens before random "
 969 |         " context. The length range of context in a random "
 970 |         " request is [random-prefix-len, "
 971 |         " random-prefix-len + random-prefix-len * random-range-ratio).")
 972 | 
 973 |     hf_group = parser.add_argument_group("hf dataset options")
 974 |     hf_group.add_argument("--hf-subset",
 975 |                           type=str,
 976 |                           default=None,
 977 |                           help="Subset of the HF dataset.")
 978 |     hf_group.add_argument("--hf-split",
 979 |                           type=str,
 980 |                           default=None,
 981 |                           help="Split of the HF dataset.")
 982 |     hf_group.add_argument(
 983 |         "--hf-output-len",
 984 |         type=int,
 985 |         default=None,
 986 |         help="Output length for each request. Overrides the output lengths "
 987 |         "from the sampled HF dataset.",
 988 |     )
 989 | 
 990 |     parser.add_argument(
 991 |         '--tokenizer-mode',
 992 |         type=str,
 993 |         default="auto",
 994 |         choices=['auto', 'slow', 'mistral', 'custom'],
 995 |         help='The tokenizer mode.\n\n* "auto" will use the '
 996 |         'fast tokenizer if available.\n* "slow" will '
 997 |         'always use the slow tokenizer. \n* '
 998 |         '"mistral" will always use the `mistral_common` tokenizer. \n*'
 999 |         '"custom" will use --tokenizer to select the preregistered tokenizer.')
1000 | 
1001 |     parser.add_argument("--served-model-name",
1002 |                         type=str,
1003 |                         default=None,
1004 |                         help="The model name used in the API. "
1005 |                         "If not specified, the model name will be the "
1006 |                         "same as the ``--model`` argument. ")
1007 | 
1008 |     parser.add_argument("--lora-modules",
1009 |                         nargs='+',
1010 |                         default=None,
1011 |                         help="A subset of LoRA module names passed in when "
1012 |                         "launching the server. For each request, the "
1013 |                         "script chooses a LoRA module at random.")
1014 | 
1015 |     args = parser.parse_args()
1016 | 
1017 |     main(args)
```