[XPU] fix speculate_verify (#6985)

This commit is contained in:
zhupengyang
2026-03-24 18:55:09 +08:00
committed by GitHub
parent 6cff780fdb
commit 5780345646
2 changed files with 181 additions and 180 deletions
@@ -26,24 +26,24 @@
namespace api = baidu::xpu::api;
void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
const paddle::Tensor &accept_tokens,
const paddle::Tensor &accept_num,
const paddle::Tensor &step_idx,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &draft_tokens,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &verify_tokens,
const paddle::Tensor &verify_scores,
const paddle::Tensor &max_dec_len,
const paddle::Tensor &end_tokens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &output_cum_offsets,
const paddle::Tensor &actual_candidate_len,
const paddle::Tensor &actual_draft_token_nums,
const paddle::Tensor &topp,
void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& step_idx,
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& draft_tokens,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& verify_tokens,
const paddle::Tensor& verify_scores,
const paddle::Tensor& max_dec_len,
const paddle::Tensor& end_tokens,
const paddle::Tensor& is_block_step,
const paddle::Tensor& output_cum_offsets,
const paddle::Tensor& actual_candidate_len,
const paddle::Tensor& actual_draft_token_nums,
const paddle::Tensor& topp,
int max_seq_len,
int verify_window,
bool enable_topp,
@@ -57,8 +57,7 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
api::Context *ctx =
static_cast<const phi::XPUContext *>(dev_ctx)->x_context();
api::Context* ctx = static_cast<const phi::XPUContext*>(dev_ctx)->x_context();
bool xpu_ctx_flag = true;
if (draft_tokens.is_cpu()) {
ctx = new api::Context(api::kCPU);
@@ -66,17 +65,17 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
}
bool use_topk = false;
char *env_var = getenv("SPECULATE_VERIFY_USE_TOPK");
char* env_var = getenv("SPECULATE_VERIFY_USE_TOPK");
if (env_var) {
use_topk = static_cast<bool>(std::stoi(env_var));
}
bool use_target_sampling = false;
char *env_var_1 = getenv("SPECULATE_VERIFY_USE_TARGET_SAMPLING");
char* env_var_1 = getenv("SPECULATE_VERIFY_USE_TARGET_SAMPLING");
if (env_var_1) {
use_target_sampling = static_cast<bool>(std::stoi(env_var_1));
}
bool prefill_one_step_stop = false;
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
if (const char* env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
// std::cout << "Your PATH is: " << env_p << '\n';
if (env_p[0] == '1') {
prefill_one_step_stop = true;
@@ -91,29 +90,30 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
std::mt19937_64 engine(infer_seed[i]);
dev_curand_states_cpu.push_back(dist(engine));
}
float *dev_curand_states_xpu;
float* dev_curand_states = dev_curand_states_cpu.data();
auto dev_curand_states_tensor =
paddle::empty({static_cast<int64_t>(dev_curand_states_cpu.size())},
paddle::DataType::FLOAT32,
draft_tokens.place());
int ret;
if (xpu_ctx_flag) {
xpu::ctx_guard RAII_GUARD(ctx);
dev_curand_states_xpu =
RAII_GUARD.alloc<float>(dev_curand_states_cpu.size());
xpu_memcpy(dev_curand_states_xpu,
dev_curand_states_cpu.data(),
dev_curand_states_cpu.size() * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
ret = xpu::do_host2device(ctx,
dev_curand_states_cpu.data(),
dev_curand_states_tensor.data<float>(),
dev_curand_states_cpu.size() * sizeof(float));
PD_CHECK(ret == 0, "do_host2device failed.");
dev_curand_states = dev_curand_states_tensor.data<float>();
}
auto dev_curand_states =
!xpu_ctx_flag ? dev_curand_states_cpu.data() : dev_curand_states_xpu;
int ret;
if (use_topk) {
if (enable_topp) {
ret = fastdeploy::plugin::speculate_verify<true, true>(
ctx,
sampled_token_ids.data<int64_t>(),
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int64_t*>(accept_tokens.data<int64_t>()),
const_cast<int*>(accept_num.data<int>()),
const_cast<int64_t*>(step_idx.data<int64_t>()),
const_cast<bool*>(stop_flags.data<bool>()),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(),
@@ -143,10 +143,10 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
ret = fastdeploy::plugin::speculate_verify<false, true>(
ctx,
sampled_token_ids.data<int64_t>(),
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int64_t*>(accept_tokens.data<int64_t>()),
const_cast<int*>(accept_num.data<int>()),
const_cast<int64_t*>(step_idx.data<int64_t>()),
const_cast<bool*>(stop_flags.data<bool>()),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(),
@@ -171,17 +171,17 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
benchmark_mode,
accept_all_drafts,
use_target_sampling);
PD_CHECK(ret == 0, "speculate_verify failed.");
}
PD_CHECK(ret == 0, "speculate_verify failed.");
} else {
if (enable_topp) {
ret = fastdeploy::plugin::speculate_verify<true, false>(
ctx,
sampled_token_ids.data<int64_t>(),
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int64_t*>(accept_tokens.data<int64_t>()),
const_cast<int*>(accept_num.data<int>()),
const_cast<int64_t*>(step_idx.data<int64_t>()),
const_cast<bool*>(stop_flags.data<bool>()),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(),
@@ -211,10 +211,10 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
ret = fastdeploy::plugin::speculate_verify<false, false>(
ctx,
sampled_token_ids.data<int64_t>(),
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int64_t*>(accept_tokens.data<int64_t>()),
const_cast<int*>(accept_num.data<int>()),
const_cast<int64_t*>(step_idx.data<int64_t>()),
const_cast<bool*>(stop_flags.data<bool>()),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(),
@@ -239,8 +239,8 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
benchmark_mode,
accept_all_drafts,
use_target_sampling);
PD_CHECK(ret == 0, "speculate_verify failed.");
}
PD_CHECK(ret == 0, "speculate_verify failed.");
}
if (draft_tokens.is_cpu()) {
delete ctx;
@@ -23,25 +23,25 @@ typedef uint32_t curandStatePhilox4_32_10_t;
template <bool ENABLE_TOPP, bool USE_TOPK>
__attribute__((global)) void speculate_verify(
const int64_t *sampled_token_ids,
int64_t *accept_tokens,
int *accept_num,
int64_t *step_idx,
bool *stop_flags,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
const int64_t *draft_tokens,
const int *actual_draft_token_nums,
const float *dev_curand_states,
const float *topp,
const int *seq_lens_this_time,
const int64_t *verify_tokens,
const float *verify_scores,
const int64_t *max_dec_len,
const int64_t *end_tokens,
const bool *is_block_step,
const int *output_cum_offsets,
const int *actual_candidate_len,
const int64_t* sampled_token_ids,
int64_t* accept_tokens,
int* accept_num,
int64_t* step_idx,
bool* stop_flags,
const int* seq_lens_encoder,
const int* seq_lens_decoder,
const int64_t* draft_tokens,
const int* actual_draft_token_nums,
const float* dev_curand_states,
const float* topp,
const int* seq_lens_this_time,
const int64_t* verify_tokens,
const float* verify_scores,
const int64_t* max_dec_len,
const int64_t* end_tokens,
const bool* is_block_step,
const int* output_cum_offsets,
const int* actual_candidate_len,
const int real_bsz,
const int max_draft_tokens,
const int end_length,
@@ -58,7 +58,7 @@ namespace fastdeploy {
namespace plugin {
static inline bool is_in_end(const int64_t id,
const int64_t *end_ids,
const int64_t* end_ids,
int length) {
bool flag = false;
for (int i = 0; i < length; i++) {
@@ -69,7 +69,7 @@ static inline bool is_in_end(const int64_t id,
return flag;
}
static inline bool is_in(const int64_t *candidates,
static inline bool is_in(const int64_t* candidates,
const int64_t draft,
const int candidate_len) {
for (int i = 0; i < candidate_len; i++) {
@@ -80,7 +80,7 @@ static inline bool is_in(const int64_t *candidates,
return false;
}
static inline unsigned int xorwow(unsigned int &state) { // NOLINT
static inline unsigned int xorwow(unsigned int& state) { // NOLINT
state ^= state >> 7;
state ^= state << 9;
state ^= state >> 13;
@@ -89,9 +89,9 @@ static inline unsigned int xorwow(unsigned int &state) { // NOLINT
typedef uint32_t curandStatePhilox4_32_10_t;
static int64_t topp_sampling_kernel(const int64_t *candidate_ids,
const float *candidate_scores,
const float *dev_curand_states,
static int64_t topp_sampling_kernel(const int64_t* candidate_ids,
const float* candidate_scores,
const float* dev_curand_states,
const int candidate_len,
const float topp,
int tid) {
@@ -111,26 +111,26 @@ static int64_t topp_sampling_kernel(const int64_t *candidate_ids,
}
template <bool ENABLE_TOPP, bool USE_TOPK>
static int cpu_wrapper(api::Context *ctx,
const int64_t *sampled_token_ids,
int64_t *accept_tokens,
int *accept_num,
int64_t *step_idx,
bool *stop_flags,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
const int64_t *draft_tokens,
const int *actual_draft_token_nums,
const float *dev_curand_states,
const float *topp,
const int *seq_lens_this_time,
const int64_t *verify_tokens,
const float *verify_scores,
const int64_t *max_dec_len,
const int64_t *end_tokens,
const bool *is_block_step,
const int *output_cum_offsets,
const int *actual_candidate_len,
static int cpu_wrapper(api::Context* ctx,
const int64_t* sampled_token_ids,
int64_t* accept_tokens,
int* accept_num,
int64_t* step_idx,
bool* stop_flags,
const int* seq_lens_encoder,
const int* seq_lens_decoder,
const int64_t* draft_tokens,
const int* actual_draft_token_nums,
const float* dev_curand_states,
const float* topp,
const int* seq_lens_this_time,
const int64_t* verify_tokens,
const float* verify_scores,
const int64_t* max_dec_len,
const int64_t* end_tokens,
const bool* is_block_step,
const int* output_cum_offsets,
const int* actual_candidate_len,
const int real_bsz,
const int max_draft_tokens,
const int end_length,
@@ -155,11 +155,11 @@ static int cpu_wrapper(api::Context *ctx,
stop_flag_now_int = 1;
} else { // 这里prefill阶段也会进入,但是因为draft
// tokens会置零,因此会直接到最后的采样阶段
auto *verify_tokens_now =
auto* verify_tokens_now =
verify_tokens + start_token_id * max_candidate_len;
auto *draft_tokens_now = draft_tokens + bid * max_draft_tokens;
auto *actual_candidate_len_now = actual_candidate_len + start_token_id;
auto *sampled_token_id_now = sampled_token_ids + start_token_id;
auto* draft_tokens_now = draft_tokens + bid * max_draft_tokens;
auto* actual_candidate_len_now = actual_candidate_len + start_token_id;
auto* sampled_token_id_now = sampled_token_ids + start_token_id;
int i = 0;
// printf("seq_lens_this_time[%d]-1: %d \n",bid,
@@ -306,7 +306,7 @@ static int cpu_wrapper(api::Context *ctx,
// 也是从verify_tokens_now[i]中选一个 但是停止的情况不算
if (!stop_flag_now_int) {
int64_t accept_token;
const float *verify_scores_now =
const float* verify_scores_now =
verify_scores + start_token_id * max_candidate_len;
step_idx[bid]++;
if (use_target_sampling) {
@@ -347,26 +347,26 @@ static int cpu_wrapper(api::Context *ctx,
}
template <bool ENABLE_TOPP, bool USE_TOPK>
static int xpu3_wrapper(api::Context *ctx,
const int64_t *sampled_token_ids,
int64_t *accept_tokens,
int *accept_num,
int64_t *step_idx,
bool *stop_flags,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
const int64_t *draft_tokens,
const int *actual_draft_token_nums,
const float *dev_curand_states,
const float *topp,
const int *seq_lens_this_time,
const int64_t *verify_tokens,
const float *verify_scores,
const int64_t *max_dec_len,
const int64_t *end_tokens,
const bool *is_block_step,
const int *output_cum_offsets,
const int *actual_candidate_len,
static int xpu3_wrapper(api::Context* ctx,
const int64_t* sampled_token_ids,
int64_t* accept_tokens,
int* accept_num,
int64_t* step_idx,
bool* stop_flags,
const int* seq_lens_encoder,
const int* seq_lens_decoder,
const int64_t* draft_tokens,
const int* actual_draft_token_nums,
const float* dev_curand_states,
const float* topp,
const int* seq_lens_this_time,
const int64_t* verify_tokens,
const float* verify_scores,
const int64_t* max_dec_len,
const int64_t* end_tokens,
const bool* is_block_step,
const int* output_cum_offsets,
const int* actual_candidate_len,
const int real_bsz,
const int max_draft_tokens,
const int end_length,
@@ -380,22 +380,22 @@ static int xpu3_wrapper(api::Context *ctx,
using XPU_INT64 = typename api::XPUIndexType<int64_t>::type;
int32_t ret_xre = fd_xpu3::speculate_verify<ENABLE_TOPP, USE_TOPK>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
reinterpret_cast<const XPU_INT64 *>(sampled_token_ids),
reinterpret_cast<XPU_INT64 *>(accept_tokens),
reinterpret_cast<const XPU_INT64*>(sampled_token_ids),
reinterpret_cast<XPU_INT64*>(accept_tokens),
accept_num,
reinterpret_cast<XPU_INT64 *>(step_idx),
reinterpret_cast<XPU_INT64*>(step_idx),
stop_flags,
seq_lens_encoder,
seq_lens_decoder,
reinterpret_cast<const XPU_INT64 *>(draft_tokens),
reinterpret_cast<const XPU_INT64*>(draft_tokens),
actual_draft_token_nums,
dev_curand_states,
topp,
seq_lens_this_time,
reinterpret_cast<const XPU_INT64 *>(verify_tokens),
reinterpret_cast<const XPU_INT64*>(verify_tokens),
verify_scores,
reinterpret_cast<const XPU_INT64 *>(max_dec_len),
reinterpret_cast<const XPU_INT64 *>(end_tokens),
reinterpret_cast<const XPU_INT64*>(max_dec_len),
reinterpret_cast<const XPU_INT64*>(end_tokens),
is_block_step,
output_cum_offsets,
actual_candidate_len,
@@ -413,26 +413,26 @@ static int xpu3_wrapper(api::Context *ctx,
return api::SUCCESS;
}
template <bool ENABLE_TOPP, bool USE_TOPK>
int speculate_verify(api::Context *ctx,
const int64_t *sampled_token_ids,
int64_t *accept_tokens,
int *accept_num,
int64_t *step_idx,
bool *stop_flags,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
const int64_t *draft_tokens,
const int *actual_draft_token_nums,
const float *dev_curand_states,
const float *topp,
const int *seq_lens_this_time,
const int64_t *verify_tokens,
const float *verify_scores,
const int64_t *max_dec_len,
const int64_t *end_tokens,
const bool *is_block_step,
const int *output_cum_offsets,
const int *actual_candidate_len,
int speculate_verify(api::Context* ctx,
const int64_t* sampled_token_ids,
int64_t* accept_tokens,
int* accept_num,
int64_t* step_idx,
bool* stop_flags,
const int* seq_lens_encoder,
const int* seq_lens_decoder,
const int64_t* draft_tokens,
const int* actual_draft_token_nums,
const float* dev_curand_states,
const float* topp,
const int* seq_lens_this_time,
const int64_t* verify_tokens,
const float* verify_scores,
const int64_t* max_dec_len,
const int64_t* end_tokens,
const bool* is_block_step,
const int* output_cum_offsets,
const int* actual_candidate_len,
const int real_bsz,
const int max_draft_tokens,
const int end_length,
@@ -475,6 +475,7 @@ int speculate_verify(api::Context *ctx,
benchmark_mode);
WRAPPER_DUMP_PARAM2(ctx, accept_all_drafts, use_target_sampling);
WRAPPER_DUMP(ctx);
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz, sampled_token_ids);
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * max_draft_tokens, accept_tokens);
WRAPPER_CHECK_PTR(ctx, int, real_bsz, accept_num);
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz, step_idx);
@@ -483,7 +484,7 @@ int speculate_verify(api::Context *ctx,
WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_decoder);
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * max_draft_tokens, draft_tokens);
WRAPPER_CHECK_PTR(ctx, int, real_bsz, actual_draft_token_nums);
WRAPPER_CHECK_PTR(ctx, float, real_bsz, dev_curand_states);
// WRAPPER_CHECK_PTR(ctx, float, real_bsz, dev_curand_states);
WRAPPER_CHECK_PTR(ctx, float, real_bsz, topp);
WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_this_time);
// WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz, verify_tokens);
@@ -574,36 +575,36 @@ int speculate_verify(api::Context *ctx,
#define INSTANTIATE_SPECULATE_VERIFY(ENABLE_TOPP, USE_TOPK) \
template int fastdeploy::plugin::speculate_verify<ENABLE_TOPP, USE_TOPK>( \
fastdeploy::plugin::api::Context *, /* xpu_ctx */ \
const int64_t *, /* sampled_token_ids */ \
int64_t *, /* accept_tokens */ \
int *, /* accept_num */ \
int64_t *, /* step_idx */ \
bool *, /* stop_flags */ \
const int *, /* seq_lens_encoder */ \
const int *, /* seq_lens_decoder */ \
const int64_t *, /* draft_tokens */ \
const int *, /* actual_draft_token_nums */ \
const float *, /* dev_curand_states or topp */ \
const float *, /* topp or nullptr */ \
const int *, /* seq_lens_this_time */ \
const int64_t *, /* verify_tokens */ \
const float *, /* verify_scores */ \
const int64_t *, /* max_dec_len */ \
const int64_t *, /* end_tokens */ \
const bool *, /* is_block_step */ \
const int *, /* output_cum_offsets */ \
const int *, /* actual_candidate_len */ \
int, /* real_bsz */ \
int, /* max_draft_tokens */ \
int, /* end_length */ \
int, /* max_seq_len */ \
int, /* max_candidate_len */ \
int, /* verify_window */ \
bool, /* prefill_one_step_stop */ \
bool, /* benchmark_mode */ \
bool, /* accept_all_drafts */ \
bool /* use_target_sampling */ \
fastdeploy::plugin::api::Context*, /* xpu_ctx */ \
const int64_t*, /* sampled_token_ids */ \
int64_t*, /* accept_tokens */ \
int*, /* accept_num */ \
int64_t*, /* step_idx */ \
bool*, /* stop_flags */ \
const int*, /* seq_lens_encoder */ \
const int*, /* seq_lens_decoder */ \
const int64_t*, /* draft_tokens */ \
const int*, /* actual_draft_token_nums */ \
const float*, /* dev_curand_states or topp */ \
const float*, /* topp or nullptr */ \
const int*, /* seq_lens_this_time */ \
const int64_t*, /* verify_tokens */ \
const float*, /* verify_scores */ \
const int64_t*, /* max_dec_len */ \
const int64_t*, /* end_tokens */ \
const bool*, /* is_block_step */ \
const int*, /* output_cum_offsets */ \
const int*, /* actual_candidate_len */ \
int, /* real_bsz */ \
int, /* max_draft_tokens */ \
int, /* end_length */ \
int, /* max_seq_len */ \
int, /* max_candidate_len */ \
int, /* verify_window */ \
bool, /* prefill_one_step_stop */ \
bool, /* benchmark_mode */ \
bool, /* accept_all_drafts */ \
bool /* use_target_sampling */ \
);
INSTANTIATE_SPECULATE_VERIFY(false, false)