mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Cherry-Pick][OP][Feature] 统一 limit_thinking_content_length CUDA 算子,支持回复长度限制与注入序列 (#6506)
* Initial plan * feat: migrate core PR6493 changes to release 2.4 Co-authored-by: yuanlehome <23653004+yuanlehome@users.noreply.github.com> * fix ci * fix ci * fix ci * fix ci --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: yuanlehome <23653004+yuanlehome@users.noreply.github.com>
This commit is contained in:
@@ -1019,41 +1019,28 @@ void SaveOutMmsgStatic(const paddle::Tensor& x,
|
||||
int64_t rank_id,
|
||||
bool save_each_rank);
|
||||
|
||||
void LimitThinkingContentLengthV1(const paddle::Tensor& next_tokens,
|
||||
const paddle::Tensor& max_think_lens,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& limit_think_status,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& eos_token_ids,
|
||||
const int64_t think_end_id);
|
||||
void LimitThinkingContentLength(const paddle::Tensor& next_tokens,
|
||||
const paddle::Tensor& max_think_lens,
|
||||
const paddle::Tensor& max_reply_lens,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& limit_status,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& eos_token_ids,
|
||||
const paddle::Tensor& inject_token_ids,
|
||||
const int64_t think_end_id,
|
||||
const bool splitwise_role_is_decode);
|
||||
|
||||
void LimitThinkingContentLengthV2(const paddle::Tensor& next_tokens,
|
||||
const paddle::Tensor& max_think_lens,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& limit_think_status,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const int64_t think_end_id,
|
||||
const int64_t line_break_id);
|
||||
|
||||
void SpeculateLimitThinkingContentLengthV1(
|
||||
const paddle::Tensor& next_tokens,
|
||||
const paddle::Tensor& max_think_lens,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& limit_think_status,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& eos_token_ids,
|
||||
const int64_t think_end_id);
|
||||
|
||||
void SpeculateLimitThinkingContentLengthV2(
|
||||
const paddle::Tensor& next_tokens,
|
||||
const paddle::Tensor& max_think_lens,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& limit_think_status,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const int64_t think_end_id,
|
||||
const int64_t line_break_id);
|
||||
void SpeculateLimitThinkingContentLength(const paddle::Tensor& next_tokens,
|
||||
const paddle::Tensor& max_think_lens,
|
||||
const paddle::Tensor& max_reply_lens,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& limit_status,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& eos_token_ids,
|
||||
const paddle::Tensor& inject_token_ids,
|
||||
const int64_t think_end_id,
|
||||
const bool splitwise_role_is_decode);
|
||||
|
||||
void SpeculateGetLogits(const paddle::Tensor& draft_logits,
|
||||
const paddle::Tensor& next_token_num,
|
||||
@@ -1665,20 +1652,12 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("save_output", &SaveOutMmsgStatic, "save_output function");
|
||||
|
||||
m.def("limit_thinking_content_length_v1",
|
||||
&LimitThinkingContentLengthV1,
|
||||
"limit_thinking_content_length_v1 function");
|
||||
m.def("limit_thinking_content_length",
|
||||
&LimitThinkingContentLength,
|
||||
"limit_thinking_content_length function");
|
||||
|
||||
m.def("limit_thinking_content_length_v2",
|
||||
&LimitThinkingContentLengthV2,
|
||||
"limit_thinking_content_length_v2 function");
|
||||
|
||||
m.def("speculate_limit_thinking_content_length_v1",
|
||||
&SpeculateLimitThinkingContentLengthV1,
|
||||
"speculate limit thinking content length function");
|
||||
|
||||
m.def("speculate_limit_thinking_content_length_v2",
|
||||
&SpeculateLimitThinkingContentLengthV2,
|
||||
m.def("speculate_limit_thinking_content_length",
|
||||
&SpeculateLimitThinkingContentLength,
|
||||
"speculate limit thinking content length function");
|
||||
|
||||
m.def("speculate_get_logits",
|
||||
|
||||
@@ -0,0 +1,188 @@
|
||||
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
__global__ void limit_thinking_content_length_kernel(
|
||||
int64_t* next_tokens,
|
||||
const int* max_think_lens,
|
||||
int* max_reply_lens,
|
||||
const int64_t* step_idx,
|
||||
const int64_t* eos_token_ids,
|
||||
int* limit_status,
|
||||
const bool* stop_flags,
|
||||
const int64_t think_end_id,
|
||||
const int64_t* inject_token_ids,
|
||||
const int bs,
|
||||
const int eos_token_id_len,
|
||||
const int inject_len,
|
||||
const bool splitwise_role_is_decode) {
|
||||
int bid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (bid >= bs) return;
|
||||
if (stop_flags[bid]) return;
|
||||
|
||||
const int max_think_len =
|
||||
max_think_lens[bid]; // <0: 不强制截断思考,但仍有思考阶段
|
||||
int max_reply_len = max_reply_lens[bid]; // <0: 不限制回复
|
||||
|
||||
// 两者都不限制,且你不需要维护状态机的话,可以直接 return
|
||||
// 但如果你希望一直维护状态(便于后续调试/统计),也可以不 return。
|
||||
if (max_think_len < 0 && max_reply_len < 0) return;
|
||||
|
||||
const int done_status = (inject_len > 0) ? (inject_len + 1) : 1;
|
||||
const int reply_base = done_status + 1;
|
||||
|
||||
int status = limit_status[bid];
|
||||
if (status < 0) status = 0;
|
||||
const int prev_status = status;
|
||||
|
||||
int64_t next_token = next_tokens[bid];
|
||||
const int64_t step = step_idx[bid];
|
||||
|
||||
// ======================= 1) 思考阶段:永远监听 think_end_id
|
||||
// ======================= 即使 max_think_len < 0(不强制截断),只要模型输出
|
||||
// </think>,也要把状态置为 done_status
|
||||
if (status == 0 && next_token == think_end_id) {
|
||||
status = done_status;
|
||||
if (max_reply_len >= 0) {
|
||||
max_reply_len += 2;
|
||||
}
|
||||
}
|
||||
|
||||
// ======================= 2) 仅当启用"思考截断"时,才触发注入/覆盖 eos
|
||||
// =======================
|
||||
if (max_think_len >= 0 && status < reply_base) {
|
||||
// A) 超长触发:到达 max_think_len 时开始注入
|
||||
if (max_think_len > 0) {
|
||||
// A) 超长触发:到达 max_think_len 时开始注入(从本 token 起输出
|
||||
// inject_token_ids[0])
|
||||
if (status == 0 && step == max_think_len) {
|
||||
status = (inject_len > 0) ? 1 : done_status;
|
||||
}
|
||||
} else if (max_think_len == 0) {
|
||||
// A) 超长触发:到达 max_think_len 时开始注入
|
||||
if (status == 0 && !splitwise_role_is_decode) {
|
||||
// 如果是集中式或 P 节点:从本 token 起输出 inject_token_ids[0])
|
||||
status = (inject_len > 0) ? 1 : done_status;
|
||||
} else if (status == 0 && splitwise_role_is_decode) {
|
||||
// 如果是 D 节点下:从本 token 起输出 inject_token_ids[1])
|
||||
status = (inject_len > 0) ? 2 : done_status + 1;
|
||||
}
|
||||
}
|
||||
|
||||
// B) 思考阶段提前输出 eos:开始注入(覆盖 eos)
|
||||
if (status == 0 && inject_len > 0) {
|
||||
for (int i = 0; i < eos_token_id_len; i++) {
|
||||
if (eos_token_ids[i] == next_token) {
|
||||
status = 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 注入序列
|
||||
if (inject_len > 0 && status >= 1 && status <= inject_len) {
|
||||
next_token = inject_token_ids[status - 1];
|
||||
status += 1;
|
||||
if (status > done_status) status = done_status;
|
||||
}
|
||||
}
|
||||
|
||||
// 这一拍是否"刚刚进入 done_status"
|
||||
const bool became_done_this_step = (status == done_status) &&
|
||||
(prev_status != done_status) &&
|
||||
(prev_status < reply_base);
|
||||
|
||||
// ======================= 3) 回复长度限制:必须在思考结束之后才生效
|
||||
// =======================
|
||||
if (max_reply_len >= 0) {
|
||||
// 关键:本 step 如果刚输出 </think> 或刚完成注入进入 done_status,不要在同
|
||||
// step 计回复
|
||||
if (!became_done_this_step) {
|
||||
// 只有在"前一拍已经是 done_status",这一拍才允许切换到 reply_base
|
||||
// 开始计数
|
||||
if (status == done_status) {
|
||||
status = reply_base; // reply_len = 0
|
||||
}
|
||||
|
||||
if (status >= reply_base) {
|
||||
int reply_len = status - reply_base;
|
||||
|
||||
if (reply_len >= max_reply_len) {
|
||||
// 强制 EOS;由后置 stop_flags 再判停
|
||||
if (eos_token_id_len > 0) next_token = eos_token_ids[0];
|
||||
status = reply_base + max_reply_len;
|
||||
} else {
|
||||
// 正常输出当前 token,并将回复计数 +1
|
||||
status = reply_base + (reply_len + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
next_tokens[bid] = next_token;
|
||||
limit_status[bid] = status;
|
||||
max_reply_lens[bid] = max_reply_len;
|
||||
}
|
||||
|
||||
void LimitThinkingContentLength(const paddle::Tensor& next_tokens,
|
||||
const paddle::Tensor& max_think_lens,
|
||||
const paddle::Tensor& max_reply_lens,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& limit_status,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& eos_token_ids,
|
||||
const paddle::Tensor& inject_token_ids,
|
||||
const int64_t think_end_id,
|
||||
const bool splitwise_role_is_decode) {
|
||||
const int batch_size = next_tokens.shape()[0];
|
||||
const int eos_token_id_len = eos_token_ids.shape()[0];
|
||||
const int inject_len = inject_token_ids.shape()[0];
|
||||
|
||||
const int threads = 256;
|
||||
const int blocks = (batch_size + threads - 1) / threads;
|
||||
|
||||
limit_thinking_content_length_kernel<<<blocks,
|
||||
threads,
|
||||
0,
|
||||
next_tokens.stream()>>>(
|
||||
const_cast<int64_t*>(next_tokens.data<int64_t>()),
|
||||
max_think_lens.data<int>(),
|
||||
const_cast<int*>(max_reply_lens.data<int>()),
|
||||
step_idx.data<int64_t>(),
|
||||
eos_token_ids.data<int64_t>(),
|
||||
const_cast<int*>(limit_status.data<int>()),
|
||||
stop_flags.data<bool>(),
|
||||
think_end_id,
|
||||
inject_token_ids.data<int64_t>(),
|
||||
batch_size,
|
||||
eos_token_id_len,
|
||||
inject_len,
|
||||
splitwise_role_is_decode);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(limit_thinking_content_length)
|
||||
.Inputs({"next_tokens",
|
||||
"max_think_lens",
|
||||
"max_reply_lens",
|
||||
"step_idx",
|
||||
"limit_status",
|
||||
"stop_flags",
|
||||
"eos_token_ids",
|
||||
"inject_token_ids"})
|
||||
.Attrs({"think_end_id: int64_t", "splitwise_role_is_decode: bool"})
|
||||
.Outputs({"next_tokens_out"})
|
||||
.SetInplaceMap({{"next_tokens", "next_tokens_out"}})
|
||||
.SetKernelFn(PD_KERNEL(LimitThinkingContentLength));
|
||||
@@ -1,116 +0,0 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
__global__ void limit_thinking_content_length_kernel_v1(
|
||||
int64_t *next_tokens,
|
||||
const int *max_think_lens,
|
||||
const int64_t *step_idx,
|
||||
const int64_t *eos_token_ids,
|
||||
int *limit_think_status,
|
||||
bool *stop_flags,
|
||||
const int64_t think_end_id,
|
||||
const int bs,
|
||||
const int eos_token_id_len) {
|
||||
int bid = threadIdx.x;
|
||||
if (bid >= bs) return;
|
||||
|
||||
// 如果该序列未启用思考功能,则直接返回,默认值为 -1,表示不限制思考长度
|
||||
const int max_think_len = max_think_lens[bid];
|
||||
if (max_think_len < 0) return;
|
||||
int current_limit_think_status = limit_think_status[bid];
|
||||
// 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行
|
||||
if (current_limit_think_status == 2 || stop_flags[bid]) {
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t next_token = next_tokens[bid];
|
||||
const int64_t step = step_idx[bid];
|
||||
|
||||
// ======================= 思考阶段控制 =======================
|
||||
// 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束
|
||||
if (current_limit_think_status < 1) {
|
||||
// 当开启思考长度控制时,检查是否超时
|
||||
if (step >= max_think_len) {
|
||||
// 强制将当前token替换为结束思考的token
|
||||
next_token = think_end_id;
|
||||
// 将状态推进到 1, 表示 "正在结束思考"
|
||||
current_limit_think_status = 1;
|
||||
} else {
|
||||
// 检查是否生成了EOS
|
||||
for (int i = 0; i < eos_token_id_len; i++) {
|
||||
if (eos_token_ids[i] == next_token) {
|
||||
// 强制将当前token替换为结束思考的token
|
||||
next_token = think_end_id;
|
||||
// 将状态推进到 1, 表示 "正在结束思考"
|
||||
current_limit_think_status = 1;
|
||||
if (stop_flags[bid]) {
|
||||
stop_flags[bid] = false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// ======================= 思考结束处理 =======================
|
||||
// 阶段 2: 检查是否已满足结束思考的条件 (status < 2)
|
||||
// 这种情况会处理两种场景:
|
||||
// 1. status == 0: 模型自己生成了 think_end_id
|
||||
// 2. status == 1: 上一阶段强制注入了 think_end_id
|
||||
if (current_limit_think_status < 2) {
|
||||
if (next_token == think_end_id) {
|
||||
// 确认思考结束,将状态推进到 2 (响应阶段)
|
||||
current_limit_think_status = 2;
|
||||
}
|
||||
}
|
||||
// 写回更新后的 token
|
||||
next_tokens[bid] = next_token;
|
||||
// 更新全局状态
|
||||
limit_think_status[bid] = current_limit_think_status;
|
||||
}
|
||||
|
||||
void LimitThinkingContentLengthV1(const paddle::Tensor &next_tokens,
|
||||
const paddle::Tensor &max_think_lens,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &limit_think_status,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &eos_token_ids,
|
||||
const int64_t think_end_id) {
|
||||
const int batch_size = next_tokens.shape()[0];
|
||||
const int eos_token_id_len = eos_token_ids.shape()[0];
|
||||
limit_thinking_content_length_kernel_v1<<<1, 1024, 0, next_tokens.stream()>>>(
|
||||
const_cast<int64_t *>(next_tokens.data<int64_t>()),
|
||||
max_think_lens.data<int>(),
|
||||
step_idx.data<int64_t>(),
|
||||
eos_token_ids.data<int64_t>(),
|
||||
const_cast<int *>(limit_think_status.data<int>()),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
think_end_id,
|
||||
batch_size,
|
||||
eos_token_id_len);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(limit_thinking_content_length_v1)
|
||||
.Inputs({"next_tokens",
|
||||
"max_think_lens",
|
||||
"step_idx",
|
||||
"limit_think_status",
|
||||
"stop_flags",
|
||||
"eos_token_ids"})
|
||||
.Attrs({"think_end_id: int64_t"})
|
||||
.Outputs({"next_tokens_out"})
|
||||
.SetInplaceMap({{"next_tokens", "next_tokens_out"}})
|
||||
.SetKernelFn(PD_KERNEL(LimitThinkingContentLengthV1));
|
||||
@@ -1,118 +0,0 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
// status == 0: 正常生成阶段
|
||||
// status == 1: 替换阶段
|
||||
// status == 2: 替换结束阶段
|
||||
// status == 3: 思考结束阶段
|
||||
__global__ void limit_thinking_content_length_kernel_v2(
|
||||
int64_t *next_tokens,
|
||||
const int *max_think_lens,
|
||||
const int64_t *step_idx,
|
||||
int *limit_think_status,
|
||||
const bool *stop_flags,
|
||||
const int64_t think_end_id,
|
||||
const int64_t line_break_id,
|
||||
const int bs) {
|
||||
int bid = threadIdx.x;
|
||||
if (bid >= bs) return;
|
||||
// 如果该序列未启用思考功能,则直接返回,默认值为 -1,表示不限制思考长度
|
||||
const int max_think_len = max_think_lens[bid];
|
||||
if (max_think_len < 0) return;
|
||||
int current_limit_think_status = limit_think_status[bid];
|
||||
// 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行
|
||||
if (current_limit_think_status == 3 || stop_flags[bid]) {
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t next_token = next_tokens[bid];
|
||||
const int64_t step = step_idx[bid];
|
||||
|
||||
// ======================= 思考阶段控制 =======================
|
||||
// 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束
|
||||
// 阶段 2: 在替换 (status == 1), 检查是否替换结束
|
||||
if (current_limit_think_status <= 1) {
|
||||
// 当开启思考长度控制时,检查是否超时
|
||||
if (step == max_think_len) {
|
||||
// 强制将当前token替换为结束思考的token
|
||||
next_token = line_break_id;
|
||||
current_limit_think_status = 1;
|
||||
} else if (step == max_think_len + 1) {
|
||||
// 强制将当前token替换为结束思考的token
|
||||
next_token = think_end_id;
|
||||
current_limit_think_status = 1;
|
||||
} else if (step == max_think_len + 2) {
|
||||
// 强制将当前token替换为结束思考的token
|
||||
next_token = line_break_id;
|
||||
current_limit_think_status = 1;
|
||||
} else if (step == max_think_len + 3) {
|
||||
// 强制将当前token替换为结束思考的token
|
||||
next_token = line_break_id;
|
||||
// 将状态推进到 1, 表示 "正在结束思考"
|
||||
current_limit_think_status = 2;
|
||||
}
|
||||
}
|
||||
// ======================= 思考结束处理 =======================
|
||||
// 阶段 3: 检查是否已满足结束思考的条件 (status == 0 || status == 2)
|
||||
// 这种情况会处理两种场景:
|
||||
// 1. status == 0: 模型可能自己生成了 </think>
|
||||
// 2. status == 2: 上一阶段强制注入了 \n</think>\n\n
|
||||
if (current_limit_think_status == 0) {
|
||||
if (next_token == think_end_id) {
|
||||
// 确认思考结束,将状态推进到 3 (响应阶段)
|
||||
current_limit_think_status = 3;
|
||||
}
|
||||
}
|
||||
if (current_limit_think_status == 2) {
|
||||
// 确认思考结束,将状态推进到 3 (响应阶段)
|
||||
current_limit_think_status = 3;
|
||||
}
|
||||
// 写回更新后的 token
|
||||
next_tokens[bid] = next_token;
|
||||
// 更新全局状态
|
||||
limit_think_status[bid] = current_limit_think_status;
|
||||
}
|
||||
|
||||
void LimitThinkingContentLengthV2(const paddle::Tensor &next_tokens,
|
||||
const paddle::Tensor &max_think_lens,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &limit_think_status,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const int64_t think_end_id,
|
||||
const int64_t line_break_id) {
|
||||
const int batch_size = next_tokens.shape()[0];
|
||||
limit_thinking_content_length_kernel_v2<<<1, 1024, 0, next_tokens.stream()>>>(
|
||||
const_cast<int64_t *>(next_tokens.data<int64_t>()),
|
||||
max_think_lens.data<int>(),
|
||||
step_idx.data<int64_t>(),
|
||||
const_cast<int *>(limit_think_status.data<int>()),
|
||||
stop_flags.data<bool>(),
|
||||
think_end_id,
|
||||
line_break_id,
|
||||
batch_size);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(limit_thinking_content_length_v2)
|
||||
.Inputs({"next_tokens",
|
||||
"max_think_lens",
|
||||
"step_idx",
|
||||
"limit_think_status",
|
||||
"stop_flags"})
|
||||
.Attrs({"think_end_id: int64_t", "line_break_id: int64_t"})
|
||||
.Outputs({"next_tokens_out"})
|
||||
.SetInplaceMap({{"next_tokens", "next_tokens_out"}})
|
||||
.SetKernelFn(PD_KERNEL(LimitThinkingContentLengthV2));
|
||||
@@ -0,0 +1,251 @@
|
||||
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
// 1) 支持 inject_token_ids(任意长度)
|
||||
// 2) 支持限制回复长度 max_reply_lens
|
||||
// 3) 语义对齐非MTP limit_thinking_content_length:
|
||||
// - max_think_len < 0:不强制截断思考,但仍会监听 think_end_id
|
||||
// 来进入回复阶段
|
||||
// - max_reply_len >= 0:仅在"思考结束后的下一个 token"开始计回复长度
|
||||
//
|
||||
// 状态机(复用 limit_status)
|
||||
// - 0:思考生成阶段
|
||||
// - 1..inject_len:注入 inject_token_ids[status - 1]
|
||||
// - done_status = inject_len + 1(inject_len==0 时
|
||||
// done_status=1):思考结束点(本 token 不计入回复)
|
||||
// - reply_base = done_status + 1
|
||||
// - status >= reply_base:回复阶段计数,reply_len = status -
|
||||
// reply_base(已输出回复 token 数)
|
||||
__global__ void speculate_limit_thinking_content_length_kernel(
|
||||
int64_t* next_tokens, // [bs, tokens_per_step]
|
||||
const int* max_think_lens, // [bs]
|
||||
int* max_reply_lens, // [bs]
|
||||
int64_t* step_idx, // [bs]
|
||||
const int64_t* eos_token_ids, // [eos_len]
|
||||
int* limit_status, // [bs]
|
||||
int* accept_num, // [bs]
|
||||
const bool* stop_flags, // [bs]
|
||||
const int64_t think_end_id,
|
||||
const int64_t* inject_token_ids, // [inject_len]
|
||||
const int tokens_per_step,
|
||||
const int bs,
|
||||
const int eos_token_id_len,
|
||||
const int inject_len,
|
||||
const bool splitwise_role_is_decode) {
|
||||
const int bid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (bid >= bs) return;
|
||||
|
||||
const int original_accept_num = accept_num[bid];
|
||||
if (original_accept_num <= 0) return;
|
||||
if (stop_flags[bid]) return;
|
||||
|
||||
const int max_think_len =
|
||||
max_think_lens[bid]; // <0: 不强制截断思考(但仍监听 think_end_id)
|
||||
int max_reply_len = max_reply_lens[bid]; // <0: 不限制回复
|
||||
if (max_think_len < 0 && max_reply_len < 0) return;
|
||||
|
||||
// 状态常量(允许 inject_len==0)
|
||||
const int done_status = (inject_len > 0) ? (inject_len + 1) : 1;
|
||||
const int reply_base = done_status + 1;
|
||||
|
||||
int status = limit_status[bid];
|
||||
if (status < 0) status = 0;
|
||||
|
||||
int new_accept_num = original_accept_num;
|
||||
|
||||
// 本 step 的 token offset 对应的绝对 step
|
||||
const int64_t current_base_step = step_idx[bid] - original_accept_num + 1;
|
||||
|
||||
for (int token_offset = 0; token_offset < original_accept_num;
|
||||
token_offset++) {
|
||||
const int token_idx = bid * tokens_per_step + token_offset;
|
||||
int64_t next_token = next_tokens[token_idx];
|
||||
const int64_t current_step = current_base_step + token_offset;
|
||||
|
||||
const int prev_status = status;
|
||||
bool condition_triggered = false;
|
||||
|
||||
// ======================= 1) 思考阶段监听 think_end_id(语义对齐非MTP)
|
||||
// =======================
|
||||
// 注意:这里必须放在"注入触发逻辑"之前,因为如果模型自然输出 </think>,
|
||||
// 这一 token 应该把 status 置为 done_status,但"本 token 不计入回复"。
|
||||
if (status == 0 && next_token == think_end_id) {
|
||||
status = done_status;
|
||||
// 不截断 accept_num:后续 token 可以继续(但回复计数从下一个 token
|
||||
// 才开始)
|
||||
if (max_reply_len >= 0) {
|
||||
max_reply_len += 2;
|
||||
}
|
||||
}
|
||||
|
||||
// ======================= 2) 仅当启用"思考截断"(max_think_len>=0)
|
||||
// 时触发注入 =======================
|
||||
if (max_think_len >= 0 && status < reply_base) {
|
||||
if (max_think_len > 0) {
|
||||
// A) 超长触发:到达 max_think_len 时开始注入(从本 token 起输出
|
||||
// inject_token_ids[0])
|
||||
if (status == 0 &&
|
||||
(current_step - 1) ==
|
||||
max_think_len) { // current_step - 1 是因为 speculate_verify 里
|
||||
// step_idx + 1 了
|
||||
status = (inject_len > 0) ? 1 : done_status;
|
||||
}
|
||||
} else if (max_think_len == 0) {
|
||||
// A) 超长触发:到达 max_think_len 时开始注入
|
||||
if (status == 0 && !splitwise_role_is_decode) {
|
||||
// 如果是集中式或 P 节点:从本 token 起输出 inject_token_ids[0])
|
||||
status = (inject_len > 0) ? 1 : done_status;
|
||||
} else if (status == 0 && splitwise_role_is_decode) {
|
||||
// 如果是 D 节点下:从本 token 起输出 inject_token_ids[1])
|
||||
status = (inject_len > 0) ? 2 : done_status + 1;
|
||||
}
|
||||
}
|
||||
|
||||
// B) 思考阶段提前输出 eos:开始注入(从本 token 起覆盖 eos 为
|
||||
// inject_token_ids[0])
|
||||
if (status == 0 && inject_len > 0) {
|
||||
for (int i = 0; i < eos_token_id_len; i++) {
|
||||
if (eos_token_ids[i] == next_token) {
|
||||
status = 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 注入序列:如果进入注入区间,则覆盖 next_token,并推进状态。
|
||||
// 由于覆盖了 token,必须截断 accept_num,避免一口气接受后续 token。
|
||||
if (inject_len > 0 && status >= 1 && status <= inject_len) {
|
||||
next_token = inject_token_ids[status - 1];
|
||||
status += 1;
|
||||
if (status > done_status) status = done_status; // 防御
|
||||
condition_triggered = true;
|
||||
}
|
||||
}
|
||||
|
||||
// 这一 token 是否"刚刚进入 done_status"
|
||||
// 典型场景:自然输出 </think>(0 -> done_status)
|
||||
// 或 注入最后一步(inject_len -> done_status)
|
||||
const bool became_done_this_token = (status == done_status) &&
|
||||
(prev_status != done_status) &&
|
||||
(prev_status < reply_base);
|
||||
|
||||
// ======================= 3) 回复长度限制(对齐非MTP)
|
||||
// ======================= 关键:刚进入 done_status 的这一 token(</think>
|
||||
// 或注入 token)不计入回复,也不在这一 token 开始回复计数
|
||||
if (max_reply_len >= 0) {
|
||||
if (!became_done_this_token) {
|
||||
// 只有在"前一个 token 已经是 done_status",当前 token 才允许进入
|
||||
// reply_base 开始计数
|
||||
if (status == done_status) {
|
||||
status = reply_base; // reply_len = 0
|
||||
}
|
||||
|
||||
if (status >= reply_base) {
|
||||
int reply_len =
|
||||
status - reply_base; // 已输出回复 token 数(不含当前 token)
|
||||
|
||||
if (reply_len >= max_reply_len) {
|
||||
// 达到上限:强制 EOS,并截断 accept_num
|
||||
if (eos_token_id_len > 0) next_token = eos_token_ids[0];
|
||||
status = reply_base + max_reply_len;
|
||||
condition_triggered = true;
|
||||
} else {
|
||||
// 正常输出当前 token,并回复计数 +1
|
||||
status = reply_base + (reply_len + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 写回当前 token
|
||||
next_tokens[token_idx] = next_token;
|
||||
|
||||
// 若本 token 触发了"覆盖 token"(注入 or 强制 EOS),则截断 accept_num
|
||||
if (condition_triggered) {
|
||||
new_accept_num = token_offset + 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// 更新 step_idx / accept_num(被截断的 token 需要回退
|
||||
// step_idx)
|
||||
const int discarded_tokens = original_accept_num - new_accept_num;
|
||||
if (discarded_tokens > 0) {
|
||||
step_idx[bid] -= discarded_tokens;
|
||||
}
|
||||
|
||||
accept_num[bid] = new_accept_num;
|
||||
limit_status[bid] = status;
|
||||
max_reply_lens[bid] = max_reply_len;
|
||||
}
|
||||
|
||||
void SpeculateLimitThinkingContentLength(
|
||||
const paddle::Tensor& next_tokens,
|
||||
const paddle::Tensor& max_think_lens,
|
||||
const paddle::Tensor& max_reply_lens, // 新增
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& limit_status,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& eos_token_ids,
|
||||
const paddle::Tensor& inject_token_ids, // 新增:支持任意长度注入序列
|
||||
const int64_t think_end_id,
|
||||
const bool splitwise_role_is_decode) {
|
||||
const int batch_size = next_tokens.shape()[0];
|
||||
const int tokens_per_step = next_tokens.shape()[1];
|
||||
const int eos_token_id_len = eos_token_ids.shape()[0];
|
||||
const int inject_len = inject_token_ids.shape()[0];
|
||||
|
||||
const int threads = 256;
|
||||
int blocks = (batch_size + threads - 1) / threads;
|
||||
if (blocks > 1024) blocks = 1024;
|
||||
|
||||
speculate_limit_thinking_content_length_kernel<<<blocks,
|
||||
threads,
|
||||
0,
|
||||
next_tokens.stream()>>>(
|
||||
const_cast<int64_t*>(next_tokens.data<int64_t>()),
|
||||
max_think_lens.data<int>(),
|
||||
const_cast<int*>(max_reply_lens.data<int>()),
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
eos_token_ids.data<int64_t>(),
|
||||
const_cast<int*>(limit_status.data<int>()),
|
||||
const_cast<int*>(accept_num.data<int>()),
|
||||
stop_flags.data<bool>(),
|
||||
think_end_id,
|
||||
(inject_len > 0) ? inject_token_ids.data<int64_t>() : nullptr,
|
||||
tokens_per_step,
|
||||
batch_size,
|
||||
eos_token_id_len,
|
||||
inject_len,
|
||||
splitwise_role_is_decode);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_limit_thinking_content_length)
|
||||
.Inputs({"next_tokens",
|
||||
"max_think_lens",
|
||||
"max_reply_lens",
|
||||
"step_idx",
|
||||
"limit_status",
|
||||
"accept_num",
|
||||
"stop_flags",
|
||||
"eos_token_ids",
|
||||
"inject_token_ids"})
|
||||
.Attrs({"think_end_id: int64_t", "splitwise_role_is_decode: bool"})
|
||||
.Outputs({"next_tokens_out"})
|
||||
.SetInplaceMap({{"next_tokens", "next_tokens_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateLimitThinkingContentLength));
|
||||
@@ -1,152 +0,0 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
__global__ void speculate_limit_thinking_content_length_kernel_v1(
|
||||
int64_t* next_tokens,
|
||||
const int* max_think_lens,
|
||||
int64_t* step_idx,
|
||||
const int64_t* eos_token_ids,
|
||||
int* limit_think_status,
|
||||
int* accept_num,
|
||||
bool* stop_flags,
|
||||
const int64_t think_end_id,
|
||||
const int tokens_per_step,
|
||||
const int bs,
|
||||
const int eos_token_id_len) {
|
||||
int bid = threadIdx.x;
|
||||
if (bid >= bs) return;
|
||||
|
||||
const int original_accept_num = accept_num[bid];
|
||||
if (original_accept_num <= 0) return;
|
||||
|
||||
// 如果该序列未启用思考功能,则直接返回,默认值为 -1,表示不限制思考长度
|
||||
const int max_think_len = max_think_lens[bid];
|
||||
if (max_think_len < 0) return;
|
||||
int current_limit_think_status = limit_think_status[bid];
|
||||
// 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行
|
||||
if (current_limit_think_status == 2 || stop_flags[bid]) {
|
||||
return;
|
||||
}
|
||||
|
||||
int new_accept_num = original_accept_num;
|
||||
|
||||
const int64_t current_base_step = step_idx[bid] - original_accept_num + 1;
|
||||
|
||||
for (int token_offset = 0; token_offset < original_accept_num;
|
||||
token_offset++) {
|
||||
const int token_idx = bid * tokens_per_step + token_offset;
|
||||
int64_t next_token = next_tokens[token_idx];
|
||||
const int64_t current_step = current_base_step + token_offset;
|
||||
|
||||
bool condition_triggered = false;
|
||||
|
||||
// ======================= 思考阶段控制 =======================
|
||||
// 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束
|
||||
// 阶段 2: 在替换 (status == 1), 检查是否替换结束
|
||||
if (current_limit_think_status < 1) {
|
||||
// 当开启思考长度控制时,检查是否超时
|
||||
if (current_step >= max_think_len) {
|
||||
// 强制将当前token替换为结束思考的token
|
||||
next_token = think_end_id;
|
||||
current_limit_think_status = 1;
|
||||
condition_triggered = true; // 因为修改了token,需要截断
|
||||
} else {
|
||||
// 检查是否生成了EOS
|
||||
for (int i = 0; i < eos_token_id_len; i++) {
|
||||
if (eos_token_ids[i] == next_token) {
|
||||
// 强制将当前token替换为结束思考的token
|
||||
next_token = think_end_id;
|
||||
current_limit_think_status = 1;
|
||||
condition_triggered = true; // 因为修改了token,需要截断
|
||||
if (stop_flags[bid]) {
|
||||
stop_flags[bid] = false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ======================= 思考结束处理 =======================
|
||||
// 阶段 3: 检查是否已满足结束思考的条件 (status == 0 || status == 2)
|
||||
// 这种情况会处理两种场景:
|
||||
// 1. status == 0: 模型可能自己生成了 </think>
|
||||
// 2. status == 2: 上一阶段强制注入了 </think>
|
||||
if (current_limit_think_status < 2) {
|
||||
if (next_token == think_end_id) {
|
||||
// 确认思考结束,将状态推进到 2 (响应阶段)
|
||||
current_limit_think_status = 2;
|
||||
}
|
||||
}
|
||||
|
||||
next_tokens[token_idx] = next_token;
|
||||
|
||||
if (condition_triggered) {
|
||||
new_accept_num = token_offset + 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// 更新全局状态
|
||||
int discarded_tokens = original_accept_num - new_accept_num;
|
||||
if (discarded_tokens > 0) {
|
||||
step_idx[bid] -= discarded_tokens;
|
||||
}
|
||||
|
||||
accept_num[bid] = new_accept_num;
|
||||
limit_think_status[bid] = current_limit_think_status;
|
||||
}
|
||||
|
||||
void SpeculateLimitThinkingContentLengthV1(
|
||||
const paddle::Tensor& next_tokens,
|
||||
const paddle::Tensor& max_think_lens,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& limit_think_status,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& eos_token_ids,
|
||||
const int64_t think_end_id) {
|
||||
const int batch_size = next_tokens.shape()[0];
|
||||
const int tokens_per_step = next_tokens.shape()[1];
|
||||
const int eos_token_id_len = eos_token_ids.shape()[0];
|
||||
|
||||
speculate_limit_thinking_content_length_kernel_v1<<<1, 1024>>>(
|
||||
const_cast<int64_t*>(next_tokens.data<int64_t>()),
|
||||
max_think_lens.data<int>(),
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
eos_token_ids.data<int64_t>(),
|
||||
const_cast<int*>(limit_think_status.data<int>()),
|
||||
const_cast<int*>(accept_num.data<int>()),
|
||||
const_cast<bool*>(stop_flags.data<bool>()),
|
||||
think_end_id,
|
||||
tokens_per_step,
|
||||
batch_size,
|
||||
eos_token_id_len);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_limit_thinking_content_length_v1)
|
||||
.Inputs({"next_tokens",
|
||||
"max_think_lens",
|
||||
"step_idx",
|
||||
"limit_think_status",
|
||||
"accept_num",
|
||||
"stop_flags",
|
||||
"eos_token_ids"})
|
||||
.Attrs({"think_end_id: int64_t"})
|
||||
.Outputs({"next_tokens_out"})
|
||||
.SetInplaceMap({{"next_tokens", "next_tokens_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateLimitThinkingContentLengthV1));
|
||||
@@ -1,158 +0,0 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
// status == 0: 正常生成阶段
|
||||
// status == 1: 替换阶段
|
||||
// status == 2: 替换结束阶段
|
||||
// status == 3: 思考结束阶段
|
||||
__global__ void speculate_limit_thinking_content_length_kernel_v2(
|
||||
int64_t* next_tokens,
|
||||
const int* max_think_lens,
|
||||
int64_t* step_idx,
|
||||
int* limit_think_status,
|
||||
int* accept_num,
|
||||
const bool* stop_flags,
|
||||
const int64_t think_end_id,
|
||||
const int64_t line_break_id,
|
||||
const int tokens_per_step,
|
||||
const int bs) {
|
||||
int bid = threadIdx.x;
|
||||
if (bid >= bs) return;
|
||||
|
||||
const int original_accept_num = accept_num[bid];
|
||||
if (original_accept_num <= 0) return;
|
||||
|
||||
// 如果该序列未启用思考功能,则直接返回,默认值为 -1,表示不限制思考长度
|
||||
const int max_think_len = max_think_lens[bid];
|
||||
if (max_think_len < 0) return;
|
||||
int current_limit_think_status = limit_think_status[bid];
|
||||
// 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行.
|
||||
if (current_limit_think_status == 3 || stop_flags[bid]) {
|
||||
return;
|
||||
}
|
||||
|
||||
int new_accept_num = original_accept_num;
|
||||
|
||||
const int64_t current_base_step = step_idx[bid] - original_accept_num + 1;
|
||||
|
||||
for (int token_offset = 0; token_offset < original_accept_num;
|
||||
token_offset++) {
|
||||
const int token_idx = bid * tokens_per_step + token_offset;
|
||||
int64_t next_token = next_tokens[token_idx];
|
||||
const int64_t current_step = current_base_step + token_offset;
|
||||
|
||||
bool condition_triggered = false;
|
||||
|
||||
// ======================= 思考阶段控制 =======================
|
||||
// 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束
|
||||
// 阶段 2: 在替换 (status == 1), 检查是否替换结束
|
||||
if (current_limit_think_status <= 1) {
|
||||
// 当开启思考长度控制时,检查是否超时
|
||||
if (current_step == max_think_len) {
|
||||
// 强制将当前token替换为结束思考的token
|
||||
next_token = line_break_id;
|
||||
current_limit_think_status = 1;
|
||||
condition_triggered = true; // 因为修改了token,需要截断
|
||||
} else if (current_step == max_think_len + 1) {
|
||||
// 强制将当前token替换为结束思考的token
|
||||
next_token = think_end_id;
|
||||
current_limit_think_status = 1;
|
||||
condition_triggered = true; // 因为修改了token,需要截断
|
||||
} else if (current_step == max_think_len + 2) {
|
||||
// 强制将当前token替换为结束思考的token
|
||||
next_token = line_break_id;
|
||||
current_limit_think_status = 1;
|
||||
condition_triggered = true; // 因为修改了token,需要截断
|
||||
} else if (current_step == max_think_len + 3) {
|
||||
// 强制将当前token替换为结束思考的token
|
||||
next_token = line_break_id;
|
||||
// 将状态推进到 1, 表示 "正在结束思考"
|
||||
current_limit_think_status = 2;
|
||||
condition_triggered = true; // 因为修改了token,需要截断
|
||||
}
|
||||
}
|
||||
|
||||
// ======================= 思考结束处理 =======================
|
||||
// 阶段 3: 检查是否已满足结束思考的条件 (status == 0 || status == 2)
|
||||
// 这种情况会处理两种场景:
|
||||
// 1. status == 0: 模型可能自己生成了 </think>
|
||||
// 2. status == 2: 上一阶段强制注入了 \n</think>\n\n
|
||||
if (current_limit_think_status == 0) {
|
||||
if (next_token == think_end_id) {
|
||||
// 确认思考结束,将状态推进到 3 (响应阶段)
|
||||
current_limit_think_status = 3;
|
||||
}
|
||||
}
|
||||
if (current_limit_think_status == 2) {
|
||||
// 确认思考结束,将状态推进到 3 (响应阶段)
|
||||
current_limit_think_status = 3;
|
||||
}
|
||||
|
||||
next_tokens[token_idx] = next_token;
|
||||
|
||||
if (condition_triggered) {
|
||||
new_accept_num = token_offset + 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// 更新全局状态
|
||||
int discarded_tokens = original_accept_num - new_accept_num;
|
||||
if (discarded_tokens > 0) {
|
||||
step_idx[bid] -= discarded_tokens;
|
||||
}
|
||||
|
||||
accept_num[bid] = new_accept_num;
|
||||
limit_think_status[bid] = current_limit_think_status;
|
||||
}
|
||||
|
||||
void SpeculateLimitThinkingContentLengthV2(
|
||||
const paddle::Tensor& next_tokens,
|
||||
const paddle::Tensor& max_think_lens,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& limit_think_status,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const int64_t think_end_id,
|
||||
const int64_t line_break_id) {
|
||||
const int batch_size = next_tokens.shape()[0];
|
||||
const int tokens_per_step = next_tokens.shape()[1];
|
||||
|
||||
speculate_limit_thinking_content_length_kernel_v2<<<1, 1024>>>(
|
||||
const_cast<int64_t*>(next_tokens.data<int64_t>()),
|
||||
max_think_lens.data<int>(),
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
const_cast<int*>(limit_think_status.data<int>()),
|
||||
const_cast<int*>(accept_num.data<int>()),
|
||||
stop_flags.data<bool>(),
|
||||
think_end_id,
|
||||
line_break_id,
|
||||
tokens_per_step,
|
||||
batch_size);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_limit_thinking_content_length_v2)
|
||||
.Inputs({"next_tokens",
|
||||
"max_think_lens",
|
||||
"step_idx",
|
||||
"limit_think_status",
|
||||
"accept_num",
|
||||
"stop_flags"})
|
||||
.Attrs({"think_end_id: int64_t", "line_break_id: int64_t"})
|
||||
.Outputs({"next_tokens_out"})
|
||||
.SetInplaceMap({{"next_tokens", "next_tokens_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateLimitThinkingContentLengthV2));
|
||||
@@ -304,8 +304,7 @@ elif paddle.is_compiled_with_cuda():
|
||||
"gpu_ops/noaux_tc_redundant.cu",
|
||||
"gpu_ops/custom_all_reduce/all_reduce.cu",
|
||||
"gpu_ops/merge_prefill_decode_output.cu",
|
||||
"gpu_ops/limit_thinking_content_length_v1.cu",
|
||||
"gpu_ops/limit_thinking_content_length_v2.cu",
|
||||
"gpu_ops/limit_thinking_content_length.cu",
|
||||
"gpu_ops/update_attn_mask_offsets.cu",
|
||||
"gpu_ops/fused_neox_rope_embedding.cu",
|
||||
"gpu_ops/gelu_tanh.cu",
|
||||
@@ -553,8 +552,7 @@ elif paddle.is_compiled_with_custom_device("iluvatar_gpu"):
|
||||
"gpu_ops/text_image_index_out.cu",
|
||||
"gpu_ops/text_image_gather_scatter.cu",
|
||||
"gpu_ops/set_data_ipc.cu",
|
||||
"gpu_ops/limit_thinking_content_length_v1.cu",
|
||||
"gpu_ops/limit_thinking_content_length_v2.cu",
|
||||
"gpu_ops/limit_thinking_content_length.cu",
|
||||
"iluvatar_ops/moe_dispatch.cu",
|
||||
"iluvatar_ops/moe_reduce.cu",
|
||||
"iluvatar_ops/paged_attn.cu",
|
||||
@@ -620,8 +618,7 @@ elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
|
||||
"gpu_ops/text_image_gather_scatter.cu",
|
||||
"gpu_ops/text_image_index_out.cu",
|
||||
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
|
||||
"gpu_ops/limit_thinking_content_length_v1.cu",
|
||||
"gpu_ops/limit_thinking_content_length_v2.cu",
|
||||
"gpu_ops/limit_thinking_content_length.cu",
|
||||
"gpu_ops/update_attn_mask_offsets.cu",
|
||||
"gpu_ops/append_attn/mla_cache_kernel.cu",
|
||||
"gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu",
|
||||
|
||||
@@ -241,6 +241,7 @@ class ModelConfig:
|
||||
self.think_end_id = args.get("think_end_id", -1)
|
||||
self.im_patch_id = args.get("image_patch_id", -1)
|
||||
self.line_break_id = args.get("line_break_id", -1)
|
||||
self.think_truncate_prompt_ids = args.get("think_truncate_prompt_ids", [-1])
|
||||
|
||||
num_max_logprobs = args.get("max_logprobs", None)
|
||||
if num_max_logprobs is not None and num_max_logprobs < -1:
|
||||
|
||||
@@ -518,6 +518,15 @@ class LLMEngine:
|
||||
llm_logger.info("No </think> token found in vocabulary, the model can not do reasoning.")
|
||||
image_patch_id = self.data_processor.tokenizer.get_vocab().get("<|IMAGE_PLACEHOLDER|>", -1)
|
||||
line_break_id = self.data_processor.tokenizer.get_vocab().get("\n", -1)
|
||||
try:
|
||||
think_truncate_prompt_ids = self.data_processor.tokenizer.convert_tokens_to_ids(
|
||||
self.data_processor.tokenizer.tokenize(self.data_processor.tokenizer.think_truncate_prompt)
|
||||
)
|
||||
except Exception:
|
||||
think_truncate_prompt_ids = self.data_processor.tokenizer.convert_tokens_to_ids(
|
||||
self.data_processor.tokenizer.tokenize(envs.FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR)
|
||||
)
|
||||
llm_logger.info(f"Get think_truncate_prompt_ids {think_truncate_prompt_ids} from tokenizer.")
|
||||
|
||||
ports = ",".join(self.cfg.parallel_config.engine_worker_queue_port)
|
||||
ips = None
|
||||
@@ -548,6 +557,7 @@ class LLMEngine:
|
||||
f" --think_end_id {think_end_id}"
|
||||
f" --image_patch_id {image_patch_id}"
|
||||
f" --line_break_id {line_break_id}"
|
||||
f" --think_truncate_prompt_ids '{json.dumps(think_truncate_prompt_ids)}'"
|
||||
f" --speculative_config '{self.cfg.speculative_config.to_json_string()}'"
|
||||
f" --graph_optimization_config '{self.cfg.graph_opt_config.to_json_string()}'"
|
||||
f" --guided_decoding_backend {self.cfg.structured_outputs_config.guided_decoding_backend}"
|
||||
|
||||
@@ -71,6 +71,7 @@ class SamplingParams:
|
||||
can complete the sequence.
|
||||
max_tokens: Maximum number of tokens to generate per output sequence.
|
||||
reasoning_max_tokens: Maximum number of tokens to generate for reasoning per output sequence.
|
||||
response_max_tokens: Maximum number of tokens to generate for response per output sequence.
|
||||
min_tokens: Minimum number of tokens to generate per output sequence
|
||||
before EOS or stop_token_ids can be generated
|
||||
logprobs: Number of log probabilities to return per output token.
|
||||
@@ -97,6 +98,7 @@ class SamplingParams:
|
||||
stop_seqs_len: Optional[int] = None
|
||||
max_tokens: Optional[int] = None
|
||||
reasoning_max_tokens: Optional[int] = None
|
||||
response_max_tokens: Optional[int] = None
|
||||
min_tokens: int = 1
|
||||
logprobs: Optional[int] = None
|
||||
prompt_logprobs: Optional[int] = None
|
||||
@@ -135,6 +137,7 @@ class SamplingParams:
|
||||
stop_token_ids=None,
|
||||
max_tokens=None,
|
||||
reasoning_max_tokens=None,
|
||||
response_max_tokens=None,
|
||||
min_tokens=1,
|
||||
logprobs=None,
|
||||
prompt_logprobs=None,
|
||||
@@ -159,6 +162,7 @@ class SamplingParams:
|
||||
stop_token_ids=stop_token_ids,
|
||||
max_tokens=max_tokens if max_tokens is not None else 8192,
|
||||
reasoning_max_tokens=reasoning_max_tokens,
|
||||
response_max_tokens=response_max_tokens,
|
||||
min_tokens=min_tokens,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
|
||||
@@ -1104,8 +1104,6 @@ class ResourceManagerV1(ResourceManager):
|
||||
If can not, return False
|
||||
"""
|
||||
assert self.config.scheduler_config.splitwise_role == "decode", "Only D instance can call this method"
|
||||
if request.reasoning_max_tokens is not None:
|
||||
request.reasoning_max_tokens -= 1
|
||||
request.need_prefill_tokens = len(request.prompt_token_ids)
|
||||
need_prealloc_prefill_blocks = (
|
||||
request.need_prefill_tokens + self.config.cache_config.block_size - 1
|
||||
|
||||
@@ -418,13 +418,18 @@ class EngineClient:
|
||||
)
|
||||
|
||||
if data.get("reasoning_max_tokens") is not None:
|
||||
if data["reasoning_max_tokens"] < 1:
|
||||
raise ParameterError("reasoning_max_tokens", "reasoning_max_tokens must be greater than 1")
|
||||
if data["reasoning_max_tokens"] < 0:
|
||||
raise ParameterError("reasoning_max_tokens", "reasoning_max_tokens must be greater than 0")
|
||||
if data["reasoning_max_tokens"] > data["max_tokens"]:
|
||||
data["reasoning_max_tokens"] = data["max_tokens"]
|
||||
api_server_logger.warning(
|
||||
f"req_id: {data['request_id']}, reasoning_max_tokens exceeds max_tokens, the value of reasoning_max_tokens will be adjusted to {data['max_tokens']}"
|
||||
)
|
||||
|
||||
if data.get("response_max_tokens") is not None:
|
||||
if data["response_max_tokens"] <= 0:
|
||||
raise ParameterError("response_max_tokens", "response_max_tokens must be greater than 0")
|
||||
|
||||
if data.get("temperature") is not None and abs(data["temperature"]) < 1e-6:
|
||||
data["temperature"] = 1e-6
|
||||
# logprobs
|
||||
|
||||
@@ -663,6 +663,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
chat_template_kwargs: Optional[dict] = None
|
||||
chat_template: Optional[str] = None
|
||||
reasoning_max_tokens: Optional[int] = None
|
||||
response_max_tokens: Optional[int] = None
|
||||
structural_tag: Optional[str] = None
|
||||
guided_json: Optional[Union[str, dict, BaseModel]] = None
|
||||
guided_regex: Optional[str] = None
|
||||
|
||||
@@ -144,8 +144,11 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
# truncate prompts that exceed the length limit
|
||||
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
|
||||
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
|
||||
max_tokens = max_model_len - len(request.prompt_token_ids)
|
||||
if request.get("max_tokens") is None:
|
||||
request.set("max_tokens", max(1, max_model_len - len(request.prompt_token_ids)))
|
||||
request.set("max_tokens", max(1, max_tokens))
|
||||
else:
|
||||
request.set("max_tokens", min(max_tokens, request.get("max_tokens")))
|
||||
if request.get("temperature") < _SAMPLING_EPS:
|
||||
# zero temperature is equivalent to greedy sampling
|
||||
request.set("temperature", 1)
|
||||
@@ -153,6 +156,8 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
request.set("top_p", _SAMPLING_EPS)
|
||||
if self.reasoning_parser and self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser":
|
||||
request.enable_thinking = True
|
||||
if request.get("response_max_tokens") is not None and request.get("enable_thinking") is False:
|
||||
request["max_tokens"] = min(request["response_max_tokens"], request["max_tokens"])
|
||||
|
||||
data_processor_logger.info(f"Processed request: {request}")
|
||||
return request
|
||||
@@ -212,6 +217,7 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
request[k] = v
|
||||
else:
|
||||
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
|
||||
request.setdefault("enable_thinking", True)
|
||||
request["prompt_token_ids"] = self.messages2ids(request, **chat_template_kwargs)
|
||||
else:
|
||||
raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")
|
||||
@@ -222,8 +228,11 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
# truncate prompts that exceed the length limit
|
||||
if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len:
|
||||
request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]
|
||||
max_tokens = max_model_len - len(request["prompt_token_ids"])
|
||||
if request.get("max_tokens") is None:
|
||||
request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"]))
|
||||
request["max_tokens"] = max(1, max_tokens)
|
||||
else:
|
||||
request["max_tokens"] = min(max_tokens, request["max_tokens"])
|
||||
if request.get("temperature") < _SAMPLING_EPS:
|
||||
# zero temperature is equivalent to greedy sampling
|
||||
request["temperature"] = 1
|
||||
@@ -231,6 +240,8 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
request["top_p"] = _SAMPLING_EPS
|
||||
if self.reasoning_parser and self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser":
|
||||
request["enable_thinking"] = True
|
||||
if request.get("response_max_tokens") is not None and request.get("enable_thinking") is False:
|
||||
request["max_tokens"] = min(request["response_max_tokens"], request["max_tokens"])
|
||||
|
||||
data_processor_logger.info(f"Processed request dict: {request}")
|
||||
return request
|
||||
|
||||
@@ -234,6 +234,7 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
|
||||
images = multimodal_data.get("image", None)
|
||||
videos = multimodal_data.get("video", None)
|
||||
request["prompt_tokens"] = request.get("prompt")
|
||||
request.setdefault("enable_thinking", True)
|
||||
outputs = self.ernie4_5_processor.text2ids(request["prompt"], images, videos)
|
||||
elif request.get("messages"):
|
||||
messages = request["messages"]
|
||||
@@ -275,16 +276,20 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
|
||||
if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len:
|
||||
request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]
|
||||
|
||||
max_tokens = max_model_len - len(request["prompt_token_ids"])
|
||||
if request.get("max_tokens") is None:
|
||||
request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"]))
|
||||
request["max_tokens"] = max(1, max_tokens)
|
||||
else:
|
||||
request["max_tokens"] = min(max_model_len - len(request["prompt_token_ids"]), request["max_tokens"])
|
||||
request["max_tokens"] = min(max_tokens, request["max_tokens"])
|
||||
if request.get("reasoning_max_tokens") is None:
|
||||
request["reasoning_max_tokens"] = max(int(request["max_tokens"] * 0.8), 1)
|
||||
data_processor_logger.info(f"Processed request {request}")
|
||||
|
||||
if request.get("top_p") is not None and request.get("top_p") < _SAMPLING_EPS:
|
||||
request["top_p"] = _SAMPLING_EPS
|
||||
if request.get("response_max_tokens") is not None and request.get("enable_thinking") is False:
|
||||
request["max_tokens"] = min(request["response_max_tokens"], request["max_tokens"])
|
||||
|
||||
data_processor_logger.info(f"Processed request {request}")
|
||||
|
||||
return request
|
||||
|
||||
|
||||
@@ -250,8 +250,11 @@ class PaddleOCRVLProcessor(TextProcessor):
|
||||
] # Leave space for at least 1 new token
|
||||
|
||||
# Set default max_tokens if not specified
|
||||
max_tokens = max_model_len - len(request["prompt_token_ids"])
|
||||
if request.get("max_tokens") is None:
|
||||
request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"])) # Ensure at least 1 token
|
||||
request["max_tokens"] = max(1, max_tokens)
|
||||
else:
|
||||
request["max_tokens"] = min(max_tokens, request["max_tokens"])
|
||||
|
||||
if request.get("top_p") is not None and request.get("top_p") < _SAMPLING_EPS:
|
||||
request["top_p"] = _SAMPLING_EPS
|
||||
|
||||
@@ -267,8 +267,11 @@ class QwenVLProcessor(TextProcessor):
|
||||
] # Leave space for at least 1 new token
|
||||
|
||||
# Set default max_tokens if not specified
|
||||
max_tokens = max_model_len - len(request["prompt_token_ids"])
|
||||
if request.get("max_tokens") is None:
|
||||
request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"])) # Ensure at least 1 token
|
||||
request["max_tokens"] = max(1, max_tokens)
|
||||
else:
|
||||
request["max_tokens"] = min(max_tokens, request["max_tokens"])
|
||||
data_processor_logger.info(f"Processed request {request}")
|
||||
|
||||
return request
|
||||
|
||||
@@ -273,13 +273,18 @@ class DataProcessor(BaseDataProcessor):
|
||||
# truncate prompts that exceed the length limit
|
||||
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
|
||||
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
|
||||
max_tokens = max_model_len - len(request.prompt_token_ids)
|
||||
if request.get("max_tokens") is None:
|
||||
request.set("max_tokens", max(1, max_model_len - len(request.prompt_token_ids)))
|
||||
request.set("max_tokens", max(1, max_tokens))
|
||||
else:
|
||||
request.set("max_tokens", min(max_tokens, request.get("max_tokens")))
|
||||
if request.get("temperature") < _SAMPLING_EPS:
|
||||
# zero temperature is equivalent to greedy sampling
|
||||
request.set("temperature", 1)
|
||||
if request.get("top_p") < _SAMPLING_EPS:
|
||||
request.set("top_p", _SAMPLING_EPS)
|
||||
if request.get("response_max_tokens") is not None and request.get("enable_thinking") is False:
|
||||
request["max_tokens"] = min(request["response_max_tokens"], request["max_tokens"])
|
||||
|
||||
data_processor_logger.info(f"Processed request: {request}")
|
||||
return request
|
||||
@@ -350,13 +355,18 @@ class DataProcessor(BaseDataProcessor):
|
||||
# truncate prompts that exceed the length limit
|
||||
if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len:
|
||||
request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]
|
||||
max_tokens = max_model_len - len(request["prompt_token_ids"])
|
||||
if request.get("max_tokens") is None:
|
||||
request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"]))
|
||||
request["max_tokens"] = max(1, max_tokens)
|
||||
else:
|
||||
request["max_tokens"] = min(max_tokens, request["max_tokens"])
|
||||
if request.get("temperature") < _SAMPLING_EPS:
|
||||
# zero temperature is equivalent to greedy sampling
|
||||
request["temperature"] = 1
|
||||
if request.get("top_p") < _SAMPLING_EPS:
|
||||
request["top_p"] = _SAMPLING_EPS
|
||||
if request.get("response_max_tokens") is not None and request.get("enable_thinking") is False:
|
||||
request["max_tokens"] = min(request["response_max_tokens"], request["max_tokens"])
|
||||
|
||||
data_processor_logger.info(f"Processed request dict: {request}")
|
||||
return request
|
||||
|
||||
@@ -27,8 +27,7 @@ from fastdeploy.platforms import current_platform
|
||||
if current_platform.is_iluvatar():
|
||||
from fastdeploy.model_executor.ops.iluvatar import (
|
||||
get_padding_offset,
|
||||
limit_thinking_content_length_v1,
|
||||
limit_thinking_content_length_v2,
|
||||
limit_thinking_content_length,
|
||||
save_output,
|
||||
set_stop_value_multi_ends,
|
||||
step_paddle,
|
||||
@@ -52,12 +51,10 @@ elif current_platform.is_dcu():
|
||||
elif current_platform.is_maca():
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
get_padding_offset,
|
||||
limit_thinking_content_length_v1,
|
||||
limit_thinking_content_length_v2,
|
||||
limit_thinking_content_length,
|
||||
save_output,
|
||||
set_stop_value_multi_ends,
|
||||
speculate_limit_thinking_content_length_v1,
|
||||
speculate_limit_thinking_content_length_v2,
|
||||
speculate_limit_thinking_content_length,
|
||||
step_paddle,
|
||||
update_inputs,
|
||||
update_inputs_v1,
|
||||
@@ -86,10 +83,8 @@ else:
|
||||
step_reschedule,
|
||||
update_inputs_v1,
|
||||
speculate_step_reschedule,
|
||||
limit_thinking_content_length_v1,
|
||||
limit_thinking_content_length_v2,
|
||||
speculate_limit_thinking_content_length_v1,
|
||||
speculate_limit_thinking_content_length_v2,
|
||||
limit_thinking_content_length,
|
||||
speculate_limit_thinking_content_length,
|
||||
)
|
||||
|
||||
from fastdeploy.model_executor.entropy_utils import (
|
||||
@@ -107,85 +102,6 @@ from fastdeploy.worker.output import LogprobsTensors, ModelOutputData, SamplerOu
|
||||
DISABLE_RECOVER = envs.FD_DISABLED_RECOVER == "1"
|
||||
|
||||
|
||||
def limit_thinking_content_length(
|
||||
limit_strategy: str,
|
||||
sampled_token_ids: paddle.Tensor,
|
||||
max_think_lens: paddle.Tensor,
|
||||
step_idx: paddle.Tensor,
|
||||
limit_think_status: paddle.Tensor,
|
||||
stop_flags: paddle.Tensor,
|
||||
eos_token_ids: paddle.Tensor,
|
||||
think_end_id: int,
|
||||
line_break_id: int = None,
|
||||
):
|
||||
if limit_strategy == "</think>":
|
||||
# for ernie-45-vl
|
||||
limit_thinking_content_length_v1(
|
||||
sampled_token_ids,
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
limit_think_status,
|
||||
stop_flags,
|
||||
eos_token_ids, # 处理由于模型效果问题导致思考过程中输出eos token的问题
|
||||
think_end_id,
|
||||
)
|
||||
elif limit_strategy == "\n</think>\n\n":
|
||||
# for ernie-x1
|
||||
assert line_break_id > 0
|
||||
limit_thinking_content_length_v2(
|
||||
sampled_token_ids,
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
limit_think_status,
|
||||
stop_flags,
|
||||
think_end_id,
|
||||
line_break_id,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Not support {limit_strategy=} for limit thinking content length.")
|
||||
|
||||
|
||||
def speculate_limit_thinking_content_length(
|
||||
limit_strategy: str,
|
||||
accept_tokens: paddle.Tensor,
|
||||
max_think_lens: paddle.Tensor,
|
||||
step_idx: paddle.Tensor,
|
||||
limit_think_status: paddle.Tensor,
|
||||
accept_num: paddle.Tensor,
|
||||
stop_flags: paddle.Tensor,
|
||||
eos_token_ids: paddle.Tensor,
|
||||
think_end_id: int,
|
||||
line_break_id: int = None,
|
||||
):
|
||||
if limit_strategy == "</think>":
|
||||
# for ernie-45-vl
|
||||
speculate_limit_thinking_content_length_v1(
|
||||
accept_tokens,
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
limit_think_status,
|
||||
accept_num,
|
||||
stop_flags,
|
||||
eos_token_ids, # 处理由于模型效果问题导致思考过程中输出eos token的问题
|
||||
think_end_id,
|
||||
)
|
||||
elif limit_strategy == "\n</think>\n\n":
|
||||
# for ernie-x1
|
||||
assert line_break_id > 0
|
||||
speculate_limit_thinking_content_length_v2(
|
||||
accept_tokens,
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
limit_think_status,
|
||||
accept_num,
|
||||
stop_flags,
|
||||
think_end_id,
|
||||
line_break_id,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Not support {limit_strategy=} for limit thinking content length.")
|
||||
|
||||
|
||||
def pre_process(
|
||||
input_ids: paddle.Tensor,
|
||||
seq_lens_this_time: int,
|
||||
@@ -327,22 +243,23 @@ def post_process_normal(
|
||||
skip_save_output: bool = False,
|
||||
async_output_queue: queue.Queue = None,
|
||||
think_end_id: int = -1,
|
||||
line_break_id: int = -1,
|
||||
splitwise_role_is_decode: bool = False,
|
||||
enable_entropy: bool = False,
|
||||
routing_replay_manager: RoutingReplayManager = None,
|
||||
):
|
||||
"""Post-processing steps after completing a single token generation."""
|
||||
if think_end_id > 0:
|
||||
limit_thinking_content_length(
|
||||
limit_strategy=envs.FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR,
|
||||
sampled_token_ids=sampler_output.sampled_token_ids,
|
||||
max_think_lens=share_inputs["max_think_lens"],
|
||||
step_idx=share_inputs["step_idx"],
|
||||
limit_think_status=share_inputs["limit_think_status"],
|
||||
stop_flags=share_inputs["stop_flags"],
|
||||
eos_token_ids=share_inputs["eos_token_id"],
|
||||
think_end_id=think_end_id,
|
||||
line_break_id=line_break_id,
|
||||
sampler_output.sampled_token_ids,
|
||||
share_inputs["max_think_lens"],
|
||||
share_inputs["max_reply_lens"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["limit_think_status"],
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["eos_token_id"],
|
||||
share_inputs["inject_token_ids"],
|
||||
think_end_id,
|
||||
splitwise_role_is_decode,
|
||||
)
|
||||
# 1. Set stop value
|
||||
paddle.assign(
|
||||
@@ -482,22 +399,23 @@ def post_process_specualate(
|
||||
save_each_rank: bool = False,
|
||||
skip_save_output: bool = False,
|
||||
think_end_id: int = -1,
|
||||
line_break_id: int = -1,
|
||||
splitwise_role_is_decode: bool = False,
|
||||
enable_entropy: bool = False,
|
||||
routing_replay_manager: RoutingReplayManager = None,
|
||||
):
|
||||
if think_end_id > 0:
|
||||
speculate_limit_thinking_content_length(
|
||||
limit_strategy=envs.FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR,
|
||||
accept_tokens=share_inputs["accept_tokens"],
|
||||
max_think_lens=share_inputs["max_think_lens"],
|
||||
step_idx=share_inputs["step_idx"],
|
||||
limit_think_status=share_inputs["limit_think_status"],
|
||||
accept_num=share_inputs["accept_num"],
|
||||
stop_flags=share_inputs["stop_flags"],
|
||||
eos_token_ids=share_inputs["eos_token_id"],
|
||||
think_end_id=think_end_id,
|
||||
line_break_id=line_break_id,
|
||||
share_inputs["accept_tokens"],
|
||||
share_inputs["max_think_lens"],
|
||||
share_inputs["max_reply_lens"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["limit_think_status"],
|
||||
share_inputs["accept_num"],
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["eos_token_id"],
|
||||
share_inputs["inject_token_ids"],
|
||||
think_end_id,
|
||||
splitwise_role_is_decode,
|
||||
)
|
||||
speculate_set_stop_value_multi_seqs(
|
||||
model_output.accept_tokens,
|
||||
@@ -606,6 +524,7 @@ def post_process(
|
||||
async_output_queue: queue.Queue = None,
|
||||
think_end_id: int = -1,
|
||||
line_break_id: int = -1,
|
||||
splitwise_role_is_decode: bool = False,
|
||||
enable_entropy: bool = False,
|
||||
routing_replay_manager: RoutingReplayManager = None,
|
||||
) -> None:
|
||||
@@ -632,7 +551,7 @@ def post_process(
|
||||
save_each_rank,
|
||||
skip_save_output,
|
||||
think_end_id,
|
||||
line_break_id,
|
||||
splitwise_role_is_decode,
|
||||
enable_entropy,
|
||||
routing_replay_manager,
|
||||
)
|
||||
@@ -647,7 +566,7 @@ def post_process(
|
||||
skip_save_output,
|
||||
async_output_queue,
|
||||
think_end_id,
|
||||
line_break_id,
|
||||
splitwise_role_is_decode,
|
||||
enable_entropy,
|
||||
routing_replay_manager,
|
||||
)
|
||||
|
||||
@@ -638,13 +638,31 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self._apply_mm_inputs(request, multi_vision_inputs, rope_3d_position_ids)
|
||||
|
||||
if not self.is_pooling_model:
|
||||
if request.get("enable_thinking", False) and request.get("reasoning_max_tokens", None) is not None:
|
||||
# Enable thinking
|
||||
self.share_inputs["max_think_lens"][idx : idx + 1, :] = request.get("reasoning_max_tokens")
|
||||
self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0
|
||||
if request.get("enable_thinking") is not None:
|
||||
enable_thinking = bool(request.get("enable_thinking"))
|
||||
logger.debug(f"request {request.request_id} with {enable_thinking=} at idx {idx}")
|
||||
self.share_inputs["enable_thinking"][idx : idx + 1, :] = enable_thinking
|
||||
if enable_thinking:
|
||||
self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0
|
||||
if request.get("reasoning_max_tokens") is not None:
|
||||
self.share_inputs["max_think_lens"][idx : idx + 1, :] = request.get(
|
||||
"reasoning_max_tokens"
|
||||
)
|
||||
else:
|
||||
self.share_inputs["max_think_lens"][idx : idx + 1, :] = -1
|
||||
if request.get("response_max_tokens") is not None:
|
||||
self.share_inputs["max_reply_lens"][idx : idx + 1, :] = request.get(
|
||||
"response_max_tokens"
|
||||
)
|
||||
else:
|
||||
self.share_inputs["max_reply_lens"][idx : idx + 1, :] = -1
|
||||
else:
|
||||
self.share_inputs["max_think_lens"][idx : idx + 1, :] = -1
|
||||
self.share_inputs["max_reply_lens"][idx : idx + 1, :] = -1
|
||||
self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0
|
||||
else:
|
||||
# Disable thinking
|
||||
self.share_inputs["max_think_lens"][idx : idx + 1, :] = -1
|
||||
self.share_inputs["max_reply_lens"][idx : idx + 1, :] = -1
|
||||
self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0
|
||||
|
||||
if isinstance(request.prompt_token_ids, np.ndarray):
|
||||
@@ -918,13 +936,31 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
|
||||
|
||||
if not self.is_pooling_model:
|
||||
if request.get("enable_thinking", False) and request.get("reasoning_max_tokens", None) is not None:
|
||||
# Enable thinking
|
||||
self.share_inputs["max_think_lens"][idx : idx + 1, :] = request.get("reasoning_max_tokens")
|
||||
self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0
|
||||
if request.get("enable_thinking") is not None:
|
||||
enable_thinking = bool(request.get("enable_thinking"))
|
||||
logger.debug(f"request {request.request_id} with {enable_thinking=} at idx {idx}")
|
||||
self.share_inputs["enable_thinking"][idx : idx + 1, :] = enable_thinking
|
||||
if enable_thinking:
|
||||
self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0
|
||||
if request.get("reasoning_max_tokens") is not None:
|
||||
self.share_inputs["max_think_lens"][idx : idx + 1, :] = request.get(
|
||||
"reasoning_max_tokens"
|
||||
)
|
||||
else:
|
||||
self.share_inputs["max_think_lens"][idx : idx + 1, :] = -1
|
||||
if request.get("response_max_tokens") is not None:
|
||||
self.share_inputs["max_reply_lens"][idx : idx + 1, :] = request.get(
|
||||
"response_max_tokens"
|
||||
)
|
||||
else:
|
||||
self.share_inputs["max_reply_lens"][idx : idx + 1, :] = -1
|
||||
else:
|
||||
self.share_inputs["max_think_lens"][idx : idx + 1, :] = -1
|
||||
self.share_inputs["max_reply_lens"][idx : idx + 1, :] = -1
|
||||
self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0
|
||||
else:
|
||||
# Disable thinking
|
||||
self.share_inputs["max_think_lens"][idx : idx + 1, :] = -1
|
||||
self.share_inputs["max_reply_lens"][idx : idx + 1, :] = -1
|
||||
self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0
|
||||
|
||||
def get_attr_from_request(request, attr, default_value=None):
|
||||
@@ -1231,8 +1267,13 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["kv_num_blocks_x_cpu"] = None # CPU
|
||||
|
||||
# Initialize thinking related buffers
|
||||
self.share_inputs["enable_thinking"] = paddle.full(shape=[max_num_seqs, 1], fill_value=True, dtype="bool")
|
||||
self.share_inputs["max_think_lens"] = paddle.full(shape=[max_num_seqs, 1], fill_value=-1, dtype="int32")
|
||||
self.share_inputs["max_reply_lens"] = paddle.full(shape=[max_num_seqs, 1], fill_value=-1, dtype="int32")
|
||||
self.share_inputs["limit_think_status"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
|
||||
self.share_inputs["inject_token_ids"] = paddle.to_tensor(
|
||||
self.model_config.think_truncate_prompt_ids, dtype="int64"
|
||||
).reshape([-1, 1])
|
||||
|
||||
# Initialize rotary position embedding
|
||||
if not self.enable_mm:
|
||||
@@ -1432,6 +1473,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
# Reset thinking related buffers
|
||||
fill_paddle_tensor(self.share_inputs, "enable_thinking", True)
|
||||
fill_paddle_tensor(self.share_inputs, "max_think_lens", -1)
|
||||
fill_paddle_tensor(self.share_inputs, "max_reply_lens", -1)
|
||||
fill_paddle_tensor(self.share_inputs, "limit_think_status", 0)
|
||||
|
||||
# Reset reasoning buffers
|
||||
@@ -1941,7 +1983,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
skip_save_output=True,
|
||||
async_output_queue=self.async_output_queue,
|
||||
think_end_id=self.model_config.think_end_id,
|
||||
line_break_id=self.model_config.line_break_id,
|
||||
splitwise_role_is_decode=self.scheduler_config.splitwise_role == "decode",
|
||||
)
|
||||
return pooler_output
|
||||
|
||||
@@ -2042,7 +2084,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
skip_save_output=True,
|
||||
async_output_queue=self.async_output_queue,
|
||||
think_end_id=self.model_config.think_end_id,
|
||||
line_break_id=self.model_config.line_break_id,
|
||||
splitwise_role_is_decode=self.scheduler_config.splitwise_role == "decode",
|
||||
enable_entropy=self.enable_entropy and self.parallel_config.tensor_parallel_rank == 0,
|
||||
)
|
||||
if self.speculative_decoding:
|
||||
@@ -2565,7 +2607,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
skip_save_output=skip_save_output,
|
||||
async_output_queue=self.async_output_queue,
|
||||
think_end_id=self.model_config.think_end_id,
|
||||
line_break_id=self.model_config.line_break_id,
|
||||
splitwise_role_is_decode=self.scheduler_config.splitwise_role == "decode",
|
||||
enable_entropy=self.enable_entropy and self.parallel_config.tensor_parallel_rank == 0,
|
||||
routing_replay_manager=self.routing_replay_manager,
|
||||
)
|
||||
|
||||
@@ -775,6 +775,7 @@ def parse_args():
|
||||
parser.add_argument("--think_end_id", type=int, default=-1)
|
||||
parser.add_argument("--image_patch_id", type=int, default=-1)
|
||||
parser.add_argument("--line_break_id", type=int, default=-1)
|
||||
parser.add_argument("--think_truncate_prompt_ids", type=json.loads, default=[])
|
||||
|
||||
parser.add_argument(
|
||||
"--quantization",
|
||||
|
||||
@@ -53,6 +53,7 @@ class MockModelConfig:
|
||||
model_type = ["mock"]
|
||||
moe_phase = MoEPhase(phase="prefill")
|
||||
hidden_size = 1536
|
||||
think_truncate_prompt_ids = [-1]
|
||||
|
||||
|
||||
class MockCacheConfig:
|
||||
|
||||
@@ -629,10 +629,109 @@ def test_chat_with_reasoning_max_tokens(openai_client):
|
||||
except openai.InternalServerError as e:
|
||||
error_message = str(e)
|
||||
assertion_executed = True
|
||||
assert "reasoning_max_tokens must be greater than 1" in error_message
|
||||
assert "reasoning_max_tokens must be greater than 0" in error_message
|
||||
assert assertion_executed, "Assertion was not executed (no exception raised)"
|
||||
|
||||
|
||||
def test_chat_with_response_max_tokens(openai_client):
|
||||
"""
|
||||
Test response_max_tokens option in streaming chat functionality with the local service
|
||||
"""
|
||||
# Test error case: response_max_tokens <= 0 should raise an error
|
||||
assertion_executed = False
|
||||
try:
|
||||
openai_client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
extra_body={"response_max_tokens": -1},
|
||||
max_tokens=10,
|
||||
stream=False,
|
||||
)
|
||||
except openai.InternalServerError as e:
|
||||
error_message = str(e)
|
||||
assertion_executed = True
|
||||
assert "response_max_tokens must be greater than 0" in error_message
|
||||
assert assertion_executed, "Assertion was not executed (no exception raised)"
|
||||
|
||||
# Test functional case: response_max_tokens limits response tokens in streaming mode
|
||||
response_max_tokens = 3
|
||||
response = openai_client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[{"role": "user", "content": "Explain gravity in a way that a five-year-old child can understand."}],
|
||||
temperature=1,
|
||||
extra_body={
|
||||
"chat_template_kwargs": {"enable_thinking": True},
|
||||
"response_max_tokens": response_max_tokens,
|
||||
"return_token_ids": True,
|
||||
},
|
||||
stream=True,
|
||||
max_tokens=10,
|
||||
)
|
||||
completion_tokens = 1
|
||||
reasoning_tokens = 0
|
||||
total_tokens = 0
|
||||
for chunk_id, chunk in enumerate(response):
|
||||
if chunk_id == 0: # the first chunk is an extra chunk
|
||||
continue
|
||||
delta_message = chunk.choices[0].delta
|
||||
if delta_message.content != "" and delta_message.reasoning_content == "":
|
||||
completion_tokens += len(delta_message.completion_token_ids)
|
||||
elif delta_message.reasoning_content != "" and delta_message.content == "":
|
||||
reasoning_tokens += len(delta_message.completion_token_ids)
|
||||
total_tokens += len(delta_message.completion_token_ids)
|
||||
assert completion_tokens + reasoning_tokens == total_tokens
|
||||
assert completion_tokens <= response_max_tokens + 1
|
||||
|
||||
# Test disable-thinking case: response_max_tokens limits tokens when thinking is off
|
||||
response_max_tokens = 3
|
||||
response = openai_client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[{"role": "user", "content": "Explain gravity in a way that a five-year-old child can understand."}],
|
||||
temperature=1,
|
||||
extra_body={
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
"response_max_tokens": response_max_tokens,
|
||||
},
|
||||
stream=False,
|
||||
max_tokens=10,
|
||||
)
|
||||
assert response.choices[0].message.reasoning_content is None
|
||||
assert "</think>" not in response.choices[0].message.content
|
||||
assert response.usage.completion_tokens <= response_max_tokens
|
||||
|
||||
# Test enable-thinking case with both reasoning_max_tokens and response_max_tokens limited
|
||||
reasoning_max_tokens = 3
|
||||
response_max_tokens = 3
|
||||
response = openai_client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[{"role": "user", "content": "Explain gravity in a way that a five-year-old child can understand."}],
|
||||
temperature=1,
|
||||
extra_body={
|
||||
"chat_template_kwargs": {"enable_thinking": True},
|
||||
"reasoning_max_tokens": reasoning_max_tokens,
|
||||
"response_max_tokens": response_max_tokens,
|
||||
"return_token_ids": True,
|
||||
},
|
||||
stream=True,
|
||||
max_tokens=20,
|
||||
)
|
||||
completion_tokens = 1
|
||||
reasoning_tokens = 0
|
||||
total_tokens = 0
|
||||
for chunk_id, chunk in enumerate(response):
|
||||
if chunk_id == 0: # the first chunk is an extra chunk
|
||||
continue
|
||||
delta_message = chunk.choices[0].delta
|
||||
if delta_message.content != "" and delta_message.reasoning_content == "":
|
||||
completion_tokens += len(delta_message.completion_token_ids)
|
||||
elif delta_message.reasoning_content != "" and delta_message.content == "":
|
||||
reasoning_tokens += len(delta_message.completion_token_ids)
|
||||
total_tokens += len(delta_message.completion_token_ids)
|
||||
assert completion_tokens + reasoning_tokens == total_tokens
|
||||
assert reasoning_tokens <= reasoning_max_tokens
|
||||
assert completion_tokens <= response_max_tokens + 1
|
||||
|
||||
|
||||
def test_profile_reset_block_num():
|
||||
"""测试profile reset_block_num功能,与baseline diff不能超过5%"""
|
||||
log_file = "./log/config.log"
|
||||
|
||||
@@ -587,7 +587,7 @@ class DataProcessorTestCase(unittest.TestCase):
|
||||
request = {"prompt_token_ids": [1, 2, 3], "max_tokens": 5}
|
||||
processed = self.processor.process_request_dict(request, max_model_len=6)
|
||||
self.assertEqual(processed["prompt_token_ids"], [1, 2, 3])
|
||||
self.assertEqual(processed["max_tokens"], 5)
|
||||
self.assertEqual(processed["max_tokens"], 3)
|
||||
|
||||
def test_process_request_dict_requires_chat_template(self):
|
||||
original_template = self.processor.tokenizer.chat_template
|
||||
|
||||
@@ -1,365 +0,0 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for limit_thinking_content_length_v1 and limit_thinking_content_length_v2"""
|
||||
|
||||
import unittest
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
limit_thinking_content_length_v1,
|
||||
limit_thinking_content_length_v2,
|
||||
)
|
||||
|
||||
|
||||
class TestLimitThinkingContentLengthV1(unittest.TestCase):
|
||||
"""Tests for limit_thinking_content_length_v1 operator (</think> strategy)"""
|
||||
|
||||
def test_normal_thinking_phase_no_limit_reached(self):
|
||||
"""Test normal thinking phase when step < max_think_len"""
|
||||
next_tokens = paddle.to_tensor([[100], [200]], dtype="int64")
|
||||
max_think_lens = paddle.to_tensor([10, 15], dtype="int32")
|
||||
step_idx = paddle.to_tensor([[5], [8]], dtype="int64")
|
||||
limit_think_status = paddle.to_tensor([0, 0], dtype="int32")
|
||||
stop_flags = paddle.to_tensor([False, False], dtype="bool")
|
||||
eos_token_ids = paddle.to_tensor([[2], [2]], dtype="int64")
|
||||
think_end_id = 999
|
||||
|
||||
# Run operator
|
||||
limit_thinking_content_length_v1(
|
||||
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, eos_token_ids, think_end_id
|
||||
)
|
||||
|
||||
# Verify: tokens unchanged, status unchanged
|
||||
assert next_tokens.numpy()[0, 0] == 100
|
||||
assert next_tokens.numpy()[1, 0] == 200
|
||||
assert limit_think_status.numpy()[0] == 0
|
||||
assert limit_think_status.numpy()[1] == 0
|
||||
|
||||
def test_force_truncation_when_max_think_len_exceeded(self):
|
||||
"""Test force truncation when step >= max_think_len"""
|
||||
next_tokens = paddle.to_tensor([[100], [200]], dtype="int64")
|
||||
max_think_lens = paddle.to_tensor([5, 8], dtype="int32")
|
||||
step_idx = paddle.to_tensor([[5], [10]], dtype="int64") # Both exceed or equal limit
|
||||
limit_think_status = paddle.to_tensor([0, 0], dtype="int32")
|
||||
stop_flags = paddle.to_tensor([False, False], dtype="bool")
|
||||
eos_token_ids = paddle.to_tensor([[2], [2]], dtype="int64")
|
||||
think_end_id = 999
|
||||
|
||||
# Run operator
|
||||
limit_thinking_content_length_v1(
|
||||
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, eos_token_ids, think_end_id
|
||||
)
|
||||
|
||||
# Verify: tokens replaced with think_end_id, status changed to 2
|
||||
assert next_tokens.numpy()[0, 0] == 999 # Replaced
|
||||
assert next_tokens.numpy()[1, 0] == 999 # Replaced
|
||||
assert limit_think_status.numpy()[0] == 2 # Status updated
|
||||
assert limit_think_status.numpy()[1] == 2 # Status updated
|
||||
|
||||
def test_model_naturally_generates_think_end_id(self):
|
||||
"""Test when model naturally generates think_end_id"""
|
||||
next_tokens = paddle.to_tensor([[999]], dtype="int64") # Model generated think_end_id
|
||||
max_think_lens = paddle.to_tensor([10], dtype="int32")
|
||||
step_idx = paddle.to_tensor([[3]], dtype="int64") # Still within limit
|
||||
limit_think_status = paddle.to_tensor([0], dtype="int32")
|
||||
stop_flags = paddle.to_tensor([False], dtype="bool")
|
||||
eos_token_ids = paddle.to_tensor([[2]], dtype="int64")
|
||||
think_end_id = 999
|
||||
|
||||
# Run operator
|
||||
limit_thinking_content_length_v1(
|
||||
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, eos_token_ids, think_end_id
|
||||
)
|
||||
|
||||
# Verify: token unchanged (already think_end_id), status changed to 2
|
||||
assert next_tokens.numpy()[0, 0] == 999
|
||||
assert limit_think_status.numpy()[0] == 2 # Move to response phase
|
||||
|
||||
def test_status_1_to_status_2_transition(self):
|
||||
"""Test transition from status 1 (injected) to status 2 (confirmed)"""
|
||||
next_tokens = paddle.to_tensor([[999]], dtype="int64") # think_end_id from previous injection
|
||||
max_think_lens = paddle.to_tensor([5], dtype="int32")
|
||||
step_idx = paddle.to_tensor([[6]], dtype="int64")
|
||||
limit_think_status = paddle.to_tensor([1], dtype="int32") # Status is 1
|
||||
stop_flags = paddle.to_tensor([False], dtype="bool")
|
||||
eos_token_ids = paddle.to_tensor([[2]], dtype="int64")
|
||||
think_end_id = 999
|
||||
|
||||
# Run operator
|
||||
limit_thinking_content_length_v1(
|
||||
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, eos_token_ids, think_end_id
|
||||
)
|
||||
|
||||
# Verify: status changed to 2
|
||||
assert limit_think_status.numpy()[0] == 2
|
||||
|
||||
def test_disabled_feature_negative_max_think_len(self):
|
||||
"""Test that negative max_think_len disables the feature"""
|
||||
next_tokens = paddle.to_tensor([[100]], dtype="int64")
|
||||
max_think_lens = paddle.to_tensor([-1], dtype="int32") # Disabled
|
||||
step_idx = paddle.to_tensor([[100]], dtype="int64") # Would exceed limit if enabled
|
||||
limit_think_status = paddle.to_tensor([0], dtype="int32")
|
||||
stop_flags = paddle.to_tensor([False], dtype="bool")
|
||||
eos_token_ids = paddle.to_tensor([[2]], dtype="int64")
|
||||
think_end_id = 999
|
||||
|
||||
# Run operator
|
||||
limit_thinking_content_length_v1(
|
||||
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, eos_token_ids, think_end_id
|
||||
)
|
||||
|
||||
# Verify: nothing changed
|
||||
assert next_tokens.numpy()[0, 0] == 100
|
||||
assert limit_think_status.numpy()[0] == 0
|
||||
|
||||
def test_already_in_response_phase_status_2(self):
|
||||
"""Test that status 2 (response phase) is terminal"""
|
||||
next_tokens = paddle.to_tensor([[100]], dtype="int64")
|
||||
max_think_lens = paddle.to_tensor([5], dtype="int32")
|
||||
step_idx = paddle.to_tensor([[10]], dtype="int64")
|
||||
limit_think_status = paddle.to_tensor([2], dtype="int32") # Already in response phase
|
||||
stop_flags = paddle.to_tensor([False], dtype="bool")
|
||||
eos_token_ids = paddle.to_tensor([[2]], dtype="int64")
|
||||
think_end_id = 999
|
||||
|
||||
# Run operator
|
||||
limit_thinking_content_length_v1(
|
||||
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, eos_token_ids, think_end_id
|
||||
)
|
||||
|
||||
# Verify: nothing changed
|
||||
assert next_tokens.numpy()[0, 0] == 100
|
||||
assert limit_think_status.numpy()[0] == 2
|
||||
|
||||
def test_mixed_batch(self):
|
||||
"""Test batch with different sequences in different states"""
|
||||
next_tokens = paddle.to_tensor([[100], [200], [999], [300]], dtype="int64")
|
||||
max_think_lens = paddle.to_tensor([10, 5, 8, -1], dtype="int32")
|
||||
step_idx = paddle.to_tensor([[3], [5], [4], [100]], dtype="int64")
|
||||
limit_think_status = paddle.to_tensor([0, 0, 0, 0], dtype="int32")
|
||||
stop_flags = paddle.to_tensor([False, False, False, False], dtype="bool")
|
||||
eos_token_ids = paddle.to_tensor([[2], [2]], dtype="int64")
|
||||
think_end_id = 999
|
||||
|
||||
# Run operator
|
||||
limit_thinking_content_length_v1(
|
||||
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, eos_token_ids, think_end_id
|
||||
)
|
||||
|
||||
# Verify each sequence
|
||||
# Seq 0: step 3 < max 10, status 0, token unchanged
|
||||
assert next_tokens.numpy()[0, 0] == 100
|
||||
assert limit_think_status.numpy()[0] == 0
|
||||
|
||||
# Seq 1: step 5 >= max 5, force inject think_end_id, status -> 2
|
||||
assert next_tokens.numpy()[1, 0] == 999
|
||||
assert limit_think_status.numpy()[1] == 2
|
||||
|
||||
# Seq 2: step 4 < max 8, but token is think_end_id, status -> 2
|
||||
assert next_tokens.numpy()[2, 0] == 999
|
||||
assert limit_think_status.numpy()[2] == 2
|
||||
|
||||
# Seq 3: disabled (max -1), unchanged
|
||||
assert next_tokens.numpy()[3, 0] == 300
|
||||
assert limit_think_status.numpy()[3] == 0
|
||||
|
||||
|
||||
class TestLimitThinkingContentLengthV2(unittest.TestCase):
|
||||
"""Tests for limit_thinking_content_length_v2 operator (\n</think>\n\n strategy)"""
|
||||
|
||||
def test_normal_thinking_phase_no_limit_reached(self):
|
||||
"""Test normal thinking phase when step < max_think_len"""
|
||||
next_tokens = paddle.to_tensor([[100], [200]], dtype="int64")
|
||||
max_think_lens = paddle.to_tensor([10, 15], dtype="int32")
|
||||
step_idx = paddle.to_tensor([[5], [8]], dtype="int64")
|
||||
limit_think_status = paddle.to_tensor([0, 0], dtype="int32")
|
||||
stop_flags = paddle.to_tensor([False, False], dtype="bool")
|
||||
think_end_id = 999
|
||||
line_break_id = 888
|
||||
|
||||
# Run operator
|
||||
limit_thinking_content_length_v2(
|
||||
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, think_end_id, line_break_id
|
||||
)
|
||||
|
||||
# Verify: tokens unchanged, status unchanged
|
||||
assert next_tokens.numpy()[0, 0] == 100
|
||||
assert next_tokens.numpy()[1, 0] == 200
|
||||
assert limit_think_status.numpy()[0] == 0
|
||||
assert limit_think_status.numpy()[1] == 0
|
||||
|
||||
def test_force_truncation_sequence_injection(self):
|
||||
"""Test force truncation with \n</think>\n\n sequence injection"""
|
||||
# Test step == max_think_len (inject first \n)
|
||||
next_tokens = paddle.to_tensor([[100]], dtype="int64")
|
||||
max_think_lens = paddle.to_tensor([5], dtype="int32")
|
||||
step_idx = paddle.to_tensor([[5]], dtype="int64")
|
||||
limit_think_status = paddle.to_tensor([0], dtype="int32")
|
||||
stop_flags = paddle.to_tensor([False], dtype="bool")
|
||||
think_end_id = 999
|
||||
line_break_id = 888
|
||||
|
||||
limit_thinking_content_length_v2(
|
||||
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, think_end_id, line_break_id
|
||||
)
|
||||
assert next_tokens.numpy()[0, 0] == 888 # line_break_id
|
||||
assert limit_think_status.numpy()[0] == 1
|
||||
|
||||
# Test step == max_think_len + 1 (inject </think>)
|
||||
next_tokens = paddle.to_tensor([[100]], dtype="int64")
|
||||
step_idx = paddle.to_tensor([[6]], dtype="int64")
|
||||
limit_think_status = paddle.to_tensor([1], dtype="int32")
|
||||
|
||||
limit_thinking_content_length_v2(
|
||||
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, think_end_id, line_break_id
|
||||
)
|
||||
assert next_tokens.numpy()[0, 0] == 999 # think_end_id
|
||||
assert limit_think_status.numpy()[0] == 1
|
||||
|
||||
# Test step == max_think_len + 2 (inject second \n)
|
||||
next_tokens = paddle.to_tensor([[100]], dtype="int64")
|
||||
step_idx = paddle.to_tensor([[7]], dtype="int64")
|
||||
limit_think_status = paddle.to_tensor([1], dtype="int32")
|
||||
|
||||
limit_thinking_content_length_v2(
|
||||
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, think_end_id, line_break_id
|
||||
)
|
||||
assert next_tokens.numpy()[0, 0] == 888 # line_break_id
|
||||
assert limit_think_status.numpy()[0] == 1
|
||||
|
||||
# Test step == max_think_len + 3 (inject third \n and finish)
|
||||
next_tokens = paddle.to_tensor([[100]], dtype="int64")
|
||||
step_idx = paddle.to_tensor([[8]], dtype="int64")
|
||||
limit_think_status = paddle.to_tensor([1], dtype="int32")
|
||||
|
||||
limit_thinking_content_length_v2(
|
||||
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, think_end_id, line_break_id
|
||||
)
|
||||
assert next_tokens.numpy()[0, 0] == 888 # line_break_id
|
||||
assert limit_think_status.numpy()[0] == 3 # Move to status 3
|
||||
|
||||
def test_model_naturally_generates_think_end_id(self):
|
||||
"""Test when model naturally generates think_end_id"""
|
||||
next_tokens = paddle.to_tensor([[999]], dtype="int64")
|
||||
max_think_lens = paddle.to_tensor([10], dtype="int32")
|
||||
step_idx = paddle.to_tensor([[3]], dtype="int64")
|
||||
limit_think_status = paddle.to_tensor([0], dtype="int32")
|
||||
stop_flags = paddle.to_tensor([False], dtype="bool")
|
||||
think_end_id = 999
|
||||
line_break_id = 888
|
||||
|
||||
# Run operator
|
||||
limit_thinking_content_length_v2(
|
||||
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, think_end_id, line_break_id
|
||||
)
|
||||
|
||||
# Verify: status changed to 3 (response phase)
|
||||
assert next_tokens.numpy()[0, 0] == 999
|
||||
assert limit_think_status.numpy()[0] == 3
|
||||
|
||||
def test_status_2_to_status_3_transition(self):
|
||||
"""Test transition from status 2 (replacement done) to status 3 (thinking ended)"""
|
||||
next_tokens = paddle.to_tensor([[100]], dtype="int64")
|
||||
max_think_lens = paddle.to_tensor([5], dtype="int32")
|
||||
step_idx = paddle.to_tensor([[9]], dtype="int64")
|
||||
limit_think_status = paddle.to_tensor([2], dtype="int32")
|
||||
stop_flags = paddle.to_tensor([False], dtype="bool")
|
||||
think_end_id = 999
|
||||
line_break_id = 888
|
||||
|
||||
# Run operator
|
||||
limit_thinking_content_length_v2(
|
||||
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, think_end_id, line_break_id
|
||||
)
|
||||
|
||||
# Verify: status changed to 3
|
||||
assert limit_think_status.numpy()[0] == 3
|
||||
|
||||
def test_disabled_feature_negative_max_think_len(self):
|
||||
"""Test that negative max_think_len disables the feature"""
|
||||
next_tokens = paddle.to_tensor([[100]], dtype="int64")
|
||||
max_think_lens = paddle.to_tensor([-1], dtype="int32")
|
||||
step_idx = paddle.to_tensor([[100]], dtype="int64")
|
||||
limit_think_status = paddle.to_tensor([0], dtype="int32")
|
||||
stop_flags = paddle.to_tensor([False], dtype="bool")
|
||||
think_end_id = 999
|
||||
line_break_id = 888
|
||||
|
||||
# Run operator
|
||||
limit_thinking_content_length_v2(
|
||||
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, think_end_id, line_break_id
|
||||
)
|
||||
|
||||
# Verify: nothing changed
|
||||
assert next_tokens.numpy()[0, 0] == 100
|
||||
assert limit_think_status.numpy()[0] == 0
|
||||
|
||||
def test_already_in_response_phase_status_3(self):
|
||||
"""Test that status 3 (response phase) is terminal"""
|
||||
next_tokens = paddle.to_tensor([[100]], dtype="int64")
|
||||
max_think_lens = paddle.to_tensor([5], dtype="int32")
|
||||
step_idx = paddle.to_tensor([[10]], dtype="int64")
|
||||
limit_think_status = paddle.to_tensor([3], dtype="int32")
|
||||
stop_flags = paddle.to_tensor([False], dtype="bool")
|
||||
think_end_id = 999
|
||||
line_break_id = 888
|
||||
|
||||
# Run operator
|
||||
limit_thinking_content_length_v2(
|
||||
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, think_end_id, line_break_id
|
||||
)
|
||||
|
||||
# Verify: nothing changed
|
||||
assert next_tokens.numpy()[0, 0] == 100
|
||||
assert limit_think_status.numpy()[0] == 3
|
||||
|
||||
def test_mixed_batch_various_states(self):
|
||||
"""Test batch with sequences in different states"""
|
||||
next_tokens = paddle.to_tensor([[100], [200], [999], [300], [400]], dtype="int64")
|
||||
max_think_lens = paddle.to_tensor([10, 5, 8, -1, 6], dtype="int32")
|
||||
step_idx = paddle.to_tensor([[3], [5], [4], [100], [9]], dtype="int64")
|
||||
limit_think_status = paddle.to_tensor([0, 0, 0, 0, 2], dtype="int32")
|
||||
stop_flags = paddle.to_tensor([False, False, False, False, False], dtype="bool")
|
||||
think_end_id = 999
|
||||
line_break_id = 888
|
||||
|
||||
# Run operator
|
||||
limit_thinking_content_length_v2(
|
||||
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, think_end_id, line_break_id
|
||||
)
|
||||
|
||||
# Seq 0: step 3 < max 10, status 0, unchanged
|
||||
assert next_tokens.numpy()[0, 0] == 100
|
||||
assert limit_think_status.numpy()[0] == 0
|
||||
|
||||
# Seq 1: step 5 == max 5, inject line_break_id, status -> 1
|
||||
assert next_tokens.numpy()[1, 0] == 888
|
||||
assert limit_think_status.numpy()[1] == 1
|
||||
|
||||
# Seq 2: token is think_end_id, status 0 -> 3
|
||||
assert next_tokens.numpy()[2, 0] == 999
|
||||
assert limit_think_status.numpy()[2] == 3
|
||||
|
||||
# Seq 3: disabled, unchanged
|
||||
assert next_tokens.numpy()[3, 0] == 300
|
||||
assert limit_think_status.numpy()[3] == 0
|
||||
|
||||
# Seq 4: status 2 (replacement done), transition to 3
|
||||
assert limit_think_status.numpy()[4] == 3
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,558 +0,0 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for speculate_limit_thinking_content_length_v1 and speculate_limit_thinking_content_length_v2"""
|
||||
|
||||
import unittest
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
speculate_limit_thinking_content_length_v1,
|
||||
speculate_limit_thinking_content_length_v2,
|
||||
)
|
||||
|
||||
|
||||
class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
|
||||
"""Tests for speculate_limit_thinking_content_length_v1 operator (</think> strategy with speculative decoding)"""
|
||||
|
||||
def test_normal_thinking_phase_no_truncation(self):
|
||||
"""Test normal thinking phase when all tokens are within limit"""
|
||||
# Batch 0 accepts 3 tokens, Batch 1 accepts 2 tokens
|
||||
next_tokens = paddle.to_tensor([[100, 101, 102], [200, 201, 0]], dtype="int64")
|
||||
max_think_lens = paddle.to_tensor([10, 15], dtype="int32")
|
||||
# step_idx represents current step after accepting tokens
|
||||
step_idx = paddle.to_tensor([5, 8], dtype="int64")
|
||||
limit_think_status = paddle.to_tensor([0, 0], dtype="int32")
|
||||
accept_num = paddle.to_tensor([3, 2], dtype="int32")
|
||||
stop_flags = paddle.to_tensor([False, False], dtype="bool")
|
||||
eos_token_ids = paddle.to_tensor([[2], [2]], dtype="int64")
|
||||
think_end_id = 999
|
||||
|
||||
# Run operator
|
||||
speculate_limit_thinking_content_length_v1(
|
||||
next_tokens,
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
limit_think_status,
|
||||
accept_num,
|
||||
stop_flags,
|
||||
eos_token_ids,
|
||||
think_end_id,
|
||||
)
|
||||
|
||||
# Verify: tokens unchanged, accept_num unchanged, status unchanged
|
||||
assert next_tokens.numpy()[0, 0] == 100
|
||||
assert next_tokens.numpy()[0, 1] == 101
|
||||
assert next_tokens.numpy()[0, 2] == 102
|
||||
assert accept_num.numpy()[0] == 3
|
||||
assert accept_num.numpy()[1] == 2
|
||||
assert limit_think_status.numpy()[0] == 0
|
||||
assert limit_think_status.numpy()[1] == 0
|
||||
assert step_idx.numpy()[0] == 5
|
||||
assert step_idx.numpy()[1] == 8
|
||||
|
||||
def test_force_truncation_when_exceeding_limit(self):
|
||||
"""Test force truncation when tokens exceed max_think_len"""
|
||||
# Accept 4 tokens, but will exceed limit at 3rd token
|
||||
next_tokens = paddle.to_tensor([[100, 101, 102, 103]], dtype="int64")
|
||||
max_think_lens = paddle.to_tensor([10], dtype="int32")
|
||||
# Current step is 12 after accepting 4 tokens, so base step is 12-4+1=9
|
||||
# Token 0 at step 9, token 1 at step 10 (>= max_think_len=10), should be truncated
|
||||
step_idx = paddle.to_tensor([12], dtype="int64")
|
||||
limit_think_status = paddle.to_tensor([0], dtype="int32")
|
||||
accept_num = paddle.to_tensor([4], dtype="int32")
|
||||
stop_flags = paddle.to_tensor([False], dtype="bool")
|
||||
eos_token_ids = paddle.to_tensor([[2], [2]], dtype="int64")
|
||||
think_end_id = 999
|
||||
|
||||
# Run operator
|
||||
speculate_limit_thinking_content_length_v1(
|
||||
next_tokens,
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
limit_think_status,
|
||||
accept_num,
|
||||
stop_flags,
|
||||
eos_token_ids,
|
||||
think_end_id,
|
||||
)
|
||||
|
||||
# Verify: token at position 1 should be replaced with think_end_id
|
||||
# accept_num should be 2 (truncated after 2nd token which triggers the condition)
|
||||
assert next_tokens.numpy()[0, 0] == 100 # Token at step 9
|
||||
assert next_tokens.numpy()[0, 1] == 999 # Token at step 10, replaced with think_end_id
|
||||
assert accept_num.numpy()[0] == 2 # Only accept first 2 tokens
|
||||
assert limit_think_status.numpy()[0] == 2 # Status updated to 2
|
||||
# step_idx should be adjusted
|
||||
assert step_idx.numpy()[0] == 10 # 12 - (4-2) = 10
|
||||
|
||||
def test_model_naturally_generates_think_end_id(self):
|
||||
"""Test when model naturally generates think_end_id in accepted tokens"""
|
||||
next_tokens = paddle.to_tensor([[100, 999, 102]], dtype="int64")
|
||||
max_think_lens = paddle.to_tensor([20], dtype="int32")
|
||||
step_idx = paddle.to_tensor([5], dtype="int64") # step 3-5
|
||||
limit_think_status = paddle.to_tensor([0], dtype="int32")
|
||||
accept_num = paddle.to_tensor([3], dtype="int32")
|
||||
stop_flags = paddle.to_tensor([False], dtype="bool")
|
||||
eos_token_ids = paddle.to_tensor([[2], [2]], dtype="int64")
|
||||
think_end_id = 999
|
||||
|
||||
# Run operator
|
||||
speculate_limit_thinking_content_length_v1(
|
||||
next_tokens,
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
limit_think_status,
|
||||
accept_num,
|
||||
stop_flags,
|
||||
eos_token_ids,
|
||||
think_end_id,
|
||||
)
|
||||
|
||||
# Verify: status changed to 2, tokens processed normally
|
||||
assert next_tokens.numpy()[0, 1] == 999
|
||||
assert limit_think_status.numpy()[0] == 2 # Thinking ended
|
||||
assert accept_num.numpy()[0] == 3 # All tokens accepted
|
||||
|
||||
def test_disabled_feature_negative_max_think_len(self):
|
||||
"""Test that negative max_think_len disables the feature"""
|
||||
next_tokens = paddle.to_tensor([[100, 101, 102]], dtype="int64")
|
||||
max_think_lens = paddle.to_tensor([-1], dtype="int32") # Disabled
|
||||
step_idx = paddle.to_tensor([100], dtype="int64")
|
||||
limit_think_status = paddle.to_tensor([0], dtype="int32")
|
||||
accept_num = paddle.to_tensor([3], dtype="int32")
|
||||
stop_flags = paddle.to_tensor([False], dtype="bool")
|
||||
eos_token_ids = paddle.to_tensor([[2], [2]], dtype="int64")
|
||||
think_end_id = 999
|
||||
|
||||
# Run operator
|
||||
speculate_limit_thinking_content_length_v1(
|
||||
next_tokens,
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
limit_think_status,
|
||||
accept_num,
|
||||
stop_flags,
|
||||
eos_token_ids,
|
||||
think_end_id,
|
||||
)
|
||||
|
||||
# Verify: nothing changed
|
||||
assert next_tokens.numpy()[0, 0] == 100
|
||||
assert accept_num.numpy()[0] == 3
|
||||
assert limit_think_status.numpy()[0] == 0
|
||||
|
||||
def test_zero_accept_num_early_return(self):
|
||||
"""Test early return when accept_num is 0"""
|
||||
next_tokens = paddle.to_tensor([[100, 101]], dtype="int64")
|
||||
max_think_lens = paddle.to_tensor([5], dtype="int32")
|
||||
step_idx = paddle.to_tensor([10], dtype="int64")
|
||||
limit_think_status = paddle.to_tensor([0], dtype="int32")
|
||||
accept_num = paddle.to_tensor([0], dtype="int32") # No tokens accepted
|
||||
stop_flags = paddle.to_tensor([False], dtype="bool")
|
||||
eos_token_ids = paddle.to_tensor([[2], [2]], dtype="int64")
|
||||
think_end_id = 999
|
||||
|
||||
# Run operator
|
||||
speculate_limit_thinking_content_length_v1(
|
||||
next_tokens,
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
limit_think_status,
|
||||
accept_num,
|
||||
stop_flags,
|
||||
eos_token_ids,
|
||||
think_end_id,
|
||||
)
|
||||
|
||||
# Verify: nothing changed (early return)
|
||||
assert accept_num.numpy()[0] == 0
|
||||
assert limit_think_status.numpy()[0] == 0
|
||||
|
||||
def test_already_in_response_phase_status_3(self):
|
||||
"""Test that status 3 is terminal (note: v1 uses status 2 as terminal in comment, but code shows 3)"""
|
||||
next_tokens = paddle.to_tensor([[100, 101]], dtype="int64")
|
||||
max_think_lens = paddle.to_tensor([5], dtype="int32")
|
||||
step_idx = paddle.to_tensor([10], dtype="int64")
|
||||
limit_think_status = paddle.to_tensor([3], dtype="int32") # Terminal status
|
||||
accept_num = paddle.to_tensor([2], dtype="int32")
|
||||
stop_flags = paddle.to_tensor([False], dtype="bool")
|
||||
eos_token_ids = paddle.to_tensor([[2], [2]], dtype="int64")
|
||||
think_end_id = 999
|
||||
|
||||
# Run operator
|
||||
speculate_limit_thinking_content_length_v1(
|
||||
next_tokens,
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
limit_think_status,
|
||||
accept_num,
|
||||
stop_flags,
|
||||
eos_token_ids,
|
||||
think_end_id,
|
||||
)
|
||||
|
||||
# Verify: early return, nothing changed
|
||||
assert limit_think_status.numpy()[0] == 3
|
||||
|
||||
def test_status_transition_from_0_to_1_to_2(self):
|
||||
"""Test status transition: 0 (thinking) -> 1 (injected) -> 2 (ended)"""
|
||||
# First call: inject think_end_id due to exceeding limit
|
||||
next_tokens = paddle.to_tensor([[100, 101]], dtype="int64")
|
||||
max_think_lens = paddle.to_tensor([9], dtype="int32")
|
||||
step_idx = paddle.to_tensor([9], dtype="int64") # base step = 9-2+1 = 8
|
||||
limit_think_status = paddle.to_tensor([0], dtype="int32")
|
||||
accept_num = paddle.to_tensor([2], dtype="int32")
|
||||
stop_flags = paddle.to_tensor([False], dtype="bool")
|
||||
eos_token_ids = paddle.to_tensor([[2], [2]], dtype="int64")
|
||||
think_end_id = 999
|
||||
|
||||
speculate_limit_thinking_content_length_v1(
|
||||
next_tokens,
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
limit_think_status,
|
||||
accept_num,
|
||||
stop_flags,
|
||||
eos_token_ids,
|
||||
think_end_id,
|
||||
)
|
||||
|
||||
# First token at step 8 is OK, second token at step 9 >= 8, so gets replaced
|
||||
assert next_tokens.numpy()[0, 0] == 100
|
||||
assert next_tokens.numpy()[0, 1] == 999 # Replaced
|
||||
assert limit_think_status.numpy()[0] == 2
|
||||
assert accept_num.numpy()[0] == 2
|
||||
|
||||
def test_mixed_batch_with_different_states(self):
|
||||
"""Test batch with different sequences in various states"""
|
||||
next_tokens = paddle.to_tensor([[100, 101, 102], [200, 999, 202], [300, 301, 0]], dtype="int64")
|
||||
max_think_lens = paddle.to_tensor([10, 15, -1], dtype="int32")
|
||||
step_idx = paddle.to_tensor([6, 8, 50], dtype="int64")
|
||||
limit_think_status = paddle.to_tensor([0, 0, 0], dtype="int32")
|
||||
accept_num = paddle.to_tensor([3, 3, 2], dtype="int32")
|
||||
stop_flags = paddle.to_tensor([False, False, False], dtype="bool")
|
||||
eos_token_ids = paddle.to_tensor([[2], [2]], dtype="int64")
|
||||
think_end_id = 999
|
||||
|
||||
# Run operator
|
||||
speculate_limit_thinking_content_length_v1(
|
||||
next_tokens,
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
limit_think_status,
|
||||
accept_num,
|
||||
stop_flags,
|
||||
eos_token_ids,
|
||||
think_end_id,
|
||||
)
|
||||
|
||||
# Seq 0: all tokens within limit, unchanged
|
||||
assert limit_think_status.numpy()[0] == 0
|
||||
assert accept_num.numpy()[0] == 3
|
||||
|
||||
# Seq 1: second token is think_end_id, status -> 2
|
||||
assert limit_think_status.numpy()[1] == 2
|
||||
assert accept_num.numpy()[1] == 3
|
||||
|
||||
# Seq 2: disabled, unchanged
|
||||
assert limit_think_status.numpy()[2] == 0
|
||||
assert accept_num.numpy()[2] == 2
|
||||
|
||||
|
||||
class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
|
||||
"""Tests for speculate_limit_thinking_content_length_v2 operator.
|
||||
|
||||
Tests the \\n</think>\\n\\n strategy with speculative decoding.
|
||||
"""
|
||||
|
||||
def test_normal_thinking_phase_no_truncation(self):
|
||||
"""Test normal thinking phase when all tokens are within limit"""
|
||||
next_tokens = paddle.to_tensor([[100, 101, 102], [200, 201, 0]], dtype="int64")
|
||||
max_think_lens = paddle.to_tensor([10, 15], dtype="int32")
|
||||
step_idx = paddle.to_tensor([5, 8], dtype="int64")
|
||||
limit_think_status = paddle.to_tensor([0, 0], dtype="int32")
|
||||
accept_num = paddle.to_tensor([3, 2], dtype="int32")
|
||||
stop_flags = paddle.to_tensor([False, False], dtype="bool")
|
||||
think_end_id = 999
|
||||
line_break_id = 888
|
||||
|
||||
# Run operator
|
||||
speculate_limit_thinking_content_length_v2(
|
||||
next_tokens,
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
limit_think_status,
|
||||
accept_num,
|
||||
stop_flags,
|
||||
think_end_id,
|
||||
line_break_id,
|
||||
)
|
||||
|
||||
# Verify: unchanged
|
||||
assert next_tokens.numpy()[0, 0] == 100
|
||||
assert accept_num.numpy()[0] == 3
|
||||
assert limit_think_status.numpy()[0] == 0
|
||||
|
||||
def test_force_truncation_with_sequence_injection(self):
|
||||
"""Test force truncation with \n</think>\n\n sequence injection"""
|
||||
# Test when multiple tokens in batch trigger different injections
|
||||
next_tokens = paddle.to_tensor([[100, 101, 102, 103, 104]], dtype="int64")
|
||||
max_think_lens = paddle.to_tensor([8], dtype="int32")
|
||||
# step_idx = 12, accept_num = 5, base_step = 12-5+1 = 8
|
||||
# Token 0 at step 8 (== max 8): inject line_break
|
||||
step_idx = paddle.to_tensor([12], dtype="int64")
|
||||
limit_think_status = paddle.to_tensor([0], dtype="int32")
|
||||
accept_num = paddle.to_tensor([5], dtype="int32")
|
||||
stop_flags = paddle.to_tensor([False], dtype="bool")
|
||||
think_end_id = 999
|
||||
line_break_id = 888
|
||||
|
||||
# Run operator
|
||||
speculate_limit_thinking_content_length_v2(
|
||||
next_tokens,
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
limit_think_status,
|
||||
accept_num,
|
||||
stop_flags,
|
||||
think_end_id,
|
||||
line_break_id,
|
||||
)
|
||||
|
||||
# Token at step 8 (== max 8) should be replaced with line_break_id
|
||||
assert next_tokens.numpy()[0, 0] == 888 # line_break_id
|
||||
assert limit_think_status.numpy()[0] == 1
|
||||
assert accept_num.numpy()[0] == 1 # Truncated after 1st token
|
||||
assert step_idx.numpy()[0] == 8 # 12 - (5-1)
|
||||
|
||||
def test_injection_sequence_steps(self):
|
||||
"""Test each step of the injection sequence: \n, </think>, \n, \n"""
|
||||
max_think_lens = paddle.to_tensor([5], dtype="int32")
|
||||
think_end_id = 999
|
||||
line_break_id = 888
|
||||
|
||||
# Step 1: at max_think_len, inject first \n
|
||||
next_tokens = paddle.to_tensor([[100]], dtype="int64")
|
||||
step_idx = paddle.to_tensor([5], dtype="int64") # base_step = 5-1+1 = 5
|
||||
limit_think_status = paddle.to_tensor([0], dtype="int32")
|
||||
accept_num = paddle.to_tensor([1], dtype="int32")
|
||||
stop_flags = paddle.to_tensor([False], dtype="bool")
|
||||
|
||||
speculate_limit_thinking_content_length_v2(
|
||||
next_tokens,
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
limit_think_status,
|
||||
accept_num,
|
||||
stop_flags,
|
||||
think_end_id,
|
||||
line_break_id,
|
||||
)
|
||||
assert next_tokens.numpy()[0, 0] == 888
|
||||
assert limit_think_status.numpy()[0] == 1
|
||||
|
||||
# Step 2: at max_think_len+1, inject </think>
|
||||
next_tokens = paddle.to_tensor([[200]], dtype="int64")
|
||||
step_idx = paddle.to_tensor([6], dtype="int64") # base_step = 6
|
||||
limit_think_status = paddle.to_tensor([1], dtype="int32")
|
||||
accept_num = paddle.to_tensor([1], dtype="int32")
|
||||
|
||||
speculate_limit_thinking_content_length_v2(
|
||||
next_tokens,
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
limit_think_status,
|
||||
accept_num,
|
||||
stop_flags,
|
||||
think_end_id,
|
||||
line_break_id,
|
||||
)
|
||||
assert next_tokens.numpy()[0, 0] == 999
|
||||
assert limit_think_status.numpy()[0] == 1
|
||||
|
||||
# Step 3: at max_think_len+2, inject second \n
|
||||
next_tokens = paddle.to_tensor([[300]], dtype="int64")
|
||||
step_idx = paddle.to_tensor([7], dtype="int64")
|
||||
limit_think_status = paddle.to_tensor([1], dtype="int32")
|
||||
accept_num = paddle.to_tensor([1], dtype="int32")
|
||||
|
||||
speculate_limit_thinking_content_length_v2(
|
||||
next_tokens,
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
limit_think_status,
|
||||
accept_num,
|
||||
stop_flags,
|
||||
think_end_id,
|
||||
line_break_id,
|
||||
)
|
||||
assert next_tokens.numpy()[0, 0] == 888
|
||||
assert limit_think_status.numpy()[0] == 1
|
||||
|
||||
# Step 4: at max_think_len+3, inject third \n and move to status 3
|
||||
next_tokens = paddle.to_tensor([[400]], dtype="int64")
|
||||
step_idx = paddle.to_tensor([8], dtype="int64")
|
||||
limit_think_status = paddle.to_tensor([1], dtype="int32")
|
||||
accept_num = paddle.to_tensor([1], dtype="int32")
|
||||
|
||||
speculate_limit_thinking_content_length_v2(
|
||||
next_tokens,
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
limit_think_status,
|
||||
accept_num,
|
||||
stop_flags,
|
||||
think_end_id,
|
||||
line_break_id,
|
||||
)
|
||||
assert next_tokens.numpy()[0, 0] == 888
|
||||
assert limit_think_status.numpy()[0] == 3
|
||||
|
||||
def test_model_naturally_generates_think_end_id(self):
|
||||
"""Test when model naturally generates think_end_id"""
|
||||
next_tokens = paddle.to_tensor([[100, 999, 102]], dtype="int64")
|
||||
max_think_lens = paddle.to_tensor([20], dtype="int32")
|
||||
step_idx = paddle.to_tensor([5], dtype="int64")
|
||||
limit_think_status = paddle.to_tensor([0], dtype="int32")
|
||||
accept_num = paddle.to_tensor([3], dtype="int32")
|
||||
stop_flags = paddle.to_tensor([False], dtype="bool")
|
||||
think_end_id = 999
|
||||
line_break_id = 888
|
||||
|
||||
# Run operator
|
||||
speculate_limit_thinking_content_length_v2(
|
||||
next_tokens,
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
limit_think_status,
|
||||
accept_num,
|
||||
stop_flags,
|
||||
think_end_id,
|
||||
line_break_id,
|
||||
)
|
||||
|
||||
# Verify: status changed to 3
|
||||
assert limit_think_status.numpy()[0] == 3
|
||||
|
||||
def test_status_2_to_status_3_transition(self):
|
||||
"""Test transition from status 2 to status 3"""
|
||||
next_tokens = paddle.to_tensor([[100]], dtype="int64")
|
||||
max_think_lens = paddle.to_tensor([5], dtype="int32")
|
||||
step_idx = paddle.to_tensor([10], dtype="int64")
|
||||
limit_think_status = paddle.to_tensor([2], dtype="int32")
|
||||
accept_num = paddle.to_tensor([1], dtype="int32")
|
||||
stop_flags = paddle.to_tensor([False], dtype="bool")
|
||||
think_end_id = 999
|
||||
line_break_id = 888
|
||||
|
||||
# Run operator
|
||||
speculate_limit_thinking_content_length_v2(
|
||||
next_tokens,
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
limit_think_status,
|
||||
accept_num,
|
||||
stop_flags,
|
||||
think_end_id,
|
||||
line_break_id,
|
||||
)
|
||||
|
||||
# Verify: status 2 -> 3
|
||||
assert limit_think_status.numpy()[0] == 3
|
||||
|
||||
def test_disabled_feature_negative_max_think_len(self):
|
||||
"""Test that negative max_think_len disables the feature"""
|
||||
next_tokens = paddle.to_tensor([[100, 101]], dtype="int64")
|
||||
max_think_lens = paddle.to_tensor([-1], dtype="int32")
|
||||
step_idx = paddle.to_tensor([100], dtype="int64")
|
||||
limit_think_status = paddle.to_tensor([0], dtype="int32")
|
||||
accept_num = paddle.to_tensor([2], dtype="int32")
|
||||
stop_flags = paddle.to_tensor([False], dtype="bool")
|
||||
think_end_id = 999
|
||||
line_break_id = 888
|
||||
|
||||
# Run operator
|
||||
speculate_limit_thinking_content_length_v2(
|
||||
next_tokens,
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
limit_think_status,
|
||||
accept_num,
|
||||
stop_flags,
|
||||
think_end_id,
|
||||
line_break_id,
|
||||
)
|
||||
|
||||
# Verify: nothing changed
|
||||
assert limit_think_status.numpy()[0] == 0
|
||||
assert accept_num.numpy()[0] == 2
|
||||
|
||||
def test_zero_accept_num_early_return(self):
|
||||
"""Test early return when accept_num is 0"""
|
||||
next_tokens = paddle.to_tensor([[100]], dtype="int64")
|
||||
max_think_lens = paddle.to_tensor([5], dtype="int32")
|
||||
step_idx = paddle.to_tensor([10], dtype="int64")
|
||||
limit_think_status = paddle.to_tensor([0], dtype="int32")
|
||||
accept_num = paddle.to_tensor([0], dtype="int32")
|
||||
stop_flags = paddle.to_tensor([False], dtype="bool")
|
||||
think_end_id = 999
|
||||
line_break_id = 888
|
||||
|
||||
# Run operator
|
||||
speculate_limit_thinking_content_length_v2(
|
||||
next_tokens,
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
limit_think_status,
|
||||
accept_num,
|
||||
stop_flags,
|
||||
think_end_id,
|
||||
line_break_id,
|
||||
)
|
||||
|
||||
# Verify: early return
|
||||
assert accept_num.numpy()[0] == 0
|
||||
assert limit_think_status.numpy()[0] == 0
|
||||
|
||||
def test_already_in_response_phase_status_3(self):
|
||||
"""Test that status 3 is terminal"""
|
||||
next_tokens = paddle.to_tensor([[100]], dtype="int64")
|
||||
max_think_lens = paddle.to_tensor([5], dtype="int32")
|
||||
step_idx = paddle.to_tensor([10], dtype="int64")
|
||||
limit_think_status = paddle.to_tensor([3], dtype="int32")
|
||||
accept_num = paddle.to_tensor([1], dtype="int32")
|
||||
stop_flags = paddle.to_tensor([False], dtype="bool")
|
||||
think_end_id = 999
|
||||
line_break_id = 888
|
||||
|
||||
# Run operator
|
||||
speculate_limit_thinking_content_length_v2(
|
||||
next_tokens,
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
limit_think_status,
|
||||
accept_num,
|
||||
stop_flags,
|
||||
think_end_id,
|
||||
line_break_id,
|
||||
)
|
||||
|
||||
# Verify: early return, nothing changed
|
||||
assert limit_think_status.numpy()[0] == 3
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user