mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-22 16:07:51 +08:00
Make seq_lens_this_time/decoder/encoder equal shape (#6942)
This commit is contained in:
@@ -90,6 +90,13 @@ def append_attention(
|
||||
append_attention
|
||||
"""
|
||||
if current_platform.is_cuda():
|
||||
bsz = seq_lens_encoder.shape[0]
|
||||
assert seq_lens_encoder.shape == [bsz]
|
||||
assert seq_lens_decoder.shape == [bsz]
|
||||
assert seq_lens_this_time.shape == [bsz]
|
||||
assert cu_seqlens_q.shape == [bsz + 1]
|
||||
assert block_tables.shape[0] == bsz
|
||||
|
||||
if sliding_window > 0 and head_wise_full_hidden > 0:
|
||||
out_swa = append_attention_gpu(
|
||||
qkv.clone(),
|
||||
|
||||
@@ -1265,6 +1265,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
routing_replay_table = None
|
||||
if self.routing_replay_manager is not None:
|
||||
routing_replay_table = self.routing_replay_manager.get_routing_table()
|
||||
|
||||
num_running_requests = self.share_inputs["seq_lens_this_time"].shape[0]
|
||||
self.forward_meta = ForwardMeta(
|
||||
ids_remove_padding=self.share_inputs["ids_remove_padding"],
|
||||
rotary_embs=self.share_inputs["rope_emb"],
|
||||
@@ -1277,13 +1279,13 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
decoder_num_blocks_device=self.share_inputs["decoder_num_blocks_device"],
|
||||
decoder_chunk_size_device=self.share_inputs["decoder_chunk_size_device"],
|
||||
max_len_tensor_cpu=self.share_inputs["max_len_tensor_cpu"],
|
||||
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
|
||||
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
|
||||
seq_lens_encoder=self.share_inputs["seq_lens_encoder"][:num_running_requests],
|
||||
seq_lens_decoder=self.share_inputs["seq_lens_decoder"][:num_running_requests],
|
||||
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
|
||||
batch_id_per_token=self.share_inputs["batch_id_per_token"],
|
||||
cu_seqlens_q=self.share_inputs["cu_seqlens_q"],
|
||||
cu_seqlens_k=self.share_inputs["cu_seqlens_k"],
|
||||
block_tables=self.share_inputs["block_tables"],
|
||||
block_tables=self.share_inputs["block_tables"][:num_running_requests],
|
||||
caches=self.share_inputs["caches"],
|
||||
encoder_batch_ids=self.share_inputs["encoder_batch_ids"],
|
||||
encoder_tile_ids_per_batch=self.share_inputs["encoder_tile_ids_per_batch"],
|
||||
|
||||
Reference in New Issue
Block a user