mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
49617d9832
* support tag phase token enforce generation * optimize note and some feature * fix sampler unit test --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
531 lines
21 KiB
Plaintext
531 lines
21 KiB
Plaintext
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include <curand_kernel.h>
|
|
#include <cstdlib>
|
|
#include <string>
|
|
#include "helper.h" // NOLINT
|
|
|
|
__device__ inline bool is_in(const int64_t *candidates,
|
|
const int64_t draft,
|
|
const int candidate_len) {
|
|
for (int i = 0; i < candidate_len; i++) {
|
|
if (draft == candidates[i]) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
static uint64_t seed = 0;
|
|
static uint64_t offset = 0;
|
|
|
|
__device__ int64_t topp_sampling_kernel(const int64_t *candidate_ids,
|
|
const float *candidate_scores,
|
|
curandState_t *dev_curand_states,
|
|
const int candidate_len,
|
|
const float topp) {
|
|
const int tid = threadIdx.x;
|
|
|
|
float sum_scores = 0.0f;
|
|
float rand_top_p = curand_uniform(dev_curand_states + tid) * topp;
|
|
for (int i = 0; i < candidate_len; i++) {
|
|
sum_scores += candidate_scores[i];
|
|
if (rand_top_p <= sum_scores) {
|
|
return candidate_ids[i];
|
|
}
|
|
}
|
|
return candidate_ids[0];
|
|
}
|
|
|
|
__global__ void setup_kernel(curandState_t *state,
|
|
const uint64_t seed,
|
|
const uint64_t offset,
|
|
const int bs,
|
|
const bool need_batch_random) {
|
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
for (int i = idx; i < bs; i += gridDim.x * blockDim.x) {
|
|
if (need_batch_random) {
|
|
curand_init(seed, i, offset, &state[i]);
|
|
} else {
|
|
curand_init(seed, 0, offset, &state[i]);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <bool ENABLE_TOPP, bool USE_TOPK>
|
|
__global__ void speculate_verify(const int64_t *sampled_token_ids,
|
|
int64_t *accept_tokens,
|
|
int *accept_num,
|
|
int64_t *step_idx,
|
|
bool *stop_flags,
|
|
const int *seq_lens_encoder,
|
|
const int *seq_lens_decoder,
|
|
const int64_t *draft_tokens,
|
|
const int *actual_draft_token_nums,
|
|
curandState_t *dev_curand_states,
|
|
const float *topp,
|
|
const int *seq_lens_this_time,
|
|
const int64_t *verify_tokens,
|
|
const float *verify_scores,
|
|
const int64_t *max_dec_len,
|
|
const int64_t *end_tokens,
|
|
const bool *is_block_step,
|
|
const int *output_cum_offsets,
|
|
const int *actual_candidate_len,
|
|
const int *reasoning_status,
|
|
const int real_bsz,
|
|
const int max_draft_tokens,
|
|
const int end_length,
|
|
const int max_seq_len,
|
|
const int max_candidate_len,
|
|
const int verify_window,
|
|
const bool prefill_one_step_stop,
|
|
const bool benchmark_mode,
|
|
const bool accept_all_drafts,
|
|
const bool use_target_sampling) {
|
|
const int bid = threadIdx.x;
|
|
// verify and set stop flags
|
|
int accept_num_now = 1;
|
|
int stop_flag_now_int = 0;
|
|
|
|
if (!(is_block_step[bid] || bid >= real_bsz)) {
|
|
const int start_token_id = bid * max_seq_len - output_cum_offsets[bid];
|
|
|
|
if (stop_flags[bid]) {
|
|
stop_flag_now_int = 1;
|
|
} else { // 这里prefill阶段也会进入,但是因为draft
|
|
// tokens会置零,因此会直接到最后的采样阶段
|
|
auto *verify_tokens_now =
|
|
verify_tokens + start_token_id * max_candidate_len;
|
|
auto *draft_tokens_now = draft_tokens + bid * max_draft_tokens;
|
|
auto *actual_candidate_len_now = actual_candidate_len + start_token_id;
|
|
auto *sampled_token_id_now = sampled_token_ids + start_token_id;
|
|
|
|
int i = 0;
|
|
// printf("seq_lens_this_time[%d]-1: %d \n",bid,
|
|
// seq_lens_this_time[bid]-1);
|
|
for (; i < seq_lens_this_time[bid] - 1; i++) {
|
|
if (benchmark_mode || seq_lens_encoder[bid] != 0 ||
|
|
reasoning_status[bid] == 1) {
|
|
break;
|
|
}
|
|
if (accept_all_drafts) {
|
|
// accept all draft tokens
|
|
step_idx[bid]++;
|
|
auto accept_token = draft_tokens_now[i + 1];
|
|
accept_tokens[bid * max_draft_tokens + i] = accept_token;
|
|
|
|
if (is_in_end(accept_token, end_tokens, end_length) ||
|
|
step_idx[bid] >= max_dec_len[bid]) {
|
|
stop_flags[bid] = true;
|
|
stop_flag_now_int = 1;
|
|
if (step_idx[bid] >= max_dec_len[bid])
|
|
accept_tokens[bid * max_draft_tokens + i] = end_tokens[0];
|
|
break;
|
|
} else {
|
|
accept_num_now++;
|
|
}
|
|
continue;
|
|
}
|
|
if (use_target_sampling) {
|
|
if (sampled_token_id_now[i] == draft_tokens_now[i + 1]) {
|
|
step_idx[bid]++;
|
|
auto accept_token = draft_tokens_now[i + 1];
|
|
accept_tokens[bid * max_draft_tokens + i] = accept_token;
|
|
if (is_in_end(accept_token, end_tokens, end_length) ||
|
|
step_idx[bid] >= max_dec_len[bid]) {
|
|
stop_flags[bid] = true;
|
|
stop_flag_now_int = 1;
|
|
if (step_idx[bid] >= max_dec_len[bid])
|
|
accept_tokens[bid * max_draft_tokens + i] = end_tokens[0];
|
|
break;
|
|
} else {
|
|
accept_num_now++;
|
|
}
|
|
} else {
|
|
break;
|
|
}
|
|
} else if (USE_TOPK) {
|
|
if (verify_tokens_now[i * max_candidate_len] ==
|
|
draft_tokens_now[i + 1]) {
|
|
// accept_num_now++;
|
|
step_idx[bid]++;
|
|
auto accept_token = draft_tokens_now[i + 1];
|
|
// printf("[USE_TOPK] bid %d Top 1 verify write accept
|
|
// %d is %lld\n", bid, i, accept_token);
|
|
accept_tokens[bid * max_draft_tokens + i] = accept_token;
|
|
if (is_in_end(accept_token, end_tokens, end_length) ||
|
|
step_idx[bid] >= max_dec_len[bid]) {
|
|
stop_flags[bid] = true;
|
|
stop_flag_now_int = 1;
|
|
if (step_idx[bid] >= max_dec_len[bid])
|
|
accept_tokens[bid * max_draft_tokens + i] = end_tokens[0];
|
|
// printf("[USE_TOPK] bid %d Top 1 verify write
|
|
// accept %d is %lld\n", bid, i, accept_token);
|
|
break;
|
|
} else {
|
|
accept_num_now++;
|
|
}
|
|
} else {
|
|
break;
|
|
}
|
|
} else {
|
|
auto actual_candidate_len_value =
|
|
actual_candidate_len_now[i] > max_candidate_len
|
|
? max_candidate_len
|
|
: actual_candidate_len_now[i];
|
|
if (is_in(verify_tokens_now + i * max_candidate_len,
|
|
draft_tokens_now[i + 1],
|
|
actual_candidate_len_value)) {
|
|
// Top P verify
|
|
// accept_num_now++;
|
|
step_idx[bid]++;
|
|
auto accept_token = draft_tokens_now[i + 1];
|
|
accept_tokens[bid * max_draft_tokens + i] = accept_token;
|
|
|
|
if (is_in_end(accept_token, end_tokens, end_length) ||
|
|
step_idx[bid] >= max_dec_len[bid]) {
|
|
stop_flags[bid] = true;
|
|
stop_flag_now_int = 1;
|
|
if (step_idx[bid] >= max_dec_len[bid])
|
|
accept_tokens[bid * max_draft_tokens + i] = end_tokens[0];
|
|
// printf("bid %d Top P verify write accept %d is
|
|
// %lld\n", bid, i, accept_token);
|
|
break;
|
|
} else {
|
|
accept_num_now++;
|
|
}
|
|
} else {
|
|
// TopK verify
|
|
int ii = i;
|
|
if (max_candidate_len >= 2 &&
|
|
verify_tokens_now[ii * max_candidate_len + 1] ==
|
|
draft_tokens_now[ii + 1]) { // top-2
|
|
int j = 0;
|
|
ii += 1;
|
|
for (; j < verify_window && ii < seq_lens_this_time[bid] - 1;
|
|
j++, ii++) {
|
|
if (verify_tokens_now[ii * max_candidate_len] !=
|
|
draft_tokens_now[ii + 1]) {
|
|
break;
|
|
}
|
|
}
|
|
if (j >= verify_window) { // accept all
|
|
accept_num_now += verify_window + 1;
|
|
step_idx[bid] += verify_window + 1;
|
|
for (; i < ii; i++) {
|
|
auto accept_token = draft_tokens_now[i + 1];
|
|
accept_tokens[bid * max_draft_tokens + i] = accept_token;
|
|
// printf(
|
|
// "bid %d TopK verify write accept %d
|
|
// is "
|
|
// "%lld\n",
|
|
// bid,
|
|
// i,
|
|
// accept_token);
|
|
if (is_in_end(accept_token, end_tokens, end_length) ||
|
|
step_idx[bid] >= max_dec_len[bid]) {
|
|
stop_flags[bid] = true;
|
|
stop_flag_now_int = 1;
|
|
if (step_idx[bid] >= max_dec_len[bid])
|
|
accept_tokens[bid * max_draft_tokens + i] = end_tokens[0];
|
|
// printf("bid %d TopK verify write
|
|
// accept %d is %lld\n", bid, i,
|
|
// end_tokens[0]);
|
|
accept_num_now--;
|
|
step_idx[bid]--;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
// sampling阶段
|
|
// 第一种,draft_token[i+1]被拒绝,需要从verify_tokens_now[i]中选一个
|
|
// 第二种,i == seq_lens_this_time[bid]-1,
|
|
// 也是从verify_tokens_now[i]中选一个 但是停止的情况不算
|
|
if (!stop_flag_now_int) {
|
|
int64_t accept_token;
|
|
const float *verify_scores_now =
|
|
verify_scores + start_token_id * max_candidate_len;
|
|
step_idx[bid]++;
|
|
if (use_target_sampling) {
|
|
accept_token = sampled_token_id_now[i];
|
|
} else if (ENABLE_TOPP) {
|
|
auto actual_candidate_len_value =
|
|
actual_candidate_len_now[i] > max_candidate_len
|
|
? max_candidate_len
|
|
: actual_candidate_len_now[i];
|
|
|
|
accept_token =
|
|
topp_sampling_kernel(verify_tokens_now + i * max_candidate_len,
|
|
verify_scores_now + i * max_candidate_len,
|
|
dev_curand_states,
|
|
actual_candidate_len_value,
|
|
topp[bid]);
|
|
} else {
|
|
accept_token = verify_tokens_now[i * max_candidate_len];
|
|
}
|
|
accept_tokens[bid * max_draft_tokens + i] = accept_token;
|
|
if (prefill_one_step_stop) {
|
|
stop_flags[bid] = true;
|
|
}
|
|
if (is_in_end(accept_token, end_tokens, end_length) ||
|
|
step_idx[bid] >= max_dec_len[bid]) {
|
|
stop_flags[bid] = true;
|
|
stop_flag_now_int = 1;
|
|
if (step_idx[bid] >= max_dec_len[bid])
|
|
accept_tokens[bid * max_draft_tokens + i] = end_tokens[0];
|
|
}
|
|
}
|
|
accept_num[bid] = accept_num_now;
|
|
}
|
|
}
|
|
}
|
|
|
|
void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
|
|
const paddle::Tensor &accept_tokens,
|
|
const paddle::Tensor &accept_num,
|
|
const paddle::Tensor &step_idx,
|
|
const paddle::Tensor &stop_flags,
|
|
const paddle::Tensor &seq_lens_encoder,
|
|
const paddle::Tensor &seq_lens_decoder,
|
|
const paddle::Tensor &draft_tokens,
|
|
const paddle::Tensor &seq_lens_this_time,
|
|
const paddle::Tensor &verify_tokens,
|
|
const paddle::Tensor &verify_scores,
|
|
const paddle::Tensor &max_dec_len,
|
|
const paddle::Tensor &end_tokens,
|
|
const paddle::Tensor &is_block_step,
|
|
const paddle::Tensor &output_cum_offsets,
|
|
const paddle::Tensor &actual_candidate_len,
|
|
const paddle::Tensor &actual_draft_token_nums,
|
|
const paddle::Tensor &topp,
|
|
const paddle::Tensor &reasoning_status,
|
|
int max_seq_len,
|
|
int verify_window,
|
|
bool enable_topp,
|
|
bool benchmark_mode,
|
|
bool accept_all_drafts) {
|
|
// printf("Enter speculate update\n");
|
|
auto bsz = accept_tokens.shape()[0];
|
|
int real_bsz = seq_lens_this_time.shape()[0];
|
|
auto max_draft_tokens = draft_tokens.shape()[1];
|
|
auto end_length = end_tokens.shape()[0];
|
|
auto max_candidate_len = verify_tokens.shape()[1];
|
|
|
|
constexpr int BlockSize = 512;
|
|
|
|
curandState_t *dev_curand_states;
|
|
cudaMalloc(&dev_curand_states, sizeof(curandState_t) * bsz);
|
|
setup_kernel<<<1, BlockSize, 0, accept_tokens.stream()>>>(
|
|
dev_curand_states, seed, offset, bsz, true);
|
|
seed++;
|
|
offset++;
|
|
|
|
bool use_topk = false;
|
|
char *env_var = getenv("SPECULATE_VERIFY_USE_TOPK");
|
|
if (env_var) {
|
|
use_topk = static_cast<bool>(std::stoi(env_var));
|
|
}
|
|
bool use_target_sampling = false;
|
|
char *env_var_1 = getenv("SPECULATE_VERIFY_USE_TARGET_SAMPLING");
|
|
if (env_var_1) {
|
|
use_target_sampling = static_cast<bool>(std::stoi(env_var_1));
|
|
}
|
|
bool prefill_one_step_stop = false;
|
|
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
|
|
if (env_p[0] == '1') {
|
|
prefill_one_step_stop = true;
|
|
}
|
|
}
|
|
if (use_topk) {
|
|
if (enable_topp) {
|
|
speculate_verify<true, true><<<1, BlockSize, 0, accept_tokens.stream()>>>(
|
|
sampled_token_ids.data<int64_t>(),
|
|
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
|
const_cast<int *>(accept_num.data<int>()),
|
|
const_cast<int64_t *>(step_idx.data<int64_t>()),
|
|
const_cast<bool *>(stop_flags.data<bool>()),
|
|
seq_lens_encoder.data<int>(),
|
|
seq_lens_decoder.data<int>(),
|
|
draft_tokens.data<int64_t>(),
|
|
actual_draft_token_nums.data<int>(),
|
|
dev_curand_states,
|
|
topp.data<float>(),
|
|
seq_lens_this_time.data<int>(),
|
|
verify_tokens.data<int64_t>(),
|
|
verify_scores.data<float>(),
|
|
max_dec_len.data<int64_t>(),
|
|
end_tokens.data<int64_t>(),
|
|
is_block_step.data<bool>(),
|
|
output_cum_offsets.data<int>(),
|
|
actual_candidate_len.data<int>(),
|
|
reasoning_status.data<int>(),
|
|
real_bsz,
|
|
max_draft_tokens,
|
|
end_length,
|
|
max_seq_len,
|
|
max_candidate_len,
|
|
verify_window,
|
|
prefill_one_step_stop,
|
|
benchmark_mode,
|
|
accept_all_drafts,
|
|
use_target_sampling);
|
|
} else {
|
|
speculate_verify<false, true>
|
|
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
|
|
sampled_token_ids.data<int64_t>(),
|
|
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
|
const_cast<int *>(accept_num.data<int>()),
|
|
const_cast<int64_t *>(step_idx.data<int64_t>()),
|
|
const_cast<bool *>(stop_flags.data<bool>()),
|
|
seq_lens_encoder.data<int>(),
|
|
seq_lens_decoder.data<int>(),
|
|
draft_tokens.data<int64_t>(),
|
|
actual_draft_token_nums.data<int>(),
|
|
dev_curand_states,
|
|
topp.data<float>(),
|
|
seq_lens_this_time.data<int>(),
|
|
verify_tokens.data<int64_t>(),
|
|
verify_scores.data<float>(),
|
|
max_dec_len.data<int64_t>(),
|
|
end_tokens.data<int64_t>(),
|
|
is_block_step.data<bool>(),
|
|
output_cum_offsets.data<int>(),
|
|
actual_candidate_len.data<int>(),
|
|
reasoning_status.data<int>(),
|
|
real_bsz,
|
|
max_draft_tokens,
|
|
end_length,
|
|
max_seq_len,
|
|
max_candidate_len,
|
|
verify_window,
|
|
prefill_one_step_stop,
|
|
benchmark_mode,
|
|
accept_all_drafts,
|
|
use_target_sampling);
|
|
}
|
|
} else {
|
|
if (enable_topp) {
|
|
speculate_verify<true, false>
|
|
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
|
|
sampled_token_ids.data<int64_t>(),
|
|
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
|
const_cast<int *>(accept_num.data<int>()),
|
|
const_cast<int64_t *>(step_idx.data<int64_t>()),
|
|
const_cast<bool *>(stop_flags.data<bool>()),
|
|
seq_lens_encoder.data<int>(),
|
|
seq_lens_decoder.data<int>(),
|
|
draft_tokens.data<int64_t>(),
|
|
actual_draft_token_nums.data<int>(),
|
|
dev_curand_states,
|
|
topp.data<float>(),
|
|
seq_lens_this_time.data<int>(),
|
|
verify_tokens.data<int64_t>(),
|
|
verify_scores.data<float>(),
|
|
max_dec_len.data<int64_t>(),
|
|
end_tokens.data<int64_t>(),
|
|
is_block_step.data<bool>(),
|
|
output_cum_offsets.data<int>(),
|
|
actual_candidate_len.data<int>(),
|
|
reasoning_status.data<int>(),
|
|
real_bsz,
|
|
max_draft_tokens,
|
|
end_length,
|
|
max_seq_len,
|
|
max_candidate_len,
|
|
verify_window,
|
|
prefill_one_step_stop,
|
|
benchmark_mode,
|
|
accept_all_drafts,
|
|
use_target_sampling);
|
|
} else {
|
|
speculate_verify<false, false>
|
|
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
|
|
sampled_token_ids.data<int64_t>(),
|
|
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
|
const_cast<int *>(accept_num.data<int>()),
|
|
const_cast<int64_t *>(step_idx.data<int64_t>()),
|
|
const_cast<bool *>(stop_flags.data<bool>()),
|
|
seq_lens_encoder.data<int>(),
|
|
seq_lens_decoder.data<int>(),
|
|
draft_tokens.data<int64_t>(),
|
|
actual_draft_token_nums.data<int>(),
|
|
dev_curand_states,
|
|
topp.data<float>(),
|
|
seq_lens_this_time.data<int>(),
|
|
verify_tokens.data<int64_t>(),
|
|
verify_scores.data<float>(),
|
|
max_dec_len.data<int64_t>(),
|
|
end_tokens.data<int64_t>(),
|
|
is_block_step.data<bool>(),
|
|
output_cum_offsets.data<int>(),
|
|
actual_candidate_len.data<int>(),
|
|
reasoning_status.data<int>(),
|
|
real_bsz,
|
|
max_draft_tokens,
|
|
end_length,
|
|
max_seq_len,
|
|
max_candidate_len,
|
|
verify_window,
|
|
prefill_one_step_stop,
|
|
benchmark_mode,
|
|
accept_all_drafts,
|
|
use_target_sampling);
|
|
}
|
|
}
|
|
|
|
cudaFree(dev_curand_states);
|
|
}
|
|
|
|
PD_BUILD_STATIC_OP(speculate_verify)
|
|
.Inputs({"sampled_token_ids",
|
|
"accept_tokens",
|
|
"accept_num",
|
|
"step_idx",
|
|
"seq_lens_encoder",
|
|
"seq_lens_decoder",
|
|
"stop_flags",
|
|
"draft_tokens",
|
|
"seq_lens_this_time",
|
|
"verify_tokens",
|
|
"verify_scores",
|
|
"max_dec_len",
|
|
"end_tokens",
|
|
"is_block_step",
|
|
"output_cum_offsets",
|
|
"actual_candidate_len",
|
|
"actual_draft_token_nums",
|
|
"topp",
|
|
"reasoning_status"})
|
|
.Outputs({"accept_tokens_out",
|
|
"accept_num_out",
|
|
"step_idx_out",
|
|
"stop_flags_out"})
|
|
.Attrs({"max_seq_len: int",
|
|
"verify_window: int",
|
|
"enable_topp: bool",
|
|
"benchmark_mode: bool",
|
|
"accept_all_drafts: bool"})
|
|
.SetInplaceMap({{"accept_tokens", "accept_tokens_out"},
|
|
{"accept_num", "accept_num_out"},
|
|
{"step_idx", "step_idx_out"},
|
|
{"stop_flags", "stop_flags_out"}})
|
|
.SetKernelFn(PD_KERNEL(SpeculateVerify));
|