mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] Support set PREEMPTED_TOKEN_ID in GET_SAVE_OUTPUT_V1 (#7159)
* [Feature] Support set PREEMPTED_TOKEN_ID in GET_SAVE_OUTPUT_V1 * [Feature] Support set PREEMPTED_TOKEN_ID in GET_SAVE_OUTPUT_V1 * fix
This commit is contained in:
@@ -27,7 +27,7 @@ import paddle
|
||||
from paddle import nn
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.config import PREEMPTED_TOKEN_ID, FDConfig
|
||||
from fastdeploy.engine.pooling_params import PoolingParams
|
||||
from fastdeploy.engine.request import ImagePosition, Request, RequestType
|
||||
from fastdeploy.model_executor.graph_optimization.utils import (
|
||||
@@ -2409,6 +2409,16 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
|
||||
# 5.1. Async cpy
|
||||
post_process_event = paddle.device.cuda.create_event()
|
||||
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
|
||||
# If one query is preempted, there is no sampled token for it, we use token_id PREEMPTED_TOKEN_ID to signal server, abort is finished.
|
||||
paddle.assign(
|
||||
paddle.where(
|
||||
self.share_inputs["last_preempted_idx"][: sampler_output.sampled_token_ids.shape[0]] == 1,
|
||||
PREEMPTED_TOKEN_ID,
|
||||
sampler_output.sampled_token_ids,
|
||||
),
|
||||
sampler_output.sampled_token_ids,
|
||||
)
|
||||
# if not self.speculative_decoding:
|
||||
self.share_inputs["sampled_token_ids"].copy_(sampler_output.sampled_token_ids, False)
|
||||
if self.speculative_decoding:
|
||||
|
||||
Reference in New Issue
Block a user