mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[XPU] rm stop nums (#6651)
* rm stop nums * fix conflict --------- Co-authored-by: Jiaxin Sui <95567040+plusNew001@users.noreply.github.com>
This commit is contained in:
@@ -39,16 +39,16 @@ std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
|
||||
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);
|
||||
|
||||
const int token_num_data = cpu_token_num.data<int64_t>()[0];
|
||||
auto x_remove_padding = paddle::empty(
|
||||
{token_num_data}, paddle::DataType::INT64, input_ids.place());
|
||||
auto padding_offset = paddle::empty(
|
||||
{token_num_data}, paddle::DataType::INT32, input_ids.place());
|
||||
auto batch_id_per_token = paddle::empty(
|
||||
{token_num_data}, paddle::DataType::INT32, input_ids.place());
|
||||
auto x_remove_padding = paddle::full(
|
||||
{token_num_data}, 0, paddle::DataType::INT64, input_ids.place());
|
||||
auto padding_offset = paddle::full(
|
||||
{token_num_data}, 0, paddle::DataType::INT32, input_ids.place());
|
||||
auto batch_id_per_token = paddle::full(
|
||||
{token_num_data}, 0, paddle::DataType::INT32, input_ids.place());
|
||||
auto cu_seqlens_q =
|
||||
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
|
||||
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
||||
auto cu_seqlens_k =
|
||||
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
|
||||
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
||||
|
||||
PD_CHECK(input_ids.is_contiguous(), "Input ids tensor must be contiguous");
|
||||
PD_CHECK(draft_tokens.is_contiguous(),
|
||||
|
||||
@@ -36,7 +36,6 @@ void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &accept_tokens,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
const paddle::Tensor &stop_nums,
|
||||
const int block_size,
|
||||
const int max_draft_tokens) {
|
||||
namespace api = baidu::xpu::api;
|
||||
@@ -79,7 +78,6 @@ void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
|
||||
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
|
||||
stop_nums.data<int64_t>(),
|
||||
real_bsz,
|
||||
max_bsz,
|
||||
max_next_step_tokens,
|
||||
@@ -97,8 +95,7 @@ void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
// PD_BUILD_STATIC_OP(speculate_schedule_cache)
|
||||
PD_BUILD_OP(speculate_schedule_cache)
|
||||
PD_BUILD_STATIC_OP(speculate_schedule_cache)
|
||||
.Inputs({"draft_tokens",
|
||||
"block_tables",
|
||||
"stop_flags",
|
||||
@@ -112,8 +109,7 @@ PD_BUILD_OP(speculate_schedule_cache)
|
||||
"accept_num",
|
||||
"accept_tokens",
|
||||
"is_block_step",
|
||||
"not_need_stop",
|
||||
"stop_nums"})
|
||||
"not_need_stop"})
|
||||
.Attrs({"block_size: int", "max_draft_tokens: int"})
|
||||
.Outputs({"draft_tokens_out",
|
||||
"block_tables_out",
|
||||
|
||||
@@ -323,6 +323,23 @@ void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens,
|
||||
bool save_each_rank,
|
||||
bool skip_prefill);
|
||||
|
||||
void SpeculateScheduleCache(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& prompt_lens,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_seq_lens_decoder,
|
||||
const paddle::Tensor& step_draft_tokens,
|
||||
const paddle::Tensor& step_seq_lens_this_time,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& is_block_step,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const int block_size,
|
||||
const int max_draft_tokens);
|
||||
|
||||
void SpeculateSaveWithOutputMsgDynamic(const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
@@ -627,7 +644,6 @@ void UpdateInputs(const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& stop_nums,
|
||||
const paddle::Tensor& next_tokens,
|
||||
const paddle::Tensor& is_block_step);
|
||||
|
||||
@@ -641,7 +657,6 @@ void UpdateInputsV1(const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& topk_ids,
|
||||
const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& stop_nums,
|
||||
const paddle::Tensor& next_tokens,
|
||||
const paddle::Tensor& is_block_step,
|
||||
const int block_size);
|
||||
@@ -1391,7 +1406,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("seq_lens_encoder"),
|
||||
py::arg("seq_lens_decoder"),
|
||||
py::arg("input_ids"),
|
||||
py::arg("stop_nums"),
|
||||
py::arg("next_tokens"),
|
||||
py::arg("is_block_step"),
|
||||
"Update inputs function");
|
||||
@@ -1408,7 +1422,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("topk_ids"),
|
||||
py::arg("input_ids"),
|
||||
py::arg("block_tables"),
|
||||
py::arg("stop_nums"),
|
||||
py::arg("next_tokens"),
|
||||
py::arg("is_block_step"),
|
||||
py::arg("block_size"),
|
||||
|
||||
@@ -23,7 +23,6 @@ void UpdateInputs(const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& stop_nums,
|
||||
const paddle::Tensor& next_tokens,
|
||||
const paddle::Tensor& is_block_step) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
@@ -47,7 +46,6 @@ void UpdateInputs(const paddle::Tensor& stop_flags,
|
||||
const_cast<int*>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int*>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t*>(input_ids.data<int64_t>()),
|
||||
stop_nums.data<int64_t>(),
|
||||
stop_flags.data<bool>(),
|
||||
is_block_step.data<bool>(),
|
||||
next_tokens.data<int64_t>(),
|
||||
@@ -68,7 +66,6 @@ PD_BUILD_OP(update_inputs)
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"input_ids",
|
||||
"stop_nums",
|
||||
"next_tokens",
|
||||
"is_block_step"})
|
||||
.Outputs({"not_need_stop_out",
|
||||
|
||||
@@ -17,6 +17,10 @@
|
||||
#include "paddle/phi/core/enforce.h"
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
void UpdateInputsV1(const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& not_need_stop, // only on cpu
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
@@ -27,7 +31,6 @@ void UpdateInputsV1(const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& topk_ids,
|
||||
const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& stop_nums,
|
||||
const paddle::Tensor& next_tokens,
|
||||
const paddle::Tensor& is_block_step,
|
||||
const int block_size) {
|
||||
@@ -52,7 +55,6 @@ void UpdateInputsV1(const paddle::Tensor& stop_flags,
|
||||
const_cast<int64_t*>(topk_ids.data<int64_t>()),
|
||||
const_cast<int64_t*>(input_ids.data<int64_t>()),
|
||||
const_cast<int*>(block_tables.data<int>()),
|
||||
stop_nums.data<int64_t>(),
|
||||
const_cast<bool*>(stop_flags.data<bool>()),
|
||||
const_cast<bool*>(is_block_step.data<bool>()),
|
||||
next_tokens.data<int64_t>(),
|
||||
@@ -68,7 +70,7 @@ void UpdateInputsV1(const paddle::Tensor& stop_flags,
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
PD_BUILD_OP(update_inputs_v1)
|
||||
PD_BUILD_STATIC_OP(update_inputs_v1)
|
||||
.Inputs({"stop_flags",
|
||||
"not_need_stop",
|
||||
"seq_lens_this_time",
|
||||
@@ -79,7 +81,6 @@ PD_BUILD_OP(update_inputs_v1)
|
||||
"topk_ids",
|
||||
"input_ids",
|
||||
"block_tables",
|
||||
"stop_nums",
|
||||
"next_tokens",
|
||||
"is_block_step"})
|
||||
.Attrs({"block_size: int"})
|
||||
|
||||
@@ -123,7 +123,6 @@ DLL_EXPORT int update_inputs(Context* ctx,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int64_t* input_ids,
|
||||
const int64_t* stop_nums,
|
||||
const bool* stop_flags,
|
||||
const bool* is_block_step,
|
||||
const int64_t* next_tokens,
|
||||
@@ -264,7 +263,6 @@ DLL_EXPORT int update_inputs_v1(Context* ctx,
|
||||
int64_t* topk_ids,
|
||||
int64_t* input_ids,
|
||||
int* block_tables,
|
||||
const int64_t* stop_nums,
|
||||
bool* stop_flags,
|
||||
bool* is_block_step,
|
||||
const int64_t* next_tokens,
|
||||
@@ -618,7 +616,6 @@ DLL_EXPORT int speculate_schedule_cache(Context* ctx,
|
||||
int64_t* accept_tokens,
|
||||
bool* is_block_step,
|
||||
bool* not_need_stop,
|
||||
const int64_t* stop_nums,
|
||||
const int real_bsz,
|
||||
const int max_bsz,
|
||||
const int max_next_step_tokens,
|
||||
|
||||
+1
-5
@@ -53,7 +53,6 @@ __global__ void speculate_schedule_cache(const int64_t *draft_tokens,
|
||||
int64_t *accept_tokens,
|
||||
bool *is_block_step,
|
||||
bool *not_need_stop,
|
||||
const int64_t *stop_nums,
|
||||
const int real_bsz,
|
||||
const int max_bsz,
|
||||
const int max_next_step_tokens,
|
||||
@@ -166,11 +165,8 @@ __global__ void speculate_schedule_cache(const int64_t *draft_tokens,
|
||||
// reduce stop_sum_now
|
||||
if (cid == 0) {
|
||||
int stop_sum = ClusterReduce(stop_flag_now_int_sm, 64);
|
||||
int stop_nums_lm;
|
||||
GM2LM_ASYNC(stop_nums, &stop_nums_lm, sizeof(int));
|
||||
mfence();
|
||||
sync_all();
|
||||
if (stop_sum < stop_nums_lm) {
|
||||
if (stop_sum < max_bsz) {
|
||||
LM2GM_ASYNC(&value_true, not_need_stop, sizeof(bool));
|
||||
} else {
|
||||
LM2GM_ASYNC(&value_false, not_need_stop, sizeof(bool));
|
||||
|
||||
@@ -10,7 +10,6 @@ __global__ void update_inputs(bool *not_need_stop,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int64_t *input_ids,
|
||||
const int64_t *stop_nums,
|
||||
const bool *stop_flags,
|
||||
const bool *is_block_step,
|
||||
const int64_t *next_tokens,
|
||||
@@ -66,9 +65,7 @@ __global__ void update_inputs(bool *not_need_stop,
|
||||
stop_sum += stop_flags_int_sm[i];
|
||||
}
|
||||
stop_sum += (max_bsz - bsz);
|
||||
int64_t stop_num;
|
||||
GM2LM(stop_nums, &stop_num, sizeof(int64_t));
|
||||
bool not_need_stop_update = stop_sum < static_cast<int>(stop_num);
|
||||
bool not_need_stop_update = stop_sum < max_bsz;
|
||||
mfence_lm();
|
||||
LM2GM(¬_need_stop_update, not_need_stop, sizeof(bool));
|
||||
}
|
||||
|
||||
@@ -19,7 +19,6 @@ __global__ void update_inputs_v1(bool* not_need_stop,
|
||||
int64_t* topk_ids,
|
||||
int64_t* input_ids,
|
||||
int* block_tables,
|
||||
const int64_t* stop_nums,
|
||||
bool* stop_flags,
|
||||
bool* is_block_step,
|
||||
const int64_t* next_tokens,
|
||||
@@ -143,9 +142,7 @@ __global__ void update_inputs_v1(bool* not_need_stop,
|
||||
stop_sum += stop_flags_int_sm[i];
|
||||
}
|
||||
// printf("stop_sum : %d\n", stop_sum);
|
||||
int64_t stop_num;
|
||||
GM2LM(stop_nums, &stop_num, sizeof(int64_t));
|
||||
bool not_need_stop_update = stop_sum < static_cast<int>(stop_num);
|
||||
bool not_need_stop_update = stop_sum < max_bsz;
|
||||
mfence_lm();
|
||||
LM2GM(¬_need_stop_update, not_need_stop, sizeof(bool));
|
||||
}
|
||||
|
||||
@@ -34,7 +34,6 @@ __attribute__((global)) void speculate_schedule_cache(
|
||||
int64_t *accept_tokens,
|
||||
bool *is_block_step,
|
||||
bool *not_need_stop,
|
||||
const int64_t *stop_nums,
|
||||
const int real_bsz,
|
||||
const int max_bsz,
|
||||
const int max_next_step_tokens,
|
||||
@@ -66,7 +65,6 @@ static int cpu_wrapper(Context *ctx,
|
||||
int64_t *accept_tokens,
|
||||
bool *is_block_step,
|
||||
bool *not_need_stop,
|
||||
const int64_t *stop_nums,
|
||||
const int real_bsz,
|
||||
const int max_bsz,
|
||||
const int max_next_step_tokens,
|
||||
@@ -131,7 +129,7 @@ static int cpu_wrapper(Context *ctx,
|
||||
// }
|
||||
|
||||
// printf("stop_sum %d \n", stop_sum);
|
||||
not_need_stop[0] = stop_sum_now < stop_nums[0];
|
||||
not_need_stop[0] = stop_sum_now < max_bsz;
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
@@ -150,7 +148,6 @@ static int xpu3_wrapper(Context *ctx,
|
||||
int64_t *accept_tokens,
|
||||
bool *is_block_step,
|
||||
bool *not_need_stop,
|
||||
const int64_t *stop_nums,
|
||||
const int real_bsz,
|
||||
const int max_bsz,
|
||||
const int max_next_step_tokens,
|
||||
@@ -177,7 +174,6 @@ static int xpu3_wrapper(Context *ctx,
|
||||
reinterpret_cast<XPU_TI *>(accept_tokens),
|
||||
is_block_step,
|
||||
not_need_stop,
|
||||
(const XPU_TI *)stop_nums,
|
||||
real_bsz,
|
||||
max_bsz,
|
||||
max_next_step_tokens,
|
||||
@@ -205,7 +201,6 @@ int speculate_schedule_cache(Context *ctx,
|
||||
int64_t *accept_tokens,
|
||||
bool *is_block_step,
|
||||
bool *not_need_stop,
|
||||
const int64_t *stop_nums,
|
||||
const int real_bsz,
|
||||
const int max_bsz,
|
||||
const int max_next_step_tokens,
|
||||
@@ -230,10 +225,9 @@ int speculate_schedule_cache(Context *ctx,
|
||||
step_seq_lens_this_time,
|
||||
accept_num,
|
||||
accept_tokens);
|
||||
WRAPPER_DUMP_PARAM6(ctx,
|
||||
WRAPPER_DUMP_PARAM5(ctx,
|
||||
is_block_step,
|
||||
not_need_stop,
|
||||
stop_nums,
|
||||
real_bsz,
|
||||
max_bsz,
|
||||
max_next_step_tokens);
|
||||
@@ -276,7 +270,6 @@ int speculate_schedule_cache(Context *ctx,
|
||||
accept_tokens,
|
||||
is_block_step,
|
||||
not_need_stop,
|
||||
stop_nums,
|
||||
real_bsz,
|
||||
max_bsz,
|
||||
max_next_step_tokens,
|
||||
@@ -302,7 +295,6 @@ int speculate_schedule_cache(Context *ctx,
|
||||
accept_tokens,
|
||||
is_block_step,
|
||||
not_need_stop,
|
||||
stop_nums,
|
||||
real_bsz,
|
||||
max_bsz,
|
||||
max_next_step_tokens,
|
||||
|
||||
@@ -170,7 +170,7 @@ int speculate_update(Context *ctx,
|
||||
const int max_bsz,
|
||||
const int max_draft_tokens) {
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_update_v3", int);
|
||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_update", int);
|
||||
WRAPPER_DUMP_PARAM4(
|
||||
ctx, seq_lens_encoder, seq_lens_decoder, not_need_stop, draft_tokens);
|
||||
WRAPPER_DUMP_PARAM4(
|
||||
|
||||
@@ -25,7 +25,6 @@ __attribute__((global)) void update_inputs(bool *not_need_stop,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int64_t *input_ids,
|
||||
const int64_t *stop_nums,
|
||||
const bool *stop_flags,
|
||||
const bool *is_block_step,
|
||||
const int64_t *next_tokens,
|
||||
@@ -47,7 +46,6 @@ static int cpu_wrapper(Context *ctx,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int64_t *input_ids,
|
||||
const int64_t *stop_nums,
|
||||
const bool *stop_flags,
|
||||
const bool *is_block_step,
|
||||
const int64_t *next_tokens,
|
||||
@@ -75,7 +73,7 @@ static int cpu_wrapper(Context *ctx,
|
||||
for (size_t i = 0; i < stop_flag_now_int.size(); i++) {
|
||||
stop_sum += stop_flag_now_int[i];
|
||||
}
|
||||
not_need_stop[0] = stop_sum < stop_nums[0];
|
||||
not_need_stop[0] = stop_sum < max_bsz;
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
@@ -85,7 +83,6 @@ static int xpu3_wrapper(Context *ctx,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int64_t *input_ids,
|
||||
const int64_t *stop_nums,
|
||||
const bool *stop_flags,
|
||||
const bool *is_block_step,
|
||||
const int64_t *next_tokens,
|
||||
@@ -100,7 +97,6 @@ static int xpu3_wrapper(Context *ctx,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
reinterpret_cast<XPU_INT64 *>(input_ids),
|
||||
reinterpret_cast<const XPU_INT64 *>(stop_nums),
|
||||
stop_flags,
|
||||
is_block_step,
|
||||
reinterpret_cast<const XPU_INT64 *>(next_tokens),
|
||||
@@ -117,7 +113,6 @@ int update_inputs(Context *ctx,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int64_t *input_ids,
|
||||
const int64_t *stop_nums,
|
||||
const bool *stop_flags,
|
||||
const bool *is_block_step,
|
||||
const int64_t *next_tokens,
|
||||
@@ -132,7 +127,7 @@ int update_inputs(Context *ctx,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
input_ids);
|
||||
WRAPPER_DUMP_PARAM4(ctx, stop_nums, stop_flags, is_block_step, next_tokens);
|
||||
WRAPPER_DUMP_PARAM3(ctx, stop_flags, is_block_step, next_tokens);
|
||||
WRAPPER_DUMP_PARAM3(ctx, bsz, max_bsz, input_ids_stride);
|
||||
WRAPPER_DUMP(ctx);
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
@@ -142,7 +137,6 @@ int update_inputs(Context *ctx,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
input_ids,
|
||||
stop_nums,
|
||||
stop_flags,
|
||||
is_block_step,
|
||||
next_tokens,
|
||||
@@ -157,7 +151,6 @@ int update_inputs(Context *ctx,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
input_ids,
|
||||
stop_nums,
|
||||
stop_flags,
|
||||
is_block_step,
|
||||
next_tokens,
|
||||
|
||||
@@ -29,7 +29,6 @@ __attribute__((global)) void update_inputs_v1(bool* not_need_stop,
|
||||
int64_t* topk_ids,
|
||||
int64_t* input_ids,
|
||||
int* block_tables,
|
||||
const int64_t* stop_nums,
|
||||
bool* stop_flags,
|
||||
bool* is_block_step,
|
||||
const int64_t* next_tokens,
|
||||
@@ -58,7 +57,6 @@ static int xpu3_wrapper(Context* ctx,
|
||||
int64_t* topk_ids,
|
||||
int64_t* input_ids,
|
||||
int* block_tables,
|
||||
const int64_t* stop_nums,
|
||||
bool* stop_flags,
|
||||
bool* is_block_step,
|
||||
const int64_t* next_tokens,
|
||||
@@ -86,7 +84,6 @@ static int xpu3_wrapper(Context* ctx,
|
||||
reinterpret_cast<XPU_INT64*>(topk_ids),
|
||||
reinterpret_cast<XPU_INT64*>(input_ids),
|
||||
block_tables,
|
||||
reinterpret_cast<const XPU_INT64*>(stop_nums),
|
||||
stop_flags,
|
||||
is_block_step,
|
||||
reinterpret_cast<const XPU_INT64*>(next_tokens),
|
||||
@@ -110,7 +107,6 @@ int update_inputs_v1(Context* ctx,
|
||||
int64_t* topk_ids,
|
||||
int64_t* input_ids,
|
||||
int* block_tables,
|
||||
const int64_t* stop_nums,
|
||||
bool* stop_flags,
|
||||
bool* is_block_step,
|
||||
const int64_t* next_tokens,
|
||||
@@ -127,11 +123,10 @@ int update_inputs_v1(Context* ctx,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_seq_lens_decoder);
|
||||
WRAPPER_DUMP_PARAM5(
|
||||
ctx, prompt_lens, topk_ids, input_ids, block_tables, stop_nums);
|
||||
WRAPPER_DUMP_PARAM4(ctx, prompt_lens, topk_ids, input_ids, block_tables);
|
||||
WRAPPER_DUMP_PARAM3(ctx, stop_flags, is_block_step, next_tokens);
|
||||
WRAPPER_DUMP_PARAM5(
|
||||
ctx, bsz, max_bsz, input_ids_stride, block_num_per_seq, block_size);
|
||||
WRAPPER_DUMP_PARAM4(
|
||||
ctx, bsz, input_ids_stride, block_num_per_seq, block_size);
|
||||
WRAPPER_DUMP(ctx);
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
assert(false);
|
||||
@@ -147,7 +142,6 @@ int update_inputs_v1(Context* ctx,
|
||||
topk_ids,
|
||||
input_ids,
|
||||
block_tables,
|
||||
stop_nums,
|
||||
stop_flags,
|
||||
is_block_step,
|
||||
next_tokens,
|
||||
|
||||
@@ -37,7 +37,6 @@ def cpu_reference(
|
||||
accept_tokens,
|
||||
is_block_step,
|
||||
not_need_stop,
|
||||
stop_nums,
|
||||
block_size,
|
||||
max_draft_tokens,
|
||||
):
|
||||
@@ -82,7 +81,7 @@ def cpu_reference(
|
||||
stop_flag_now_int[bid] = 0
|
||||
|
||||
stop_sum = int(stop_flag_now_int.sum())
|
||||
not_need_stop[0] = stop_sum < int(stop_nums[0])
|
||||
not_need_stop[0] = stop_sum < max_bsz
|
||||
|
||||
|
||||
class TestSpeculateScheduleCache(unittest.TestCase):
|
||||
@@ -141,8 +140,6 @@ class TestSpeculateScheduleCache(unittest.TestCase):
|
||||
self.not_need_stop = paddle.zeros((1,), dtype=paddle.bool).cpu()
|
||||
|
||||
# 设置阈值:bid0触发,bid1已停止,填充(5-3)=2 -> stop_sum = 1+1+2 = 4
|
||||
# 设置stop_nums为5,使得not_need_stop = (4 < 5) = True
|
||||
self.stop_nums = paddle.to_tensor([5], dtype=paddle.int64)
|
||||
|
||||
# 保存NumPy副本用于CPU参考实现对比
|
||||
self.np_draft_tokens = self.draft_tokens.numpy().copy()
|
||||
@@ -159,7 +156,6 @@ class TestSpeculateScheduleCache(unittest.TestCase):
|
||||
self.np_accept_tokens = self.accept_tokens.numpy().copy()
|
||||
self.np_is_block_step = self.is_block_step.numpy().copy()
|
||||
self.np_not_need_stop = self.not_need_stop.numpy().copy()
|
||||
self.np_stop_nums = self.stop_nums.numpy().copy()
|
||||
|
||||
def test_correctness_against_cpu_reference(self):
|
||||
# Run GPU kernel (in-place)
|
||||
@@ -178,7 +174,6 @@ class TestSpeculateScheduleCache(unittest.TestCase):
|
||||
self.accept_tokens,
|
||||
self.is_block_step,
|
||||
self.not_need_stop,
|
||||
self.stop_nums,
|
||||
self.block_size,
|
||||
self.max_draft_tokens,
|
||||
)
|
||||
@@ -199,7 +194,6 @@ class TestSpeculateScheduleCache(unittest.TestCase):
|
||||
self.np_accept_tokens,
|
||||
self.np_is_block_step,
|
||||
self.np_not_need_stop,
|
||||
self.np_stop_nums,
|
||||
self.block_size,
|
||||
self.max_draft_tokens,
|
||||
)
|
||||
@@ -233,7 +227,6 @@ class TestSpeculateScheduleCache(unittest.TestCase):
|
||||
self.not_need_stop[:] = False
|
||||
|
||||
# For not_need_stop: stopped_in_real = (bid1 True) = 1, padding = 2 -> stop_sum=3
|
||||
# With stop_nums=5 -> True
|
||||
speculate_schedule_cache(
|
||||
self.draft_tokens,
|
||||
self.block_tables,
|
||||
@@ -249,7 +242,6 @@ class TestSpeculateScheduleCache(unittest.TestCase):
|
||||
self.accept_tokens,
|
||||
self.is_block_step,
|
||||
self.not_need_stop,
|
||||
self.stop_nums,
|
||||
self.block_size,
|
||||
self.max_draft_tokens,
|
||||
)
|
||||
|
||||
@@ -1346,10 +1346,7 @@ class MTPSampler(nn.Layer):
|
||||
sampling_metadata.pre_token_ids,
|
||||
)
|
||||
probs = F.softmax(logits)
|
||||
|
||||
_, next_tokens = top_k_top_p_sampling(
|
||||
probs, sampling_metadata.top_p, sampling_metadata.top_k, sampling_metadata.top_k_list
|
||||
)
|
||||
next_tokens = paddle.argmax(probs, axis=-1)
|
||||
# TODO(chenhuan09): add support for logprobs
|
||||
token_ids = None
|
||||
logprobs_tensors = None
|
||||
|
||||
@@ -351,7 +351,6 @@ def xpu_post_process_normal(
|
||||
sampled_token_ids,
|
||||
model_output.input_ids,
|
||||
share_inputs["block_tables"],
|
||||
model_output.stop_nums,
|
||||
model_output.next_tokens,
|
||||
model_output.is_block_step,
|
||||
block_size,
|
||||
@@ -364,7 +363,6 @@ def xpu_post_process_normal(
|
||||
model_output.seq_lens_encoder,
|
||||
model_output.seq_lens_decoder,
|
||||
model_output.input_ids,
|
||||
model_output.stop_nums,
|
||||
sampled_token_ids,
|
||||
model_output.is_block_step,
|
||||
)
|
||||
|
||||
@@ -948,7 +948,6 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
[1], False, dtype="bool"
|
||||
).cpu() # TODO(gongshaotian): move to pinnd memory
|
||||
self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], True, dtype="bool")
|
||||
self.share_inputs["stop_nums"] = paddle.full([1], max_num_seqs, dtype="int64")
|
||||
|
||||
self.share_inputs["bad_tokens"] = paddle.full([max_num_seqs, self.model_config.vocab_size], -1, dtype="int64")
|
||||
self.share_inputs["bad_tokens_len"] = paddle.full([max_num_seqs], 1, dtype="int64")
|
||||
@@ -1607,7 +1606,6 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
eos_token_id=self.share_inputs["eos_token_id"],
|
||||
not_need_stop=self.share_inputs["not_need_stop"],
|
||||
input_ids=self.share_inputs["input_ids"],
|
||||
stop_nums=self.share_inputs["stop_nums"],
|
||||
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
|
||||
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
|
||||
is_block_step=self.share_inputs["is_block_step"],
|
||||
@@ -1688,7 +1686,6 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["accept_tokens"],
|
||||
self.share_inputs["is_block_step"],
|
||||
self.share_inputs["not_need_stop"],
|
||||
self.share_inputs["stop_nums"],
|
||||
self.cache_config.block_size,
|
||||
self.speculative_config.num_speculative_tokens,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user