fix bug for EP+MTP (#5605)

Co-authored-by: lizan1999 <lizan03@baidu.com>
This commit is contained in:
lizan1999
2025-12-18 14:34:54 +08:00
committed by GitHub
parent d8587e987e
commit e1a9b282eb
3 changed files with 38 additions and 42 deletions
+12 -10
View File
@@ -71,16 +71,18 @@ std::vector<paddle::Tensor> AdjustBatchKernel(
const_cast<int32_t *>(decoder_batch_idx.data<int32_t>())};
auto out = paddle::empty({token_num, dim}, x.type(), x.place());
int r = baidu::xpu::api::plugin::eb_adjust_batch<XPUType, XPUType>(
ctx,
reinterpret_cast<const XPUType *>(x.data<data_t>()),
reinterpret_cast<XPUType *>(out.data<data_t>()),
encoder_seqs_lods_vp,
decoder_seqs_lods_vp,
encoder_batch_map_vp,
decoder_batch_map_vp,
dim);
if (token_num > 0) {
int r = baidu::xpu::api::plugin::eb_adjust_batch<XPUType, XPUType>(
ctx,
reinterpret_cast<const XPUType *>(x.data<data_t>()),
reinterpret_cast<XPUType *>(out.data<data_t>()),
encoder_seqs_lods_vp,
decoder_seqs_lods_vp,
encoder_batch_map_vp,
decoder_batch_map_vp,
dim);
PD_CHECK(r == 0, "XPU eb_adjust_batch failed");
}
return {out};
}
@@ -57,31 +57,33 @@ std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
"Cum offsets tensor must be contiguous");
PD_CHECK(seq_len.is_contiguous(), "Seq lens tensor must be contiguous");
int r = baidu::xpu::api::plugin::speculate_get_padding_offset(
xpu_ctx->x_context(),
batch_id_per_token.data<int>(),
cum_offsets_out.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
cum_offsets.data<int>(),
seq_len.data<int>(),
seq_length,
bsz);
PD_CHECK(r == 0, "XPU speculate_get_padding_offset failed");
if (token_num_data > 0) {
int r = baidu::xpu::api::plugin::speculate_get_padding_offset(
xpu_ctx->x_context(),
batch_id_per_token.data<int>(),
cum_offsets_out.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
cum_offsets.data<int>(),
seq_len.data<int>(),
seq_length,
bsz);
PD_CHECK(r == 0, "XPU speculate_get_padding_offset failed");
r = baidu::xpu::api::plugin::speculate_remove_padding<int64_t>(
xpu_ctx->x_context(),
x_remove_padding.data<int64_t>(),
input_ids.data<int64_t>(),
draft_tokens.data<int64_t>(),
seq_len.data<int>(),
seq_lens_encoder.data<int>(),
cum_offsets_out.data<int>(),
seq_length,
max_draft_tokens,
bsz,
token_num_data);
PD_CHECK(r == 0, "XPU speculate_remove_padding failed");
r = baidu::xpu::api::plugin::speculate_remove_padding<int64_t>(
xpu_ctx->x_context(),
x_remove_padding.data<int64_t>(),
input_ids.data<int64_t>(),
draft_tokens.data<int64_t>(),
seq_len.data<int>(),
seq_lens_encoder.data<int>(),
cum_offsets_out.data<int>(),
seq_length,
max_draft_tokens,
bsz,
token_num_data);
PD_CHECK(r == 0, "XPU speculate_remove_padding failed");
}
return {x_remove_padding,
cum_offsets_out,