[Speculative Decoding]Reformat input preprocess for spec decode (#6501)

* add speculate_pre_process kernel

* reduce one slice

* make d2h async && fix mtp bug for new pre_process

* fix

* add unitest

* fix: code stype formatting

* fix

* fix: thread race in speculate_preprocess && rename d2h event
This commit is contained in:
huicongyao
2026-03-03 10:22:07 +08:00
committed by GitHub
parent 33d6d2403c
commit 0f718baaf2
6 changed files with 619 additions and 25 deletions
+12
View File
@@ -751,6 +751,14 @@ std::vector<paddle::Tensor> SpeculateGetSeqLensOutput(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder);
std::vector<paddle::Tensor> SpeculatePreProcess(
const int64_t cpu_token_num,
const paddle::Tensor& input_ids,
const paddle::Tensor& seq_len,
const paddle::Tensor& draft_tokens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder);
void SpecTokenPenaltyMultiScores(
const paddle::Tensor& token_ids_all,
const paddle::Tensor& prompt_lens,
@@ -1604,6 +1612,10 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
&SpeculateGetSeqLensOutput,
"speculate_get_seq_lens_output function");
m.def("speculate_pre_process",
&SpeculatePreProcess,
"speculate_pre_process function");
m.def("speculate_get_token_penalty_multi_scores",
&SpecTokenPenaltyMultiScores,
"speculate_get_token_penalty_multi_scores function");
@@ -0,0 +1,247 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "paddle/extension.h"
#include <cooperative_groups.h>
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
namespace cg = cooperative_groups;
__global__ void SpeculatePreProcessKernel(int64_t *ids_remove_padding,
int *batch_id_per_token,
int *cu_seqlens_q,
int *cu_seqlens_k,
int *seq_lens_output,
int *cu_seq_lens_q_output,
int *batch_id_per_token_output,
int *real_output_token_num,
const int64_t *input_data,
const int *seq_lens,
const int max_seq_len,
const int64_t *draft_tokens,
const int *seq_lens_encoder,
const int max_draft_tokens_per_batch,
const int real_bsz) {
auto grid = cg::this_grid();
const int bi = blockIdx.x;
const int tid = threadIdx.x;
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
int cum_seq_len = 0;
// compute sum of seq_lens[0, 1, 2, ...,bi] per warp
for (int i = lane_id; i < bi + 1; i += WARP_SIZE) {
cum_seq_len += seq_lens[i];
}
#pragma unroll
for (int mask = WARP_SIZE >> 1; mask >= 1; mask >>= 1) {
cum_seq_len += __shfl_xor_sync(0xffffffff, cum_seq_len, mask);
}
if (tid == 0) {
cu_seqlens_q[bi + 1] = cum_seq_len;
cu_seqlens_k[bi + 1] = cum_seq_len;
}
if (bi == 0 && tid == 0) {
cu_seqlens_q[0] = 0;
cu_seqlens_k[0] = 0;
}
for (int i = tid; i < seq_lens[bi]; i += blockDim.x) {
const int tgt_seq_id = cum_seq_len - seq_lens[bi] + i;
if (max_draft_tokens_per_batch > 0 && seq_lens_encoder[bi] <= 0) {
// speculative decoding
const int src_seq_id = bi * max_draft_tokens_per_batch + i;
ids_remove_padding[tgt_seq_id] = draft_tokens[src_seq_id];
} else {
// Non-speculative decoding
const int src_seq_id = bi * max_seq_len + i;
ids_remove_padding[tgt_seq_id] = input_data[src_seq_id];
}
batch_id_per_token[tgt_seq_id] = bi;
}
for (int bid = blockIdx.x * blockDim.x + threadIdx.x; bid < real_bsz;
bid += gridDim.x * blockDim.x) {
if (seq_lens[bid] == 0) {
seq_lens_output[bid] = 0;
} else if (seq_lens[bid] == 1) {
seq_lens_output[bid] = 1;
} else if (seq_lens_encoder[bid] != 0) {
seq_lens_output[bid] = 1;
} else {
seq_lens_output[bid] = seq_lens[bid];
}
}
grid.sync();
int cum_seq_len_output = 0;
// compute sum of seq_lens_output[0,1,2,...,bi] per warp
for (int i = lane_id; i < bi + 1; i += WARP_SIZE) {
cum_seq_len_output += seq_lens_output[i];
}
#pragma unroll
for (int mask = WARP_SIZE >> 1; mask >= 1; mask >>= 1) {
cum_seq_len_output += __shfl_xor_sync(0xffffffff, cum_seq_len_output, mask);
}
if (tid == 0) {
cu_seq_lens_q_output[bi + 1] = cum_seq_len_output;
}
if (bi == 0 && tid == 0) {
cu_seq_lens_q_output[0] = 0;
}
// get real output token num
if (bi == real_bsz - 1 && tid == 0) {
real_output_token_num[0] = cum_seq_len_output;
}
for (int i = tid; i < seq_lens_output[bi]; i += blockDim.x) {
const int tgt_seq_id_output = cum_seq_len_output - seq_lens_output[bi] + i;
batch_id_per_token_output[tgt_seq_id_output] = bi;
}
}
std::vector<paddle::Tensor> SpeculatePreProcess(
const int64_t cpu_token_num,
const paddle::Tensor &input_ids,
const paddle::Tensor &seq_len,
const paddle::Tensor &draft_tokens,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto dev_ctx = static_cast<const phi::CustomContext *>(
paddle::experimental::DeviceContextPool::Instance().Get(
input_ids.place()));
auto cu_stream = dev_ctx->stream();
#else
auto cu_stream = input_ids.stream();
#endif
std::vector<int64_t> input_ids_shape = input_ids.shape();
const int bsz = seq_len.shape()[0];
const int max_seq_len = input_ids_shape[1];
const int token_num_data = cpu_token_num;
auto ids_remove_padding = paddle::empty(
{token_num_data}, paddle::DataType::INT64, input_ids.place());
auto batch_id_per_token = paddle::empty(
{token_num_data}, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_q =
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_k =
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
if (token_num_data == 0) {
return {ids_remove_padding,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
paddle::Tensor(),
paddle::Tensor(),
paddle::Tensor()};
}
#ifdef PADDLE_WITH_COREX
int blockSize =
std::min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128);
#else
int blockSize =
min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128);
#endif
const int max_draft_tokens_per_batch = draft_tokens.shape()[1];
auto seq_lens_output =
paddle::empty({bsz}, paddle::DataType::INT32, input_ids.place());
auto cu_seq_lens_q_output =
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
auto batch_id_per_token_output =
paddle::empty({bsz * max_draft_tokens_per_batch},
paddle::DataType::INT32,
input_ids.place());
auto real_output_token_num =
paddle::empty({1}, paddle::DataType::INT32, input_ids.place());
int64_t *ids_remove_padding_ptr = ids_remove_padding.data<int64_t>();
int *batch_id_per_token_ptr = batch_id_per_token.data<int>();
int *cu_seqlens_q_ptr = cu_seqlens_q.data<int>();
int *cu_seqlens_k_ptr = cu_seqlens_k.data<int>();
int *seq_lens_output_ptr = seq_lens_output.data<int>();
int *cu_seq_lens_q_output_ptr = cu_seq_lens_q_output.data<int>();
int *batch_id_per_token_output_ptr = batch_id_per_token_output.data<int>();
int *real_output_token_num_ptr = real_output_token_num.data<int>();
const int64_t *input_data_ptr = input_ids.data<int64_t>();
const int *seq_len_ptr = seq_len.data<int>();
const int64_t *draft_tokens_ptr = draft_tokens.data<int64_t>();
const int *seq_lens_encoder_ptr = seq_lens_encoder.data<int>();
void *kernel_args[] = {(void *)&ids_remove_padding_ptr,
(void *)&batch_id_per_token_ptr,
(void *)&cu_seqlens_q_ptr,
(void *)&cu_seqlens_k_ptr,
(void *)&seq_lens_output_ptr,
(void *)&cu_seq_lens_q_output_ptr,
(void *)&batch_id_per_token_output_ptr,
(void *)&real_output_token_num_ptr,
(void *)&input_data_ptr,
(void *)&seq_len_ptr,
(void *)&max_seq_len,
(void *)&draft_tokens_ptr,
(void *)&seq_lens_encoder_ptr,
(void *)&max_draft_tokens_per_batch,
(void *)&bsz};
cudaLaunchCooperativeKernel((void *)SpeculatePreProcessKernel,
dim3(bsz),
dim3(blockSize),
kernel_args,
0,
cu_stream);
return {ids_remove_padding,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
cu_seq_lens_q_output,
batch_id_per_token_output,
real_output_token_num};
}
PD_BUILD_STATIC_OP(speculate_pre_process)
.Inputs({"input_ids",
"seq_len",
"draft_tokens",
"seq_lens_encoder",
"seq_lens_decoder"})
.Outputs({"ids_remove_padding",
"batch_id_per_token",
"cu_seqlens_q",
"cu_seqlens_k",
"cu_seq_lens_q_output",
"batch_id_per_token_output",
"real_output_token_num"})
.Attrs({"cpu_token_num: int64_t"})
.SetKernelFn(PD_KERNEL(SpeculatePreProcess));
@@ -61,7 +61,6 @@ elif current_platform.is_maca():
save_output,
save_output_topk,
set_stop_value_multi_ends,
speculate_get_seq_lens_output,
speculate_limit_thinking_content_length,
speculate_save_output,
speculate_save_output_topk,
@@ -85,7 +84,7 @@ else:
save_output,
save_output_topk,
set_stop_value_multi_ends,
speculate_get_seq_lens_output,
speculate_pre_process,
speculate_save_output,
speculate_save_output_topk,
speculate_set_value_by_flags_and_idx,
@@ -152,6 +151,7 @@ def pre_process(
cu_seqlens_k,
None,
None,
None,
)
# Remove padding
if speculative_decoding:
@@ -160,27 +160,12 @@ def pre_process(
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
) = get_padding_offset(input_ids, seq_lens_this_time, draft_tokens, seq_lens_encoder, token_num_cpu)
# compute each batch's output token num
seq_lens_output = speculate_get_seq_lens_output(
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
cu_seqlens_q_output,
batch_id_per_token_output,
real_output_token_num,
) = speculate_pre_process(
token_num_cpu, input_ids, seq_lens_this_time, draft_tokens, seq_lens_encoder, seq_lens_decoder
)
if isinstance(seq_lens_output, list):
seq_lens_output = seq_lens_output[0]
output_token_num = paddle.sum(seq_lens_output)
useless_input_ids = input_ids
_, batch_id_per_token_output, cu_seqlens_q_output, _ = get_padding_offset(
useless_input_ids,
seq_lens_output,
None,
None,
output_token_num.item(),
)
return (
ids_remove_padding,
batch_id_per_token,
@@ -188,6 +173,7 @@ def pre_process(
cu_seqlens_k,
cu_seqlens_q_output,
batch_id_per_token_output,
real_output_token_num,
)
+12 -1
View File
@@ -123,6 +123,10 @@ class MTPProposer(Proposer):
self.model_inputs = ProposerInputBatch(self.fd_config, self.target_model_inputs)
self.model_inputs.init_share_inputs()
if current_platform.is_cuda() or current_platform.is_maca():
self._real_output_token_num_host = paddle.empty([1], dtype="int32").pin_memory()
self.output_token_num_event = paddle.device.cuda.Event()
# CUDA Graph
self.draft_model_use_cudagraph = self.graph_opt_config.draft_model_use_cudagraph
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes))
@@ -815,6 +819,7 @@ class MTPProposer(Proposer):
cu_seqlens_k,
cu_seqlens_q_output,
batch_id_per_token_output,
real_output_token_num,
) = pre_process(
token_num_cpu,
self.model_inputs["input_ids"],
@@ -851,6 +856,8 @@ class MTPProposer(Proposer):
# For speculative decoding
self.model_inputs["cu_seqlens_q_output"].copy_(cu_seqlens_q_output, False)
self.model_inputs["batch_id_per_token_output"].copy_(batch_id_per_token_output, False)
self._real_output_token_num_host.copy_(real_output_token_num, False)
self.output_token_num_event.record()
# Initialize forward meta data
self._initialize_forward_meta(
@@ -896,13 +903,17 @@ class MTPProposer(Proposer):
)
if self.forward_meta.step_use_cudagraph:
model_output = model_output[: self.real_token_num]
self.output_token_num_event.synchronize()
real_num = int(self._real_output_token_num_host)
real_batch_id_per_token_output = self.model_inputs["batch_id_per_token_output"][:real_num]
hidden_states = rebuild_padding(
model_output,
self.model_inputs["cu_seqlens_q"],
self.model_inputs["seq_lens_this_time"],
self.model_inputs["seq_lens_decoder"],
self.model_inputs["seq_lens_encoder"],
self.model_inputs["batch_id_per_token_output"],
real_batch_id_per_token_output,
self.model_inputs["cu_seqlens_q_output"],
self.model_inputs["first_token_hidden_states"],
self.enable_logprob if substep == 0 else False,
+21 -2
View File
@@ -148,6 +148,10 @@ class GPUModelRunner(ModelRunnerBase):
self.cache_kvs_map: dict = {}
self.exist_prefill_flag = False
if self.speculative_decoding:
self._real_output_token_num_host = paddle.empty([1], dtype="int32").pin_memory()
self.output_token_num_event = paddle.device.cuda.Event()
# VL model config:
if self.enable_mm:
if "ernie" in self.fd_config.model_config.model_type:
@@ -1129,6 +1133,7 @@ class GPUModelRunner(ModelRunnerBase):
cu_seqlens_k,
cu_seqlens_q_output,
batch_id_per_token_output,
real_output_token_num,
) = pre_process(
token_num,
self.share_inputs["input_ids"],
@@ -1150,6 +1155,9 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["cu_seqlens_q_output"].copy_(cu_seqlens_q_output, False)
self.share_inputs["batch_id_per_token_output"].copy_(batch_id_per_token_output, False)
self._real_output_token_num_host.copy_(real_output_token_num, False)
self.output_token_num_event.record()
# Initialize forward meta data
self.initialize_forward_meta(is_dummy_or_profile_run=is_dummy_or_profile_run)
@@ -1750,13 +1758,19 @@ class GPUModelRunner(ModelRunnerBase):
self._dummy_pooler_run(model_output, model_output)
break
else:
if self.speculative_decoding:
self.output_token_num_event.synchronize()
real_num = int(self._real_output_token_num_host)
real_batch_id_per_token_output = self.share_inputs["batch_id_per_token_output"][:real_num]
else:
real_batch_id_per_token_output = None
hidden_states = rebuild_padding(
model_output,
self.share_inputs["cu_seqlens_q"],
self.share_inputs["seq_lens_this_time"],
self.share_inputs["seq_lens_decoder"],
self.share_inputs["seq_lens_encoder"],
(self.share_inputs["batch_id_per_token_output"] if self.speculative_decoding else None),
real_batch_id_per_token_output,
(self.share_inputs["cu_seqlens_q_output"] if self.speculative_decoding else None),
)
self._dummy_sampler_run(hidden_states, model_output, batch_size, accept_all_drafts, reject_all_drafts)
@@ -2059,6 +2073,11 @@ class GPUModelRunner(ModelRunnerBase):
+ self.share_inputs["is_block_step_cpu"].numpy().sum().item()
)
if self.speculative_decoding:
self.output_token_num_event.synchronize()
real_num = int(self._real_output_token_num_host)
real_batch_id_per_token_output = self.share_inputs["batch_id_per_token_output"][:real_num]
prompt_logprobs_list = self._get_prompt_logprobs_list(model_output)
if self.is_pooling_model:
pooler_output = self._pool(model_output, num_running_requests)
@@ -2117,7 +2136,7 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["seq_lens_this_time"],
self.share_inputs["seq_lens_decoder"],
self.share_inputs["seq_lens_encoder"],
(self.share_inputs["batch_id_per_token_output"] if self.speculative_decoding else None),
(real_batch_id_per_token_output if self.speculative_decoding else None),
(self.share_inputs["cu_seqlens_q_output"] if self.speculative_decoding else None),
)
@@ -0,0 +1,319 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from fastdeploy.model_executor.ops.gpu import speculate_pre_process
def speculate_pre_process_ref(
input_ids,
seq_lens,
draft_tokens,
seq_lens_encoder,
max_seq_len,
max_draft_tokens_per_batch,
real_bsz,
token_num,
):
"""
Python reference implementation for SpeculatePreProcessKernel.
Returns:
ids_remove_padding: int64[token_num]
batch_id_per_token: int32[token_num]
cu_seqlens_q: int32[real_bsz + 1]
cu_seqlens_k: int32[real_bsz + 1]
seq_lens_output: int32[real_bsz]
cu_seq_lens_q_output: int32[real_bsz + 1]
batch_id_per_token_output: int32[real_bsz * max_draft_tokens_per_batch]
real_output_token_num: int32[1]
"""
# --- Part 1: ids_remove_padding, batch_id_per_token, cu_seqlens_q/k ---
ids_remove_padding = np.zeros(token_num, dtype=np.int64)
batch_id_per_token = np.zeros(token_num, dtype=np.int32)
cu_seqlens_q = np.zeros(real_bsz + 1, dtype=np.int32)
cu_seqlens_k = np.zeros(real_bsz + 1, dtype=np.int32)
cum = 0
for bi in range(real_bsz):
cum += seq_lens[bi]
cu_seqlens_q[bi + 1] = cum
cu_seqlens_k[bi + 1] = cum
start = cum - seq_lens[bi]
for i in range(seq_lens[bi]):
tgt = start + i
if max_draft_tokens_per_batch > 0 and seq_lens_encoder[bi] <= 0:
src = bi * max_draft_tokens_per_batch + i
ids_remove_padding[tgt] = draft_tokens[src]
else:
src = bi * max_seq_len + i
ids_remove_padding[tgt] = input_ids[src]
batch_id_per_token[tgt] = bi
# --- Part 2: seq_lens_output ---
seq_lens_output = np.zeros(real_bsz, dtype=np.int32)
for bid in range(real_bsz):
if seq_lens[bid] == 0:
seq_lens_output[bid] = 0
elif seq_lens[bid] == 1:
seq_lens_output[bid] = 1
elif seq_lens_encoder[bid] != 0:
seq_lens_output[bid] = 1
else:
seq_lens_output[bid] = seq_lens[bid]
# --- Part 3: cu_seq_lens_q_output, batch_id_per_token_output, real_output_token_num ---
cu_seq_lens_q_output = np.zeros(real_bsz + 1, dtype=np.int32)
batch_id_per_token_output = np.zeros(real_bsz * max_draft_tokens_per_batch, dtype=np.int32)
cum_output = 0
for bi in range(real_bsz):
cum_output += seq_lens_output[bi]
cu_seq_lens_q_output[bi + 1] = cum_output
start_out = cum_output - seq_lens_output[bi]
for i in range(seq_lens_output[bi]):
batch_id_per_token_output[start_out + i] = bi
real_output_token_num = np.array([cum_output], dtype=np.int32)
return (
ids_remove_padding,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
seq_lens_output,
cu_seq_lens_q_output,
batch_id_per_token_output,
real_output_token_num,
)
def build_inputs(
real_bsz,
max_seq_len,
max_draft_tokens,
seq_lens_list,
seq_lens_encoder_list,
draft_tokens_data=None,
input_ids_data=None,
seed=42,
):
"""
Helper to build test inputs from explicit seq_lens and seq_lens_encoder lists.
draft_tokens_data and input_ids_data are optional; if None, random data is used.
"""
rng = np.random.default_rng(seed)
seq_lens = np.array(seq_lens_list, dtype=np.int32)
seq_lens_encoder = np.array(seq_lens_encoder_list, dtype=np.int32)
seq_lens_decoder = np.zeros(real_bsz, dtype=np.int32) # not used in kernel logic
token_num = int(np.sum(seq_lens))
if input_ids_data is not None:
input_ids = np.array(input_ids_data, dtype=np.int64).reshape(real_bsz, max_seq_len)
else:
input_ids = rng.integers(1, 1000, size=(real_bsz, max_seq_len), dtype=np.int64)
if draft_tokens_data is not None:
draft_tokens = np.array(draft_tokens_data, dtype=np.int64).reshape(real_bsz, max_draft_tokens)
else:
draft_tokens = rng.integers(1, 1000, size=(real_bsz, max_draft_tokens), dtype=np.int64)
return {
"input_ids": input_ids,
"seq_lens": seq_lens,
"draft_tokens": draft_tokens,
"seq_lens_encoder": seq_lens_encoder,
"seq_lens_decoder": seq_lens_decoder,
"max_seq_len": max_seq_len,
"max_draft_tokens": max_draft_tokens,
"token_num": token_num,
"real_bsz": real_bsz,
}
def run_and_compare(tc, inputs):
"""
Call GPU op and Python reference, compare all outputs.
tc: unittest.TestCase instance (for assertion messages).
"""
real_bsz = inputs["real_bsz"]
max_seq_len = inputs["max_seq_len"]
max_draft_tokens = inputs["max_draft_tokens"]
token_num = inputs["token_num"]
t_input_ids = paddle.to_tensor(inputs["input_ids"], dtype="int64")
t_seq_lens = paddle.to_tensor(inputs["seq_lens"], dtype="int32")
t_draft_tokens = paddle.to_tensor(inputs["draft_tokens"], dtype="int64")
t_seq_lens_encoder = paddle.to_tensor(inputs["seq_lens_encoder"], dtype="int32")
t_seq_lens_decoder = paddle.to_tensor(inputs["seq_lens_decoder"], dtype="int32")
gpu_outs = speculate_pre_process(
token_num, t_input_ids, t_seq_lens, t_draft_tokens, t_seq_lens_encoder, t_seq_lens_decoder
)
ref_outs = speculate_pre_process_ref(
input_ids=inputs["input_ids"].reshape(-1),
seq_lens=inputs["seq_lens"],
draft_tokens=inputs["draft_tokens"].reshape(-1),
seq_lens_encoder=inputs["seq_lens_encoder"],
max_seq_len=max_seq_len,
max_draft_tokens_per_batch=max_draft_tokens,
real_bsz=real_bsz,
token_num=token_num,
)
output_names = [
"ids_remove_padding",
"batch_id_per_token",
"cu_seqlens_q",
"cu_seqlens_k",
"cu_seq_lens_q_output",
"batch_id_per_token_output",
"real_output_token_num",
]
# GPU op returns 7 tensors; ref returns 8 (with seq_lens_output at index 4).
# GPU output order: ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k,
# cu_seq_lens_q_output, batch_id_per_token_output, real_output_token_num
# Ref output order: ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k,
# seq_lens_output, cu_seq_lens_q_output, batch_id_per_token_output, real_output_token_num
ref_indices = [0, 1, 2, 3, 5, 6, 7] # skip seq_lens_output (index 4) for direct comparison
for name, gpu_idx, ref_idx in zip(output_names, range(7), ref_indices):
gpu_val = gpu_outs[gpu_idx].numpy()
ref_val = ref_outs[ref_idx]
# Trim batch_id_per_token_output to the valid portion (real_output_token_num)
# The kernel only writes valid positions; beyond that the content is undefined.
if name == "batch_id_per_token_output":
valid_len = int(ref_outs[7][0]) # real_output_token_num
gpu_val = gpu_val[:valid_len]
ref_val = ref_val[:valid_len]
np.testing.assert_allclose(
gpu_val,
ref_val,
err_msg=f"Mismatch in output '{name}'",
)
class TestSpeculatePreProcess(unittest.TestCase):
"""Unit tests for speculate_pre_process custom operator."""
# ----------------------------------------------------------------
# Test 1: mixed batch covering all 4 seq_lens_output branches
# bid=0: seq_lens=0 => output=0 (skip)
# bid=1: seq_lens=1, encoder=0 => output=1, read draft_tokens
# bid=2: seq_lens=5, encoder=3 => output=1, read input_ids (prefill)
# bid=3: seq_lens=4, encoder=0 => output=4, read draft_tokens (decode)
# bid=4: seq_lens=1, encoder=2 => output=1, read input_ids (prefill single)
# bid=5: seq_lens=8, encoder=0 => output=8, read draft_tokens (decode saturated)
# ----------------------------------------------------------------
def test_mixed_batch_all_branches(self):
inputs = build_inputs(
real_bsz=6,
max_seq_len=16,
max_draft_tokens=8,
seq_lens_list=[0, 1, 5, 4, 1, 8],
seq_lens_encoder_list=[0, 0, 3, 0, 2, 0],
)
run_and_compare(self, inputs)
# ----------------------------------------------------------------
# Test 2: token_num=0 early return — verify no crash, 7 outputs
# ----------------------------------------------------------------
def test_all_zero_seq_lens(self):
real_bsz = 3
t_input_ids = paddle.zeros([real_bsz, 8], dtype="int64")
t_seq_lens = paddle.zeros([real_bsz], dtype="int32")
t_draft_tokens = paddle.zeros([real_bsz, 4], dtype="int64")
t_seq_lens_encoder = paddle.zeros([real_bsz], dtype="int32")
t_seq_lens_decoder = paddle.zeros([real_bsz], dtype="int32")
gpu_outs = speculate_pre_process(
0, t_input_ids, t_seq_lens, t_draft_tokens, t_seq_lens_encoder, t_seq_lens_decoder
)
self.assertEqual(len(gpu_outs), 7)
# ----------------------------------------------------------------
# Test 3: exact token values — manually verify ids_remove_padding
# bid=0: encoder=0 (decode) => draft_tokens[0][0:3] = [10,11,12]
# bid=1: encoder=5 (prefill) => input_ids[1][0:2] = [200,201]
# ----------------------------------------------------------------
def test_exact_token_values(self):
inputs = build_inputs(
real_bsz=2,
max_seq_len=4,
max_draft_tokens=4,
seq_lens_list=[3, 2],
seq_lens_encoder_list=[0, 5],
draft_tokens_data=[[10, 11, 12, 13], [20, 21, 22, 23]],
input_ids_data=[[100, 101, 102, 103], [200, 201, 202, 203]],
)
t_input_ids = paddle.to_tensor(inputs["input_ids"], dtype="int64")
t_seq_lens = paddle.to_tensor(inputs["seq_lens"], dtype="int32")
t_draft_tokens = paddle.to_tensor(inputs["draft_tokens"], dtype="int64")
t_seq_lens_encoder = paddle.to_tensor(inputs["seq_lens_encoder"], dtype="int32")
t_seq_lens_decoder = paddle.to_tensor(inputs["seq_lens_decoder"], dtype="int32")
gpu_outs = speculate_pre_process(
int(np.sum(inputs["seq_lens"])),
t_input_ids,
t_seq_lens,
t_draft_tokens,
t_seq_lens_encoder,
t_seq_lens_decoder,
)
np.testing.assert_allclose(gpu_outs[0].numpy(), [10, 11, 12, 200, 201])
np.testing.assert_allclose(gpu_outs[1].numpy(), [0, 0, 0, 1, 1])
np.testing.assert_allclose(gpu_outs[2].numpy(), [0, 3, 5])
np.testing.assert_allclose(gpu_outs[6].numpy(), [4]) # real_output_token_num = 3+1
# ----------------------------------------------------------------
# Test 4: random stress test (2 configs covering small & medium batch)
# ----------------------------------------------------------------
def test_random_configs(self):
configs = [
{"real_bsz": 7, "max_seq_len": 32, "max_draft_tokens": 8, "seed": 200},
{"real_bsz": 32, "max_seq_len": 128, "max_draft_tokens": 16, "seed": 400},
]
for cfg in configs:
with self.subTest(**cfg):
rng = np.random.default_rng(cfg["seed"])
real_bsz = cfg["real_bsz"]
max_draft = cfg["max_draft_tokens"]
seq_lens_list = rng.integers(0, max_draft + 1, size=real_bsz).tolist()
seq_lens_encoder_list = rng.integers(0, 3, size=real_bsz).tolist()
inputs = build_inputs(
real_bsz=real_bsz,
max_seq_len=cfg["max_seq_len"],
max_draft_tokens=max_draft,
seq_lens_list=seq_lens_list,
seq_lens_encoder_list=seq_lens_encoder_list,
seed=cfg["seed"],
)
if inputs["token_num"] == 0:
continue
run_and_compare(self, inputs)
if __name__ == "__main__":
unittest.main()