[Others] Support constrained decoding when enable_thinking is false (#6248)

* support constrained decoding when enable_thinking is false

* fix

* fix

* fix
This commit is contained in:
GoldPancake
2026-01-28 00:05:17 -08:00
committed by GitHub
parent 27f8799f04
commit 7d6c87c29e
6 changed files with 88 additions and 4 deletions
@@ -67,6 +67,7 @@ __global__ void update_reasoning_status_kernel(
const int* seq_lens_encoder, // [bs]
const int64_t* step_idx, // [bs]
const int64_t* pre_ids, // [bs, max_seq_len]
const bool* enable_thinking, // [bs]
int32_t* reasoning_status, // [bs]
int32_t bs,
int32_t max_seq_len,
@@ -74,8 +75,9 @@ __global__ void update_reasoning_status_kernel(
int64_t line_break_id) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid >= bs) return;
bool enable_thinking_flag = enable_thinking[tid];
int32_t status = reasoning_status[tid];
if (stop_flags[tid] || seq_lens_encoder[tid] > 0 || status == 3) return;
if (stop_flags[tid] || status == 3) return;
int64_t cur_step = step_idx[tid];
const int64_t* pre_ids_now = pre_ids + tid * max_seq_len;
@@ -88,8 +90,11 @@ __global__ void update_reasoning_status_kernel(
// x = 0 -> x = 1
if (status == 0) {
if (t0 == think_end_id || t1 == think_end_id || t2 == think_end_id ||
t3 == think_end_id) {
if (!enable_thinking_flag && seq_lens_encoder[tid] > 0 && cur_step == 0) {
// x = 0 -> x = 2 (only for first token when thinking is disabled)
new_status = 2;
} else if (t0 == think_end_id || t1 == think_end_id || t2 == think_end_id ||
t3 == think_end_id) {
new_status = 1;
}
}
@@ -174,6 +179,7 @@ void reasoning_phase_token_constraint(
const paddle::Tensor& reasoning_status,
const paddle::Tensor& output_padding_offset,
const paddle::Tensor& output_cum_offsets,
const paddle::Tensor& enable_thinking,
int64_t think_end_id,
int64_t line_break_id) {
typedef PDTraits<D> traits_;
@@ -201,6 +207,7 @@ void reasoning_phase_token_constraint(
seq_lens_encoder.data<int>(),
step_idx.data<int64_t>(),
pre_ids.data<int64_t>(),
enable_thinking.data<bool>(),
const_cast<int32_t*>(reasoning_status.data<int32_t>()),
bs,
max_seq_len,
@@ -244,6 +251,7 @@ void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits,
const paddle::Tensor& reasoning_status,
const paddle::Tensor& output_padding_offset,
const paddle::Tensor& output_cum_offsets,
const paddle::Tensor& enable_thinking,
int64_t think_end_id,
int64_t line_break_id) {
switch (logits.type()) {
@@ -259,6 +267,7 @@ void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits,
reasoning_status,
output_padding_offset,
output_cum_offsets,
enable_thinking,
think_end_id,
line_break_id);
case paddle::DataType::BFLOAT16:
@@ -273,6 +282,7 @@ void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits,
reasoning_status,
output_padding_offset,
output_cum_offsets,
enable_thinking,
think_end_id,
line_break_id);
case paddle::DataType::FLOAT32:
@@ -287,6 +297,7 @@ void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits,
reasoning_status,
output_padding_offset,
output_cum_offsets,
enable_thinking,
think_end_id,
line_break_id);
default:
@@ -307,7 +318,8 @@ PD_BUILD_STATIC_OP(reasoning_phase_token_constraint)
"allowed_tokens",
"reasoning_status",
"output_padding_offset",
"output_cum_offsets"})
"output_cum_offsets",
"enable_thinking"})
.Outputs({"logits_out", "reasoning_status_out"})
.Attrs({"think_end_id: int64_t", "line_break_id: int64_t"})
.SetInplaceMap({{"logits", "logits_out"},