mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[XPU] get_infer_param use inplace copy, remove block_tables abundant d2h copy (#7431)
* inplace_copy: encoder_batch_idx/decoder_batch_idx bs == 9 ok * inplace_copy: encoder_seq_lod/decoder_seq_lod bs == 9 ok * inplace_copy: all bs == 9 ok * inplace_copy: all cpu bs == 9 ok * inplace_copy: len_info_cpu bs == 9 ok * finished and rm unused code * prefix_block_tables reuse * refine * improve performance * remove block_table copy to cpu * fix unit test * fix * resolve conflict * refine code * fix * fix * fix * fix * fix * try fix unit tests * fix * tmp save * fix unit test * get_infer_param try less return values * add yinwei fix --------- Co-authored-by: yinwei <yinwei_hust@163.com>
This commit is contained in:
@@ -0,0 +1,54 @@
|
||||
import paddle
|
||||
|
||||
|
||||
def init_inplace_tensor(bsz, block_tables_shape):
|
||||
encoder_batch_map = paddle.empty(bsz, dtype="int32")
|
||||
decoder_batch_map = paddle.empty(bsz, dtype="int32")
|
||||
encoder_batch_idx = paddle.empty(bsz, dtype="int32")
|
||||
decoder_batch_idx = paddle.empty(bsz, dtype="int32")
|
||||
encoder_seq_lod = paddle.empty(bsz + 1, dtype="int32")
|
||||
decoder_seq_lod = paddle.empty(bsz + 1, dtype="int32")
|
||||
encoder_kv_lod = paddle.empty(bsz + 1, dtype="int32")
|
||||
prefix_len = paddle.empty(bsz, dtype="int32")
|
||||
decoder_context_len = paddle.empty(bsz, dtype="int32")
|
||||
decoder_context_len_cache = paddle.empty(bsz, dtype="int32")
|
||||
|
||||
prefix_block_tables = paddle.empty(block_tables_shape, dtype="int32")
|
||||
|
||||
encoder_batch_map_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
|
||||
decoder_batch_map_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
|
||||
encoder_batch_idx_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
|
||||
decoder_batch_idx_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
|
||||
encoder_seq_lod_cpu = paddle.empty(bsz + 1, dtype="int32", device="cpu")
|
||||
decoder_seq_lod_cpu = paddle.empty(bsz + 1, dtype="int32", device="cpu")
|
||||
encoder_kv_lod_cpu = paddle.empty(bsz + 1, dtype="int32", device="cpu")
|
||||
prefix_len_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
|
||||
decoder_context_len_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
|
||||
decoder_context_len_cache_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
|
||||
|
||||
len_info_cpu = paddle.empty(7, dtype="int32", device="cpu")
|
||||
|
||||
return (
|
||||
encoder_batch_map,
|
||||
decoder_batch_map,
|
||||
encoder_batch_idx,
|
||||
decoder_batch_idx,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_kv_lod,
|
||||
prefix_len,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
prefix_block_tables,
|
||||
encoder_batch_map_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
encoder_batch_idx_cpu,
|
||||
decoder_batch_idx_cpu,
|
||||
encoder_seq_lod_cpu,
|
||||
decoder_seq_lod_cpu,
|
||||
encoder_kv_lod_cpu,
|
||||
prefix_len_cpu,
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
len_info_cpu,
|
||||
)
|
||||
Reference in New Issue
Block a user