mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Speculative Decoding] split draft_tokens into standalone post-processing path (#5205)
* refactor(mtp): split draft_tokens into standalone post-processing path for MTP + logprobs * Restore Request.__repr__ implementation * ci * add envs * fix unittest
This commit is contained in:
@@ -25,6 +25,7 @@ from typing import Any, Dict, Generic, Optional, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from typing_extensions import TypeVar
|
from typing_extensions import TypeVar
|
||||||
|
|
||||||
|
from fastdeploy import envs
|
||||||
from fastdeploy.engine.pooling_params import PoolingParams
|
from fastdeploy.engine.pooling_params import PoolingParams
|
||||||
from fastdeploy.engine.sampling_params import SamplingParams
|
from fastdeploy.engine.sampling_params import SamplingParams
|
||||||
from fastdeploy.entrypoints.openai.protocol import ToolCall
|
from fastdeploy.entrypoints.openai.protocol import ToolCall
|
||||||
@@ -331,8 +332,20 @@ class Request:
|
|||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
"""Safe string representation that ignores private and None fields."""
|
"""Sanitized repr without private or None fields."""
|
||||||
return ""
|
try:
|
||||||
|
if not envs.FD_DEBUG:
|
||||||
|
return f"Request(request_id={self.request_id})"
|
||||||
|
else:
|
||||||
|
attrs_snapshot = dict(vars(self))
|
||||||
|
non_none_fields = [
|
||||||
|
f"{attr}={value!r}"
|
||||||
|
for attr, value in attrs_snapshot.items()
|
||||||
|
if value is not None and not attr.startswith("_")
|
||||||
|
]
|
||||||
|
return f"Request({', '.join(non_none_fields)})"
|
||||||
|
except Exception as e:
|
||||||
|
return f"<Request repr failed: {e}>"
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
|
|||||||
@@ -571,6 +571,7 @@ class OpenAIServingChat:
|
|||||||
num_input_video_tokens=num_input_video_tokens,
|
num_input_video_tokens=num_input_video_tokens,
|
||||||
num_image_tokens=num_image_tokens,
|
num_image_tokens=num_image_tokens,
|
||||||
logprob_contents=logprob_contents,
|
logprob_contents=logprob_contents,
|
||||||
|
draft_logprob_contents=draft_logprob_contents,
|
||||||
response_processor=response_processor,
|
response_processor=response_processor,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
)
|
)
|
||||||
@@ -622,6 +623,7 @@ class OpenAIServingChat:
|
|||||||
num_input_video_tokens: list,
|
num_input_video_tokens: list,
|
||||||
num_image_tokens: list,
|
num_image_tokens: list,
|
||||||
logprob_contents: list,
|
logprob_contents: list,
|
||||||
|
draft_logprob_contents: list,
|
||||||
response_processor: ChatResponseProcessor,
|
response_processor: ChatResponseProcessor,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
) -> ChatCompletionResponseChoice:
|
) -> ChatCompletionResponseChoice:
|
||||||
@@ -649,6 +651,9 @@ class OpenAIServingChat:
|
|||||||
logprobs_full_res = None
|
logprobs_full_res = None
|
||||||
if logprob_contents[idx]:
|
if logprob_contents[idx]:
|
||||||
logprobs_full_res = LogProbs(content=logprob_contents[idx])
|
logprobs_full_res = LogProbs(content=logprob_contents[idx])
|
||||||
|
draft_logprobs_full_res = None
|
||||||
|
if draft_logprob_contents[idx]:
|
||||||
|
draft_logprobs_full_res = LogProbs(content=draft_logprob_contents[idx])
|
||||||
|
|
||||||
num_cached_tokens[idx] = data.get("num_cached_tokens", 0)
|
num_cached_tokens[idx] = data.get("num_cached_tokens", 0)
|
||||||
num_input_image_tokens[idx] = data.get("num_input_image_tokens", 0)
|
num_input_image_tokens[idx] = data.get("num_input_image_tokens", 0)
|
||||||
@@ -669,6 +674,7 @@ class OpenAIServingChat:
|
|||||||
index=idx,
|
index=idx,
|
||||||
message=message,
|
message=message,
|
||||||
logprobs=logprobs_full_res,
|
logprobs=logprobs_full_res,
|
||||||
|
draft_logprobs=draft_logprobs_full_res,
|
||||||
finish_reason=finish_reason,
|
finish_reason=finish_reason,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -527,6 +527,60 @@ class TokenProcessor:
|
|||||||
self.total_step = 0
|
self.total_step = 0
|
||||||
self.speculative_stats_step += 1
|
self.speculative_stats_step += 1
|
||||||
|
|
||||||
|
def _process_batch_draft_tokens(self, mtype, batch, accept_num, tokens, scores, ranks):
|
||||||
|
"""
|
||||||
|
Process batch draft tokens and generate corresponding request outputs
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mtype (int): Message type (3=target token, 4=draft token)
|
||||||
|
batch (int): Batch size
|
||||||
|
accept_num (list): List of accepted token counts per request
|
||||||
|
tokens (paddle.Tensor): Generated draft token IDs tensor
|
||||||
|
scores (paddle.Tensor): Token scores tensor
|
||||||
|
ranks (paddle.Tensor): Token sampling ranks tensor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[RequestOutput]: List containing processed results for all requests
|
||||||
|
"""
|
||||||
|
batch_result = list()
|
||||||
|
for i in range(batch):
|
||||||
|
if self.resource_manager.stop_flags[i]:
|
||||||
|
continue
|
||||||
|
task = self.resource_manager.tasks_list[i]
|
||||||
|
task_id = task.request_id
|
||||||
|
result = RequestOutput(
|
||||||
|
request_id=task_id,
|
||||||
|
output_type=mtype,
|
||||||
|
outputs=CompletionOutput(
|
||||||
|
index=i,
|
||||||
|
send_idx=None,
|
||||||
|
token_ids=[],
|
||||||
|
draft_token_ids=[],
|
||||||
|
),
|
||||||
|
finished=False,
|
||||||
|
metrics=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
token_ids = tokens[i][:, 0].tolist()[: accept_num[i]]
|
||||||
|
for batch_token_index in range(len(token_ids)):
|
||||||
|
result.outputs.logprob = float(scores[i, batch_token_index, 0])
|
||||||
|
topk_token_ids = tokens[i, batch_token_index, :].tolist()
|
||||||
|
topk_logprobs = scores[i, batch_token_index, :].tolist()
|
||||||
|
sampled_rank = ranks[i, batch_token_index].item()
|
||||||
|
|
||||||
|
if result.outputs.draft_top_logprobs is None:
|
||||||
|
result.outputs.draft_top_logprobs = LogprobsLists(
|
||||||
|
logprob_token_ids=[topk_token_ids],
|
||||||
|
logprobs=[topk_logprobs],
|
||||||
|
sampled_token_ranks=[sampled_rank],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result.outputs.draft_top_logprobs.logprob_token_ids.extend([topk_token_ids])
|
||||||
|
result.outputs.draft_top_logprobs.logprobs.extend([topk_logprobs])
|
||||||
|
result.outputs.draft_top_logprobs.sampled_token_ranks.extend([sampled_rank])
|
||||||
|
batch_result.append(result)
|
||||||
|
return batch_result
|
||||||
|
|
||||||
def _process_batch_output(self):
|
def _process_batch_output(self):
|
||||||
"""
|
"""
|
||||||
batch post-processing function
|
batch post-processing function
|
||||||
@@ -551,6 +605,12 @@ class TokenProcessor:
|
|||||||
.reshape([batch, MAX_DRAFT_TOKENS, K + 1])
|
.reshape([batch, MAX_DRAFT_TOKENS, K + 1])
|
||||||
)
|
)
|
||||||
ranks = self.output_ranks[: batch * MAX_DRAFT_TOKENS].numpy().reshape([batch, MAX_DRAFT_TOKENS])
|
ranks = self.output_ranks[: batch * MAX_DRAFT_TOKENS].numpy().reshape([batch, MAX_DRAFT_TOKENS])
|
||||||
|
|
||||||
|
# split draft_tokens into standalone post-processing path for MTP + logprobs
|
||||||
|
if mtype == 4:
|
||||||
|
batch_result = self._process_batch_draft_tokens(mtype, batch, accept_num, tokens, scores, ranks)
|
||||||
|
self.postprocess(batch_result, mtype)
|
||||||
|
return
|
||||||
else:
|
else:
|
||||||
batch = self.output_tokens[1]
|
batch = self.output_tokens[1]
|
||||||
accept_num = tokens[2 : batch + 2]
|
accept_num = tokens[2 : batch + 2]
|
||||||
@@ -678,8 +738,7 @@ class TokenProcessor:
|
|||||||
if not (envs.FD_ENABLE_INTERNAL_ADAPTER and token_id in task.eos_token_ids):
|
if not (envs.FD_ENABLE_INTERNAL_ADAPTER and token_id in task.eos_token_ids):
|
||||||
result.outputs.token_ids.append(token_id)
|
result.outputs.token_ids.append(token_id)
|
||||||
|
|
||||||
if mtype == 3:
|
task.output_token_ids.append(token_id)
|
||||||
task.output_token_ids.append(token_id)
|
|
||||||
|
|
||||||
if self.use_logprobs:
|
if self.use_logprobs:
|
||||||
if self.cfg.speculative_config.method:
|
if self.cfg.speculative_config.method:
|
||||||
@@ -693,29 +752,18 @@ class TokenProcessor:
|
|||||||
topk_logprobs = scores[i, :].tolist()
|
topk_logprobs = scores[i, :].tolist()
|
||||||
sampled_rank = ranks[i].item()
|
sampled_rank = ranks[i].item()
|
||||||
|
|
||||||
if mtype == 3: # top_logprobs
|
if result.outputs.top_logprobs is None:
|
||||||
if result.outputs.top_logprobs is None:
|
result.outputs.top_logprobs = LogprobsLists(
|
||||||
result.outputs.top_logprobs = LogprobsLists(
|
logprob_token_ids=[topk_token_ids],
|
||||||
logprob_token_ids=[topk_token_ids],
|
logprobs=[topk_logprobs],
|
||||||
logprobs=[topk_logprobs],
|
sampled_token_ranks=[sampled_rank],
|
||||||
sampled_token_ranks=[sampled_rank],
|
)
|
||||||
)
|
else:
|
||||||
else:
|
result.outputs.top_logprobs.logprob_token_ids.extend([topk_token_ids])
|
||||||
result.outputs.top_logprobs.logprob_token_ids.extend([topk_token_ids])
|
result.outputs.top_logprobs.logprobs.extend([topk_logprobs])
|
||||||
result.outputs.top_logprobs.logprobs.extend([topk_logprobs])
|
result.outputs.top_logprobs.sampled_token_ranks.extend([sampled_rank])
|
||||||
result.outputs.top_logprobs.sampled_token_ranks.extend([sampled_rank])
|
|
||||||
elif mtype == 4: # draft_top_logprobs
|
if token_id in task.eos_token_ids or is_prefill or recovery_stop:
|
||||||
if result.outputs.draft_top_logprobs is None:
|
|
||||||
result.outputs.draft_top_logprobs = LogprobsLists(
|
|
||||||
logprob_token_ids=[topk_token_ids],
|
|
||||||
logprobs=[topk_logprobs],
|
|
||||||
sampled_token_ranks=[sampled_rank],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
result.outputs.draft_top_logprobs.logprob_token_ids.extend([topk_token_ids])
|
|
||||||
result.outputs.draft_top_logprobs.logprobs.extend([topk_logprobs])
|
|
||||||
result.outputs.draft_top_logprobs.sampled_token_ranks.extend([sampled_rank])
|
|
||||||
if mtype == 3 and (token_id in task.eos_token_ids or is_prefill or recovery_stop):
|
|
||||||
result.finished = True
|
result.finished = True
|
||||||
if recovery_stop:
|
if recovery_stop:
|
||||||
result.error_msg = "Recover is not supported, the result is incomplete!"
|
result.error_msg = "Recover is not supported, the result is incomplete!"
|
||||||
|
|||||||
@@ -445,6 +445,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
|
|||||||
prompt_token_ids = [1, 2]
|
prompt_token_ids = [1, 2]
|
||||||
prompt_tokens = "test_prompt"
|
prompt_tokens = "test_prompt"
|
||||||
logprob_contents = [[{"token": "hello", "logprob": 0.1}], [{"token": "hello", "logprob": 0.1}]]
|
logprob_contents = [[{"token": "hello", "logprob": 0.1}], [{"token": "hello", "logprob": 0.1}]]
|
||||||
|
draft_logprob_contents = [[{"token": "hello", "logprob": 0.1}], [{"token": "hello", "logprob": 0.1}]]
|
||||||
mock_response_processor = Mock()
|
mock_response_processor = Mock()
|
||||||
mock_response_processor.enable_multimodal_content.return_value = False
|
mock_response_processor.enable_multimodal_content.return_value = False
|
||||||
completion_token_ids = [[], []]
|
completion_token_ids = [[], []]
|
||||||
@@ -467,6 +468,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
|
|||||||
num_input_video_tokens=num_input_video_tokens,
|
num_input_video_tokens=num_input_video_tokens,
|
||||||
num_image_tokens=num_image_tokens,
|
num_image_tokens=num_image_tokens,
|
||||||
logprob_contents=logprob_contents,
|
logprob_contents=logprob_contents,
|
||||||
|
draft_logprob_contents=draft_logprob_contents,
|
||||||
response_processor=mock_response_processor,
|
response_processor=mock_response_processor,
|
||||||
max_tokens=max_tokens_list[idx],
|
max_tokens=max_tokens_list[idx],
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -0,0 +1,157 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from fastdeploy.engine.request import RequestOutput
|
||||||
|
from fastdeploy.output.token_processor import TokenProcessor
|
||||||
|
|
||||||
|
|
||||||
|
class TestProcessBatchDraftTokens(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
# 模拟 cfg
|
||||||
|
cfg = MagicMock()
|
||||||
|
cfg.speculative_config = MagicMock()
|
||||||
|
cfg.speculative_config.method = "mtp"
|
||||||
|
cfg.speculative_config.num_speculative_tokens = 3
|
||||||
|
cfg.model_config = MagicMock()
|
||||||
|
cfg.model_config.enable_logprob = True
|
||||||
|
|
||||||
|
self.processor = TokenProcessor(
|
||||||
|
cfg=cfg, cached_generated_tokens=MagicMock(), engine_worker_queue=MagicMock(), split_connector=MagicMock()
|
||||||
|
)
|
||||||
|
|
||||||
|
# mock resource_manager
|
||||||
|
self.processor.resource_manager = MagicMock()
|
||||||
|
self.processor.resource_manager.stop_flags = [False] * 512
|
||||||
|
self.processor.resource_manager.tasks_list = [MagicMock()] * 512
|
||||||
|
|
||||||
|
for task in self.processor.resource_manager.tasks_list:
|
||||||
|
task.request_id = "test_request"
|
||||||
|
task.eos_token_ids = [2]
|
||||||
|
|
||||||
|
def test_process_batch_draft_tokens_normal_case(self):
|
||||||
|
"""测试正常情况下的target处理"""
|
||||||
|
batch = 2
|
||||||
|
accept_num = [3, 2]
|
||||||
|
K = 20
|
||||||
|
MAX_DRAFT_TOKENS = 6
|
||||||
|
|
||||||
|
tokens = np.random.randint(100, 200, size=(batch, MAX_DRAFT_TOKENS, K + 1))
|
||||||
|
scores = np.random.rand(batch, MAX_DRAFT_TOKENS, K + 1).astype(np.float32)
|
||||||
|
ranks = np.random.randint(0, K, size=(batch, MAX_DRAFT_TOKENS))
|
||||||
|
|
||||||
|
results = self.processor._process_batch_draft_tokens(
|
||||||
|
mtype=4,
|
||||||
|
batch=batch,
|
||||||
|
accept_num=accept_num,
|
||||||
|
tokens=paddle.to_tensor(tokens),
|
||||||
|
scores=paddle.to_tensor(scores),
|
||||||
|
ranks=paddle.to_tensor(ranks),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(results), batch)
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
self.assertIsInstance(result, RequestOutput)
|
||||||
|
self.assertEqual(result.output_type, 4)
|
||||||
|
self.assertEqual(result.outputs.index, i)
|
||||||
|
self.assertEqual(len(result.outputs.draft_top_logprobs.logprob_token_ids), accept_num[i])
|
||||||
|
self.assertEqual(len(result.outputs.draft_top_logprobs.logprobs), accept_num[i])
|
||||||
|
self.assertEqual(len(result.outputs.draft_top_logprobs.sampled_token_ranks), accept_num[i])
|
||||||
|
|
||||||
|
def test_process_batch_draft_tokens_with_stop_flag(self):
|
||||||
|
"""测试有停止标志的情况"""
|
||||||
|
batch = 3
|
||||||
|
self.processor.resource_manager.stop_flags[1] = True # 第二个 request 停止
|
||||||
|
|
||||||
|
accept_num = [3, 2, 1]
|
||||||
|
K = 20
|
||||||
|
MAX_DRAFT_TOKENS = 6
|
||||||
|
|
||||||
|
tokens = np.random.randint(100, 200, size=(batch, MAX_DRAFT_TOKENS, K + 1))
|
||||||
|
scores = np.random.rand(batch, MAX_DRAFT_TOKENS, K + 1).astype(np.float32)
|
||||||
|
ranks = np.random.randint(0, K, size=(batch, MAX_DRAFT_TOKENS))
|
||||||
|
|
||||||
|
results = self.processor._process_batch_draft_tokens(
|
||||||
|
mtype=4,
|
||||||
|
batch=batch,
|
||||||
|
accept_num=accept_num,
|
||||||
|
tokens=paddle.to_tensor(tokens),
|
||||||
|
scores=paddle.to_tensor(scores),
|
||||||
|
ranks=paddle.to_tensor(ranks),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(results), 2)
|
||||||
|
self.assertEqual(results[0].outputs.index, 0)
|
||||||
|
self.assertEqual(results[1].outputs.index, 2)
|
||||||
|
|
||||||
|
def test_process_batch_draft_tokens_empty_accept(self):
|
||||||
|
"""测试 accept_num 为 0 的情况"""
|
||||||
|
batch = 2
|
||||||
|
accept_num = [0, 0]
|
||||||
|
|
||||||
|
K = 20
|
||||||
|
MAX_DRAFT_TOKENS = 6
|
||||||
|
tokens = np.random.randint(100, 200, size=(batch, MAX_DRAFT_TOKENS, K + 1))
|
||||||
|
scores = np.random.rand(batch, MAX_DRAFT_TOKENS, K + 1).astype(np.float32)
|
||||||
|
ranks = np.random.randint(0, K, size=(batch, MAX_DRAFT_TOKENS))
|
||||||
|
|
||||||
|
results = self.processor._process_batch_draft_tokens(
|
||||||
|
mtype=4,
|
||||||
|
batch=batch,
|
||||||
|
accept_num=accept_num,
|
||||||
|
tokens=paddle.to_tensor(tokens),
|
||||||
|
scores=paddle.to_tensor(scores),
|
||||||
|
ranks=paddle.to_tensor(ranks),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(results), batch)
|
||||||
|
for result in results:
|
||||||
|
self.assertIsNone(result.outputs.draft_top_logprobs)
|
||||||
|
|
||||||
|
def test_process_batch_draft_tokens_different_k_values(self):
|
||||||
|
"""测试不同 K 值情况"""
|
||||||
|
batch = 2
|
||||||
|
accept_num = [3, 2]
|
||||||
|
|
||||||
|
K = 5
|
||||||
|
MAX_DRAFT_TOKENS = 6
|
||||||
|
tokens = np.random.randint(100, 200, size=(batch, MAX_DRAFT_TOKENS, K + 1))
|
||||||
|
scores = np.random.rand(batch, MAX_DRAFT_TOKENS, K + 1).astype(np.float32)
|
||||||
|
ranks = np.random.randint(0, K, size=(batch, MAX_DRAFT_TOKENS))
|
||||||
|
|
||||||
|
results = self.processor._process_batch_draft_tokens(
|
||||||
|
mtype=4,
|
||||||
|
batch=batch,
|
||||||
|
accept_num=accept_num,
|
||||||
|
tokens=paddle.to_tensor(tokens),
|
||||||
|
scores=paddle.to_tensor(scores),
|
||||||
|
ranks=paddle.to_tensor(ranks),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(results), batch)
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
self.assertEqual(len(result.outputs.draft_top_logprobs.logprob_token_ids[0]), K + 1)
|
||||||
|
self.assertEqual(len(result.outputs.draft_top_logprobs.logprobs[0]), K + 1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -1,3 +1,19 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user