mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[XPU] get_infer_param use inplace copy, remove block_tables abundant d2h copy (#7431)
* inplace_copy: encoder_batch_idx/decoder_batch_idx bs == 9 ok * inplace_copy: encoder_seq_lod/decoder_seq_lod bs == 9 ok * inplace_copy: all bs == 9 ok * inplace_copy: all cpu bs == 9 ok * inplace_copy: len_info_cpu bs == 9 ok * finished and rm unused code * prefix_block_tables reuse * refine * improve performance * remove block_table copy to cpu * fix unit test * fix * resolve conflict * refine code * fix * fix * fix * fix * fix * try fix unit tests * fix * tmp save * fix unit test * get_infer_param try less return values * add yinwei fix --------- Co-authored-by: yinwei <yinwei_hust@163.com>
This commit is contained in:
@@ -16,6 +16,7 @@ import random
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from utils import init_inplace_tensor
|
||||
|
||||
# block_attn_fused is deprecated and should be removed in the future
|
||||
from fastdeploy.model_executor.ops.xpu import (
|
||||
@@ -76,6 +77,7 @@ def run_prefix_cache_block_attn(
|
||||
# prefix cache block attn
|
||||
seq_lens_encoder = paddle.to_tensor([seq_len - hit_prefix_len, 0, 0, 0, 0], dtype="int32")
|
||||
seq_lens_decoder = paddle.to_tensor([hit_prefix_len, 0, 0, 0, 0], dtype="int32")
|
||||
|
||||
(
|
||||
encoder_batch_map,
|
||||
decoder_batch_map,
|
||||
@@ -99,11 +101,40 @@ def run_prefix_cache_block_attn(
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
len_info_cpu,
|
||||
) = init_inplace_tensor(seq_lens_encoder.shape[0], block_tables.shape)
|
||||
(
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
) = get_infer_param(
|
||||
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, num_speculative_tokens
|
||||
) # block_size
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
block_tables,
|
||||
encoder_batch_map,
|
||||
decoder_batch_map,
|
||||
encoder_batch_idx,
|
||||
decoder_batch_idx,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_kv_lod,
|
||||
prefix_len,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
prefix_block_tables,
|
||||
encoder_batch_map_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
encoder_batch_idx_cpu,
|
||||
decoder_batch_idx_cpu,
|
||||
encoder_seq_lod_cpu,
|
||||
decoder_seq_lod_cpu,
|
||||
encoder_kv_lod_cpu,
|
||||
prefix_len_cpu,
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
len_info_cpu,
|
||||
64,
|
||||
num_speculative_tokens,
|
||||
)
|
||||
qkv_prefix = qkv[hit_prefix_len:]
|
||||
attn_out_prefix_cache = block_attn_func(
|
||||
qkv_prefix,
|
||||
@@ -194,6 +225,7 @@ def run_block_attn(
|
||||
seq_lens_this_time = paddle.to_tensor([seq_len, 0, 0, 0, 0], dtype="int32")
|
||||
block_tables = paddle.arange(0, block_batch * max_block_per_seq, dtype="int32")
|
||||
block_tables = block_tables.reshape((block_batch, max_block_per_seq))
|
||||
|
||||
(
|
||||
encoder_batch_map,
|
||||
decoder_batch_map,
|
||||
@@ -217,10 +249,39 @@ def run_block_attn(
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
len_info_cpu,
|
||||
) = init_inplace_tensor(seq_lens_encoder.shape[0], block_tables.shape)
|
||||
(
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
) = get_infer_param(
|
||||
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, num_speculative_tokens
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
block_tables,
|
||||
encoder_batch_map,
|
||||
decoder_batch_map,
|
||||
encoder_batch_idx,
|
||||
decoder_batch_idx,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_kv_lod,
|
||||
prefix_len,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
prefix_block_tables,
|
||||
encoder_batch_map_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
encoder_batch_idx_cpu,
|
||||
decoder_batch_idx_cpu,
|
||||
encoder_seq_lod_cpu,
|
||||
decoder_seq_lod_cpu,
|
||||
encoder_kv_lod_cpu,
|
||||
prefix_len_cpu,
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
len_info_cpu,
|
||||
64,
|
||||
num_speculative_tokens,
|
||||
)
|
||||
qkv = paddle.uniform(
|
||||
shape=[seq_len, (head_num + 2 * kv_head_num) * head_dim], dtype="bfloat16", min=-1.0, max=1.0, seed=seed
|
||||
|
||||
Reference in New Issue
Block a user