mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[XPU] Speculate Decoding + PD, benchmark fix (#6036)
* fix mtp pd * fix kernel * fix code style * fix kernel * fix test / clear debug code * fix test / clear debug code * fix codestyle * fix codestyle * fix codestyle
This commit is contained in:
@@ -18,15 +18,7 @@
|
||||
#include <sys/msg.h>
|
||||
#include <sys/types.h>
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#define MAX_BSZ 256
|
||||
#define MAX_DRAFT_TOKENS 6
|
||||
|
||||
struct msgdata {
|
||||
int64_t mtype;
|
||||
int mtext[MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ +
|
||||
2]; // stop_flag, bsz, accept_num*bsz, tokens...
|
||||
};
|
||||
#include "speculate_msg.h"
|
||||
|
||||
void SpeculateGetOutput(const paddle::Tensor& x,
|
||||
int64_t rank_id,
|
||||
@@ -49,7 +41,7 @@ void SpeculateGetOutput(const paddle::Tensor& x,
|
||||
msg_queue_id = inference_msg_queue_id_from_env;
|
||||
}
|
||||
|
||||
static struct msgdata msg_rcv;
|
||||
static struct speculate_msgdata msg_rcv;
|
||||
|
||||
static key_t key = ftok("./", msg_queue_id);
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@
|
||||
#include <sys/types.h>
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#define MAX_BSZ 512
|
||||
#define MAX_BSZ 256
|
||||
#define MAX_DRAFT_TOKENS 6
|
||||
|
||||
struct speculate_msgdata {
|
||||
|
||||
@@ -19,17 +19,9 @@
|
||||
#include <sys/msg.h>
|
||||
#include <sys/types.h>
|
||||
#include "paddle/extension.h"
|
||||
#include "speculate_msg.h"
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
#define MAX_BSZ 256
|
||||
#define MAX_DRAFT_TOKENS 6
|
||||
|
||||
struct msgdata {
|
||||
long mtype; // NOLINT
|
||||
int mtext[MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ +
|
||||
2]; // stop_flag, bsz, tokens
|
||||
};
|
||||
|
||||
void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
@@ -59,7 +51,7 @@ void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
|
||||
#endif
|
||||
msg_queue_id = inference_msg_queue_id_from_env;
|
||||
}
|
||||
static struct msgdata msg_sed;
|
||||
static struct speculate_msgdata msg_sed;
|
||||
static key_t key = ftok("./", msg_queue_id);
|
||||
static int msgid = msgget(key, IPC_CREAT | 0666);
|
||||
|
||||
|
||||
@@ -0,0 +1,146 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <paddle/phi/backends/xpu/xpu_context.h>
|
||||
#include "paddle/phi/core/enforce.h"
|
||||
#include "speculate_msg.h" // NOLINT
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
// 为不修改接口调用方式,入参暂不改变
|
||||
void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &prompt_lens,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &step_seq_lens_decoder,
|
||||
const paddle::Tensor &step_draft_tokens,
|
||||
const paddle::Tensor &step_seq_lens_this_time,
|
||||
const paddle::Tensor &accept_num,
|
||||
const paddle::Tensor &accept_tokens,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
const paddle::Tensor &stop_nums,
|
||||
const int block_size,
|
||||
const int max_draft_tokens) {
|
||||
namespace api = baidu::xpu::api;
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||
api::Context *ctx = xpu_ctx->x_context();
|
||||
if (stop_flags.is_cpu()) {
|
||||
ctx = new api::Context(api::kCPU);
|
||||
}
|
||||
|
||||
const int real_bsz = seq_lens_this_time.shape()[0];
|
||||
const int max_bsz = stop_flags.shape()[0];
|
||||
const int accept_tokens_len = accept_tokens.shape()[1];
|
||||
const int draft_token_len = draft_tokens.shape()[1];
|
||||
const int block_num_per_seq = block_tables.shape()[1];
|
||||
const int max_next_step_tokens = 2 * max_draft_tokens + 2;
|
||||
constexpr int BlockSize = MAX_BSZ; // bsz <= 512
|
||||
bool prefill_one_step_stop = false;
|
||||
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP_V1")) {
|
||||
if (env_p[0] == '1') {
|
||||
prefill_one_step_stop = true;
|
||||
}
|
||||
}
|
||||
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
|
||||
|
||||
int r = baidu::xpu::api::plugin::speculate_schedule_cache(
|
||||
ctx,
|
||||
draft_tokens.data<int64_t>(),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
prompt_lens.data<int64_t>(),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(step_seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t *>(step_draft_tokens.data<int64_t>()),
|
||||
const_cast<int *>(step_seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(accept_num.data<int>()),
|
||||
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
|
||||
stop_nums.data<int64_t>(),
|
||||
real_bsz,
|
||||
max_bsz,
|
||||
max_next_step_tokens,
|
||||
draft_token_len,
|
||||
accept_tokens_len,
|
||||
block_size,
|
||||
block_num_per_seq,
|
||||
prefill_one_step_stop);
|
||||
// kernel launch
|
||||
PD_CHECK(r == 0, "speculate_free_and_reschedule failed.");
|
||||
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), true);
|
||||
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
// PD_BUILD_STATIC_OP(speculate_schedule_cache)
|
||||
PD_BUILD_OP(speculate_schedule_cache)
|
||||
.Inputs({"draft_tokens",
|
||||
"block_tables",
|
||||
"stop_flags",
|
||||
"prompt_lens",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"step_seq_lens_decoder",
|
||||
"step_draft_tokens",
|
||||
"step_seq_lens_this_time",
|
||||
"accept_num",
|
||||
"accept_tokens",
|
||||
"is_block_step",
|
||||
"not_need_stop",
|
||||
"stop_nums"})
|
||||
.Attrs({"block_size: int", "max_draft_tokens: int"})
|
||||
.Outputs({"draft_tokens_out",
|
||||
"block_tables_out",
|
||||
"stop_flags_out",
|
||||
"seq_lens_this_time_out",
|
||||
"seq_lens_encoder_out",
|
||||
"seq_lens_decoder_out",
|
||||
"step_seq_lens_decoder_out",
|
||||
"step_draft_tokens_out",
|
||||
"step_seq_lens_this_time_out",
|
||||
"accept_num_out",
|
||||
"accept_tokens_out",
|
||||
"is_block_step_out",
|
||||
"not_need_stop_out"})
|
||||
.SetInplaceMap({
|
||||
{"draft_tokens", "draft_tokens_out"},
|
||||
{"block_tables", "block_tables_out"},
|
||||
{"stop_flags", "stop_flags_out"},
|
||||
{"seq_lens_this_time", "seq_lens_this_time_out"},
|
||||
{"seq_lens_encoder", "seq_lens_encoder_out"},
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"step_seq_lens_decoder", "step_seq_lens_decoder_out"},
|
||||
{"step_draft_tokens", "step_draft_tokens_out"},
|
||||
{"step_seq_lens_this_time", "step_seq_lens_this_time_out"},
|
||||
{"accept_num", "accept_num_out"},
|
||||
{"accept_tokens", "accept_tokens_out"},
|
||||
{"is_block_step", "is_block_step_out"},
|
||||
{"not_need_stop", "not_need_stop_out"},
|
||||
})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateScheduleCache));
|
||||
@@ -61,7 +61,7 @@ void SpeculateStepSchedule(
|
||||
const int block_num_per_seq = block_tables.shape()[1];
|
||||
const int length = input_ids.shape()[1];
|
||||
const int pre_id_length = pre_ids.shape()[1];
|
||||
constexpr int BlockSize = 256; // bsz <= 256
|
||||
constexpr int BlockSize = MAX_BSZ; // bsz <= MAX_BSZ
|
||||
const int max_decoder_block_num =
|
||||
length / block_size -
|
||||
encoder_decoder_block_num; // 最大输出长度对应的block -
|
||||
|
||||
@@ -574,6 +574,31 @@ DLL_EXPORT int speculate_free_and_reschedule(Context* ctx,
|
||||
const int max_decoder_block_num,
|
||||
const int max_draft_tokens);
|
||||
|
||||
DLL_EXPORT int speculate_schedule_cache(Context* ctx,
|
||||
const int64_t* draft_tokens,
|
||||
int* block_tables,
|
||||
bool* stop_flags,
|
||||
const int64_t* prompt_lens,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int* step_seq_lens_decoder,
|
||||
int64_t* step_draft_tokens,
|
||||
int* step_seq_lens_this_time,
|
||||
int* accept_num,
|
||||
int64_t* accept_tokens,
|
||||
bool* is_block_step,
|
||||
bool* not_need_stop,
|
||||
const int64_t* stop_nums,
|
||||
const int real_bsz,
|
||||
const int max_bsz,
|
||||
const int max_next_step_tokens,
|
||||
const int draft_tokens_len,
|
||||
const int accept_tokens_len,
|
||||
const int block_size,
|
||||
const int block_num_per_seq,
|
||||
const bool prefill_one_step_stop);
|
||||
|
||||
DLL_EXPORT int speculate_update_v3(Context* ctx,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
|
||||
+182
@@ -0,0 +1,182 @@
|
||||
#include "xpu/kernel/cluster.h"
|
||||
#include "xpu/kernel/cluster_partition.h"
|
||||
#include "xpu/kernel/cluster_primitive.h"
|
||||
#include "xpu/kernel/xtdk_io.h"
|
||||
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
|
||||
static inline __device__ int v_reduce(int32x16_t &v0, int32x16_t &v1) {
|
||||
int res;
|
||||
v1 = vvadd_int32x16(v0, v1);
|
||||
auto v = vsrlp_int32x16(256, v1);
|
||||
v1 = vvadd_int32x16(v, v1);
|
||||
v = vsrlp_int32x16(128, v1);
|
||||
v1 = vvadd_int32x16(v, v1);
|
||||
v = vsrlp_int32x16(64, v1);
|
||||
v1 = vvadd_int32x16(v, v1);
|
||||
v = vsrlp_int32x16(32, v1);
|
||||
v1 = vvadd_int32x16(v, v1);
|
||||
res = vextract_int32x16(v1, 1);
|
||||
return res;
|
||||
}
|
||||
|
||||
static inline __device__ int ClusterReduce(
|
||||
const _shared_ptr_ int *stop_flag_now_int_sm, int len) {
|
||||
int sum = 0;
|
||||
if (core_id() == 0) {
|
||||
int32x16_t vec_x_0;
|
||||
int32x16_t vec_x_1;
|
||||
int32x16_t vec_y_0 = vzero<int>();
|
||||
int32x16_t vec_y_1 = vzero<int>();
|
||||
for (int i = 0; i < len; i += 32) {
|
||||
vload2_sm(stop_flag_now_int_sm + i, vec_x_0, vec_x_1);
|
||||
vec_y_0 = vvadd_int32x16(vec_y_0, vec_x_0);
|
||||
vec_y_1 = vvadd_int32x16(vec_y_1, vec_x_1);
|
||||
}
|
||||
sum = v_reduce(vec_y_0, vec_y_1);
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
__global__ void speculate_schedule_cache(const int64_t *draft_tokens,
|
||||
int *block_tables,
|
||||
bool *stop_flags,
|
||||
const int64_t *prompt_lens,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int64_t *step_draft_tokens,
|
||||
int *step_seq_lens_this_time,
|
||||
int *accept_num,
|
||||
int64_t *accept_tokens,
|
||||
bool *is_block_step,
|
||||
bool *not_need_stop,
|
||||
const int64_t *stop_nums,
|
||||
const int real_bsz,
|
||||
const int max_bsz,
|
||||
const int max_next_step_tokens,
|
||||
const int draft_tokens_len,
|
||||
const int accept_tokens_len,
|
||||
const int block_size,
|
||||
const int block_num_per_seq,
|
||||
const bool prefill_one_step_stop) {
|
||||
const int cid = core_id();
|
||||
const int tid = core_id() * cluster_num() + cluster_id();
|
||||
const int nthreads = core_num() * cluster_num();
|
||||
|
||||
__shared__ int stop_flag_now_int_sm[64];
|
||||
int value_zero = 0;
|
||||
bool value_true = true;
|
||||
bool value_false = false;
|
||||
|
||||
int bid_start_core, bid_end_core;
|
||||
partition(tid, nthreads, max_bsz, 1, &bid_start_core, &bid_end_core);
|
||||
|
||||
int64_t draft_tokens_lm[draft_tokens_len];
|
||||
int64_t step_draft_tokens_lm[draft_tokens_len];
|
||||
int block_table_lm[block_num_per_seq];
|
||||
int64_t accept_tokens_lm[accept_tokens_len];
|
||||
|
||||
int *seq_lens_encoder_lm;
|
||||
int seq_lens_decoder_lm;
|
||||
int64_t prompt_len_lm;
|
||||
int seq_lens_this_time_lm;
|
||||
bool stop_flag_lm;
|
||||
|
||||
stop_flag_now_int_sm[cid] = 0;
|
||||
for (int bid = bid_start_core; bid < bid_end_core; bid++) {
|
||||
int stop_flag_now_int = 0;
|
||||
GM2LM_ASYNC(draft_tokens + bid * draft_tokens_len,
|
||||
draft_tokens_lm,
|
||||
draft_tokens_len * sizeof(int64_t));
|
||||
GM2LM_ASYNC(step_draft_tokens + bid * draft_tokens_len,
|
||||
step_draft_tokens_lm,
|
||||
draft_tokens_len * sizeof(int64_t));
|
||||
GM2LM_ASYNC(block_tables + bid * block_num_per_seq,
|
||||
block_table_lm,
|
||||
block_num_per_seq * sizeof(int));
|
||||
GM2LM_ASYNC(accept_tokens + bid * accept_tokens_len,
|
||||
accept_tokens_lm,
|
||||
accept_tokens_len * sizeof(int64_t));
|
||||
GM2LM_ASYNC(seq_lens_decoder + bid, &seq_lens_decoder_lm, sizeof(int));
|
||||
GM2LM_ASYNC(prompt_lens + bid, &prompt_len_lm, sizeof(int64_t));
|
||||
GM2LM_ASYNC(seq_lens_this_time + bid, &seq_lens_this_time_lm, sizeof(int));
|
||||
GM2LM_ASYNC(stop_flags + bid, &stop_flag_lm, sizeof(int));
|
||||
mfence();
|
||||
|
||||
if (!stop_flag_lm) {
|
||||
if (seq_lens_decoder_lm >= prompt_len_lm) {
|
||||
const int max_possible_block_idx =
|
||||
(seq_lens_decoder_lm + max_next_step_tokens) / block_size;
|
||||
if (prefill_one_step_stop) {
|
||||
LM2GM_ASYNC(&value_true, stop_flags + bid, sizeof(bool));
|
||||
LM2GM_ASYNC(&value_zero, seq_lens_this_time + bid, sizeof(int));
|
||||
LM2GM_ASYNC(&value_zero, seq_lens_decoder + bid, sizeof(int));
|
||||
LM2GM_ASYNC(&value_zero, seq_lens_encoder + bid, sizeof(int));
|
||||
LM2GM_ASYNC(&value_zero, accept_num + bid, sizeof(int));
|
||||
mfence();
|
||||
stop_flag_now_int = 1;
|
||||
} else if (max_possible_block_idx < block_num_per_seq &&
|
||||
block_table_lm[max_possible_block_idx] == -1) {
|
||||
LM2GM_ASYNC(&value_true, is_block_step + bid, sizeof(bool));
|
||||
LM2GM_ASYNC(&seq_lens_this_time_lm,
|
||||
step_seq_lens_this_time + bid,
|
||||
sizeof(int));
|
||||
LM2GM_ASYNC(&value_zero, seq_lens_this_time + bid, sizeof(int));
|
||||
LM2GM_ASYNC(&value_true, stop_flags + bid, sizeof(bool));
|
||||
stop_flag_now_int = 1;
|
||||
LM2GM_ASYNC(
|
||||
&seq_lens_decoder_lm, step_seq_lens_decoder + bid, sizeof(int));
|
||||
LM2GM_ASYNC(&value_zero, seq_lens_decoder + bid, sizeof(int));
|
||||
LM2GM_ASYNC(&value_zero, accept_num + bid, sizeof(int));
|
||||
mfence();
|
||||
for (int i = 0; i < accept_tokens_len; i++) {
|
||||
accept_tokens_lm[i] = -1;
|
||||
}
|
||||
for (int i = 0; i < draft_tokens_len; i++) {
|
||||
step_draft_tokens_lm[i] = draft_tokens_lm[i];
|
||||
}
|
||||
LM2GM_ASYNC(accept_tokens_lm,
|
||||
accept_tokens + bid * accept_tokens_len,
|
||||
accept_tokens_len * sizeof(int64_t));
|
||||
LM2GM_ASYNC(step_draft_tokens_lm,
|
||||
step_draft_tokens + bid * draft_tokens_len,
|
||||
draft_tokens_len * sizeof(int64_t));
|
||||
}
|
||||
} else {
|
||||
// prefill
|
||||
LM2GM_ASYNC(&value_true, stop_flags + bid, sizeof(bool));
|
||||
LM2GM_ASYNC(&value_zero, seq_lens_this_time + bid, sizeof(int));
|
||||
LM2GM_ASYNC(&value_zero, seq_lens_decoder + bid, sizeof(int));
|
||||
LM2GM_ASYNC(&value_zero, seq_lens_encoder + bid, sizeof(int));
|
||||
LM2GM_ASYNC(&value_zero, accept_num + bid, sizeof(int));
|
||||
mfence();
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
} else {
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
stop_flag_now_int_sm[cid] += stop_flag_now_int;
|
||||
}
|
||||
mfence_sm();
|
||||
sync_all();
|
||||
// reduce stop_sum_now
|
||||
if (cid == 0) {
|
||||
int stop_sum = ClusterReduce(stop_flag_now_int_sm, 64);
|
||||
int stop_nums_lm;
|
||||
GM2LM_ASYNC(stop_nums, &stop_nums_lm, sizeof(int));
|
||||
mfence();
|
||||
sync_all();
|
||||
if (stop_sum < stop_nums_lm) {
|
||||
LM2GM_ASYNC(&value_true, not_need_stop, sizeof(bool));
|
||||
} else {
|
||||
LM2GM_ASYNC(&value_false, not_need_stop, sizeof(bool));
|
||||
}
|
||||
mfence();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
@@ -0,0 +1,319 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include "xpu/plugin.h"
|
||||
#include "xpu/refactor/impl_public/wrapper_check.h"
|
||||
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
__attribute__((global)) void speculate_schedule_cache(
|
||||
const int64_t *draft_tokens,
|
||||
int *block_tables,
|
||||
bool *stop_flags,
|
||||
const int64_t *prompt_lens,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int64_t *step_draft_tokens,
|
||||
int *step_seq_lens_this_time,
|
||||
int *accept_num,
|
||||
int64_t *accept_tokens,
|
||||
bool *is_block_step,
|
||||
bool *not_need_stop,
|
||||
const int64_t *stop_nums,
|
||||
const int real_bsz,
|
||||
const int max_bsz,
|
||||
const int max_next_step_tokens,
|
||||
const int draft_tokens_len,
|
||||
const int accept_tokens_len,
|
||||
const int block_size,
|
||||
const int block_num_per_seq,
|
||||
const bool prefill_one_step_stop);
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
|
||||
namespace baidu {
|
||||
namespace xpu {
|
||||
namespace api {
|
||||
namespace plugin {
|
||||
|
||||
static int cpu_wrapper(Context *ctx,
|
||||
const int64_t *draft_tokens,
|
||||
int *block_tables,
|
||||
bool *stop_flags,
|
||||
const int64_t *prompt_lens,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int64_t *step_draft_tokens,
|
||||
int *step_seq_lens_this_time,
|
||||
int *accept_num,
|
||||
int64_t *accept_tokens,
|
||||
bool *is_block_step,
|
||||
bool *not_need_stop,
|
||||
const int64_t *stop_nums,
|
||||
const int real_bsz,
|
||||
const int max_bsz,
|
||||
const int max_next_step_tokens,
|
||||
const int draft_tokens_len,
|
||||
const int accept_tokens_len,
|
||||
const int block_size,
|
||||
const int block_num_per_seq,
|
||||
const bool prefill_one_step_stop) {
|
||||
int stop_sum_now = 0;
|
||||
for (int bid = 0; bid < real_bsz; bid++) {
|
||||
if (!stop_flags[bid]) {
|
||||
const int64_t *draft_tokens_now = draft_tokens + bid * draft_tokens_len;
|
||||
int64_t *step_draft_tokens_now =
|
||||
step_draft_tokens + bid * draft_tokens_len;
|
||||
int *block_table_now = block_tables + bid * block_num_per_seq;
|
||||
int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len;
|
||||
|
||||
if (seq_lens_decoder[bid] >= prompt_lens[bid]) {
|
||||
const int max_possible_block_idx =
|
||||
(seq_lens_decoder[bid] + max_next_step_tokens) / block_size;
|
||||
|
||||
if (prefill_one_step_stop) {
|
||||
stop_flags[bid] = true;
|
||||
seq_lens_this_time[bid] = 0;
|
||||
seq_lens_decoder[bid] = 0;
|
||||
seq_lens_encoder[bid] = 0;
|
||||
accept_num[bid] = 0;
|
||||
stop_sum_now += 1;
|
||||
} else if (max_possible_block_idx < block_num_per_seq &&
|
||||
block_table_now[max_possible_block_idx] == -1) {
|
||||
is_block_step[bid] = true;
|
||||
step_seq_lens_this_time[bid] = seq_lens_this_time[bid];
|
||||
seq_lens_this_time[bid] = 0;
|
||||
stop_flags[bid] = true;
|
||||
stop_sum_now += 1;
|
||||
step_seq_lens_decoder[bid] = seq_lens_decoder[bid];
|
||||
seq_lens_decoder[bid] = 0;
|
||||
accept_num[bid] = 0;
|
||||
for (int i = 0; i < accept_tokens_len; i++) {
|
||||
accept_tokens_now[i] = -1;
|
||||
}
|
||||
for (int i = 0; i < draft_tokens_len; i++) {
|
||||
step_draft_tokens_now[i] = draft_tokens_now[i];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// prefill
|
||||
stop_flags[bid] = true;
|
||||
seq_lens_this_time[bid] = 0;
|
||||
seq_lens_decoder[bid] = 0;
|
||||
seq_lens_encoder[bid] = 0;
|
||||
accept_num[bid] = 0;
|
||||
stop_sum_now += 1;
|
||||
}
|
||||
} else {
|
||||
stop_sum_now += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// for (int bid = real_bsz; i < max_bsz; bid++) {
|
||||
// stop_sum_now += 1;
|
||||
// }
|
||||
|
||||
// printf("stop_sum %d \n", stop_sum);
|
||||
not_need_stop[0] = stop_sum_now < stop_nums[0];
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
static int xpu3_wrapper(Context *ctx,
|
||||
const int64_t *draft_tokens,
|
||||
int *block_tables,
|
||||
bool *stop_flags,
|
||||
const int64_t *prompt_lens,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int64_t *step_draft_tokens,
|
||||
int *step_seq_lens_this_time,
|
||||
int *accept_num,
|
||||
int64_t *accept_tokens,
|
||||
bool *is_block_step,
|
||||
bool *not_need_stop,
|
||||
const int64_t *stop_nums,
|
||||
const int real_bsz,
|
||||
const int max_bsz,
|
||||
const int max_next_step_tokens,
|
||||
const int draft_tokens_len,
|
||||
const int accept_tokens_len,
|
||||
const int block_size,
|
||||
const int block_num_per_seq,
|
||||
const bool prefill_one_step_stop) {
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
using XPU_TI = typename XPUIndexType<int64_t>::type;
|
||||
xpu3::plugin::speculate_schedule_cache<<<1, 64, ctx->xpu_stream>>>(
|
||||
(const XPU_TI *)draft_tokens,
|
||||
block_tables,
|
||||
stop_flags,
|
||||
(const XPU_TI *)prompt_lens,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_seq_lens_decoder,
|
||||
reinterpret_cast<XPU_TI *>(step_draft_tokens),
|
||||
step_seq_lens_this_time,
|
||||
accept_num,
|
||||
reinterpret_cast<XPU_TI *>(accept_tokens),
|
||||
is_block_step,
|
||||
not_need_stop,
|
||||
(const XPU_TI *)stop_nums,
|
||||
real_bsz,
|
||||
max_bsz,
|
||||
max_next_step_tokens,
|
||||
draft_tokens_len,
|
||||
accept_tokens_len,
|
||||
block_size,
|
||||
block_num_per_seq,
|
||||
prefill_one_step_stop);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
int speculate_schedule_cache(Context *ctx,
|
||||
const int64_t *draft_tokens,
|
||||
int *block_tables,
|
||||
bool *stop_flags,
|
||||
const int64_t *prompt_lens,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int64_t *step_draft_tokens,
|
||||
int *step_seq_lens_this_time,
|
||||
int *accept_num,
|
||||
int64_t *accept_tokens,
|
||||
bool *is_block_step,
|
||||
bool *not_need_stop,
|
||||
const int64_t *stop_nums,
|
||||
const int real_bsz,
|
||||
const int max_bsz,
|
||||
const int max_next_step_tokens,
|
||||
const int draft_tokens_len,
|
||||
const int accept_tokens_len,
|
||||
const int block_size,
|
||||
const int block_num_per_seq,
|
||||
const bool prefill_one_step_stop) {
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_schedule_cache", float);
|
||||
WRAPPER_DUMP_PARAM6(ctx,
|
||||
draft_tokens,
|
||||
block_tables,
|
||||
stop_flags,
|
||||
prompt_lens,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder);
|
||||
WRAPPER_DUMP_PARAM6(ctx,
|
||||
seq_lens_decoder,
|
||||
step_seq_lens_decoder,
|
||||
step_draft_tokens,
|
||||
step_seq_lens_this_time,
|
||||
accept_num,
|
||||
accept_tokens);
|
||||
WRAPPER_DUMP_PARAM6(ctx,
|
||||
is_block_step,
|
||||
not_need_stop,
|
||||
stop_nums,
|
||||
real_bsz,
|
||||
max_bsz,
|
||||
max_next_step_tokens);
|
||||
WRAPPER_DUMP_PARAM5(ctx,
|
||||
draft_tokens_len,
|
||||
accept_tokens_len,
|
||||
block_size,
|
||||
block_num_per_seq,
|
||||
prefill_one_step_stop);
|
||||
|
||||
WRAPPER_ASSERT_GT(ctx, draft_tokens_len, 0);
|
||||
WRAPPER_ASSERT_GT(ctx, accept_tokens_len, 0);
|
||||
WRAPPER_ASSERT_GT(ctx, block_num_per_seq, 0);
|
||||
WRAPPER_ASSERT_GT(ctx, real_bsz, 0);
|
||||
WRAPPER_ASSERT_GT(ctx, block_size, 0);
|
||||
WRAPPER_ASSERT_GT(ctx, max_next_step_tokens, 0);
|
||||
WRAPPER_ASSERT_GE(ctx, max_bsz, real_bsz);
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, draft_tokens_len * real_bsz, draft_tokens);
|
||||
WRAPPER_CHECK_PTR(
|
||||
ctx, int64_t, draft_tokens_len * real_bsz, step_draft_tokens);
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, accept_tokens_len * real_bsz, accept_tokens);
|
||||
WRAPPER_CHECK_PTR(ctx, int, block_num_per_seq *real_bsz, block_tables);
|
||||
WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_decoder);
|
||||
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz, prompt_lens);
|
||||
WRAPPER_CHECK_PTR(ctx, bool, real_bsz, stop_flags);
|
||||
WRAPPER_DUMP(ctx);
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
return cpu_wrapper(ctx,
|
||||
draft_tokens,
|
||||
block_tables,
|
||||
stop_flags,
|
||||
prompt_lens,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_seq_lens_decoder,
|
||||
step_draft_tokens,
|
||||
step_seq_lens_this_time,
|
||||
accept_num,
|
||||
accept_tokens,
|
||||
is_block_step,
|
||||
not_need_stop,
|
||||
stop_nums,
|
||||
real_bsz,
|
||||
max_bsz,
|
||||
max_next_step_tokens,
|
||||
draft_tokens_len,
|
||||
accept_tokens_len,
|
||||
block_size,
|
||||
block_num_per_seq,
|
||||
prefill_one_step_stop);
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper(ctx,
|
||||
draft_tokens,
|
||||
block_tables,
|
||||
stop_flags,
|
||||
prompt_lens,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_seq_lens_decoder,
|
||||
step_draft_tokens,
|
||||
step_seq_lens_this_time,
|
||||
accept_num,
|
||||
accept_tokens,
|
||||
is_block_step,
|
||||
not_need_stop,
|
||||
stop_nums,
|
||||
real_bsz,
|
||||
max_bsz,
|
||||
max_next_step_tokens,
|
||||
draft_tokens_len,
|
||||
accept_tokens_len,
|
||||
block_size,
|
||||
block_num_per_seq,
|
||||
prefill_one_step_stop);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
}
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
@@ -126,7 +126,7 @@ def test_main(
|
||||
# 84, 100352, 12288, 1, 1, 54, 32768);
|
||||
|
||||
|
||||
def miain():
|
||||
def main():
|
||||
seed = np.random.randint(1, 1e9)
|
||||
print(f"random seed is {seed}")
|
||||
np.random.seed(seed)
|
||||
@@ -203,4 +203,4 @@ def miain():
|
||||
|
||||
if __name__ == "__main__":
|
||||
for i in range(10):
|
||||
miain()
|
||||
main()
|
||||
|
||||
@@ -0,0 +1,266 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.ops.xpu import speculate_schedule_cache
|
||||
|
||||
|
||||
def cpu_reference(
|
||||
draft_tokens,
|
||||
block_tables,
|
||||
stop_flags,
|
||||
prompt_lens,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_seq_lens_decoder,
|
||||
step_draft_tokens,
|
||||
step_seq_lens_this_time,
|
||||
accept_num,
|
||||
accept_tokens,
|
||||
is_block_step,
|
||||
not_need_stop,
|
||||
stop_nums,
|
||||
block_size,
|
||||
max_draft_tokens,
|
||||
):
|
||||
"""Pure-NumPy mirror of the CUDA kernel's logic (single block of 512 threads).
|
||||
|
||||
Shapes are the same as inputs to the custom op. This mutates the provided
|
||||
NumPy arrays in-place, exactly like the kernel does.
|
||||
"""
|
||||
real_bsz = seq_lens_this_time.shape[0]
|
||||
max_bsz = stop_flags.shape[0]
|
||||
draft_tokens_len = draft_tokens.shape[1]
|
||||
block_num_per_seq = block_tables.shape[1]
|
||||
|
||||
max_next_step_tokens = 2 * max_draft_tokens + 2
|
||||
|
||||
# Block-local reduction input per thread (threadIdx.x -> bid)
|
||||
stop_flag_now_int = np.zeros(512, dtype=np.int64) # THREADBLOCK_SIZE = 512
|
||||
|
||||
for bid in range(512):
|
||||
if bid < real_bsz:
|
||||
if not stop_flags[bid]:
|
||||
max_possible_block_idx = (seq_lens_decoder[bid] + max_next_step_tokens) // block_size
|
||||
if max_possible_block_idx < block_num_per_seq and block_tables[bid, max_possible_block_idx] == -1:
|
||||
is_block_step[bid] = True
|
||||
step_seq_lens_this_time[bid] = seq_lens_this_time[bid]
|
||||
seq_lens_this_time[bid] = 0
|
||||
stop_flags[bid] = True
|
||||
step_seq_lens_decoder[bid] = seq_lens_decoder[bid]
|
||||
seq_lens_decoder[bid] = 0
|
||||
accept_num[bid] = 0
|
||||
accept_tokens[bid, :] = -1
|
||||
step_draft_tokens[bid, :draft_tokens_len] = draft_tokens[bid, :draft_tokens_len]
|
||||
stop_flag_now_int[bid] = 1
|
||||
else:
|
||||
stop_flag_now_int[bid] = 0
|
||||
else:
|
||||
stop_flag_now_int[bid] = 1
|
||||
elif bid < max_bsz:
|
||||
# Threads in [real_bsz, max_bsz) contribute 1 to reduction
|
||||
stop_flag_now_int[bid] = 1
|
||||
else:
|
||||
stop_flag_now_int[bid] = 0
|
||||
|
||||
stop_sum = int(stop_flag_now_int.sum())
|
||||
not_need_stop[0] = stop_sum < int(stop_nums[0])
|
||||
|
||||
|
||||
class TestSpeculateScheduleCache(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
pass
|
||||
# paddle.device.set_device("cpu")
|
||||
|
||||
def setUp(self):
|
||||
# --- Construct a deterministic case that exercises all branches ---
|
||||
# real_bsz < max_bsz to test the padding logic in the CUB reduction
|
||||
self.real_bsz = 3 # 实际批次大小
|
||||
self.max_bsz = 5 # 最大批次大小(仅stop_flags使用)
|
||||
|
||||
self.draft_tokens_len = 6 # 草稿令牌长度
|
||||
self.accept_tokens_len = 5 # 接受令牌长度
|
||||
self.block_size = 4 # 块大小
|
||||
self.block_num_per_seq = 3 # 每个序列的块数
|
||||
self.max_draft_tokens = 2 # 最大草稿令牌数 -> max_next_step_tokens = 6
|
||||
|
||||
# 输入数据:bid0触发,bid2不触发,bid1已停止
|
||||
# seq_lens_decoder + 6 // 4 -> 索引:[1, 1, 4]。索引4越界 -> bid2不触发
|
||||
self.draft_tokens = paddle.to_tensor(
|
||||
np.array(
|
||||
[
|
||||
[1, 1, 1, 1, 1, 1], # bid0的草稿令牌
|
||||
[2, 2, 2, 2, 2, 2], # bid1的草稿令牌
|
||||
[3, 3, 3, 3, 3, 3], # bid2的草稿令牌
|
||||
],
|
||||
dtype=np.int64,
|
||||
)
|
||||
)
|
||||
self.block_tables = paddle.to_tensor(np.full((self.real_bsz, self.block_num_per_seq), -1, dtype=np.int32))
|
||||
# stop_flags长度为max_bsz,其他输入为real_bsz
|
||||
self.stop_flags = paddle.to_tensor(np.array([False, True, False, False, False], dtype=np.bool_))
|
||||
self.prompt_lens = paddle.to_tensor(np.array([1, 1, 1], dtype=np.int64))
|
||||
self.seq_lens_this_time = paddle.to_tensor(np.array([5, 6, 7], dtype=np.int32))
|
||||
self.seq_lens_encoder = paddle.to_tensor(np.array([1, 1, 1], dtype=np.int32))
|
||||
self.seq_lens_decoder = paddle.to_tensor(np.array([1, 1, 10], dtype=np.int32))
|
||||
|
||||
# 由kernel填充的输出张量(仅对触发批次)
|
||||
self.step_seq_lens_decoder = paddle.zeros((self.real_bsz,), dtype="int32")
|
||||
self.step_draft_tokens = paddle.zeros((self.real_bsz, self.draft_tokens_len), dtype="int64")
|
||||
self.step_seq_lens_this_time = paddle.zeros((self.real_bsz,), dtype="int32")
|
||||
|
||||
# 故意设置为非零值,以便验证只对触发部分进行原地清零
|
||||
self.accept_num = paddle.to_tensor(np.array([9, 8, 7], dtype=np.int32))
|
||||
self.accept_tokens = paddle.to_tensor(
|
||||
np.arange(self.real_bsz * self.accept_tokens_len, dtype=np.int64).reshape(
|
||||
self.real_bsz, self.accept_tokens_len
|
||||
)
|
||||
)
|
||||
self.is_block_step = paddle.zeros((self.real_bsz,), dtype=paddle.bool)
|
||||
|
||||
# not_need_stop在调用者端位于CPU;kernel内部会复制到设备
|
||||
self.not_need_stop = paddle.zeros((1,), dtype=paddle.bool).cpu()
|
||||
|
||||
# 设置阈值:bid0触发,bid1已停止,填充(5-3)=2 -> stop_sum = 1+1+2 = 4
|
||||
# 设置stop_nums为5,使得not_need_stop = (4 < 5) = True
|
||||
self.stop_nums = paddle.to_tensor([5], dtype=paddle.int64)
|
||||
|
||||
# 保存NumPy副本用于CPU参考实现对比
|
||||
self.np_draft_tokens = self.draft_tokens.numpy().copy()
|
||||
self.np_block_tables = self.block_tables.numpy().copy()
|
||||
self.np_stop_flags = self.stop_flags.numpy().copy()
|
||||
self.np_prompt_lens = self.prompt_lens.numpy().copy()
|
||||
self.np_seq_lens_this_time = self.seq_lens_this_time.numpy().copy()
|
||||
self.np_seq_lens_encoder = self.seq_lens_encoder.numpy().copy()
|
||||
self.np_seq_lens_decoder = self.seq_lens_decoder.numpy().copy()
|
||||
self.np_step_seq_lens_decoder = self.step_seq_lens_decoder.numpy().copy()
|
||||
self.np_step_draft_tokens = self.step_draft_tokens.numpy().copy()
|
||||
self.np_step_seq_lens_this_time = self.step_seq_lens_this_time.numpy().copy()
|
||||
self.np_accept_num = self.accept_num.numpy().copy()
|
||||
self.np_accept_tokens = self.accept_tokens.numpy().copy()
|
||||
self.np_is_block_step = self.is_block_step.numpy().copy()
|
||||
self.np_not_need_stop = self.not_need_stop.numpy().copy()
|
||||
self.np_stop_nums = self.stop_nums.numpy().copy()
|
||||
|
||||
def test_correctness_against_cpu_reference(self):
|
||||
# Run GPU kernel (in-place)
|
||||
speculate_schedule_cache(
|
||||
self.draft_tokens,
|
||||
self.block_tables,
|
||||
self.stop_flags,
|
||||
self.prompt_lens,
|
||||
self.seq_lens_this_time,
|
||||
self.seq_lens_encoder,
|
||||
self.seq_lens_decoder,
|
||||
self.step_seq_lens_decoder,
|
||||
self.step_draft_tokens,
|
||||
self.step_seq_lens_this_time,
|
||||
self.accept_num,
|
||||
self.accept_tokens,
|
||||
self.is_block_step,
|
||||
self.not_need_stop,
|
||||
self.stop_nums,
|
||||
self.block_size,
|
||||
self.max_draft_tokens,
|
||||
)
|
||||
|
||||
# Compute CPU reference (in-place on NumPy copies)
|
||||
cpu_reference(
|
||||
self.np_draft_tokens,
|
||||
self.np_block_tables,
|
||||
self.np_stop_flags,
|
||||
self.prompt_lens,
|
||||
self.np_seq_lens_this_time,
|
||||
self.np_seq_lens_encoder,
|
||||
self.np_seq_lens_decoder,
|
||||
self.np_step_seq_lens_decoder,
|
||||
self.np_step_draft_tokens,
|
||||
self.np_step_seq_lens_this_time,
|
||||
self.np_accept_num,
|
||||
self.np_accept_tokens,
|
||||
self.np_is_block_step,
|
||||
self.np_not_need_stop,
|
||||
self.np_stop_nums,
|
||||
self.block_size,
|
||||
self.max_draft_tokens,
|
||||
)
|
||||
|
||||
# Compare all mutated tensors
|
||||
np.testing.assert_array_equal(self.step_draft_tokens.numpy(), self.np_step_draft_tokens)
|
||||
np.testing.assert_array_equal(self.accept_tokens.numpy(), self.np_accept_tokens)
|
||||
np.testing.assert_array_equal(self.stop_flags.numpy(), self.np_stop_flags)
|
||||
np.testing.assert_array_equal(self.is_block_step.numpy(), self.np_is_block_step)
|
||||
np.testing.assert_array_equal(self.seq_lens_this_time.numpy(), self.np_seq_lens_this_time)
|
||||
np.testing.assert_array_equal(self.seq_lens_decoder.numpy(), self.np_seq_lens_decoder)
|
||||
np.testing.assert_array_equal(self.step_seq_lens_decoder.numpy(), self.np_step_seq_lens_decoder)
|
||||
np.testing.assert_array_equal(self.step_seq_lens_this_time.numpy(), self.np_step_seq_lens_this_time)
|
||||
np.testing.assert_array_equal(self.accept_num.numpy(), self.np_accept_num)
|
||||
self.assertEqual(bool(self.not_need_stop.numpy()[0]), bool(self.np_not_need_stop[0]))
|
||||
|
||||
def test_no_trigger_path(self):
|
||||
# Make block_tables at candidate index != -1 so nothing triggers
|
||||
# Candidate index for bid 0/1 is 1, set it to 7
|
||||
bt = self.block_tables.numpy()
|
||||
bt[:, 1] = 7
|
||||
self.block_tables = paddle.to_tensor(bt)
|
||||
|
||||
# Reset outputs to distinctive values
|
||||
self.step_seq_lens_decoder[:] = 0
|
||||
self.step_draft_tokens[:] = 0
|
||||
self.step_seq_lens_this_time[:] = 0
|
||||
self.accept_num[:] = -123
|
||||
self.accept_tokens[:] = -777
|
||||
self.is_block_step[:] = False
|
||||
self.not_need_stop[:] = False
|
||||
|
||||
# For not_need_stop: stopped_in_real = (bid1 True) = 1, padding = 2 -> stop_sum=3
|
||||
# With stop_nums=5 -> True
|
||||
speculate_schedule_cache(
|
||||
self.draft_tokens,
|
||||
self.block_tables,
|
||||
self.stop_flags,
|
||||
self.prompt_lens,
|
||||
self.seq_lens_this_time,
|
||||
self.seq_lens_encoder,
|
||||
self.seq_lens_decoder,
|
||||
self.step_seq_lens_decoder,
|
||||
self.step_draft_tokens,
|
||||
self.step_seq_lens_this_time,
|
||||
self.accept_num,
|
||||
self.accept_tokens,
|
||||
self.is_block_step,
|
||||
self.not_need_stop,
|
||||
self.stop_nums,
|
||||
self.block_size,
|
||||
self.max_draft_tokens,
|
||||
)
|
||||
|
||||
# Nothing should have changed except not_need_stop
|
||||
np.testing.assert_array_equal(self.step_draft_tokens.numpy(), np.zeros_like(self.step_draft_tokens.numpy()))
|
||||
np.testing.assert_array_equal(self.is_block_step.numpy(), np.zeros_like(self.is_block_step.numpy()))
|
||||
np.testing.assert_array_equal(self.accept_tokens.numpy(), np.full_like(self.accept_tokens.numpy(), -777))
|
||||
np.testing.assert_array_equal(self.accept_num.numpy(), np.full_like(self.accept_num.numpy(), -123))
|
||||
self.assertTrue(bool(self.not_need_stop.numpy()[0]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user