[XPU] get_infer_param use inplace copy, remove block_tables abundant d2h copy (#7431)

* inplace_copy: encoder_batch_idx/decoder_batch_idx bs == 9 ok

* inplace_copy: encoder_seq_lod/decoder_seq_lod bs == 9 ok

* inplace_copy: all bs == 9 ok

* inplace_copy: all cpu bs == 9 ok

* inplace_copy: len_info_cpu bs == 9 ok

* finished and rm unused code

* prefix_block_tables reuse

* refine

* improve performance

* remove block_table copy to cpu

* fix unit test

* fix

* resolve conflict

* refine code

* fix

* fix

* fix

* fix

* fix

* try fix unit tests

* fix

* tmp save

* fix unit test

* get_infer_param try less return values

* add yinwei fix

---------

Co-authored-by: yinwei <yinwei_hust@163.com>
This commit is contained in:
RuohengMa
2026-04-22 11:01:32 +08:00
committed by GitHub
parent 2edb30c2d0
commit 36d47aa23e
10 changed files with 628 additions and 302 deletions
@@ -158,70 +158,113 @@ def xpu_pre_process(
share_inputs["cu_seqlens_q"] = cu_seqlens_q
share_inputs["cu_seqlens_k"] = cu_seqlens_k
xpu_forward_meta = XPUForwardMeta(
ids_remove_padding=share_inputs["ids_remove_padding"],
rotary_embs=share_inputs["rope_emb"],
attn_backend=None,
seq_lens_encoder=share_inputs["seq_lens_encoder"],
seq_lens_decoder=share_inputs["seq_lens_decoder"],
seq_lens_this_time=share_inputs["seq_lens_this_time"],
batch_id_per_token=share_inputs["batch_id_per_token"],
cu_seqlens_q=share_inputs["cu_seqlens_q"],
cu_seqlens_k=share_inputs["cu_seqlens_k"],
block_tables=share_inputs["block_tables"],
caches=share_inputs["caches"],
max_num_seqs=share_inputs["seq_lens_this_time"].shape[0],
is_speculative=use_speculate_method,
)
if use_cudagraph and forward_meta is not None:
forward_meta.ids_remove_padding.copy_(share_inputs["ids_remove_padding"], False)
forward_meta.rotary_embs.copy_(share_inputs["rope_emb"], False)
forward_meta.attn_backend = None
forward_meta.seq_lens_encoder.copy_(share_inputs["seq_lens_encoder"], False)
forward_meta.seq_lens_decoder.copy_(share_inputs["seq_lens_decoder"], False)
forward_meta.seq_lens_this_time.copy_(share_inputs["seq_lens_this_time"], False)
forward_meta.batch_id_per_token.copy_(share_inputs["batch_id_per_token"], False)
forward_meta.cu_seqlens_q.copy_(share_inputs["cu_seqlens_q"], False)
forward_meta.cu_seqlens_k.copy_(share_inputs["cu_seqlens_k"], False)
forward_meta.block_tables.copy_(share_inputs["block_tables"], False)
forward_meta.caches = share_inputs["caches"]
forward_meta.max_num_seqs = share_inputs["seq_lens_this_time"].shape[0]
forward_meta.is_speculative = use_speculate_method
xpu_forward_meta = forward_meta
else:
xpu_forward_meta = XPUForwardMeta(
ids_remove_padding=share_inputs["ids_remove_padding"],
rotary_embs=share_inputs["rope_emb"],
attn_backend=None,
seq_lens_encoder=share_inputs["seq_lens_encoder"],
seq_lens_decoder=share_inputs["seq_lens_decoder"],
seq_lens_this_time=share_inputs["seq_lens_this_time"],
batch_id_per_token=share_inputs["batch_id_per_token"],
cu_seqlens_q=share_inputs["cu_seqlens_q"],
cu_seqlens_k=share_inputs["cu_seqlens_k"],
block_tables=share_inputs["block_tables"],
caches=share_inputs["caches"],
max_num_seqs=share_inputs["seq_lens_this_time"].shape[0],
is_speculative=use_speculate_method,
)
xpu_forward_meta.init_inplace_tensor(seq_lens_encoder.shape[0], share_inputs["block_tables"].shape)
block_tables = xpu_forward_meta.block_tables
encoder_batch_map = xpu_forward_meta.encoder_batch_map
decoder_batch_map = xpu_forward_meta.decoder_batch_map
encoder_batch_idx = xpu_forward_meta.encoder_batch_idx
decoder_batch_idx = xpu_forward_meta.decoder_batch_idx
encoder_seq_lod = xpu_forward_meta.encoder_seq_lod
decoder_seq_lod = xpu_forward_meta.decoder_seq_lod
encoder_kv_lod = xpu_forward_meta.encoder_kv_lod
prefix_len = xpu_forward_meta.prefix_len
decoder_context_len = xpu_forward_meta.decoder_context_len
decoder_context_len_cache = xpu_forward_meta.decoder_context_len_cache
prefix_block_tables = xpu_forward_meta.prefix_block_tables
encoder_batch_map_cpu = xpu_forward_meta.encoder_batch_map_cpu
decoder_batch_map_cpu = xpu_forward_meta.decoder_batch_map_cpu
encoder_batch_idx_cpu = xpu_forward_meta.encoder_batch_idx_cpu
decoder_batch_idx_cpu = xpu_forward_meta.decoder_batch_idx_cpu
encoder_seq_lod_cpu = xpu_forward_meta.encoder_seq_lod_cpu
decoder_seq_lod_cpu = xpu_forward_meta.decoder_seq_lod_cpu
encoder_kv_lod_cpu = xpu_forward_meta.encoder_kv_lod_cpu
prefix_len_cpu = xpu_forward_meta.prefix_len_cpu
decoder_context_len_cpu = xpu_forward_meta.decoder_context_len_cpu
decoder_context_len_cache_cpu = xpu_forward_meta.decoder_context_len_cache_cpu
len_info_cpu = xpu_forward_meta.len_info_cpu
(
xpu_forward_meta.encoder_batch_map,
xpu_forward_meta.decoder_batch_map,
xpu_forward_meta.encoder_batch_idx,
xpu_forward_meta.decoder_batch_idx,
xpu_forward_meta.encoder_seq_lod,
xpu_forward_meta.decoder_seq_lod,
xpu_forward_meta.encoder_kv_lod,
xpu_forward_meta.prefix_len,
xpu_forward_meta.decoder_context_len,
xpu_forward_meta.decoder_context_len_cache,
xpu_forward_meta.prefix_block_tables,
xpu_forward_meta.encoder_batch_map_cpu,
xpu_forward_meta.decoder_batch_map_cpu,
xpu_forward_meta.encoder_batch_idx_cpu,
xpu_forward_meta.decoder_batch_idx_cpu,
xpu_forward_meta.encoder_seq_lod_cpu,
xpu_forward_meta.decoder_seq_lod_cpu,
xpu_forward_meta.encoder_kv_lod_cpu,
xpu_forward_meta.prefix_len_cpu,
xpu_forward_meta.decoder_context_len_cpu,
xpu_forward_meta.decoder_context_len_cache_cpu,
xpu_forward_meta.len_info_cpu,
xpu_forward_meta.slot_mapping_enc,
xpu_forward_meta.slot_mapping_dec,
slot_mapping_enc,
slot_mapping_dec,
) = get_infer_param(
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
xpu_forward_meta.block_tables,
block_tables,
encoder_batch_map,
decoder_batch_map,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
prefix_len,
decoder_context_len,
decoder_context_len_cache,
prefix_block_tables,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
block_size,
num_speculative_tokens,
)
xpu_forward_meta.enc_batch = xpu_forward_meta.len_info_cpu[0]
xpu_forward_meta.dec_batch = xpu_forward_meta.len_info_cpu[1]
xpu_forward_meta.total_enc_len = xpu_forward_meta.len_info_cpu[2]
adjusted_input = adjust_batch(
ids_remove_padding.reshape([-1, 1]),
xpu_forward_meta.encoder_seq_lod,
xpu_forward_meta.decoder_seq_lod,
xpu_forward_meta.encoder_batch_idx,
xpu_forward_meta.decoder_batch_idx,
xpu_forward_meta.encoder_seq_lod_cpu,
xpu_forward_meta.decoder_seq_lod_cpu,
xpu_forward_meta.encoder_batch_idx_cpu,
xpu_forward_meta.decoder_batch_idx_cpu,
xpu_forward_meta.len_info_cpu,
encoder_seq_lod,
decoder_seq_lod,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
len_info_cpu,
None, # output_padding_offset
-1, # max bs
)
@@ -229,17 +272,22 @@ def xpu_pre_process(
adjusted_input = adjusted_input.squeeze(1)
share_inputs["ids_remove_padding"].copy_(adjusted_input, False)
xpu_forward_meta.enc_batch = len_info_cpu[0]
xpu_forward_meta.dec_batch = len_info_cpu[1]
xpu_forward_meta.total_enc_len = len_info_cpu[2]
xpu_forward_meta.ids_remove_padding = adjusted_input
# Set forward_meta.is_profiling to True to skip init_kv_signal_per_query for attention backends
# Set xpu_forward_meta.is_profiling to True to skip init_kv_signal_per_query for attention backends
xpu_forward_meta.is_profiling = is_profiling
if use_cudagraph:
if forward_meta is None:
return xpu_forward_meta
else:
forward_meta.copy_from(xpu_forward_meta)
return forward_meta
# prefill does not use cudagraph, inplace copy is not needed
xpu_forward_meta.slot_mapping_enc = slot_mapping_enc
if use_cudagraph and forward_meta is not None:
xpu_forward_meta.slot_mapping_dec.copy_(slot_mapping_dec, False)
else:
return xpu_forward_meta
xpu_forward_meta.slot_mapping_dec = slot_mapping_dec
return xpu_forward_meta
def xpu_process_output(