[Cherry-Pick] add mm token usage (#4648)
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

* [Feature] add mm token usage (#4570)

* add mm token usage

* fix unit test

* fix unit test

* fix unit test

* fix model path

* fix unit test

* fix unit test

* fix unit test

* remove uncomment

* change var name

* fix code style

* fix code style

* fix code style

* fix code style

* fix unit test

* update doc

* update doc

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
ApplEOFDiscord
2025-10-30 09:58:07 +08:00
committed by GitHub
parent 895ca7694e
commit 52a6e0be41
13 changed files with 126 additions and 20 deletions
+2
View File
@@ -198,6 +198,8 @@ For ``LLM`` configuration, refer to [Parameter Documentation](parameters.md).
* finished(bool): Completion status
* metrics(fastdeploy.engine.request.RequestMetrics): Performance metrics
* num_cached_tokens(int): Cached token count (only valid when enable_prefix_caching``` is enabled)
* num_input_image_tokens(int): Number of input image tokens.
* num_input_video_tokens(int): Number of input video tokens.
* error_code(int): Error code
* error_msg(str): Error message
+26
View File
@@ -238,6 +238,19 @@ ChatMessage:
completion_token_ids: Optional[List[int]] = None
prompt_tokens: Optional[str] = None
completion_tokens: Optional[str] = None
UsageInfo:
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
prompt_tokens_details: Optional[PromptTokenUsageInfo] = None
completion_tokens_details: Optional[CompletionTokenUsageInfo] = None
PromptTokenUsageInfo:
cached_tokens: Optional[int] = None
image_tokens: Optional[int] = None
video_tokens: Optional[int] = None
CompletionTokenUsageInfo:
reasoning_tokens: Optional[int] = None
image_tokens: Optional[int] = None
ToolCall:
id: str = None
type: Literal["function"] = "function"
@@ -414,6 +427,19 @@ CompletionResponseChoice:
reasoning_content: Optional[str] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls"]]
tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None
UsageInfo:
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
prompt_tokens_details: Optional[PromptTokenUsageInfo] = None
completion_tokens_details: Optional[CompletionTokenUsageInfo] = None
PromptTokenUsageInfo:
cached_tokens: Optional[int] = None
image_tokens: Optional[int] = None
video_tokens: Optional[int] = None
CompletionTokenUsageInfo:
reasoning_tokens: Optional[int] = None
image_tokens: Optional[int] = None
ToolCall:
id: str = None
type: Literal["function"] = "function"
+2
View File
@@ -198,6 +198,8 @@ for output in outputs:
* finished(bool):标识当前query 是否推理结束
* metrics(fastdeploy.engine.request.RequestMetrics):记录推理耗时指标
* num_cached_tokens(int): 缓存的token数量, 仅在开启 ``enable_prefix_caching``时有效
* num_input_image_tokens(int): 输入图片token的数量
* num_input_video_tokens(int): 输入视频token的数量
* error_code(int): 错误码
* error_msg(str): 错误信息
+26
View File
@@ -237,6 +237,19 @@ ChatMessage:
completion_token_ids: Optional[List[int]] = None
prompt_tokens: Optional[str] = None
completion_tokens: Optional[str] = None
UsageInfo:
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
prompt_tokens_details: Optional[PromptTokenUsageInfo] = None
completion_tokens_details: Optional[CompletionTokenUsageInfo] = None
PromptTokenUsageInfo:
cached_tokens: Optional[int] = None
image_tokens: Optional[int] = None
video_tokens: Optional[int] = None
CompletionTokenUsageInfo:
reasoning_tokens: Optional[int] = None
image_tokens: Optional[int] = None
ToolCall:
id: str = None
type: Literal["function"] = "function"
@@ -410,6 +423,19 @@ CompletionResponseChoice:
reasoning_content: Optional[str] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls"]]
tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None
UsageInfo:
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
prompt_tokens_details: Optional[PromptTokenUsageInfo] = None
completion_tokens_details: Optional[CompletionTokenUsageInfo] = None
PromptTokenUsageInfo:
cached_tokens: Optional[int] = None
image_tokens: Optional[int] = None
video_tokens: Optional[int] = None
CompletionTokenUsageInfo:
reasoning_tokens: Optional[int] = None
image_tokens: Optional[int] = None
ToolCall:
id: str = None
type: Literal["function"] = "function"
+10
View File
@@ -447,6 +447,8 @@ class RequestOutput:
encoder_prompt_token_ids: The token IDs of the encoder prompt.
None if decoder-only.
num_cached_tokens: The number of tokens with prefix cache hit.
num_input_image_tokens: The number of input image tokens.
num_input_video_tokens: The number of input video tokens.
"""
def __init__(
@@ -459,6 +461,8 @@ class RequestOutput:
finished: bool = False,
metrics: Optional[RequestMetrics] = None,
num_cached_tokens: Optional[int] = 0,
num_input_image_tokens: Optional[int] = 0,
num_input_video_tokens: Optional[int] = 0,
error_code: Optional[int] = 200,
error_msg: Optional[str] = None,
) -> None:
@@ -470,6 +474,8 @@ class RequestOutput:
self.finished = finished
self.metrics = metrics
self.num_cached_tokens = num_cached_tokens
self.num_input_image_tokens = num_input_image_tokens
self.num_input_video_tokens = num_input_video_tokens
self.error_code = error_code
self.error_msg = error_msg
@@ -512,6 +518,8 @@ class RequestOutput:
f"outputs={self.outputs}, "
f"finished={self.finished}, "
f"num_cached_tokens={self.num_cached_tokens}, "
f"num_input_image_tokens={self.num_input_image_tokens}, "
f"num_input_video_tokens={self.num_input_video_tokens}, "
f"metrics={self.metrics}, "
)
@@ -534,6 +542,8 @@ class RequestOutput:
"metrics": None if self.metrics is None else self.metrics.to_dict(),
"finished": self.finished,
"num_cached_tokens": self.num_cached_tokens,
"num_input_image_tokens": self.num_input_image_tokens,
"num_input_video_tokens": self.num_input_video_tokens,
"error_code": self.error_code,
"error_msg": self.error_msg,
}
+32 -13
View File
@@ -276,6 +276,8 @@ class OpenAIServingChat:
if first_iteration:
num_prompt_tokens = len(prompt_token_ids)
num_cached_tokens = res.get("num_cached_tokens", 0)
num_input_image_tokens = res.get("num_input_image_tokens", 0)
num_input_video_tokens = res.get("num_input_video_tokens", 0)
for i in range(num_choices):
choice = ChatCompletionResponseStreamChoice(
index=i,
@@ -312,7 +314,11 @@ class OpenAIServingChat:
prompt_tokens=num_prompt_tokens,
completion_tokens=0,
total_tokens=num_prompt_tokens,
prompt_tokens_details=PromptTokenUsageInfo(cached_tokens=num_cached_tokens),
prompt_tokens_details=PromptTokenUsageInfo(
cached_tokens=num_cached_tokens,
image_tokens=num_input_image_tokens,
video_tokens=num_input_video_tokens,
),
completion_tokens_details=CompletionTokenUsageInfo(reasoning_tokens=0),
)
yield f"data: {chunk.model_dump_json(exclude_unset=True)} \n\n"
@@ -476,6 +482,8 @@ class OpenAIServingChat:
draft_logprob_contents = [[] for _ in range(num_choices)]
completion_token_ids = [[] for _ in range(num_choices)]
num_cached_tokens = [0] * num_choices
num_input_image_tokens = [0] * num_choices
num_input_video_tokens = [0] * num_choices
num_image_tokens = [0] * num_choices
response_processor = ChatResponseProcessor(
data_processor=self.engine_client.data_processor,
@@ -546,14 +554,15 @@ class OpenAIServingChat:
previous_num_tokens[idx] += data["outputs"].get("image_token_num")
num_image_tokens[idx] = data["outputs"].get("image_token_num")
choice = await self._create_chat_completion_choice(
output=output,
index=idx,
data=data,
request=request,
previous_num_tokens=previous_num_tokens[idx],
prompt_token_ids=prompt_token_ids,
prompt_tokens=prompt_tokens,
completion_token_ids=completion_token_ids[idx],
previous_num_tokens=previous_num_tokens[idx],
num_cached_tokens=num_cached_tokens,
num_input_image_tokens=num_input_image_tokens,
num_input_video_tokens=num_input_video_tokens,
num_image_tokens=num_image_tokens,
logprob_contents=logprob_contents,
response_processor=response_processor,
@@ -571,11 +580,16 @@ class OpenAIServingChat:
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
prompt_tokens_details=PromptTokenUsageInfo(cached_tokens=sum(num_cached_tokens)),
prompt_tokens_details=PromptTokenUsageInfo(
cached_tokens=sum(num_cached_tokens),
image_tokens=sum(num_input_image_tokens),
video_tokens=sum(num_input_video_tokens),
),
completion_tokens_details=CompletionTokenUsageInfo(
reasoning_tokens=num_reasoning_tokens, image_tokens=sum(num_image_tokens)
),
)
choices = sorted(choices, key=lambda x: x.index)
res = ChatCompletionResponse(
id=request_id,
@@ -589,18 +603,21 @@ class OpenAIServingChat:
async def _create_chat_completion_choice(
self,
output: dict,
index: int,
data: dict,
request: ChatCompletionRequest,
previous_num_tokens: int,
prompt_token_ids: list,
prompt_tokens: str,
completion_token_ids: list,
previous_num_tokens: int,
num_cached_tokens: list,
num_input_image_tokens: list,
num_input_video_tokens: list,
num_image_tokens: list,
logprob_contents: list,
response_processor: ChatResponseProcessor,
) -> ChatCompletionResponseChoice:
idx = int(data["request_id"].split("_")[-1])
output = data["outputs"]
if output is not None and output.get("metrics") and output["metrics"].get("request_start_time"):
work_process_metrics.e2e_request_latency.observe(
@@ -621,13 +638,15 @@ class OpenAIServingChat:
message.content = output["text"]
logprobs_full_res = None
if logprob_contents[index]:
logprobs_full_res = LogProbs(content=logprob_contents[index])
if logprob_contents[idx]:
logprobs_full_res = LogProbs(content=logprob_contents[idx])
has_no_token_limit = request.max_tokens is None and request.max_completion_tokens is None
max_tokens = request.max_completion_tokens or request.max_tokens
num_cached_tokens[index] = output.get("num_cached_tokens", 0)
num_image_tokens[index] = output.get("num_image_tokens", 0)
num_cached_tokens[idx] = data.get("num_cached_tokens", 0)
num_input_image_tokens[idx] = data.get("num_input_image_tokens", 0)
num_input_video_tokens[idx] = data.get("num_input_video_tokens", 0)
num_image_tokens[idx] = output.get("num_image_tokens", 0)
finish_reason = "stop"
if has_no_token_limit or previous_num_tokens != max_tokens:
@@ -640,7 +659,7 @@ class OpenAIServingChat:
finish_reason = "recover_stop"
return ChatCompletionResponseChoice(
index=index,
index=idx,
message=message,
logprobs=logprobs_full_res,
finish_reason=finish_reason,
@@ -193,6 +193,8 @@ class DataProcessor:
"labels": [],
"cur_position": 0,
"video_cnt": 0,
"num_input_image_tokens": 0,
"num_input_video_tokens": 0,
"mm_positions": [],
"mm_hashes": [],
}
@@ -357,6 +359,7 @@ class DataProcessor:
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
outputs["num_input_image_tokens"] += num_tokens
pos_ids = self._compute_3d_positions(1, patches_h, patches_w, outputs["cur_position"])
outputs["position_ids"].extend(pos_ids)
@@ -428,6 +431,7 @@ class DataProcessor:
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
outputs["num_input_video_tokens"] += num_tokens
pos_ids = self._compute_3d_positions(num_frames, patches_h, patches_w, outputs["cur_position"])
outputs["position_ids"].extend(pos_ids)
@@ -143,9 +143,10 @@ class DataProcessor:
"labels": [],
"cur_position": 0,
"video_cnt": 0,
"num_input_image_tokens": 0,
"num_input_video_tokens": 0,
"fps": [],
"mm_positions": [],
"mm_hashes": [],
"vit_seqlen": [],
"vit_position_ids": [],
}
@@ -354,6 +355,7 @@ class DataProcessor:
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.image_token_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
outputs["num_input_image_tokens"] += int(num_tokens)
outputs["images"].append(ret["pixel_values"])
if not uuid:
@@ -414,6 +416,7 @@ class DataProcessor:
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.video_token_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
outputs["num_input_video_tokens"] += int(num_tokens)
outputs["images"].append(ret["pixel_values"])
if not uuid:
@@ -142,6 +142,8 @@ class DataProcessor:
"labels": [],
"cur_position": 0,
"video_cnt": 0,
"num_input_image_tokens": 0,
"num_input_video_tokens": 0,
"fps": [],
"mm_positions": [],
"mm_hashes": [],
@@ -351,6 +353,7 @@ class DataProcessor:
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.image_token_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
outputs["num_input_image_tokens"] += int(num_tokens)
outputs["images"].append(ret["pixel_values"])
if not uuid:
@@ -409,6 +412,7 @@ class DataProcessor:
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.video_token_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
outputs["num_input_video_tokens"] += int(num_tokens)
outputs["images"].append(ret["pixel_values"])
if not uuid:
+6
View File
@@ -289,6 +289,9 @@ class TokenProcessor:
if task.messages is not None:
result.prompt = task.messages
result.num_cached_tokens = task.num_cached_tokens
if task.get("multimodal_inputs", None):
result.num_input_image_tokens = task.multimodal_inputs.get("num_input_image_tokens", 0)
result.num_input_video_tokens = task.multimodal_inputs.get("num_input_video_tokens", 0)
is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill"
result = self._process_per_token(task, i, token_ids, result, is_prefill)
@@ -655,6 +658,9 @@ class TokenProcessor:
if task.messages is not None:
result.prompt = task.messages
result.num_cached_tokens = task.num_cached_tokens
if task.get("multimodal_inputs", None):
result.num_input_image_tokens = task.multimodal_inputs.get("num_input_image_tokens", 0)
result.num_input_video_tokens = task.multimodal_inputs.get("num_input_video_tokens", 0)
is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill"
+1
View File
@@ -32,6 +32,7 @@ def test_unstream_with_logprobs():
"bytes": [231, 137, 155, 233, 161, 191],
"top_logprobs": None,
}
assert resp_json["usage"]["prompt_tokens"] == 22
assert resp_json["usage"]["completion_tokens"] == 3
assert resp_json["usage"]["total_tokens"] == 25
@@ -387,10 +387,10 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
"text": "Normal AI response",
"reasoning_content": "Normal reasoning",
"tool_call": None,
"num_cached_tokens": 3,
"num_image_tokens": 2,
"raw_prediction": "raw_answer_0",
},
"num_cached_tokens": 3,
"finished": True,
"previous_num_tokens": 2,
},
@@ -416,10 +416,10 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
"text": "Edge case response",
"reasoning_content": None,
"tool_call": None,
"num_cached_tokens": 0,
"num_image_tokens": 0,
"raw_prediction": None,
},
"num_cached_tokens": 0,
"finished": True,
"previous_num_tokens": 1,
},
@@ -446,18 +446,21 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
mock_response_processor.enable_multimodal_content.return_value = False
completion_token_ids = [[], []]
num_cached_tokens = [0, 0]
num_input_image_tokens = [0, 0]
num_input_video_tokens = [0, 0]
num_image_tokens = [0, 0]
for idx, case in enumerate(test_cases):
actual_choice = await self.chat_serving._create_chat_completion_choice(
output=case["test_data"]["outputs"],
index=idx,
data=case["test_data"],
request=case["mock_request"],
previous_num_tokens=case["test_data"]["previous_num_tokens"],
prompt_token_ids=prompt_token_ids,
prompt_tokens=prompt_tokens,
completion_token_ids=completion_token_ids[idx],
previous_num_tokens=case["test_data"]["previous_num_tokens"],
num_cached_tokens=num_cached_tokens,
num_input_image_tokens=num_input_image_tokens,
num_input_video_tokens=num_input_video_tokens,
num_image_tokens=num_image_tokens,
logprob_contents=logprob_contents,
response_processor=mock_response_processor,
+1 -1
View File
@@ -50,7 +50,7 @@ class MockTask:
def get(self, key: str, default_value=None):
if hasattr(self, key):
return getattr(self, key)
elif hasattr(self.sampling_params, key):
elif hasattr(self, "sampling_params") and hasattr(self.sampling_params, key):
return getattr(self.sampling_params, key)
else:
return default_value