[XPU] get_infer_param use inplace copy, remove block_tables abundant d2h copy (#7431)

* inplace_copy: encoder_batch_idx/decoder_batch_idx bs == 9 ok

* inplace_copy: encoder_seq_lod/decoder_seq_lod bs == 9 ok

* inplace_copy: all bs == 9 ok

* inplace_copy: all cpu bs == 9 ok

* inplace_copy: len_info_cpu bs == 9 ok

* finished and rm unused code

* prefix_block_tables reuse

* refine

* improve performance

* remove block_table copy to cpu

* fix unit test

* fix

* resolve conflict

* refine code

* fix

* fix

* fix

* fix

* fix

* try fix unit tests

* fix

* tmp save

* fix unit test

* get_infer_param try less return values

* add yinwei fix

---------

Co-authored-by: yinwei <yinwei_hust@163.com>
This commit is contained in:
RuohengMa
2026-04-22 11:01:32 +08:00
committed by GitHub
parent 2edb30c2d0
commit 36d47aa23e
10 changed files with 628 additions and 302 deletions
+64 -3
View File
@@ -16,6 +16,7 @@ import random
import numpy as np
import paddle
from utils import init_inplace_tensor
# block_attn_fused is deprecated and should be removed in the future
from fastdeploy.model_executor.ops.xpu import (
@@ -76,6 +77,7 @@ def run_prefix_cache_block_attn(
# prefix cache block attn
seq_lens_encoder = paddle.to_tensor([seq_len - hit_prefix_len, 0, 0, 0, 0], dtype="int32")
seq_lens_decoder = paddle.to_tensor([hit_prefix_len, 0, 0, 0, 0], dtype="int32")
(
encoder_batch_map,
decoder_batch_map,
@@ -99,11 +101,40 @@ def run_prefix_cache_block_attn(
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
) = init_inplace_tensor(seq_lens_encoder.shape[0], block_tables.shape)
(
slot_mapping_enc,
slot_mapping_dec,
) = get_infer_param(
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, num_speculative_tokens
) # block_size
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
block_tables,
encoder_batch_map,
decoder_batch_map,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
prefix_len,
decoder_context_len,
decoder_context_len_cache,
prefix_block_tables,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
64,
num_speculative_tokens,
)
qkv_prefix = qkv[hit_prefix_len:]
attn_out_prefix_cache = block_attn_func(
qkv_prefix,
@@ -194,6 +225,7 @@ def run_block_attn(
seq_lens_this_time = paddle.to_tensor([seq_len, 0, 0, 0, 0], dtype="int32")
block_tables = paddle.arange(0, block_batch * max_block_per_seq, dtype="int32")
block_tables = block_tables.reshape((block_batch, max_block_per_seq))
(
encoder_batch_map,
decoder_batch_map,
@@ -217,10 +249,39 @@ def run_block_attn(
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
) = init_inplace_tensor(seq_lens_encoder.shape[0], block_tables.shape)
(
slot_mapping_enc,
slot_mapping_dec,
) = get_infer_param(
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, num_speculative_tokens
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
block_tables,
encoder_batch_map,
decoder_batch_map,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
prefix_len,
decoder_context_len,
decoder_context_len_cache,
prefix_block_tables,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
64,
num_speculative_tokens,
)
qkv = paddle.uniform(
shape=[seq_len, (head_num + 2 * kv_head_num) * head_dim], dtype="bfloat16", min=-1.0, max=1.0, seed=seed