mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-24 01:29:57 +08:00
[SOT] Remove breakgraph in post processing && fix datatype (#2780)
This commit is contained in:
+6
-2
@@ -18,6 +18,11 @@ import paddle
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import \
|
||||
get_block_shape_and_split_kv_block as \
|
||||
get_block_shape_and_split_kv_block_cuda
|
||||
|
||||
|
||||
def get_block_shape_and_split_kv_block(
|
||||
seq_lens_encoder: paddle.Tensor,
|
||||
@@ -34,7 +39,6 @@ def get_block_shape_and_split_kv_block(
|
||||
get_block_shape_and_split_kv_block
|
||||
"""
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import get_block_shape_and_split_kv_block
|
||||
(
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
@@ -47,7 +51,7 @@ def get_block_shape_and_split_kv_block(
|
||||
decoder_num_blocks,
|
||||
max_len_kv,
|
||||
set_max_lengths,
|
||||
) = get_block_shape_and_split_kv_block(
|
||||
) = get_block_shape_and_split_kv_block_cuda(
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
|
||||
Reference in New Issue
Block a user