Files
FastDeploy/custom_ops/gpu_ops/speculate_decoding/verify_draft_tokens.cu
T
2026-03-31 11:05:51 +08:00

483 lines
18 KiB
Plaintext

// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Verification kernel — outputs step_output_ids + step_output_len,
// and performs EOS / max_dec_len detection (read-only on step_idx).
// step_idx is NOT modified here; all state updates (including step_idx)
// are handled by unified_update_model_status.
//
// Verification strategies:
// 0 = TOPP : draft token in top-p candidate set (+ verify_window
// fallback) 1 = GREEDY : draft token == top-1 token (strict argmax
// match) 2 = TARGET_MATCH : draft token == target model's sampled token
#include <curand_kernel.h>
#include "helper.h" // NOLINT
// ============================================================
// Persistent curand state — allocated once, reused across calls.
// Only needed for TOPP strategy (Phase 2 stochastic sampling).
// ============================================================
static curandState_t *dev_curand_states = nullptr;
static int allocated_bsz = 0;
static uint64_t seed = 0;
static uint64_t offset = 0;
__global__ void setup_seed_kernel(curandState_t *state,
const uint64_t seed,
const uint64_t offset,
const int bs,
const bool need_batch_random) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = idx; i < bs; i += gridDim.x * blockDim.x) {
if (need_batch_random) {
curand_init(seed, i, offset, &state[i]);
} else {
curand_init(seed, 0, offset, &state[i]);
}
}
}
// ============================================================
// Phase 1 helpers — single-step draft token verification
// ============================================================
// Check if draft_token appears in the candidate set
__device__ inline bool is_in(const int64_t *candidates,
const int64_t draft,
const int candidate_len) {
for (int i = 0; i < candidate_len; i++) {
if (draft == candidates[i]) return true;
}
return false;
}
// TOPP: draft in top-p filtered candidate set
__device__ inline bool verify_one_topp(const int64_t *verify_tokens_row,
int64_t draft_token,
int actual_cand_len) {
return is_in(verify_tokens_row, draft_token, actual_cand_len);
}
// GREEDY / TARGET_MATCH: exact single-token match
__device__ inline bool verify_one_match(int64_t target_token,
int64_t draft_token) {
return target_token == draft_token;
}
// ============================================================
// VerifyContext — per-batch mutable state + accept helpers.
// Eliminates repeated EOS/max_dec_len check and output write
// patterns across Phase 1 and Phase 2.
// ============================================================
struct VerifyContext {
// Immutable per-batch (set once at kernel entry)
int bid;
int max_step_tokens;
int end_length;
const int64_t *end_tokens;
const int64_t *max_dec_len;
const int64_t *step_input_ids_now;
int64_t *step_output_ids;
// Mutable per-batch state
int64_t cur_step_idx;
int output_len_now;
bool stopped;
// Emit a token at position `pos` to output in Phase 1.
// Performs: step_idx check, EOS detection, token replacement, output write.
// Returns true if this sequence should stop (EOS or max_dec_len hit).
__device__ __forceinline__ bool emit_token(int pos, int64_t token) {
cur_step_idx++;
bool is_eos = is_in_end(token, end_tokens, end_length);
bool max_len_hit = (cur_step_idx >= max_dec_len[bid]);
if ((is_eos || max_len_hit) && !is_eos) {
token = end_tokens[0];
}
step_output_ids[bid * max_step_tokens + pos] = token;
output_len_now++;
if (is_eos || max_len_hit) {
stopped = true;
return true;
}
return false;
}
// Emit the final token at position `pos` in Phase 2.
// Same EOS/max_dec_len logic. Increments output_len_now since
// Phase 2 produces one additional token.
__device__ __forceinline__ void emit_final_token(int pos, int64_t token) {
cur_step_idx++;
bool is_eos = is_in_end(token, end_tokens, end_length);
bool max_len_hit = (cur_step_idx >= max_dec_len[bid]);
if ((is_eos || max_len_hit) && !is_eos) {
token = end_tokens[0];
}
step_output_ids[bid * max_step_tokens + pos] = token;
output_len_now++;
}
// ============================================================
// Phase 2 helpers — sample token for rejected/last position
// ============================================================
};
__device__ inline int64_t topp_sampling_kernel(const int64_t *candidate_ids,
const float *candidate_scores,
curandState_t *curand_states,
const int candidate_len,
const float topp) {
// Use bid (blockIdx.x-based) index, not threadIdx.x — curand_states is
// allocated with size bsz, and each batch element uses one thread.
const int bid = blockIdx.x * blockDim.x + threadIdx.x;
float sum_scores = 0.0f;
float rand_top_p = curand_uniform(curand_states + bid) * topp;
for (int i = 0; i < candidate_len; i++) {
sum_scores += candidate_scores[i];
if (rand_top_p <= sum_scores) {
return candidate_ids[i];
}
}
return candidate_ids[0];
}
// ============================================================
// Main verification kernel
// ============================================================
//
// Input parameter groups by strategy:
// - target_tokens: GREEDY=argmax, TARGET_MATCH=sampled, TOPP=unused
// (None)
// - candidate_ids/scores: TOPP=full candidate set,
// GREEDY/TARGET_MATCH=unused (None)
// - candidate_lens: TOPP=actual length per position,
// GREEDY/TARGET_MATCH=unused (None)
//
// All parameters may be empty tensors for strategies that don't use them.
//
__global__ void verify_draft_tokens(
// Core I/O
int64_t *step_output_ids,
int *step_output_len,
const int64_t *step_input_ids, // draft tokens
// Target model outputs (strategy-dependent interpretation)
const int64_t
*target_tokens, // GREEDY:argmax, TARGET_MATCH:sampled, TOPP:unused
// Candidate set for TOPP/GREEDY (TARGET_MATCH: unused)
const int64_t *candidate_ids,
const float *candidate_scores,
const int *candidate_lens,
// Sampling params
curandState_t *curand_states, // nullptr for GREEDY/TARGET_MATCH
const float *topp,
// Metadata
const bool *stop_flags,
const int *seq_lens_encoder,
const int *seq_lens_this_time,
const int64_t *end_tokens,
const bool *is_block_step,
const int *cu_seqlens_q_output,
const int *reasoning_status,
// max_dec_len / step_idx for EOS/max-len detection (read-only)
const int64_t *max_dec_len,
const int64_t *step_idx,
// Dimensions and config
const int max_bsz,
const int real_bsz,
const int max_step_tokens,
const int end_length,
const int max_seq_len,
const int max_candidate_len,
const int verify_window,
const int verify_strategy, // 0=TOPP, 1=GREEDY, 2=TARGET_MATCH
const bool reject_all,
const bool accept_all) {
const int bid = threadIdx.x;
// Initialize step_output_len to 0 for ALL slots
if (bid < max_bsz) {
step_output_len[bid] = 0;
for (int i = 0; i < max_step_tokens; i++) {
step_output_ids[bid * max_step_tokens + i] = -1;
}
} else {
return;
}
if (bid >= real_bsz || is_block_step[bid] || stop_flags[bid]) return;
const int start_token_id = cu_seqlens_q_output[bid];
// Pointers are strategy-dependent (may be nullptr for unused params)
auto *candidate_ids_now =
candidate_ids ? candidate_ids + start_token_id * max_candidate_len
: nullptr;
auto *candidate_scores_now =
candidate_scores ? candidate_scores + start_token_id * max_candidate_len
: nullptr;
auto *candidate_lens_now =
candidate_lens ? candidate_lens + start_token_id : nullptr;
auto *target_tokens_now =
target_tokens ? target_tokens + start_token_id : nullptr;
// Initialize per-batch verification context
VerifyContext ctx;
ctx.bid = bid;
ctx.max_step_tokens = max_step_tokens;
ctx.end_length = end_length;
ctx.end_tokens = end_tokens;
ctx.max_dec_len = max_dec_len;
ctx.step_input_ids_now = step_input_ids + bid * max_step_tokens;
ctx.step_output_ids = step_output_ids;
ctx.cur_step_idx = step_idx[bid];
ctx.output_len_now = 0;
ctx.stopped = false;
// ======== Phase 1: Verify draft tokens ========
int i = 0;
for (; i < seq_lens_this_time[bid] - 1; i++) {
// Early exit conditions: reject-all, prefill, reasoning
if (reject_all || seq_lens_encoder[bid] != 0 ||
reasoning_status[bid] == 1) {
break;
}
// Accept-all override (debug/warmup/CUDA graph capture)
if (accept_all) {
int64_t token = ctx.step_input_ids_now[i + 1];
// During dummy run (accept_all), replace EOS tokens with a safe
// non-EOS value to prevent stop_flags being set, which would cause
// CUDA graph capture failure due to token count mismatch.
if (is_in_end(token, end_tokens, end_length)) {
token = 5;
}
if (ctx.emit_token(i, token)) break;
continue;
}
// Strategy dispatch
bool accepted = false;
switch (verify_strategy) {
case 0: { // TOPP
auto actual_cand_len = candidate_lens_now[i] > max_candidate_len
? max_candidate_len
: candidate_lens_now[i];
accepted = verify_one_topp(candidate_ids_now + i * max_candidate_len,
ctx.step_input_ids_now[i + 1],
actual_cand_len);
break;
}
case 1: // GREEDY
case 2: // TARGET_MATCH
accepted = verify_one_match(target_tokens_now[i],
ctx.step_input_ids_now[i + 1]);
break;
}
if (accepted) {
if (ctx.emit_token(i, ctx.step_input_ids_now[i + 1])) break;
} else {
break; // reject
}
}
// ======== Phase 2: Output token for rejected/last position ========
if (!ctx.stopped) {
int64_t output_token;
switch (verify_strategy) {
case 0: { // TOPP — stochastic sampling from candidate set
auto actual_cand_len = candidate_lens_now[i] > max_candidate_len
? max_candidate_len
: candidate_lens_now[i];
output_token =
topp_sampling_kernel(candidate_ids_now + i * max_candidate_len,
candidate_scores_now + i * max_candidate_len,
curand_states,
actual_cand_len,
topp[bid]);
break;
}
case 1: // GREEDY — deterministic argmax from target_tokens
case 2: // TARGET_MATCH — target model's sampled token
output_token = target_tokens_now[i];
break;
}
ctx.emit_final_token(i, output_token);
}
step_output_len[bid] = ctx.output_len_now;
}
// ============================================================
// Host function
// ============================================================
void VerifyDraftTokens(
// Core I/O
const paddle::Tensor &step_output_ids,
const paddle::Tensor &step_output_len,
const paddle::Tensor &step_input_ids,
// Target model outputs (optional, required for TARGET_MATCH)
const paddle::optional<paddle::Tensor> &target_tokens,
// Candidate set (optional, required for TOPP/GREEDY)
const paddle::optional<paddle::Tensor> &candidate_ids,
const paddle::optional<paddle::Tensor> &candidate_scores,
const paddle::optional<paddle::Tensor> &candidate_lens,
// Sampling params
const paddle::Tensor &topp,
// Metadata
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &end_tokens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &cu_seqlens_q_output,
const paddle::Tensor &reasoning_status,
// max_dec_len / step_idx for EOS/max-len detection
const paddle::Tensor &max_dec_len,
const paddle::Tensor &step_idx,
int max_seq_len,
int verify_window,
int verify_strategy,
bool reject_all,
bool accept_all) {
auto bsz = step_output_ids.shape()[0];
auto real_bsz = seq_lens_this_time.shape()[0];
auto max_step_tokens = step_input_ids.shape()[1];
auto end_length = end_tokens.shape()[0];
// max_candidate_len: 1 if candidate_ids not provided, else from shape
int max_candidate_len = candidate_ids ? candidate_ids->shape()[1] : 1;
constexpr int BlockSize = 1024;
PADDLE_ENFORCE_LE(bsz,
BlockSize,
phi::errors::InvalidArgument(
"verify_draft_tokens: bsz (%d) exceeds BlockSize (%d). "
"Increase BlockSize or reduce max_num_seqs.",
bsz,
BlockSize));
auto stream = step_output_ids.stream();
// curand state: only needed for TOPP(0) strategy (stochastic sampling)
curandState_t *curand_ptr = nullptr;
if (verify_strategy ==
0 /* TOPP only - GREEDY and TARGET_MATCH use deterministic output */) {
if (dev_curand_states == nullptr || bsz > allocated_bsz) {
if (dev_curand_states) cudaFree(dev_curand_states);
cudaMalloc(&dev_curand_states, sizeof(curandState_t) * bsz);
allocated_bsz = bsz;
}
setup_seed_kernel<<<1, BlockSize, 0, stream>>>(
dev_curand_states, seed, offset, bsz, true);
seed++;
offset++;
curand_ptr = dev_curand_states;
}
// Get data pointers (nullptr if optional not provided)
const int64_t *target_tokens_ptr =
target_tokens ? target_tokens->data<int64_t>() : nullptr;
const int64_t *candidate_ids_ptr =
candidate_ids ? candidate_ids->data<int64_t>() : nullptr;
const float *candidate_scores_ptr =
candidate_scores ? candidate_scores->data<float>() : nullptr;
const int *candidate_lens_ptr =
candidate_lens ? candidate_lens->data<int>() : nullptr;
// Validate parameters based on verify_strategy.
// Note: empty_input_forward may lead to empty optional tensors — only
// validate when bsz > 0 (i.e. there are active sequences).
if (bsz > 0) {
if (verify_strategy == 0 /* TOPP */) {
if (!candidate_ids_ptr || !candidate_scores_ptr || !candidate_lens_ptr) {
PD_THROW(
"verify_strategy=TOPP (0) requires candidate_ids, "
"candidate_scores, candidate_lens");
}
} else if (verify_strategy == 1 /* GREEDY */) {
if (!target_tokens_ptr) {
PD_THROW("verify_strategy=GREEDY (1) requires target_tokens (argmax)");
}
} else if (verify_strategy == 2 /* TARGET_MATCH */) {
if (!target_tokens_ptr) {
PD_THROW(
"verify_strategy=TARGET_MATCH (2) requires target_tokens "
"(sampled)");
}
}
}
verify_draft_tokens<<<1, BlockSize, 0, stream>>>(
// Core I/O
const_cast<int64_t *>(step_output_ids.data<int64_t>()),
const_cast<int *>(step_output_len.data<int>()),
step_input_ids.data<int64_t>(),
// Target model outputs
target_tokens_ptr,
// Candidate set
candidate_ids_ptr,
candidate_scores_ptr,
candidate_lens_ptr,
// Sampling params
curand_ptr,
topp.data<float>(),
// Metadata
stop_flags.data<bool>(),
seq_lens_encoder.data<int>(),
seq_lens_this_time.data<int>(),
end_tokens.data<int64_t>(),
is_block_step.data<bool>(),
cu_seqlens_q_output.data<int>(),
reasoning_status.data<int>(),
// max_dec_len / step_idx
max_dec_len.data<int64_t>(),
step_idx.data<int64_t>(),
// Dimensions and config
bsz, // max_bsz
real_bsz, // real_bsz
max_step_tokens,
end_length,
max_seq_len,
max_candidate_len,
verify_window,
verify_strategy,
reject_all,
accept_all);
}
PD_BUILD_STATIC_OP(verify_draft_tokens)
.Inputs({"step_output_ids",
"step_output_len",
"step_input_ids",
paddle::Optional("target_tokens"),
paddle::Optional("candidate_ids"),
paddle::Optional("candidate_scores"),
paddle::Optional("candidate_lens"),
"topp",
"stop_flags",
"seq_lens_encoder",
"seq_lens_this_time",
"end_tokens",
"is_block_step",
"cu_seqlens_q_output",
"reasoning_status",
"max_dec_len",
"step_idx"})
.Outputs({"step_output_ids_out", "step_output_len_out"})
.Attrs({"max_seq_len: int",
"verify_window: int",
"verify_strategy: int",
"reject_all: bool",
"accept_all: bool"})
.SetInplaceMap({{"step_output_ids", "step_output_ids_out"},
{"step_output_len", "step_output_len_out"}})
.SetKernelFn(PD_KERNEL(VerifyDraftTokens));