mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Others] Support constrained decoding when enable_thinking is false (#6248)
* support constrained decoding when enable_thinking is false * fix * fix * fix
This commit is contained in:
@@ -67,6 +67,7 @@ __global__ void update_reasoning_status_kernel(
|
||||
const int* seq_lens_encoder, // [bs]
|
||||
const int64_t* step_idx, // [bs]
|
||||
const int64_t* pre_ids, // [bs, max_seq_len]
|
||||
const bool* enable_thinking, // [bs]
|
||||
int32_t* reasoning_status, // [bs]
|
||||
int32_t bs,
|
||||
int32_t max_seq_len,
|
||||
@@ -74,8 +75,9 @@ __global__ void update_reasoning_status_kernel(
|
||||
int64_t line_break_id) {
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (tid >= bs) return;
|
||||
bool enable_thinking_flag = enable_thinking[tid];
|
||||
int32_t status = reasoning_status[tid];
|
||||
if (stop_flags[tid] || seq_lens_encoder[tid] > 0 || status == 3) return;
|
||||
if (stop_flags[tid] || status == 3) return;
|
||||
|
||||
int64_t cur_step = step_idx[tid];
|
||||
const int64_t* pre_ids_now = pre_ids + tid * max_seq_len;
|
||||
@@ -88,8 +90,11 @@ __global__ void update_reasoning_status_kernel(
|
||||
|
||||
// x = 0 -> x = 1
|
||||
if (status == 0) {
|
||||
if (t0 == think_end_id || t1 == think_end_id || t2 == think_end_id ||
|
||||
t3 == think_end_id) {
|
||||
if (!enable_thinking_flag && seq_lens_encoder[tid] > 0 && cur_step == 0) {
|
||||
// x = 0 -> x = 2 (only for first token when thinking is disabled)
|
||||
new_status = 2;
|
||||
} else if (t0 == think_end_id || t1 == think_end_id || t2 == think_end_id ||
|
||||
t3 == think_end_id) {
|
||||
new_status = 1;
|
||||
}
|
||||
}
|
||||
@@ -174,6 +179,7 @@ void reasoning_phase_token_constraint(
|
||||
const paddle::Tensor& reasoning_status,
|
||||
const paddle::Tensor& output_padding_offset,
|
||||
const paddle::Tensor& output_cum_offsets,
|
||||
const paddle::Tensor& enable_thinking,
|
||||
int64_t think_end_id,
|
||||
int64_t line_break_id) {
|
||||
typedef PDTraits<D> traits_;
|
||||
@@ -201,6 +207,7 @@ void reasoning_phase_token_constraint(
|
||||
seq_lens_encoder.data<int>(),
|
||||
step_idx.data<int64_t>(),
|
||||
pre_ids.data<int64_t>(),
|
||||
enable_thinking.data<bool>(),
|
||||
const_cast<int32_t*>(reasoning_status.data<int32_t>()),
|
||||
bs,
|
||||
max_seq_len,
|
||||
@@ -244,6 +251,7 @@ void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits,
|
||||
const paddle::Tensor& reasoning_status,
|
||||
const paddle::Tensor& output_padding_offset,
|
||||
const paddle::Tensor& output_cum_offsets,
|
||||
const paddle::Tensor& enable_thinking,
|
||||
int64_t think_end_id,
|
||||
int64_t line_break_id) {
|
||||
switch (logits.type()) {
|
||||
@@ -259,6 +267,7 @@ void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits,
|
||||
reasoning_status,
|
||||
output_padding_offset,
|
||||
output_cum_offsets,
|
||||
enable_thinking,
|
||||
think_end_id,
|
||||
line_break_id);
|
||||
case paddle::DataType::BFLOAT16:
|
||||
@@ -273,6 +282,7 @@ void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits,
|
||||
reasoning_status,
|
||||
output_padding_offset,
|
||||
output_cum_offsets,
|
||||
enable_thinking,
|
||||
think_end_id,
|
||||
line_break_id);
|
||||
case paddle::DataType::FLOAT32:
|
||||
@@ -287,6 +297,7 @@ void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits,
|
||||
reasoning_status,
|
||||
output_padding_offset,
|
||||
output_cum_offsets,
|
||||
enable_thinking,
|
||||
think_end_id,
|
||||
line_break_id);
|
||||
default:
|
||||
@@ -307,7 +318,8 @@ PD_BUILD_STATIC_OP(reasoning_phase_token_constraint)
|
||||
"allowed_tokens",
|
||||
"reasoning_status",
|
||||
"output_padding_offset",
|
||||
"output_cum_offsets"})
|
||||
"output_cum_offsets",
|
||||
"enable_thinking"})
|
||||
.Outputs({"logits_out", "reasoning_status_out"})
|
||||
.Attrs({"think_end_id: int64_t", "line_break_id: int64_t"})
|
||||
.SetInplaceMap({{"logits", "logits_out"},
|
||||
|
||||
Reference in New Issue
Block a user