[XPU] get_infer_param use inplace copy, remove block_tables abundant d2h copy (#7431)

* inplace_copy: encoder_batch_idx/decoder_batch_idx bs == 9 ok

* inplace_copy: encoder_seq_lod/decoder_seq_lod bs == 9 ok

* inplace_copy: all bs == 9 ok

* inplace_copy: all cpu bs == 9 ok

* inplace_copy: len_info_cpu bs == 9 ok

* finished and rm unused code

* prefix_block_tables reuse

* refine

* improve performance

* remove block_table copy to cpu

* fix unit test

* fix

* resolve conflict

* refine code

* fix

* fix

* fix

* fix

* fix

* try fix unit tests

* fix

* tmp save

* fix unit test

* get_infer_param try less return values

* add yinwei fix

---------

Co-authored-by: yinwei <yinwei_hust@163.com>
This commit is contained in:
RuohengMa
2026-04-22 11:01:32 +08:00
committed by GitHub
parent 2edb30c2d0
commit 36d47aa23e
10 changed files with 628 additions and 302 deletions
+54
View File
@@ -0,0 +1,54 @@
import paddle
def init_inplace_tensor(bsz, block_tables_shape):
encoder_batch_map = paddle.empty(bsz, dtype="int32")
decoder_batch_map = paddle.empty(bsz, dtype="int32")
encoder_batch_idx = paddle.empty(bsz, dtype="int32")
decoder_batch_idx = paddle.empty(bsz, dtype="int32")
encoder_seq_lod = paddle.empty(bsz + 1, dtype="int32")
decoder_seq_lod = paddle.empty(bsz + 1, dtype="int32")
encoder_kv_lod = paddle.empty(bsz + 1, dtype="int32")
prefix_len = paddle.empty(bsz, dtype="int32")
decoder_context_len = paddle.empty(bsz, dtype="int32")
decoder_context_len_cache = paddle.empty(bsz, dtype="int32")
prefix_block_tables = paddle.empty(block_tables_shape, dtype="int32")
encoder_batch_map_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
decoder_batch_map_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
encoder_batch_idx_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
decoder_batch_idx_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
encoder_seq_lod_cpu = paddle.empty(bsz + 1, dtype="int32", device="cpu")
decoder_seq_lod_cpu = paddle.empty(bsz + 1, dtype="int32", device="cpu")
encoder_kv_lod_cpu = paddle.empty(bsz + 1, dtype="int32", device="cpu")
prefix_len_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
decoder_context_len_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
decoder_context_len_cache_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
len_info_cpu = paddle.empty(7, dtype="int32", device="cpu")
return (
encoder_batch_map,
decoder_batch_map,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
prefix_len,
decoder_context_len,
decoder_context_len_cache,
prefix_block_tables,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
)