Make seq_lens_this_time/decoder/encoder equal shape (#6942)

This commit is contained in:
周周周
2026-03-20 15:31:52 +08:00
committed by GitHub
parent 2b10ebc1f1
commit 1c38da2118
2 changed files with 12 additions and 3 deletions
@@ -90,6 +90,13 @@ def append_attention(
append_attention
"""
if current_platform.is_cuda():
bsz = seq_lens_encoder.shape[0]
assert seq_lens_encoder.shape == [bsz]
assert seq_lens_decoder.shape == [bsz]
assert seq_lens_this_time.shape == [bsz]
assert cu_seqlens_q.shape == [bsz + 1]
assert block_tables.shape[0] == bsz
if sliding_window > 0 and head_wise_full_hidden > 0:
out_swa = append_attention_gpu(
qkv.clone(),
+5 -3
View File
@@ -1265,6 +1265,8 @@ class GPUModelRunner(ModelRunnerBase):
routing_replay_table = None
if self.routing_replay_manager is not None:
routing_replay_table = self.routing_replay_manager.get_routing_table()
num_running_requests = self.share_inputs["seq_lens_this_time"].shape[0]
self.forward_meta = ForwardMeta(
ids_remove_padding=self.share_inputs["ids_remove_padding"],
rotary_embs=self.share_inputs["rope_emb"],
@@ -1277,13 +1279,13 @@ class GPUModelRunner(ModelRunnerBase):
decoder_num_blocks_device=self.share_inputs["decoder_num_blocks_device"],
decoder_chunk_size_device=self.share_inputs["decoder_chunk_size_device"],
max_len_tensor_cpu=self.share_inputs["max_len_tensor_cpu"],
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
seq_lens_encoder=self.share_inputs["seq_lens_encoder"][:num_running_requests],
seq_lens_decoder=self.share_inputs["seq_lens_decoder"][:num_running_requests],
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
batch_id_per_token=self.share_inputs["batch_id_per_token"],
cu_seqlens_q=self.share_inputs["cu_seqlens_q"],
cu_seqlens_k=self.share_inputs["cu_seqlens_k"],
block_tables=self.share_inputs["block_tables"],
block_tables=self.share_inputs["block_tables"][:num_running_requests],
caches=self.share_inputs["caches"],
encoder_batch_ids=self.share_inputs["encoder_batch_ids"],
encoder_tile_ids_per_batch=self.share_inputs["encoder_tile_ids_per_batch"],