[XPU] add verify draft tokens (#6947)

* [XPU] add verify draft tokens

* fix test

* fix code style

* use sync cpy

* fix code style

* fix kernel check

* fix ramdom seed

* fix test

* fix check

* fix eos set

* fix verify

* fix verify
This commit is contained in:
cmcamdy
2026-04-15 10:18:33 +08:00
committed by GitHub
parent e0a1653b26
commit 13b9fe7299
8 changed files with 2320 additions and 7 deletions
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <atomic>
#include <paddle/phi/backends/xpu/xpu_context.h>
#include <stdio.h>
#include "paddle/common/flags.h"
@@ -26,6 +27,10 @@
namespace api = baidu::xpu::api;
// Persistent seed/offset — mirrors GPU curand state lifecycle.
static std::atomic<uint64_t> g_seed{0};
static std::atomic<uint64_t> g_offset{0};
void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
const paddle::Tensor &accept_tokens,
const paddle::Tensor &accept_num,
@@ -82,13 +87,15 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
prefill_one_step_stop = true;
}
}
// random
int random_seed = 0;
std::vector<int64_t> infer_seed(bsz, random_seed);
// random — use persistent seed/offset so each call and batch element
// produce distinct random numbers (mirrors GPU curand lifecycle).
uint64_t cur_seed = g_seed++;
uint64_t cur_offset = g_offset++;
std::uniform_real_distribution<float> dist(0.0, 1.0);
std::vector<float> dev_curand_states_cpu;
for (int i = 0; i < bsz; i++) {
std::mt19937_64 engine(infer_seed[i]);
std::mt19937_64 engine(cur_seed + i);
engine.discard(cur_offset);
dev_curand_states_cpu.push_back(dist(engine));
}
float *dev_curand_states = dev_curand_states_cpu.data();
@@ -0,0 +1,217 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Verification kernel — outputs step_output_ids + step_output_len,
// and performs EOS / max_dec_len detection (read-only on step_idx).
// step_idx is NOT modified here; all state updates (including step_idx)
// are handled by unified_update_model_status.
//
// Verification strategies:
// 0 = TOPP : draft token in top-p candidate set (+ verify_window
// fallback) 1 = GREEDY : draft token == top-1 token (strict argmax
// match) 2 = TARGET_MATCH : draft token == target model's sampled token
#include <atomic>
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "xpu/plugin.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
namespace api = baidu::xpu::api;
// Persistent seed/offset — mirrors GPU curand state lifecycle.
static std::atomic<uint64_t> g_seed{0};
static std::atomic<uint64_t> g_offset{0};
// ============================================================
// Host function
// ============================================================
void VerifyDraftTokens(
// Core I/O
const paddle::Tensor &step_output_ids,
const paddle::Tensor &step_output_len,
const paddle::Tensor &step_input_ids,
// Target model outputs (optional, required for TARGET_MATCH)
const paddle::optional<paddle::Tensor> &target_tokens,
// Candidate set (optional, required for TOPP/GREEDY)
const paddle::optional<paddle::Tensor> &candidate_ids,
const paddle::optional<paddle::Tensor> &candidate_scores,
const paddle::optional<paddle::Tensor> &candidate_lens,
// Sampling params
const paddle::Tensor &topp,
// Metadata
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &end_tokens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &cu_seqlens_q_output,
const paddle::Tensor &reasoning_status,
// max_dec_len / step_idx for EOS/max-len detection
const paddle::Tensor &max_dec_len,
const paddle::Tensor &step_idx,
int max_seq_len,
int verify_window,
int verify_strategy,
bool reject_all,
bool accept_all) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
api::Context *ctx =
static_cast<const phi::XPUContext *>(dev_ctx)->x_context();
bool xpu_ctx_flag = true;
if (step_output_ids.is_cpu()) {
ctx = new api::Context(api::kCPU);
xpu_ctx_flag = false;
}
auto bsz = step_output_ids.shape()[0];
auto real_bsz = seq_lens_this_time.shape()[0];
auto max_step_tokens = step_input_ids.shape()[1];
auto end_length = end_tokens.shape()[0];
// max_candidate_len: 1 if candidate_ids not provided, else from shape
int max_candidate_len = candidate_ids ? candidate_ids->shape()[1] : 1;
// curand state: only needed for TOPP(0) strategy (stochastic sampling)
// Use persistent seed/offset (mirrors GPU curand lifecycle) so that
// each call and each batch element produce distinct random numbers.
uint64_t cur_seed = g_seed++;
uint64_t cur_offset = g_offset++;
std::uniform_real_distribution<float> dist(0.0, 1.0);
std::vector<float> dev_curand_states_cpu;
for (int i = 0; i < bsz; i++) {
std::mt19937_64 engine(cur_seed + i);
engine.discard(cur_offset);
dev_curand_states_cpu.push_back(dist(engine));
}
float *dev_curand_states = dev_curand_states_cpu.data();
auto dev_curand_states_tensor =
paddle::empty({static_cast<int64_t>(dev_curand_states_cpu.size())},
paddle::DataType::FLOAT32,
seq_lens_this_time.place());
int ret;
if (xpu_ctx_flag) {
ret = api::do_host2device(ctx,
dev_curand_states_cpu.data(),
dev_curand_states_tensor.data<float>(),
dev_curand_states_cpu.size() * sizeof(float));
PD_CHECK(ret == 0, "do_host2device failed.");
dev_curand_states = dev_curand_states_tensor.data<float>();
}
// Get data pointers (nullptr if optional not provided)
const int64_t *target_tokens_ptr =
target_tokens ? target_tokens->data<int64_t>() : nullptr;
const int64_t *candidate_ids_ptr =
candidate_ids ? candidate_ids->data<int64_t>() : nullptr;
const float *candidate_scores_ptr =
candidate_scores ? candidate_scores->data<float>() : nullptr;
const int *candidate_lens_ptr =
candidate_lens ? candidate_lens->data<int>() : nullptr;
// Validate parameters based on verify_strategy.
// Note: empty_input_forward may lead to empty optional tensors — only
// validate when bsz > 0 (i.e. there are active sequences).
if (bsz > 0) {
if (verify_strategy == 0 /* TOPP */) {
if (!candidate_ids_ptr || !candidate_scores_ptr || !candidate_lens_ptr) {
PD_THROW(
"verify_strategy=TOPP (0) requires candidate_ids, "
"candidate_scores, candidate_lens");
}
} else if (verify_strategy == 1 /* GREEDY */) {
if (!target_tokens_ptr) {
PD_THROW("verify_strategy=GREEDY (1) requires target_tokens (argmax)");
}
} else if (verify_strategy == 2 /* TARGET_MATCH */) {
if (!target_tokens_ptr) {
PD_THROW(
"verify_strategy=TARGET_MATCH (2) requires target_tokens "
"(sampled)");
}
}
}
ret = fastdeploy::plugin::verify_draft_tokens(
ctx,
// Core I/O
const_cast<int64_t *>(step_output_ids.data<int64_t>()),
const_cast<int *>(step_output_len.data<int>()),
step_input_ids.data<int64_t>(),
// Target model outputs
target_tokens_ptr,
// Candidate set
candidate_ids_ptr,
candidate_scores_ptr,
candidate_lens_ptr,
// Sampling params
dev_curand_states,
topp.data<float>(),
// Metadata
stop_flags.data<bool>(),
seq_lens_encoder.data<int>(),
seq_lens_this_time.data<int>(),
end_tokens.data<int64_t>(),
is_block_step.data<bool>(),
cu_seqlens_q_output.data<int>(),
reasoning_status.data<int>(),
// max_dec_len / step_idx
max_dec_len.data<int64_t>(),
step_idx.data<int64_t>(),
// Dimensions and config
bsz, // max_bsz
real_bsz, // real_bsz
max_step_tokens,
end_length,
max_seq_len,
max_candidate_len,
verify_window,
verify_strategy,
reject_all,
accept_all);
if (step_output_ids.is_cpu()) {
delete ctx;
}
PD_CHECK(ret == 0, "verify_draft_tokens failed.");
}
PD_BUILD_STATIC_OP(verify_draft_tokens)
.Inputs({"step_output_ids",
"step_output_len",
"step_input_ids",
paddle::Optional("target_tokens"),
paddle::Optional("candidate_ids"),
paddle::Optional("candidate_scores"),
paddle::Optional("candidate_lens"),
"topp",
"stop_flags",
"seq_lens_encoder",
"seq_lens_this_time",
"end_tokens",
"is_block_step",
"cu_seqlens_q_output",
"reasoning_status",
"max_dec_len",
"step_idx"})
.Outputs({"step_output_ids_out", "step_output_len_out"})
.Attrs({"max_seq_len: int",
"verify_window: int",
"verify_strategy: int",
"reject_all: bool",
"accept_all: bool"})
.SetInplaceMap({{"step_output_ids", "step_output_ids_out"},
{"step_output_len", "step_output_len_out"}})
.SetKernelFn(PD_KERNEL(VerifyDraftTokens));
@@ -701,6 +701,36 @@ std::vector<paddle::Tensor> WeightQuantize(const paddle::Tensor& x,
const int32_t arch,
const int32_t group_size);
void VerifyDraftTokens(
// Core I/O
const paddle::Tensor& step_output_ids,
const paddle::Tensor& step_output_len,
const paddle::Tensor& step_input_ids,
// Target model outputs (optional, required for TARGET_MATCH)
const paddle::optional<paddle::Tensor>& target_tokens,
// Candidate set (optional, required for TOPP/GREEDY)
const paddle::optional<paddle::Tensor>& candidate_ids,
const paddle::optional<paddle::Tensor>& candidate_scores,
const paddle::optional<paddle::Tensor>& candidate_lens,
// Sampling params
const paddle::Tensor& topp,
// Metadata
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& end_tokens,
const paddle::Tensor& is_block_step,
const paddle::Tensor& cu_seqlens_q_output,
const paddle::Tensor& reasoning_status,
// max_dec_len / step_idx for EOS/max-len detection
const paddle::Tensor& max_dec_len,
const paddle::Tensor& step_idx,
int max_seq_len,
int verify_window,
int verify_strategy,
bool reject_all,
bool accept_all);
PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("adjust_batch",
&AdjustBatch,
@@ -1234,6 +1264,32 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("accept_all_drafts"),
"Perform speculative verification for decoding");
m.def("verify_draft_tokens",
&VerifyDraftTokens,
py::arg("step_output_ids"),
py::arg("step_output_len"),
py::arg("step_input_ids"),
py::arg("target_tokens"),
py::arg("candidate_ids"),
py::arg("candidate_scores"),
py::arg("candidate_lens"),
py::arg("topp"),
py::arg("stop_flags"),
py::arg("seq_lens_encoder"),
py::arg("seq_lens_this_time"),
py::arg("end_tokens"),
py::arg("is_block_step"),
py::arg("cu_seqlens_q_output"),
py::arg("reasoning_status"),
py::arg("max_dec_len"),
py::arg("step_idx"),
py::arg("max_seq_len"),
py::arg("verify_window"),
py::arg("verify_strategy"),
py::arg("reject_all"),
py::arg("accept_all"),
"Perform speculative verification for decoding v2");
m.def("speculate_save_output",
&SpeculateSaveWithOutputMsgStatic,
py::arg("accept_tokens"),
@@ -766,6 +766,45 @@ DLL_EXPORT int speculate_limit_thinking_content_length_kernel(
const int eos_token_id_len,
const int inject_len,
const bool splitwise_role_is_decode);
DLL_EXPORT int verify_draft_tokens(
api::Context* ctx,
// Core I/O
int64_t* step_output_ids,
int* step_output_len,
const int64_t* step_input_ids, // draft tokens
// Target model outputs (strategy-dependent interpretation)
const int64_t*
target_tokens, // GREEDY:argmax, TARGET_MATCH:sampled, TOPP:unused
// Candidate set for TOPP/GREEDY (TARGET_MATCH: unused)
const int64_t* candidate_ids,
const float* candidate_scores,
const int* candidate_lens,
// Sampling params
const float* curand_states, // nullptr for GREEDY/TARGET_MATCH
const float* topp,
// Metadata
const bool* stop_flags,
const int* seq_lens_encoder,
const int* seq_lens_this_time,
const int64_t* end_tokens,
const bool* is_block_step,
const int* cu_seqlens_q_output,
const int* reasoning_status,
// max_dec_len / step_idx for EOS/max-len detection (read-only)
const int64_t* max_dec_len,
const int64_t* step_idx,
// Dimensions and config
const int max_bsz,
const int real_bsz,
const int max_step_tokens,
const int end_length,
const int max_seq_len,
const int max_candidate_len,
const int verify_window,
const int verify_strategy, // 0=TOPP, 1=GREEDY, 2=TARGET_MATCH
const bool reject_all,
const bool accept_all);
/*--------------------------------------- MTP end
* --------------------------------------------*/
@@ -0,0 +1,344 @@
#include "xpu/kernel/cluster_debug.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/xtdk.h"
#include "xpu/kernel/xtdk_math.h"
#include "xpu/kernel/xtdk_simd.h"
namespace fd_xpu3 {
static inline __device__ int v_reduce(int32x16_t &v0, int32x16_t &v1) {
int res;
v1 = vvadd_int32x16(v0, v1);
auto v = vsrlp_int32x16(256, v1);
v1 = vvadd_int32x16(v, v1);
v = vsrlp_int32x16(128, v1);
v1 = vvadd_int32x16(v, v1);
v = vsrlp_int32x16(64, v1);
v1 = vvadd_int32x16(v, v1);
v = vsrlp_int32x16(32, v1);
v1 = vvadd_int32x16(v, v1);
res = vextract_int32x16(v1, 1);
return res;
}
static inline __device__ int ClusterReduce(
const _shared_ptr_ int *stop_flag_now_int_sm, int len) {
int sum = 0;
if (core_id() == 0) {
int32x16_t vec_x_0;
int32x16_t vec_x_1;
int32x16_t vec_y_0 = vzero<int>();
int32x16_t vec_y_1 = vzero<int>();
for (int i = 0; i < len; i += 32) {
vload2_sm(stop_flag_now_int_sm + i, vec_x_0, vec_x_1);
vec_y_0 = vvadd_int32x16(vec_y_0, vec_x_0);
vec_y_1 = vvadd_int32x16(vec_y_1, vec_x_1);
}
sum = v_reduce(vec_y_0, vec_y_1);
}
return sum;
}
__device__ bool is_in_end(const int64_t id,
__global_ptr__ const int64_t *end_ids,
int length) {
bool flag = false;
for (int i = 0; i < length; i++) {
if (id == end_ids[i]) {
return true;
}
}
return flag;
}
__device__ inline bool is_in(__global_ptr__ const int64_t *candidates,
const int64_t draft,
const int candidate_len) {
for (int i = 0; i < candidate_len; i++) {
if (draft == candidates[i]) {
return true;
}
}
return false;
}
static __device__ inline unsigned int xorwow(unsigned int &state) {
state ^= state >> 7;
state ^= state << 9;
state ^= state >> 13;
return state;
}
__device__ int64_t
topp_sampling_kernel(__global_ptr__ const int64_t *candidate_ids,
__global_ptr__ const float *candidate_scores,
__global_ptr__ const float *dev_curand_states,
const int candidate_len,
const float topp) {
const int tid = core_id();
float sum_scores = 0.0f;
float rand_top_p = *dev_curand_states * topp;
// printf("debug rand_top_p:%f\n",rand_top_p);
for (int i = 0; i < candidate_len; i++) {
sum_scores += candidate_scores[i];
if (rand_top_p <= sum_scores) {
return candidate_ids[i];
}
}
return candidate_ids[0];
}
// GREEDY / TARGET_MATCH: exact single-token match
__device__ inline bool verify_one_match(int64_t target_token,
int64_t draft_token) {
return target_token == draft_token;
}
__device__ inline bool verify_one_topp(
__global_ptr__ const int64_t *verify_tokens_row,
int64_t draft_token,
int actual_cand_len) {
return is_in(verify_tokens_row, draft_token, actual_cand_len);
}
// ============================================================
// VerifyContext — per-batch mutable state + accept helpers.
// Eliminates repeated EOS/max_dec_len check and output write
// patterns across Phase 1 and Phase 2.
// ============================================================
struct VerifyContext {
// Immutable per-batch (set once at kernel entry)
int bid;
int max_step_tokens;
int end_length;
__global_ptr__ const int64_t *end_tokens;
__global_ptr__ const int64_t *max_dec_len;
__global_ptr__ const int64_t *step_input_ids_now;
__global_ptr__ int64_t *step_output_ids;
// Mutable per-batch state
int64_t cur_step_idx;
int output_len_now;
bool stopped;
// Emit a token at position `pos` to output in Phase 1.
// Performs: step_idx check, EOS detection, token replacement, output write.
// Returns true if this sequence should stop (EOS or max_dec_len hit).
__device__ bool emit_token(int pos, int64_t token) {
cur_step_idx++;
bool is_eos = is_in_end(token, end_tokens, end_length);
bool max_len_hit = (cur_step_idx >= max_dec_len[bid]);
if (max_len_hit && !is_eos) {
token = end_tokens[0];
}
step_output_ids[bid * max_step_tokens + pos] = token;
output_len_now++;
if (is_eos || max_len_hit) {
stopped = true;
return true;
}
return false;
}
// Emit the final token at position `pos` in Phase 2.
// Same EOS/max_dec_len logic. Increments output_len_now since
// Phase 2 produces one additional token.
__device__ void emit_final_token(int pos, int64_t token) {
cur_step_idx++;
bool is_eos = is_in_end(token, end_tokens, end_length);
bool max_len_hit = (cur_step_idx >= max_dec_len[bid]);
if (max_len_hit && !is_eos) {
token = end_tokens[0];
}
step_output_ids[bid * max_step_tokens + pos] = token;
output_len_now++;
}
// TOPP-only: verify_window bulk-accept fallback.
//
// When draft token is NOT in top-p set but IS the top-2 token,
// check verify_window consecutive positions for top-1 match.
// If all match, bulk-accept from position i through ii.
//
// Returns the new loop position (i) after handling.
// Sets *rejected=true if fallback was not triggered (caller should break).
__device__ int try_verify_window_fallback(
int i,
bool *rejected,
__global_ptr__ const int64_t *verify_tokens_now,
int seq_len_this_time,
int max_candidate_len,
int verify_window) {
int ii = i;
if (max_candidate_len >= 2 &&
verify_tokens_now[ii * max_candidate_len + 1] ==
step_input_ids_now[ii + 1]) {
// top-2 matches — scan verify_window consecutive top-1 matches
int j = 0;
ii += 1;
for (; j < verify_window && ii < seq_len_this_time - 1; j++, ii++) {
if (verify_tokens_now[ii * max_candidate_len] !=
step_input_ids_now[ii + 1]) {
break;
}
}
if (j >= verify_window) {
// Bulk accept all tokens from i to ii
for (; i < ii; i++) {
if (emit_token(i, step_input_ids_now[i + 1])) return i;
}
return i; // continue outer loop from position ii
}
}
// Fallback not triggered or insufficient window — reject
*rejected = true;
return i;
}
};
__global__ void verify_draft_tokens(
// Core I/O
int64_t *step_output_ids,
int *step_output_len,
const int64_t *step_input_ids, // draft tokens
// Target model outputs (strategy-dependent interpretation)
const int64_t
*target_tokens, // GREEDY:argmax, TARGET_MATCH:sampled, TOPP:unused
// Candidate set for TOPP/GREEDY (TARGET_MATCH: unused)
const int64_t *candidate_ids,
const float *candidate_scores,
const int *candidate_lens,
// Sampling params
const float *curand_states, // nullptr for GREEDY/TARGET_MATCH
const float *topp,
// Metadata
const bool *stop_flags,
const int *seq_lens_encoder,
const int *seq_lens_this_time,
const int64_t *end_tokens,
const bool *is_block_step,
const int *cu_seqlens_q_output,
const int *reasoning_status,
// max_dec_len / step_idx for EOS/max-len detection (read-only)
const int64_t *max_dec_len,
const int64_t *step_idx,
// Dimensions and config
const int max_bsz,
const int real_bsz,
const int max_step_tokens,
const int end_length,
const int max_seq_len,
const int max_candidate_len,
const int verify_window,
const int verify_strategy, // 0=TOPP, 1=GREEDY, 2=TARGET_MATCH
const bool reject_all,
const bool accept_all) {
const int64_t tid = core_id() * cluster_num() + cluster_id();
const int64_t nthreads = cluster_num() * core_num();
for (int64_t bid = tid; bid < real_bsz; bid += nthreads) {
step_output_len[bid] = 0;
if (bid >= real_bsz || is_block_step[bid] || stop_flags[bid]) continue;
const int start_token_id = cu_seqlens_q_output[bid];
// Pointers are strategy-dependent (may be nullptr for unused params)
auto *candidate_ids_now =
candidate_ids ? candidate_ids + start_token_id * max_candidate_len
: nullptr;
auto *candidate_scores_now =
candidate_scores ? candidate_scores + start_token_id * max_candidate_len
: nullptr;
auto *candidate_lens_now =
candidate_lens ? candidate_lens + start_token_id : nullptr;
auto *target_tokens_now =
target_tokens ? target_tokens + start_token_id : nullptr;
// Initialize per-batch verification context
VerifyContext ctx;
ctx.bid = bid;
ctx.max_step_tokens = max_step_tokens;
ctx.end_length = end_length;
ctx.end_tokens = end_tokens;
ctx.max_dec_len = max_dec_len;
ctx.step_input_ids_now = step_input_ids + bid * max_step_tokens;
ctx.step_output_ids = step_output_ids;
ctx.cur_step_idx = step_idx[bid];
ctx.output_len_now = 0;
ctx.stopped = false;
// ======== Phase 1: Verify draft tokens ========
int i = 0;
for (; i < seq_lens_this_time[bid] - 1; i++) {
// Early exit conditions: reject-all, prefill, reasoning
if (reject_all || seq_lens_encoder[bid] != 0 ||
reasoning_status[bid] == 1) {
break;
}
// Accept-all override (debug/warmup)
if (accept_all) {
if (ctx.emit_token(i, ctx.step_input_ids_now[i + 1])) break;
continue;
}
// Strategy dispatch
bool accepted = false;
switch (verify_strategy) {
case 0: { // TOPP
auto actual_cand_len = candidate_lens_now[i] > max_candidate_len
? max_candidate_len
: candidate_lens_now[i];
accepted = verify_one_topp(candidate_ids_now + i * max_candidate_len,
ctx.step_input_ids_now[i + 1],
actual_cand_len);
if (!accepted) {
bool rejected = false;
i = ctx.try_verify_window_fallback(i,
&rejected,
candidate_ids_now,
seq_lens_this_time[bid],
max_candidate_len,
verify_window);
if (ctx.stopped || rejected) goto phase1_done;
continue; // bulk accept succeeded, continue from new i
}
break;
}
case 1: // GREEDY
case 2: // TARGET_MATCH
accepted = verify_one_match(target_tokens_now[i],
ctx.step_input_ids_now[i + 1]);
break;
}
if (accepted) {
if (ctx.emit_token(i, ctx.step_input_ids_now[i + 1])) break;
} else {
break; // reject
}
}
phase1_done:
// ======== Phase 2: Output token for rejected/last position ========
if (!ctx.stopped) {
int64_t output_token;
switch (verify_strategy) {
case 0: { // TOPP — stochastic sampling from candidate set
auto actual_cand_len = candidate_lens_now[i] > max_candidate_len
? max_candidate_len
: candidate_lens_now[i];
output_token =
topp_sampling_kernel(candidate_ids_now + i * max_candidate_len,
candidate_scores_now + i * max_candidate_len,
curand_states + bid,
actual_cand_len,
topp[bid]);
break;
}
case 1: // GREEDY — deterministic argmax from target_tokens
case 2: // TARGET_MATCH — target model's sampled token
output_token = target_tokens_now[i];
break;
}
ctx.emit_final_token(i, output_token);
}
step_output_len[bid] = ctx.output_len_now;
}
}
} // namespace fd_xpu3
@@ -19,7 +19,6 @@
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace fd_xpu3 {
typedef uint32_t curandStatePhilox4_32_10_t;
template <bool ENABLE_TOPP, bool USE_TOPK>
__attribute__((global)) void speculate_verify(
@@ -87,8 +86,6 @@ static inline unsigned int xorwow(unsigned int &state) { // NOLINT
return state;
}
typedef uint32_t curandStatePhilox4_32_10_t;
static int64_t topp_sampling_kernel(const int64_t *candidate_ids,
const float *candidate_scores,
const float *dev_curand_states,
@@ -0,0 +1,614 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <numeric>
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace fd_xpu3 {
__attribute__((global)) void verify_draft_tokens(
// Core I/O
int64_t *step_output_ids,
int *step_output_len,
const int64_t *step_input_ids, // draft tokens
// Target model outputs (strategy-dependent interpretation)
const int64_t
*target_tokens, // GREEDY:argmax, TARGET_MATCH:sampled, TOPP:unused
// Candidate set for TOPP/GREEDY (TARGET_MATCH: unused)
const int64_t *candidate_ids,
const float *candidate_scores,
const int *candidate_lens,
// Sampling params
const float *curand_states, // nullptr for GREEDY/TARGET_MATCH
const float *topp,
// Metadata
const bool *stop_flags,
const int *seq_lens_encoder,
const int *seq_lens_this_time,
const int64_t *end_tokens,
const bool *is_block_step,
const int *cu_seqlens_q_output,
const int *reasoning_status,
// max_dec_len / step_idx for EOS/max-len detection (read-only)
const int64_t *max_dec_len,
const int64_t *step_idx,
// Dimensions and config
const int max_bsz,
const int real_bsz,
const int max_step_tokens,
const int end_length,
const int max_seq_len,
const int max_candidate_len,
const int verify_window,
const int verify_strategy, // 0=TOPP, 1=GREEDY, 2=TARGET_MATCH
const bool reject_all,
const bool accept_all);
} // namespace fd_xpu3
namespace fastdeploy {
namespace plugin {
// ============================================================
// Phase 1 helpers — single-step draft token verification
// ============================================================
// Check if draft_token appears in the candidate set
static inline bool is_in(const int64_t *candidates,
const int64_t draft,
const int candidate_len) {
for (int i = 0; i < candidate_len; i++) {
if (draft == candidates[i]) {
return true;
}
}
return false;
}
// TOPP: draft in top-p filtered candidate set
static inline bool verify_one_topp(const int64_t *verify_tokens_row,
int64_t draft_token,
int actual_cand_len) {
return is_in(verify_tokens_row, draft_token, actual_cand_len);
}
// GREEDY / TARGET_MATCH: exact single-token match
static inline bool verify_one_match(int64_t target_token, int64_t draft_token) {
return target_token == draft_token;
}
static inline bool is_in_end(const int64_t id,
const int64_t *end_ids,
int length) {
bool flag = false;
for (int i = 0; i < length; i++) {
if (id == end_ids[i]) {
return true;
}
}
return flag;
}
// ============================================================
// VerifyContext — per-batch mutable state + accept helpers.
// Eliminates repeated EOS/max_dec_len check and output write
// patterns across Phase 1 and Phase 2.
// ============================================================
struct VerifyContext {
// Immutable per-batch (set once at kernel entry)
int bid;
int max_step_tokens;
int end_length;
const int64_t *end_tokens;
const int64_t *max_dec_len;
const int64_t *step_input_ids_now;
int64_t *step_output_ids;
// Mutable per-batch state
int64_t cur_step_idx;
int output_len_now;
bool stopped;
// Emit a token at position `pos` to output in Phase 1.
// Performs: step_idx check, EOS detection, token replacement, output write.
// Returns true if this sequence should stop (EOS or max_dec_len hit).
bool emit_token(int pos, int64_t token) {
cur_step_idx++;
bool is_eos = is_in_end(token, end_tokens, end_length);
bool max_len_hit = (cur_step_idx >= max_dec_len[bid]);
if ((is_eos || max_len_hit) && !is_eos) {
token = end_tokens[0];
}
step_output_ids[bid * max_step_tokens + pos] = token;
output_len_now++;
if (is_eos || max_len_hit) {
stopped = true;
return true;
}
return false;
}
// Emit the final token at position `pos` in Phase 2.
// Same EOS/max_dec_len logic. Increments output_len_now since
// Phase 2 produces one additional token.
void emit_final_token(int pos, int64_t token) {
cur_step_idx++;
bool is_eos = is_in_end(token, end_tokens, end_length);
bool max_len_hit = (cur_step_idx >= max_dec_len[bid]);
if ((is_eos || max_len_hit) && !is_eos) {
token = end_tokens[0];
}
step_output_ids[bid * max_step_tokens + pos] = token;
output_len_now++;
}
// TOPP-only: verify_window bulk-accept fallback.
//
// When draft token is NOT in top-p set but IS the top-2 token,
// check verify_window consecutive positions for top-1 match.
// If all match, bulk-accept from position i through ii.
//
// Returns the new loop position (i) after handling.
// Sets *rejected=true if fallback was not triggered (caller should break).
int try_verify_window_fallback(int i,
bool *rejected,
const int64_t *verify_tokens_now,
int seq_len_this_time,
int max_candidate_len,
int verify_window) {
int ii = i;
if (max_candidate_len >= 2 &&
verify_tokens_now[ii * max_candidate_len + 1] ==
step_input_ids_now[ii + 1]) {
// top-2 matches — scan verify_window consecutive top-1 matches
int j = 0;
ii += 1;
for (; j < verify_window && ii < seq_len_this_time - 1; j++, ii++) {
if (verify_tokens_now[ii * max_candidate_len] !=
step_input_ids_now[ii + 1]) {
break;
}
}
if (j >= verify_window) {
// Bulk accept all tokens from i to ii
for (; i < ii; i++) {
if (emit_token(i, step_input_ids_now[i + 1])) return i;
}
return i; // continue outer loop from position ii
}
}
// Fallback not triggered or insufficient window — reject
*rejected = true;
return i;
}
};
static int64_t topp_sampling_kernel(const int64_t *candidate_ids,
const float *candidate_scores,
const float *dev_curand_states,
const int candidate_len,
const float topp,
int tid) {
// const int tid = core_id();
float sum_scores = 0.0f;
float rand_top_p = *dev_curand_states * topp;
for (int i = 0; i < candidate_len; i++) {
// printf("debug cpu sample i:%d scores:%f,ids:%ld
// rand_top_p:%f,candidate_len:%d\n",
// i,candidate_scores[i],candidate_ids[i],rand_top_p,candidate_len);
sum_scores += candidate_scores[i];
if (rand_top_p <= sum_scores) {
return candidate_ids[i];
}
}
return candidate_ids[0];
}
static int cpu_wrapper(
api::Context *ctx,
// Core I/O
int64_t *step_output_ids,
int *step_output_len,
const int64_t *step_input_ids, // draft tokens
// Target model outputs (strategy-dependent interpretation)
const int64_t
*target_tokens, // GREEDY:argmax, TARGET_MATCH:sampled, TOPP:unused
// Candidate set for TOPP/GREEDY (TARGET_MATCH: unused)
const int64_t *candidate_ids,
const float *candidate_scores,
const int *candidate_lens,
// Sampling params
const float *curand_states, // nullptr for GREEDY/TARGET_MATCH
const float *topp,
// Metadata
const bool *stop_flags,
const int *seq_lens_encoder,
const int *seq_lens_this_time,
const int64_t *end_tokens,
const bool *is_block_step,
const int *cu_seqlens_q_output,
const int *reasoning_status,
// max_dec_len / step_idx for EOS/max-len detection (read-only)
const int64_t *max_dec_len,
const int64_t *step_idx,
// Dimensions and config
const int max_bsz,
const int real_bsz,
const int max_step_tokens,
const int end_length,
const int max_seq_len,
const int max_candidate_len,
const int verify_window,
const int verify_strategy, // 0=TOPP, 1=GREEDY, 2=TARGET_MATCH
const bool reject_all,
const bool accept_all) {
for (int bid = 0; bid < max_bsz; bid++) {
step_output_len[bid] = 0;
if (bid >= real_bsz || is_block_step[bid] || stop_flags[bid]) continue;
const int start_token_id = cu_seqlens_q_output[bid];
// Pointers are strategy-dependent (may be nullptr for unused params)
auto *candidate_ids_now =
candidate_ids ? candidate_ids + start_token_id * max_candidate_len
: nullptr;
auto *candidate_scores_now =
candidate_scores ? candidate_scores + start_token_id * max_candidate_len
: nullptr;
auto *candidate_lens_now =
candidate_lens ? candidate_lens + start_token_id : nullptr;
auto *target_tokens_now =
target_tokens ? target_tokens + start_token_id : nullptr;
// Initialize per-batch verification context
VerifyContext v_ctx;
v_ctx.bid = bid;
v_ctx.max_step_tokens = max_step_tokens;
v_ctx.end_length = end_length;
v_ctx.end_tokens = end_tokens;
v_ctx.max_dec_len = max_dec_len;
v_ctx.step_input_ids_now = step_input_ids + bid * max_step_tokens;
v_ctx.step_output_ids = step_output_ids;
v_ctx.cur_step_idx = step_idx[bid];
v_ctx.output_len_now = 0;
v_ctx.stopped = false;
// ======== Phase 1: Verify draft tokens ========
int i = 0;
for (; i < seq_lens_this_time[bid] - 1; i++) {
// Early exit conditions: reject-all, prefill, reasoning
if (reject_all || seq_lens_encoder[bid] != 0 ||
reasoning_status[bid] == 1) {
break;
}
// Accept-all override (debug/warmup)
if (accept_all) {
if (v_ctx.emit_token(i, v_ctx.step_input_ids_now[i + 1])) break;
continue;
}
// Strategy dispatch
bool accepted = false;
switch (verify_strategy) {
case 0: { // TOPP
auto actual_cand_len = candidate_lens_now[i] > max_candidate_len
? max_candidate_len
: candidate_lens_now[i];
accepted = verify_one_topp(candidate_ids_now + i * max_candidate_len,
v_ctx.step_input_ids_now[i + 1],
actual_cand_len);
if (!accepted) {
bool rejected = false;
i = v_ctx.try_verify_window_fallback(i,
&rejected,
candidate_ids_now,
seq_lens_this_time[bid],
max_candidate_len,
verify_window);
if (v_ctx.stopped || rejected) goto phase1_done;
continue; // bulk accept succeeded, continue from new i
}
break;
}
case 1: // GREEDY
case 2: // TARGET_MATCH
accepted = verify_one_match(target_tokens_now[i],
v_ctx.step_input_ids_now[i + 1]);
break;
}
if (accepted) {
if (v_ctx.emit_token(i, v_ctx.step_input_ids_now[i + 1])) break;
} else {
break; // reject
}
}
phase1_done:
// ======== Phase 2: Output token for rejected/last position ========
if (!v_ctx.stopped) {
int64_t output_token = 0;
switch (verify_strategy) {
case 0: { // TOPP — stochastic sampling from candidate set
auto actual_cand_len = candidate_lens_now[i] > max_candidate_len
? max_candidate_len
: candidate_lens_now[i];
output_token =
topp_sampling_kernel(candidate_ids_now + i * max_candidate_len,
candidate_scores_now + i * max_candidate_len,
curand_states + bid,
actual_cand_len,
topp[bid],
bid);
break;
}
case 1: // GREEDY — deterministic argmax from target_tokens
case 2: // TARGET_MATCH — target model's sampled token
output_token = target_tokens_now[i];
break;
}
v_ctx.emit_final_token(i, output_token);
}
step_output_len[bid] = v_ctx.output_len_now;
}
return api::SUCCESS;
}
static int xpu3_wrapper(
api::Context *ctx,
// Core I/O
int64_t *step_output_ids,
int *step_output_len,
const int64_t *step_input_ids, // draft tokens
// Target model outputs (strategy-dependent interpretation)
const int64_t
*target_tokens, // GREEDY:argmax, TARGET_MATCH:sampled, TOPP:unused
// Candidate set for TOPP/GREEDY (TARGET_MATCH: unused)
const int64_t *candidate_ids,
const float *candidate_scores,
const int *candidate_lens,
// Sampling params
const float *curand_states, // nullptr for GREEDY/TARGET_MATCH
const float *topp,
// Metadata
const bool *stop_flags,
const int *seq_lens_encoder,
const int *seq_lens_this_time,
const int64_t *end_tokens,
const bool *is_block_step,
const int *cu_seqlens_q_output,
const int *reasoning_status,
// max_dec_len / step_idx for EOS/max-len detection (read-only)
const int64_t *max_dec_len,
const int64_t *step_idx,
// Dimensions and config
const int max_bsz,
const int real_bsz,
const int max_step_tokens,
const int end_length,
const int max_seq_len,
const int max_candidate_len,
const int verify_window,
const int verify_strategy, // 0=TOPP, 1=GREEDY, 2=TARGET_MATCH
const bool reject_all,
const bool accept_all) {
using XPU_INT64 = typename api::XPUIndexType<int64_t>::type;
int32_t ret_xre =
fd_xpu3::verify_draft_tokens<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
reinterpret_cast<XPU_INT64 *>(step_output_ids),
step_output_len,
reinterpret_cast<const XPU_INT64 *>(step_input_ids),
reinterpret_cast<const XPU_INT64 *>(target_tokens),
reinterpret_cast<const XPU_INT64 *>(candidate_ids),
candidate_scores,
candidate_lens,
curand_states,
topp,
stop_flags,
seq_lens_encoder,
seq_lens_this_time,
reinterpret_cast<const XPU_INT64 *>(end_tokens),
is_block_step,
cu_seqlens_q_output,
reasoning_status,
reinterpret_cast<const XPU_INT64 *>(max_dec_len),
reinterpret_cast<const XPU_INT64 *>(step_idx),
max_bsz,
real_bsz,
max_step_tokens,
end_length,
max_seq_len,
max_candidate_len,
verify_window,
verify_strategy,
reject_all,
accept_all);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
int verify_draft_tokens(
api::Context *ctx,
// Core I/O
int64_t *step_output_ids,
int *step_output_len,
const int64_t *step_input_ids, // draft tokens
// Target model outputs (strategy-dependent interpretation)
const int64_t
*target_tokens, // GREEDY:argmax, TARGET_MATCH:sampled, TOPP:unused
// Candidate set for TOPP/GREEDY (TARGET_MATCH: unused)
const int64_t *candidate_ids,
const float *candidate_scores,
const int *candidate_lens,
// Sampling params
const float *curand_states, // nullptr for GREEDY/TARGET_MATCH
const float *topp,
// Metadata
const bool *stop_flags,
const int *seq_lens_encoder,
const int *seq_lens_this_time,
const int64_t *end_tokens,
const bool *is_block_step,
const int *cu_seqlens_q_output,
const int *reasoning_status,
// max_dec_len / step_idx for EOS/max-len detection (read-only)
const int64_t *max_dec_len,
const int64_t *step_idx,
// Dimensions and config
const int max_bsz,
const int real_bsz,
const int max_step_tokens,
const int end_length,
const int max_seq_len,
const int max_candidate_len,
const int verify_window,
const int verify_strategy, // 0=TOPP, 1=GREEDY, 2=TARGET_MATCH
const bool reject_all,
const bool accept_all) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "verify_draft_tokens", int64_t);
WRAPPER_DUMP_PARAM6(ctx,
step_output_ids,
step_output_len,
step_input_ids,
target_tokens,
candidate_ids,
candidate_scores);
WRAPPER_DUMP_PARAM6(ctx,
candidate_lens,
curand_states,
topp,
stop_flags,
seq_lens_encoder,
seq_lens_this_time);
WRAPPER_DUMP_PARAM6(ctx,
end_tokens,
is_block_step,
cu_seqlens_q_output,
reasoning_status,
max_dec_len,
step_idx);
WRAPPER_DUMP_PARAM6(ctx,
max_bsz,
real_bsz,
max_step_tokens,
end_length,
max_seq_len,
max_candidate_len);
WRAPPER_DUMP_PARAM4(
ctx, verify_window, verify_strategy, reject_all, accept_all);
WRAPPER_DUMP(ctx);
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * max_step_tokens, step_output_ids);
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * max_step_tokens, step_input_ids);
// len(target_tokens) = cu_seqlens_q_output[-1]
WRAPPER_CHECK_PTR_OR_NULL(ctx, int64_t, real_bsz, target_tokens);
WRAPPER_CHECK_PTR_OR_NULL(ctx, int64_t, real_bsz, candidate_lens);
WRAPPER_CHECK_PTR_OR_NULL(
ctx, int64_t, real_bsz * max_candidate_len, candidate_ids);
WRAPPER_CHECK_PTR_OR_NULL(
ctx, float, real_bsz *max_candidate_len, candidate_scores);
WRAPPER_CHECK_PTR(ctx, float, real_bsz, curand_states);
WRAPPER_CHECK_PTR(ctx, float, real_bsz, topp);
WRAPPER_CHECK_PTR(ctx, bool, real_bsz, stop_flags);
WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_encoder);
WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_this_time);
WRAPPER_CHECK_PTR(ctx, int64_t, end_length, end_tokens);
WRAPPER_CHECK_PTR(ctx, bool, real_bsz, is_block_step);
WRAPPER_CHECK_PTR(ctx, int, real_bsz, cu_seqlens_q_output);
WRAPPER_CHECK_PTR(ctx, int, real_bsz, reasoning_status);
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz, max_dec_len);
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz, step_idx);
// param check sm size limit
WRAPPER_ASSERT_GT(ctx, real_bsz, 0);
WRAPPER_ASSERT_LE(ctx, real_bsz, 1024);
WRAPPER_ASSERT_LE(ctx, real_bsz * max_candidate_len, 2048);
WRAPPER_ASSERT_LE(ctx, verify_window * max_candidate_len, 128);
WRAPPER_CHECK_PTR(ctx, int, real_bsz, step_output_len);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(ctx,
step_output_ids,
step_output_len,
step_input_ids,
target_tokens,
candidate_ids,
candidate_scores,
candidate_lens,
curand_states,
topp,
stop_flags,
seq_lens_encoder,
seq_lens_this_time,
end_tokens,
is_block_step,
cu_seqlens_q_output,
reasoning_status,
max_dec_len,
step_idx,
max_bsz,
real_bsz,
max_step_tokens,
end_length,
max_seq_len,
max_candidate_len,
verify_window,
verify_strategy,
reject_all,
accept_all);
}
if (ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx,
step_output_ids,
step_output_len,
step_input_ids,
target_tokens,
candidate_ids,
candidate_scores,
candidate_lens,
curand_states,
topp,
stop_flags,
seq_lens_encoder,
seq_lens_this_time,
end_tokens,
is_block_step,
cu_seqlens_q_output,
reasoning_status,
max_dec_len,
step_idx,
max_bsz,
real_bsz,
max_step_tokens,
end_length,
max_seq_len,
max_candidate_len,
verify_window,
verify_strategy,
reject_all,
accept_all);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace fastdeploy
File diff suppressed because it is too large Load Diff