[XPU] get_infer_param use inplace copy, remove block_tables abundant d2h copy (#7431)

* inplace_copy: encoder_batch_idx/decoder_batch_idx bs == 9 ok

* inplace_copy: encoder_seq_lod/decoder_seq_lod bs == 9 ok

* inplace_copy: all bs == 9 ok

* inplace_copy: all cpu bs == 9 ok

* inplace_copy: len_info_cpu bs == 9 ok

* finished and rm unused code

* prefix_block_tables reuse

* refine

* improve performance

* remove block_table copy to cpu

* fix unit test

* fix

* resolve conflict

* refine code

* fix

* fix

* fix

* fix

* fix

* try fix unit tests

* fix

* tmp save

* fix unit test

* get_infer_param try less return values

* add yinwei fix

---------

Co-authored-by: yinwei <yinwei_hust@163.com>
This commit is contained in:
RuohengMa
2026-04-22 11:01:32 +08:00
committed by GitHub
parent 2edb30c2d0
commit 36d47aa23e
10 changed files with 628 additions and 302 deletions
+144 -202
View File
@@ -25,7 +25,7 @@ namespace api = baidu::xpu::api;
void lod_to_slot_mapping(api::Context* xpu_ctx,
paddle::Place place,
const std::vector<int32_t>& block_table,
const paddle::Tensor& block_table_xpu,
const std::vector<int32_t>& kv_seq_lod,
const std::vector<int32_t>& start_tokens,
const std::vector<int32_t>& real_batch,
@@ -35,10 +35,16 @@ void lod_to_slot_mapping(api::Context* xpu_ctx,
int32_t batch_size,
int32_t max_num_blocks_per_seq,
int32_t num_speculative_tokens) {
if (token_num <= 0) {
int32_t actual_token_num = kv_seq_lod[batch_size];
if (token_num <= 0 || actual_token_num <= 0) {
return;
}
std::vector<int32_t> slot_mapping_vec(token_num, -1);
int ret;
std::vector<int32_t> block_table_idx_vec(actual_token_num);
std::vector<int32_t> seq_offset_vec(actual_token_num);
int32_t idx = 0;
// For each Batch
for (auto batch_ = 0; batch_ < batch_size; batch_++) {
@@ -47,20 +53,56 @@ void lod_to_slot_mapping(api::Context* xpu_ctx,
int32_t dst_batch_id = real_batch[batch_];
// for each token
for (auto seq_ = seq_start; seq_ < seq_start + seq_len; seq_++) {
int32_t table_id = seq_ / block_size;
int32_t block_id =
block_table[dst_batch_id * max_num_blocks_per_seq + table_id];
int32_t seq_offset = seq_ % block_size;
int32_t dst_token_offset = block_id * block_size + seq_offset;
slot_mapping_vec[idx] = dst_token_offset;
block_table_idx_vec[idx] =
seq_ / block_size + dst_batch_id * max_num_blocks_per_seq;
seq_offset_vec[idx] = seq_ % block_size;
idx++;
}
}
int ret = api::do_host2device(xpu_ctx,
slot_mapping_vec.data(),
slot_mapping,
token_num * sizeof(int32_t));
auto block_table_idx =
paddle::empty({actual_token_num}, paddle::DataType::INT32, place);
auto seq_offset =
paddle::empty({actual_token_num}, paddle::DataType::INT32, place);
ret = api::do_host2device(xpu_ctx,
block_table_idx_vec.data(),
block_table_idx.data<int32_t>(),
actual_token_num * sizeof(int32_t));
PD_CHECK(ret == api::SUCCESS, "api::do_host2device failed.");
ret = api::do_host2device(xpu_ctx,
seq_offset_vec.data(),
seq_offset.data<int32_t>(),
actual_token_num * sizeof(int32_t));
PD_CHECK(ret == api::SUCCESS, "api::do_host2device failed.");
// int32_t block_id =
// block_table[dst_batch_id * max_num_blocks_per_seq + table_id];
auto block_id =
paddle::empty({actual_token_num}, paddle::DataType::INT32, place);
auto block_size_tensor =
paddle::full({1}, block_size, paddle::DataType::INT32, place);
ret = api::index_select<int32_t, int32_t>(xpu_ctx,
block_table_xpu.data<int32_t>(),
block_table_idx.data<int32_t>(),
block_id.data<int32_t>(),
{block_table_xpu.numel()},
actual_token_num,
0);
PD_CHECK(ret == api::SUCCESS, "api::index_select failed.");
// int32_t dst_token_offset = block_id * block_size + seq_offset;
ret = api::broadcast_mul<int32_t>(xpu_ctx,
block_id.data<int32_t>(),
block_size_tensor.data<int32_t>(),
block_id.data<int32_t>(),
{actual_token_num},
{1});
PD_CHECK(ret == api::SUCCESS, "api::broadcast_mul failed.");
ret = api::broadcast_add<int32_t>(xpu_ctx,
block_id.data<int32_t>(),
seq_offset.data<int32_t>(),
slot_mapping,
{actual_token_num},
{actual_token_num});
PD_CHECK(ret == api::SUCCESS, "api::broadcast_add failed.");
}
std::vector<paddle::Tensor> GetInferParam(
@@ -68,6 +110,28 @@ std::vector<paddle::Tensor> GetInferParam(
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& block_tables,
paddle::Tensor& encoder_batch_map,
paddle::Tensor& decoder_batch_map,
paddle::Tensor& encoder_batch_idx,
paddle::Tensor& decoder_batch_idx,
paddle::Tensor& encoder_seq_lod,
paddle::Tensor& decoder_seq_lod,
paddle::Tensor& encoder_kv_lod,
paddle::Tensor& prefix_len,
paddle::Tensor& decoder_context_len,
paddle::Tensor& decoder_context_len_cache,
paddle::Tensor& prefix_block_tables,
paddle::Tensor& encoder_batch_map_cpu,
paddle::Tensor& decoder_batch_map_cpu,
paddle::Tensor& encoder_batch_idx_cpu,
paddle::Tensor& decoder_batch_idx_cpu,
paddle::Tensor& encoder_seq_lod_cpu,
paddle::Tensor& decoder_seq_lod_cpu,
paddle::Tensor& encoder_kv_lod_cpu,
paddle::Tensor& prefix_len_cpu,
paddle::Tensor& decoder_context_len_cpu,
paddle::Tensor& decoder_context_len_cache_cpu,
paddle::Tensor& len_info_cpu,
int block_size,
int num_speculative_tokens) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
@@ -128,18 +192,18 @@ std::vector<paddle::Tensor> GetInferParam(
if (seq_lens_encoder_vec[i] > 0) {
enc_batch++;
int seq_len = seq_lens_encoder_vec[i];
int prefix_len = seq_lens_decoder_vec[i];
int prefix_len_int = seq_lens_decoder_vec[i];
total_enc_len += seq_len;
max_seq_len = std::max(max_seq_len, seq_len);
max_prefix_len = std::max(max_prefix_len, prefix_len);
max_kv_len = std::max(max_kv_len, seq_len + prefix_len);
max_prefix_len = std::max(max_prefix_len, prefix_len_int);
max_kv_len = std::max(max_kv_len, seq_len + prefix_len_int);
encoder_batch_map_vec[enc_batch - 1] = i;
encoder_batch_idx_vec[enc_batch - 1] = i - batch_offset;
encoder_seq_lod_vec[enc_batch] =
seq_len + encoder_seq_lod_vec[enc_batch - 1];
encoder_kv_lod_vec[enc_batch] =
seq_len + prefix_len + encoder_kv_lod_vec[enc_batch - 1];
prefix_len_vec[enc_batch - 1] = prefix_len;
seq_len + prefix_len_int + encoder_kv_lod_vec[enc_batch - 1];
prefix_len_vec[enc_batch - 1] = prefix_len_int;
} else if (seq_lens_decoder_vec[i] > 0 && seq_lens_this_time_vec[i] > 0) {
dec_batch++;
max_dec_len = std::max(max_dec_len, seq_lens_this_time_vec[i]);
@@ -186,42 +250,6 @@ std::vector<paddle::Tensor> GetInferParam(
prefix_block_num_per_seq = -1;
}
auto encoder_batch_map = paddle::empty({encoder_batch_map_vec.size()},
seq_lens_encoder.type(),
seq_lens_encoder.place());
auto decoder_batch_map = paddle::empty({decoder_batch_map_vec.size()},
seq_lens_encoder.type(),
seq_lens_encoder.place());
auto encoder_batch_idx = paddle::empty({encoder_batch_idx_vec.size()},
seq_lens_encoder.type(),
seq_lens_encoder.place());
auto decoder_batch_idx = paddle::empty({decoder_batch_idx_vec.size()},
seq_lens_encoder.type(),
seq_lens_encoder.place());
auto encoder_seq_lod = paddle::empty({encoder_seq_lod_vec.size()},
seq_lens_encoder.type(),
seq_lens_encoder.place());
auto decoder_seq_lod = paddle::empty({decoder_seq_lod_vec.size()},
seq_lens_encoder.type(),
seq_lens_encoder.place());
auto encoder_kv_lod = paddle::empty({encoder_kv_lod_vec.size()},
seq_lens_encoder.type(),
seq_lens_encoder.place());
auto prefix_len = paddle::empty({prefix_len_vec.size()},
seq_lens_encoder.type(),
seq_lens_encoder.place());
auto decoder_context_len = paddle::empty({decoder_context_len_vec.size()},
seq_lens_encoder.type(),
seq_lens_encoder.place());
auto decoder_context_len_cache =
paddle::empty({decoder_context_len_cache_vec.size()},
seq_lens_encoder.type(),
seq_lens_encoder.place());
auto prefix_block_tables =
paddle::empty({block_bs, block_num_per_seq}, // full size
seq_lens_encoder.type(),
seq_lens_encoder.place());
// for store_paged_kv_cache of cudagraph mode
// if slot_mapping is -1, store_paged_kv_cache will not write to kv cache
paddle::Tensor slot_mapping_enc = paddle::full(
@@ -232,72 +260,34 @@ std::vector<paddle::Tensor> GetInferParam(
-1,
paddle::DataType::INT32,
seq_lens_decoder.place());
if (FLAGS_encoder_splice || FLAGS_decoder_splice) {
std::vector<int32_t> block_tables_vec(block_bs * block_num_per_seq);
r = xpu_memcpy(block_tables_vec.data(),
block_tables.data<int32_t>(),
sizeof(int32_t) * block_bs * block_num_per_seq,
XPUMemcpyKind::XPU_DEVICE_TO_HOST);
if (FLAGS_encoder_splice) {
lod_to_slot_mapping(xpu_ctx->x_context(),
seq_lens_encoder.place(),
block_tables_vec,
encoder_seq_lod_vec,
prefix_len_vec,
encoder_batch_map_vec,
slot_mapping_enc.data<int32_t>(),
total_enc_len,
block_size,
enc_batch,
block_num_per_seq,
0);
}
if (FLAGS_decoder_splice) {
lod_to_slot_mapping(xpu_ctx->x_context(),
seq_lens_decoder.place(),
block_tables_vec,
decoder_seq_lod_vec,
decoder_context_len_cache_vec,
decoder_batch_map_vec,
slot_mapping_dec.data<int32_t>(),
bsz * (1 + num_speculative_tokens),
block_size,
dec_batch,
block_num_per_seq,
num_speculative_tokens);
}
if (FLAGS_encoder_splice) {
lod_to_slot_mapping(xpu_ctx->x_context(),
seq_lens_encoder.place(),
block_tables,
encoder_seq_lod_vec,
prefix_len_vec,
encoder_batch_map_vec,
slot_mapping_enc.data<int32_t>(),
slot_mapping_enc.numel(),
block_size,
enc_batch,
block_num_per_seq,
0);
}
if (FLAGS_decoder_splice) {
lod_to_slot_mapping(xpu_ctx->x_context(),
seq_lens_decoder.place(),
block_tables,
decoder_seq_lod_vec,
decoder_context_len_cache_vec,
decoder_batch_map_vec,
slot_mapping_dec.data<int32_t>(),
slot_mapping_dec.numel(),
block_size,
dec_batch,
block_num_per_seq,
num_speculative_tokens);
}
auto encoder_batch_map_cpu = paddle::empty({encoder_batch_map_vec.size()},
seq_lens_encoder.type(),
paddle::CPUPlace());
auto decoder_batch_map_cpu = paddle::empty({decoder_batch_map_vec.size()},
seq_lens_encoder.type(),
paddle::CPUPlace());
auto encoder_batch_idx_cpu = paddle::empty({encoder_batch_idx_vec.size()},
seq_lens_encoder.type(),
paddle::CPUPlace());
auto decoder_batch_idx_cpu = paddle::empty({decoder_batch_idx_vec.size()},
seq_lens_encoder.type(),
paddle::CPUPlace());
auto encoder_seq_lod_cpu = paddle::empty({encoder_seq_lod_vec.size()},
seq_lens_encoder.type(),
paddle::CPUPlace());
auto decoder_seq_lod_cpu = paddle::empty({decoder_seq_lod_vec.size()},
seq_lens_encoder.type(),
paddle::CPUPlace());
auto encoder_kv_lod_cpu = paddle::empty(
{encoder_kv_lod_vec.size()}, seq_lens_encoder.type(), paddle::CPUPlace());
auto prefix_len_cpu = paddle::empty(
{prefix_len_vec.size()}, seq_lens_encoder.type(), paddle::CPUPlace());
auto decoder_context_len_cpu = paddle::empty({decoder_context_len_vec.size()},
seq_lens_encoder.type(),
paddle::CPUPlace());
auto decoder_context_len_cache_cpu =
paddle::empty({decoder_context_len_cache_vec.size()},
seq_lens_encoder.type(),
paddle::CPUPlace());
ret = api::do_host2device(
xpu_ctx->x_context(),
@@ -400,65 +390,25 @@ std::vector<paddle::Tensor> GetInferParam(
max_kv_len,
prefix_block_num_per_seq,
max_dec_len};
auto len_info_cpu =
paddle::empty({7}, seq_lens_encoder.type(), paddle::CPUPlace());
std::memcpy(len_info_cpu.data<int32_t>(),
len_info_vec.data(),
sizeof(int32_t) * len_info_vec.size());
return {encoder_batch_map,
decoder_batch_map,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
prefix_len,
decoder_context_len,
decoder_context_len_cache,
prefix_block_tables,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
slot_mapping_enc,
slot_mapping_dec};
return {slot_mapping_enc, slot_mapping_dec};
}
std::vector<std::vector<int64_t>> GetInferParamInferShape(
const std::vector<int64_t>& seq_lens_encoder_shape,
const std::vector<int64_t>& seq_lens_decoder_shape,
const std::vector<int64_t>& seq_lens_this_time_shape,
const std::vector<int64_t>& block_tables_shape) {
return {seq_lens_encoder_shape,
seq_lens_encoder_shape,
seq_lens_encoder_shape,
seq_lens_encoder_shape,
{seq_lens_encoder_shape[0] + 1},
{seq_lens_encoder_shape[0] + 1},
{seq_lens_encoder_shape[0] + 1},
seq_lens_encoder_shape,
seq_lens_encoder_shape,
seq_lens_encoder_shape,
block_tables_shape,
seq_lens_encoder_shape,
seq_lens_encoder_shape,
seq_lens_encoder_shape,
seq_lens_encoder_shape,
{seq_lens_encoder_shape[0] + 1},
{seq_lens_encoder_shape[0] + 1},
{seq_lens_encoder_shape[0] + 1},
seq_lens_encoder_shape,
seq_lens_encoder_shape,
seq_lens_encoder_shape,
{7}};
const std::vector<int64_t>& block_tables_shape,
int num_speculative_tokens) {
// Return shapes for slot_mapping_enc and slot_mapping_dec
// slot_mapping_enc shape depends on encoder token count (unknown at shape
// inference time) slot_mapping_dec shape depends on batch size and
// speculative token count
return {{-1}, {seq_lens_encoder_shape[0] * (1 + num_speculative_tokens)}};
}
std::vector<paddle::DataType> GetInferParamInferDtype(
@@ -466,46 +416,38 @@ std::vector<paddle::DataType> GetInferParamInferDtype(
const paddle::DataType& seq_lens_decoder_dtype,
const paddle::DataType& seq_lens_this_time_dtype,
const paddle::DataType& block_tables_dtype) {
return {
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
seq_lens_encoder_dtype, block_tables_dtype, seq_lens_encoder_dtype,
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
seq_lens_encoder_dtype};
// Return dtypes for slot_mapping_enc and slot_mapping_dec (both INT32)
return {paddle::DataType::INT32, paddle::DataType::INT32};
}
PD_BUILD_OP(get_infer_param)
.Inputs({"seq_lens_encoder",
"seq_lens_decoder",
"seq_lens_this_time",
"block_tables"})
.Outputs({"encoder_batch_map",
"decoder_batch_map",
"encoder_batch_idx",
"decoder_batch_idx",
"encoder_seq_lod",
"decoder_seq_lod",
"encoder_kv_lod",
"prefix_len",
"decoder_context_len",
"decoder_context_len_cache",
"prefix_block_tables",
"encoder_batch_map_cpu",
"decoder_batch_map_cpu",
"encoder_batch_idx_cpu",
"decoder_batch_idx_cpu",
"encoder_seq_lod_cpu",
"decoder_seq_lod_cpu",
"encoder_kv_lod_cpu",
"prefix_len_cpu",
"decoder_context_len_cpu",
"decoder_context_len_cache_cpu",
"len_info_cpu",
"slot_mapping_enc",
"slot_mapping_dec"})
"block_tables",
"encoder_batch_map",
"decoder_batch_map",
"encoder_batch_idx",
"decoder_batch_idx",
"encoder_seq_lod",
"decoder_seq_lod",
"encoder_kv_lod",
"prefix_len",
"decoder_context_len",
"decoder_context_len_cache",
"prefix_block_tables",
"encoder_batch_map_cpu",
"decoder_batch_map_cpu",
"encoder_batch_idx_cpu",
"decoder_batch_idx_cpu",
"encoder_seq_lod_cpu",
"decoder_seq_lod_cpu",
"encoder_kv_lod_cpu",
"prefix_len_cpu",
"decoder_context_len_cpu",
"decoder_context_len_cache_cpu",
"len_info_cpu"})
.Outputs({"slot_mapping_enc", "slot_mapping_dec"})
.SetKernelFn(PD_KERNEL(GetInferParam))
.Attrs({"block_size: int", "num_speculative_tokens: int"})
.SetInferShapeFn(PD_INFER_SHAPE(GetInferParamInferShape))
@@ -478,6 +478,28 @@ std::vector<paddle::Tensor> GetInferParam(
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& block_tables,
paddle::Tensor& encoder_batch_map,
paddle::Tensor& decoder_batch_map,
paddle::Tensor& encoder_batch_idx,
paddle::Tensor& decoder_batch_idx,
paddle::Tensor& encoder_seq_lod,
paddle::Tensor& decoder_seq_lod,
paddle::Tensor& encoder_kv_lod,
paddle::Tensor& prefix_len,
paddle::Tensor& decoder_context_len,
paddle::Tensor& decoder_context_len_cache,
paddle::Tensor& prefix_block_tables,
paddle::Tensor& encoder_batch_map_cpu,
paddle::Tensor& decoder_batch_map_cpu,
paddle::Tensor& encoder_batch_idx_cpu,
paddle::Tensor& decoder_batch_idx_cpu,
paddle::Tensor& encoder_seq_lod_cpu,
paddle::Tensor& decoder_seq_lod_cpu,
paddle::Tensor& encoder_kv_lod_cpu,
paddle::Tensor& prefix_len_cpu,
paddle::Tensor& decoder_context_len_cpu,
paddle::Tensor& decoder_context_len_cache_cpu,
paddle::Tensor& len_info_cpu,
int block_size,
int num_speculative_tokens);
@@ -1052,6 +1074,28 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("seq_lens_decoder"),
py::arg("seq_lens_this_time"),
py::arg("block_tables"),
py::arg("encoder_batch_map"),
py::arg("decoder_batch_map"),
py::arg("encoder_batch_idx"),
py::arg("decoder_batch_idx"),
py::arg("encoder_seq_lod"),
py::arg("decoder_seq_lod"),
py::arg("encoder_kv_lod"),
py::arg("prefix_len"),
py::arg("decoder_context_len"),
py::arg("decoder_context_len_cache"),
py::arg("prefix_block_tables"),
py::arg("encoder_batch_map_cpu"),
py::arg("decoder_batch_map_cpu"),
py::arg("encoder_batch_idx_cpu"),
py::arg("decoder_batch_idx_cpu"),
py::arg("encoder_seq_lod_cpu"),
py::arg("decoder_seq_lod_cpu"),
py::arg("encoder_kv_lod_cpu"),
py::arg("prefix_len_cpu"),
py::arg("decoder_context_len_cpu"),
py::arg("decoder_context_len_cache_cpu"),
py::arg("len_info_cpu"),
py::arg("block_size"),
py::arg("num_speculative_tokens"),
"Get infer parameters for block attention in XPU");
@@ -16,6 +16,7 @@ import unittest # 导入 unittest
import numpy as np
import paddle
from utils import init_inplace_tensor
from fastdeploy.model_executor.ops.xpu import (
adjust_batch,
@@ -33,11 +34,8 @@ def _run_test_base(seq_lens_this_time_data, is_speculative):
seq_lens_this_time = paddle.to_tensor(seq_lens_this_time_data, dtype="int32")
bsz = seq_lens_this_time.shape[0]
cum_offsets = paddle.zeros(bsz, dtype="int32")
block_table = paddle.arange(0, 56, dtype="int32").reshape((bsz, 8))
infer_params = get_infer_param(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_table, 64)
(
encoder_batch_map,
decoder_batch_map,
@@ -45,23 +43,56 @@ def _run_test_base(seq_lens_this_time_data, is_speculative):
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
_,
_,
_,
_,
_,
encoder_kv_lod,
prefix_len,
decoder_context_len,
decoder_context_len_cache,
prefix_block_tables,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
_,
_,
_,
_,
encoder_kv_lod_cpu,
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
) = infer_params
) = init_inplace_tensor(seq_lens_encoder.shape[0], block_table.shape)
(
slot_mapping_enc,
slot_mapping_dec,
) = get_infer_param(
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
block_table,
encoder_batch_map,
decoder_batch_map,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
prefix_len,
decoder_context_len,
decoder_context_len_cache,
prefix_block_tables,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
64,
0,
)
token_num = seq_lens_this_time.sum().cpu().item()
hidden_dim = 8192
@@ -72,7 +103,6 @@ def _run_test_base(seq_lens_this_time_data, is_speculative):
# 测试 adjust_batch
adjusted_output = adjust_batch(
input_tensor,
cum_offsets,
encoder_seq_lod,
decoder_seq_lod,
encoder_batch_idx,
@@ -88,7 +118,6 @@ def _run_test_base(seq_lens_this_time_data, is_speculative):
adjusted_output_cpu = adjust_batch(
input_tensor.cpu(),
cum_offsets,
encoder_seq_lod,
decoder_seq_lod,
encoder_batch_idx,
@@ -110,7 +139,6 @@ def _run_test_base(seq_lens_this_time_data, is_speculative):
# 测试 gather_next_token
gather_out = gather_next_token(
adjusted_output,
cum_offsets,
encoder_seq_lod,
decoder_seq_lod,
encoder_batch_map,
@@ -126,7 +154,6 @@ def _run_test_base(seq_lens_this_time_data, is_speculative):
gather_out_cpu = gather_next_token(
adjusted_output.cpu(),
cum_offsets,
encoder_seq_lod,
decoder_seq_lod,
encoder_batch_map,
@@ -16,6 +16,7 @@ import unittest
import numpy as np
import paddle
from utils import init_inplace_tensor
from fastdeploy.model_executor.ops.xpu import (
adjust_batch,
@@ -33,8 +34,6 @@ def _run_test_base(seq_lens_this_time_data):
cum_offsets = paddle.zeros(bsz, dtype="int32")
block_table = paddle.arange(0, 56, dtype="int32").reshape((bsz, 8))
infer_params = get_infer_param(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_table, 64)
(
encoder_batch_map,
decoder_batch_map,
@@ -42,23 +41,56 @@ def _run_test_base(seq_lens_this_time_data):
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
_,
_,
_,
_,
_,
encoder_kv_lod,
prefix_len,
decoder_context_len,
decoder_context_len_cache,
prefix_block_tables,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
_,
_,
_,
_,
encoder_kv_lod_cpu,
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
) = infer_params
) = init_inplace_tensor(seq_lens_encoder.shape[0], block_table.shape)
(
slot_mapping_enc,
slot_mapping_dec,
) = get_infer_param(
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
block_table,
encoder_batch_map,
decoder_batch_map,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
prefix_len,
decoder_context_len,
decoder_context_len_cache,
prefix_block_tables,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
64,
0,
)
token_num = seq_lens_this_time.sum().cpu().item()
hidden_dim = 8192
@@ -68,7 +100,6 @@ def _run_test_base(seq_lens_this_time_data):
# test adjust_batch
adjusted_output = adjust_batch(
input_tensor,
cum_offsets,
encoder_seq_lod,
decoder_seq_lod,
encoder_batch_idx,
@@ -84,7 +115,6 @@ def _run_test_base(seq_lens_this_time_data):
adjusted_output_cpu = adjust_batch(
input_tensor.cpu(),
cum_offsets,
encoder_seq_lod,
decoder_seq_lod,
encoder_batch_idx,
+64 -3
View File
@@ -16,6 +16,7 @@ import random
import numpy as np
import paddle
from utils import init_inplace_tensor
# block_attn_fused is deprecated and should be removed in the future
from fastdeploy.model_executor.ops.xpu import (
@@ -76,6 +77,7 @@ def run_prefix_cache_block_attn(
# prefix cache block attn
seq_lens_encoder = paddle.to_tensor([seq_len - hit_prefix_len, 0, 0, 0, 0], dtype="int32")
seq_lens_decoder = paddle.to_tensor([hit_prefix_len, 0, 0, 0, 0], dtype="int32")
(
encoder_batch_map,
decoder_batch_map,
@@ -99,11 +101,40 @@ def run_prefix_cache_block_attn(
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
) = init_inplace_tensor(seq_lens_encoder.shape[0], block_tables.shape)
(
slot_mapping_enc,
slot_mapping_dec,
) = get_infer_param(
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, num_speculative_tokens
) # block_size
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
block_tables,
encoder_batch_map,
decoder_batch_map,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
prefix_len,
decoder_context_len,
decoder_context_len_cache,
prefix_block_tables,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
64,
num_speculative_tokens,
)
qkv_prefix = qkv[hit_prefix_len:]
attn_out_prefix_cache = block_attn_func(
qkv_prefix,
@@ -194,6 +225,7 @@ def run_block_attn(
seq_lens_this_time = paddle.to_tensor([seq_len, 0, 0, 0, 0], dtype="int32")
block_tables = paddle.arange(0, block_batch * max_block_per_seq, dtype="int32")
block_tables = block_tables.reshape((block_batch, max_block_per_seq))
(
encoder_batch_map,
decoder_batch_map,
@@ -217,10 +249,39 @@ def run_block_attn(
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
) = init_inplace_tensor(seq_lens_encoder.shape[0], block_tables.shape)
(
slot_mapping_enc,
slot_mapping_dec,
) = get_infer_param(
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, num_speculative_tokens
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
block_tables,
encoder_batch_map,
decoder_batch_map,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
prefix_len,
decoder_context_len,
decoder_context_len_cache,
prefix_block_tables,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
64,
num_speculative_tokens,
)
qkv = paddle.uniform(
shape=[seq_len, (head_num + 2 * kv_head_num) * head_dim], dtype="bfloat16", min=-1.0, max=1.0, seed=seed
@@ -14,6 +14,7 @@
import numpy as np
import paddle
from utils import init_inplace_tensor
from fastdeploy.model_executor.ops.xpu import block_attn_fused, get_infer_param
@@ -24,6 +25,7 @@ seq_len = 128
block_batch = 5
max_block_per_seq = 128
block_size = 64
num_speculative_tokens = 0
seq_lens_encoder = paddle.to_tensor([128, 0, 0, 0, 0], dtype="int32")
seq_lens_decoder = paddle.to_tensor([0, 0, 0, 0, 0], dtype="int32")
@@ -53,11 +55,40 @@ block_tables = block_tables.reshape((block_batch, max_block_per_seq))
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
) = init_inplace_tensor(seq_lens_encoder.shape[0], block_tables.shape)
(
slot_mapping_enc,
slot_mapping_dec,
) = get_infer_param(
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, 0
) # block_size
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
block_tables,
encoder_batch_map,
decoder_batch_map,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
prefix_len,
decoder_context_len,
decoder_context_len_cache,
prefix_block_tables,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
64,
num_speculative_tokens,
)
qkv = paddle.uniform(
shape=[seq_len, (head_num + 2 * kv_head_num) * head_dim],
@@ -247,11 +278,40 @@ seq_lens_decoder = paddle.to_tensor([hit_prefix_len, 0, 0, 0, 0], dtype="int32")
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
) = init_inplace_tensor(seq_lens_encoder.shape[0], block_tables.shape)
(
slot_mapping_enc,
slot_mapping_dec,
) = get_infer_param(
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, 0
) # block_size
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
block_tables,
encoder_batch_map,
decoder_batch_map,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
prefix_len,
decoder_context_len,
decoder_context_len_cache,
prefix_block_tables,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
64,
num_speculative_tokens,
)
qkv_prefix = qkv[hit_prefix_len:]
attn_out_prefix_cache = block_attn_fused(
@@ -13,6 +13,7 @@
# limitations under the License.
import paddle
from utils import init_inplace_tensor
from fastdeploy.model_executor.ops.xpu import get_infer_param
@@ -21,6 +22,7 @@ seq_lens_decoder = paddle.to_tensor([0, 5, 0, 25, 64], dtype="int32")
seq_lens_this_time = paddle.to_tensor([100, 1, 0, 1, 300], dtype="int32")
block_table = paddle.arange(0, 40, dtype="int32")
block_table = block_table.reshape((5, 8))
(
encoder_batch_map,
decoder_batch_map,
@@ -44,9 +46,40 @@ block_table = block_table.reshape((5, 8))
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
) = init_inplace_tensor(seq_lens_encoder.shape[0], block_table.shape)
(
slot_mapping_enc,
slot_mapping_dec,
) = get_infer_param(
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_table, 64
) # block_size
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
block_table,
encoder_batch_map,
decoder_batch_map,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
prefix_len,
decoder_context_len,
decoder_context_len_cache,
prefix_block_tables,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
64,
0,
)
print("block_table", block_table)
print("encoder_batch_map", encoder_batch_map) # [0, 4, 0, 0, 0]
+54
View File
@@ -0,0 +1,54 @@
import paddle
def init_inplace_tensor(bsz, block_tables_shape):
encoder_batch_map = paddle.empty(bsz, dtype="int32")
decoder_batch_map = paddle.empty(bsz, dtype="int32")
encoder_batch_idx = paddle.empty(bsz, dtype="int32")
decoder_batch_idx = paddle.empty(bsz, dtype="int32")
encoder_seq_lod = paddle.empty(bsz + 1, dtype="int32")
decoder_seq_lod = paddle.empty(bsz + 1, dtype="int32")
encoder_kv_lod = paddle.empty(bsz + 1, dtype="int32")
prefix_len = paddle.empty(bsz, dtype="int32")
decoder_context_len = paddle.empty(bsz, dtype="int32")
decoder_context_len_cache = paddle.empty(bsz, dtype="int32")
prefix_block_tables = paddle.empty(block_tables_shape, dtype="int32")
encoder_batch_map_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
decoder_batch_map_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
encoder_batch_idx_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
decoder_batch_idx_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
encoder_seq_lod_cpu = paddle.empty(bsz + 1, dtype="int32", device="cpu")
decoder_seq_lod_cpu = paddle.empty(bsz + 1, dtype="int32", device="cpu")
encoder_kv_lod_cpu = paddle.empty(bsz + 1, dtype="int32", device="cpu")
prefix_len_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
decoder_context_len_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
decoder_context_len_cache_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
len_info_cpu = paddle.empty(7, dtype="int32", device="cpu")
return (
encoder_batch_map,
decoder_batch_map,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
prefix_len,
decoder_context_len,
decoder_context_len_cache,
prefix_block_tables,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
)