[Feature] support mtp distribution equivalence verification (#4699)
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

This commit is contained in:
GoldPancake
2025-10-31 11:45:04 +08:00
committed by GitHub
parent 28de91b50f
commit 1f3ce65b58
6 changed files with 257 additions and 88 deletions
@@ -12,12 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h" // NOLINT
#include <cstdlib>
#include <curand_kernel.h>
#include <cstdlib>
#include <string>
#include "helper.h" // NOLINT
__device__ inline bool is_in(const int64_t *candidates, const int64_t draft,
__device__ inline bool is_in(const int64_t *candidates,
const int64_t draft,
const int candidate_len) {
for (int i = 0; i < candidate_len; i++) {
if (draft == candidates[i]) {
@@ -48,8 +49,10 @@ __device__ int64_t topp_sampling_kernel(const int64_t *candidate_ids,
return candidate_ids[0];
}
__global__ void setup_kernel(curandState_t *state, const uint64_t seed,
const uint64_t offset, const int bs,
__global__ void setup_kernel(curandState_t *state,
const uint64_t seed,
const uint64_t offset,
const int bs,
const bool need_batch_random) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = idx; i < bs; i += gridDim.x * blockDim.x) {
@@ -62,18 +65,35 @@ __global__ void setup_kernel(curandState_t *state, const uint64_t seed,
}
template <bool ENABLE_TOPP, bool USE_TOPK>
__global__ void speculate_verify(
int64_t *accept_tokens, int *accept_num, int64_t *step_idx,
bool *stop_flags, const int *seq_lens_encoder, const int *seq_lens_decoder,
const int64_t *draft_tokens, const int *actual_draft_token_nums,
curandState_t *dev_curand_states, const float *topp,
const int *seq_lens_this_time, const int64_t *verify_tokens,
const float *verify_scores, const int64_t *max_dec_len,
const int64_t *end_tokens, const bool *is_block_step,
const int *output_cum_offsets, const int *actual_candidate_len,
const int real_bsz, const int max_draft_tokens, const int end_length,
const int max_seq_len, const int max_candidate_len, const int verify_window,
const bool prefill_one_step_stop, const bool benchmark_mode, const bool accept_all_drafts) {
__global__ void speculate_verify(const int64_t *sampled_token_ids,
int64_t *accept_tokens,
int *accept_num,
int64_t *step_idx,
bool *stop_flags,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
const int64_t *draft_tokens,
const int *actual_draft_token_nums,
curandState_t *dev_curand_states,
const float *topp,
const int *seq_lens_this_time,
const int64_t *verify_tokens,
const float *verify_scores,
const int64_t *max_dec_len,
const int64_t *end_tokens,
const bool *is_block_step,
const int *output_cum_offsets,
const int *actual_candidate_len,
const int real_bsz,
const int max_draft_tokens,
const int end_length,
const int max_seq_len,
const int max_candidate_len,
const int verify_window,
const bool prefill_one_step_stop,
const bool benchmark_mode,
const bool accept_all_drafts,
const bool use_target_sampling) {
const int bid = threadIdx.x;
// verify and set stop flags
int accept_num_now = 1;
@@ -84,12 +104,13 @@ __global__ void speculate_verify(
if (stop_flags[bid]) {
stop_flag_now_int = 1;
} else { // 这里prefill阶段也会进入,但是因为draft
// tokens会置零,因此会直接到最后的采样阶段
} else { // 这里prefill阶段也会进入,但是因为draft
// tokens会置零,因此会直接到最后的采样阶段
auto *verify_tokens_now =
verify_tokens + start_token_id * max_candidate_len;
auto *draft_tokens_now = draft_tokens + bid * max_draft_tokens;
auto *actual_candidate_len_now = actual_candidate_len + start_token_id;
auto *sampled_token_id_now = sampled_token_ids + start_token_id;
int i = 0;
// printf("seq_lens_this_time[%d]-1: %d \n",bid,
@@ -119,7 +140,25 @@ __global__ void speculate_verify(
}
continue;
}
if (USE_TOPK) {
if (use_target_sampling) {
if (sampled_token_id_now[i] == draft_tokens_now[i + 1]) {
step_idx[bid]++;
auto accept_token = draft_tokens_now[i + 1];
accept_tokens[bid * max_draft_tokens + i] = accept_token;
if (is_in_end(accept_token, end_tokens, end_length) ||
step_idx[bid] >= max_dec_len[bid]) {
stop_flags[bid] = true;
stop_flag_now_int = 1;
if (step_idx[bid] >= max_dec_len[bid])
accept_tokens[bid * max_draft_tokens + i] = end_tokens[0];
break;
} else {
accept_num_now++;
}
} else {
break;
}
} else if (USE_TOPK) {
if (verify_tokens_now[i * max_candidate_len] ==
draft_tokens_now[i + 1]) {
// accept_num_now++;
@@ -149,7 +188,8 @@ __global__ void speculate_verify(
? max_candidate_len
: actual_candidate_len_now[i];
if (is_in(verify_tokens_now + i * max_candidate_len,
draft_tokens_now[i + 1], actual_candidate_len_value)) {
draft_tokens_now[i + 1],
actual_candidate_len_value)) {
// Top P verify
// accept_num_now++;
step_idx[bid]++;
@@ -173,7 +213,7 @@ __global__ void speculate_verify(
int ii = i;
if (max_candidate_len >= 2 &&
verify_tokens_now[ii * max_candidate_len + 1] ==
draft_tokens_now[ii + 1]) { // top-2
draft_tokens_now[ii + 1]) { // top-2
int j = 0;
ii += 1;
for (; j < verify_window && ii < seq_lens_this_time[bid] - 1;
@@ -183,7 +223,7 @@ __global__ void speculate_verify(
break;
}
}
if (j >= verify_window) { // accept all
if (j >= verify_window) { // accept all
accept_num_now += verify_window + 1;
step_idx[bid] += verify_window + 1;
for (; i < ii; i++) {
@@ -225,16 +265,20 @@ __global__ void speculate_verify(
const float *verify_scores_now =
verify_scores + start_token_id * max_candidate_len;
step_idx[bid]++;
if (ENABLE_TOPP) {
if (use_target_sampling) {
accept_token = sampled_token_id_now[i];
} else if (ENABLE_TOPP) {
auto actual_candidate_len_value =
actual_candidate_len_now[i] > max_candidate_len
? max_candidate_len
: actual_candidate_len_now[i];
accept_token = topp_sampling_kernel(
verify_tokens_now + i * max_candidate_len,
verify_scores_now + i * max_candidate_len, dev_curand_states,
actual_candidate_len_value, topp[bid]);
accept_token =
topp_sampling_kernel(verify_tokens_now + i * max_candidate_len,
verify_scores_now + i * max_candidate_len,
dev_curand_states,
actual_candidate_len_value,
topp[bid]);
} else {
accept_token = verify_tokens_now[i * max_candidate_len];
}
@@ -255,19 +299,29 @@ __global__ void speculate_verify(
}
}
void SpeculateVerify(
const paddle::Tensor &accept_tokens, const paddle::Tensor &accept_num,
const paddle::Tensor &step_idx, const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &draft_tokens,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &verify_tokens, const paddle::Tensor &verify_scores,
const paddle::Tensor &max_dec_len, const paddle::Tensor &end_tokens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &output_cum_offsets,
const paddle::Tensor &actual_candidate_len,
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode, bool accept_all_drafts) {
void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
const paddle::Tensor &accept_tokens,
const paddle::Tensor &accept_num,
const paddle::Tensor &step_idx,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &draft_tokens,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &verify_tokens,
const paddle::Tensor &verify_scores,
const paddle::Tensor &max_dec_len,
const paddle::Tensor &end_tokens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &output_cum_offsets,
const paddle::Tensor &actual_candidate_len,
const paddle::Tensor &actual_draft_token_nums,
const paddle::Tensor &topp,
int max_seq_len,
int verify_window,
bool enable_topp,
bool benchmark_mode,
bool accept_all_drafts) {
// printf("Enter speculate update\n");
auto bsz = accept_tokens.shape()[0];
int real_bsz = seq_lens_this_time.shape()[0];
@@ -289,6 +343,11 @@ void SpeculateVerify(
if (env_var) {
use_topk = static_cast<bool>(std::stoi(env_var));
}
bool use_target_sampling = false;
char *env_var_1 = getenv("SPECULATE_VERIFY_USE_TARGET_SAMPLING");
if (env_var_1) {
use_target_sampling = static_cast<bool>(std::stoi(env_var_1));
}
bool prefill_one_step_stop = false;
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
if (env_p[0] == '1') {
@@ -298,69 +357,133 @@ void SpeculateVerify(
if (use_topk) {
if (enable_topp) {
speculate_verify<true, true><<<1, BlockSize, 0, accept_tokens.stream()>>>(
sampled_token_ids.data<int64_t>(),
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<bool *>(stop_flags.data<bool>()),
seq_lens_encoder.data<int>(), seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(), actual_draft_token_nums.data<int>(),
dev_curand_states, topp.data<float>(), seq_lens_this_time.data<int>(),
verify_tokens.data<int64_t>(), verify_scores.data<float>(),
max_dec_len.data<int64_t>(), end_tokens.data<int64_t>(),
is_block_step.data<bool>(), output_cum_offsets.data<int>(),
actual_candidate_len.data<int>(), real_bsz, max_draft_tokens,
end_length, max_seq_len, max_candidate_len, verify_window,
prefill_one_step_stop, benchmark_mode, accept_all_drafts);
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(),
actual_draft_token_nums.data<int>(),
dev_curand_states,
topp.data<float>(),
seq_lens_this_time.data<int>(),
verify_tokens.data<int64_t>(),
verify_scores.data<float>(),
max_dec_len.data<int64_t>(),
end_tokens.data<int64_t>(),
is_block_step.data<bool>(),
output_cum_offsets.data<int>(),
actual_candidate_len.data<int>(),
real_bsz,
max_draft_tokens,
end_length,
max_seq_len,
max_candidate_len,
verify_window,
prefill_one_step_stop,
benchmark_mode,
accept_all_drafts,
use_target_sampling);
} else {
speculate_verify<false, true>
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
sampled_token_ids.data<int64_t>(),
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<bool *>(stop_flags.data<bool>()),
seq_lens_encoder.data<int>(), seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(), actual_draft_token_nums.data<int>(),
dev_curand_states, topp.data<float>(),
seq_lens_this_time.data<int>(), verify_tokens.data<int64_t>(),
verify_scores.data<float>(), max_dec_len.data<int64_t>(),
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
real_bsz, max_draft_tokens, end_length, max_seq_len,
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode, accept_all_drafts);
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(),
actual_draft_token_nums.data<int>(),
dev_curand_states,
topp.data<float>(),
seq_lens_this_time.data<int>(),
verify_tokens.data<int64_t>(),
verify_scores.data<float>(),
max_dec_len.data<int64_t>(),
end_tokens.data<int64_t>(),
is_block_step.data<bool>(),
output_cum_offsets.data<int>(),
actual_candidate_len.data<int>(),
real_bsz,
max_draft_tokens,
end_length,
max_seq_len,
max_candidate_len,
verify_window,
prefill_one_step_stop,
benchmark_mode,
accept_all_drafts,
use_target_sampling);
}
} else {
if (enable_topp) {
speculate_verify<true, false>
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
sampled_token_ids.data<int64_t>(),
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<bool *>(stop_flags.data<bool>()),
seq_lens_encoder.data<int>(), seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(), actual_draft_token_nums.data<int>(),
dev_curand_states, topp.data<float>(),
seq_lens_this_time.data<int>(), verify_tokens.data<int64_t>(),
verify_scores.data<float>(), max_dec_len.data<int64_t>(),
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
real_bsz, max_draft_tokens, end_length, max_seq_len,
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode, accept_all_drafts);
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(),
actual_draft_token_nums.data<int>(),
dev_curand_states,
topp.data<float>(),
seq_lens_this_time.data<int>(),
verify_tokens.data<int64_t>(),
verify_scores.data<float>(),
max_dec_len.data<int64_t>(),
end_tokens.data<int64_t>(),
is_block_step.data<bool>(),
output_cum_offsets.data<int>(),
actual_candidate_len.data<int>(),
real_bsz,
max_draft_tokens,
end_length,
max_seq_len,
max_candidate_len,
verify_window,
prefill_one_step_stop,
benchmark_mode,
accept_all_drafts,
use_target_sampling);
} else {
speculate_verify<false, false>
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
sampled_token_ids.data<int64_t>(),
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<bool *>(stop_flags.data<bool>()),
seq_lens_encoder.data<int>(), seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(), actual_draft_token_nums.data<int>(),
dev_curand_states, topp.data<float>(),
seq_lens_this_time.data<int>(), verify_tokens.data<int64_t>(),
verify_scores.data<float>(), max_dec_len.data<int64_t>(),
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
real_bsz, max_draft_tokens, end_length, max_seq_len,
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode, accept_all_drafts);
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(),
actual_draft_token_nums.data<int>(),
dev_curand_states,
topp.data<float>(),
seq_lens_this_time.data<int>(),
verify_tokens.data<int64_t>(),
verify_scores.data<float>(),
max_dec_len.data<int64_t>(),
end_tokens.data<int64_t>(),
is_block_step.data<bool>(),
output_cum_offsets.data<int>(),
actual_candidate_len.data<int>(),
real_bsz,
max_draft_tokens,
end_length,
max_seq_len,
max_candidate_len,
verify_window,
prefill_one_step_stop,
benchmark_mode,
accept_all_drafts,
use_target_sampling);
}
}
@@ -368,14 +491,33 @@ void SpeculateVerify(
}
PD_BUILD_STATIC_OP(speculate_verify)
.Inputs({"accept_tokens", "accept_num", "step_idx", "seq_lens_encoder",
"seq_lens_decoder", "stop_flags", "draft_tokens",
"seq_lens_this_time", "verify_tokens", "verify_scores",
"max_dec_len", "end_tokens", "is_block_step", "output_cum_offsets",
"actual_candidate_len", "actual_draft_token_nums", "topp"})
.Outputs({"accept_tokens_out", "accept_num_out", "step_idx_out",
.Inputs({"sampled_token_ids",
"accept_tokens",
"accept_num",
"step_idx",
"seq_lens_encoder",
"seq_lens_decoder",
"stop_flags",
"draft_tokens",
"seq_lens_this_time",
"verify_tokens",
"verify_scores",
"max_dec_len",
"end_tokens",
"is_block_step",
"output_cum_offsets",
"actual_candidate_len",
"actual_draft_token_nums",
"topp"})
.Outputs({"accept_tokens_out",
"accept_num_out",
"step_idx_out",
"stop_flags_out"})
.Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool", "benchmark_mode: bool","accept_all_drafts: bool"})
.Attrs({"max_seq_len: int",
"verify_window: int",
"enable_topp: bool",
"benchmark_mode: bool",
"accept_all_drafts: bool"})
.SetInplaceMap({{"accept_tokens", "accept_tokens_out"},
{"accept_num", "accept_num_out"},
{"step_idx", "step_idx_out"},