mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 08:21:53 +08:00
[XPU] get_infer_param use inplace copy, remove block_tables abundant d2h copy (#7431)
* inplace_copy: encoder_batch_idx/decoder_batch_idx bs == 9 ok * inplace_copy: encoder_seq_lod/decoder_seq_lod bs == 9 ok * inplace_copy: all bs == 9 ok * inplace_copy: all cpu bs == 9 ok * inplace_copy: len_info_cpu bs == 9 ok * finished and rm unused code * prefix_block_tables reuse * refine * improve performance * remove block_table copy to cpu * fix unit test * fix * resolve conflict * refine code * fix * fix * fix * fix * fix * try fix unit tests * fix * tmp save * fix unit test * get_infer_param try less return values * add yinwei fix --------- Co-authored-by: yinwei <yinwei_hust@163.com>
This commit is contained in:
@@ -158,70 +158,113 @@ def xpu_pre_process(
|
||||
share_inputs["cu_seqlens_q"] = cu_seqlens_q
|
||||
share_inputs["cu_seqlens_k"] = cu_seqlens_k
|
||||
|
||||
xpu_forward_meta = XPUForwardMeta(
|
||||
ids_remove_padding=share_inputs["ids_remove_padding"],
|
||||
rotary_embs=share_inputs["rope_emb"],
|
||||
attn_backend=None,
|
||||
seq_lens_encoder=share_inputs["seq_lens_encoder"],
|
||||
seq_lens_decoder=share_inputs["seq_lens_decoder"],
|
||||
seq_lens_this_time=share_inputs["seq_lens_this_time"],
|
||||
batch_id_per_token=share_inputs["batch_id_per_token"],
|
||||
cu_seqlens_q=share_inputs["cu_seqlens_q"],
|
||||
cu_seqlens_k=share_inputs["cu_seqlens_k"],
|
||||
block_tables=share_inputs["block_tables"],
|
||||
caches=share_inputs["caches"],
|
||||
max_num_seqs=share_inputs["seq_lens_this_time"].shape[0],
|
||||
is_speculative=use_speculate_method,
|
||||
)
|
||||
if use_cudagraph and forward_meta is not None:
|
||||
forward_meta.ids_remove_padding.copy_(share_inputs["ids_remove_padding"], False)
|
||||
forward_meta.rotary_embs.copy_(share_inputs["rope_emb"], False)
|
||||
forward_meta.attn_backend = None
|
||||
forward_meta.seq_lens_encoder.copy_(share_inputs["seq_lens_encoder"], False)
|
||||
forward_meta.seq_lens_decoder.copy_(share_inputs["seq_lens_decoder"], False)
|
||||
forward_meta.seq_lens_this_time.copy_(share_inputs["seq_lens_this_time"], False)
|
||||
forward_meta.batch_id_per_token.copy_(share_inputs["batch_id_per_token"], False)
|
||||
forward_meta.cu_seqlens_q.copy_(share_inputs["cu_seqlens_q"], False)
|
||||
forward_meta.cu_seqlens_k.copy_(share_inputs["cu_seqlens_k"], False)
|
||||
forward_meta.block_tables.copy_(share_inputs["block_tables"], False)
|
||||
forward_meta.caches = share_inputs["caches"]
|
||||
forward_meta.max_num_seqs = share_inputs["seq_lens_this_time"].shape[0]
|
||||
forward_meta.is_speculative = use_speculate_method
|
||||
|
||||
xpu_forward_meta = forward_meta
|
||||
else:
|
||||
xpu_forward_meta = XPUForwardMeta(
|
||||
ids_remove_padding=share_inputs["ids_remove_padding"],
|
||||
rotary_embs=share_inputs["rope_emb"],
|
||||
attn_backend=None,
|
||||
seq_lens_encoder=share_inputs["seq_lens_encoder"],
|
||||
seq_lens_decoder=share_inputs["seq_lens_decoder"],
|
||||
seq_lens_this_time=share_inputs["seq_lens_this_time"],
|
||||
batch_id_per_token=share_inputs["batch_id_per_token"],
|
||||
cu_seqlens_q=share_inputs["cu_seqlens_q"],
|
||||
cu_seqlens_k=share_inputs["cu_seqlens_k"],
|
||||
block_tables=share_inputs["block_tables"],
|
||||
caches=share_inputs["caches"],
|
||||
max_num_seqs=share_inputs["seq_lens_this_time"].shape[0],
|
||||
is_speculative=use_speculate_method,
|
||||
)
|
||||
xpu_forward_meta.init_inplace_tensor(seq_lens_encoder.shape[0], share_inputs["block_tables"].shape)
|
||||
|
||||
block_tables = xpu_forward_meta.block_tables
|
||||
|
||||
encoder_batch_map = xpu_forward_meta.encoder_batch_map
|
||||
decoder_batch_map = xpu_forward_meta.decoder_batch_map
|
||||
encoder_batch_idx = xpu_forward_meta.encoder_batch_idx
|
||||
decoder_batch_idx = xpu_forward_meta.decoder_batch_idx
|
||||
encoder_seq_lod = xpu_forward_meta.encoder_seq_lod
|
||||
decoder_seq_lod = xpu_forward_meta.decoder_seq_lod
|
||||
encoder_kv_lod = xpu_forward_meta.encoder_kv_lod
|
||||
prefix_len = xpu_forward_meta.prefix_len
|
||||
decoder_context_len = xpu_forward_meta.decoder_context_len
|
||||
decoder_context_len_cache = xpu_forward_meta.decoder_context_len_cache
|
||||
|
||||
prefix_block_tables = xpu_forward_meta.prefix_block_tables
|
||||
|
||||
encoder_batch_map_cpu = xpu_forward_meta.encoder_batch_map_cpu
|
||||
decoder_batch_map_cpu = xpu_forward_meta.decoder_batch_map_cpu
|
||||
encoder_batch_idx_cpu = xpu_forward_meta.encoder_batch_idx_cpu
|
||||
decoder_batch_idx_cpu = xpu_forward_meta.decoder_batch_idx_cpu
|
||||
encoder_seq_lod_cpu = xpu_forward_meta.encoder_seq_lod_cpu
|
||||
decoder_seq_lod_cpu = xpu_forward_meta.decoder_seq_lod_cpu
|
||||
encoder_kv_lod_cpu = xpu_forward_meta.encoder_kv_lod_cpu
|
||||
prefix_len_cpu = xpu_forward_meta.prefix_len_cpu
|
||||
decoder_context_len_cpu = xpu_forward_meta.decoder_context_len_cpu
|
||||
decoder_context_len_cache_cpu = xpu_forward_meta.decoder_context_len_cache_cpu
|
||||
|
||||
len_info_cpu = xpu_forward_meta.len_info_cpu
|
||||
|
||||
(
|
||||
xpu_forward_meta.encoder_batch_map,
|
||||
xpu_forward_meta.decoder_batch_map,
|
||||
xpu_forward_meta.encoder_batch_idx,
|
||||
xpu_forward_meta.decoder_batch_idx,
|
||||
xpu_forward_meta.encoder_seq_lod,
|
||||
xpu_forward_meta.decoder_seq_lod,
|
||||
xpu_forward_meta.encoder_kv_lod,
|
||||
xpu_forward_meta.prefix_len,
|
||||
xpu_forward_meta.decoder_context_len,
|
||||
xpu_forward_meta.decoder_context_len_cache,
|
||||
xpu_forward_meta.prefix_block_tables,
|
||||
xpu_forward_meta.encoder_batch_map_cpu,
|
||||
xpu_forward_meta.decoder_batch_map_cpu,
|
||||
xpu_forward_meta.encoder_batch_idx_cpu,
|
||||
xpu_forward_meta.decoder_batch_idx_cpu,
|
||||
xpu_forward_meta.encoder_seq_lod_cpu,
|
||||
xpu_forward_meta.decoder_seq_lod_cpu,
|
||||
xpu_forward_meta.encoder_kv_lod_cpu,
|
||||
xpu_forward_meta.prefix_len_cpu,
|
||||
xpu_forward_meta.decoder_context_len_cpu,
|
||||
xpu_forward_meta.decoder_context_len_cache_cpu,
|
||||
xpu_forward_meta.len_info_cpu,
|
||||
xpu_forward_meta.slot_mapping_enc,
|
||||
xpu_forward_meta.slot_mapping_dec,
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
) = get_infer_param(
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
xpu_forward_meta.block_tables,
|
||||
block_tables,
|
||||
encoder_batch_map,
|
||||
decoder_batch_map,
|
||||
encoder_batch_idx,
|
||||
decoder_batch_idx,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_kv_lod,
|
||||
prefix_len,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
prefix_block_tables,
|
||||
encoder_batch_map_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
encoder_batch_idx_cpu,
|
||||
decoder_batch_idx_cpu,
|
||||
encoder_seq_lod_cpu,
|
||||
decoder_seq_lod_cpu,
|
||||
encoder_kv_lod_cpu,
|
||||
prefix_len_cpu,
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
len_info_cpu,
|
||||
block_size,
|
||||
num_speculative_tokens,
|
||||
)
|
||||
xpu_forward_meta.enc_batch = xpu_forward_meta.len_info_cpu[0]
|
||||
xpu_forward_meta.dec_batch = xpu_forward_meta.len_info_cpu[1]
|
||||
xpu_forward_meta.total_enc_len = xpu_forward_meta.len_info_cpu[2]
|
||||
|
||||
adjusted_input = adjust_batch(
|
||||
ids_remove_padding.reshape([-1, 1]),
|
||||
xpu_forward_meta.encoder_seq_lod,
|
||||
xpu_forward_meta.decoder_seq_lod,
|
||||
xpu_forward_meta.encoder_batch_idx,
|
||||
xpu_forward_meta.decoder_batch_idx,
|
||||
xpu_forward_meta.encoder_seq_lod_cpu,
|
||||
xpu_forward_meta.decoder_seq_lod_cpu,
|
||||
xpu_forward_meta.encoder_batch_idx_cpu,
|
||||
xpu_forward_meta.decoder_batch_idx_cpu,
|
||||
xpu_forward_meta.len_info_cpu,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_batch_idx,
|
||||
decoder_batch_idx,
|
||||
encoder_seq_lod_cpu,
|
||||
decoder_seq_lod_cpu,
|
||||
encoder_batch_idx_cpu,
|
||||
decoder_batch_idx_cpu,
|
||||
len_info_cpu,
|
||||
None, # output_padding_offset
|
||||
-1, # max bs
|
||||
)
|
||||
@@ -229,17 +272,22 @@ def xpu_pre_process(
|
||||
adjusted_input = adjusted_input.squeeze(1)
|
||||
|
||||
share_inputs["ids_remove_padding"].copy_(adjusted_input, False)
|
||||
|
||||
xpu_forward_meta.enc_batch = len_info_cpu[0]
|
||||
xpu_forward_meta.dec_batch = len_info_cpu[1]
|
||||
xpu_forward_meta.total_enc_len = len_info_cpu[2]
|
||||
xpu_forward_meta.ids_remove_padding = adjusted_input
|
||||
# Set forward_meta.is_profiling to True to skip init_kv_signal_per_query for attention backends
|
||||
# Set xpu_forward_meta.is_profiling to True to skip init_kv_signal_per_query for attention backends
|
||||
xpu_forward_meta.is_profiling = is_profiling
|
||||
if use_cudagraph:
|
||||
if forward_meta is None:
|
||||
return xpu_forward_meta
|
||||
else:
|
||||
forward_meta.copy_from(xpu_forward_meta)
|
||||
return forward_meta
|
||||
|
||||
# prefill does not use cudagraph, inplace copy is not needed
|
||||
xpu_forward_meta.slot_mapping_enc = slot_mapping_enc
|
||||
if use_cudagraph and forward_meta is not None:
|
||||
xpu_forward_meta.slot_mapping_dec.copy_(slot_mapping_dec, False)
|
||||
else:
|
||||
return xpu_forward_meta
|
||||
xpu_forward_meta.slot_mapping_dec = slot_mapping_dec
|
||||
|
||||
return xpu_forward_meta
|
||||
|
||||
|
||||
def xpu_process_output(
|
||||
|
||||
Reference in New Issue
Block a user