mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
fix bug for EP+MTP (#5605)
Co-authored-by: lizan1999 <lizan03@baidu.com>
This commit is contained in:
@@ -71,16 +71,18 @@ std::vector<paddle::Tensor> AdjustBatchKernel(
|
||||
const_cast<int32_t *>(decoder_batch_idx.data<int32_t>())};
|
||||
|
||||
auto out = paddle::empty({token_num, dim}, x.type(), x.place());
|
||||
|
||||
int r = baidu::xpu::api::plugin::eb_adjust_batch<XPUType, XPUType>(
|
||||
ctx,
|
||||
reinterpret_cast<const XPUType *>(x.data<data_t>()),
|
||||
reinterpret_cast<XPUType *>(out.data<data_t>()),
|
||||
encoder_seqs_lods_vp,
|
||||
decoder_seqs_lods_vp,
|
||||
encoder_batch_map_vp,
|
||||
decoder_batch_map_vp,
|
||||
dim);
|
||||
if (token_num > 0) {
|
||||
int r = baidu::xpu::api::plugin::eb_adjust_batch<XPUType, XPUType>(
|
||||
ctx,
|
||||
reinterpret_cast<const XPUType *>(x.data<data_t>()),
|
||||
reinterpret_cast<XPUType *>(out.data<data_t>()),
|
||||
encoder_seqs_lods_vp,
|
||||
decoder_seqs_lods_vp,
|
||||
encoder_batch_map_vp,
|
||||
decoder_batch_map_vp,
|
||||
dim);
|
||||
PD_CHECK(r == 0, "XPU eb_adjust_batch failed");
|
||||
}
|
||||
return {out};
|
||||
}
|
||||
|
||||
|
||||
@@ -57,31 +57,33 @@ std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
|
||||
"Cum offsets tensor must be contiguous");
|
||||
PD_CHECK(seq_len.is_contiguous(), "Seq lens tensor must be contiguous");
|
||||
|
||||
int r = baidu::xpu::api::plugin::speculate_get_padding_offset(
|
||||
xpu_ctx->x_context(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cum_offsets_out.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
seq_len.data<int>(),
|
||||
seq_length,
|
||||
bsz);
|
||||
PD_CHECK(r == 0, "XPU speculate_get_padding_offset failed");
|
||||
if (token_num_data > 0) {
|
||||
int r = baidu::xpu::api::plugin::speculate_get_padding_offset(
|
||||
xpu_ctx->x_context(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cum_offsets_out.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
seq_len.data<int>(),
|
||||
seq_length,
|
||||
bsz);
|
||||
PD_CHECK(r == 0, "XPU speculate_get_padding_offset failed");
|
||||
|
||||
r = baidu::xpu::api::plugin::speculate_remove_padding<int64_t>(
|
||||
xpu_ctx->x_context(),
|
||||
x_remove_padding.data<int64_t>(),
|
||||
input_ids.data<int64_t>(),
|
||||
draft_tokens.data<int64_t>(),
|
||||
seq_len.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cum_offsets_out.data<int>(),
|
||||
seq_length,
|
||||
max_draft_tokens,
|
||||
bsz,
|
||||
token_num_data);
|
||||
PD_CHECK(r == 0, "XPU speculate_remove_padding failed");
|
||||
r = baidu::xpu::api::plugin::speculate_remove_padding<int64_t>(
|
||||
xpu_ctx->x_context(),
|
||||
x_remove_padding.data<int64_t>(),
|
||||
input_ids.data<int64_t>(),
|
||||
draft_tokens.data<int64_t>(),
|
||||
seq_len.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cum_offsets_out.data<int>(),
|
||||
seq_length,
|
||||
max_draft_tokens,
|
||||
bsz,
|
||||
token_num_data);
|
||||
PD_CHECK(r == 0, "XPU speculate_remove_padding failed");
|
||||
}
|
||||
|
||||
return {x_remove_padding,
|
||||
cum_offsets_out,
|
||||
|
||||
Reference in New Issue
Block a user