mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[XPU] get_infer_param use inplace copy, remove block_tables abundant d2h copy (#7431)
* inplace_copy: encoder_batch_idx/decoder_batch_idx bs == 9 ok * inplace_copy: encoder_seq_lod/decoder_seq_lod bs == 9 ok * inplace_copy: all bs == 9 ok * inplace_copy: all cpu bs == 9 ok * inplace_copy: len_info_cpu bs == 9 ok * finished and rm unused code * prefix_block_tables reuse * refine * improve performance * remove block_table copy to cpu * fix unit test * fix * resolve conflict * refine code * fix * fix * fix * fix * fix * try fix unit tests * fix * tmp save * fix unit test * get_infer_param try less return values * add yinwei fix --------- Co-authored-by: yinwei <yinwei_hust@163.com>
This commit is contained in:
@@ -25,7 +25,7 @@ namespace api = baidu::xpu::api;
|
|||||||
|
|
||||||
void lod_to_slot_mapping(api::Context* xpu_ctx,
|
void lod_to_slot_mapping(api::Context* xpu_ctx,
|
||||||
paddle::Place place,
|
paddle::Place place,
|
||||||
const std::vector<int32_t>& block_table,
|
const paddle::Tensor& block_table_xpu,
|
||||||
const std::vector<int32_t>& kv_seq_lod,
|
const std::vector<int32_t>& kv_seq_lod,
|
||||||
const std::vector<int32_t>& start_tokens,
|
const std::vector<int32_t>& start_tokens,
|
||||||
const std::vector<int32_t>& real_batch,
|
const std::vector<int32_t>& real_batch,
|
||||||
@@ -35,10 +35,16 @@ void lod_to_slot_mapping(api::Context* xpu_ctx,
|
|||||||
int32_t batch_size,
|
int32_t batch_size,
|
||||||
int32_t max_num_blocks_per_seq,
|
int32_t max_num_blocks_per_seq,
|
||||||
int32_t num_speculative_tokens) {
|
int32_t num_speculative_tokens) {
|
||||||
if (token_num <= 0) {
|
int32_t actual_token_num = kv_seq_lod[batch_size];
|
||||||
|
if (token_num <= 0 || actual_token_num <= 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
std::vector<int32_t> slot_mapping_vec(token_num, -1);
|
|
||||||
|
int ret;
|
||||||
|
|
||||||
|
std::vector<int32_t> block_table_idx_vec(actual_token_num);
|
||||||
|
std::vector<int32_t> seq_offset_vec(actual_token_num);
|
||||||
|
|
||||||
int32_t idx = 0;
|
int32_t idx = 0;
|
||||||
// For each Batch
|
// For each Batch
|
||||||
for (auto batch_ = 0; batch_ < batch_size; batch_++) {
|
for (auto batch_ = 0; batch_ < batch_size; batch_++) {
|
||||||
@@ -47,20 +53,56 @@ void lod_to_slot_mapping(api::Context* xpu_ctx,
|
|||||||
int32_t dst_batch_id = real_batch[batch_];
|
int32_t dst_batch_id = real_batch[batch_];
|
||||||
// for each token
|
// for each token
|
||||||
for (auto seq_ = seq_start; seq_ < seq_start + seq_len; seq_++) {
|
for (auto seq_ = seq_start; seq_ < seq_start + seq_len; seq_++) {
|
||||||
int32_t table_id = seq_ / block_size;
|
block_table_idx_vec[idx] =
|
||||||
int32_t block_id =
|
seq_ / block_size + dst_batch_id * max_num_blocks_per_seq;
|
||||||
block_table[dst_batch_id * max_num_blocks_per_seq + table_id];
|
seq_offset_vec[idx] = seq_ % block_size;
|
||||||
int32_t seq_offset = seq_ % block_size;
|
|
||||||
int32_t dst_token_offset = block_id * block_size + seq_offset;
|
|
||||||
slot_mapping_vec[idx] = dst_token_offset;
|
|
||||||
idx++;
|
idx++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
int ret = api::do_host2device(xpu_ctx,
|
auto block_table_idx =
|
||||||
slot_mapping_vec.data(),
|
paddle::empty({actual_token_num}, paddle::DataType::INT32, place);
|
||||||
slot_mapping,
|
auto seq_offset =
|
||||||
token_num * sizeof(int32_t));
|
paddle::empty({actual_token_num}, paddle::DataType::INT32, place);
|
||||||
|
ret = api::do_host2device(xpu_ctx,
|
||||||
|
block_table_idx_vec.data(),
|
||||||
|
block_table_idx.data<int32_t>(),
|
||||||
|
actual_token_num * sizeof(int32_t));
|
||||||
PD_CHECK(ret == api::SUCCESS, "api::do_host2device failed.");
|
PD_CHECK(ret == api::SUCCESS, "api::do_host2device failed.");
|
||||||
|
ret = api::do_host2device(xpu_ctx,
|
||||||
|
seq_offset_vec.data(),
|
||||||
|
seq_offset.data<int32_t>(),
|
||||||
|
actual_token_num * sizeof(int32_t));
|
||||||
|
PD_CHECK(ret == api::SUCCESS, "api::do_host2device failed.");
|
||||||
|
|
||||||
|
// int32_t block_id =
|
||||||
|
// block_table[dst_batch_id * max_num_blocks_per_seq + table_id];
|
||||||
|
auto block_id =
|
||||||
|
paddle::empty({actual_token_num}, paddle::DataType::INT32, place);
|
||||||
|
auto block_size_tensor =
|
||||||
|
paddle::full({1}, block_size, paddle::DataType::INT32, place);
|
||||||
|
ret = api::index_select<int32_t, int32_t>(xpu_ctx,
|
||||||
|
block_table_xpu.data<int32_t>(),
|
||||||
|
block_table_idx.data<int32_t>(),
|
||||||
|
block_id.data<int32_t>(),
|
||||||
|
{block_table_xpu.numel()},
|
||||||
|
actual_token_num,
|
||||||
|
0);
|
||||||
|
PD_CHECK(ret == api::SUCCESS, "api::index_select failed.");
|
||||||
|
// int32_t dst_token_offset = block_id * block_size + seq_offset;
|
||||||
|
ret = api::broadcast_mul<int32_t>(xpu_ctx,
|
||||||
|
block_id.data<int32_t>(),
|
||||||
|
block_size_tensor.data<int32_t>(),
|
||||||
|
block_id.data<int32_t>(),
|
||||||
|
{actual_token_num},
|
||||||
|
{1});
|
||||||
|
PD_CHECK(ret == api::SUCCESS, "api::broadcast_mul failed.");
|
||||||
|
ret = api::broadcast_add<int32_t>(xpu_ctx,
|
||||||
|
block_id.data<int32_t>(),
|
||||||
|
seq_offset.data<int32_t>(),
|
||||||
|
slot_mapping,
|
||||||
|
{actual_token_num},
|
||||||
|
{actual_token_num});
|
||||||
|
PD_CHECK(ret == api::SUCCESS, "api::broadcast_add failed.");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<paddle::Tensor> GetInferParam(
|
std::vector<paddle::Tensor> GetInferParam(
|
||||||
@@ -68,6 +110,28 @@ std::vector<paddle::Tensor> GetInferParam(
|
|||||||
const paddle::Tensor& seq_lens_decoder,
|
const paddle::Tensor& seq_lens_decoder,
|
||||||
const paddle::Tensor& seq_lens_this_time,
|
const paddle::Tensor& seq_lens_this_time,
|
||||||
const paddle::Tensor& block_tables,
|
const paddle::Tensor& block_tables,
|
||||||
|
paddle::Tensor& encoder_batch_map,
|
||||||
|
paddle::Tensor& decoder_batch_map,
|
||||||
|
paddle::Tensor& encoder_batch_idx,
|
||||||
|
paddle::Tensor& decoder_batch_idx,
|
||||||
|
paddle::Tensor& encoder_seq_lod,
|
||||||
|
paddle::Tensor& decoder_seq_lod,
|
||||||
|
paddle::Tensor& encoder_kv_lod,
|
||||||
|
paddle::Tensor& prefix_len,
|
||||||
|
paddle::Tensor& decoder_context_len,
|
||||||
|
paddle::Tensor& decoder_context_len_cache,
|
||||||
|
paddle::Tensor& prefix_block_tables,
|
||||||
|
paddle::Tensor& encoder_batch_map_cpu,
|
||||||
|
paddle::Tensor& decoder_batch_map_cpu,
|
||||||
|
paddle::Tensor& encoder_batch_idx_cpu,
|
||||||
|
paddle::Tensor& decoder_batch_idx_cpu,
|
||||||
|
paddle::Tensor& encoder_seq_lod_cpu,
|
||||||
|
paddle::Tensor& decoder_seq_lod_cpu,
|
||||||
|
paddle::Tensor& encoder_kv_lod_cpu,
|
||||||
|
paddle::Tensor& prefix_len_cpu,
|
||||||
|
paddle::Tensor& decoder_context_len_cpu,
|
||||||
|
paddle::Tensor& decoder_context_len_cache_cpu,
|
||||||
|
paddle::Tensor& len_info_cpu,
|
||||||
int block_size,
|
int block_size,
|
||||||
int num_speculative_tokens) {
|
int num_speculative_tokens) {
|
||||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||||
@@ -128,18 +192,18 @@ std::vector<paddle::Tensor> GetInferParam(
|
|||||||
if (seq_lens_encoder_vec[i] > 0) {
|
if (seq_lens_encoder_vec[i] > 0) {
|
||||||
enc_batch++;
|
enc_batch++;
|
||||||
int seq_len = seq_lens_encoder_vec[i];
|
int seq_len = seq_lens_encoder_vec[i];
|
||||||
int prefix_len = seq_lens_decoder_vec[i];
|
int prefix_len_int = seq_lens_decoder_vec[i];
|
||||||
total_enc_len += seq_len;
|
total_enc_len += seq_len;
|
||||||
max_seq_len = std::max(max_seq_len, seq_len);
|
max_seq_len = std::max(max_seq_len, seq_len);
|
||||||
max_prefix_len = std::max(max_prefix_len, prefix_len);
|
max_prefix_len = std::max(max_prefix_len, prefix_len_int);
|
||||||
max_kv_len = std::max(max_kv_len, seq_len + prefix_len);
|
max_kv_len = std::max(max_kv_len, seq_len + prefix_len_int);
|
||||||
encoder_batch_map_vec[enc_batch - 1] = i;
|
encoder_batch_map_vec[enc_batch - 1] = i;
|
||||||
encoder_batch_idx_vec[enc_batch - 1] = i - batch_offset;
|
encoder_batch_idx_vec[enc_batch - 1] = i - batch_offset;
|
||||||
encoder_seq_lod_vec[enc_batch] =
|
encoder_seq_lod_vec[enc_batch] =
|
||||||
seq_len + encoder_seq_lod_vec[enc_batch - 1];
|
seq_len + encoder_seq_lod_vec[enc_batch - 1];
|
||||||
encoder_kv_lod_vec[enc_batch] =
|
encoder_kv_lod_vec[enc_batch] =
|
||||||
seq_len + prefix_len + encoder_kv_lod_vec[enc_batch - 1];
|
seq_len + prefix_len_int + encoder_kv_lod_vec[enc_batch - 1];
|
||||||
prefix_len_vec[enc_batch - 1] = prefix_len;
|
prefix_len_vec[enc_batch - 1] = prefix_len_int;
|
||||||
} else if (seq_lens_decoder_vec[i] > 0 && seq_lens_this_time_vec[i] > 0) {
|
} else if (seq_lens_decoder_vec[i] > 0 && seq_lens_this_time_vec[i] > 0) {
|
||||||
dec_batch++;
|
dec_batch++;
|
||||||
max_dec_len = std::max(max_dec_len, seq_lens_this_time_vec[i]);
|
max_dec_len = std::max(max_dec_len, seq_lens_this_time_vec[i]);
|
||||||
@@ -186,42 +250,6 @@ std::vector<paddle::Tensor> GetInferParam(
|
|||||||
prefix_block_num_per_seq = -1;
|
prefix_block_num_per_seq = -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto encoder_batch_map = paddle::empty({encoder_batch_map_vec.size()},
|
|
||||||
seq_lens_encoder.type(),
|
|
||||||
seq_lens_encoder.place());
|
|
||||||
auto decoder_batch_map = paddle::empty({decoder_batch_map_vec.size()},
|
|
||||||
seq_lens_encoder.type(),
|
|
||||||
seq_lens_encoder.place());
|
|
||||||
auto encoder_batch_idx = paddle::empty({encoder_batch_idx_vec.size()},
|
|
||||||
seq_lens_encoder.type(),
|
|
||||||
seq_lens_encoder.place());
|
|
||||||
auto decoder_batch_idx = paddle::empty({decoder_batch_idx_vec.size()},
|
|
||||||
seq_lens_encoder.type(),
|
|
||||||
seq_lens_encoder.place());
|
|
||||||
auto encoder_seq_lod = paddle::empty({encoder_seq_lod_vec.size()},
|
|
||||||
seq_lens_encoder.type(),
|
|
||||||
seq_lens_encoder.place());
|
|
||||||
auto decoder_seq_lod = paddle::empty({decoder_seq_lod_vec.size()},
|
|
||||||
seq_lens_encoder.type(),
|
|
||||||
seq_lens_encoder.place());
|
|
||||||
auto encoder_kv_lod = paddle::empty({encoder_kv_lod_vec.size()},
|
|
||||||
seq_lens_encoder.type(),
|
|
||||||
seq_lens_encoder.place());
|
|
||||||
auto prefix_len = paddle::empty({prefix_len_vec.size()},
|
|
||||||
seq_lens_encoder.type(),
|
|
||||||
seq_lens_encoder.place());
|
|
||||||
auto decoder_context_len = paddle::empty({decoder_context_len_vec.size()},
|
|
||||||
seq_lens_encoder.type(),
|
|
||||||
seq_lens_encoder.place());
|
|
||||||
auto decoder_context_len_cache =
|
|
||||||
paddle::empty({decoder_context_len_cache_vec.size()},
|
|
||||||
seq_lens_encoder.type(),
|
|
||||||
seq_lens_encoder.place());
|
|
||||||
auto prefix_block_tables =
|
|
||||||
paddle::empty({block_bs, block_num_per_seq}, // full size
|
|
||||||
seq_lens_encoder.type(),
|
|
||||||
seq_lens_encoder.place());
|
|
||||||
|
|
||||||
// for store_paged_kv_cache of cudagraph mode
|
// for store_paged_kv_cache of cudagraph mode
|
||||||
// if slot_mapping is -1, store_paged_kv_cache will not write to kv cache
|
// if slot_mapping is -1, store_paged_kv_cache will not write to kv cache
|
||||||
paddle::Tensor slot_mapping_enc = paddle::full(
|
paddle::Tensor slot_mapping_enc = paddle::full(
|
||||||
@@ -232,21 +260,15 @@ std::vector<paddle::Tensor> GetInferParam(
|
|||||||
-1,
|
-1,
|
||||||
paddle::DataType::INT32,
|
paddle::DataType::INT32,
|
||||||
seq_lens_decoder.place());
|
seq_lens_decoder.place());
|
||||||
if (FLAGS_encoder_splice || FLAGS_decoder_splice) {
|
|
||||||
std::vector<int32_t> block_tables_vec(block_bs * block_num_per_seq);
|
|
||||||
r = xpu_memcpy(block_tables_vec.data(),
|
|
||||||
block_tables.data<int32_t>(),
|
|
||||||
sizeof(int32_t) * block_bs * block_num_per_seq,
|
|
||||||
XPUMemcpyKind::XPU_DEVICE_TO_HOST);
|
|
||||||
if (FLAGS_encoder_splice) {
|
if (FLAGS_encoder_splice) {
|
||||||
lod_to_slot_mapping(xpu_ctx->x_context(),
|
lod_to_slot_mapping(xpu_ctx->x_context(),
|
||||||
seq_lens_encoder.place(),
|
seq_lens_encoder.place(),
|
||||||
block_tables_vec,
|
block_tables,
|
||||||
encoder_seq_lod_vec,
|
encoder_seq_lod_vec,
|
||||||
prefix_len_vec,
|
prefix_len_vec,
|
||||||
encoder_batch_map_vec,
|
encoder_batch_map_vec,
|
||||||
slot_mapping_enc.data<int32_t>(),
|
slot_mapping_enc.data<int32_t>(),
|
||||||
total_enc_len,
|
slot_mapping_enc.numel(),
|
||||||
block_size,
|
block_size,
|
||||||
enc_batch,
|
enc_batch,
|
||||||
block_num_per_seq,
|
block_num_per_seq,
|
||||||
@@ -255,49 +277,17 @@ std::vector<paddle::Tensor> GetInferParam(
|
|||||||
if (FLAGS_decoder_splice) {
|
if (FLAGS_decoder_splice) {
|
||||||
lod_to_slot_mapping(xpu_ctx->x_context(),
|
lod_to_slot_mapping(xpu_ctx->x_context(),
|
||||||
seq_lens_decoder.place(),
|
seq_lens_decoder.place(),
|
||||||
block_tables_vec,
|
block_tables,
|
||||||
decoder_seq_lod_vec,
|
decoder_seq_lod_vec,
|
||||||
decoder_context_len_cache_vec,
|
decoder_context_len_cache_vec,
|
||||||
decoder_batch_map_vec,
|
decoder_batch_map_vec,
|
||||||
slot_mapping_dec.data<int32_t>(),
|
slot_mapping_dec.data<int32_t>(),
|
||||||
bsz * (1 + num_speculative_tokens),
|
slot_mapping_dec.numel(),
|
||||||
block_size,
|
block_size,
|
||||||
dec_batch,
|
dec_batch,
|
||||||
block_num_per_seq,
|
block_num_per_seq,
|
||||||
num_speculative_tokens);
|
num_speculative_tokens);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
auto encoder_batch_map_cpu = paddle::empty({encoder_batch_map_vec.size()},
|
|
||||||
seq_lens_encoder.type(),
|
|
||||||
paddle::CPUPlace());
|
|
||||||
auto decoder_batch_map_cpu = paddle::empty({decoder_batch_map_vec.size()},
|
|
||||||
seq_lens_encoder.type(),
|
|
||||||
paddle::CPUPlace());
|
|
||||||
auto encoder_batch_idx_cpu = paddle::empty({encoder_batch_idx_vec.size()},
|
|
||||||
seq_lens_encoder.type(),
|
|
||||||
paddle::CPUPlace());
|
|
||||||
auto decoder_batch_idx_cpu = paddle::empty({decoder_batch_idx_vec.size()},
|
|
||||||
seq_lens_encoder.type(),
|
|
||||||
paddle::CPUPlace());
|
|
||||||
auto encoder_seq_lod_cpu = paddle::empty({encoder_seq_lod_vec.size()},
|
|
||||||
seq_lens_encoder.type(),
|
|
||||||
paddle::CPUPlace());
|
|
||||||
auto decoder_seq_lod_cpu = paddle::empty({decoder_seq_lod_vec.size()},
|
|
||||||
seq_lens_encoder.type(),
|
|
||||||
paddle::CPUPlace());
|
|
||||||
|
|
||||||
auto encoder_kv_lod_cpu = paddle::empty(
|
|
||||||
{encoder_kv_lod_vec.size()}, seq_lens_encoder.type(), paddle::CPUPlace());
|
|
||||||
auto prefix_len_cpu = paddle::empty(
|
|
||||||
{prefix_len_vec.size()}, seq_lens_encoder.type(), paddle::CPUPlace());
|
|
||||||
auto decoder_context_len_cpu = paddle::empty({decoder_context_len_vec.size()},
|
|
||||||
seq_lens_encoder.type(),
|
|
||||||
paddle::CPUPlace());
|
|
||||||
auto decoder_context_len_cache_cpu =
|
|
||||||
paddle::empty({decoder_context_len_cache_vec.size()},
|
|
||||||
seq_lens_encoder.type(),
|
|
||||||
paddle::CPUPlace());
|
|
||||||
|
|
||||||
ret = api::do_host2device(
|
ret = api::do_host2device(
|
||||||
xpu_ctx->x_context(),
|
xpu_ctx->x_context(),
|
||||||
@@ -400,65 +390,25 @@ std::vector<paddle::Tensor> GetInferParam(
|
|||||||
max_kv_len,
|
max_kv_len,
|
||||||
prefix_block_num_per_seq,
|
prefix_block_num_per_seq,
|
||||||
max_dec_len};
|
max_dec_len};
|
||||||
auto len_info_cpu =
|
|
||||||
paddle::empty({7}, seq_lens_encoder.type(), paddle::CPUPlace());
|
|
||||||
std::memcpy(len_info_cpu.data<int32_t>(),
|
std::memcpy(len_info_cpu.data<int32_t>(),
|
||||||
len_info_vec.data(),
|
len_info_vec.data(),
|
||||||
sizeof(int32_t) * len_info_vec.size());
|
sizeof(int32_t) * len_info_vec.size());
|
||||||
|
|
||||||
return {encoder_batch_map,
|
return {slot_mapping_enc, slot_mapping_dec};
|
||||||
decoder_batch_map,
|
|
||||||
encoder_batch_idx,
|
|
||||||
decoder_batch_idx,
|
|
||||||
encoder_seq_lod,
|
|
||||||
decoder_seq_lod,
|
|
||||||
encoder_kv_lod,
|
|
||||||
prefix_len,
|
|
||||||
decoder_context_len,
|
|
||||||
decoder_context_len_cache,
|
|
||||||
prefix_block_tables,
|
|
||||||
encoder_batch_map_cpu,
|
|
||||||
decoder_batch_map_cpu,
|
|
||||||
encoder_batch_idx_cpu,
|
|
||||||
decoder_batch_idx_cpu,
|
|
||||||
encoder_seq_lod_cpu,
|
|
||||||
decoder_seq_lod_cpu,
|
|
||||||
encoder_kv_lod_cpu,
|
|
||||||
prefix_len_cpu,
|
|
||||||
decoder_context_len_cpu,
|
|
||||||
decoder_context_len_cache_cpu,
|
|
||||||
len_info_cpu,
|
|
||||||
slot_mapping_enc,
|
|
||||||
slot_mapping_dec};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::vector<int64_t>> GetInferParamInferShape(
|
std::vector<std::vector<int64_t>> GetInferParamInferShape(
|
||||||
const std::vector<int64_t>& seq_lens_encoder_shape,
|
const std::vector<int64_t>& seq_lens_encoder_shape,
|
||||||
const std::vector<int64_t>& seq_lens_decoder_shape,
|
const std::vector<int64_t>& seq_lens_decoder_shape,
|
||||||
const std::vector<int64_t>& seq_lens_this_time_shape,
|
const std::vector<int64_t>& seq_lens_this_time_shape,
|
||||||
const std::vector<int64_t>& block_tables_shape) {
|
const std::vector<int64_t>& block_tables_shape,
|
||||||
return {seq_lens_encoder_shape,
|
int num_speculative_tokens) {
|
||||||
seq_lens_encoder_shape,
|
// Return shapes for slot_mapping_enc and slot_mapping_dec
|
||||||
seq_lens_encoder_shape,
|
// slot_mapping_enc shape depends on encoder token count (unknown at shape
|
||||||
seq_lens_encoder_shape,
|
// inference time) slot_mapping_dec shape depends on batch size and
|
||||||
{seq_lens_encoder_shape[0] + 1},
|
// speculative token count
|
||||||
{seq_lens_encoder_shape[0] + 1},
|
return {{-1}, {seq_lens_encoder_shape[0] * (1 + num_speculative_tokens)}};
|
||||||
{seq_lens_encoder_shape[0] + 1},
|
|
||||||
seq_lens_encoder_shape,
|
|
||||||
seq_lens_encoder_shape,
|
|
||||||
seq_lens_encoder_shape,
|
|
||||||
block_tables_shape,
|
|
||||||
seq_lens_encoder_shape,
|
|
||||||
seq_lens_encoder_shape,
|
|
||||||
seq_lens_encoder_shape,
|
|
||||||
seq_lens_encoder_shape,
|
|
||||||
{seq_lens_encoder_shape[0] + 1},
|
|
||||||
{seq_lens_encoder_shape[0] + 1},
|
|
||||||
{seq_lens_encoder_shape[0] + 1},
|
|
||||||
seq_lens_encoder_shape,
|
|
||||||
seq_lens_encoder_shape,
|
|
||||||
seq_lens_encoder_shape,
|
|
||||||
{7}};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<paddle::DataType> GetInferParamInferDtype(
|
std::vector<paddle::DataType> GetInferParamInferDtype(
|
||||||
@@ -466,23 +416,16 @@ std::vector<paddle::DataType> GetInferParamInferDtype(
|
|||||||
const paddle::DataType& seq_lens_decoder_dtype,
|
const paddle::DataType& seq_lens_decoder_dtype,
|
||||||
const paddle::DataType& seq_lens_this_time_dtype,
|
const paddle::DataType& seq_lens_this_time_dtype,
|
||||||
const paddle::DataType& block_tables_dtype) {
|
const paddle::DataType& block_tables_dtype) {
|
||||||
return {
|
// Return dtypes for slot_mapping_enc and slot_mapping_dec (both INT32)
|
||||||
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
|
return {paddle::DataType::INT32, paddle::DataType::INT32};
|
||||||
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
|
|
||||||
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
|
|
||||||
seq_lens_encoder_dtype, block_tables_dtype, seq_lens_encoder_dtype,
|
|
||||||
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
|
|
||||||
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
|
|
||||||
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
|
|
||||||
seq_lens_encoder_dtype};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_OP(get_infer_param)
|
PD_BUILD_OP(get_infer_param)
|
||||||
.Inputs({"seq_lens_encoder",
|
.Inputs({"seq_lens_encoder",
|
||||||
"seq_lens_decoder",
|
"seq_lens_decoder",
|
||||||
"seq_lens_this_time",
|
"seq_lens_this_time",
|
||||||
"block_tables"})
|
"block_tables",
|
||||||
.Outputs({"encoder_batch_map",
|
"encoder_batch_map",
|
||||||
"decoder_batch_map",
|
"decoder_batch_map",
|
||||||
"encoder_batch_idx",
|
"encoder_batch_idx",
|
||||||
"decoder_batch_idx",
|
"decoder_batch_idx",
|
||||||
@@ -503,9 +446,8 @@ PD_BUILD_OP(get_infer_param)
|
|||||||
"prefix_len_cpu",
|
"prefix_len_cpu",
|
||||||
"decoder_context_len_cpu",
|
"decoder_context_len_cpu",
|
||||||
"decoder_context_len_cache_cpu",
|
"decoder_context_len_cache_cpu",
|
||||||
"len_info_cpu",
|
"len_info_cpu"})
|
||||||
"slot_mapping_enc",
|
.Outputs({"slot_mapping_enc", "slot_mapping_dec"})
|
||||||
"slot_mapping_dec"})
|
|
||||||
.SetKernelFn(PD_KERNEL(GetInferParam))
|
.SetKernelFn(PD_KERNEL(GetInferParam))
|
||||||
.Attrs({"block_size: int", "num_speculative_tokens: int"})
|
.Attrs({"block_size: int", "num_speculative_tokens: int"})
|
||||||
.SetInferShapeFn(PD_INFER_SHAPE(GetInferParamInferShape))
|
.SetInferShapeFn(PD_INFER_SHAPE(GetInferParamInferShape))
|
||||||
|
|||||||
@@ -478,6 +478,28 @@ std::vector<paddle::Tensor> GetInferParam(
|
|||||||
const paddle::Tensor& seq_lens_decoder,
|
const paddle::Tensor& seq_lens_decoder,
|
||||||
const paddle::Tensor& seq_lens_this_time,
|
const paddle::Tensor& seq_lens_this_time,
|
||||||
const paddle::Tensor& block_tables,
|
const paddle::Tensor& block_tables,
|
||||||
|
paddle::Tensor& encoder_batch_map,
|
||||||
|
paddle::Tensor& decoder_batch_map,
|
||||||
|
paddle::Tensor& encoder_batch_idx,
|
||||||
|
paddle::Tensor& decoder_batch_idx,
|
||||||
|
paddle::Tensor& encoder_seq_lod,
|
||||||
|
paddle::Tensor& decoder_seq_lod,
|
||||||
|
paddle::Tensor& encoder_kv_lod,
|
||||||
|
paddle::Tensor& prefix_len,
|
||||||
|
paddle::Tensor& decoder_context_len,
|
||||||
|
paddle::Tensor& decoder_context_len_cache,
|
||||||
|
paddle::Tensor& prefix_block_tables,
|
||||||
|
paddle::Tensor& encoder_batch_map_cpu,
|
||||||
|
paddle::Tensor& decoder_batch_map_cpu,
|
||||||
|
paddle::Tensor& encoder_batch_idx_cpu,
|
||||||
|
paddle::Tensor& decoder_batch_idx_cpu,
|
||||||
|
paddle::Tensor& encoder_seq_lod_cpu,
|
||||||
|
paddle::Tensor& decoder_seq_lod_cpu,
|
||||||
|
paddle::Tensor& encoder_kv_lod_cpu,
|
||||||
|
paddle::Tensor& prefix_len_cpu,
|
||||||
|
paddle::Tensor& decoder_context_len_cpu,
|
||||||
|
paddle::Tensor& decoder_context_len_cache_cpu,
|
||||||
|
paddle::Tensor& len_info_cpu,
|
||||||
int block_size,
|
int block_size,
|
||||||
int num_speculative_tokens);
|
int num_speculative_tokens);
|
||||||
|
|
||||||
@@ -1052,6 +1074,28 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
|||||||
py::arg("seq_lens_decoder"),
|
py::arg("seq_lens_decoder"),
|
||||||
py::arg("seq_lens_this_time"),
|
py::arg("seq_lens_this_time"),
|
||||||
py::arg("block_tables"),
|
py::arg("block_tables"),
|
||||||
|
py::arg("encoder_batch_map"),
|
||||||
|
py::arg("decoder_batch_map"),
|
||||||
|
py::arg("encoder_batch_idx"),
|
||||||
|
py::arg("decoder_batch_idx"),
|
||||||
|
py::arg("encoder_seq_lod"),
|
||||||
|
py::arg("decoder_seq_lod"),
|
||||||
|
py::arg("encoder_kv_lod"),
|
||||||
|
py::arg("prefix_len"),
|
||||||
|
py::arg("decoder_context_len"),
|
||||||
|
py::arg("decoder_context_len_cache"),
|
||||||
|
py::arg("prefix_block_tables"),
|
||||||
|
py::arg("encoder_batch_map_cpu"),
|
||||||
|
py::arg("decoder_batch_map_cpu"),
|
||||||
|
py::arg("encoder_batch_idx_cpu"),
|
||||||
|
py::arg("decoder_batch_idx_cpu"),
|
||||||
|
py::arg("encoder_seq_lod_cpu"),
|
||||||
|
py::arg("decoder_seq_lod_cpu"),
|
||||||
|
py::arg("encoder_kv_lod_cpu"),
|
||||||
|
py::arg("prefix_len_cpu"),
|
||||||
|
py::arg("decoder_context_len_cpu"),
|
||||||
|
py::arg("decoder_context_len_cache_cpu"),
|
||||||
|
py::arg("len_info_cpu"),
|
||||||
py::arg("block_size"),
|
py::arg("block_size"),
|
||||||
py::arg("num_speculative_tokens"),
|
py::arg("num_speculative_tokens"),
|
||||||
"Get infer parameters for block attention in XPU");
|
"Get infer parameters for block attention in XPU");
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import unittest # 导入 unittest
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
|
from utils import init_inplace_tensor
|
||||||
|
|
||||||
from fastdeploy.model_executor.ops.xpu import (
|
from fastdeploy.model_executor.ops.xpu import (
|
||||||
adjust_batch,
|
adjust_batch,
|
||||||
@@ -33,11 +34,8 @@ def _run_test_base(seq_lens_this_time_data, is_speculative):
|
|||||||
seq_lens_this_time = paddle.to_tensor(seq_lens_this_time_data, dtype="int32")
|
seq_lens_this_time = paddle.to_tensor(seq_lens_this_time_data, dtype="int32")
|
||||||
|
|
||||||
bsz = seq_lens_this_time.shape[0]
|
bsz = seq_lens_this_time.shape[0]
|
||||||
cum_offsets = paddle.zeros(bsz, dtype="int32")
|
|
||||||
block_table = paddle.arange(0, 56, dtype="int32").reshape((bsz, 8))
|
block_table = paddle.arange(0, 56, dtype="int32").reshape((bsz, 8))
|
||||||
|
|
||||||
infer_params = get_infer_param(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_table, 64)
|
|
||||||
|
|
||||||
(
|
(
|
||||||
encoder_batch_map,
|
encoder_batch_map,
|
||||||
decoder_batch_map,
|
decoder_batch_map,
|
||||||
@@ -45,23 +43,56 @@ def _run_test_base(seq_lens_this_time_data, is_speculative):
|
|||||||
decoder_batch_idx,
|
decoder_batch_idx,
|
||||||
encoder_seq_lod,
|
encoder_seq_lod,
|
||||||
decoder_seq_lod,
|
decoder_seq_lod,
|
||||||
_,
|
encoder_kv_lod,
|
||||||
_,
|
prefix_len,
|
||||||
_,
|
decoder_context_len,
|
||||||
_,
|
decoder_context_len_cache,
|
||||||
_,
|
prefix_block_tables,
|
||||||
encoder_batch_map_cpu,
|
encoder_batch_map_cpu,
|
||||||
decoder_batch_map_cpu,
|
decoder_batch_map_cpu,
|
||||||
encoder_batch_idx_cpu,
|
encoder_batch_idx_cpu,
|
||||||
decoder_batch_idx_cpu,
|
decoder_batch_idx_cpu,
|
||||||
encoder_seq_lod_cpu,
|
encoder_seq_lod_cpu,
|
||||||
decoder_seq_lod_cpu,
|
decoder_seq_lod_cpu,
|
||||||
_,
|
encoder_kv_lod_cpu,
|
||||||
_,
|
prefix_len_cpu,
|
||||||
_,
|
decoder_context_len_cpu,
|
||||||
_,
|
decoder_context_len_cache_cpu,
|
||||||
len_info_cpu,
|
len_info_cpu,
|
||||||
) = infer_params
|
) = init_inplace_tensor(seq_lens_encoder.shape[0], block_table.shape)
|
||||||
|
(
|
||||||
|
slot_mapping_enc,
|
||||||
|
slot_mapping_dec,
|
||||||
|
) = get_infer_param(
|
||||||
|
seq_lens_encoder,
|
||||||
|
seq_lens_decoder,
|
||||||
|
seq_lens_this_time,
|
||||||
|
block_table,
|
||||||
|
encoder_batch_map,
|
||||||
|
decoder_batch_map,
|
||||||
|
encoder_batch_idx,
|
||||||
|
decoder_batch_idx,
|
||||||
|
encoder_seq_lod,
|
||||||
|
decoder_seq_lod,
|
||||||
|
encoder_kv_lod,
|
||||||
|
prefix_len,
|
||||||
|
decoder_context_len,
|
||||||
|
decoder_context_len_cache,
|
||||||
|
prefix_block_tables,
|
||||||
|
encoder_batch_map_cpu,
|
||||||
|
decoder_batch_map_cpu,
|
||||||
|
encoder_batch_idx_cpu,
|
||||||
|
decoder_batch_idx_cpu,
|
||||||
|
encoder_seq_lod_cpu,
|
||||||
|
decoder_seq_lod_cpu,
|
||||||
|
encoder_kv_lod_cpu,
|
||||||
|
prefix_len_cpu,
|
||||||
|
decoder_context_len_cpu,
|
||||||
|
decoder_context_len_cache_cpu,
|
||||||
|
len_info_cpu,
|
||||||
|
64,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
token_num = seq_lens_this_time.sum().cpu().item()
|
token_num = seq_lens_this_time.sum().cpu().item()
|
||||||
hidden_dim = 8192
|
hidden_dim = 8192
|
||||||
@@ -72,7 +103,6 @@ def _run_test_base(seq_lens_this_time_data, is_speculative):
|
|||||||
# 测试 adjust_batch
|
# 测试 adjust_batch
|
||||||
adjusted_output = adjust_batch(
|
adjusted_output = adjust_batch(
|
||||||
input_tensor,
|
input_tensor,
|
||||||
cum_offsets,
|
|
||||||
encoder_seq_lod,
|
encoder_seq_lod,
|
||||||
decoder_seq_lod,
|
decoder_seq_lod,
|
||||||
encoder_batch_idx,
|
encoder_batch_idx,
|
||||||
@@ -88,7 +118,6 @@ def _run_test_base(seq_lens_this_time_data, is_speculative):
|
|||||||
|
|
||||||
adjusted_output_cpu = adjust_batch(
|
adjusted_output_cpu = adjust_batch(
|
||||||
input_tensor.cpu(),
|
input_tensor.cpu(),
|
||||||
cum_offsets,
|
|
||||||
encoder_seq_lod,
|
encoder_seq_lod,
|
||||||
decoder_seq_lod,
|
decoder_seq_lod,
|
||||||
encoder_batch_idx,
|
encoder_batch_idx,
|
||||||
@@ -110,7 +139,6 @@ def _run_test_base(seq_lens_this_time_data, is_speculative):
|
|||||||
# 测试 gather_next_token
|
# 测试 gather_next_token
|
||||||
gather_out = gather_next_token(
|
gather_out = gather_next_token(
|
||||||
adjusted_output,
|
adjusted_output,
|
||||||
cum_offsets,
|
|
||||||
encoder_seq_lod,
|
encoder_seq_lod,
|
||||||
decoder_seq_lod,
|
decoder_seq_lod,
|
||||||
encoder_batch_map,
|
encoder_batch_map,
|
||||||
@@ -126,7 +154,6 @@ def _run_test_base(seq_lens_this_time_data, is_speculative):
|
|||||||
|
|
||||||
gather_out_cpu = gather_next_token(
|
gather_out_cpu = gather_next_token(
|
||||||
adjusted_output.cpu(),
|
adjusted_output.cpu(),
|
||||||
cum_offsets,
|
|
||||||
encoder_seq_lod,
|
encoder_seq_lod,
|
||||||
decoder_seq_lod,
|
decoder_seq_lod,
|
||||||
encoder_batch_map,
|
encoder_batch_map,
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import unittest
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
|
from utils import init_inplace_tensor
|
||||||
|
|
||||||
from fastdeploy.model_executor.ops.xpu import (
|
from fastdeploy.model_executor.ops.xpu import (
|
||||||
adjust_batch,
|
adjust_batch,
|
||||||
@@ -33,8 +34,6 @@ def _run_test_base(seq_lens_this_time_data):
|
|||||||
cum_offsets = paddle.zeros(bsz, dtype="int32")
|
cum_offsets = paddle.zeros(bsz, dtype="int32")
|
||||||
block_table = paddle.arange(0, 56, dtype="int32").reshape((bsz, 8))
|
block_table = paddle.arange(0, 56, dtype="int32").reshape((bsz, 8))
|
||||||
|
|
||||||
infer_params = get_infer_param(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_table, 64)
|
|
||||||
|
|
||||||
(
|
(
|
||||||
encoder_batch_map,
|
encoder_batch_map,
|
||||||
decoder_batch_map,
|
decoder_batch_map,
|
||||||
@@ -42,23 +41,56 @@ def _run_test_base(seq_lens_this_time_data):
|
|||||||
decoder_batch_idx,
|
decoder_batch_idx,
|
||||||
encoder_seq_lod,
|
encoder_seq_lod,
|
||||||
decoder_seq_lod,
|
decoder_seq_lod,
|
||||||
_,
|
encoder_kv_lod,
|
||||||
_,
|
prefix_len,
|
||||||
_,
|
decoder_context_len,
|
||||||
_,
|
decoder_context_len_cache,
|
||||||
_,
|
prefix_block_tables,
|
||||||
encoder_batch_map_cpu,
|
encoder_batch_map_cpu,
|
||||||
decoder_batch_map_cpu,
|
decoder_batch_map_cpu,
|
||||||
encoder_batch_idx_cpu,
|
encoder_batch_idx_cpu,
|
||||||
decoder_batch_idx_cpu,
|
decoder_batch_idx_cpu,
|
||||||
encoder_seq_lod_cpu,
|
encoder_seq_lod_cpu,
|
||||||
decoder_seq_lod_cpu,
|
decoder_seq_lod_cpu,
|
||||||
_,
|
encoder_kv_lod_cpu,
|
||||||
_,
|
prefix_len_cpu,
|
||||||
_,
|
decoder_context_len_cpu,
|
||||||
_,
|
decoder_context_len_cache_cpu,
|
||||||
len_info_cpu,
|
len_info_cpu,
|
||||||
) = infer_params
|
) = init_inplace_tensor(seq_lens_encoder.shape[0], block_table.shape)
|
||||||
|
(
|
||||||
|
slot_mapping_enc,
|
||||||
|
slot_mapping_dec,
|
||||||
|
) = get_infer_param(
|
||||||
|
seq_lens_encoder,
|
||||||
|
seq_lens_decoder,
|
||||||
|
seq_lens_this_time,
|
||||||
|
block_table,
|
||||||
|
encoder_batch_map,
|
||||||
|
decoder_batch_map,
|
||||||
|
encoder_batch_idx,
|
||||||
|
decoder_batch_idx,
|
||||||
|
encoder_seq_lod,
|
||||||
|
decoder_seq_lod,
|
||||||
|
encoder_kv_lod,
|
||||||
|
prefix_len,
|
||||||
|
decoder_context_len,
|
||||||
|
decoder_context_len_cache,
|
||||||
|
prefix_block_tables,
|
||||||
|
encoder_batch_map_cpu,
|
||||||
|
decoder_batch_map_cpu,
|
||||||
|
encoder_batch_idx_cpu,
|
||||||
|
decoder_batch_idx_cpu,
|
||||||
|
encoder_seq_lod_cpu,
|
||||||
|
decoder_seq_lod_cpu,
|
||||||
|
encoder_kv_lod_cpu,
|
||||||
|
prefix_len_cpu,
|
||||||
|
decoder_context_len_cpu,
|
||||||
|
decoder_context_len_cache_cpu,
|
||||||
|
len_info_cpu,
|
||||||
|
64,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
token_num = seq_lens_this_time.sum().cpu().item()
|
token_num = seq_lens_this_time.sum().cpu().item()
|
||||||
hidden_dim = 8192
|
hidden_dim = 8192
|
||||||
@@ -68,7 +100,6 @@ def _run_test_base(seq_lens_this_time_data):
|
|||||||
# test adjust_batch
|
# test adjust_batch
|
||||||
adjusted_output = adjust_batch(
|
adjusted_output = adjust_batch(
|
||||||
input_tensor,
|
input_tensor,
|
||||||
cum_offsets,
|
|
||||||
encoder_seq_lod,
|
encoder_seq_lod,
|
||||||
decoder_seq_lod,
|
decoder_seq_lod,
|
||||||
encoder_batch_idx,
|
encoder_batch_idx,
|
||||||
@@ -84,7 +115,6 @@ def _run_test_base(seq_lens_this_time_data):
|
|||||||
|
|
||||||
adjusted_output_cpu = adjust_batch(
|
adjusted_output_cpu = adjust_batch(
|
||||||
input_tensor.cpu(),
|
input_tensor.cpu(),
|
||||||
cum_offsets,
|
|
||||||
encoder_seq_lod,
|
encoder_seq_lod,
|
||||||
decoder_seq_lod,
|
decoder_seq_lod,
|
||||||
encoder_batch_idx,
|
encoder_batch_idx,
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import random
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
|
from utils import init_inplace_tensor
|
||||||
|
|
||||||
# block_attn_fused is deprecated and should be removed in the future
|
# block_attn_fused is deprecated and should be removed in the future
|
||||||
from fastdeploy.model_executor.ops.xpu import (
|
from fastdeploy.model_executor.ops.xpu import (
|
||||||
@@ -76,6 +77,7 @@ def run_prefix_cache_block_attn(
|
|||||||
# prefix cache block attn
|
# prefix cache block attn
|
||||||
seq_lens_encoder = paddle.to_tensor([seq_len - hit_prefix_len, 0, 0, 0, 0], dtype="int32")
|
seq_lens_encoder = paddle.to_tensor([seq_len - hit_prefix_len, 0, 0, 0, 0], dtype="int32")
|
||||||
seq_lens_decoder = paddle.to_tensor([hit_prefix_len, 0, 0, 0, 0], dtype="int32")
|
seq_lens_decoder = paddle.to_tensor([hit_prefix_len, 0, 0, 0, 0], dtype="int32")
|
||||||
|
|
||||||
(
|
(
|
||||||
encoder_batch_map,
|
encoder_batch_map,
|
||||||
decoder_batch_map,
|
decoder_batch_map,
|
||||||
@@ -99,11 +101,40 @@ def run_prefix_cache_block_attn(
|
|||||||
decoder_context_len_cpu,
|
decoder_context_len_cpu,
|
||||||
decoder_context_len_cache_cpu,
|
decoder_context_len_cache_cpu,
|
||||||
len_info_cpu,
|
len_info_cpu,
|
||||||
|
) = init_inplace_tensor(seq_lens_encoder.shape[0], block_tables.shape)
|
||||||
|
(
|
||||||
slot_mapping_enc,
|
slot_mapping_enc,
|
||||||
slot_mapping_dec,
|
slot_mapping_dec,
|
||||||
) = get_infer_param(
|
) = get_infer_param(
|
||||||
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, num_speculative_tokens
|
seq_lens_encoder,
|
||||||
) # block_size
|
seq_lens_decoder,
|
||||||
|
seq_lens_this_time,
|
||||||
|
block_tables,
|
||||||
|
encoder_batch_map,
|
||||||
|
decoder_batch_map,
|
||||||
|
encoder_batch_idx,
|
||||||
|
decoder_batch_idx,
|
||||||
|
encoder_seq_lod,
|
||||||
|
decoder_seq_lod,
|
||||||
|
encoder_kv_lod,
|
||||||
|
prefix_len,
|
||||||
|
decoder_context_len,
|
||||||
|
decoder_context_len_cache,
|
||||||
|
prefix_block_tables,
|
||||||
|
encoder_batch_map_cpu,
|
||||||
|
decoder_batch_map_cpu,
|
||||||
|
encoder_batch_idx_cpu,
|
||||||
|
decoder_batch_idx_cpu,
|
||||||
|
encoder_seq_lod_cpu,
|
||||||
|
decoder_seq_lod_cpu,
|
||||||
|
encoder_kv_lod_cpu,
|
||||||
|
prefix_len_cpu,
|
||||||
|
decoder_context_len_cpu,
|
||||||
|
decoder_context_len_cache_cpu,
|
||||||
|
len_info_cpu,
|
||||||
|
64,
|
||||||
|
num_speculative_tokens,
|
||||||
|
)
|
||||||
qkv_prefix = qkv[hit_prefix_len:]
|
qkv_prefix = qkv[hit_prefix_len:]
|
||||||
attn_out_prefix_cache = block_attn_func(
|
attn_out_prefix_cache = block_attn_func(
|
||||||
qkv_prefix,
|
qkv_prefix,
|
||||||
@@ -194,6 +225,7 @@ def run_block_attn(
|
|||||||
seq_lens_this_time = paddle.to_tensor([seq_len, 0, 0, 0, 0], dtype="int32")
|
seq_lens_this_time = paddle.to_tensor([seq_len, 0, 0, 0, 0], dtype="int32")
|
||||||
block_tables = paddle.arange(0, block_batch * max_block_per_seq, dtype="int32")
|
block_tables = paddle.arange(0, block_batch * max_block_per_seq, dtype="int32")
|
||||||
block_tables = block_tables.reshape((block_batch, max_block_per_seq))
|
block_tables = block_tables.reshape((block_batch, max_block_per_seq))
|
||||||
|
|
||||||
(
|
(
|
||||||
encoder_batch_map,
|
encoder_batch_map,
|
||||||
decoder_batch_map,
|
decoder_batch_map,
|
||||||
@@ -217,10 +249,39 @@ def run_block_attn(
|
|||||||
decoder_context_len_cpu,
|
decoder_context_len_cpu,
|
||||||
decoder_context_len_cache_cpu,
|
decoder_context_len_cache_cpu,
|
||||||
len_info_cpu,
|
len_info_cpu,
|
||||||
|
) = init_inplace_tensor(seq_lens_encoder.shape[0], block_tables.shape)
|
||||||
|
(
|
||||||
slot_mapping_enc,
|
slot_mapping_enc,
|
||||||
slot_mapping_dec,
|
slot_mapping_dec,
|
||||||
) = get_infer_param(
|
) = get_infer_param(
|
||||||
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, num_speculative_tokens
|
seq_lens_encoder,
|
||||||
|
seq_lens_decoder,
|
||||||
|
seq_lens_this_time,
|
||||||
|
block_tables,
|
||||||
|
encoder_batch_map,
|
||||||
|
decoder_batch_map,
|
||||||
|
encoder_batch_idx,
|
||||||
|
decoder_batch_idx,
|
||||||
|
encoder_seq_lod,
|
||||||
|
decoder_seq_lod,
|
||||||
|
encoder_kv_lod,
|
||||||
|
prefix_len,
|
||||||
|
decoder_context_len,
|
||||||
|
decoder_context_len_cache,
|
||||||
|
prefix_block_tables,
|
||||||
|
encoder_batch_map_cpu,
|
||||||
|
decoder_batch_map_cpu,
|
||||||
|
encoder_batch_idx_cpu,
|
||||||
|
decoder_batch_idx_cpu,
|
||||||
|
encoder_seq_lod_cpu,
|
||||||
|
decoder_seq_lod_cpu,
|
||||||
|
encoder_kv_lod_cpu,
|
||||||
|
prefix_len_cpu,
|
||||||
|
decoder_context_len_cpu,
|
||||||
|
decoder_context_len_cache_cpu,
|
||||||
|
len_info_cpu,
|
||||||
|
64,
|
||||||
|
num_speculative_tokens,
|
||||||
)
|
)
|
||||||
qkv = paddle.uniform(
|
qkv = paddle.uniform(
|
||||||
shape=[seq_len, (head_num + 2 * kv_head_num) * head_dim], dtype="bfloat16", min=-1.0, max=1.0, seed=seed
|
shape=[seq_len, (head_num + 2 * kv_head_num) * head_dim], dtype="bfloat16", min=-1.0, max=1.0, seed=seed
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
|
from utils import init_inplace_tensor
|
||||||
|
|
||||||
from fastdeploy.model_executor.ops.xpu import block_attn_fused, get_infer_param
|
from fastdeploy.model_executor.ops.xpu import block_attn_fused, get_infer_param
|
||||||
|
|
||||||
@@ -24,6 +25,7 @@ seq_len = 128
|
|||||||
block_batch = 5
|
block_batch = 5
|
||||||
max_block_per_seq = 128
|
max_block_per_seq = 128
|
||||||
block_size = 64
|
block_size = 64
|
||||||
|
num_speculative_tokens = 0
|
||||||
|
|
||||||
seq_lens_encoder = paddle.to_tensor([128, 0, 0, 0, 0], dtype="int32")
|
seq_lens_encoder = paddle.to_tensor([128, 0, 0, 0, 0], dtype="int32")
|
||||||
seq_lens_decoder = paddle.to_tensor([0, 0, 0, 0, 0], dtype="int32")
|
seq_lens_decoder = paddle.to_tensor([0, 0, 0, 0, 0], dtype="int32")
|
||||||
@@ -53,11 +55,40 @@ block_tables = block_tables.reshape((block_batch, max_block_per_seq))
|
|||||||
decoder_context_len_cpu,
|
decoder_context_len_cpu,
|
||||||
decoder_context_len_cache_cpu,
|
decoder_context_len_cache_cpu,
|
||||||
len_info_cpu,
|
len_info_cpu,
|
||||||
|
) = init_inplace_tensor(seq_lens_encoder.shape[0], block_tables.shape)
|
||||||
|
(
|
||||||
slot_mapping_enc,
|
slot_mapping_enc,
|
||||||
slot_mapping_dec,
|
slot_mapping_dec,
|
||||||
) = get_infer_param(
|
) = get_infer_param(
|
||||||
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, 0
|
seq_lens_encoder,
|
||||||
) # block_size
|
seq_lens_decoder,
|
||||||
|
seq_lens_this_time,
|
||||||
|
block_tables,
|
||||||
|
encoder_batch_map,
|
||||||
|
decoder_batch_map,
|
||||||
|
encoder_batch_idx,
|
||||||
|
decoder_batch_idx,
|
||||||
|
encoder_seq_lod,
|
||||||
|
decoder_seq_lod,
|
||||||
|
encoder_kv_lod,
|
||||||
|
prefix_len,
|
||||||
|
decoder_context_len,
|
||||||
|
decoder_context_len_cache,
|
||||||
|
prefix_block_tables,
|
||||||
|
encoder_batch_map_cpu,
|
||||||
|
decoder_batch_map_cpu,
|
||||||
|
encoder_batch_idx_cpu,
|
||||||
|
decoder_batch_idx_cpu,
|
||||||
|
encoder_seq_lod_cpu,
|
||||||
|
decoder_seq_lod_cpu,
|
||||||
|
encoder_kv_lod_cpu,
|
||||||
|
prefix_len_cpu,
|
||||||
|
decoder_context_len_cpu,
|
||||||
|
decoder_context_len_cache_cpu,
|
||||||
|
len_info_cpu,
|
||||||
|
64,
|
||||||
|
num_speculative_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
qkv = paddle.uniform(
|
qkv = paddle.uniform(
|
||||||
shape=[seq_len, (head_num + 2 * kv_head_num) * head_dim],
|
shape=[seq_len, (head_num + 2 * kv_head_num) * head_dim],
|
||||||
@@ -247,11 +278,40 @@ seq_lens_decoder = paddle.to_tensor([hit_prefix_len, 0, 0, 0, 0], dtype="int32")
|
|||||||
decoder_context_len_cpu,
|
decoder_context_len_cpu,
|
||||||
decoder_context_len_cache_cpu,
|
decoder_context_len_cache_cpu,
|
||||||
len_info_cpu,
|
len_info_cpu,
|
||||||
|
) = init_inplace_tensor(seq_lens_encoder.shape[0], block_tables.shape)
|
||||||
|
(
|
||||||
slot_mapping_enc,
|
slot_mapping_enc,
|
||||||
slot_mapping_dec,
|
slot_mapping_dec,
|
||||||
) = get_infer_param(
|
) = get_infer_param(
|
||||||
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, 0
|
seq_lens_encoder,
|
||||||
) # block_size
|
seq_lens_decoder,
|
||||||
|
seq_lens_this_time,
|
||||||
|
block_tables,
|
||||||
|
encoder_batch_map,
|
||||||
|
decoder_batch_map,
|
||||||
|
encoder_batch_idx,
|
||||||
|
decoder_batch_idx,
|
||||||
|
encoder_seq_lod,
|
||||||
|
decoder_seq_lod,
|
||||||
|
encoder_kv_lod,
|
||||||
|
prefix_len,
|
||||||
|
decoder_context_len,
|
||||||
|
decoder_context_len_cache,
|
||||||
|
prefix_block_tables,
|
||||||
|
encoder_batch_map_cpu,
|
||||||
|
decoder_batch_map_cpu,
|
||||||
|
encoder_batch_idx_cpu,
|
||||||
|
decoder_batch_idx_cpu,
|
||||||
|
encoder_seq_lod_cpu,
|
||||||
|
decoder_seq_lod_cpu,
|
||||||
|
encoder_kv_lod_cpu,
|
||||||
|
prefix_len_cpu,
|
||||||
|
decoder_context_len_cpu,
|
||||||
|
decoder_context_len_cache_cpu,
|
||||||
|
len_info_cpu,
|
||||||
|
64,
|
||||||
|
num_speculative_tokens,
|
||||||
|
)
|
||||||
qkv_prefix = qkv[hit_prefix_len:]
|
qkv_prefix = qkv[hit_prefix_len:]
|
||||||
|
|
||||||
attn_out_prefix_cache = block_attn_fused(
|
attn_out_prefix_cache = block_attn_fused(
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
|
from utils import init_inplace_tensor
|
||||||
|
|
||||||
from fastdeploy.model_executor.ops.xpu import get_infer_param
|
from fastdeploy.model_executor.ops.xpu import get_infer_param
|
||||||
|
|
||||||
@@ -21,6 +22,7 @@ seq_lens_decoder = paddle.to_tensor([0, 5, 0, 25, 64], dtype="int32")
|
|||||||
seq_lens_this_time = paddle.to_tensor([100, 1, 0, 1, 300], dtype="int32")
|
seq_lens_this_time = paddle.to_tensor([100, 1, 0, 1, 300], dtype="int32")
|
||||||
block_table = paddle.arange(0, 40, dtype="int32")
|
block_table = paddle.arange(0, 40, dtype="int32")
|
||||||
block_table = block_table.reshape((5, 8))
|
block_table = block_table.reshape((5, 8))
|
||||||
|
|
||||||
(
|
(
|
||||||
encoder_batch_map,
|
encoder_batch_map,
|
||||||
decoder_batch_map,
|
decoder_batch_map,
|
||||||
@@ -44,9 +46,40 @@ block_table = block_table.reshape((5, 8))
|
|||||||
decoder_context_len_cpu,
|
decoder_context_len_cpu,
|
||||||
decoder_context_len_cache_cpu,
|
decoder_context_len_cache_cpu,
|
||||||
len_info_cpu,
|
len_info_cpu,
|
||||||
|
) = init_inplace_tensor(seq_lens_encoder.shape[0], block_table.shape)
|
||||||
|
(
|
||||||
|
slot_mapping_enc,
|
||||||
|
slot_mapping_dec,
|
||||||
) = get_infer_param(
|
) = get_infer_param(
|
||||||
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_table, 64
|
seq_lens_encoder,
|
||||||
) # block_size
|
seq_lens_decoder,
|
||||||
|
seq_lens_this_time,
|
||||||
|
block_table,
|
||||||
|
encoder_batch_map,
|
||||||
|
decoder_batch_map,
|
||||||
|
encoder_batch_idx,
|
||||||
|
decoder_batch_idx,
|
||||||
|
encoder_seq_lod,
|
||||||
|
decoder_seq_lod,
|
||||||
|
encoder_kv_lod,
|
||||||
|
prefix_len,
|
||||||
|
decoder_context_len,
|
||||||
|
decoder_context_len_cache,
|
||||||
|
prefix_block_tables,
|
||||||
|
encoder_batch_map_cpu,
|
||||||
|
decoder_batch_map_cpu,
|
||||||
|
encoder_batch_idx_cpu,
|
||||||
|
decoder_batch_idx_cpu,
|
||||||
|
encoder_seq_lod_cpu,
|
||||||
|
decoder_seq_lod_cpu,
|
||||||
|
encoder_kv_lod_cpu,
|
||||||
|
prefix_len_cpu,
|
||||||
|
decoder_context_len_cpu,
|
||||||
|
decoder_context_len_cache_cpu,
|
||||||
|
len_info_cpu,
|
||||||
|
64,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
print("block_table", block_table)
|
print("block_table", block_table)
|
||||||
print("encoder_batch_map", encoder_batch_map) # [0, 4, 0, 0, 0]
|
print("encoder_batch_map", encoder_batch_map) # [0, 4, 0, 0, 0]
|
||||||
|
|||||||
@@ -0,0 +1,54 @@
|
|||||||
|
import paddle
|
||||||
|
|
||||||
|
|
||||||
|
def init_inplace_tensor(bsz, block_tables_shape):
|
||||||
|
encoder_batch_map = paddle.empty(bsz, dtype="int32")
|
||||||
|
decoder_batch_map = paddle.empty(bsz, dtype="int32")
|
||||||
|
encoder_batch_idx = paddle.empty(bsz, dtype="int32")
|
||||||
|
decoder_batch_idx = paddle.empty(bsz, dtype="int32")
|
||||||
|
encoder_seq_lod = paddle.empty(bsz + 1, dtype="int32")
|
||||||
|
decoder_seq_lod = paddle.empty(bsz + 1, dtype="int32")
|
||||||
|
encoder_kv_lod = paddle.empty(bsz + 1, dtype="int32")
|
||||||
|
prefix_len = paddle.empty(bsz, dtype="int32")
|
||||||
|
decoder_context_len = paddle.empty(bsz, dtype="int32")
|
||||||
|
decoder_context_len_cache = paddle.empty(bsz, dtype="int32")
|
||||||
|
|
||||||
|
prefix_block_tables = paddle.empty(block_tables_shape, dtype="int32")
|
||||||
|
|
||||||
|
encoder_batch_map_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
|
||||||
|
decoder_batch_map_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
|
||||||
|
encoder_batch_idx_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
|
||||||
|
decoder_batch_idx_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
|
||||||
|
encoder_seq_lod_cpu = paddle.empty(bsz + 1, dtype="int32", device="cpu")
|
||||||
|
decoder_seq_lod_cpu = paddle.empty(bsz + 1, dtype="int32", device="cpu")
|
||||||
|
encoder_kv_lod_cpu = paddle.empty(bsz + 1, dtype="int32", device="cpu")
|
||||||
|
prefix_len_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
|
||||||
|
decoder_context_len_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
|
||||||
|
decoder_context_len_cache_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
|
||||||
|
|
||||||
|
len_info_cpu = paddle.empty(7, dtype="int32", device="cpu")
|
||||||
|
|
||||||
|
return (
|
||||||
|
encoder_batch_map,
|
||||||
|
decoder_batch_map,
|
||||||
|
encoder_batch_idx,
|
||||||
|
decoder_batch_idx,
|
||||||
|
encoder_seq_lod,
|
||||||
|
decoder_seq_lod,
|
||||||
|
encoder_kv_lod,
|
||||||
|
prefix_len,
|
||||||
|
decoder_context_len,
|
||||||
|
decoder_context_len_cache,
|
||||||
|
prefix_block_tables,
|
||||||
|
encoder_batch_map_cpu,
|
||||||
|
decoder_batch_map_cpu,
|
||||||
|
encoder_batch_idx_cpu,
|
||||||
|
decoder_batch_idx_cpu,
|
||||||
|
encoder_seq_lod_cpu,
|
||||||
|
decoder_seq_lod_cpu,
|
||||||
|
encoder_kv_lod_cpu,
|
||||||
|
prefix_len_cpu,
|
||||||
|
decoder_context_len_cpu,
|
||||||
|
decoder_context_len_cache_cpu,
|
||||||
|
len_info_cpu,
|
||||||
|
)
|
||||||
@@ -289,6 +289,33 @@ class XPUForwardMeta(ForwardMeta):
|
|||||||
#
|
#
|
||||||
slot_mapping_dec: Optional[paddle.Tensor] = None
|
slot_mapping_dec: Optional[paddle.Tensor] = None
|
||||||
|
|
||||||
|
def init_inplace_tensor(self, bsz, block_tables_shape):
|
||||||
|
self.encoder_batch_map = paddle.empty(bsz, dtype="int32")
|
||||||
|
self.decoder_batch_map = paddle.empty(bsz, dtype="int32")
|
||||||
|
self.encoder_batch_idx = paddle.empty(bsz, dtype="int32")
|
||||||
|
self.decoder_batch_idx = paddle.empty(bsz, dtype="int32")
|
||||||
|
self.encoder_seq_lod = paddle.empty(bsz + 1, dtype="int32")
|
||||||
|
self.decoder_seq_lod = paddle.empty(bsz + 1, dtype="int32")
|
||||||
|
self.encoder_kv_lod = paddle.empty(bsz + 1, dtype="int32")
|
||||||
|
self.prefix_len = paddle.empty(bsz, dtype="int32")
|
||||||
|
self.decoder_context_len = paddle.empty(bsz, dtype="int32")
|
||||||
|
self.decoder_context_len_cache = paddle.empty(bsz, dtype="int32")
|
||||||
|
|
||||||
|
self.prefix_block_tables = paddle.empty(block_tables_shape, dtype="int32")
|
||||||
|
|
||||||
|
self.encoder_batch_map_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
|
||||||
|
self.decoder_batch_map_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
|
||||||
|
self.encoder_batch_idx_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
|
||||||
|
self.decoder_batch_idx_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
|
||||||
|
self.encoder_seq_lod_cpu = paddle.empty(bsz + 1, dtype="int32", device="cpu")
|
||||||
|
self.decoder_seq_lod_cpu = paddle.empty(bsz + 1, dtype="int32", device="cpu")
|
||||||
|
self.encoder_kv_lod_cpu = paddle.empty(bsz + 1, dtype="int32", device="cpu")
|
||||||
|
self.prefix_len_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
|
||||||
|
self.decoder_context_len_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
|
||||||
|
self.decoder_context_len_cache_cpu = paddle.empty(bsz, dtype="int32", device="cpu")
|
||||||
|
|
||||||
|
self.len_info_cpu = paddle.empty(7, dtype="int32", device="cpu")
|
||||||
|
|
||||||
def copy_from(self, other: "XPUForwardMeta", skip_keys: Optional[list] = None):
|
def copy_from(self, other: "XPUForwardMeta", skip_keys: Optional[list] = None):
|
||||||
"""
|
"""
|
||||||
Synchronize attributes from another XPUForwardMeta object
|
Synchronize attributes from another XPUForwardMeta object
|
||||||
|
|||||||
@@ -158,6 +158,23 @@ def xpu_pre_process(
|
|||||||
share_inputs["cu_seqlens_q"] = cu_seqlens_q
|
share_inputs["cu_seqlens_q"] = cu_seqlens_q
|
||||||
share_inputs["cu_seqlens_k"] = cu_seqlens_k
|
share_inputs["cu_seqlens_k"] = cu_seqlens_k
|
||||||
|
|
||||||
|
if use_cudagraph and forward_meta is not None:
|
||||||
|
forward_meta.ids_remove_padding.copy_(share_inputs["ids_remove_padding"], False)
|
||||||
|
forward_meta.rotary_embs.copy_(share_inputs["rope_emb"], False)
|
||||||
|
forward_meta.attn_backend = None
|
||||||
|
forward_meta.seq_lens_encoder.copy_(share_inputs["seq_lens_encoder"], False)
|
||||||
|
forward_meta.seq_lens_decoder.copy_(share_inputs["seq_lens_decoder"], False)
|
||||||
|
forward_meta.seq_lens_this_time.copy_(share_inputs["seq_lens_this_time"], False)
|
||||||
|
forward_meta.batch_id_per_token.copy_(share_inputs["batch_id_per_token"], False)
|
||||||
|
forward_meta.cu_seqlens_q.copy_(share_inputs["cu_seqlens_q"], False)
|
||||||
|
forward_meta.cu_seqlens_k.copy_(share_inputs["cu_seqlens_k"], False)
|
||||||
|
forward_meta.block_tables.copy_(share_inputs["block_tables"], False)
|
||||||
|
forward_meta.caches = share_inputs["caches"]
|
||||||
|
forward_meta.max_num_seqs = share_inputs["seq_lens_this_time"].shape[0]
|
||||||
|
forward_meta.is_speculative = use_speculate_method
|
||||||
|
|
||||||
|
xpu_forward_meta = forward_meta
|
||||||
|
else:
|
||||||
xpu_forward_meta = XPUForwardMeta(
|
xpu_forward_meta = XPUForwardMeta(
|
||||||
ids_remove_padding=share_inputs["ids_remove_padding"],
|
ids_remove_padding=share_inputs["ids_remove_padding"],
|
||||||
rotary_embs=share_inputs["rope_emb"],
|
rotary_embs=share_inputs["rope_emb"],
|
||||||
@@ -173,55 +190,81 @@ def xpu_pre_process(
|
|||||||
max_num_seqs=share_inputs["seq_lens_this_time"].shape[0],
|
max_num_seqs=share_inputs["seq_lens_this_time"].shape[0],
|
||||||
is_speculative=use_speculate_method,
|
is_speculative=use_speculate_method,
|
||||||
)
|
)
|
||||||
|
xpu_forward_meta.init_inplace_tensor(seq_lens_encoder.shape[0], share_inputs["block_tables"].shape)
|
||||||
|
|
||||||
|
block_tables = xpu_forward_meta.block_tables
|
||||||
|
|
||||||
|
encoder_batch_map = xpu_forward_meta.encoder_batch_map
|
||||||
|
decoder_batch_map = xpu_forward_meta.decoder_batch_map
|
||||||
|
encoder_batch_idx = xpu_forward_meta.encoder_batch_idx
|
||||||
|
decoder_batch_idx = xpu_forward_meta.decoder_batch_idx
|
||||||
|
encoder_seq_lod = xpu_forward_meta.encoder_seq_lod
|
||||||
|
decoder_seq_lod = xpu_forward_meta.decoder_seq_lod
|
||||||
|
encoder_kv_lod = xpu_forward_meta.encoder_kv_lod
|
||||||
|
prefix_len = xpu_forward_meta.prefix_len
|
||||||
|
decoder_context_len = xpu_forward_meta.decoder_context_len
|
||||||
|
decoder_context_len_cache = xpu_forward_meta.decoder_context_len_cache
|
||||||
|
|
||||||
|
prefix_block_tables = xpu_forward_meta.prefix_block_tables
|
||||||
|
|
||||||
|
encoder_batch_map_cpu = xpu_forward_meta.encoder_batch_map_cpu
|
||||||
|
decoder_batch_map_cpu = xpu_forward_meta.decoder_batch_map_cpu
|
||||||
|
encoder_batch_idx_cpu = xpu_forward_meta.encoder_batch_idx_cpu
|
||||||
|
decoder_batch_idx_cpu = xpu_forward_meta.decoder_batch_idx_cpu
|
||||||
|
encoder_seq_lod_cpu = xpu_forward_meta.encoder_seq_lod_cpu
|
||||||
|
decoder_seq_lod_cpu = xpu_forward_meta.decoder_seq_lod_cpu
|
||||||
|
encoder_kv_lod_cpu = xpu_forward_meta.encoder_kv_lod_cpu
|
||||||
|
prefix_len_cpu = xpu_forward_meta.prefix_len_cpu
|
||||||
|
decoder_context_len_cpu = xpu_forward_meta.decoder_context_len_cpu
|
||||||
|
decoder_context_len_cache_cpu = xpu_forward_meta.decoder_context_len_cache_cpu
|
||||||
|
|
||||||
|
len_info_cpu = xpu_forward_meta.len_info_cpu
|
||||||
|
|
||||||
(
|
(
|
||||||
xpu_forward_meta.encoder_batch_map,
|
slot_mapping_enc,
|
||||||
xpu_forward_meta.decoder_batch_map,
|
slot_mapping_dec,
|
||||||
xpu_forward_meta.encoder_batch_idx,
|
|
||||||
xpu_forward_meta.decoder_batch_idx,
|
|
||||||
xpu_forward_meta.encoder_seq_lod,
|
|
||||||
xpu_forward_meta.decoder_seq_lod,
|
|
||||||
xpu_forward_meta.encoder_kv_lod,
|
|
||||||
xpu_forward_meta.prefix_len,
|
|
||||||
xpu_forward_meta.decoder_context_len,
|
|
||||||
xpu_forward_meta.decoder_context_len_cache,
|
|
||||||
xpu_forward_meta.prefix_block_tables,
|
|
||||||
xpu_forward_meta.encoder_batch_map_cpu,
|
|
||||||
xpu_forward_meta.decoder_batch_map_cpu,
|
|
||||||
xpu_forward_meta.encoder_batch_idx_cpu,
|
|
||||||
xpu_forward_meta.decoder_batch_idx_cpu,
|
|
||||||
xpu_forward_meta.encoder_seq_lod_cpu,
|
|
||||||
xpu_forward_meta.decoder_seq_lod_cpu,
|
|
||||||
xpu_forward_meta.encoder_kv_lod_cpu,
|
|
||||||
xpu_forward_meta.prefix_len_cpu,
|
|
||||||
xpu_forward_meta.decoder_context_len_cpu,
|
|
||||||
xpu_forward_meta.decoder_context_len_cache_cpu,
|
|
||||||
xpu_forward_meta.len_info_cpu,
|
|
||||||
xpu_forward_meta.slot_mapping_enc,
|
|
||||||
xpu_forward_meta.slot_mapping_dec,
|
|
||||||
) = get_infer_param(
|
) = get_infer_param(
|
||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
seq_lens_decoder,
|
seq_lens_decoder,
|
||||||
seq_lens_this_time,
|
seq_lens_this_time,
|
||||||
xpu_forward_meta.block_tables,
|
block_tables,
|
||||||
|
encoder_batch_map,
|
||||||
|
decoder_batch_map,
|
||||||
|
encoder_batch_idx,
|
||||||
|
decoder_batch_idx,
|
||||||
|
encoder_seq_lod,
|
||||||
|
decoder_seq_lod,
|
||||||
|
encoder_kv_lod,
|
||||||
|
prefix_len,
|
||||||
|
decoder_context_len,
|
||||||
|
decoder_context_len_cache,
|
||||||
|
prefix_block_tables,
|
||||||
|
encoder_batch_map_cpu,
|
||||||
|
decoder_batch_map_cpu,
|
||||||
|
encoder_batch_idx_cpu,
|
||||||
|
decoder_batch_idx_cpu,
|
||||||
|
encoder_seq_lod_cpu,
|
||||||
|
decoder_seq_lod_cpu,
|
||||||
|
encoder_kv_lod_cpu,
|
||||||
|
prefix_len_cpu,
|
||||||
|
decoder_context_len_cpu,
|
||||||
|
decoder_context_len_cache_cpu,
|
||||||
|
len_info_cpu,
|
||||||
block_size,
|
block_size,
|
||||||
num_speculative_tokens,
|
num_speculative_tokens,
|
||||||
)
|
)
|
||||||
xpu_forward_meta.enc_batch = xpu_forward_meta.len_info_cpu[0]
|
|
||||||
xpu_forward_meta.dec_batch = xpu_forward_meta.len_info_cpu[1]
|
|
||||||
xpu_forward_meta.total_enc_len = xpu_forward_meta.len_info_cpu[2]
|
|
||||||
|
|
||||||
adjusted_input = adjust_batch(
|
adjusted_input = adjust_batch(
|
||||||
ids_remove_padding.reshape([-1, 1]),
|
ids_remove_padding.reshape([-1, 1]),
|
||||||
xpu_forward_meta.encoder_seq_lod,
|
encoder_seq_lod,
|
||||||
xpu_forward_meta.decoder_seq_lod,
|
decoder_seq_lod,
|
||||||
xpu_forward_meta.encoder_batch_idx,
|
encoder_batch_idx,
|
||||||
xpu_forward_meta.decoder_batch_idx,
|
decoder_batch_idx,
|
||||||
xpu_forward_meta.encoder_seq_lod_cpu,
|
encoder_seq_lod_cpu,
|
||||||
xpu_forward_meta.decoder_seq_lod_cpu,
|
decoder_seq_lod_cpu,
|
||||||
xpu_forward_meta.encoder_batch_idx_cpu,
|
encoder_batch_idx_cpu,
|
||||||
xpu_forward_meta.decoder_batch_idx_cpu,
|
decoder_batch_idx_cpu,
|
||||||
xpu_forward_meta.len_info_cpu,
|
len_info_cpu,
|
||||||
None, # output_padding_offset
|
None, # output_padding_offset
|
||||||
-1, # max bs
|
-1, # max bs
|
||||||
)
|
)
|
||||||
@@ -229,16 +272,21 @@ def xpu_pre_process(
|
|||||||
adjusted_input = adjusted_input.squeeze(1)
|
adjusted_input = adjusted_input.squeeze(1)
|
||||||
|
|
||||||
share_inputs["ids_remove_padding"].copy_(adjusted_input, False)
|
share_inputs["ids_remove_padding"].copy_(adjusted_input, False)
|
||||||
|
|
||||||
|
xpu_forward_meta.enc_batch = len_info_cpu[0]
|
||||||
|
xpu_forward_meta.dec_batch = len_info_cpu[1]
|
||||||
|
xpu_forward_meta.total_enc_len = len_info_cpu[2]
|
||||||
xpu_forward_meta.ids_remove_padding = adjusted_input
|
xpu_forward_meta.ids_remove_padding = adjusted_input
|
||||||
# Set forward_meta.is_profiling to True to skip init_kv_signal_per_query for attention backends
|
# Set xpu_forward_meta.is_profiling to True to skip init_kv_signal_per_query for attention backends
|
||||||
xpu_forward_meta.is_profiling = is_profiling
|
xpu_forward_meta.is_profiling = is_profiling
|
||||||
if use_cudagraph:
|
|
||||||
if forward_meta is None:
|
# prefill does not use cudagraph, inplace copy is not needed
|
||||||
return xpu_forward_meta
|
xpu_forward_meta.slot_mapping_enc = slot_mapping_enc
|
||||||
else:
|
if use_cudagraph and forward_meta is not None:
|
||||||
forward_meta.copy_from(xpu_forward_meta)
|
xpu_forward_meta.slot_mapping_dec.copy_(slot_mapping_dec, False)
|
||||||
return forward_meta
|
|
||||||
else:
|
else:
|
||||||
|
xpu_forward_meta.slot_mapping_dec = slot_mapping_dec
|
||||||
|
|
||||||
return xpu_forward_meta
|
return xpu_forward_meta
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user