mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[XPU] get_infer_param use inplace copy, remove block_tables abundant d2h copy (#7431)
* inplace_copy: encoder_batch_idx/decoder_batch_idx bs == 9 ok * inplace_copy: encoder_seq_lod/decoder_seq_lod bs == 9 ok * inplace_copy: all bs == 9 ok * inplace_copy: all cpu bs == 9 ok * inplace_copy: len_info_cpu bs == 9 ok * finished and rm unused code * prefix_block_tables reuse * refine * improve performance * remove block_table copy to cpu * fix unit test * fix * resolve conflict * refine code * fix * fix * fix * fix * fix * try fix unit tests * fix * tmp save * fix unit test * get_infer_param try less return values * add yinwei fix --------- Co-authored-by: yinwei <yinwei_hust@163.com>
This commit is contained in:
@@ -25,7 +25,7 @@ namespace api = baidu::xpu::api;
|
||||
|
||||
void lod_to_slot_mapping(api::Context* xpu_ctx,
|
||||
paddle::Place place,
|
||||
const std::vector<int32_t>& block_table,
|
||||
const paddle::Tensor& block_table_xpu,
|
||||
const std::vector<int32_t>& kv_seq_lod,
|
||||
const std::vector<int32_t>& start_tokens,
|
||||
const std::vector<int32_t>& real_batch,
|
||||
@@ -35,10 +35,16 @@ void lod_to_slot_mapping(api::Context* xpu_ctx,
|
||||
int32_t batch_size,
|
||||
int32_t max_num_blocks_per_seq,
|
||||
int32_t num_speculative_tokens) {
|
||||
if (token_num <= 0) {
|
||||
int32_t actual_token_num = kv_seq_lod[batch_size];
|
||||
if (token_num <= 0 || actual_token_num <= 0) {
|
||||
return;
|
||||
}
|
||||
std::vector<int32_t> slot_mapping_vec(token_num, -1);
|
||||
|
||||
int ret;
|
||||
|
||||
std::vector<int32_t> block_table_idx_vec(actual_token_num);
|
||||
std::vector<int32_t> seq_offset_vec(actual_token_num);
|
||||
|
||||
int32_t idx = 0;
|
||||
// For each Batch
|
||||
for (auto batch_ = 0; batch_ < batch_size; batch_++) {
|
||||
@@ -47,20 +53,56 @@ void lod_to_slot_mapping(api::Context* xpu_ctx,
|
||||
int32_t dst_batch_id = real_batch[batch_];
|
||||
// for each token
|
||||
for (auto seq_ = seq_start; seq_ < seq_start + seq_len; seq_++) {
|
||||
int32_t table_id = seq_ / block_size;
|
||||
int32_t block_id =
|
||||
block_table[dst_batch_id * max_num_blocks_per_seq + table_id];
|
||||
int32_t seq_offset = seq_ % block_size;
|
||||
int32_t dst_token_offset = block_id * block_size + seq_offset;
|
||||
slot_mapping_vec[idx] = dst_token_offset;
|
||||
block_table_idx_vec[idx] =
|
||||
seq_ / block_size + dst_batch_id * max_num_blocks_per_seq;
|
||||
seq_offset_vec[idx] = seq_ % block_size;
|
||||
idx++;
|
||||
}
|
||||
}
|
||||
int ret = api::do_host2device(xpu_ctx,
|
||||
slot_mapping_vec.data(),
|
||||
slot_mapping,
|
||||
token_num * sizeof(int32_t));
|
||||
auto block_table_idx =
|
||||
paddle::empty({actual_token_num}, paddle::DataType::INT32, place);
|
||||
auto seq_offset =
|
||||
paddle::empty({actual_token_num}, paddle::DataType::INT32, place);
|
||||
ret = api::do_host2device(xpu_ctx,
|
||||
block_table_idx_vec.data(),
|
||||
block_table_idx.data<int32_t>(),
|
||||
actual_token_num * sizeof(int32_t));
|
||||
PD_CHECK(ret == api::SUCCESS, "api::do_host2device failed.");
|
||||
ret = api::do_host2device(xpu_ctx,
|
||||
seq_offset_vec.data(),
|
||||
seq_offset.data<int32_t>(),
|
||||
actual_token_num * sizeof(int32_t));
|
||||
PD_CHECK(ret == api::SUCCESS, "api::do_host2device failed.");
|
||||
|
||||
// int32_t block_id =
|
||||
// block_table[dst_batch_id * max_num_blocks_per_seq + table_id];
|
||||
auto block_id =
|
||||
paddle::empty({actual_token_num}, paddle::DataType::INT32, place);
|
||||
auto block_size_tensor =
|
||||
paddle::full({1}, block_size, paddle::DataType::INT32, place);
|
||||
ret = api::index_select<int32_t, int32_t>(xpu_ctx,
|
||||
block_table_xpu.data<int32_t>(),
|
||||
block_table_idx.data<int32_t>(),
|
||||
block_id.data<int32_t>(),
|
||||
{block_table_xpu.numel()},
|
||||
actual_token_num,
|
||||
0);
|
||||
PD_CHECK(ret == api::SUCCESS, "api::index_select failed.");
|
||||
// int32_t dst_token_offset = block_id * block_size + seq_offset;
|
||||
ret = api::broadcast_mul<int32_t>(xpu_ctx,
|
||||
block_id.data<int32_t>(),
|
||||
block_size_tensor.data<int32_t>(),
|
||||
block_id.data<int32_t>(),
|
||||
{actual_token_num},
|
||||
{1});
|
||||
PD_CHECK(ret == api::SUCCESS, "api::broadcast_mul failed.");
|
||||
ret = api::broadcast_add<int32_t>(xpu_ctx,
|
||||
block_id.data<int32_t>(),
|
||||
seq_offset.data<int32_t>(),
|
||||
slot_mapping,
|
||||
{actual_token_num},
|
||||
{actual_token_num});
|
||||
PD_CHECK(ret == api::SUCCESS, "api::broadcast_add failed.");
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> GetInferParam(
|
||||
@@ -68,6 +110,28 @@ std::vector<paddle::Tensor> GetInferParam(
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& block_tables,
|
||||
paddle::Tensor& encoder_batch_map,
|
||||
paddle::Tensor& decoder_batch_map,
|
||||
paddle::Tensor& encoder_batch_idx,
|
||||
paddle::Tensor& decoder_batch_idx,
|
||||
paddle::Tensor& encoder_seq_lod,
|
||||
paddle::Tensor& decoder_seq_lod,
|
||||
paddle::Tensor& encoder_kv_lod,
|
||||
paddle::Tensor& prefix_len,
|
||||
paddle::Tensor& decoder_context_len,
|
||||
paddle::Tensor& decoder_context_len_cache,
|
||||
paddle::Tensor& prefix_block_tables,
|
||||
paddle::Tensor& encoder_batch_map_cpu,
|
||||
paddle::Tensor& decoder_batch_map_cpu,
|
||||
paddle::Tensor& encoder_batch_idx_cpu,
|
||||
paddle::Tensor& decoder_batch_idx_cpu,
|
||||
paddle::Tensor& encoder_seq_lod_cpu,
|
||||
paddle::Tensor& decoder_seq_lod_cpu,
|
||||
paddle::Tensor& encoder_kv_lod_cpu,
|
||||
paddle::Tensor& prefix_len_cpu,
|
||||
paddle::Tensor& decoder_context_len_cpu,
|
||||
paddle::Tensor& decoder_context_len_cache_cpu,
|
||||
paddle::Tensor& len_info_cpu,
|
||||
int block_size,
|
||||
int num_speculative_tokens) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
@@ -128,18 +192,18 @@ std::vector<paddle::Tensor> GetInferParam(
|
||||
if (seq_lens_encoder_vec[i] > 0) {
|
||||
enc_batch++;
|
||||
int seq_len = seq_lens_encoder_vec[i];
|
||||
int prefix_len = seq_lens_decoder_vec[i];
|
||||
int prefix_len_int = seq_lens_decoder_vec[i];
|
||||
total_enc_len += seq_len;
|
||||
max_seq_len = std::max(max_seq_len, seq_len);
|
||||
max_prefix_len = std::max(max_prefix_len, prefix_len);
|
||||
max_kv_len = std::max(max_kv_len, seq_len + prefix_len);
|
||||
max_prefix_len = std::max(max_prefix_len, prefix_len_int);
|
||||
max_kv_len = std::max(max_kv_len, seq_len + prefix_len_int);
|
||||
encoder_batch_map_vec[enc_batch - 1] = i;
|
||||
encoder_batch_idx_vec[enc_batch - 1] = i - batch_offset;
|
||||
encoder_seq_lod_vec[enc_batch] =
|
||||
seq_len + encoder_seq_lod_vec[enc_batch - 1];
|
||||
encoder_kv_lod_vec[enc_batch] =
|
||||
seq_len + prefix_len + encoder_kv_lod_vec[enc_batch - 1];
|
||||
prefix_len_vec[enc_batch - 1] = prefix_len;
|
||||
seq_len + prefix_len_int + encoder_kv_lod_vec[enc_batch - 1];
|
||||
prefix_len_vec[enc_batch - 1] = prefix_len_int;
|
||||
} else if (seq_lens_decoder_vec[i] > 0 && seq_lens_this_time_vec[i] > 0) {
|
||||
dec_batch++;
|
||||
max_dec_len = std::max(max_dec_len, seq_lens_this_time_vec[i]);
|
||||
@@ -186,42 +250,6 @@ std::vector<paddle::Tensor> GetInferParam(
|
||||
prefix_block_num_per_seq = -1;
|
||||
}
|
||||
|
||||
auto encoder_batch_map = paddle::empty({encoder_batch_map_vec.size()},
|
||||
seq_lens_encoder.type(),
|
||||
seq_lens_encoder.place());
|
||||
auto decoder_batch_map = paddle::empty({decoder_batch_map_vec.size()},
|
||||
seq_lens_encoder.type(),
|
||||
seq_lens_encoder.place());
|
||||
auto encoder_batch_idx = paddle::empty({encoder_batch_idx_vec.size()},
|
||||
seq_lens_encoder.type(),
|
||||
seq_lens_encoder.place());
|
||||
auto decoder_batch_idx = paddle::empty({decoder_batch_idx_vec.size()},
|
||||
seq_lens_encoder.type(),
|
||||
seq_lens_encoder.place());
|
||||
auto encoder_seq_lod = paddle::empty({encoder_seq_lod_vec.size()},
|
||||
seq_lens_encoder.type(),
|
||||
seq_lens_encoder.place());
|
||||
auto decoder_seq_lod = paddle::empty({decoder_seq_lod_vec.size()},
|
||||
seq_lens_encoder.type(),
|
||||
seq_lens_encoder.place());
|
||||
auto encoder_kv_lod = paddle::empty({encoder_kv_lod_vec.size()},
|
||||
seq_lens_encoder.type(),
|
||||
seq_lens_encoder.place());
|
||||
auto prefix_len = paddle::empty({prefix_len_vec.size()},
|
||||
seq_lens_encoder.type(),
|
||||
seq_lens_encoder.place());
|
||||
auto decoder_context_len = paddle::empty({decoder_context_len_vec.size()},
|
||||
seq_lens_encoder.type(),
|
||||
seq_lens_encoder.place());
|
||||
auto decoder_context_len_cache =
|
||||
paddle::empty({decoder_context_len_cache_vec.size()},
|
||||
seq_lens_encoder.type(),
|
||||
seq_lens_encoder.place());
|
||||
auto prefix_block_tables =
|
||||
paddle::empty({block_bs, block_num_per_seq}, // full size
|
||||
seq_lens_encoder.type(),
|
||||
seq_lens_encoder.place());
|
||||
|
||||
// for store_paged_kv_cache of cudagraph mode
|
||||
// if slot_mapping is -1, store_paged_kv_cache will not write to kv cache
|
||||
paddle::Tensor slot_mapping_enc = paddle::full(
|
||||
@@ -232,72 +260,34 @@ std::vector<paddle::Tensor> GetInferParam(
|
||||
-1,
|
||||
paddle::DataType::INT32,
|
||||
seq_lens_decoder.place());
|
||||
if (FLAGS_encoder_splice || FLAGS_decoder_splice) {
|
||||
std::vector<int32_t> block_tables_vec(block_bs * block_num_per_seq);
|
||||
r = xpu_memcpy(block_tables_vec.data(),
|
||||
block_tables.data<int32_t>(),
|
||||
sizeof(int32_t) * block_bs * block_num_per_seq,
|
||||
XPUMemcpyKind::XPU_DEVICE_TO_HOST);
|
||||
if (FLAGS_encoder_splice) {
|
||||
lod_to_slot_mapping(xpu_ctx->x_context(),
|
||||
seq_lens_encoder.place(),
|
||||
block_tables_vec,
|
||||
encoder_seq_lod_vec,
|
||||
prefix_len_vec,
|
||||
encoder_batch_map_vec,
|
||||
slot_mapping_enc.data<int32_t>(),
|
||||
total_enc_len,
|
||||
block_size,
|
||||
enc_batch,
|
||||
block_num_per_seq,
|
||||
0);
|
||||
}
|
||||
if (FLAGS_decoder_splice) {
|
||||
lod_to_slot_mapping(xpu_ctx->x_context(),
|
||||
seq_lens_decoder.place(),
|
||||
block_tables_vec,
|
||||
decoder_seq_lod_vec,
|
||||
decoder_context_len_cache_vec,
|
||||
decoder_batch_map_vec,
|
||||
slot_mapping_dec.data<int32_t>(),
|
||||
bsz * (1 + num_speculative_tokens),
|
||||
block_size,
|
||||
dec_batch,
|
||||
block_num_per_seq,
|
||||
num_speculative_tokens);
|
||||
}
|
||||
if (FLAGS_encoder_splice) {
|
||||
lod_to_slot_mapping(xpu_ctx->x_context(),
|
||||
seq_lens_encoder.place(),
|
||||
block_tables,
|
||||
encoder_seq_lod_vec,
|
||||
prefix_len_vec,
|
||||
encoder_batch_map_vec,
|
||||
slot_mapping_enc.data<int32_t>(),
|
||||
slot_mapping_enc.numel(),
|
||||
block_size,
|
||||
enc_batch,
|
||||
block_num_per_seq,
|
||||
0);
|
||||
}
|
||||
if (FLAGS_decoder_splice) {
|
||||
lod_to_slot_mapping(xpu_ctx->x_context(),
|
||||
seq_lens_decoder.place(),
|
||||
block_tables,
|
||||
decoder_seq_lod_vec,
|
||||
decoder_context_len_cache_vec,
|
||||
decoder_batch_map_vec,
|
||||
slot_mapping_dec.data<int32_t>(),
|
||||
slot_mapping_dec.numel(),
|
||||
block_size,
|
||||
dec_batch,
|
||||
block_num_per_seq,
|
||||
num_speculative_tokens);
|
||||
}
|
||||
|
||||
auto encoder_batch_map_cpu = paddle::empty({encoder_batch_map_vec.size()},
|
||||
seq_lens_encoder.type(),
|
||||
paddle::CPUPlace());
|
||||
auto decoder_batch_map_cpu = paddle::empty({decoder_batch_map_vec.size()},
|
||||
seq_lens_encoder.type(),
|
||||
paddle::CPUPlace());
|
||||
auto encoder_batch_idx_cpu = paddle::empty({encoder_batch_idx_vec.size()},
|
||||
seq_lens_encoder.type(),
|
||||
paddle::CPUPlace());
|
||||
auto decoder_batch_idx_cpu = paddle::empty({decoder_batch_idx_vec.size()},
|
||||
seq_lens_encoder.type(),
|
||||
paddle::CPUPlace());
|
||||
auto encoder_seq_lod_cpu = paddle::empty({encoder_seq_lod_vec.size()},
|
||||
seq_lens_encoder.type(),
|
||||
paddle::CPUPlace());
|
||||
auto decoder_seq_lod_cpu = paddle::empty({decoder_seq_lod_vec.size()},
|
||||
seq_lens_encoder.type(),
|
||||
paddle::CPUPlace());
|
||||
|
||||
auto encoder_kv_lod_cpu = paddle::empty(
|
||||
{encoder_kv_lod_vec.size()}, seq_lens_encoder.type(), paddle::CPUPlace());
|
||||
auto prefix_len_cpu = paddle::empty(
|
||||
{prefix_len_vec.size()}, seq_lens_encoder.type(), paddle::CPUPlace());
|
||||
auto decoder_context_len_cpu = paddle::empty({decoder_context_len_vec.size()},
|
||||
seq_lens_encoder.type(),
|
||||
paddle::CPUPlace());
|
||||
auto decoder_context_len_cache_cpu =
|
||||
paddle::empty({decoder_context_len_cache_vec.size()},
|
||||
seq_lens_encoder.type(),
|
||||
paddle::CPUPlace());
|
||||
|
||||
ret = api::do_host2device(
|
||||
xpu_ctx->x_context(),
|
||||
@@ -400,65 +390,25 @@ std::vector<paddle::Tensor> GetInferParam(
|
||||
max_kv_len,
|
||||
prefix_block_num_per_seq,
|
||||
max_dec_len};
|
||||
auto len_info_cpu =
|
||||
paddle::empty({7}, seq_lens_encoder.type(), paddle::CPUPlace());
|
||||
|
||||
std::memcpy(len_info_cpu.data<int32_t>(),
|
||||
len_info_vec.data(),
|
||||
sizeof(int32_t) * len_info_vec.size());
|
||||
|
||||
return {encoder_batch_map,
|
||||
decoder_batch_map,
|
||||
encoder_batch_idx,
|
||||
decoder_batch_idx,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_kv_lod,
|
||||
prefix_len,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
prefix_block_tables,
|
||||
encoder_batch_map_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
encoder_batch_idx_cpu,
|
||||
decoder_batch_idx_cpu,
|
||||
encoder_seq_lod_cpu,
|
||||
decoder_seq_lod_cpu,
|
||||
encoder_kv_lod_cpu,
|
||||
prefix_len_cpu,
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
len_info_cpu,
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec};
|
||||
return {slot_mapping_enc, slot_mapping_dec};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> GetInferParamInferShape(
|
||||
const std::vector<int64_t>& seq_lens_encoder_shape,
|
||||
const std::vector<int64_t>& seq_lens_decoder_shape,
|
||||
const std::vector<int64_t>& seq_lens_this_time_shape,
|
||||
const std::vector<int64_t>& block_tables_shape) {
|
||||
return {seq_lens_encoder_shape,
|
||||
seq_lens_encoder_shape,
|
||||
seq_lens_encoder_shape,
|
||||
seq_lens_encoder_shape,
|
||||
{seq_lens_encoder_shape[0] + 1},
|
||||
{seq_lens_encoder_shape[0] + 1},
|
||||
{seq_lens_encoder_shape[0] + 1},
|
||||
seq_lens_encoder_shape,
|
||||
seq_lens_encoder_shape,
|
||||
seq_lens_encoder_shape,
|
||||
block_tables_shape,
|
||||
seq_lens_encoder_shape,
|
||||
seq_lens_encoder_shape,
|
||||
seq_lens_encoder_shape,
|
||||
seq_lens_encoder_shape,
|
||||
{seq_lens_encoder_shape[0] + 1},
|
||||
{seq_lens_encoder_shape[0] + 1},
|
||||
{seq_lens_encoder_shape[0] + 1},
|
||||
seq_lens_encoder_shape,
|
||||
seq_lens_encoder_shape,
|
||||
seq_lens_encoder_shape,
|
||||
{7}};
|
||||
const std::vector<int64_t>& block_tables_shape,
|
||||
int num_speculative_tokens) {
|
||||
// Return shapes for slot_mapping_enc and slot_mapping_dec
|
||||
// slot_mapping_enc shape depends on encoder token count (unknown at shape
|
||||
// inference time) slot_mapping_dec shape depends on batch size and
|
||||
// speculative token count
|
||||
return {{-1}, {seq_lens_encoder_shape[0] * (1 + num_speculative_tokens)}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> GetInferParamInferDtype(
|
||||
@@ -466,46 +416,38 @@ std::vector<paddle::DataType> GetInferParamInferDtype(
|
||||
const paddle::DataType& seq_lens_decoder_dtype,
|
||||
const paddle::DataType& seq_lens_this_time_dtype,
|
||||
const paddle::DataType& block_tables_dtype) {
|
||||
return {
|
||||
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
|
||||
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
|
||||
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
|
||||
seq_lens_encoder_dtype, block_tables_dtype, seq_lens_encoder_dtype,
|
||||
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
|
||||
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
|
||||
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
|
||||
seq_lens_encoder_dtype};
|
||||
// Return dtypes for slot_mapping_enc and slot_mapping_dec (both INT32)
|
||||
return {paddle::DataType::INT32, paddle::DataType::INT32};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(get_infer_param)
|
||||
.Inputs({"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_this_time",
|
||||
"block_tables"})
|
||||
.Outputs({"encoder_batch_map",
|
||||
"decoder_batch_map",
|
||||
"encoder_batch_idx",
|
||||
"decoder_batch_idx",
|
||||
"encoder_seq_lod",
|
||||
"decoder_seq_lod",
|
||||
"encoder_kv_lod",
|
||||
"prefix_len",
|
||||
"decoder_context_len",
|
||||
"decoder_context_len_cache",
|
||||
"prefix_block_tables",
|
||||
"encoder_batch_map_cpu",
|
||||
"decoder_batch_map_cpu",
|
||||
"encoder_batch_idx_cpu",
|
||||
"decoder_batch_idx_cpu",
|
||||
"encoder_seq_lod_cpu",
|
||||
"decoder_seq_lod_cpu",
|
||||
"encoder_kv_lod_cpu",
|
||||
"prefix_len_cpu",
|
||||
"decoder_context_len_cpu",
|
||||
"decoder_context_len_cache_cpu",
|
||||
"len_info_cpu",
|
||||
"slot_mapping_enc",
|
||||
"slot_mapping_dec"})
|
||||
"block_tables",
|
||||
"encoder_batch_map",
|
||||
"decoder_batch_map",
|
||||
"encoder_batch_idx",
|
||||
"decoder_batch_idx",
|
||||
"encoder_seq_lod",
|
||||
"decoder_seq_lod",
|
||||
"encoder_kv_lod",
|
||||
"prefix_len",
|
||||
"decoder_context_len",
|
||||
"decoder_context_len_cache",
|
||||
"prefix_block_tables",
|
||||
"encoder_batch_map_cpu",
|
||||
"decoder_batch_map_cpu",
|
||||
"encoder_batch_idx_cpu",
|
||||
"decoder_batch_idx_cpu",
|
||||
"encoder_seq_lod_cpu",
|
||||
"decoder_seq_lod_cpu",
|
||||
"encoder_kv_lod_cpu",
|
||||
"prefix_len_cpu",
|
||||
"decoder_context_len_cpu",
|
||||
"decoder_context_len_cache_cpu",
|
||||
"len_info_cpu"})
|
||||
.Outputs({"slot_mapping_enc", "slot_mapping_dec"})
|
||||
.SetKernelFn(PD_KERNEL(GetInferParam))
|
||||
.Attrs({"block_size: int", "num_speculative_tokens: int"})
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(GetInferParamInferShape))
|
||||
|
||||
@@ -478,6 +478,28 @@ std::vector<paddle::Tensor> GetInferParam(
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& block_tables,
|
||||
paddle::Tensor& encoder_batch_map,
|
||||
paddle::Tensor& decoder_batch_map,
|
||||
paddle::Tensor& encoder_batch_idx,
|
||||
paddle::Tensor& decoder_batch_idx,
|
||||
paddle::Tensor& encoder_seq_lod,
|
||||
paddle::Tensor& decoder_seq_lod,
|
||||
paddle::Tensor& encoder_kv_lod,
|
||||
paddle::Tensor& prefix_len,
|
||||
paddle::Tensor& decoder_context_len,
|
||||
paddle::Tensor& decoder_context_len_cache,
|
||||
paddle::Tensor& prefix_block_tables,
|
||||
paddle::Tensor& encoder_batch_map_cpu,
|
||||
paddle::Tensor& decoder_batch_map_cpu,
|
||||
paddle::Tensor& encoder_batch_idx_cpu,
|
||||
paddle::Tensor& decoder_batch_idx_cpu,
|
||||
paddle::Tensor& encoder_seq_lod_cpu,
|
||||
paddle::Tensor& decoder_seq_lod_cpu,
|
||||
paddle::Tensor& encoder_kv_lod_cpu,
|
||||
paddle::Tensor& prefix_len_cpu,
|
||||
paddle::Tensor& decoder_context_len_cpu,
|
||||
paddle::Tensor& decoder_context_len_cache_cpu,
|
||||
paddle::Tensor& len_info_cpu,
|
||||
int block_size,
|
||||
int num_speculative_tokens);
|
||||
|
||||
@@ -1052,6 +1074,28 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("seq_lens_decoder"),
|
||||
py::arg("seq_lens_this_time"),
|
||||
py::arg("block_tables"),
|
||||
py::arg("encoder_batch_map"),
|
||||
py::arg("decoder_batch_map"),
|
||||
py::arg("encoder_batch_idx"),
|
||||
py::arg("decoder_batch_idx"),
|
||||
py::arg("encoder_seq_lod"),
|
||||
py::arg("decoder_seq_lod"),
|
||||
py::arg("encoder_kv_lod"),
|
||||
py::arg("prefix_len"),
|
||||
py::arg("decoder_context_len"),
|
||||
py::arg("decoder_context_len_cache"),
|
||||
py::arg("prefix_block_tables"),
|
||||
py::arg("encoder_batch_map_cpu"),
|
||||
py::arg("decoder_batch_map_cpu"),
|
||||
py::arg("encoder_batch_idx_cpu"),
|
||||
py::arg("decoder_batch_idx_cpu"),
|
||||
py::arg("encoder_seq_lod_cpu"),
|
||||
py::arg("decoder_seq_lod_cpu"),
|
||||
py::arg("encoder_kv_lod_cpu"),
|
||||
py::arg("prefix_len_cpu"),
|
||||
py::arg("decoder_context_len_cpu"),
|
||||
py::arg("decoder_context_len_cache_cpu"),
|
||||
py::arg("len_info_cpu"),
|
||||
py::arg("block_size"),
|
||||
py::arg("num_speculative_tokens"),
|
||||
"Get infer parameters for block attention in XPU");
|
||||
|
||||
@@ -16,6 +16,7 @@ import unittest # 导入 unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from utils import init_inplace_tensor
|
||||
|
||||
from fastdeploy.model_executor.ops.xpu import (
|
||||
adjust_batch,
|
||||
@@ -33,11 +34,8 @@ def _run_test_base(seq_lens_this_time_data, is_speculative):
|
||||
seq_lens_this_time = paddle.to_tensor(seq_lens_this_time_data, dtype="int32")
|
||||
|
||||
bsz = seq_lens_this_time.shape[0]
|
||||
cum_offsets = paddle.zeros(bsz, dtype="int32")
|
||||
block_table = paddle.arange(0, 56, dtype="int32").reshape((bsz, 8))
|
||||
|
||||
infer_params = get_infer_param(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_table, 64)
|
||||
|
||||
(
|
||||
encoder_batch_map,
|
||||
decoder_batch_map,
|
||||
@@ -45,23 +43,56 @@ def _run_test_base(seq_lens_this_time_data, is_speculative):
|
||||
decoder_batch_idx,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
encoder_kv_lod,
|
||||
prefix_len,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
prefix_block_tables,
|
||||
encoder_batch_map_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
encoder_batch_idx_cpu,
|
||||
decoder_batch_idx_cpu,
|
||||
encoder_seq_lod_cpu,
|
||||
decoder_seq_lod_cpu,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
encoder_kv_lod_cpu,
|
||||
prefix_len_cpu,
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
len_info_cpu,
|
||||
) = infer_params
|
||||
) = init_inplace_tensor(seq_lens_encoder.shape[0], block_table.shape)
|
||||
(
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
) = get_infer_param(
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
block_table,
|
||||
encoder_batch_map,
|
||||
decoder_batch_map,
|
||||
encoder_batch_idx,
|
||||
decoder_batch_idx,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_kv_lod,
|
||||
prefix_len,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
prefix_block_tables,
|
||||
encoder_batch_map_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
encoder_batch_idx_cpu,
|
||||
decoder_batch_idx_cpu,
|
||||
encoder_seq_lod_cpu,
|
||||
decoder_seq_lod_cpu,
|
||||
encoder_kv_lod_cpu,
|
||||
prefix_len_cpu,
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
len_info_cpu,
|
||||
64,
|
||||
0,
|
||||
)
|
||||
|
||||
token_num = seq_lens_this_time.sum().cpu().item()
|
||||
hidden_dim = 8192
|
||||
@@ -72,7 +103,6 @@ def _run_test_base(seq_lens_this_time_data, is_speculative):
|
||||
# 测试 adjust_batch
|
||||
adjusted_output = adjust_batch(
|
||||
input_tensor,
|
||||
cum_offsets,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_batch_idx,
|
||||
@@ -88,7 +118,6 @@ def _run_test_base(seq_lens_this_time_data, is_speculative):
|
||||
|
||||
adjusted_output_cpu = adjust_batch(
|
||||
input_tensor.cpu(),
|
||||
cum_offsets,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_batch_idx,
|
||||
@@ -110,7 +139,6 @@ def _run_test_base(seq_lens_this_time_data, is_speculative):
|
||||
# 测试 gather_next_token
|
||||
gather_out = gather_next_token(
|
||||
adjusted_output,
|
||||
cum_offsets,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_batch_map,
|
||||
@@ -126,7 +154,6 @@ def _run_test_base(seq_lens_this_time_data, is_speculative):
|
||||
|
||||
gather_out_cpu = gather_next_token(
|
||||
adjusted_output.cpu(),
|
||||
cum_offsets,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_batch_map,
|
||||
|
||||
@@ -16,6 +16,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from utils import init_inplace_tensor
|
||||
|
||||
from fastdeploy.model_executor.ops.xpu import (
|
||||
adjust_batch,
|
||||
@@ -33,8 +34,6 @@ def _run_test_base(seq_lens_this_time_data):
|
||||
cum_offsets = paddle.zeros(bsz, dtype="int32")
|
||||
block_table = paddle.arange(0, 56, dtype="int32").reshape((bsz, 8))
|
||||
|
||||
infer_params = get_infer_param(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_table, 64)
|
||||
|
||||
(
|
||||
encoder_batch_map,
|
||||
decoder_batch_map,
|
||||
@@ -42,23 +41,56 @@ def _run_test_base(seq_lens_this_time_data):
|
||||
decoder_batch_idx,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
encoder_kv_lod,
|
||||
prefix_len,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
prefix_block_tables,
|
||||
encoder_batch_map_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
encoder_batch_idx_cpu,
|
||||
decoder_batch_idx_cpu,
|
||||
encoder_seq_lod_cpu,
|
||||
decoder_seq_lod_cpu,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
encoder_kv_lod_cpu,
|
||||
prefix_len_cpu,
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
len_info_cpu,
|
||||
) = infer_params
|
||||
) = init_inplace_tensor(seq_lens_encoder.shape[0], block_table.shape)
|
||||
(
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
) = get_infer_param(
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
block_table,
|
||||
encoder_batch_map,
|
||||
decoder_batch_map,
|
||||
encoder_batch_idx,
|
||||
decoder_batch_idx,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_kv_lod,
|
||||
prefix_len,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
prefix_block_tables,
|
||||
encoder_batch_map_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
encoder_batch_idx_cpu,
|
||||
decoder_batch_idx_cpu,
|
||||
encoder_seq_lod_cpu,
|
||||
decoder_seq_lod_cpu,
|
||||
encoder_kv_lod_cpu,
|
||||
prefix_len_cpu,
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
len_info_cpu,
|
||||
64,
|
||||
0,
|
||||
)
|
||||
|
||||
token_num = seq_lens_this_time.sum().cpu().item()
|
||||
hidden_dim = 8192
|
||||
@@ -68,7 +100,6 @@ def _run_test_base(seq_lens_this_time_data):
|
||||
# test adjust_batch
|
||||
adjusted_output = adjust_batch(
|
||||
input_tensor,
|
||||
cum_offsets,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_batch_idx,
|
||||
@@ -84,7 +115,6 @@ def _run_test_base(seq_lens_this_time_data):
|
||||
|
||||
adjusted_output_cpu = adjust_batch(
|
||||
input_tensor.cpu(),
|
||||
cum_offsets,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_batch_idx,
|
||||
|
||||
@@ -16,6 +16,7 @@ import random
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from utils import init_inplace_tensor
|
||||
|
||||
# block_attn_fused is deprecated and should be removed in the future
|
||||
from fastdeploy.model_executor.ops.xpu import (
|
||||
@@ -76,6 +77,7 @@ def run_prefix_cache_block_attn(
|
||||
# prefix cache block attn
|
||||
seq_lens_encoder = paddle.to_tensor([seq_len - hit_prefix_len, 0, 0, 0, 0], dtype="int32")
|
||||
seq_lens_decoder = paddle.to_tensor([hit_prefix_len, 0, 0, 0, 0], dtype="int32")
|
||||
|
||||
(
|
||||
encoder_batch_map,
|
||||
decoder_batch_map,
|
||||
@@ -99,11 +101,40 @@ def run_prefix_cache_block_attn(
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
len_info_cpu,
|
||||
) = init_inplace_tensor(seq_lens_encoder.shape[0], block_tables.shape)
|
||||
(
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
) = get_infer_param(
|
||||
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, num_speculative_tokens
|
||||
) # block_size
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
block_tables,
|
||||
encoder_batch_map,
|
||||
decoder_batch_map,
|
||||
encoder_batch_idx,
|
||||
decoder_batch_idx,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_kv_lod,
|
||||
prefix_len,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
prefix_block_tables,
|
||||
encoder_batch_map_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
encoder_batch_idx_cpu,
|
||||
decoder_batch_idx_cpu,
|
||||
encoder_seq_lod_cpu,
|
||||
decoder_seq_lod_cpu,
|
||||
encoder_kv_lod_cpu,
|
||||
prefix_len_cpu,
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
len_info_cpu,
|
||||
64,
|
||||
num_speculative_tokens,
|
||||
)
|
||||
qkv_prefix = qkv[hit_prefix_len:]
|
||||
attn_out_prefix_cache = block_attn_func(
|
||||
qkv_prefix,
|
||||
@@ -194,6 +225,7 @@ def run_block_attn(
|
||||
seq_lens_this_time = paddle.to_tensor([seq_len, 0, 0, 0, 0], dtype="int32")
|
||||
block_tables = paddle.arange(0, block_batch * max_block_per_seq, dtype="int32")
|
||||
block_tables = block_tables.reshape((block_batch, max_block_per_seq))
|
||||
|
||||
(
|
||||
encoder_batch_map,
|
||||
decoder_batch_map,
|
||||
@@ -217,10 +249,39 @@ def run_block_attn(
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
len_info_cpu,
|
||||
) = init_inplace_tensor(seq_lens_encoder.shape[0], block_tables.shape)
|
||||
(
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
) = get_infer_param(
|
||||
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, num_speculative_tokens
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
block_tables,
|
||||
encoder_batch_map,
|
||||
decoder_batch_map,
|
||||
encoder_batch_idx,
|
||||
decoder_batch_idx,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_kv_lod,
|
||||
prefix_len,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
prefix_block_tables,
|
||||
encoder_batch_map_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
encoder_batch_idx_cpu,
|
||||
decoder_batch_idx_cpu,
|
||||
encoder_seq_lod_cpu,
|
||||
decoder_seq_lod_cpu,
|
||||
encoder_kv_lod_cpu,
|
||||
prefix_len_cpu,
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
len_info_cpu,
|
||||
64,
|
||||
num_speculative_tokens,
|
||||
)
|
||||
qkv = paddle.uniform(
|
||||
shape=[seq_len, (head_num + 2 * kv_head_num) * head_dim], dtype="bfloat16", min=-1.0, max=1.0, seed=seed
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from utils import init_inplace_tensor
|
||||
|
||||
from fastdeploy.model_executor.ops.xpu import block_attn_fused, get_infer_param
|
||||
|
||||
@@ -24,6 +25,7 @@ seq_len = 128
|
||||
block_batch = 5
|
||||
max_block_per_seq = 128
|
||||
block_size = 64
|
||||
num_speculative_tokens = 0
|
||||
|
||||
seq_lens_encoder = paddle.to_tensor([128, 0, 0, 0, 0], dtype="int32")
|
||||
seq_lens_decoder = paddle.to_tensor([0, 0, 0, 0, 0], dtype="int32")
|
||||
@@ -53,11 +55,40 @@ block_tables = block_tables.reshape((block_batch, max_block_per_seq))
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
len_info_cpu,
|
||||
) = init_inplace_tensor(seq_lens_encoder.shape[0], block_tables.shape)
|
||||
(
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
) = get_infer_param(
|
||||
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, 0
|
||||
) # block_size
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
block_tables,
|
||||
encoder_batch_map,
|
||||
decoder_batch_map,
|
||||
encoder_batch_idx,
|
||||
decoder_batch_idx,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_kv_lod,
|
||||
prefix_len,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
prefix_block_tables,
|
||||
encoder_batch_map_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
encoder_batch_idx_cpu,
|
||||
decoder_batch_idx_cpu,
|
||||
encoder_seq_lod_cpu,
|
||||
decoder_seq_lod_cpu,
|
||||
encoder_kv_lod_cpu,
|
||||
prefix_len_cpu,
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
len_info_cpu,
|
||||
64,
|
||||
num_speculative_tokens,
|
||||
)
|
||||
|
||||
qkv = paddle.uniform(
|
||||
shape=[seq_len, (head_num + 2 * kv_head_num) * head_dim],
|
||||
@@ -247,11 +278,40 @@ seq_lens_decoder = paddle.to_tensor([hit_prefix_len, 0, 0, 0, 0], dtype="int32")
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
len_info_cpu,
|
||||
) = init_inplace_tensor(seq_lens_encoder.shape[0], block_tables.shape)
|
||||
(
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
) = get_infer_param(
|
||||
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, 0
|
||||
) # block_size
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
block_tables,
|
||||
encoder_batch_map,
|
||||
decoder_batch_map,
|
||||
encoder_batch_idx,
|
||||
decoder_batch_idx,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_kv_lod,
|
||||
prefix_len,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
prefix_block_tables,
|
||||
encoder_batch_map_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
encoder_batch_idx_cpu,
|
||||
decoder_batch_idx_cpu,
|
||||
encoder_seq_lod_cpu,
|
||||
decoder_seq_lod_cpu,
|
||||
encoder_kv_lod_cpu,
|
||||
prefix_len_cpu,
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
len_info_cpu,
|
||||
64,
|
||||
num_speculative_tokens,
|
||||
)
|
||||
qkv_prefix = qkv[hit_prefix_len:]
|
||||
|
||||
attn_out_prefix_cache = block_attn_fused(
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
from utils import init_inplace_tensor
|
||||
|
||||
from fastdeploy.model_executor.ops.xpu import get_infer_param
|
||||
|
||||
@@ -21,6 +22,7 @@ seq_lens_decoder = paddle.to_tensor([0, 5, 0, 25, 64], dtype="int32")
|
||||
seq_lens_this_time = paddle.to_tensor([100, 1, 0, 1, 300], dtype="int32")
|
||||
block_table = paddle.arange(0, 40, dtype="int32")
|
||||
block_table = block_table.reshape((5, 8))
|
||||
|
||||
(
|
||||
encoder_batch_map,
|
||||
decoder_batch_map,
|
||||
@@ -44,9 +46,40 @@ block_table = block_table.reshape((5, 8))
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
len_info_cpu,
|
||||
) = init_inplace_tensor(seq_lens_encoder.shape[0], block_table.shape)
|
||||
(
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
) = get_infer_param(
|
||||
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_table, 64
|
||||
) # block_size
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
block_table,
|
||||
encoder_batch_map,
|
||||
decoder_batch_map,
|
||||
encoder_batch_idx,
|
||||
decoder_batch_idx,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_kv_lod,
|
||||
prefix_len,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
prefix_block_tables,
|
||||
encoder_batch_map_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
encoder_batch_idx_cpu,
|
||||
decoder_batch_idx_cpu,
|
||||
encoder_seq_lod_cpu,
|
||||
decoder_seq_lod_cpu,
|
||||
encoder_kv_lod_cpu,
|
||||
prefix_len_cpu,
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
len_info_cpu,
|
||||
64,
|
||||
0,
|
||||
)
|
||||
|
||||
print("block_table", block_table)
|
||||
print("encoder_batch_map", encoder_batch_map) # [0, 4, 0, 0, 0]
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
import paddle
|
||||
|
||||
|
||||
def init_inplace_tensor(bsz, block_tables_shape):
|
||||
encoder_batch_map = paddle.empty(bsz, dtype="int32")
|
||||
decoder_batch_map = paddle.empty(bsz, dtype="int32")
|
||||
encoder_batch_idx = paddle.empty(bsz, dtype="int32")
|
||||
decoder_batch_idx = paddle.empty(bsz, dtype="int32")
|
||||
encoder_seq_lod = paddle.empty(bsz + 1, dtype="int32")
|
||||
decoder_seq_lod = paddle.empty(bsz + 1, dtype="int32")
|
||||
encoder_kv_lod = paddle.empty(bsz + 1, dtype="int32")
|
||||
prefix_len = paddle.empty(bsz, dtype="int32")
|
||||
decoder_context_len = paddle.empty(bsz, dtype="int32")
|
||||
decoder_context_len_cache = paddle.empty(bsz, dtype="int32")
|
||||
|
||||
prefix_block_tables = paddle.empty(block_tables_shape, dtype="int32")
|
||||
|
||||
encoder_batch_map_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
|
||||
decoder_batch_map_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
|
||||
encoder_batch_idx_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
|
||||
decoder_batch_idx_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
|
||||
encoder_seq_lod_cpu = paddle.empty(bsz + 1, dtype="int32", device="cpu")
|
||||
decoder_seq_lod_cpu = paddle.empty(bsz + 1, dtype="int32", device="cpu")
|
||||
encoder_kv_lod_cpu = paddle.empty(bsz + 1, dtype="int32", device="cpu")
|
||||
prefix_len_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
|
||||
decoder_context_len_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
|
||||
decoder_context_len_cache_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
|
||||
|
||||
len_info_cpu = paddle.empty(7, dtype="int32", device="cpu")
|
||||
|
||||
return (
|
||||
encoder_batch_map,
|
||||
decoder_batch_map,
|
||||
encoder_batch_idx,
|
||||
decoder_batch_idx,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_kv_lod,
|
||||
prefix_len,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
prefix_block_tables,
|
||||
encoder_batch_map_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
encoder_batch_idx_cpu,
|
||||
decoder_batch_idx_cpu,
|
||||
encoder_seq_lod_cpu,
|
||||
decoder_seq_lod_cpu,
|
||||
encoder_kv_lod_cpu,
|
||||
prefix_len_cpu,
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
len_info_cpu,
|
||||
)
|
||||
Reference in New Issue
Block a user