mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[XPU] fix speculate_verify (#6985)
This commit is contained in:
@@ -26,24 +26,24 @@
|
||||
|
||||
namespace api = baidu::xpu::api;
|
||||
|
||||
void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
|
||||
const paddle::Tensor &accept_tokens,
|
||||
const paddle::Tensor &accept_num,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &verify_tokens,
|
||||
const paddle::Tensor &verify_scores,
|
||||
const paddle::Tensor &max_dec_len,
|
||||
const paddle::Tensor &end_tokens,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const paddle::Tensor &output_cum_offsets,
|
||||
const paddle::Tensor &actual_candidate_len,
|
||||
const paddle::Tensor &actual_draft_token_nums,
|
||||
const paddle::Tensor &topp,
|
||||
void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
|
||||
const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& verify_tokens,
|
||||
const paddle::Tensor& verify_scores,
|
||||
const paddle::Tensor& max_dec_len,
|
||||
const paddle::Tensor& end_tokens,
|
||||
const paddle::Tensor& is_block_step,
|
||||
const paddle::Tensor& output_cum_offsets,
|
||||
const paddle::Tensor& actual_candidate_len,
|
||||
const paddle::Tensor& actual_draft_token_nums,
|
||||
const paddle::Tensor& topp,
|
||||
int max_seq_len,
|
||||
int verify_window,
|
||||
bool enable_topp,
|
||||
@@ -57,8 +57,7 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
|
||||
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
api::Context *ctx =
|
||||
static_cast<const phi::XPUContext *>(dev_ctx)->x_context();
|
||||
api::Context* ctx = static_cast<const phi::XPUContext*>(dev_ctx)->x_context();
|
||||
bool xpu_ctx_flag = true;
|
||||
if (draft_tokens.is_cpu()) {
|
||||
ctx = new api::Context(api::kCPU);
|
||||
@@ -66,17 +65,17 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
|
||||
}
|
||||
|
||||
bool use_topk = false;
|
||||
char *env_var = getenv("SPECULATE_VERIFY_USE_TOPK");
|
||||
char* env_var = getenv("SPECULATE_VERIFY_USE_TOPK");
|
||||
if (env_var) {
|
||||
use_topk = static_cast<bool>(std::stoi(env_var));
|
||||
}
|
||||
bool use_target_sampling = false;
|
||||
char *env_var_1 = getenv("SPECULATE_VERIFY_USE_TARGET_SAMPLING");
|
||||
char* env_var_1 = getenv("SPECULATE_VERIFY_USE_TARGET_SAMPLING");
|
||||
if (env_var_1) {
|
||||
use_target_sampling = static_cast<bool>(std::stoi(env_var_1));
|
||||
}
|
||||
bool prefill_one_step_stop = false;
|
||||
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
|
||||
if (const char* env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
|
||||
// std::cout << "Your PATH is: " << env_p << '\n';
|
||||
if (env_p[0] == '1') {
|
||||
prefill_one_step_stop = true;
|
||||
@@ -91,29 +90,30 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
|
||||
std::mt19937_64 engine(infer_seed[i]);
|
||||
dev_curand_states_cpu.push_back(dist(engine));
|
||||
}
|
||||
float *dev_curand_states_xpu;
|
||||
float* dev_curand_states = dev_curand_states_cpu.data();
|
||||
auto dev_curand_states_tensor =
|
||||
paddle::empty({static_cast<int64_t>(dev_curand_states_cpu.size())},
|
||||
paddle::DataType::FLOAT32,
|
||||
draft_tokens.place());
|
||||
int ret;
|
||||
if (xpu_ctx_flag) {
|
||||
xpu::ctx_guard RAII_GUARD(ctx);
|
||||
dev_curand_states_xpu =
|
||||
RAII_GUARD.alloc<float>(dev_curand_states_cpu.size());
|
||||
xpu_memcpy(dev_curand_states_xpu,
|
||||
dev_curand_states_cpu.data(),
|
||||
dev_curand_states_cpu.size() * sizeof(float),
|
||||
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
|
||||
ret = xpu::do_host2device(ctx,
|
||||
dev_curand_states_cpu.data(),
|
||||
dev_curand_states_tensor.data<float>(),
|
||||
dev_curand_states_cpu.size() * sizeof(float));
|
||||
PD_CHECK(ret == 0, "do_host2device failed.");
|
||||
dev_curand_states = dev_curand_states_tensor.data<float>();
|
||||
}
|
||||
|
||||
auto dev_curand_states =
|
||||
!xpu_ctx_flag ? dev_curand_states_cpu.data() : dev_curand_states_xpu;
|
||||
int ret;
|
||||
if (use_topk) {
|
||||
if (enable_topp) {
|
||||
ret = fastdeploy::plugin::speculate_verify<true, true>(
|
||||
ctx,
|
||||
sampled_token_ids.data<int64_t>(),
|
||||
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
||||
const_cast<int *>(accept_num.data<int>()),
|
||||
const_cast<int64_t *>(step_idx.data<int64_t>()),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int64_t*>(accept_tokens.data<int64_t>()),
|
||||
const_cast<int*>(accept_num.data<int>()),
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
const_cast<bool*>(stop_flags.data<bool>()),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
draft_tokens.data<int64_t>(),
|
||||
@@ -143,10 +143,10 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
|
||||
ret = fastdeploy::plugin::speculate_verify<false, true>(
|
||||
ctx,
|
||||
sampled_token_ids.data<int64_t>(),
|
||||
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
||||
const_cast<int *>(accept_num.data<int>()),
|
||||
const_cast<int64_t *>(step_idx.data<int64_t>()),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int64_t*>(accept_tokens.data<int64_t>()),
|
||||
const_cast<int*>(accept_num.data<int>()),
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
const_cast<bool*>(stop_flags.data<bool>()),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
draft_tokens.data<int64_t>(),
|
||||
@@ -171,17 +171,17 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
|
||||
benchmark_mode,
|
||||
accept_all_drafts,
|
||||
use_target_sampling);
|
||||
PD_CHECK(ret == 0, "speculate_verify failed.");
|
||||
}
|
||||
PD_CHECK(ret == 0, "speculate_verify failed.");
|
||||
} else {
|
||||
if (enable_topp) {
|
||||
ret = fastdeploy::plugin::speculate_verify<true, false>(
|
||||
ctx,
|
||||
sampled_token_ids.data<int64_t>(),
|
||||
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
||||
const_cast<int *>(accept_num.data<int>()),
|
||||
const_cast<int64_t *>(step_idx.data<int64_t>()),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int64_t*>(accept_tokens.data<int64_t>()),
|
||||
const_cast<int*>(accept_num.data<int>()),
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
const_cast<bool*>(stop_flags.data<bool>()),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
draft_tokens.data<int64_t>(),
|
||||
@@ -211,10 +211,10 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
|
||||
ret = fastdeploy::plugin::speculate_verify<false, false>(
|
||||
ctx,
|
||||
sampled_token_ids.data<int64_t>(),
|
||||
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
||||
const_cast<int *>(accept_num.data<int>()),
|
||||
const_cast<int64_t *>(step_idx.data<int64_t>()),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int64_t*>(accept_tokens.data<int64_t>()),
|
||||
const_cast<int*>(accept_num.data<int>()),
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
const_cast<bool*>(stop_flags.data<bool>()),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
draft_tokens.data<int64_t>(),
|
||||
@@ -239,8 +239,8 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
|
||||
benchmark_mode,
|
||||
accept_all_drafts,
|
||||
use_target_sampling);
|
||||
PD_CHECK(ret == 0, "speculate_verify failed.");
|
||||
}
|
||||
PD_CHECK(ret == 0, "speculate_verify failed.");
|
||||
}
|
||||
if (draft_tokens.is_cpu()) {
|
||||
delete ctx;
|
||||
|
||||
@@ -23,25 +23,25 @@ typedef uint32_t curandStatePhilox4_32_10_t;
|
||||
|
||||
template <bool ENABLE_TOPP, bool USE_TOPK>
|
||||
__attribute__((global)) void speculate_verify(
|
||||
const int64_t *sampled_token_ids,
|
||||
int64_t *accept_tokens,
|
||||
int *accept_num,
|
||||
int64_t *step_idx,
|
||||
bool *stop_flags,
|
||||
const int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder,
|
||||
const int64_t *draft_tokens,
|
||||
const int *actual_draft_token_nums,
|
||||
const float *dev_curand_states,
|
||||
const float *topp,
|
||||
const int *seq_lens_this_time,
|
||||
const int64_t *verify_tokens,
|
||||
const float *verify_scores,
|
||||
const int64_t *max_dec_len,
|
||||
const int64_t *end_tokens,
|
||||
const bool *is_block_step,
|
||||
const int *output_cum_offsets,
|
||||
const int *actual_candidate_len,
|
||||
const int64_t* sampled_token_ids,
|
||||
int64_t* accept_tokens,
|
||||
int* accept_num,
|
||||
int64_t* step_idx,
|
||||
bool* stop_flags,
|
||||
const int* seq_lens_encoder,
|
||||
const int* seq_lens_decoder,
|
||||
const int64_t* draft_tokens,
|
||||
const int* actual_draft_token_nums,
|
||||
const float* dev_curand_states,
|
||||
const float* topp,
|
||||
const int* seq_lens_this_time,
|
||||
const int64_t* verify_tokens,
|
||||
const float* verify_scores,
|
||||
const int64_t* max_dec_len,
|
||||
const int64_t* end_tokens,
|
||||
const bool* is_block_step,
|
||||
const int* output_cum_offsets,
|
||||
const int* actual_candidate_len,
|
||||
const int real_bsz,
|
||||
const int max_draft_tokens,
|
||||
const int end_length,
|
||||
@@ -58,7 +58,7 @@ namespace fastdeploy {
|
||||
namespace plugin {
|
||||
|
||||
static inline bool is_in_end(const int64_t id,
|
||||
const int64_t *end_ids,
|
||||
const int64_t* end_ids,
|
||||
int length) {
|
||||
bool flag = false;
|
||||
for (int i = 0; i < length; i++) {
|
||||
@@ -69,7 +69,7 @@ static inline bool is_in_end(const int64_t id,
|
||||
return flag;
|
||||
}
|
||||
|
||||
static inline bool is_in(const int64_t *candidates,
|
||||
static inline bool is_in(const int64_t* candidates,
|
||||
const int64_t draft,
|
||||
const int candidate_len) {
|
||||
for (int i = 0; i < candidate_len; i++) {
|
||||
@@ -80,7 +80,7 @@ static inline bool is_in(const int64_t *candidates,
|
||||
return false;
|
||||
}
|
||||
|
||||
static inline unsigned int xorwow(unsigned int &state) { // NOLINT
|
||||
static inline unsigned int xorwow(unsigned int& state) { // NOLINT
|
||||
state ^= state >> 7;
|
||||
state ^= state << 9;
|
||||
state ^= state >> 13;
|
||||
@@ -89,9 +89,9 @@ static inline unsigned int xorwow(unsigned int &state) { // NOLINT
|
||||
|
||||
typedef uint32_t curandStatePhilox4_32_10_t;
|
||||
|
||||
static int64_t topp_sampling_kernel(const int64_t *candidate_ids,
|
||||
const float *candidate_scores,
|
||||
const float *dev_curand_states,
|
||||
static int64_t topp_sampling_kernel(const int64_t* candidate_ids,
|
||||
const float* candidate_scores,
|
||||
const float* dev_curand_states,
|
||||
const int candidate_len,
|
||||
const float topp,
|
||||
int tid) {
|
||||
@@ -111,26 +111,26 @@ static int64_t topp_sampling_kernel(const int64_t *candidate_ids,
|
||||
}
|
||||
|
||||
template <bool ENABLE_TOPP, bool USE_TOPK>
|
||||
static int cpu_wrapper(api::Context *ctx,
|
||||
const int64_t *sampled_token_ids,
|
||||
int64_t *accept_tokens,
|
||||
int *accept_num,
|
||||
int64_t *step_idx,
|
||||
bool *stop_flags,
|
||||
const int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder,
|
||||
const int64_t *draft_tokens,
|
||||
const int *actual_draft_token_nums,
|
||||
const float *dev_curand_states,
|
||||
const float *topp,
|
||||
const int *seq_lens_this_time,
|
||||
const int64_t *verify_tokens,
|
||||
const float *verify_scores,
|
||||
const int64_t *max_dec_len,
|
||||
const int64_t *end_tokens,
|
||||
const bool *is_block_step,
|
||||
const int *output_cum_offsets,
|
||||
const int *actual_candidate_len,
|
||||
static int cpu_wrapper(api::Context* ctx,
|
||||
const int64_t* sampled_token_ids,
|
||||
int64_t* accept_tokens,
|
||||
int* accept_num,
|
||||
int64_t* step_idx,
|
||||
bool* stop_flags,
|
||||
const int* seq_lens_encoder,
|
||||
const int* seq_lens_decoder,
|
||||
const int64_t* draft_tokens,
|
||||
const int* actual_draft_token_nums,
|
||||
const float* dev_curand_states,
|
||||
const float* topp,
|
||||
const int* seq_lens_this_time,
|
||||
const int64_t* verify_tokens,
|
||||
const float* verify_scores,
|
||||
const int64_t* max_dec_len,
|
||||
const int64_t* end_tokens,
|
||||
const bool* is_block_step,
|
||||
const int* output_cum_offsets,
|
||||
const int* actual_candidate_len,
|
||||
const int real_bsz,
|
||||
const int max_draft_tokens,
|
||||
const int end_length,
|
||||
@@ -155,11 +155,11 @@ static int cpu_wrapper(api::Context *ctx,
|
||||
stop_flag_now_int = 1;
|
||||
} else { // 这里prefill阶段也会进入,但是因为draft
|
||||
// tokens会置零,因此会直接到最后的采样阶段
|
||||
auto *verify_tokens_now =
|
||||
auto* verify_tokens_now =
|
||||
verify_tokens + start_token_id * max_candidate_len;
|
||||
auto *draft_tokens_now = draft_tokens + bid * max_draft_tokens;
|
||||
auto *actual_candidate_len_now = actual_candidate_len + start_token_id;
|
||||
auto *sampled_token_id_now = sampled_token_ids + start_token_id;
|
||||
auto* draft_tokens_now = draft_tokens + bid * max_draft_tokens;
|
||||
auto* actual_candidate_len_now = actual_candidate_len + start_token_id;
|
||||
auto* sampled_token_id_now = sampled_token_ids + start_token_id;
|
||||
|
||||
int i = 0;
|
||||
// printf("seq_lens_this_time[%d]-1: %d \n",bid,
|
||||
@@ -306,7 +306,7 @@ static int cpu_wrapper(api::Context *ctx,
|
||||
// 也是从verify_tokens_now[i]中选一个 但是停止的情况不算
|
||||
if (!stop_flag_now_int) {
|
||||
int64_t accept_token;
|
||||
const float *verify_scores_now =
|
||||
const float* verify_scores_now =
|
||||
verify_scores + start_token_id * max_candidate_len;
|
||||
step_idx[bid]++;
|
||||
if (use_target_sampling) {
|
||||
@@ -347,26 +347,26 @@ static int cpu_wrapper(api::Context *ctx,
|
||||
}
|
||||
|
||||
template <bool ENABLE_TOPP, bool USE_TOPK>
|
||||
static int xpu3_wrapper(api::Context *ctx,
|
||||
const int64_t *sampled_token_ids,
|
||||
int64_t *accept_tokens,
|
||||
int *accept_num,
|
||||
int64_t *step_idx,
|
||||
bool *stop_flags,
|
||||
const int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder,
|
||||
const int64_t *draft_tokens,
|
||||
const int *actual_draft_token_nums,
|
||||
const float *dev_curand_states,
|
||||
const float *topp,
|
||||
const int *seq_lens_this_time,
|
||||
const int64_t *verify_tokens,
|
||||
const float *verify_scores,
|
||||
const int64_t *max_dec_len,
|
||||
const int64_t *end_tokens,
|
||||
const bool *is_block_step,
|
||||
const int *output_cum_offsets,
|
||||
const int *actual_candidate_len,
|
||||
static int xpu3_wrapper(api::Context* ctx,
|
||||
const int64_t* sampled_token_ids,
|
||||
int64_t* accept_tokens,
|
||||
int* accept_num,
|
||||
int64_t* step_idx,
|
||||
bool* stop_flags,
|
||||
const int* seq_lens_encoder,
|
||||
const int* seq_lens_decoder,
|
||||
const int64_t* draft_tokens,
|
||||
const int* actual_draft_token_nums,
|
||||
const float* dev_curand_states,
|
||||
const float* topp,
|
||||
const int* seq_lens_this_time,
|
||||
const int64_t* verify_tokens,
|
||||
const float* verify_scores,
|
||||
const int64_t* max_dec_len,
|
||||
const int64_t* end_tokens,
|
||||
const bool* is_block_step,
|
||||
const int* output_cum_offsets,
|
||||
const int* actual_candidate_len,
|
||||
const int real_bsz,
|
||||
const int max_draft_tokens,
|
||||
const int end_length,
|
||||
@@ -380,22 +380,22 @@ static int xpu3_wrapper(api::Context *ctx,
|
||||
using XPU_INT64 = typename api::XPUIndexType<int64_t>::type;
|
||||
int32_t ret_xre = fd_xpu3::speculate_verify<ENABLE_TOPP, USE_TOPK>
|
||||
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
reinterpret_cast<const XPU_INT64 *>(sampled_token_ids),
|
||||
reinterpret_cast<XPU_INT64 *>(accept_tokens),
|
||||
reinterpret_cast<const XPU_INT64*>(sampled_token_ids),
|
||||
reinterpret_cast<XPU_INT64*>(accept_tokens),
|
||||
accept_num,
|
||||
reinterpret_cast<XPU_INT64 *>(step_idx),
|
||||
reinterpret_cast<XPU_INT64*>(step_idx),
|
||||
stop_flags,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
reinterpret_cast<const XPU_INT64 *>(draft_tokens),
|
||||
reinterpret_cast<const XPU_INT64*>(draft_tokens),
|
||||
actual_draft_token_nums,
|
||||
dev_curand_states,
|
||||
topp,
|
||||
seq_lens_this_time,
|
||||
reinterpret_cast<const XPU_INT64 *>(verify_tokens),
|
||||
reinterpret_cast<const XPU_INT64*>(verify_tokens),
|
||||
verify_scores,
|
||||
reinterpret_cast<const XPU_INT64 *>(max_dec_len),
|
||||
reinterpret_cast<const XPU_INT64 *>(end_tokens),
|
||||
reinterpret_cast<const XPU_INT64*>(max_dec_len),
|
||||
reinterpret_cast<const XPU_INT64*>(end_tokens),
|
||||
is_block_step,
|
||||
output_cum_offsets,
|
||||
actual_candidate_len,
|
||||
@@ -413,26 +413,26 @@ static int xpu3_wrapper(api::Context *ctx,
|
||||
return api::SUCCESS;
|
||||
}
|
||||
template <bool ENABLE_TOPP, bool USE_TOPK>
|
||||
int speculate_verify(api::Context *ctx,
|
||||
const int64_t *sampled_token_ids,
|
||||
int64_t *accept_tokens,
|
||||
int *accept_num,
|
||||
int64_t *step_idx,
|
||||
bool *stop_flags,
|
||||
const int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder,
|
||||
const int64_t *draft_tokens,
|
||||
const int *actual_draft_token_nums,
|
||||
const float *dev_curand_states,
|
||||
const float *topp,
|
||||
const int *seq_lens_this_time,
|
||||
const int64_t *verify_tokens,
|
||||
const float *verify_scores,
|
||||
const int64_t *max_dec_len,
|
||||
const int64_t *end_tokens,
|
||||
const bool *is_block_step,
|
||||
const int *output_cum_offsets,
|
||||
const int *actual_candidate_len,
|
||||
int speculate_verify(api::Context* ctx,
|
||||
const int64_t* sampled_token_ids,
|
||||
int64_t* accept_tokens,
|
||||
int* accept_num,
|
||||
int64_t* step_idx,
|
||||
bool* stop_flags,
|
||||
const int* seq_lens_encoder,
|
||||
const int* seq_lens_decoder,
|
||||
const int64_t* draft_tokens,
|
||||
const int* actual_draft_token_nums,
|
||||
const float* dev_curand_states,
|
||||
const float* topp,
|
||||
const int* seq_lens_this_time,
|
||||
const int64_t* verify_tokens,
|
||||
const float* verify_scores,
|
||||
const int64_t* max_dec_len,
|
||||
const int64_t* end_tokens,
|
||||
const bool* is_block_step,
|
||||
const int* output_cum_offsets,
|
||||
const int* actual_candidate_len,
|
||||
const int real_bsz,
|
||||
const int max_draft_tokens,
|
||||
const int end_length,
|
||||
@@ -475,6 +475,7 @@ int speculate_verify(api::Context *ctx,
|
||||
benchmark_mode);
|
||||
WRAPPER_DUMP_PARAM2(ctx, accept_all_drafts, use_target_sampling);
|
||||
WRAPPER_DUMP(ctx);
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz, sampled_token_ids);
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * max_draft_tokens, accept_tokens);
|
||||
WRAPPER_CHECK_PTR(ctx, int, real_bsz, accept_num);
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz, step_idx);
|
||||
@@ -483,7 +484,7 @@ int speculate_verify(api::Context *ctx,
|
||||
WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_decoder);
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * max_draft_tokens, draft_tokens);
|
||||
WRAPPER_CHECK_PTR(ctx, int, real_bsz, actual_draft_token_nums);
|
||||
WRAPPER_CHECK_PTR(ctx, float, real_bsz, dev_curand_states);
|
||||
// WRAPPER_CHECK_PTR(ctx, float, real_bsz, dev_curand_states);
|
||||
WRAPPER_CHECK_PTR(ctx, float, real_bsz, topp);
|
||||
WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_this_time);
|
||||
// WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz, verify_tokens);
|
||||
@@ -574,36 +575,36 @@ int speculate_verify(api::Context *ctx,
|
||||
|
||||
#define INSTANTIATE_SPECULATE_VERIFY(ENABLE_TOPP, USE_TOPK) \
|
||||
template int fastdeploy::plugin::speculate_verify<ENABLE_TOPP, USE_TOPK>( \
|
||||
fastdeploy::plugin::api::Context *, /* xpu_ctx */ \
|
||||
const int64_t *, /* sampled_token_ids */ \
|
||||
int64_t *, /* accept_tokens */ \
|
||||
int *, /* accept_num */ \
|
||||
int64_t *, /* step_idx */ \
|
||||
bool *, /* stop_flags */ \
|
||||
const int *, /* seq_lens_encoder */ \
|
||||
const int *, /* seq_lens_decoder */ \
|
||||
const int64_t *, /* draft_tokens */ \
|
||||
const int *, /* actual_draft_token_nums */ \
|
||||
const float *, /* dev_curand_states or topp */ \
|
||||
const float *, /* topp or nullptr */ \
|
||||
const int *, /* seq_lens_this_time */ \
|
||||
const int64_t *, /* verify_tokens */ \
|
||||
const float *, /* verify_scores */ \
|
||||
const int64_t *, /* max_dec_len */ \
|
||||
const int64_t *, /* end_tokens */ \
|
||||
const bool *, /* is_block_step */ \
|
||||
const int *, /* output_cum_offsets */ \
|
||||
const int *, /* actual_candidate_len */ \
|
||||
int, /* real_bsz */ \
|
||||
int, /* max_draft_tokens */ \
|
||||
int, /* end_length */ \
|
||||
int, /* max_seq_len */ \
|
||||
int, /* max_candidate_len */ \
|
||||
int, /* verify_window */ \
|
||||
bool, /* prefill_one_step_stop */ \
|
||||
bool, /* benchmark_mode */ \
|
||||
bool, /* accept_all_drafts */ \
|
||||
bool /* use_target_sampling */ \
|
||||
fastdeploy::plugin::api::Context*, /* xpu_ctx */ \
|
||||
const int64_t*, /* sampled_token_ids */ \
|
||||
int64_t*, /* accept_tokens */ \
|
||||
int*, /* accept_num */ \
|
||||
int64_t*, /* step_idx */ \
|
||||
bool*, /* stop_flags */ \
|
||||
const int*, /* seq_lens_encoder */ \
|
||||
const int*, /* seq_lens_decoder */ \
|
||||
const int64_t*, /* draft_tokens */ \
|
||||
const int*, /* actual_draft_token_nums */ \
|
||||
const float*, /* dev_curand_states or topp */ \
|
||||
const float*, /* topp or nullptr */ \
|
||||
const int*, /* seq_lens_this_time */ \
|
||||
const int64_t*, /* verify_tokens */ \
|
||||
const float*, /* verify_scores */ \
|
||||
const int64_t*, /* max_dec_len */ \
|
||||
const int64_t*, /* end_tokens */ \
|
||||
const bool*, /* is_block_step */ \
|
||||
const int*, /* output_cum_offsets */ \
|
||||
const int*, /* actual_candidate_len */ \
|
||||
int, /* real_bsz */ \
|
||||
int, /* max_draft_tokens */ \
|
||||
int, /* end_length */ \
|
||||
int, /* max_seq_len */ \
|
||||
int, /* max_candidate_len */ \
|
||||
int, /* verify_window */ \
|
||||
bool, /* prefill_one_step_stop */ \
|
||||
bool, /* benchmark_mode */ \
|
||||
bool, /* accept_all_drafts */ \
|
||||
bool /* use_target_sampling */ \
|
||||
);
|
||||
|
||||
INSTANTIATE_SPECULATE_VERIFY(false, false)
|
||||
|
||||
Reference in New Issue
Block a user