[Others] remove stop_nums (#6182)

This commit is contained in:
周周周
2026-01-26 12:12:47 +08:00
committed by GitHub
parent 84a1780814
commit 0966df78dc
13 changed files with 14 additions and 63 deletions
+1 -5
View File
@@ -24,7 +24,6 @@ __global__ void update_inputs_kernel_v1(bool* not_need_stop,
int64_t* topk_ids,
int64_t* input_ids,
int* block_tables,
const int64_t* stop_nums,
bool* stop_flags,
bool* is_block_step,
const int64_t* next_tokens,
@@ -103,7 +102,7 @@ __global__ void update_inputs_kernel_v1(bool* not_need_stop,
__syncthreads();
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
if (thread_idx == 0) {
not_need_stop[0] = stop_sum < stop_nums[0];
not_need_stop[0] = stop_sum < max_bsz;
}
}
@@ -117,7 +116,6 @@ void UpdateInputsV1(const paddle::Tensor& stop_flags,
const paddle::Tensor& topk_ids,
const paddle::Tensor& input_ids,
const paddle::Tensor& block_tables,
const paddle::Tensor& stop_nums,
const paddle::Tensor& next_tokens,
const paddle::Tensor& is_block_step,
const int block_size) {
@@ -150,7 +148,6 @@ void UpdateInputsV1(const paddle::Tensor& stop_flags,
const_cast<int64_t*>(topk_ids.data<int64_t>()),
const_cast<int64_t*>(input_ids.data<int64_t>()),
const_cast<int*>(block_tables.data<int>()),
stop_nums.data<int64_t>(),
const_cast<bool*>(stop_flags.data<bool>()),
const_cast<bool*>(is_block_step.data<bool>()),
next_tokens.data<int64_t>(),
@@ -177,7 +174,6 @@ PD_BUILD_STATIC_OP(update_inputs_v1)
"topk_ids",
"input_ids",
"block_tables",
"stop_nums",
"next_tokens",
"is_block_step"})
.Attrs({"block_size: int"})