From 3c7ca62dc3eefc7e5d79865565934b37a4578498 Mon Sep 17 00:00:00 2001 From: Zhang Yulong <35552275+ZhangYulongg@users.noreply.github.com> Date: Tue, 21 Apr 2026 16:32:09 +0800 Subject: [PATCH] [benchmark] support swe temp (#7530) --- benchmarks/backend_request_func_swe.py | 1361 ++++++++++++++++++++++++ 1 file changed, 1361 insertions(+) create mode 100644 benchmarks/backend_request_func_swe.py diff --git a/benchmarks/backend_request_func_swe.py b/benchmarks/backend_request_func_swe.py new file mode 100644 index 0000000000..7c09570022 --- /dev/null +++ b/benchmarks/backend_request_func_swe.py @@ -0,0 +1,1361 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +# This file is modified from https://github.com/vllm-project/vllm/blob/main/benchmarks/backend_request_func.py + + +import copy +import io +import json +import logging +import os +import sys +import time +import traceback +from dataclasses import dataclass, field +from typing import Optional + +import aiohttp +from tqdm.asyncio import tqdm + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) + + +@dataclass +class RequestFuncInput: + """Input for requesting LLMs via API""" + + no: int + prompt: str + history_QA: Optional[dict] + hyper_parameters: dict + api_url: str + prompt_len: int + output_len: int + model: str + model_name: Optional[str] = None + logprobs: Optional[int] = None + extra_body: Optional[dict] = None + multi_modal_content: Optional[dict] = None + ignore_eos: bool = False + language: Optional[str] = None + debug: bool = False + pd_metrics: bool = False + response_format: Optional[dict] = None + random_flag: bool = False + json_data: Optional[dict] = None + prompt_token_ids: Optional[list] = None + tokenizer_model: str = None + tokenizer_path: str = None + + +@dataclass +class RequestFuncOutput: + """Output for requesting LLMs via API""" + + no: int = 0 + request_id: str = "" + generated_text: str = "" + reasoning_content: str = "" + success: bool = False + latency: float = 0.0 + end_timestamp: float = 0.0 # 模型完全返回的时间戳(秒, perf_counter基准) + output_tokens: int = 0 + ttft: float = 0.0 # Time to first token + arrival_time: list = field(default_factory=list) # arrival_time + itl: list = field(default_factory=list) # list of inter-token latencies + tpot: float = 0.0 # avg next-token latencies + prompt_len: int = 0 + prompt_tokens: int = 0 # 推理侧返回输入token数 + reasoning_tokens: int = 0 # 思考长度 + res_ttft: int = 0 # 包含思考首token时延 + error: str = "" + metrics: dict = field(default_factory=dict) + tool_calls: list = field(default_factory=list) + output_ids: list = field(default_factory=list) + + +@dataclass +class SessionMetrics: + """多轮对话指标""" + + session_no: int + session_e2e_time: float + pure_llm_time: float + input_tokens: int + output_tokens: int + tool_calls: int + + +def safe_cost(a, b): + """时间差计算""" + if a is None or b is None: + return None + return a - b + + +def metrics_summary(metrics, token_timestamps): + """Summarize metrics""" + if not metrics or len(token_timestamps) < 2: + return {} + + m0 = metrics[0] + m_last = metrics[-1] + + summary = {} + + arrival_time = m0.get("arrival_time") + inference_start_time = m0.get("inference_start_time") + + # prefill 总耗时 + summary["prefill_cost_time"] = safe_cost(m0.get("send_request_output_to_decode_time"), arrival_time) + # prefill准备总耗时 + summary["prefill_prepare_cost_time"] = safe_cost(inference_start_time, arrival_time) + # 预处理耗时 + summary["preprocess_cost_time"] = safe_cost(m0.get("scheduler_recv_req_time"), arrival_time) + # 请求缓存耗时 + summary["cache_in_scheduler_cost_time"] = safe_cost( + m0.get("engine_get_req_time"), m0.get("scheduler_recv_req_time") + ) + # 申请 decode资源耗时 + summary["ask_decode_resource_cost_time"] = safe_cost( + m0.get("ask_decode_resource_finish_time"), m0.get("ask_decode_resource_start_time") + ) + # scheduler调度耗时 + summary["schedule_cost_time"] = safe_cost( + m0.get("inference_start_time"), m0.get("ask_decode_resource_finish_time") + ) + # prefill 的首 token 推理耗时 + summary["prefill_first_token_infer_cost_time"] = safe_cost( + m0.get("engine_recv_first_token_time"), inference_start_time + ) + # prefill 等待 cache 传输耗时 + summary["wait_sending_cache_cost_time"] = safe_cost( + m0.get("send_request_output_to_decode_time"), m0.get("wait_for_sending_cache_time") + ) + # decode分配资源耗时 + summary["decode_preallocate_cost_time"] = safe_cost( + m_last.get("decode_preallocate_req_time"), m_last.get("decode_recv_req_time") + ) + # decode准备推理耗时 + summary["decode_prepare_cost_time"] = safe_cost( + m_last.get("decode_inference_start_time"), m_last.get("decode_recv_first_token_time") + ) + # decode次token推理耗时 + summary["decode_second_token_infer_cost_time"] = safe_cost( + m_last.get("decode_recv_second_token_time"), m_last.get("decode_inference_start_time") + ) + # 返回首 token 链路耗时 + summary["first_token_transmission_cost_time"] = safe_cost( + token_timestamps[0], m_last.get("decode_recv_first_token_time") + ) + # 返回次 token 链路耗时 + summary["second_token_transmission_cost_time"] = safe_cost( + token_timestamps[1], m_last.get("decode_recv_second_token_time") + ) + + # MIX 模式下,scheduler调度耗时 + summary["mixed_schedule_cost_time"] = safe_cost(m0.get("inference_start_time"), m0.get("engine_get_req_time")) + # MIX 模式下,返回首 token 链路耗时 + summary["mixed_first_token_transmission_cost_time"] = safe_cost( + token_timestamps[0], m0.get("engine_recv_first_token_time") + ) + + summary["gpu_cache_token_num"] = m0.get("gpu_cache_token_num") + summary["cpu_cache_token_num"] = m0.get("cpu_cache_token_num") + summary["storage_cache_token_num"] = m0.get("storage_cache_token_num") + summary["cpu_cache_prepare_time"] = m0.get("cpu_cache_prepare_time") + summary["storage_cache_prepare_time"] = m0.get("storage_cache_prepare_time") + + return summary + + +def load_tokenizer(model, actor_tokenizer_path): + """加载tokenizer""" + from ernie_tokenizer import Ernie5Tokenizer, ErnieBotTokenizer + from paddleformers.transformers import AutoTokenizer + + from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer + + vocab_file_names = ["tokenizer.model", "spm.model", "ernie_token_100k.model"] + + try: + if model == "eb": + for i in range(len(vocab_file_names)): + if os.path.exists(os.path.join(actor_tokenizer_path, vocab_file_names[i])): + ErnieBotTokenizer.resource_files_names["vocab_file"] = vocab_file_names[i] + break + tokenizer = ErnieBotTokenizer.from_pretrained(actor_tokenizer_path) + elif model == "eb_mm": + for vocab_file in vocab_file_names: + full_path = os.path.join(actor_tokenizer_path, vocab_file) + if os.path.exists(full_path): + Ernie4_5Tokenizer.resource_files_names["vocab_file"] = vocab_file + # for i in range(len(vocab_file_names)): + # if os.path.exists(os.path.join(actor_tokenizer_path, vocab_file_names[i])): + # Ernie45Tokenizer.resource_files_names["vocab_file"] = vocab_file_names[i] + # break + tokenizer = Ernie4_5Tokenizer.from_pretrained(actor_tokenizer_path) + # tokenizer.ignored_index = -100 + elif model == "eb5": + for i in range(len(vocab_file_names)): + if os.path.exists(os.path.join(actor_tokenizer_path, vocab_file_names[i])): + Ernie5Tokenizer.resource_files_names["vocab_file"] = vocab_file_names[i] + break + tokenizer = Ernie5Tokenizer.from_pretrained(actor_tokenizer_path) + else: + print("tokenizer: AUTO") + tokenizer = AutoTokenizer.from_pretrained(actor_tokenizer_path, padding_side="left", use_fast=True) + except Exception as e: + tokenizer = None + logging.warning(f"Load tokenizer error: {e}") + + return tokenizer + + +async def async_request_eb_openai_chat_completions( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, + session: aiohttp.ClientSession | None = None, +) -> RequestFuncOutput: + """Request an LLM using EB OpenAI""" + api_url = request_func_input.api_url + assert api_url.endswith(("completions", "profile")), "OpenAI Chat Completions API URL must end with 'completions'." + + own_session = session is None + if own_session: + session = aiohttp.ClientSession( + trust_env=True, + read_bufsize=10 * 1024 * 1024, + timeout=AIOHTTP_TIMEOUT, + ) + + content = [{"type": "text", "text": request_func_input.prompt}] + if request_func_input.multi_modal_content: + content.append(request_func_input.multi_modal_content) + # print("######json_data:", request_func_input.json_data) + payload = { + "model": request_func_input.model, + "messages": request_func_input.history_QA, + "stream": True, + "stream_options": { + "include_usage": True, + "continuous_usage_stats": True, + }, + "max_tokens": request_func_input.output_len, + "collect_metrics": request_func_input.pd_metrics, + } + if request_func_input.json_data: + json_data = request_func_input.json_data + + if json_data.get("max_tokens"): + payload["max_tokens"] = json_data["max_tokens"] + + if json_data.get("min_tokens"): + payload["min_tokens"] = json_data["min_tokens"] + if request_func_input.response_format: + payload["response_format"] = request_func_input.response_format + + # 随机输入开关 + if request_func_input.random_flag: + payload["max_tokens"] = request_func_input.output_len + payload["min_tokens"] = request_func_input.output_len + # 随机token_ids场景 + if isinstance(request_func_input.prompt, list): + request_func_input.prompt_token_ids = request_func_input.prompt + request_func_input.prompt = "" + + # 支持传入prompt_token_ids + if request_func_input.prompt_token_ids: + # 不走messages + payload["messages"] = [{"role": "user", "content": [{"type": "text", "text": ""}]}] + payload["prompt_token_ids"] = request_func_input.prompt_token_ids + payload["return_token_ids"] = True + # print("use_token_ids:", payload) + + # 超参由yaml传入 + payload.update(request_func_input.hyper_parameters) + + # tools信息,yaml优先级最高 + json_data = request_func_input.json_data or {} + hyper = request_func_input.hyper_parameters or {} + + tools = None + tool_choice = None + + if hyper.get("tools"): + tools = hyper.get("tools") + tool_choice = hyper.get("tool_choice", "auto") + elif json_data.get("tools"): + tools = json_data.get("tools") + tool_choice = json_data.get("tool_choice", "auto") + + if tools: + payload["tools"] = tools + payload["tool_choice"] = tool_choice + + # 随机输入开关 + if request_func_input.random_flag: + payload["max_tokens"] = request_func_input.output_len + metadata = payload.get("metadata", {}) + metadata["min_tokens"] = request_func_input.output_len + payload["metadata"] = metadata + + if request_func_input.ignore_eos: + payload["ignore_eos"] = request_func_input.ignore_eos + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + + output = RequestFuncOutput() + output.prompt_len = 0 + output.no = request_func_input.no + payload["no"] = request_func_input.no + if request_func_input.debug: + print(f"payload:{json.dumps(payload, ensure_ascii=False)}") + metrics_list = [] + request_id = "None" + + ttft = 0.0 + res_ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + token_timestamps = [] + tool_call_buffer = {} + try: + async with session.post(url=api_url, json=payload, headers=headers, read_bufsize=10 * 1024 * 1024) as response: + data = {} + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = chunk_bytes.decode("utf-8").removeprefix("data: ") + if chunk != "[DONE]": + # print("####chunk:", chunk, type(chunk)) + timestamp = time.perf_counter() + data = json.loads(chunk) + # print("####data:", json.dumps(data, indent=2, ensure_ascii=False)) + + if "metrics" in data: + metrics_list.append(data["metrics"]) + + if request_id == "None" and "id" in data: + request_id = data["id"] + + if choices := data.get("choices"): + content = choices[0]["delta"].get("content") + reason_content = choices[0]["delta"].get("reasoning_content") + tool_calls = choices[0]["delta"].get("tool_calls") + completion_token_ids = choices[0]["delta"].get("completion_token_ids", []) + if tool_calls: + for tc in tool_calls: + idx = tc.get("index", 0) + + if idx not in tool_call_buffer: + tool_call_buffer[idx] = { + "id": tc.get("id"), + "name": "", + "arguments": "", + } + + func = tc.get("function", {}) + + if "name" in func: + tool_call_buffer[idx]["name"] = func["name"] + + if "arguments" in func: + tool_call_buffer[idx]["arguments"] += func["arguments"] + + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft + # cached_tokens + if data["usage"] and data["usage"].get("prompt_tokens_details", {}): + output.prompt_len = ( + data["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0) + ) + else: + output.prompt_len = 0 + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + # response首token + if res_ttft == 0.0: + if content: + res_ttft = choices[0].get("arrival_time", timestamp) + output.res_ttft = res_ttft + usage = data.get("usage") or {} + output.reasoning_tokens = max(usage.get("completion_tokens", 0) - 1, 0) + + output.generated_text += content or "" + output.reasoning_content += reason_content or "" + if completion_token_ids: + output.output_ids.extend(completion_token_ids) + # print(f"####content:{data}") + output.arrival_time.append(choices[0].get("arrival_time", timestamp)) + elif usage := data.get("usage", {}): + output.output_tokens = usage.get("completion_tokens", 0) + output.prompt_tokens = usage.get("prompt_tokens", 0) + if output.prompt_len == 0: + if data["usage"] and data["usage"].get("prompt_tokens_details", {}): + output.prompt_len = ( + data["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0) + ) + + most_recent_timestamp = timestamp + token_timestamps.append(time.time()) + + # output.generated_text = generated_text + # 在流式结束时,记录最后一个 chunk 收到的时间戳 + output.end_timestamp = most_recent_timestamp + # 截断case也记录usage + usage = data.get("usage", {}) + if usage: + output.output_tokens = usage.get("completion_tokens", 0) + output.prompt_tokens = usage.get("prompt_tokens", 0) + if output.prompt_len == 0: + prompt_details = usage.get("prompt_tokens_details", {}) + if prompt_details: + output.prompt_len = prompt_details.get("cached_tokens", 0) + + if tool_call_buffer: + for _, tc in tool_call_buffer.items(): + try: + args = json.loads(tc["arguments"]) if tc["arguments"] else {} + except: + args = {} + + output.tool_calls.append({"id": tc["id"], "name": tc["name"], "arguments": args}) + + # 新增metrics统计,计算首token过滤空包 + output.metrics = metrics_summary(metrics_list, token_timestamps[1:]) + + has_text = output.generated_text.strip() or output.reasoning_content.strip() + has_tool = getattr(output, "tool_calls", None) + + # 兼容思考内容超长截断的情况,此时回复内容为空 + if not has_text and not has_tool: + output.success = False + output.reasoning_tokens = output.output_tokens + output.error = "No generated text found!" + else: + output.success = True + output.latency = most_recent_timestamp - st + else: + error_text = await response.text() + print( + "####error response:", + error_text, + "####payload:", + payload, + ) + output.error = error_text or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + finally: + if own_session: + await session.close() + + output.request_id = request_id + + # 保存失败请求结果 + if not output.success or output.output_tokens == 0: + with open("error_output.txt", "a") as f: + f.write(str(output) + "\n") + if pbar: + pbar.update(1) + if request_func_input.debug: + print("#####final_output:", output) + return output + + +async def simple_tool_call(model_output, tool_url: str, timeout=60): + """调用工具函数""" + import re + + import httpx + + tool_id = None + + if getattr(model_output, "tool_calls", None): + tc = model_output.tool_calls[0] + tool_name = tc["name"] + args = tc.get("arguments", {}) + tool_id = tc.get("id") + else: + match = re.search(r"(.*?)", model_output.generated_text, re.S) + if not match: + return "", False, "", tool_id + + block = match.group(1).strip() + lines = block.splitlines() + tool_name = lines[0].strip() + + key = re.search(r"(.*?)", block) + val = re.search(r"(.*?)", block) + + args = {key.group(1): val.group(1)} if key and val else {} + + if not tool_name: + return "", False, "", tool_id + + headers = {"Content-Type": "application/json"} + + try: + async with httpx.AsyncClient(timeout=timeout) as client: + resp = await client.post( + tool_url, + headers=headers, + json={"tool_name": tool_name, "arguments": args}, + ) + + resp.raise_for_status() + obj = resp.json() + + return obj.get("result", resp.text), "result" in obj, tool_name, tool_id + + except Exception as e: + print(f"[TOOL ERROR] {tool_name}: {repr(e)}") + return str(e), False, tool_name, tool_id + + +async def async_request_eb_openai_chat_completions_multi_turn( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +): + # yaml中或数据集中带tools才走工具调用逻辑 + json_data = request_func_input.json_data or {} + hyper = request_func_input.hyper_parameters or {} + # enable_tools = bool(json_data.get("tools") or hyper.get("tools")) + # SWE数据集无工具可调用 + enable_tools = False + + outputs = [] + + tool_call_count = 0 + llm_time = 0.0 + tool_time = 0.0 + input_tokens = 0 + output_tokens = 0 + + ori_history = request_func_input.history_QA + user_count = sum(msg.get("role") == "user" for msg in ori_history) + print("START", request_func_input.no, "user对话轮数:", user_count, flush=True) + history = [] + prompt_no = 0 + max_prompt_len = ( + hyper.get("max_prompt_len") if hyper.get("max_prompt_len") is not None else json_data.get("max_prompt_len") + ) + print("max_prompt_len:", max_prompt_len) + input_ids_all = [] + # FD每轮 completion_token_ids + output_ids = [] + use_token_ids = bool(request_func_input.tokenizer_model and request_func_input.tokenizer_path) + tokenizer = None + + if use_token_ids: + print("token ids 拼接模式") + enable_tools = False + print("tokenizer_model:", request_func_input.tokenizer_model) + print("tokenizer_path:", request_func_input.tokenizer_path) + tokenizer = load_tokenizer( + request_func_input.tokenizer_model, + request_func_input.tokenizer_path, + ) + else: + print("messages 明文拼接模式") + + # 只创建一次 session + session_start = time.perf_counter() + connector = aiohttp.TCPConnector( + limit=0, + limit_per_host=0, + keepalive_timeout=60, + ) + + async with aiohttp.ClientSession( + connector=connector, + trust_env=True, + read_bufsize=10 * 1024 * 1024, + timeout=AIOHTTP_TIMEOUT, + ) as session: + for i, message in enumerate(ori_history): + if message["role"] == "user" or message["role"] == "tool": + history.append(message) + round_input = copy.deepcopy(request_func_input) + round_input.history_QA = history + round_input.no = f"{round_input.no}_{prompt_no}" + if use_token_ids: + if len(input_ids_all) == 0: + # 拼接token_ids模式,首轮token_ids + spliced_text = tokenizer.apply_chat_template( + history, + tokenize=False, + split_special_tokens=False, + add_special_tokens=False, + ) + # 转换为token ids + tokens = tokenizer.tokenize(spliced_text) + prompt_token_ids = tokenizer.convert_tokens_to_ids(tokens) + input_ids_all.extend(prompt_token_ids) + round_input.prompt_token_ids = input_ids_all + else: + prompt_length = len(input_ids_all) + len(output_ids) + if max_prompt_len and prompt_length >= max_prompt_len: + # 超长截断 + print( + f"[SESSION STOP] {round_input.no} reach max_prompt_len={max_prompt_len}, stop session" + ) + break + # 拼接token_ids模式,后续轮 + input_ids_all.extend(output_ids) + user_prompt = message["content"] + # 拼接user_prompt + if round_input.tokenizer_model == "eb5": + # EB5模型 + user_prompt = ( + f"\n\n<|im_start|>user\n{user_prompt}<|im_end|>\n\n<|im_start|>assistant\n\n" + ) + else: + # 0.3B模型,2 ,拼接时会被替换成100272 <|end_of_sentence|> + input_ids_all[-1] = 100272 + user_prompt = f"User: {user_prompt}\nAssistant: " + prompt_token_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(user_prompt)) + input_ids_all.extend(prompt_token_ids) + round_input.prompt_token_ids = input_ids_all + # 复用 session + s0 = time.perf_counter() + output = await async_request_eb_openai_chat_completions( + round_input, + pbar=None, + session=session, + ) + s1 = time.perf_counter() + llm_time += s1 - s0 + + outputs.append(output) + + if not output.success: + session_end = time.perf_counter() + metrics = SessionMetrics( + session_no=request_func_input.no, + session_e2e_time=session_end - session_start, + pure_llm_time=llm_time, + input_tokens=input_tokens, + output_tokens=output_tokens, + tool_calls=tool_call_count, + ) + return outputs, metrics + + # llm_cost = s1 - s0 + input_tokens += output.prompt_tokens + output_tokens += output.output_tokens + + # 更新output_ids + output_ids = output.output_ids + + if max_prompt_len and input_tokens >= max_prompt_len: + # 后验超长截断 + print(f"[SESSION STOP] {round_input.no} reach max_prompt_len={max_prompt_len}, stop session") + break + + if enable_tools: + # 循环调用工具 + max_loop = json_data.get("max_loop", 10) + tool_url = json_data.get("tool_url", "") + max_prompt_len = json_data.get("max_prompt_len") + if not tool_url: + raise ValueError("tool_url is empty.") + for _ in range(max_loop): + t0 = time.perf_counter() + tool_result, is_tool_result, tool_name, tool_id = await simple_tool_call( + output, + tool_url, + ) + t1 = time.perf_counter() + tool_time += t1 - t0 + # print(f"#### tool_time: {t1 - t0:.3f}") + # print(f"#### tool_result: {tool_result}") + # print(f"#### is_tool_result: {is_tool_result}") + + # 工具调用失败 + if tool_name and not is_tool_result: + print(f"[SESSION FAIL] tool call failed: {tool_name}") + + output.success = False + + session_end = time.perf_counter() + session_e2e_time = session_end - session_start + tool_call_count += 1 + + metrics = SessionMetrics( + session_no=request_func_input.no, + session_e2e_time=session_e2e_time, + pure_llm_time=llm_time, + input_tokens=input_tokens, + output_tokens=output_tokens, + tool_calls=tool_call_count, + ) + + return outputs, metrics + + if not is_tool_result: + history.append( + { + "role": "assistant", + "content": output.generated_text, + } + ) + break + + assistant_msg = { + "role": "assistant", + "content": output.generated_text, + } + + if getattr(output, "tool_calls", None): + assistant_msg["tool_calls"] = [ + { + "id": tc["id"], + "type": "function", + "function": { + "name": tc["name"], + "arguments": json.dumps(tc["arguments"], ensure_ascii=False), + }, + } + for tc in output.tool_calls + ] + + history.append(assistant_msg) + + history.append( + { + "role": "tool", + "content": json.dumps(tool_result, ensure_ascii=False), + "tool_call_id": tool_id or tool_name, + } + ) + tool_call_count += 1 + + round_input.history_QA = history + + s0 = time.perf_counter() + output = await async_request_eb_openai_chat_completions( + round_input, + pbar=None, + session=session, + ) + s1 = time.perf_counter() + llm_time += s1 - s0 + + outputs.append(output) + + if not output.success: + session_end = time.perf_counter() + metrics = SessionMetrics( + session_no=request_func_input.no, + session_e2e_time=session_end - session_start, + pure_llm_time=llm_time, + input_tokens=input_tokens, + output_tokens=output_tokens, + tool_calls=tool_call_count, + ) + return outputs, metrics + + input_tokens += output.prompt_tokens + output_tokens += output.output_tokens + # 若session输入长度超过max_prompt_len,则停止session + if max_prompt_len and input_tokens >= max_prompt_len: + print( + f"[SESSION STOP] {round_input.no} reach max_prompt_len={max_prompt_len}, stop session" + ) + session_end = time.perf_counter() + metrics = SessionMetrics( + session_no=request_func_input.no, + session_e2e_time=session_end - session_start, + pure_llm_time=llm_time, + input_tokens=input_tokens, + output_tokens=output_tokens, + tool_calls=tool_call_count, + ) + return outputs, metrics + else: + print(f"Warning {prompt_no} exceed max_loop={max_loop}, force stop tool loop") + + else: + # 无tools + # history.append( + # { + # "role": "assistant", + # "content": output.generated_text, + # } + # ) + # SWE数据集拒绝交互,直接用数据集里的模型返回 + pass + + prompt_no += 1 + elif message["role"] == "assistant": + # continue + # SWE数据集拒绝交互,直接用数据集里的模型返回 + history.append(message) + else: + history.append(message) + + session_end = time.perf_counter() + session_e2e_time = session_end - session_start + + if pbar: + pbar.update(1) + + metrics = SessionMetrics( + session_no=request_func_input.no, + session_e2e_time=session_e2e_time, + pure_llm_time=llm_time, + input_tokens=input_tokens, + output_tokens=output_tokens, + tool_calls=tool_call_count, + ) + + return outputs, metrics + + +async def async_request_eb_openai_completions( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + """Request an LLM using EB OpenAI""" + api_url = request_func_input.api_url + assert api_url.endswith( + ("completions", "profile") + ), "OpenAI Completions API URL must end with 'completions' or 'profile'." + + async with aiohttp.ClientSession( + trust_env=True, read_bufsize=10 * 1024 * 1024, timeout=AIOHTTP_TIMEOUT + ) as session: + payload = { + "model": request_func_input.model, + "prompt": request_func_input.prompt, + "stream": True, + "stream_options": { + "include_usage": True, + "continuous_usage_stats": True, + }, + } + # 超参由yaml传入 + payload.update(request_func_input.hyper_parameters) + + if request_func_input.ignore_eos: + payload["ignore_eos"] = request_func_input.ignore_eos + + if request_func_input.debug: + print("payload:", json.dumps(payload, ensure_ascii=False)) + + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "Content-Type": "application/json", + } + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + output.no = request_func_input.no + + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload, headers=headers) as response: + if response.status == 200: + first_chunk_received = False + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = chunk_bytes.decode("utf-8").removeprefix("data: ") + if chunk != "[DONE]": + # print("####chunk:", chunk, chunk.usage) + timestamp = time.perf_counter() + data = json.loads(chunk) + + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # want to check a token was generated + if choices := data.get("choices"): + # Note that text could be empty here + # e.g. for special tokens + text = choices[0].get("text") + + # First token + if not first_chunk_received: + first_chunk_received = True + ttft = timestamp - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + generated_text += text or "" + + most_recent_timestamp = timestamp + output.arrival_time.append(choices[0].get("arrival_time", timestamp)) + elif usage := data.get("usage"): + output.prompt_tokens = usage.get("prompt_tokens") + output.output_tokens = usage.get("completion_tokens") + if first_chunk_received: + output.success = True + else: + output.success = False + output.error = ( + "Never received a valid chunk to calculate TTFT." "This response will be marked as failed!" + ) + + output.generated_text = generated_text + output.latency = most_recent_timestamp - st + + if output.generated_text == "": + output.success = False + output.error = "No generated text found!" + else: + output.success = True + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if request_func_input.debug: + print(f"final_output:{output}") + + if pbar: + pbar.update(1) + return output + + +async def async_request_tgi( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + """Request an LLM using the TGI API""" + api_url = request_func_input.api_url + assert api_url.endswith("generate_stream") + + async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session: + params = { + "max_new_tokens": request_func_input.output_len, + "do_sample": True, + "temperature": 0.01, # TGI does not accept 0.0 temperature. + "top_p": 0.99, # TGI does not accept 1.0 top_p. + "truncate": request_func_input.prompt_len, + "ignore_eos_token": request_func_input.ignore_eos, + } + payload = { + "inputs": request_func_input.prompt, + "parameters": params, + } + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + if request_func_input.ignore_eos: + output.output_tokens = request_func_input.output_len + else: + output.output_tokens = None + + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + chunk_bytes = chunk_bytes.decode("utf-8") + + # NOTE: Sometimes TGI returns a ping response without + # any data, we should skip it. + if chunk_bytes.startswith(":"): + continue + chunk = chunk_bytes.removeprefix("data:") + + data = json.loads(chunk) + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + output.arrival_time.append(data["arrival_time"]) + + output.latency = most_recent_timestamp - st + output.success = True + output.generated_text = data["generated_text"] + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +async def async_request_trt_llm( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + """Request an LLM using TRT's llm_server""" + api_url = request_func_input.api_url + assert api_url.endswith("generate_stream") + + async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session: + payload = { + "accumulate_tokens": True, + "text_input": request_func_input.prompt, + "temperature": 0.0, + "top_p": 1.0, + "max_tokens": request_func_input.output_len, + "stream": True, + } + if request_func_input.ignore_eos: + payload["min_length"] = request_func_input.output_len + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = chunk_bytes.decode("utf-8").removeprefix("data:") + + data = json.loads(chunk) + output.generated_text += data["text_output"] + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + + output.latency = most_recent_timestamp - st + output.success = True + + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +async def async_request_deepspeed_mii( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + """Request an LLM using Deepspeed MII""" + async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session: + + payload = { + "prompt": request_func_input.prompt, + "max_tokens": request_func_input.output_len, + "temperature": 0.01, # deepspeed-mii does not accept 0.0 temp. + "top_p": 1.0, + } + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + # NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024, + # will use 0 as placeholder. + # See https://github.com/microsoft/DeepSpeed-MII/pull/311 + output.ttft = 0 + + st = time.perf_counter() + try: + async with session.post(url=request_func_input.api_url, json=payload) as response: + if response.status == 200: + parsed_resp = await response.json() + output.latency = time.perf_counter() - st + if "choices" in parsed_resp: + output.generated_text = parsed_resp["choices"][0]["text"] + elif "text" in parsed_resp: + output.generated_text = parsed_resp["text"][0] + else: + output.error = "Unexpected response format: " "neither 'choices' nor 'text' found" + output.success = False + output.success = True + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +async def async_request_openai_completions( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + """Request an LLM using OpenAI""" + api_url = request_func_input.api_url + assert api_url.endswith( + ("completions", "profile") + ), "OpenAI Completions API URL must end with 'completions' or 'profile'." + + async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session: + payload = { + "model": (request_func_input.model_name if request_func_input.model_name else request_func_input.model), + "prompt": request_func_input.prompt, + # "temperature": 0.0, + "max_tokens": request_func_input.output_len, + "logprobs": request_func_input.logprobs, + "stream": True, + # "stream_options": { + # "include_usage": True, + # }, + } + if request_func_input.ignore_eos: + payload["ignore_eos"] = request_func_input.ignore_eos + + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload, headers=headers) as response: + if response.status == 200: + first_chunk_received = False + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = chunk_bytes.decode("utf-8").removeprefix("data: ") + if chunk != "[DONE]": + # print("####chunk:", chunk, type(chunk)) + data = json.loads(chunk) + + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # want to check a token was generated + if choices := data.get("choices"): + # Note that text could be empty here + # e.g. for special tokens + text = choices[0].get("text") + timestamp = time.perf_counter() + # First token + if not first_chunk_received: + first_chunk_received = True + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + generated_text += text or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get("completion_tokens") + if first_chunk_received: + output.success = True + else: + output.success = False + output.error = ( + "Never received a valid chunk to calculate TTFT." "This response will be marked as failed!" + ) + output.generated_text = generated_text + output.latency = most_recent_timestamp - st + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +async def async_request_openai_audio( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + """Request an LLM using OpenAI""" + # Lazy import without PlaceholderModule to avoid vllm dep. + import soundfile + + api_url = request_func_input.api_url + assert api_url.endswith( + ("transcriptions", "translations") + ), "OpenAI Chat Completions API URL must end with 'transcriptions' " + "or `translations`." + + async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session: + content = [{"type": "text", "text": request_func_input.prompt}] + payload = { + "model": (request_func_input.model_name if request_func_input.model_name else request_func_input.model), + "temperature": 0.0, + "max_completion_tokens": request_func_input.output_len, + "stream": True, + "language": "en", + # Flattened due to multipart/form-data + "stream_include_usage": True, + "stream_continuous_usage_stats": True, + } + if request_func_input.extra_body: + payload.update(request_func_input.extra_body) + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + + # Send audio file + def to_bytes(y, sr): + buffer = io.BytesIO() + soundfile.write(buffer, y, sr, format="WAV") + buffer.seek(0) + return buffer + + with to_bytes(*request_func_input.multi_modal_content["audio"]) as f: + form = aiohttp.FormData() + form.add_field("file", f, content_type="audio/wav") + for key, value in payload.items(): + form.add_field(key, str(value)) + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, data=form, headers=headers) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = chunk_bytes.decode("utf-8").removeprefix("data: ") + if chunk != "[DONE]": + timestamp = time.perf_counter() + data = json.loads(chunk) + + if choices := data.get("choices"): + content = choices[0]["delta"].get("content") + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + generated_text += content or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get("completion_tokens") + + most_recent_timestamp = timestamp + + output.generated_text = generated_text + output.success = True + output.latency = most_recent_timestamp - st + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +ASYNC_REQUEST_FUNCS = { + "tgi": async_request_tgi, + "vllm": async_request_openai_completions, + "lmdeploy": async_request_openai_completions, + "deepspeed-mii": async_request_deepspeed_mii, + "openai": async_request_eb_openai_completions, + "openai-chat": async_request_eb_openai_chat_completions, + "openai-chat-multi-turn": async_request_eb_openai_chat_completions_multi_turn, + "openai-audio": async_request_openai_audio, + "tensorrt-llm": async_request_trt_llm, + "scalellm": async_request_openai_completions, + "sglang": async_request_openai_completions, +} + +OPENAI_COMPATIBLE_BACKENDS = [ + k + for k, v in ASYNC_REQUEST_FUNCS.items() + if v + in ( + async_request_openai_completions, + async_request_eb_openai_chat_completions, + ) +]