[XPU]ZMQ logprob (#5628)

* [XPU]ZMQ logprob
This commit is contained in:
qw86972190
2025-12-25 14:50:01 +08:00
committed by GitHub
parent 75b3180280
commit 135e47d551
2 changed files with 192 additions and 24 deletions
@@ -14,15 +14,18 @@
# limitations under the License.
"""
from typing import Dict, Optional
import queue
from typing import Dict, List, Optional
import numpy as np
import paddle
from fastdeploy import envs
from fastdeploy.model_executor.forward_meta import XPUForwardMeta
from fastdeploy.model_executor.layers.sample.sampler import Sampler
from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData
from fastdeploy.platforms import current_platform
from fastdeploy.worker.output import ModelOutputData
from fastdeploy.worker.output import LogprobsTensors, ModelOutputData
if current_platform.is_xpu():
from fastdeploy.model_executor.ops.xpu import (
@@ -49,6 +52,43 @@ if current_platform.is_xpu():
)
def _build_stream_transfer_data(
output_tokens: paddle.Tensor,
pooler_outputs: List = None,
logprobs: Optional[LogprobsTensors] = None,
prompt_logprobs_list: Optional[LogprobsTensors] = None,
):
"""Split output_tokens and output"""
stream_transfer_datas = []
if output_tokens is not None:
output_tokens = output_tokens.reshape([-1]).numpy()
output_tokens_lists = np.split(output_tokens, output_tokens.shape[0])
for bid, output_token_per_sample in enumerate(output_tokens_lists):
stream_transfer_data = StreamTransferData(
decoder_state=DecoderState.TEXT, tokens=output_token_per_sample, batch_id=bid
)
if logprobs:
stream_transfer_data.logprobs = logprobs.slice_rows(bid, bid + 1)
if prompt_logprobs_list:
stream_transfer_data.prompt_logprobs = prompt_logprobs_list[bid]
stream_transfer_datas.append(stream_transfer_data)
elif pooler_outputs is not None:
for bid, pooler_output in enumerate(pooler_outputs):
if pooler_output is None:
continue
if pooler_output.dtype == paddle.bfloat16:
pooler_output = pooler_output.astype("float32")
pooler_output = pooler_output.numpy()
stream_transfer_data = StreamTransferData(
decoder_state=DecoderState.TEXT, pooler_output=pooler_output, batch_id=bid
)
stream_transfer_datas.append(stream_transfer_data)
return stream_transfer_datas
def xpu_pre_process(
input_ids: paddle.Tensor,
seq_lens_this_time: int,
@@ -217,6 +257,8 @@ def xpu_post_process_normal(
share_inputs: Dict[str, paddle.Tensor],
block_size: int = 64,
skip_save_output: bool = False,
save_each_rank: bool = False,
async_output_queue: queue.Queue = None,
think_end_id: int = None,
line_break_id: int = None,
) -> None:
@@ -314,27 +356,37 @@ def xpu_post_process_normal(
# 3. Transmit the model's output and stop generation signal via message queue.
# In the future, we will abandon this approach.
if not skip_save_output:
if sampler_output.logprobs_tensors is None:
save_output(
sampled_token_ids,
model_output.not_need_stop,
model_output.mp_rank,
False, # use_ep
)
else:
if save_output_topk is None:
raise ImportError(
"save_output_topk operator is not available. "
"Please rebuild the XPU operators with the new get_output_msg_with_topk.cc and save_output_msg_with_topk.cc files."
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
if save_each_rank or model_output.mp_rank == 0:
output = _build_stream_transfer_data(
sampled_token_ids,
logprobs=sampler_output.logprobs_tensors,
prompt_logprobs_list=model_output.prompt_logprobs_list,
)
if async_output_queue is not None:
async_output_queue.put(output)
else:
if sampler_output.logprobs_tensors is None:
save_output(
sampled_token_ids,
model_output.not_need_stop,
model_output.mp_rank,
False, # use_ep
)
else:
if save_output_topk is None:
raise ImportError(
"save_output_topk operator is not available. "
"Please rebuild the XPU operators with the new get_output_msg_with_topk.cc and save_output_msg_with_topk.cc files."
)
save_output_topk(
sampled_token_ids,
sampler_output.logprobs_tensors.logprob_token_ids,
sampler_output.logprobs_tensors.logprobs,
sampler_output.logprobs_tensors.selected_token_ranks,
model_output.not_need_stop,
model_output.mp_rank,
)
save_output_topk(
sampled_token_ids,
sampler_output.logprobs_tensors.logprob_token_ids,
sampler_output.logprobs_tensors.logprobs,
sampler_output.logprobs_tensors.selected_token_ranks,
model_output.not_need_stop,
model_output.mp_rank,
)
def xpu_post_process_specualate(