remove speculate_get_padding_offset op (#6308)

This commit is contained in:
周周周
2026-02-03 15:18:12 +08:00
committed by GitHub
parent 39dc4b0c2e
commit 8277b95fa6
7 changed files with 44 additions and 324 deletions
+34 -8
View File
@@ -25,7 +25,10 @@ __global__ void PrefixSumKernel(int64_t *ids_remove_padding,
int *cu_seqlens_k,
const int64_t *input_data,
const int *seq_lens,
const int max_seq_len) {
const int max_seq_len,
const int64_t *draft_tokens,
const int *seq_lens_encoder,
const int max_draft_tokens_per_batch) {
const int bi = blockIdx.x;
const int tid = threadIdx.x;
#ifdef PADDLE_WITH_COREX
@@ -62,15 +65,26 @@ __global__ void PrefixSumKernel(int64_t *ids_remove_padding,
for (int i = tid; i < seq_lens[bi]; i += blockDim.x) {
const int tgt_seq_id = cum_seq_len - seq_lens[bi] + i;
const int src_seq_id = bi * max_seq_len + i;
ids_remove_padding[tgt_seq_id] = input_data[src_seq_id];
if (max_draft_tokens_per_batch > 0 && seq_lens_encoder[bi] <= 0) {
// speculative decoding
const int src_seq_id = bi * max_draft_tokens_per_batch + i;
ids_remove_padding[tgt_seq_id] = draft_tokens[src_seq_id];
} else {
// Non-speculative decoding
const int src_seq_id = bi * max_seq_len + i;
ids_remove_padding[tgt_seq_id] = input_data[src_seq_id];
}
batch_id_per_token[tgt_seq_id] = bi;
}
}
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
const paddle::Tensor &seq_len,
const int64_t cpu_token_num) {
std::vector<paddle::Tensor> GetPaddingOffset(
const paddle::Tensor &input_ids,
const paddle::Tensor &seq_len,
const paddle::optional<paddle::Tensor> &draft_tokens,
const paddle::optional<paddle::Tensor> &seq_lens_encoder,
const int64_t cpu_token_num) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto dev_ctx = static_cast<const phi::CustomContext *>(
paddle::experimental::DeviceContextPool::Instance().Get(
@@ -98,6 +112,12 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
int blockSize =
min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128);
#endif
int max_draft_tokens_per_batch = -1;
if (draft_tokens) {
max_draft_tokens_per_batch = draft_tokens.get().shape()[1];
}
PrefixSumKernel<<<bsz, blockSize, 0, cu_stream>>>(
x_remove_padding.data<int64_t>(),
batch_id_per_token.data<int>(),
@@ -105,7 +125,10 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
cu_seqlens_k.data<int>(),
input_ids.data<int64_t>(),
seq_len.data<int>(),
max_seq_len);
max_seq_len,
draft_tokens ? draft_tokens.get().data<int64_t>() : nullptr,
seq_lens_encoder ? seq_lens_encoder.get().data<int32_t>() : nullptr,
max_draft_tokens_per_batch);
return {x_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k};
}
@@ -127,7 +150,10 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
}
PD_BUILD_STATIC_OP(get_padding_offset)
.Inputs({"input_ids", "seq_len"})
.Inputs({"input_ids",
"seq_len",
paddle::Optional("draft_tokens"),
paddle::Optional("seq_lens_encoder")})
.Outputs({"x_remove_padding",
"batch_id_per_token",
"cu_seqlens_q",