This commit is contained in:
Zhang Yulong
2026-02-03 21:59:32 +08:00
committed by GitHub
parent 793dac0f9d
commit 16d03c3127
2 changed files with 218 additions and 133 deletions
+203 -133
View File
@@ -17,6 +17,7 @@
# This file is modified from https://github.com/vllm-project/vllm/blob/main/benchmarks/backend_request_func.py
import copy
import io
import json
import os
@@ -166,166 +167,173 @@ def metrics_summary(metrics, token_timestamps):
async def async_request_eb_openai_chat_completions(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
session: aiohttp.ClientSession | None = None,
) -> RequestFuncOutput:
"""Request an LLM using EB OpenAI"""
api_url = request_func_input.api_url
assert api_url.endswith(("completions", "profile")), "OpenAI Chat Completions API URL must end with 'completions'."
async with aiohttp.ClientSession(
trust_env=True, read_bufsize=10 * 1024 * 1024, timeout=AIOHTTP_TIMEOUT
) as session:
content = [{"type": "text", "text": request_func_input.prompt}]
if request_func_input.multi_modal_content:
content.append(request_func_input.multi_modal_content)
payload = {
"model": request_func_input.model,
"messages": request_func_input.history_QA,
"stream": True,
"stream_options": {
"include_usage": True,
"continuous_usage_stats": True,
},
"max_tokens": request_func_input.output_len,
"collect_metrics": request_func_input.pd_metrics,
}
if request_func_input.response_format:
payload["response_format"] = request_func_input.response_format
own_session = session is None
if own_session:
session = aiohttp.ClientSession(
trust_env=True,
read_bufsize=10 * 1024 * 1024,
timeout=AIOHTTP_TIMEOUT,
)
# 超参由yaml传入
payload.update(request_func_input.hyper_parameters)
content = [{"type": "text", "text": request_func_input.prompt}]
if request_func_input.multi_modal_content:
content.append(request_func_input.multi_modal_content)
payload = {
"model": request_func_input.model,
"messages": request_func_input.history_QA,
"stream": True,
"stream_options": {
"include_usage": True,
"continuous_usage_stats": True,
},
"max_tokens": request_func_input.output_len,
"collect_metrics": request_func_input.pd_metrics,
}
if request_func_input.response_format:
payload["response_format"] = request_func_input.response_format
# 随机输入开关
if request_func_input.random_flag:
payload["max_tokens"] = request_func_input.output_len
metadata = payload.get("metadata", {})
metadata["min_tokens"] = request_func_input.output_len
payload["metadata"] = metadata
# 超参由yaml传入
payload.update(request_func_input.hyper_parameters)
if request_func_input.ignore_eos:
payload["ignore_eos"] = request_func_input.ignore_eos
# 随机输入开关
if request_func_input.random_flag:
payload["max_tokens"] = request_func_input.output_len
metadata = payload.get("metadata", {})
metadata["min_tokens"] = request_func_input.output_len
payload["metadata"] = metadata
if request_func_input.debug:
print(f"payload:{json.dumps(payload, ensure_ascii=False)}")
if request_func_input.ignore_eos:
payload["ignore_eos"] = request_func_input.ignore_eos
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
if request_func_input.debug:
print(f"payload:{json.dumps(payload, ensure_ascii=False)}")
output = RequestFuncOutput()
output.prompt_len = 0
output.no = request_func_input.no
metrics_list = []
request_id = "None"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
ttft = 0.0
res_ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
token_timestamps = []
try:
async with session.post(
url=api_url, json=payload, headers=headers, read_bufsize=10 * 1024 * 1024
) as response:
data = {}
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
output = RequestFuncOutput()
output.prompt_len = 0
output.no = request_func_input.no
metrics_list = []
request_id = "None"
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
if chunk != "[DONE]":
# print("####chunk:", chunk, type(chunk))
timestamp = time.perf_counter()
data = json.loads(chunk)
# print("####data:", json.dumps(data, indent=2, ensure_ascii=False))
ttft = 0.0
res_ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
token_timestamps = []
try:
async with session.post(url=api_url, json=payload, headers=headers, read_bufsize=10 * 1024 * 1024) as response:
data = {}
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
if "metrics" in data:
metrics_list.append(data["metrics"])
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
if chunk != "[DONE]":
# print("####chunk:", chunk, type(chunk))
timestamp = time.perf_counter()
data = json.loads(chunk)
# print("####data:", json.dumps(data, indent=2, ensure_ascii=False))
if request_id == "None" and "id" in data:
request_id = data["id"]
if "metrics" in data:
metrics_list.append(data["metrics"])
if choices := data.get("choices"):
content = choices[0]["delta"].get("content")
reason_content = choices[0]["delta"].get("reasoning_content")
# First token
if ttft == 0.0:
ttft = timestamp - st
output.ttft = ttft
# cached_tokens
if data["usage"] and data["usage"].get("prompt_tokens_details", {}):
output.prompt_len = (
data["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0)
)
else:
output.prompt_len = 0
if request_id == "None" and "id" in data:
request_id = data["id"]
# Decoding phase
if choices := data.get("choices"):
content = choices[0]["delta"].get("content")
reason_content = choices[0]["delta"].get("reasoning_content")
# First token
if ttft == 0.0:
ttft = timestamp - st
output.ttft = ttft
# cached_tokens
if data["usage"] and data["usage"].get("prompt_tokens_details", {}):
output.prompt_len = (
data["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0)
)
else:
output.itl.append(timestamp - most_recent_timestamp)
output.prompt_len = 0
# response首token
if res_ttft == 0.0:
if content:
res_ttft = choices[0].get("arrival_time", timestamp)
output.res_ttft = res_ttft
usage = data.get("usage") or {}
output.reasoning_tokens = max(usage.get("completion_tokens", 0) - 1, 0)
# Decoding phase
else:
output.itl.append(timestamp - most_recent_timestamp)
output.generated_text += content or ""
output.reasoning_content += reason_content or ""
# print(f"####content:{data}")
output.arrival_time.append(choices[0].get("arrival_time", timestamp))
elif usage := data.get("usage", {}):
output.output_tokens = usage.get("completion_tokens", 0)
output.prompt_tokens = usage.get("prompt_tokens", 0)
if output.prompt_len == 0:
if data["usage"] and data["usage"].get("prompt_tokens_details", {}):
output.prompt_len = (
data["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0)
)
# response首token
if res_ttft == 0.0:
if content:
res_ttft = choices[0].get("arrival_time", timestamp)
output.res_ttft = res_ttft
usage = data.get("usage") or {}
output.reasoning_tokens = max(usage.get("completion_tokens", 0) - 1, 0)
most_recent_timestamp = timestamp
token_timestamps.append(time.time())
output.generated_text += content or ""
output.reasoning_content += reason_content or ""
# print(f"####content:{data}")
output.arrival_time.append(choices[0].get("arrival_time", timestamp))
elif usage := data.get("usage", {}):
output.output_tokens = usage.get("completion_tokens", 0)
output.prompt_tokens = usage.get("prompt_tokens", 0)
if output.prompt_len == 0:
if data["usage"] and data["usage"].get("prompt_tokens_details", {}):
output.prompt_len = (
data["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0)
)
# output.generated_text = generated_text
# 在流式结束时,记录最后一个 chunk 收到的时间戳
output.end_timestamp = most_recent_timestamp
most_recent_timestamp = timestamp
token_timestamps.append(time.time())
# 新增metrics统计,计算首token过滤空包
output.metrics = metrics_summary(metrics_list, token_timestamps[1:])
# output.generated_text = generated_text
# 在流式结束时,记录最后一个 chunk 收到的时间戳
output.end_timestamp = most_recent_timestamp
# 兼容思考内容超长截断的情况,此时回复内容为空
if output.generated_text.strip() == "" and output.reasoning_content.strip() == "":
output.success = False
output.reasoning_tokens = output.output_tokens
output.error = "No generated text found!"
else:
output.success = True
output.latency = most_recent_timestamp - st
else:
error_text = await response.text()
print(
"####error response:",
error_text,
"####payload:",
payload,
)
output.error = error_text or ""
# 新增metrics统计,计算首token过滤空包
output.metrics = metrics_summary(metrics_list, token_timestamps[1:])
# 兼容思考内容超长截断的情况,此时回复内容为空
if output.generated_text.strip() == "" and output.reasoning_content.strip() == "":
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
output.reasoning_tokens = output.output_tokens
output.error = "No generated text found!"
else:
output.success = True
output.latency = most_recent_timestamp - st
else:
error_text = await response.text()
print(
"####error response:",
error_text,
"####payload:",
payload,
)
output.error = error_text or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
finally:
if own_session:
await session.close()
output.request_id = request_id
output.request_id = request_id
# 保存失败请求结果
if not output.success or output.output_tokens == 0:
with open("error_output.txt", "a") as f:
f.write(str(output) + "\n")
# 保存失败请求结果
if not output.success or output.output_tokens == 0:
with open("error_output.txt", "a") as f:
f.write(str(output) + "\n")
if pbar:
pbar.update(1)
if request_func_input.debug:
@@ -333,6 +341,67 @@ async def async_request_eb_openai_chat_completions(
return output
async def async_request_eb_openai_chat_completions_multi_turn(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
):
outputs = []
ori_history = request_func_input.history_QA
user_count = sum(msg.get("role") == "user" for msg in ori_history)
print("START", request_func_input.no, "对话轮数:", user_count, flush=True)
history = []
prompt_no = 0
# 只创建一次 session
connector = aiohttp.TCPConnector(
limit=0,
limit_per_host=0,
keepalive_timeout=60,
)
async with aiohttp.ClientSession(
connector=connector,
trust_env=True,
read_bufsize=10 * 1024 * 1024,
timeout=AIOHTTP_TIMEOUT,
) as session:
for i, message in enumerate(ori_history):
if message["role"] == "user":
history.append(message)
round_input = copy.deepcopy(request_func_input)
round_input.history_QA = history
round_input.no = f"{round_input.no}_{prompt_no}"
# 复用 session
output = await async_request_eb_openai_chat_completions(
round_input,
pbar=None,
session=session,
)
outputs.append(output)
if not output.success:
return outputs
prompt_no += 1
history.append(
{
"role": "assistant",
"content": output.generated_text,
}
)
elif message["role"] == "assistant":
continue
else:
history.append(message)
if pbar:
pbar.update(1)
return outputs
async def async_request_eb_openai_completions(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
@@ -834,6 +903,7 @@ ASYNC_REQUEST_FUNCS = {
"deepspeed-mii": async_request_deepspeed_mii,
"openai": async_request_eb_openai_completions,
"openai-chat": async_request_eb_openai_chat_completions,
"openai-chat-multi-turn": async_request_eb_openai_chat_completions_multi_turn,
"openai-audio": async_request_openai_audio,
"tensorrt-llm": async_request_trt_llm,
"scalellm": async_request_openai_completions,
+15
View File
@@ -389,6 +389,9 @@ async def benchmark(
print("test_output:", test_output)
if args.multi_turn:
test_output = test_output[0]
if not test_output.success:
raise ValueError(
f"Initial test run failed - Please make sure that 1. benchmark arguments are correctly specified and 2. the http_proxy and https_proxy are turned off. Error: {test_output.error}"
@@ -569,6 +572,10 @@ async def benchmark(
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
# 多轮对话需要flatten后统计
if args.multi_turn:
outputs = [x for sub in outputs for x in sub]
outputs.sort(key=lambda x: x.end_timestamp)
if profile:
@@ -1040,6 +1047,9 @@ def main(args: argparse.Namespace):
np.random.seed(args.seed)
backend = args.backend
# 支持多轮对话方式请求,仅支持chat接口
if args.multi_turn:
backend = "openai-chat-multi-turn"
model_id = args.model
model_name = args.served_model_name
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
@@ -1345,6 +1355,11 @@ if __name__ == "__main__":
action="store_true",
help="请求时增加PD分离参数,metrics: True",
)
parser.add_argument(
"--multi-turn",
action="store_true",
help="按多轮对话方式请求",
)
parser.add_argument(
"--drop-ratio",
type=float,