[XPU] Speculate Decoding + PD, benchmark fix (#6036)

* fix mtp pd

* fix kernel

* fix code style

* fix kernel

* fix test / clear debug code

* fix test / clear debug code

* fix codestyle

* fix codestyle

* fix codestyle
This commit is contained in:
cmcamdy
2026-01-15 19:19:03 +08:00
committed by GitHub
parent 6619298b50
commit 59d8ae0a25
13 changed files with 995 additions and 31 deletions
@@ -18,15 +18,7 @@
#include <sys/msg.h>
#include <sys/types.h>
#include "paddle/extension.h"
#define MAX_BSZ 256
#define MAX_DRAFT_TOKENS 6
struct msgdata {
int64_t mtype;
int mtext[MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ +
2]; // stop_flag, bsz, accept_num*bsz, tokens...
};
#include "speculate_msg.h"
void SpeculateGetOutput(const paddle::Tensor& x,
int64_t rank_id,
@@ -49,7 +41,7 @@ void SpeculateGetOutput(const paddle::Tensor& x,
msg_queue_id = inference_msg_queue_id_from_env;
}
static struct msgdata msg_rcv;
static struct speculate_msgdata msg_rcv;
static key_t key = ftok("./", msg_queue_id);
@@ -21,7 +21,7 @@
#include <sys/types.h>
#include "paddle/extension.h"
#define MAX_BSZ 512
#define MAX_BSZ 256
#define MAX_DRAFT_TOKENS 6
struct speculate_msgdata {
@@ -19,17 +19,9 @@
#include <sys/msg.h>
#include <sys/types.h>
#include "paddle/extension.h"
#include "speculate_msg.h"
#include "xpu/plugin.h"
#define MAX_BSZ 256
#define MAX_DRAFT_TOKENS 6
struct msgdata {
long mtype; // NOLINT
int mtext[MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ +
2]; // stop_flag, bsz, tokens
};
void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& not_need_stop,
@@ -59,7 +51,7 @@ void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
#endif
msg_queue_id = inference_msg_queue_id_from_env;
}
static struct msgdata msg_sed;
static struct speculate_msgdata msg_sed;
static key_t key = ftok("./", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);
@@ -0,0 +1,146 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/phi/core/enforce.h"
#include "speculate_msg.h" // NOLINT
#include "xpu/plugin.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
// 为不修改接口调用方式,入参暂不改变
void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
const paddle::Tensor &block_tables,
const paddle::Tensor &stop_flags,
const paddle::Tensor &prompt_lens,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &step_seq_lens_decoder,
const paddle::Tensor &step_draft_tokens,
const paddle::Tensor &step_seq_lens_this_time,
const paddle::Tensor &accept_num,
const paddle::Tensor &accept_tokens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &stop_nums,
const int block_size,
const int max_draft_tokens) {
namespace api = baidu::xpu::api;
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
api::Context *ctx = xpu_ctx->x_context();
if (stop_flags.is_cpu()) {
ctx = new api::Context(api::kCPU);
}
const int real_bsz = seq_lens_this_time.shape()[0];
const int max_bsz = stop_flags.shape()[0];
const int accept_tokens_len = accept_tokens.shape()[1];
const int draft_token_len = draft_tokens.shape()[1];
const int block_num_per_seq = block_tables.shape()[1];
const int max_next_step_tokens = 2 * max_draft_tokens + 2;
constexpr int BlockSize = MAX_BSZ; // bsz <= 512
bool prefill_one_step_stop = false;
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP_V1")) {
if (env_p[0] == '1') {
prefill_one_step_stop = true;
}
}
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
int r = baidu::xpu::api::plugin::speculate_schedule_cache(
ctx,
draft_tokens.data<int64_t>(),
const_cast<int *>(block_tables.data<int>()),
const_cast<bool *>(stop_flags.data<bool>()),
prompt_lens.data<int64_t>(),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int *>(step_seq_lens_decoder.data<int>()),
const_cast<int64_t *>(step_draft_tokens.data<int64_t>()),
const_cast<int *>(step_seq_lens_this_time.data<int>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<bool *>(is_block_step.data<bool>()),
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
stop_nums.data<int64_t>(),
real_bsz,
max_bsz,
max_next_step_tokens,
draft_token_len,
accept_tokens_len,
block_size,
block_num_per_seq,
prefill_one_step_stop);
// kernel launch
PD_CHECK(r == 0, "speculate_free_and_reschedule failed.");
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), true);
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}
// PD_BUILD_STATIC_OP(speculate_schedule_cache)
PD_BUILD_OP(speculate_schedule_cache)
.Inputs({"draft_tokens",
"block_tables",
"stop_flags",
"prompt_lens",
"seq_lens_this_time",
"seq_lens_encoder",
"seq_lens_decoder",
"step_seq_lens_decoder",
"step_draft_tokens",
"step_seq_lens_this_time",
"accept_num",
"accept_tokens",
"is_block_step",
"not_need_stop",
"stop_nums"})
.Attrs({"block_size: int", "max_draft_tokens: int"})
.Outputs({"draft_tokens_out",
"block_tables_out",
"stop_flags_out",
"seq_lens_this_time_out",
"seq_lens_encoder_out",
"seq_lens_decoder_out",
"step_seq_lens_decoder_out",
"step_draft_tokens_out",
"step_seq_lens_this_time_out",
"accept_num_out",
"accept_tokens_out",
"is_block_step_out",
"not_need_stop_out"})
.SetInplaceMap({
{"draft_tokens", "draft_tokens_out"},
{"block_tables", "block_tables_out"},
{"stop_flags", "stop_flags_out"},
{"seq_lens_this_time", "seq_lens_this_time_out"},
{"seq_lens_encoder", "seq_lens_encoder_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"step_seq_lens_decoder", "step_seq_lens_decoder_out"},
{"step_draft_tokens", "step_draft_tokens_out"},
{"step_seq_lens_this_time", "step_seq_lens_this_time_out"},
{"accept_num", "accept_num_out"},
{"accept_tokens", "accept_tokens_out"},
{"is_block_step", "is_block_step_out"},
{"not_need_stop", "not_need_stop_out"},
})
.SetKernelFn(PD_KERNEL(SpeculateScheduleCache));
@@ -61,7 +61,7 @@ void SpeculateStepSchedule(
const int block_num_per_seq = block_tables.shape()[1];
const int length = input_ids.shape()[1];
const int pre_id_length = pre_ids.shape()[1];
constexpr int BlockSize = 256; // bsz <= 256
constexpr int BlockSize = MAX_BSZ; // bsz <= MAX_BSZ
const int max_decoder_block_num =
length / block_size -
encoder_decoder_block_num; // 最大输出长度对应的block -
@@ -574,6 +574,31 @@ DLL_EXPORT int speculate_free_and_reschedule(Context* ctx,
const int max_decoder_block_num,
const int max_draft_tokens);
DLL_EXPORT int speculate_schedule_cache(Context* ctx,
const int64_t* draft_tokens,
int* block_tables,
bool* stop_flags,
const int64_t* prompt_lens,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int* step_seq_lens_decoder,
int64_t* step_draft_tokens,
int* step_seq_lens_this_time,
int* accept_num,
int64_t* accept_tokens,
bool* is_block_step,
bool* not_need_stop,
const int64_t* stop_nums,
const int real_bsz,
const int max_bsz,
const int max_next_step_tokens,
const int draft_tokens_len,
const int accept_tokens_len,
const int block_size,
const int block_num_per_seq,
const bool prefill_one_step_stop);
DLL_EXPORT int speculate_update_v3(Context* ctx,
int* seq_lens_encoder,
int* seq_lens_decoder,
@@ -0,0 +1,182 @@
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
#include "xpu/kernel/xtdk_io.h"
namespace xpu3 {
namespace plugin {
static inline __device__ int v_reduce(int32x16_t &v0, int32x16_t &v1) {
int res;
v1 = vvadd_int32x16(v0, v1);
auto v = vsrlp_int32x16(256, v1);
v1 = vvadd_int32x16(v, v1);
v = vsrlp_int32x16(128, v1);
v1 = vvadd_int32x16(v, v1);
v = vsrlp_int32x16(64, v1);
v1 = vvadd_int32x16(v, v1);
v = vsrlp_int32x16(32, v1);
v1 = vvadd_int32x16(v, v1);
res = vextract_int32x16(v1, 1);
return res;
}
static inline __device__ int ClusterReduce(
const _shared_ptr_ int *stop_flag_now_int_sm, int len) {
int sum = 0;
if (core_id() == 0) {
int32x16_t vec_x_0;
int32x16_t vec_x_1;
int32x16_t vec_y_0 = vzero<int>();
int32x16_t vec_y_1 = vzero<int>();
for (int i = 0; i < len; i += 32) {
vload2_sm(stop_flag_now_int_sm + i, vec_x_0, vec_x_1);
vec_y_0 = vvadd_int32x16(vec_y_0, vec_x_0);
vec_y_1 = vvadd_int32x16(vec_y_1, vec_x_1);
}
sum = v_reduce(vec_y_0, vec_y_1);
}
return sum;
}
__global__ void speculate_schedule_cache(const int64_t *draft_tokens,
int *block_tables,
bool *stop_flags,
const int64_t *prompt_lens,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int64_t *step_draft_tokens,
int *step_seq_lens_this_time,
int *accept_num,
int64_t *accept_tokens,
bool *is_block_step,
bool *not_need_stop,
const int64_t *stop_nums,
const int real_bsz,
const int max_bsz,
const int max_next_step_tokens,
const int draft_tokens_len,
const int accept_tokens_len,
const int block_size,
const int block_num_per_seq,
const bool prefill_one_step_stop) {
const int cid = core_id();
const int tid = core_id() * cluster_num() + cluster_id();
const int nthreads = core_num() * cluster_num();
__shared__ int stop_flag_now_int_sm[64];
int value_zero = 0;
bool value_true = true;
bool value_false = false;
int bid_start_core, bid_end_core;
partition(tid, nthreads, max_bsz, 1, &bid_start_core, &bid_end_core);
int64_t draft_tokens_lm[draft_tokens_len];
int64_t step_draft_tokens_lm[draft_tokens_len];
int block_table_lm[block_num_per_seq];
int64_t accept_tokens_lm[accept_tokens_len];
int *seq_lens_encoder_lm;
int seq_lens_decoder_lm;
int64_t prompt_len_lm;
int seq_lens_this_time_lm;
bool stop_flag_lm;
stop_flag_now_int_sm[cid] = 0;
for (int bid = bid_start_core; bid < bid_end_core; bid++) {
int stop_flag_now_int = 0;
GM2LM_ASYNC(draft_tokens + bid * draft_tokens_len,
draft_tokens_lm,
draft_tokens_len * sizeof(int64_t));
GM2LM_ASYNC(step_draft_tokens + bid * draft_tokens_len,
step_draft_tokens_lm,
draft_tokens_len * sizeof(int64_t));
GM2LM_ASYNC(block_tables + bid * block_num_per_seq,
block_table_lm,
block_num_per_seq * sizeof(int));
GM2LM_ASYNC(accept_tokens + bid * accept_tokens_len,
accept_tokens_lm,
accept_tokens_len * sizeof(int64_t));
GM2LM_ASYNC(seq_lens_decoder + bid, &seq_lens_decoder_lm, sizeof(int));
GM2LM_ASYNC(prompt_lens + bid, &prompt_len_lm, sizeof(int64_t));
GM2LM_ASYNC(seq_lens_this_time + bid, &seq_lens_this_time_lm, sizeof(int));
GM2LM_ASYNC(stop_flags + bid, &stop_flag_lm, sizeof(int));
mfence();
if (!stop_flag_lm) {
if (seq_lens_decoder_lm >= prompt_len_lm) {
const int max_possible_block_idx =
(seq_lens_decoder_lm + max_next_step_tokens) / block_size;
if (prefill_one_step_stop) {
LM2GM_ASYNC(&value_true, stop_flags + bid, sizeof(bool));
LM2GM_ASYNC(&value_zero, seq_lens_this_time + bid, sizeof(int));
LM2GM_ASYNC(&value_zero, seq_lens_decoder + bid, sizeof(int));
LM2GM_ASYNC(&value_zero, seq_lens_encoder + bid, sizeof(int));
LM2GM_ASYNC(&value_zero, accept_num + bid, sizeof(int));
mfence();
stop_flag_now_int = 1;
} else if (max_possible_block_idx < block_num_per_seq &&
block_table_lm[max_possible_block_idx] == -1) {
LM2GM_ASYNC(&value_true, is_block_step + bid, sizeof(bool));
LM2GM_ASYNC(&seq_lens_this_time_lm,
step_seq_lens_this_time + bid,
sizeof(int));
LM2GM_ASYNC(&value_zero, seq_lens_this_time + bid, sizeof(int));
LM2GM_ASYNC(&value_true, stop_flags + bid, sizeof(bool));
stop_flag_now_int = 1;
LM2GM_ASYNC(
&seq_lens_decoder_lm, step_seq_lens_decoder + bid, sizeof(int));
LM2GM_ASYNC(&value_zero, seq_lens_decoder + bid, sizeof(int));
LM2GM_ASYNC(&value_zero, accept_num + bid, sizeof(int));
mfence();
for (int i = 0; i < accept_tokens_len; i++) {
accept_tokens_lm[i] = -1;
}
for (int i = 0; i < draft_tokens_len; i++) {
step_draft_tokens_lm[i] = draft_tokens_lm[i];
}
LM2GM_ASYNC(accept_tokens_lm,
accept_tokens + bid * accept_tokens_len,
accept_tokens_len * sizeof(int64_t));
LM2GM_ASYNC(step_draft_tokens_lm,
step_draft_tokens + bid * draft_tokens_len,
draft_tokens_len * sizeof(int64_t));
}
} else {
// prefill
LM2GM_ASYNC(&value_true, stop_flags + bid, sizeof(bool));
LM2GM_ASYNC(&value_zero, seq_lens_this_time + bid, sizeof(int));
LM2GM_ASYNC(&value_zero, seq_lens_decoder + bid, sizeof(int));
LM2GM_ASYNC(&value_zero, seq_lens_encoder + bid, sizeof(int));
LM2GM_ASYNC(&value_zero, accept_num + bid, sizeof(int));
mfence();
stop_flag_now_int = 1;
}
} else {
stop_flag_now_int = 1;
}
stop_flag_now_int_sm[cid] += stop_flag_now_int;
}
mfence_sm();
sync_all();
// reduce stop_sum_now
if (cid == 0) {
int stop_sum = ClusterReduce(stop_flag_now_int_sm, 64);
int stop_nums_lm;
GM2LM_ASYNC(stop_nums, &stop_nums_lm, sizeof(int));
mfence();
sync_all();
if (stop_sum < stop_nums_lm) {
LM2GM_ASYNC(&value_true, not_need_stop, sizeof(bool));
} else {
LM2GM_ASYNC(&value_false, not_need_stop, sizeof(bool));
}
mfence();
}
}
} // namespace plugin
} // namespace xpu3
@@ -0,0 +1,319 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <numeric>
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
__attribute__((global)) void speculate_schedule_cache(
const int64_t *draft_tokens,
int *block_tables,
bool *stop_flags,
const int64_t *prompt_lens,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int64_t *step_draft_tokens,
int *step_seq_lens_this_time,
int *accept_num,
int64_t *accept_tokens,
bool *is_block_step,
bool *not_need_stop,
const int64_t *stop_nums,
const int real_bsz,
const int max_bsz,
const int max_next_step_tokens,
const int draft_tokens_len,
const int accept_tokens_len,
const int block_size,
const int block_num_per_seq,
const bool prefill_one_step_stop);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
static int cpu_wrapper(Context *ctx,
const int64_t *draft_tokens,
int *block_tables,
bool *stop_flags,
const int64_t *prompt_lens,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int64_t *step_draft_tokens,
int *step_seq_lens_this_time,
int *accept_num,
int64_t *accept_tokens,
bool *is_block_step,
bool *not_need_stop,
const int64_t *stop_nums,
const int real_bsz,
const int max_bsz,
const int max_next_step_tokens,
const int draft_tokens_len,
const int accept_tokens_len,
const int block_size,
const int block_num_per_seq,
const bool prefill_one_step_stop) {
int stop_sum_now = 0;
for (int bid = 0; bid < real_bsz; bid++) {
if (!stop_flags[bid]) {
const int64_t *draft_tokens_now = draft_tokens + bid * draft_tokens_len;
int64_t *step_draft_tokens_now =
step_draft_tokens + bid * draft_tokens_len;
int *block_table_now = block_tables + bid * block_num_per_seq;
int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len;
if (seq_lens_decoder[bid] >= prompt_lens[bid]) {
const int max_possible_block_idx =
(seq_lens_decoder[bid] + max_next_step_tokens) / block_size;
if (prefill_one_step_stop) {
stop_flags[bid] = true;
seq_lens_this_time[bid] = 0;
seq_lens_decoder[bid] = 0;
seq_lens_encoder[bid] = 0;
accept_num[bid] = 0;
stop_sum_now += 1;
} else if (max_possible_block_idx < block_num_per_seq &&
block_table_now[max_possible_block_idx] == -1) {
is_block_step[bid] = true;
step_seq_lens_this_time[bid] = seq_lens_this_time[bid];
seq_lens_this_time[bid] = 0;
stop_flags[bid] = true;
stop_sum_now += 1;
step_seq_lens_decoder[bid] = seq_lens_decoder[bid];
seq_lens_decoder[bid] = 0;
accept_num[bid] = 0;
for (int i = 0; i < accept_tokens_len; i++) {
accept_tokens_now[i] = -1;
}
for (int i = 0; i < draft_tokens_len; i++) {
step_draft_tokens_now[i] = draft_tokens_now[i];
}
}
} else {
// prefill
stop_flags[bid] = true;
seq_lens_this_time[bid] = 0;
seq_lens_decoder[bid] = 0;
seq_lens_encoder[bid] = 0;
accept_num[bid] = 0;
stop_sum_now += 1;
}
} else {
stop_sum_now += 1;
}
}
// for (int bid = real_bsz; i < max_bsz; bid++) {
// stop_sum_now += 1;
// }
// printf("stop_sum %d \n", stop_sum);
not_need_stop[0] = stop_sum_now < stop_nums[0];
return api::SUCCESS;
}
static int xpu3_wrapper(Context *ctx,
const int64_t *draft_tokens,
int *block_tables,
bool *stop_flags,
const int64_t *prompt_lens,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int64_t *step_draft_tokens,
int *step_seq_lens_this_time,
int *accept_num,
int64_t *accept_tokens,
bool *is_block_step,
bool *not_need_stop,
const int64_t *stop_nums,
const int real_bsz,
const int max_bsz,
const int max_next_step_tokens,
const int draft_tokens_len,
const int accept_tokens_len,
const int block_size,
const int block_num_per_seq,
const bool prefill_one_step_stop) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
using XPU_TI = typename XPUIndexType<int64_t>::type;
xpu3::plugin::speculate_schedule_cache<<<1, 64, ctx->xpu_stream>>>(
(const XPU_TI *)draft_tokens,
block_tables,
stop_flags,
(const XPU_TI *)prompt_lens,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_seq_lens_decoder,
reinterpret_cast<XPU_TI *>(step_draft_tokens),
step_seq_lens_this_time,
accept_num,
reinterpret_cast<XPU_TI *>(accept_tokens),
is_block_step,
not_need_stop,
(const XPU_TI *)stop_nums,
real_bsz,
max_bsz,
max_next_step_tokens,
draft_tokens_len,
accept_tokens_len,
block_size,
block_num_per_seq,
prefill_one_step_stop);
return api::SUCCESS;
}
int speculate_schedule_cache(Context *ctx,
const int64_t *draft_tokens,
int *block_tables,
bool *stop_flags,
const int64_t *prompt_lens,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int64_t *step_draft_tokens,
int *step_seq_lens_this_time,
int *accept_num,
int64_t *accept_tokens,
bool *is_block_step,
bool *not_need_stop,
const int64_t *stop_nums,
const int real_bsz,
const int max_bsz,
const int max_next_step_tokens,
const int draft_tokens_len,
const int accept_tokens_len,
const int block_size,
const int block_num_per_seq,
const bool prefill_one_step_stop) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_schedule_cache", float);
WRAPPER_DUMP_PARAM6(ctx,
draft_tokens,
block_tables,
stop_flags,
prompt_lens,
seq_lens_this_time,
seq_lens_encoder);
WRAPPER_DUMP_PARAM6(ctx,
seq_lens_decoder,
step_seq_lens_decoder,
step_draft_tokens,
step_seq_lens_this_time,
accept_num,
accept_tokens);
WRAPPER_DUMP_PARAM6(ctx,
is_block_step,
not_need_stop,
stop_nums,
real_bsz,
max_bsz,
max_next_step_tokens);
WRAPPER_DUMP_PARAM5(ctx,
draft_tokens_len,
accept_tokens_len,
block_size,
block_num_per_seq,
prefill_one_step_stop);
WRAPPER_ASSERT_GT(ctx, draft_tokens_len, 0);
WRAPPER_ASSERT_GT(ctx, accept_tokens_len, 0);
WRAPPER_ASSERT_GT(ctx, block_num_per_seq, 0);
WRAPPER_ASSERT_GT(ctx, real_bsz, 0);
WRAPPER_ASSERT_GT(ctx, block_size, 0);
WRAPPER_ASSERT_GT(ctx, max_next_step_tokens, 0);
WRAPPER_ASSERT_GE(ctx, max_bsz, real_bsz);
WRAPPER_CHECK_PTR(ctx, int64_t, draft_tokens_len * real_bsz, draft_tokens);
WRAPPER_CHECK_PTR(
ctx, int64_t, draft_tokens_len * real_bsz, step_draft_tokens);
WRAPPER_CHECK_PTR(ctx, int64_t, accept_tokens_len * real_bsz, accept_tokens);
WRAPPER_CHECK_PTR(ctx, int, block_num_per_seq *real_bsz, block_tables);
WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_decoder);
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz, prompt_lens);
WRAPPER_CHECK_PTR(ctx, bool, real_bsz, stop_flags);
WRAPPER_DUMP(ctx);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(ctx,
draft_tokens,
block_tables,
stop_flags,
prompt_lens,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_seq_lens_decoder,
step_draft_tokens,
step_seq_lens_this_time,
accept_num,
accept_tokens,
is_block_step,
not_need_stop,
stop_nums,
real_bsz,
max_bsz,
max_next_step_tokens,
draft_tokens_len,
accept_tokens_len,
block_size,
block_num_per_seq,
prefill_one_step_stop);
}
if (ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx,
draft_tokens,
block_tables,
stop_flags,
prompt_lens,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_seq_lens_decoder,
step_draft_tokens,
step_seq_lens_this_time,
accept_num,
accept_tokens,
is_block_step,
not_need_stop,
stop_nums,
real_bsz,
max_bsz,
max_next_step_tokens,
draft_tokens_len,
accept_tokens_len,
block_size,
block_num_per_seq,
prefill_one_step_stop);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
@@ -126,7 +126,7 @@ def test_main(
# 84, 100352, 12288, 1, 1, 54, 32768);
def miain():
def main():
seed = np.random.randint(1, 1e9)
print(f"random seed is {seed}")
np.random.seed(seed)
@@ -203,4 +203,4 @@ def miain():
if __name__ == "__main__":
for i in range(10):
miain()
main()
@@ -0,0 +1,266 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import unittest
import numpy as np
import paddle
from fastdeploy.model_executor.ops.xpu import speculate_schedule_cache
def cpu_reference(
draft_tokens,
block_tables,
stop_flags,
prompt_lens,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_seq_lens_decoder,
step_draft_tokens,
step_seq_lens_this_time,
accept_num,
accept_tokens,
is_block_step,
not_need_stop,
stop_nums,
block_size,
max_draft_tokens,
):
"""Pure-NumPy mirror of the CUDA kernel's logic (single block of 512 threads).
Shapes are the same as inputs to the custom op. This mutates the provided
NumPy arrays in-place, exactly like the kernel does.
"""
real_bsz = seq_lens_this_time.shape[0]
max_bsz = stop_flags.shape[0]
draft_tokens_len = draft_tokens.shape[1]
block_num_per_seq = block_tables.shape[1]
max_next_step_tokens = 2 * max_draft_tokens + 2
# Block-local reduction input per thread (threadIdx.x -> bid)
stop_flag_now_int = np.zeros(512, dtype=np.int64) # THREADBLOCK_SIZE = 512
for bid in range(512):
if bid < real_bsz:
if not stop_flags[bid]:
max_possible_block_idx = (seq_lens_decoder[bid] + max_next_step_tokens) // block_size
if max_possible_block_idx < block_num_per_seq and block_tables[bid, max_possible_block_idx] == -1:
is_block_step[bid] = True
step_seq_lens_this_time[bid] = seq_lens_this_time[bid]
seq_lens_this_time[bid] = 0
stop_flags[bid] = True
step_seq_lens_decoder[bid] = seq_lens_decoder[bid]
seq_lens_decoder[bid] = 0
accept_num[bid] = 0
accept_tokens[bid, :] = -1
step_draft_tokens[bid, :draft_tokens_len] = draft_tokens[bid, :draft_tokens_len]
stop_flag_now_int[bid] = 1
else:
stop_flag_now_int[bid] = 0
else:
stop_flag_now_int[bid] = 1
elif bid < max_bsz:
# Threads in [real_bsz, max_bsz) contribute 1 to reduction
stop_flag_now_int[bid] = 1
else:
stop_flag_now_int[bid] = 0
stop_sum = int(stop_flag_now_int.sum())
not_need_stop[0] = stop_sum < int(stop_nums[0])
class TestSpeculateScheduleCache(unittest.TestCase):
@classmethod
def setUpClass(cls):
pass
# paddle.device.set_device("cpu")
def setUp(self):
# --- Construct a deterministic case that exercises all branches ---
# real_bsz < max_bsz to test the padding logic in the CUB reduction
self.real_bsz = 3 # 实际批次大小
self.max_bsz = 5 # 最大批次大小(仅stop_flags使用)
self.draft_tokens_len = 6 # 草稿令牌长度
self.accept_tokens_len = 5 # 接受令牌长度
self.block_size = 4 # 块大小
self.block_num_per_seq = 3 # 每个序列的块数
self.max_draft_tokens = 2 # 最大草稿令牌数 -> max_next_step_tokens = 6
# 输入数据:bid0触发,bid2不触发,bid1已停止
# seq_lens_decoder + 6 // 4 -> 索引:[1, 1, 4]。索引4越界 -> bid2不触发
self.draft_tokens = paddle.to_tensor(
np.array(
[
[1, 1, 1, 1, 1, 1], # bid0的草稿令牌
[2, 2, 2, 2, 2, 2], # bid1的草稿令牌
[3, 3, 3, 3, 3, 3], # bid2的草稿令牌
],
dtype=np.int64,
)
)
self.block_tables = paddle.to_tensor(np.full((self.real_bsz, self.block_num_per_seq), -1, dtype=np.int32))
# stop_flags长度为max_bsz,其他输入为real_bsz
self.stop_flags = paddle.to_tensor(np.array([False, True, False, False, False], dtype=np.bool_))
self.prompt_lens = paddle.to_tensor(np.array([1, 1, 1], dtype=np.int64))
self.seq_lens_this_time = paddle.to_tensor(np.array([5, 6, 7], dtype=np.int32))
self.seq_lens_encoder = paddle.to_tensor(np.array([1, 1, 1], dtype=np.int32))
self.seq_lens_decoder = paddle.to_tensor(np.array([1, 1, 10], dtype=np.int32))
# 由kernel填充的输出张量(仅对触发批次)
self.step_seq_lens_decoder = paddle.zeros((self.real_bsz,), dtype="int32")
self.step_draft_tokens = paddle.zeros((self.real_bsz, self.draft_tokens_len), dtype="int64")
self.step_seq_lens_this_time = paddle.zeros((self.real_bsz,), dtype="int32")
# 故意设置为非零值,以便验证只对触发部分进行原地清零
self.accept_num = paddle.to_tensor(np.array([9, 8, 7], dtype=np.int32))
self.accept_tokens = paddle.to_tensor(
np.arange(self.real_bsz * self.accept_tokens_len, dtype=np.int64).reshape(
self.real_bsz, self.accept_tokens_len
)
)
self.is_block_step = paddle.zeros((self.real_bsz,), dtype=paddle.bool)
# not_need_stop在调用者端位于CPUkernel内部会复制到设备
self.not_need_stop = paddle.zeros((1,), dtype=paddle.bool).cpu()
# 设置阈值:bid0触发,bid1已停止,填充(5-3)=2 -> stop_sum = 1+1+2 = 4
# 设置stop_nums为5,使得not_need_stop = (4 < 5) = True
self.stop_nums = paddle.to_tensor([5], dtype=paddle.int64)
# 保存NumPy副本用于CPU参考实现对比
self.np_draft_tokens = self.draft_tokens.numpy().copy()
self.np_block_tables = self.block_tables.numpy().copy()
self.np_stop_flags = self.stop_flags.numpy().copy()
self.np_prompt_lens = self.prompt_lens.numpy().copy()
self.np_seq_lens_this_time = self.seq_lens_this_time.numpy().copy()
self.np_seq_lens_encoder = self.seq_lens_encoder.numpy().copy()
self.np_seq_lens_decoder = self.seq_lens_decoder.numpy().copy()
self.np_step_seq_lens_decoder = self.step_seq_lens_decoder.numpy().copy()
self.np_step_draft_tokens = self.step_draft_tokens.numpy().copy()
self.np_step_seq_lens_this_time = self.step_seq_lens_this_time.numpy().copy()
self.np_accept_num = self.accept_num.numpy().copy()
self.np_accept_tokens = self.accept_tokens.numpy().copy()
self.np_is_block_step = self.is_block_step.numpy().copy()
self.np_not_need_stop = self.not_need_stop.numpy().copy()
self.np_stop_nums = self.stop_nums.numpy().copy()
def test_correctness_against_cpu_reference(self):
# Run GPU kernel (in-place)
speculate_schedule_cache(
self.draft_tokens,
self.block_tables,
self.stop_flags,
self.prompt_lens,
self.seq_lens_this_time,
self.seq_lens_encoder,
self.seq_lens_decoder,
self.step_seq_lens_decoder,
self.step_draft_tokens,
self.step_seq_lens_this_time,
self.accept_num,
self.accept_tokens,
self.is_block_step,
self.not_need_stop,
self.stop_nums,
self.block_size,
self.max_draft_tokens,
)
# Compute CPU reference (in-place on NumPy copies)
cpu_reference(
self.np_draft_tokens,
self.np_block_tables,
self.np_stop_flags,
self.prompt_lens,
self.np_seq_lens_this_time,
self.np_seq_lens_encoder,
self.np_seq_lens_decoder,
self.np_step_seq_lens_decoder,
self.np_step_draft_tokens,
self.np_step_seq_lens_this_time,
self.np_accept_num,
self.np_accept_tokens,
self.np_is_block_step,
self.np_not_need_stop,
self.np_stop_nums,
self.block_size,
self.max_draft_tokens,
)
# Compare all mutated tensors
np.testing.assert_array_equal(self.step_draft_tokens.numpy(), self.np_step_draft_tokens)
np.testing.assert_array_equal(self.accept_tokens.numpy(), self.np_accept_tokens)
np.testing.assert_array_equal(self.stop_flags.numpy(), self.np_stop_flags)
np.testing.assert_array_equal(self.is_block_step.numpy(), self.np_is_block_step)
np.testing.assert_array_equal(self.seq_lens_this_time.numpy(), self.np_seq_lens_this_time)
np.testing.assert_array_equal(self.seq_lens_decoder.numpy(), self.np_seq_lens_decoder)
np.testing.assert_array_equal(self.step_seq_lens_decoder.numpy(), self.np_step_seq_lens_decoder)
np.testing.assert_array_equal(self.step_seq_lens_this_time.numpy(), self.np_step_seq_lens_this_time)
np.testing.assert_array_equal(self.accept_num.numpy(), self.np_accept_num)
self.assertEqual(bool(self.not_need_stop.numpy()[0]), bool(self.np_not_need_stop[0]))
def test_no_trigger_path(self):
# Make block_tables at candidate index != -1 so nothing triggers
# Candidate index for bid 0/1 is 1, set it to 7
bt = self.block_tables.numpy()
bt[:, 1] = 7
self.block_tables = paddle.to_tensor(bt)
# Reset outputs to distinctive values
self.step_seq_lens_decoder[:] = 0
self.step_draft_tokens[:] = 0
self.step_seq_lens_this_time[:] = 0
self.accept_num[:] = -123
self.accept_tokens[:] = -777
self.is_block_step[:] = False
self.not_need_stop[:] = False
# For not_need_stop: stopped_in_real = (bid1 True) = 1, padding = 2 -> stop_sum=3
# With stop_nums=5 -> True
speculate_schedule_cache(
self.draft_tokens,
self.block_tables,
self.stop_flags,
self.prompt_lens,
self.seq_lens_this_time,
self.seq_lens_encoder,
self.seq_lens_decoder,
self.step_seq_lens_decoder,
self.step_draft_tokens,
self.step_seq_lens_this_time,
self.accept_num,
self.accept_tokens,
self.is_block_step,
self.not_need_stop,
self.stop_nums,
self.block_size,
self.max_draft_tokens,
)
# Nothing should have changed except not_need_stop
np.testing.assert_array_equal(self.step_draft_tokens.numpy(), np.zeros_like(self.step_draft_tokens.numpy()))
np.testing.assert_array_equal(self.is_block_step.numpy(), np.zeros_like(self.is_block_step.numpy()))
np.testing.assert_array_equal(self.accept_tokens.numpy(), np.full_like(self.accept_tokens.numpy(), -777))
np.testing.assert_array_equal(self.accept_num.numpy(), np.full_like(self.accept_num.numpy(), -123))
self.assertTrue(bool(self.not_need_stop.numpy()[0]))
if __name__ == "__main__":
unittest.main()