[XPU] rm stop nums (#6651)

* rm stop nums

* fix conflict

---------

Co-authored-by: Jiaxin Sui <95567040+plusNew001@users.noreply.github.com>
This commit is contained in:
cmcamdy
2026-03-12 14:05:58 +08:00
committed by GitHub
parent 7d31a728d1
commit 3543088d3e
17 changed files with 45 additions and 88 deletions
@@ -39,16 +39,16 @@ std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);
const int token_num_data = cpu_token_num.data<int64_t>()[0];
auto x_remove_padding = paddle::empty(
{token_num_data}, paddle::DataType::INT64, input_ids.place());
auto padding_offset = paddle::empty(
{token_num_data}, paddle::DataType::INT32, input_ids.place());
auto batch_id_per_token = paddle::empty(
{token_num_data}, paddle::DataType::INT32, input_ids.place());
auto x_remove_padding = paddle::full(
{token_num_data}, 0, paddle::DataType::INT64, input_ids.place());
auto padding_offset = paddle::full(
{token_num_data}, 0, paddle::DataType::INT32, input_ids.place());
auto batch_id_per_token = paddle::full(
{token_num_data}, 0, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_q =
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_k =
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
PD_CHECK(input_ids.is_contiguous(), "Input ids tensor must be contiguous");
PD_CHECK(draft_tokens.is_contiguous(),
@@ -36,7 +36,6 @@ void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
const paddle::Tensor &accept_tokens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &stop_nums,
const int block_size,
const int max_draft_tokens) {
namespace api = baidu::xpu::api;
@@ -79,7 +78,6 @@ void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<bool *>(is_block_step.data<bool>()),
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
stop_nums.data<int64_t>(),
real_bsz,
max_bsz,
max_next_step_tokens,
@@ -97,8 +95,7 @@ void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}
// PD_BUILD_STATIC_OP(speculate_schedule_cache)
PD_BUILD_OP(speculate_schedule_cache)
PD_BUILD_STATIC_OP(speculate_schedule_cache)
.Inputs({"draft_tokens",
"block_tables",
"stop_flags",
@@ -112,8 +109,7 @@ PD_BUILD_OP(speculate_schedule_cache)
"accept_num",
"accept_tokens",
"is_block_step",
"not_need_stop",
"stop_nums"})
"not_need_stop"})
.Attrs({"block_size: int", "max_draft_tokens: int"})
.Outputs({"draft_tokens_out",
"block_tables_out",
+17 -4
View File
@@ -323,6 +323,23 @@ void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens,
bool save_each_rank,
bool skip_prefill);
void SpeculateScheduleCache(const paddle::Tensor& draft_tokens,
const paddle::Tensor& block_tables,
const paddle::Tensor& stop_flags,
const paddle::Tensor& prompt_lens,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_seq_lens_decoder,
const paddle::Tensor& step_draft_tokens,
const paddle::Tensor& step_seq_lens_this_time,
const paddle::Tensor& accept_num,
const paddle::Tensor& accept_tokens,
const paddle::Tensor& is_block_step,
const paddle::Tensor& not_need_stop,
const int block_size,
const int max_draft_tokens);
void SpeculateSaveWithOutputMsgDynamic(const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& not_need_stop,
@@ -627,7 +644,6 @@ void UpdateInputs(const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& input_ids,
const paddle::Tensor& stop_nums,
const paddle::Tensor& next_tokens,
const paddle::Tensor& is_block_step);
@@ -641,7 +657,6 @@ void UpdateInputsV1(const paddle::Tensor& stop_flags,
const paddle::Tensor& topk_ids,
const paddle::Tensor& input_ids,
const paddle::Tensor& block_tables,
const paddle::Tensor& stop_nums,
const paddle::Tensor& next_tokens,
const paddle::Tensor& is_block_step,
const int block_size);
@@ -1391,7 +1406,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("seq_lens_encoder"),
py::arg("seq_lens_decoder"),
py::arg("input_ids"),
py::arg("stop_nums"),
py::arg("next_tokens"),
py::arg("is_block_step"),
"Update inputs function");
@@ -1408,7 +1422,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("topk_ids"),
py::arg("input_ids"),
py::arg("block_tables"),
py::arg("stop_nums"),
py::arg("next_tokens"),
py::arg("is_block_step"),
py::arg("block_size"),
@@ -23,7 +23,6 @@ void UpdateInputs(const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& input_ids,
const paddle::Tensor& stop_nums,
const paddle::Tensor& next_tokens,
const paddle::Tensor& is_block_step) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
@@ -47,7 +46,6 @@ void UpdateInputs(const paddle::Tensor& stop_flags,
const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<int64_t*>(input_ids.data<int64_t>()),
stop_nums.data<int64_t>(),
stop_flags.data<bool>(),
is_block_step.data<bool>(),
next_tokens.data<int64_t>(),
@@ -68,7 +66,6 @@ PD_BUILD_OP(update_inputs)
"seq_lens_encoder",
"seq_lens_decoder",
"input_ids",
"stop_nums",
"next_tokens",
"is_block_step"})
.Outputs({"not_need_stop_out",
@@ -17,6 +17,10 @@
#include "paddle/phi/core/enforce.h"
#include "xpu/plugin.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
void UpdateInputsV1(const paddle::Tensor& stop_flags,
const paddle::Tensor& not_need_stop, // only on cpu
const paddle::Tensor& seq_lens_this_time,
@@ -27,7 +31,6 @@ void UpdateInputsV1(const paddle::Tensor& stop_flags,
const paddle::Tensor& topk_ids,
const paddle::Tensor& input_ids,
const paddle::Tensor& block_tables,
const paddle::Tensor& stop_nums,
const paddle::Tensor& next_tokens,
const paddle::Tensor& is_block_step,
const int block_size) {
@@ -52,7 +55,6 @@ void UpdateInputsV1(const paddle::Tensor& stop_flags,
const_cast<int64_t*>(topk_ids.data<int64_t>()),
const_cast<int64_t*>(input_ids.data<int64_t>()),
const_cast<int*>(block_tables.data<int>()),
stop_nums.data<int64_t>(),
const_cast<bool*>(stop_flags.data<bool>()),
const_cast<bool*>(is_block_step.data<bool>()),
next_tokens.data<int64_t>(),
@@ -68,7 +70,7 @@ void UpdateInputsV1(const paddle::Tensor& stop_flags,
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}
PD_BUILD_OP(update_inputs_v1)
PD_BUILD_STATIC_OP(update_inputs_v1)
.Inputs({"stop_flags",
"not_need_stop",
"seq_lens_this_time",
@@ -79,7 +81,6 @@ PD_BUILD_OP(update_inputs_v1)
"topk_ids",
"input_ids",
"block_tables",
"stop_nums",
"next_tokens",
"is_block_step"})
.Attrs({"block_size: int"})
@@ -123,7 +123,6 @@ DLL_EXPORT int update_inputs(Context* ctx,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* input_ids,
const int64_t* stop_nums,
const bool* stop_flags,
const bool* is_block_step,
const int64_t* next_tokens,
@@ -264,7 +263,6 @@ DLL_EXPORT int update_inputs_v1(Context* ctx,
int64_t* topk_ids,
int64_t* input_ids,
int* block_tables,
const int64_t* stop_nums,
bool* stop_flags,
bool* is_block_step,
const int64_t* next_tokens,
@@ -618,7 +616,6 @@ DLL_EXPORT int speculate_schedule_cache(Context* ctx,
int64_t* accept_tokens,
bool* is_block_step,
bool* not_need_stop,
const int64_t* stop_nums,
const int real_bsz,
const int max_bsz,
const int max_next_step_tokens,
@@ -53,7 +53,6 @@ __global__ void speculate_schedule_cache(const int64_t *draft_tokens,
int64_t *accept_tokens,
bool *is_block_step,
bool *not_need_stop,
const int64_t *stop_nums,
const int real_bsz,
const int max_bsz,
const int max_next_step_tokens,
@@ -166,11 +165,8 @@ __global__ void speculate_schedule_cache(const int64_t *draft_tokens,
// reduce stop_sum_now
if (cid == 0) {
int stop_sum = ClusterReduce(stop_flag_now_int_sm, 64);
int stop_nums_lm;
GM2LM_ASYNC(stop_nums, &stop_nums_lm, sizeof(int));
mfence();
sync_all();
if (stop_sum < stop_nums_lm) {
if (stop_sum < max_bsz) {
LM2GM_ASYNC(&value_true, not_need_stop, sizeof(bool));
} else {
LM2GM_ASYNC(&value_false, not_need_stop, sizeof(bool));
@@ -10,7 +10,6 @@ __global__ void update_inputs(bool *not_need_stop,
int *seq_lens_encoder,
int *seq_lens_decoder,
int64_t *input_ids,
const int64_t *stop_nums,
const bool *stop_flags,
const bool *is_block_step,
const int64_t *next_tokens,
@@ -66,9 +65,7 @@ __global__ void update_inputs(bool *not_need_stop,
stop_sum += stop_flags_int_sm[i];
}
stop_sum += (max_bsz - bsz);
int64_t stop_num;
GM2LM(stop_nums, &stop_num, sizeof(int64_t));
bool not_need_stop_update = stop_sum < static_cast<int>(stop_num);
bool not_need_stop_update = stop_sum < max_bsz;
mfence_lm();
LM2GM(&not_need_stop_update, not_need_stop, sizeof(bool));
}
@@ -19,7 +19,6 @@ __global__ void update_inputs_v1(bool* not_need_stop,
int64_t* topk_ids,
int64_t* input_ids,
int* block_tables,
const int64_t* stop_nums,
bool* stop_flags,
bool* is_block_step,
const int64_t* next_tokens,
@@ -143,9 +142,7 @@ __global__ void update_inputs_v1(bool* not_need_stop,
stop_sum += stop_flags_int_sm[i];
}
// printf("stop_sum : %d\n", stop_sum);
int64_t stop_num;
GM2LM(stop_nums, &stop_num, sizeof(int64_t));
bool not_need_stop_update = stop_sum < static_cast<int>(stop_num);
bool not_need_stop_update = stop_sum < max_bsz;
mfence_lm();
LM2GM(&not_need_stop_update, not_need_stop, sizeof(bool));
}
@@ -34,7 +34,6 @@ __attribute__((global)) void speculate_schedule_cache(
int64_t *accept_tokens,
bool *is_block_step,
bool *not_need_stop,
const int64_t *stop_nums,
const int real_bsz,
const int max_bsz,
const int max_next_step_tokens,
@@ -66,7 +65,6 @@ static int cpu_wrapper(Context *ctx,
int64_t *accept_tokens,
bool *is_block_step,
bool *not_need_stop,
const int64_t *stop_nums,
const int real_bsz,
const int max_bsz,
const int max_next_step_tokens,
@@ -131,7 +129,7 @@ static int cpu_wrapper(Context *ctx,
// }
// printf("stop_sum %d \n", stop_sum);
not_need_stop[0] = stop_sum_now < stop_nums[0];
not_need_stop[0] = stop_sum_now < max_bsz;
return api::SUCCESS;
}
@@ -150,7 +148,6 @@ static int xpu3_wrapper(Context *ctx,
int64_t *accept_tokens,
bool *is_block_step,
bool *not_need_stop,
const int64_t *stop_nums,
const int real_bsz,
const int max_bsz,
const int max_next_step_tokens,
@@ -177,7 +174,6 @@ static int xpu3_wrapper(Context *ctx,
reinterpret_cast<XPU_TI *>(accept_tokens),
is_block_step,
not_need_stop,
(const XPU_TI *)stop_nums,
real_bsz,
max_bsz,
max_next_step_tokens,
@@ -205,7 +201,6 @@ int speculate_schedule_cache(Context *ctx,
int64_t *accept_tokens,
bool *is_block_step,
bool *not_need_stop,
const int64_t *stop_nums,
const int real_bsz,
const int max_bsz,
const int max_next_step_tokens,
@@ -230,10 +225,9 @@ int speculate_schedule_cache(Context *ctx,
step_seq_lens_this_time,
accept_num,
accept_tokens);
WRAPPER_DUMP_PARAM6(ctx,
WRAPPER_DUMP_PARAM5(ctx,
is_block_step,
not_need_stop,
stop_nums,
real_bsz,
max_bsz,
max_next_step_tokens);
@@ -276,7 +270,6 @@ int speculate_schedule_cache(Context *ctx,
accept_tokens,
is_block_step,
not_need_stop,
stop_nums,
real_bsz,
max_bsz,
max_next_step_tokens,
@@ -302,7 +295,6 @@ int speculate_schedule_cache(Context *ctx,
accept_tokens,
is_block_step,
not_need_stop,
stop_nums,
real_bsz,
max_bsz,
max_next_step_tokens,
@@ -170,7 +170,7 @@ int speculate_update(Context *ctx,
const int max_bsz,
const int max_draft_tokens) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_update_v3", int);
WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_update", int);
WRAPPER_DUMP_PARAM4(
ctx, seq_lens_encoder, seq_lens_decoder, not_need_stop, draft_tokens);
WRAPPER_DUMP_PARAM4(
@@ -25,7 +25,6 @@ __attribute__((global)) void update_inputs(bool *not_need_stop,
int *seq_lens_encoder,
int *seq_lens_decoder,
int64_t *input_ids,
const int64_t *stop_nums,
const bool *stop_flags,
const bool *is_block_step,
const int64_t *next_tokens,
@@ -47,7 +46,6 @@ static int cpu_wrapper(Context *ctx,
int *seq_lens_encoder,
int *seq_lens_decoder,
int64_t *input_ids,
const int64_t *stop_nums,
const bool *stop_flags,
const bool *is_block_step,
const int64_t *next_tokens,
@@ -75,7 +73,7 @@ static int cpu_wrapper(Context *ctx,
for (size_t i = 0; i < stop_flag_now_int.size(); i++) {
stop_sum += stop_flag_now_int[i];
}
not_need_stop[0] = stop_sum < stop_nums[0];
not_need_stop[0] = stop_sum < max_bsz;
return api::SUCCESS;
}
@@ -85,7 +83,6 @@ static int xpu3_wrapper(Context *ctx,
int *seq_lens_encoder,
int *seq_lens_decoder,
int64_t *input_ids,
const int64_t *stop_nums,
const bool *stop_flags,
const bool *is_block_step,
const int64_t *next_tokens,
@@ -100,7 +97,6 @@ static int xpu3_wrapper(Context *ctx,
seq_lens_encoder,
seq_lens_decoder,
reinterpret_cast<XPU_INT64 *>(input_ids),
reinterpret_cast<const XPU_INT64 *>(stop_nums),
stop_flags,
is_block_step,
reinterpret_cast<const XPU_INT64 *>(next_tokens),
@@ -117,7 +113,6 @@ int update_inputs(Context *ctx,
int *seq_lens_encoder,
int *seq_lens_decoder,
int64_t *input_ids,
const int64_t *stop_nums,
const bool *stop_flags,
const bool *is_block_step,
const int64_t *next_tokens,
@@ -132,7 +127,7 @@ int update_inputs(Context *ctx,
seq_lens_encoder,
seq_lens_decoder,
input_ids);
WRAPPER_DUMP_PARAM4(ctx, stop_nums, stop_flags, is_block_step, next_tokens);
WRAPPER_DUMP_PARAM3(ctx, stop_flags, is_block_step, next_tokens);
WRAPPER_DUMP_PARAM3(ctx, bsz, max_bsz, input_ids_stride);
WRAPPER_DUMP(ctx);
if (ctx->dev().type() == api::kCPU) {
@@ -142,7 +137,6 @@ int update_inputs(Context *ctx,
seq_lens_encoder,
seq_lens_decoder,
input_ids,
stop_nums,
stop_flags,
is_block_step,
next_tokens,
@@ -157,7 +151,6 @@ int update_inputs(Context *ctx,
seq_lens_encoder,
seq_lens_decoder,
input_ids,
stop_nums,
stop_flags,
is_block_step,
next_tokens,
@@ -29,7 +29,6 @@ __attribute__((global)) void update_inputs_v1(bool* not_need_stop,
int64_t* topk_ids,
int64_t* input_ids,
int* block_tables,
const int64_t* stop_nums,
bool* stop_flags,
bool* is_block_step,
const int64_t* next_tokens,
@@ -58,7 +57,6 @@ static int xpu3_wrapper(Context* ctx,
int64_t* topk_ids,
int64_t* input_ids,
int* block_tables,
const int64_t* stop_nums,
bool* stop_flags,
bool* is_block_step,
const int64_t* next_tokens,
@@ -86,7 +84,6 @@ static int xpu3_wrapper(Context* ctx,
reinterpret_cast<XPU_INT64*>(topk_ids),
reinterpret_cast<XPU_INT64*>(input_ids),
block_tables,
reinterpret_cast<const XPU_INT64*>(stop_nums),
stop_flags,
is_block_step,
reinterpret_cast<const XPU_INT64*>(next_tokens),
@@ -110,7 +107,6 @@ int update_inputs_v1(Context* ctx,
int64_t* topk_ids,
int64_t* input_ids,
int* block_tables,
const int64_t* stop_nums,
bool* stop_flags,
bool* is_block_step,
const int64_t* next_tokens,
@@ -127,11 +123,10 @@ int update_inputs_v1(Context* ctx,
seq_lens_encoder,
seq_lens_decoder,
step_seq_lens_decoder);
WRAPPER_DUMP_PARAM5(
ctx, prompt_lens, topk_ids, input_ids, block_tables, stop_nums);
WRAPPER_DUMP_PARAM4(ctx, prompt_lens, topk_ids, input_ids, block_tables);
WRAPPER_DUMP_PARAM3(ctx, stop_flags, is_block_step, next_tokens);
WRAPPER_DUMP_PARAM5(
ctx, bsz, max_bsz, input_ids_stride, block_num_per_seq, block_size);
WRAPPER_DUMP_PARAM4(
ctx, bsz, input_ids_stride, block_num_per_seq, block_size);
WRAPPER_DUMP(ctx);
if (ctx->dev().type() == api::kCPU) {
assert(false);
@@ -147,7 +142,6 @@ int update_inputs_v1(Context* ctx,
topk_ids,
input_ids,
block_tables,
stop_nums,
stop_flags,
is_block_step,
next_tokens,
@@ -37,7 +37,6 @@ def cpu_reference(
accept_tokens,
is_block_step,
not_need_stop,
stop_nums,
block_size,
max_draft_tokens,
):
@@ -82,7 +81,7 @@ def cpu_reference(
stop_flag_now_int[bid] = 0
stop_sum = int(stop_flag_now_int.sum())
not_need_stop[0] = stop_sum < int(stop_nums[0])
not_need_stop[0] = stop_sum < max_bsz
class TestSpeculateScheduleCache(unittest.TestCase):
@@ -141,8 +140,6 @@ class TestSpeculateScheduleCache(unittest.TestCase):
self.not_need_stop = paddle.zeros((1,), dtype=paddle.bool).cpu()
# 设置阈值:bid0触发,bid1已停止,填充(5-3)=2 -> stop_sum = 1+1+2 = 4
# 设置stop_nums为5,使得not_need_stop = (4 < 5) = True
self.stop_nums = paddle.to_tensor([5], dtype=paddle.int64)
# 保存NumPy副本用于CPU参考实现对比
self.np_draft_tokens = self.draft_tokens.numpy().copy()
@@ -159,7 +156,6 @@ class TestSpeculateScheduleCache(unittest.TestCase):
self.np_accept_tokens = self.accept_tokens.numpy().copy()
self.np_is_block_step = self.is_block_step.numpy().copy()
self.np_not_need_stop = self.not_need_stop.numpy().copy()
self.np_stop_nums = self.stop_nums.numpy().copy()
def test_correctness_against_cpu_reference(self):
# Run GPU kernel (in-place)
@@ -178,7 +174,6 @@ class TestSpeculateScheduleCache(unittest.TestCase):
self.accept_tokens,
self.is_block_step,
self.not_need_stop,
self.stop_nums,
self.block_size,
self.max_draft_tokens,
)
@@ -199,7 +194,6 @@ class TestSpeculateScheduleCache(unittest.TestCase):
self.np_accept_tokens,
self.np_is_block_step,
self.np_not_need_stop,
self.np_stop_nums,
self.block_size,
self.max_draft_tokens,
)
@@ -233,7 +227,6 @@ class TestSpeculateScheduleCache(unittest.TestCase):
self.not_need_stop[:] = False
# For not_need_stop: stopped_in_real = (bid1 True) = 1, padding = 2 -> stop_sum=3
# With stop_nums=5 -> True
speculate_schedule_cache(
self.draft_tokens,
self.block_tables,
@@ -249,7 +242,6 @@ class TestSpeculateScheduleCache(unittest.TestCase):
self.accept_tokens,
self.is_block_step,
self.not_need_stop,
self.stop_nums,
self.block_size,
self.max_draft_tokens,
)
@@ -1346,10 +1346,7 @@ class MTPSampler(nn.Layer):
sampling_metadata.pre_token_ids,
)
probs = F.softmax(logits)
_, next_tokens = top_k_top_p_sampling(
probs, sampling_metadata.top_p, sampling_metadata.top_k, sampling_metadata.top_k_list
)
next_tokens = paddle.argmax(probs, axis=-1)
# TODO(chenhuan09): add support for logprobs
token_ids = None
logprobs_tensors = None
@@ -351,7 +351,6 @@ def xpu_post_process_normal(
sampled_token_ids,
model_output.input_ids,
share_inputs["block_tables"],
model_output.stop_nums,
model_output.next_tokens,
model_output.is_block_step,
block_size,
@@ -364,7 +363,6 @@ def xpu_post_process_normal(
model_output.seq_lens_encoder,
model_output.seq_lens_decoder,
model_output.input_ids,
model_output.stop_nums,
sampled_token_ids,
model_output.is_block_step,
)
-3
View File
@@ -948,7 +948,6 @@ class XPUModelRunner(ModelRunnerBase):
[1], False, dtype="bool"
).cpu() # TODO(gongshaotian): move to pinnd memory
self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], True, dtype="bool")
self.share_inputs["stop_nums"] = paddle.full([1], max_num_seqs, dtype="int64")
self.share_inputs["bad_tokens"] = paddle.full([max_num_seqs, self.model_config.vocab_size], -1, dtype="int64")
self.share_inputs["bad_tokens_len"] = paddle.full([max_num_seqs], 1, dtype="int64")
@@ -1607,7 +1606,6 @@ class XPUModelRunner(ModelRunnerBase):
eos_token_id=self.share_inputs["eos_token_id"],
not_need_stop=self.share_inputs["not_need_stop"],
input_ids=self.share_inputs["input_ids"],
stop_nums=self.share_inputs["stop_nums"],
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
is_block_step=self.share_inputs["is_block_step"],
@@ -1688,7 +1686,6 @@ class XPUModelRunner(ModelRunnerBase):
self.share_inputs["accept_tokens"],
self.share_inputs["is_block_step"],
self.share_inputs["not_need_stop"],
self.share_inputs["stop_nums"],
self.cache_config.block_size,
self.speculative_config.num_speculative_tokens,
)