mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Speculative Decoding] Unify Spec and non-spec branch (#6685)
* optimize spec-inference architecture * delete debug log * optimize spec_method usage && fix unit_test * add claude unit-test skill * fix some ugly bug * enhance robustness and bounds check * unify method & spec_method to method to avoid bug * activate CI * fix unit test * Unify logprobs computation for naive and speculative decoding, fix CUDA kernel * fix logprob bug && optimize verify kernel * fix exist_decode() judge
This commit is contained in:
@@ -78,6 +78,7 @@ __global__ void draft_model_update_kernel(const int64_t* inter_next_tokens,
|
||||
}
|
||||
|
||||
// multi_end
|
||||
// TODO(liuzichang): Don't check eos in future
|
||||
if (is_in_end(token_this_time, end_ids, end_ids_len) ||
|
||||
prefill_one_step_stop) {
|
||||
stop_flags[tid] = true;
|
||||
|
||||
@@ -100,6 +100,7 @@ void SpeculateGetLogits(const paddle::Tensor& draft_logits,
|
||||
const_cast<int*>(&cu_batch_token_offset.data<int>()[1]),
|
||||
real_bsz,
|
||||
cu_stream);
|
||||
cudaFree(temp_storage1);
|
||||
|
||||
void* temp_storage2 = nullptr;
|
||||
size_t temp_storage_bytes2 = 0;
|
||||
@@ -118,6 +119,7 @@ void SpeculateGetLogits(const paddle::Tensor& draft_logits,
|
||||
const_cast<int*>(&cu_next_token_offset.data<int>()[1]),
|
||||
real_bsz,
|
||||
cu_stream);
|
||||
cudaFree(temp_storage2);
|
||||
|
||||
constexpr int PackSize = VEC_16B / sizeof(float);
|
||||
dim3 grid_dim(real_bsz);
|
||||
@@ -184,7 +186,7 @@ void SpeculateInsertFirstToken(const paddle::Tensor& token_ids,
|
||||
|
||||
template <int VecSize>
|
||||
__global__ void speculate_get_target_logits_kernel(
|
||||
float* target_logtis,
|
||||
float* target_logits,
|
||||
const float* logits,
|
||||
const int* cu_batch_token_offset,
|
||||
const int* ori_cu_batch_token_offset,
|
||||
@@ -197,18 +199,18 @@ __global__ void speculate_get_target_logits_kernel(
|
||||
const int bid = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
if (bid < real_bsz) {
|
||||
auto* target_logtis_now =
|
||||
target_logtis + cu_batch_token_offset[bid] * vocab_size;
|
||||
auto* target_logits_now =
|
||||
target_logits + cu_batch_token_offset[bid] * vocab_size;
|
||||
auto* logits_now = logits + ori_cu_batch_token_offset[bid] * vocab_size;
|
||||
for (int i = tid * VecSize; i < vocab_size; i += blockDim.x * VecSize) {
|
||||
if (seq_lens_encoder[bid] > 0) {
|
||||
Load<float, VecSize>(&logits_now[i], &src_vec);
|
||||
Store<float, VecSize>(src_vec, &target_logtis_now[i]);
|
||||
Store<float, VecSize>(src_vec, &target_logits_now[i]);
|
||||
} else {
|
||||
for (int j = 0; j < accept_num[bid]; j++) {
|
||||
Load<float, VecSize>(&logits_now[j * vocab_size + i], &src_vec);
|
||||
Store<float, VecSize>(src_vec,
|
||||
&target_logtis_now[j * vocab_size + i]);
|
||||
&target_logits_now[j * vocab_size + i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,298 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
/**
|
||||
* @file unified_update_model_status.cu
|
||||
* @brief Unified kernel for updating model status after token generation.
|
||||
*
|
||||
* Launched as a single block of 1024 threads (max_bsz <= 1024).
|
||||
*/
|
||||
|
||||
/**
|
||||
* @brief Check if token is an end token.
|
||||
*/
|
||||
__device__ __forceinline__ bool is_end_token(int64_t token,
|
||||
const int64_t *end_tokens,
|
||||
int num_end_tokens) {
|
||||
#pragma unroll 4
|
||||
for (int i = 0; i < num_end_tokens; i++) {
|
||||
if (token == end_tokens[i]) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Main unified update kernel.
|
||||
*/
|
||||
template <int BLOCK_SIZE>
|
||||
__global__ void unified_update_model_status_kernel(int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
bool *has_running_seqs,
|
||||
int *mask_rollback,
|
||||
int64_t *step_input_ids,
|
||||
int *adaptive_step_input_len,
|
||||
int64_t *step_output_ids,
|
||||
int *step_output_len,
|
||||
bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
const bool *is_paused,
|
||||
int64_t *token_ids_all,
|
||||
const int64_t *prompt_lens,
|
||||
int64_t *step_idx,
|
||||
const int64_t *end_tokens,
|
||||
const int64_t *max_dec_len,
|
||||
int real_bsz,
|
||||
int max_bsz,
|
||||
int max_step_tokens,
|
||||
int max_model_len,
|
||||
int num_end_tokens,
|
||||
bool is_naive_mode,
|
||||
bool prefill_one_step_stop) {
|
||||
const int batch_id = blockIdx.x * BLOCK_SIZE + threadIdx.x;
|
||||
const bool is_valid_slot = batch_id < max_bsz;
|
||||
int stop_flag_int = 0;
|
||||
|
||||
if (is_valid_slot) {
|
||||
// Read state
|
||||
int cur_seq_len_encoder = seq_lens_encoder[batch_id];
|
||||
int cur_seq_len_decoder = seq_lens_decoder[batch_id];
|
||||
bool cur_stop_flag = stop_flags[batch_id];
|
||||
int output_len = 0;
|
||||
int64_t cur_step_idx = step_idx[batch_id];
|
||||
bool cur_is_paused = is_paused[batch_id];
|
||||
|
||||
bool is_running = !cur_stop_flag && !cur_is_paused;
|
||||
|
||||
// Compute output length
|
||||
if (is_running) {
|
||||
if (is_naive_mode) {
|
||||
output_len = 1;
|
||||
} else {
|
||||
output_len = step_output_len[batch_id];
|
||||
}
|
||||
}
|
||||
|
||||
// EOS detection
|
||||
if (is_running && output_len > 0) {
|
||||
bool hit_stop = false;
|
||||
int64_t *output_ids = &step_output_ids[batch_id * max_step_tokens];
|
||||
|
||||
for (int i = 0; i < output_len; i++) {
|
||||
cur_step_idx++;
|
||||
int64_t token = output_ids[i];
|
||||
bool is_eos = is_end_token(token, end_tokens, num_end_tokens);
|
||||
bool max_len_hit = (cur_step_idx >= max_dec_len[batch_id]);
|
||||
|
||||
if (is_eos || max_len_hit) {
|
||||
if (!is_eos) output_ids[i] = end_tokens[0];
|
||||
output_len = i + 1;
|
||||
cur_stop_flag = true;
|
||||
hit_stop = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!hit_stop && prefill_one_step_stop && cur_seq_len_encoder > 0) {
|
||||
cur_stop_flag = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Update state and write back
|
||||
if (is_running) {
|
||||
if (cur_stop_flag) {
|
||||
stop_flag_int = 1;
|
||||
if (output_len == 0) cur_seq_len_decoder = 0;
|
||||
stop_flags[batch_id] = true;
|
||||
mask_rollback[batch_id] = 0;
|
||||
} else if (cur_seq_len_encoder == 0) {
|
||||
cur_seq_len_decoder += output_len;
|
||||
mask_rollback[batch_id] = seq_lens_this_time[batch_id] - output_len;
|
||||
} else {
|
||||
mask_rollback[batch_id] = 0;
|
||||
}
|
||||
|
||||
if (cur_seq_len_encoder > 0) {
|
||||
cur_seq_len_decoder += cur_seq_len_encoder;
|
||||
cur_seq_len_encoder = 0;
|
||||
}
|
||||
|
||||
seq_lens_encoder[batch_id] = cur_seq_len_encoder;
|
||||
seq_lens_decoder[batch_id] = cur_seq_len_decoder;
|
||||
step_output_len[batch_id] = output_len;
|
||||
step_idx[batch_id] = cur_step_idx;
|
||||
|
||||
// Write history to token_ids_all
|
||||
if (cur_step_idx > 0 && output_len > 0) {
|
||||
// Bounds check: highest write index is prompt_lens + cur_step_idx
|
||||
if (prompt_lens[batch_id] + cur_step_idx < max_model_len) {
|
||||
int64_t *token_ids_all_now =
|
||||
&token_ids_all[batch_id * max_model_len + prompt_lens[batch_id]];
|
||||
int64_t *output_ids = &step_output_ids[batch_id * max_step_tokens];
|
||||
for (int i = 0; i < output_len; i++) {
|
||||
token_ids_all_now[cur_step_idx - i] =
|
||||
output_ids[output_len - 1 - i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Setup next input
|
||||
if (output_len > 0) {
|
||||
step_input_ids[batch_id * max_step_tokens] =
|
||||
step_output_ids[batch_id * max_step_tokens + output_len - 1];
|
||||
}
|
||||
|
||||
if (is_naive_mode) {
|
||||
seq_lens_this_time[batch_id] = cur_stop_flag ? 0 : 1;
|
||||
}
|
||||
} else if (batch_id >= real_bsz) {
|
||||
// Padding slot: just count as stopped, don't modify state
|
||||
stop_flag_int = 1;
|
||||
} else {
|
||||
// Stopped or paused slot (batch_id < real_bsz)
|
||||
stop_flag_int = 1;
|
||||
stop_flags[batch_id] = true;
|
||||
seq_lens_decoder[batch_id] = 0;
|
||||
seq_lens_this_time[batch_id] = 0;
|
||||
step_output_len[batch_id] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Simple block-level reduction using shared memory
|
||||
__syncthreads();
|
||||
typedef cub::BlockReduce<int64_t, BLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
||||
// printf("stop_flag_now_int %d \n", stop_flag_int);
|
||||
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_int);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
// printf("stop_sum %d \n", stop_sum);
|
||||
has_running_seqs[0] = stop_sum < max_bsz;
|
||||
}
|
||||
}
|
||||
|
||||
// Host interface
|
||||
void UnifiedUpdateModelStatus(const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &has_running_seqs,
|
||||
const paddle::Tensor &step_input_ids,
|
||||
const paddle::Tensor &adaptive_step_input_len,
|
||||
const paddle::Tensor &step_output_ids,
|
||||
const paddle::Tensor &step_output_len,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &is_paused,
|
||||
const paddle::Tensor &mask_rollback,
|
||||
const paddle::Tensor &token_ids_all,
|
||||
const paddle::Tensor &prompt_lens,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &end_tokens,
|
||||
const paddle::Tensor &max_dec_len,
|
||||
const bool is_naive_mode,
|
||||
const bool prefill_one_step_stop) {
|
||||
const int real_bsz = seq_lens_this_time.shape()[0];
|
||||
const int max_bsz = stop_flags.shape()[0];
|
||||
PADDLE_ENFORCE_LE(
|
||||
max_bsz,
|
||||
1024,
|
||||
phi::errors::InvalidArgument(
|
||||
"unified_update_model_status: max_bsz (%d) must be <= 1024 "
|
||||
"(single-block launch limit).",
|
||||
max_bsz));
|
||||
const int max_step_tokens = step_input_ids.shape()[1];
|
||||
const int max_model_len = token_ids_all.shape()[1];
|
||||
const int num_end_tokens = end_tokens.shape()[0];
|
||||
|
||||
constexpr int BlockSize = 1024;
|
||||
|
||||
// has_running_seqs is CPU tensor, need to copy to GPU first
|
||||
auto has_running_seqs_gpu =
|
||||
has_running_seqs.copy_to(seq_lens_this_time.place(), false);
|
||||
unified_update_model_status_kernel<BlockSize>
|
||||
<<<1, BlockSize, 0, seq_lens_this_time.stream()>>>(
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<bool *>(has_running_seqs_gpu.data<bool>()),
|
||||
const_cast<int *>(mask_rollback.data<int>()),
|
||||
const_cast<int64_t *>(step_input_ids.data<int64_t>()),
|
||||
const_cast<int *>(adaptive_step_input_len.data<int>()),
|
||||
const_cast<int64_t *>(step_output_ids.data<int64_t>()),
|
||||
const_cast<int *>(step_output_len.data<int>()),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<bool *>(is_paused.data<bool>()),
|
||||
const_cast<int64_t *>(token_ids_all.data<int64_t>()),
|
||||
prompt_lens.data<int64_t>(),
|
||||
const_cast<int64_t *>(step_idx.data<int64_t>()),
|
||||
end_tokens.data<int64_t>(),
|
||||
max_dec_len.data<int64_t>(),
|
||||
real_bsz,
|
||||
max_bsz,
|
||||
max_step_tokens,
|
||||
max_model_len,
|
||||
num_end_tokens,
|
||||
is_naive_mode,
|
||||
prefill_one_step_stop);
|
||||
// Copy result back to CPU
|
||||
auto has_running_seqs_cpu =
|
||||
has_running_seqs_gpu.copy_to(has_running_seqs.place(), false);
|
||||
bool *out_data = const_cast<bool *>(has_running_seqs.data<bool>());
|
||||
out_data[0] = has_running_seqs_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(unified_update_model_status)
|
||||
.Inputs({"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"has_running_seqs",
|
||||
"step_input_ids",
|
||||
"adaptive_step_input_len",
|
||||
"step_output_ids",
|
||||
"step_output_len",
|
||||
"stop_flags",
|
||||
"seq_lens_this_time",
|
||||
"is_paused",
|
||||
"mask_rollback",
|
||||
"token_ids_all",
|
||||
"prompt_lens",
|
||||
"step_idx",
|
||||
"end_tokens",
|
||||
"max_dec_len"})
|
||||
.Attrs({"is_naive_mode: bool", "prefill_one_step_stop: bool"})
|
||||
.Outputs({"seq_lens_encoder_out",
|
||||
"seq_lens_decoder_out",
|
||||
"has_running_seqs_out",
|
||||
"step_input_ids_out",
|
||||
"adaptive_step_input_len_out",
|
||||
"step_output_ids_out",
|
||||
"step_output_len_out",
|
||||
"stop_flags_out",
|
||||
"seq_lens_this_time_out",
|
||||
"mask_rollback_out",
|
||||
"token_ids_all_out",
|
||||
"step_idx_out"})
|
||||
.SetInplaceMap({{"seq_lens_encoder", "seq_lens_encoder_out"},
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"has_running_seqs", "has_running_seqs_out"},
|
||||
{"step_input_ids", "step_input_ids_out"},
|
||||
{"adaptive_step_input_len", "adaptive_step_input_len_out"},
|
||||
{"step_output_ids", "step_output_ids_out"},
|
||||
{"step_output_len", "step_output_len_out"},
|
||||
{"stop_flags", "stop_flags_out"},
|
||||
{"seq_lens_this_time", "seq_lens_this_time_out"},
|
||||
{"mask_rollback", "mask_rollback_out"},
|
||||
{"token_ids_all", "token_ids_all_out"},
|
||||
{"step_idx", "step_idx_out"}})
|
||||
.SetKernelFn(PD_KERNEL(UnifiedUpdateModelStatus));
|
||||
@@ -0,0 +1,525 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Verification kernel — outputs step_output_ids + step_output_len,
|
||||
// and performs EOS / max_dec_len detection (read-only on step_idx).
|
||||
// step_idx is NOT modified here; all state updates (including step_idx)
|
||||
// are handled by unified_update_model_status.
|
||||
//
|
||||
// Verification strategies:
|
||||
// 0 = TOPP : draft token in top-p candidate set (+ verify_window
|
||||
// fallback) 1 = GREEDY : draft token == top-1 token (strict argmax
|
||||
// match) 2 = TARGET_MATCH : draft token == target model's sampled token
|
||||
|
||||
#include <curand_kernel.h>
|
||||
#include "helper.h" // NOLINT
|
||||
|
||||
// ============================================================
|
||||
// Persistent curand state — allocated once, reused across calls.
|
||||
// Only needed for TOPP strategy (Phase 2 stochastic sampling).
|
||||
// ============================================================
|
||||
static curandState_t *dev_curand_states = nullptr;
|
||||
static int allocated_bsz = 0;
|
||||
static uint64_t seed = 0;
|
||||
static uint64_t offset = 0;
|
||||
|
||||
__global__ void setup_seed_kernel(curandState_t *state,
|
||||
const uint64_t seed,
|
||||
const uint64_t offset,
|
||||
const int bs,
|
||||
const bool need_batch_random) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (int i = idx; i < bs; i += gridDim.x * blockDim.x) {
|
||||
if (need_batch_random) {
|
||||
curand_init(seed, i, offset, &state[i]);
|
||||
} else {
|
||||
curand_init(seed, 0, offset, &state[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Phase 1 helpers — single-step draft token verification
|
||||
// ============================================================
|
||||
|
||||
// Check if draft_token appears in the candidate set
|
||||
__device__ inline bool is_in(const int64_t *candidates,
|
||||
const int64_t draft,
|
||||
const int candidate_len) {
|
||||
for (int i = 0; i < candidate_len; i++) {
|
||||
if (draft == candidates[i]) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// TOPP: draft in top-p filtered candidate set
|
||||
__device__ inline bool verify_one_topp(const int64_t *verify_tokens_row,
|
||||
int64_t draft_token,
|
||||
int actual_cand_len) {
|
||||
return is_in(verify_tokens_row, draft_token, actual_cand_len);
|
||||
}
|
||||
|
||||
// GREEDY / TARGET_MATCH: exact single-token match
|
||||
__device__ inline bool verify_one_match(int64_t target_token,
|
||||
int64_t draft_token) {
|
||||
return target_token == draft_token;
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// VerifyContext — per-batch mutable state + accept helpers.
|
||||
// Eliminates repeated EOS/max_dec_len check and output write
|
||||
// patterns across Phase 1 and Phase 2.
|
||||
// ============================================================
|
||||
struct VerifyContext {
|
||||
// Immutable per-batch (set once at kernel entry)
|
||||
int bid;
|
||||
int max_step_tokens;
|
||||
int end_length;
|
||||
const int64_t *end_tokens;
|
||||
const int64_t *max_dec_len;
|
||||
const int64_t *step_input_ids_now;
|
||||
int64_t *step_output_ids;
|
||||
|
||||
// Mutable per-batch state
|
||||
int64_t cur_step_idx;
|
||||
int output_len_now;
|
||||
bool stopped;
|
||||
|
||||
// Emit a token at position `pos` to output in Phase 1.
|
||||
// Performs: step_idx check, EOS detection, token replacement, output write.
|
||||
// Returns true if this sequence should stop (EOS or max_dec_len hit).
|
||||
__device__ __forceinline__ bool emit_token(int pos, int64_t token) {
|
||||
cur_step_idx++;
|
||||
bool is_eos = is_in_end(token, end_tokens, end_length);
|
||||
bool max_len_hit = (cur_step_idx >= max_dec_len[bid]);
|
||||
if ((is_eos || max_len_hit) && !is_eos) {
|
||||
token = end_tokens[0];
|
||||
}
|
||||
step_output_ids[bid * max_step_tokens + pos] = token;
|
||||
output_len_now++;
|
||||
if (is_eos || max_len_hit) {
|
||||
stopped = true;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Emit the final token at position `pos` in Phase 2.
|
||||
// Same EOS/max_dec_len logic, but does NOT increment output_len_now
|
||||
// (Phase 2's token is already counted in the initial output_len_now=1).
|
||||
__device__ __forceinline__ void emit_final_token(int pos, int64_t token) {
|
||||
cur_step_idx++;
|
||||
bool is_eos = is_in_end(token, end_tokens, end_length);
|
||||
bool max_len_hit = (cur_step_idx >= max_dec_len[bid]);
|
||||
if ((is_eos || max_len_hit) && !is_eos) {
|
||||
token = end_tokens[0];
|
||||
}
|
||||
step_output_ids[bid * max_step_tokens + pos] = token;
|
||||
}
|
||||
|
||||
// TOPP-only: verify_window bulk-accept fallback.
|
||||
//
|
||||
// When draft token is NOT in top-p set but IS the top-2 token,
|
||||
// check verify_window consecutive positions for top-1 match.
|
||||
// If all match, bulk-accept from position i through ii.
|
||||
//
|
||||
// Returns the new loop position (i) after handling.
|
||||
// Sets *rejected=true if fallback was not triggered (caller should break).
|
||||
__device__ __forceinline__ int try_verify_window_fallback(
|
||||
int i,
|
||||
bool *rejected,
|
||||
const int64_t *verify_tokens_now,
|
||||
int seq_len_this_time,
|
||||
int max_candidate_len,
|
||||
int verify_window) {
|
||||
int ii = i;
|
||||
if (max_candidate_len >= 2 &&
|
||||
verify_tokens_now[ii * max_candidate_len + 1] ==
|
||||
step_input_ids_now[ii + 1]) {
|
||||
// top-2 matches — scan verify_window consecutive top-1 matches
|
||||
int j = 0;
|
||||
ii += 1;
|
||||
for (; j < verify_window && ii < seq_len_this_time - 1; j++, ii++) {
|
||||
if (verify_tokens_now[ii * max_candidate_len] !=
|
||||
step_input_ids_now[ii + 1]) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (j >= verify_window) {
|
||||
// Bulk accept all tokens from i to ii
|
||||
for (; i < ii; i++) {
|
||||
if (emit_token(i, step_input_ids_now[i + 1])) return i;
|
||||
}
|
||||
return i; // continue outer loop from position ii
|
||||
}
|
||||
}
|
||||
// Fallback not triggered or insufficient window — reject
|
||||
*rejected = true;
|
||||
return i;
|
||||
}
|
||||
};
|
||||
|
||||
// ============================================================
|
||||
// Phase 2 helpers — sample token for rejected/last position
|
||||
// ============================================================
|
||||
|
||||
__device__ inline int64_t topp_sampling_kernel(const int64_t *candidate_ids,
|
||||
const float *candidate_scores,
|
||||
curandState_t *curand_states,
|
||||
const int candidate_len,
|
||||
const float topp) {
|
||||
// Use bid (blockIdx.x-based) index, not threadIdx.x — curand_states is
|
||||
// allocated with size bsz, and each batch element uses one thread.
|
||||
const int bid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
float sum_scores = 0.0f;
|
||||
float rand_top_p = curand_uniform(curand_states + bid) * topp;
|
||||
for (int i = 0; i < candidate_len; i++) {
|
||||
sum_scores += candidate_scores[i];
|
||||
if (rand_top_p <= sum_scores) {
|
||||
return candidate_ids[i];
|
||||
}
|
||||
}
|
||||
return candidate_ids[0];
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Main verification kernel
|
||||
// ============================================================
|
||||
//
|
||||
// Input parameter groups by strategy:
|
||||
// - target_tokens: GREEDY=argmax, TARGET_MATCH=sampled, TOPP=unused
|
||||
// (None)
|
||||
// - candidate_ids/scores: TOPP=full candidate set, GREEDY/TARGET_MATCH=unused
|
||||
// (None)
|
||||
// - candidate_lens: TOPP=actual length per position,
|
||||
// GREEDY/TARGET_MATCH=unused (None)
|
||||
//
|
||||
// All parameters may be empty tensors for strategies that don't use them.
|
||||
//
|
||||
__global__ void verify_draft_tokens(
|
||||
// Core I/O
|
||||
int64_t *step_output_ids,
|
||||
int *step_output_len,
|
||||
const int64_t *step_input_ids, // draft tokens
|
||||
// Target model outputs (strategy-dependent interpretation)
|
||||
const int64_t
|
||||
*target_tokens, // GREEDY:argmax, TARGET_MATCH:sampled, TOPP:unused
|
||||
// Candidate set for TOPP/GREEDY (TARGET_MATCH: unused)
|
||||
const int64_t *candidate_ids,
|
||||
const float *candidate_scores,
|
||||
const int *candidate_lens,
|
||||
// Sampling params
|
||||
curandState_t *curand_states, // nullptr for GREEDY/TARGET_MATCH
|
||||
const float *topp,
|
||||
// Metadata
|
||||
const bool *stop_flags,
|
||||
const int *seq_lens_encoder,
|
||||
const int *seq_lens_this_time,
|
||||
const int64_t *end_tokens,
|
||||
const bool *is_block_step,
|
||||
const int *cu_seqlens_q_output,
|
||||
const int *reasoning_status,
|
||||
// max_dec_len / step_idx for EOS/max-len detection (read-only)
|
||||
const int64_t *max_dec_len,
|
||||
const int64_t *step_idx,
|
||||
// Dimensions and config
|
||||
const int max_bsz,
|
||||
const int real_bsz,
|
||||
const int max_step_tokens,
|
||||
const int end_length,
|
||||
const int max_seq_len,
|
||||
const int max_candidate_len,
|
||||
const int verify_window,
|
||||
const int verify_strategy, // 0=TOPP, 1=GREEDY, 2=TARGET_MATCH
|
||||
const bool reject_all,
|
||||
const bool accept_all) {
|
||||
const int bid = threadIdx.x;
|
||||
|
||||
// Initialize step_output_len to 0 for ALL slots
|
||||
if (bid < max_bsz) {
|
||||
step_output_len[bid] = 0;
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
|
||||
if (bid >= real_bsz || is_block_step[bid] || stop_flags[bid]) return;
|
||||
|
||||
const int start_token_id = cu_seqlens_q_output[bid];
|
||||
// Pointers are strategy-dependent (may be nullptr for unused params)
|
||||
auto *candidate_ids_now =
|
||||
candidate_ids ? candidate_ids + start_token_id * max_candidate_len
|
||||
: nullptr;
|
||||
auto *candidate_scores_now =
|
||||
candidate_scores ? candidate_scores + start_token_id * max_candidate_len
|
||||
: nullptr;
|
||||
auto *candidate_lens_now =
|
||||
candidate_lens ? candidate_lens + start_token_id : nullptr;
|
||||
auto *target_tokens_now =
|
||||
target_tokens ? target_tokens + start_token_id : nullptr;
|
||||
|
||||
// Initialize per-batch verification context
|
||||
VerifyContext ctx;
|
||||
ctx.bid = bid;
|
||||
ctx.max_step_tokens = max_step_tokens;
|
||||
ctx.end_length = end_length;
|
||||
ctx.end_tokens = end_tokens;
|
||||
ctx.max_dec_len = max_dec_len;
|
||||
ctx.step_input_ids_now = step_input_ids + bid * max_step_tokens;
|
||||
ctx.step_output_ids = step_output_ids;
|
||||
ctx.cur_step_idx = step_idx[bid];
|
||||
ctx.output_len_now = 1;
|
||||
ctx.stopped = false;
|
||||
|
||||
// ======== Phase 1: Verify draft tokens ========
|
||||
int i = 0;
|
||||
for (; i < seq_lens_this_time[bid] - 1; i++) {
|
||||
// Early exit conditions: reject-all, prefill, reasoning
|
||||
if (reject_all || seq_lens_encoder[bid] != 0 ||
|
||||
reasoning_status[bid] == 1) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Accept-all override (debug/warmup)
|
||||
if (accept_all) {
|
||||
if (ctx.emit_token(i, ctx.step_input_ids_now[i + 1])) break;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Strategy dispatch
|
||||
bool accepted = false;
|
||||
switch (verify_strategy) {
|
||||
case 0: { // TOPP
|
||||
auto actual_cand_len = candidate_lens_now[i] > max_candidate_len
|
||||
? max_candidate_len
|
||||
: candidate_lens_now[i];
|
||||
accepted = verify_one_topp(candidate_ids_now + i * max_candidate_len,
|
||||
ctx.step_input_ids_now[i + 1],
|
||||
actual_cand_len);
|
||||
if (!accepted) {
|
||||
bool rejected = false;
|
||||
i = ctx.try_verify_window_fallback(i,
|
||||
&rejected,
|
||||
candidate_ids_now,
|
||||
seq_lens_this_time[bid],
|
||||
max_candidate_len,
|
||||
verify_window);
|
||||
if (ctx.stopped || rejected) goto phase1_done;
|
||||
continue; // bulk accept succeeded, continue from new i
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 1: // GREEDY
|
||||
case 2: // TARGET_MATCH
|
||||
accepted = verify_one_match(target_tokens_now[i],
|
||||
ctx.step_input_ids_now[i + 1]);
|
||||
break;
|
||||
}
|
||||
|
||||
if (accepted) {
|
||||
if (ctx.emit_token(i, ctx.step_input_ids_now[i + 1])) break;
|
||||
} else {
|
||||
break; // reject
|
||||
}
|
||||
}
|
||||
phase1_done:
|
||||
|
||||
// ======== Phase 2: Output token for rejected/last position ========
|
||||
if (!ctx.stopped) {
|
||||
int64_t output_token;
|
||||
switch (verify_strategy) {
|
||||
case 0: { // TOPP — stochastic sampling from candidate set
|
||||
auto actual_cand_len = candidate_lens_now[i] > max_candidate_len
|
||||
? max_candidate_len
|
||||
: candidate_lens_now[i];
|
||||
output_token =
|
||||
topp_sampling_kernel(candidate_ids_now + i * max_candidate_len,
|
||||
candidate_scores_now + i * max_candidate_len,
|
||||
curand_states,
|
||||
actual_cand_len,
|
||||
topp[bid]);
|
||||
break;
|
||||
}
|
||||
case 1: // GREEDY — deterministic argmax from target_tokens
|
||||
case 2: // TARGET_MATCH — target model's sampled token
|
||||
output_token = target_tokens_now[i];
|
||||
break;
|
||||
}
|
||||
ctx.emit_final_token(i, output_token);
|
||||
}
|
||||
step_output_len[bid] = ctx.output_len_now;
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Host function
|
||||
// ============================================================
|
||||
void VerifyDraftTokens(
|
||||
// Core I/O
|
||||
const paddle::Tensor &step_output_ids,
|
||||
const paddle::Tensor &step_output_len,
|
||||
const paddle::Tensor &step_input_ids,
|
||||
// Target model outputs (optional, required for TARGET_MATCH)
|
||||
const paddle::optional<paddle::Tensor> &target_tokens,
|
||||
// Candidate set (optional, required for TOPP/GREEDY)
|
||||
const paddle::optional<paddle::Tensor> &candidate_ids,
|
||||
const paddle::optional<paddle::Tensor> &candidate_scores,
|
||||
const paddle::optional<paddle::Tensor> &candidate_lens,
|
||||
// Sampling params
|
||||
const paddle::Tensor &topp,
|
||||
// Metadata
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &end_tokens,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const paddle::Tensor &cu_seqlens_q_output,
|
||||
const paddle::Tensor &reasoning_status,
|
||||
// max_dec_len / step_idx for EOS/max-len detection
|
||||
const paddle::Tensor &max_dec_len,
|
||||
const paddle::Tensor &step_idx,
|
||||
int max_seq_len,
|
||||
int verify_window,
|
||||
int verify_strategy,
|
||||
bool reject_all,
|
||||
bool accept_all) {
|
||||
auto bsz = step_output_ids.shape()[0];
|
||||
auto real_bsz = seq_lens_this_time.shape()[0];
|
||||
auto max_step_tokens = step_input_ids.shape()[1];
|
||||
auto end_length = end_tokens.shape()[0];
|
||||
// max_candidate_len: 1 if candidate_ids not provided, else from shape
|
||||
int max_candidate_len = candidate_ids ? candidate_ids->shape()[1] : 1;
|
||||
|
||||
constexpr int BlockSize = 1024;
|
||||
PADDLE_ENFORCE_LE(bsz,
|
||||
BlockSize,
|
||||
phi::errors::InvalidArgument(
|
||||
"verify_draft_tokens: bsz (%d) exceeds BlockSize (%d). "
|
||||
"Increase BlockSize or reduce max_num_seqs.",
|
||||
bsz,
|
||||
BlockSize));
|
||||
auto stream = step_output_ids.stream();
|
||||
|
||||
// curand state: only needed for TOPP(0) strategy (stochastic sampling)
|
||||
curandState_t *curand_ptr = nullptr;
|
||||
if (verify_strategy ==
|
||||
0 /* TOPP only - GREEDY and TARGET_MATCH use deterministic output */) {
|
||||
if (dev_curand_states == nullptr || bsz > allocated_bsz) {
|
||||
if (dev_curand_states) cudaFree(dev_curand_states);
|
||||
cudaMalloc(&dev_curand_states, sizeof(curandState_t) * bsz);
|
||||
allocated_bsz = bsz;
|
||||
}
|
||||
setup_seed_kernel<<<1, BlockSize, 0, stream>>>(
|
||||
dev_curand_states, seed, offset, bsz, true);
|
||||
seed++;
|
||||
offset++;
|
||||
curand_ptr = dev_curand_states;
|
||||
}
|
||||
|
||||
// Get data pointers (nullptr if optional not provided)
|
||||
const int64_t *target_tokens_ptr =
|
||||
target_tokens ? target_tokens->data<int64_t>() : nullptr;
|
||||
const int64_t *candidate_ids_ptr =
|
||||
candidate_ids ? candidate_ids->data<int64_t>() : nullptr;
|
||||
const float *candidate_scores_ptr =
|
||||
candidate_scores ? candidate_scores->data<float>() : nullptr;
|
||||
const int *candidate_lens_ptr =
|
||||
candidate_lens ? candidate_lens->data<int>() : nullptr;
|
||||
|
||||
// Validate parameters based on verify_strategy.
|
||||
// Note: empty_input_forward may lead to empty optional tensors — only
|
||||
// validate when bsz > 0 (i.e. there are active sequences).
|
||||
if (bsz > 0) {
|
||||
if (verify_strategy == 0 /* TOPP */) {
|
||||
if (!candidate_ids_ptr || !candidate_scores_ptr || !candidate_lens_ptr) {
|
||||
PD_THROW(
|
||||
"verify_strategy=TOPP (0) requires candidate_ids, "
|
||||
"candidate_scores, candidate_lens");
|
||||
}
|
||||
} else if (verify_strategy == 1 /* GREEDY */) {
|
||||
if (!target_tokens_ptr) {
|
||||
PD_THROW("verify_strategy=GREEDY (1) requires target_tokens (argmax)");
|
||||
}
|
||||
} else if (verify_strategy == 2 /* TARGET_MATCH */) {
|
||||
if (!target_tokens_ptr) {
|
||||
PD_THROW(
|
||||
"verify_strategy=TARGET_MATCH (2) requires target_tokens "
|
||||
"(sampled)");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
verify_draft_tokens<<<1, BlockSize, 0, stream>>>(
|
||||
// Core I/O
|
||||
const_cast<int64_t *>(step_output_ids.data<int64_t>()),
|
||||
const_cast<int *>(step_output_len.data<int>()),
|
||||
step_input_ids.data<int64_t>(),
|
||||
// Target model outputs
|
||||
target_tokens_ptr,
|
||||
// Candidate set
|
||||
candidate_ids_ptr,
|
||||
candidate_scores_ptr,
|
||||
candidate_lens_ptr,
|
||||
// Sampling params
|
||||
curand_ptr,
|
||||
topp.data<float>(),
|
||||
// Metadata
|
||||
stop_flags.data<bool>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
end_tokens.data<int64_t>(),
|
||||
is_block_step.data<bool>(),
|
||||
cu_seqlens_q_output.data<int>(),
|
||||
reasoning_status.data<int>(),
|
||||
// max_dec_len / step_idx
|
||||
max_dec_len.data<int64_t>(),
|
||||
step_idx.data<int64_t>(),
|
||||
// Dimensions and config
|
||||
bsz, // max_bsz
|
||||
real_bsz, // real_bsz
|
||||
max_step_tokens,
|
||||
end_length,
|
||||
max_seq_len,
|
||||
max_candidate_len,
|
||||
verify_window,
|
||||
verify_strategy,
|
||||
reject_all,
|
||||
accept_all);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(verify_draft_tokens)
|
||||
.Inputs({"step_output_ids",
|
||||
"step_output_len",
|
||||
"step_input_ids",
|
||||
paddle::Optional("target_tokens"),
|
||||
paddle::Optional("candidate_ids"),
|
||||
paddle::Optional("candidate_scores"),
|
||||
paddle::Optional("candidate_lens"),
|
||||
"topp",
|
||||
"stop_flags",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_this_time",
|
||||
"end_tokens",
|
||||
"is_block_step",
|
||||
"cu_seqlens_q_output",
|
||||
"reasoning_status",
|
||||
"max_dec_len",
|
||||
"step_idx"})
|
||||
.Outputs({"step_output_ids_out", "step_output_len_out"})
|
||||
.Attrs({"max_seq_len: int",
|
||||
"verify_window: int",
|
||||
"verify_strategy: int",
|
||||
"reject_all: bool",
|
||||
"accept_all: bool"})
|
||||
.SetInplaceMap({{"step_output_ids", "step_output_ids_out"},
|
||||
{"step_output_len", "step_output_len_out"}})
|
||||
.SetKernelFn(PD_KERNEL(VerifyDraftTokens));
|
||||
Reference in New Issue
Block a user