[SOT] Remove breakgraph in post processing && fix datatype (#2780)

This commit is contained in:
Ryan
2025-07-10 11:26:00 +08:00
committed by GitHub
parent 2ea267f624
commit b0f525955c
3 changed files with 20 additions and 17 deletions
@@ -18,6 +18,11 @@ import paddle
from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import \
get_block_shape_and_split_kv_block as \
get_block_shape_and_split_kv_block_cuda
def get_block_shape_and_split_kv_block(
seq_lens_encoder: paddle.Tensor,
@@ -34,7 +39,6 @@ def get_block_shape_and_split_kv_block(
get_block_shape_and_split_kv_block
"""
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import get_block_shape_and_split_kv_block
(
encoder_batch_ids,
encoder_tile_ids_per_batch,
@@ -47,7 +51,7 @@ def get_block_shape_and_split_kv_block(
decoder_num_blocks,
max_len_kv,
set_max_lengths,
) = get_block_shape_and_split_kv_block(
) = get_block_shape_and_split_kv_block_cuda(
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,