mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[XPU] Add speculate_limit_thinking_content_length Op. (#6627)
* [XPU] Add speculate_limit_thinking_content_length OP for xpu. * add unittest. * format codes. * format codes. * format codes. * Fix unused kernel launch return value. --------- Co-authored-by: cmcamdy <1027740945@qq.com>
This commit is contained in:
@@ -0,0 +1,80 @@
|
||||
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <paddle/phi/backends/xpu/xpu_context.h>
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/core/enforce.h"
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
void SpeculateLimitThinkingContentLength(const paddle::Tensor& next_tokens,
|
||||
const paddle::Tensor& max_think_lens,
|
||||
const paddle::Tensor& max_reply_lens,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& limit_status,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& eos_token_ids,
|
||||
const paddle::Tensor& inject_token_ids,
|
||||
const int64_t think_end_id,
|
||||
const bool splitwise_role_is_decode) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
|
||||
|
||||
const int batch_size = next_tokens.shape()[0];
|
||||
const int tokens_per_step = next_tokens.shape()[1];
|
||||
const int eos_token_id_len = eos_token_ids.shape()[0];
|
||||
const int inject_len = inject_token_ids.shape()[0];
|
||||
|
||||
int r =
|
||||
baidu::xpu::api::plugin::speculate_limit_thinking_content_length_kernel(
|
||||
xpu_ctx->x_context(),
|
||||
const_cast<int64_t*>(next_tokens.data<int64_t>()),
|
||||
max_think_lens.data<int>(),
|
||||
const_cast<int*>(max_reply_lens.data<int>()),
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
eos_token_ids.data<int64_t>(),
|
||||
const_cast<int*>(limit_status.data<int>()),
|
||||
const_cast<int*>(accept_num.data<int>()),
|
||||
stop_flags.data<bool>(),
|
||||
think_end_id,
|
||||
(inject_len > 0) ? inject_token_ids.data<int64_t>() : nullptr,
|
||||
tokens_per_step,
|
||||
batch_size,
|
||||
eos_token_id_len,
|
||||
inject_len,
|
||||
splitwise_role_is_decode);
|
||||
PD_CHECK(r == 0,
|
||||
"baidu::xpu::api::plugin::"
|
||||
"speculate_limit_thinking_content_length_kernel failed.");
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_limit_thinking_content_length)
|
||||
.Inputs({"next_tokens",
|
||||
"max_think_lens",
|
||||
"max_reply_lens",
|
||||
"step_idx",
|
||||
"limit_status",
|
||||
"accept_num",
|
||||
"stop_flags",
|
||||
"eos_token_ids",
|
||||
"inject_token_ids"})
|
||||
.Attrs({"think_end_id: int64_t", "splitwise_role_is_decode: bool"})
|
||||
.Outputs({"next_tokens_out"})
|
||||
.SetInplaceMap({{"next_tokens", "next_tokens_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateLimitThinkingContentLength));
|
||||
@@ -645,17 +645,17 @@ DLL_EXPORT int speculate_update_v3(Context* ctx,
|
||||
const int max_draft_tokens);
|
||||
|
||||
DLL_EXPORT int speculate_update(Context* ctx,
|
||||
int* seq_lens_encoder, // 输入 [B_max, ]
|
||||
int* seq_lens_decoder, // 输出 [B_max, ]
|
||||
bool* not_need_stop, // [1,]
|
||||
int64_t* draft_tokens, // [B_max, T_max]
|
||||
int* actual_draft_token_nums, // [B_max, ]
|
||||
const int64_t* accept_tokens, // [B_max, T_max]
|
||||
const int* accept_num, // [B_max, ]
|
||||
const bool* stop_flags, // [B_max, ]
|
||||
const int* seq_lens_this_time, // [B_real,]
|
||||
const bool* is_block_step, // [B_max, ]
|
||||
int* mask_rollback, // [1,]
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
bool* not_need_stop,
|
||||
int64_t* draft_tokens,
|
||||
int* actual_draft_token_nums,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const bool* stop_flags,
|
||||
const int* seq_lens_this_time,
|
||||
const bool* is_block_step,
|
||||
int* mask_rollback,
|
||||
const int real_bsz,
|
||||
const int max_bsz,
|
||||
const int max_draft_tokens);
|
||||
@@ -704,6 +704,24 @@ DLL_EXPORT int update_attn_mask_offsets(Context* ctx,
|
||||
int real_bsz,
|
||||
int max_model_len,
|
||||
int decode_states_len);
|
||||
|
||||
DLL_EXPORT int speculate_limit_thinking_content_length_kernel(
|
||||
api::Context* ctx,
|
||||
int64_t* next_tokens,
|
||||
const int* max_think_lens,
|
||||
int* max_reply_lens,
|
||||
int64_t* step_idx,
|
||||
const int64_t* eos_token_ids,
|
||||
int* limit_status,
|
||||
int* accept_num,
|
||||
const bool* stop_flags,
|
||||
const int64_t think_end_id,
|
||||
const int64_t* inject_token_ids,
|
||||
const int tokens_per_step,
|
||||
const int bs,
|
||||
const int eos_token_id_len,
|
||||
const int inject_len,
|
||||
const bool splitwise_role_is_decode);
|
||||
/*--------------------------------------- MTP end
|
||||
* --------------------------------------------*/
|
||||
|
||||
|
||||
+204
@@ -0,0 +1,204 @@
|
||||
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "xpu/kernel/cluster_debug.h"
|
||||
#include "xpu/kernel/cluster_partition.h"
|
||||
#include "xpu/kernel/xtdk.h"
|
||||
#include "xpu/kernel/xtdk_math.h"
|
||||
#include "xpu/kernel/xtdk_simd.h"
|
||||
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
|
||||
// 1) 支持 inject_token_ids(任意长度)
|
||||
// 2) 支持限制回复长度 max_reply_lens
|
||||
// 3) 语义对齐非MTP limit_thinking_content_length:
|
||||
// - max_think_len < 0:不强制截断思考,但仍会监听 think_end_id
|
||||
// 来进入回复阶段
|
||||
// - max_reply_len >= 0:仅在"思考结束后的下一个 token"开始计回复长度
|
||||
//
|
||||
// 状态机(复用 limit_status)
|
||||
// - 0:思考生成阶段
|
||||
// - 1..inject_len:注入 inject_token_ids[status - 1]
|
||||
// - done_status = inject_len + 1(inject_len==0 时
|
||||
// done_status=1):思考结束点(本 token 不计入回复)
|
||||
// - reply_base = done_status + 1
|
||||
// - status >= reply_base:回复阶段计数,reply_len = status -
|
||||
// reply_base(已输出回复 token 数)
|
||||
__global__ void speculate_limit_thinking_content_length_kernel(
|
||||
int64_t* next_tokens, // [bs, tokens_per_step]
|
||||
const int* max_think_lens, // [bs]
|
||||
int* max_reply_lens, // [bs]
|
||||
int64_t* step_idx, // [bs]
|
||||
const int64_t* eos_token_ids, // [eos_len]
|
||||
int* limit_status, // [bs]
|
||||
int* accept_num, // [bs]
|
||||
const bool* stop_flags, // [bs]
|
||||
const int64_t think_end_id,
|
||||
const int64_t* inject_token_ids, // [inject_len]
|
||||
const int tokens_per_step,
|
||||
const int bs,
|
||||
const int eos_token_id_len,
|
||||
const int inject_len,
|
||||
const bool splitwise_role_is_decode) {
|
||||
const int total_threads = cluster_num() * core_num();
|
||||
const int tid = cluster_id() * core_num() + core_id();
|
||||
|
||||
for (int bid = tid; bid < bs; bid += total_threads) {
|
||||
const int original_accept_num = accept_num[bid];
|
||||
if (original_accept_num <= 0) continue;
|
||||
if (stop_flags[bid]) continue;
|
||||
|
||||
const int max_think_len =
|
||||
max_think_lens[bid]; // <0: 不强制截断思考(但仍监听 think_end_id)
|
||||
int max_reply_len = max_reply_lens[bid]; // <0: 不限制回复
|
||||
if (max_think_len < 0 && max_reply_len < 0) continue;
|
||||
|
||||
// 状态常量(允许 inject_len==0)
|
||||
const int done_status = (inject_len > 0) ? (inject_len + 1) : 1;
|
||||
const int reply_base = done_status + 1;
|
||||
|
||||
int status = limit_status[bid];
|
||||
if (status < 0) status = 0;
|
||||
|
||||
int new_accept_num = original_accept_num;
|
||||
|
||||
// 本 step 的 token offset 对应的绝对 step
|
||||
const int64_t current_base_step = step_idx[bid] - original_accept_num + 1;
|
||||
|
||||
for (int token_offset = 0; token_offset < original_accept_num;
|
||||
token_offset++) {
|
||||
const int token_idx = bid * tokens_per_step + token_offset;
|
||||
int64_t next_token = next_tokens[token_idx];
|
||||
const int64_t current_step = current_base_step + token_offset;
|
||||
|
||||
const int prev_status = status;
|
||||
bool condition_triggered = false;
|
||||
|
||||
// ======================= 1) 思考阶段监听 think_end_id(语义对齐非MTP)
|
||||
// =======================
|
||||
// 注意:这里必须放在"注入触发逻辑"之前,因为如果模型自然输出 </think>,
|
||||
// 这一 token 应该把 status 置为 done_status,但"本 token 不计入回复"。
|
||||
if (status == 0 && next_token == think_end_id) {
|
||||
status = done_status;
|
||||
// 不截断 accept_num:后续 token 可以继续(但回复计数从下一个 token
|
||||
// 才开始)
|
||||
if (max_reply_len >= 0) {
|
||||
max_reply_len += 2;
|
||||
}
|
||||
}
|
||||
|
||||
// ======================= 2) 仅当启用"思考截断"(max_think_len>=0)
|
||||
// 时触发注入 =======================
|
||||
if (max_think_len >= 0 && status < reply_base) {
|
||||
if (max_think_len > 0) {
|
||||
// A) 超长触发:到达 max_think_len 时开始注入(从本 token 起输出
|
||||
// inject_token_ids[0])
|
||||
if (status == 0 &&
|
||||
(current_step - 1) ==
|
||||
max_think_len) { // current_step - 1 是因为 speculate_verify
|
||||
// 里 step_idx + 1 了
|
||||
status = (inject_len > 0) ? 1 : done_status;
|
||||
}
|
||||
} else if (max_think_len == 0) {
|
||||
// A) 超长触发:到达 max_think_len 时开始注入
|
||||
if (status == 0 && !splitwise_role_is_decode) {
|
||||
// 如果是集中式或 P 节点:从本 token 起输出 inject_token_ids[0])
|
||||
status = (inject_len > 0) ? 1 : done_status;
|
||||
} else if (status == 0 && splitwise_role_is_decode) {
|
||||
// 如果是 D 节点下:从本 token 起输出 inject_token_ids[1])
|
||||
status = (inject_len > 0) ? 2 : done_status + 1;
|
||||
}
|
||||
}
|
||||
|
||||
// B) 思考阶段提前输出 eos:开始注入(从本 token 起覆盖 eos 为
|
||||
// inject_token_ids[0])
|
||||
if (status == 0 && inject_len > 0) {
|
||||
for (int i = 0; i < eos_token_id_len; i++) {
|
||||
if (eos_token_ids[i] == next_token) {
|
||||
status = 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 注入序列:如果进入注入区间,则覆盖 next_token,并推进状态。
|
||||
// 由于覆盖了 token,必须截断 accept_num,避免一口气接受后续 token。
|
||||
if (inject_len > 0 && status >= 1 && status <= inject_len) {
|
||||
next_token = inject_token_ids[status - 1];
|
||||
status += 1;
|
||||
if (status > done_status) status = done_status; // 防御
|
||||
condition_triggered = true;
|
||||
}
|
||||
}
|
||||
|
||||
// 这一 token 是否"刚刚进入 done_status"
|
||||
// 典型场景:自然输出 </think>(0 -> done_status)
|
||||
// 或 注入最后一步(inject_len -> done_status)
|
||||
const bool became_done_this_token = (status == done_status) &&
|
||||
(prev_status != done_status) &&
|
||||
(prev_status < reply_base);
|
||||
|
||||
// ======================= 3) 回复长度限制(对齐非MTP)
|
||||
// ======================= 关键:刚进入 done_status 的这一 token(</think>
|
||||
// 或注入 token)不计入回复,也不在这一 token 开始回复计数
|
||||
if (max_reply_len >= 0) {
|
||||
if (!became_done_this_token) {
|
||||
// 只有在"前一个 token 已经是 done_status",当前 token 才允许进入
|
||||
// reply_base 开始计数
|
||||
if (status == done_status) {
|
||||
status = reply_base; // reply_len = 0
|
||||
}
|
||||
|
||||
if (status >= reply_base) {
|
||||
int reply_len =
|
||||
status - reply_base; // 已输出回复 token 数(不含当前 token)
|
||||
|
||||
if (reply_len >= max_reply_len) {
|
||||
// 达到上限:强制 EOS,并截断 accept_num
|
||||
if (eos_token_id_len > 0) next_token = eos_token_ids[0];
|
||||
status = reply_base + max_reply_len;
|
||||
condition_triggered = true;
|
||||
} else {
|
||||
// 正常输出当前 token,并回复计数 +1
|
||||
status = reply_base + (reply_len + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 写回当前 token
|
||||
next_tokens[token_idx] = next_token;
|
||||
|
||||
// 若本 token 触发了"覆盖 token"(注入 or 强制 EOS),则截断 accept_num
|
||||
if (condition_triggered) {
|
||||
new_accept_num = token_offset + 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// 更新 step_idx / accept_num(被截断的 token 需要回退 step_idx)
|
||||
const int discarded_tokens = original_accept_num - new_accept_num;
|
||||
if (discarded_tokens > 0) {
|
||||
step_idx[bid] -= discarded_tokens;
|
||||
}
|
||||
|
||||
accept_num[bid] = new_accept_num;
|
||||
limit_status[bid] = status;
|
||||
max_reply_lens[bid] = max_reply_len;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
+292
@@ -0,0 +1,292 @@
|
||||
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include "xpu/plugin.h"
|
||||
#include "xpu/refactor/impl_public/wrapper_check.h"
|
||||
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
|
||||
__attribute__((global)) void speculate_limit_thinking_content_length_kernel(
|
||||
int64_t* next_tokens,
|
||||
const int* max_think_lens,
|
||||
int* max_reply_lens,
|
||||
int64_t* step_idx,
|
||||
const int64_t* eos_token_ids,
|
||||
int* limit_status,
|
||||
int* accept_num,
|
||||
const bool* stop_flags,
|
||||
const int64_t think_end_id,
|
||||
const int64_t* inject_token_ids,
|
||||
const int tokens_per_step,
|
||||
const int bs,
|
||||
const int eos_token_id_len,
|
||||
const int inject_len,
|
||||
const bool splitwise_role_is_decode);
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
|
||||
namespace baidu {
|
||||
namespace xpu {
|
||||
namespace api {
|
||||
namespace plugin {
|
||||
|
||||
static int cpu_wrapper(Context* ctx,
|
||||
int64_t* next_tokens,
|
||||
const int* max_think_lens,
|
||||
int* max_reply_lens,
|
||||
int64_t* step_idx,
|
||||
const int64_t* eos_token_ids,
|
||||
int* limit_status,
|
||||
int* accept_num,
|
||||
const bool* stop_flags,
|
||||
const int64_t think_end_id,
|
||||
const int64_t* inject_token_ids,
|
||||
const int tokens_per_step,
|
||||
const int bs,
|
||||
const int eos_token_id_len,
|
||||
const int inject_len,
|
||||
const bool splitwise_role_is_decode) {
|
||||
for (int bid = 0; bid < bs; bid++) {
|
||||
const int original_accept_num = accept_num[bid];
|
||||
if (original_accept_num <= 0) continue;
|
||||
if (stop_flags[bid]) continue;
|
||||
|
||||
const int max_think_len = max_think_lens[bid];
|
||||
int max_reply_len = max_reply_lens[bid];
|
||||
if (max_think_len < 0 && max_reply_len < 0) continue;
|
||||
|
||||
const int done_status = (inject_len > 0) ? (inject_len + 1) : 1;
|
||||
const int reply_base = done_status + 1;
|
||||
|
||||
int status = limit_status[bid];
|
||||
if (status < 0) status = 0;
|
||||
|
||||
int new_accept_num = original_accept_num;
|
||||
|
||||
const int64_t current_base_step = step_idx[bid] - original_accept_num + 1;
|
||||
|
||||
for (int token_offset = 0; token_offset < original_accept_num;
|
||||
token_offset++) {
|
||||
const int token_idx = bid * tokens_per_step + token_offset;
|
||||
int64_t next_token = next_tokens[token_idx];
|
||||
const int64_t current_step = current_base_step + token_offset;
|
||||
|
||||
const int prev_status = status;
|
||||
bool condition_triggered = false;
|
||||
|
||||
if (status == 0 && next_token == think_end_id) {
|
||||
status = done_status;
|
||||
if (max_reply_len >= 0) {
|
||||
max_reply_len += 2;
|
||||
}
|
||||
}
|
||||
|
||||
if (max_think_len >= 0 && status < reply_base) {
|
||||
if (max_think_len > 0) {
|
||||
if (status == 0 && (current_step - 1) == max_think_len) {
|
||||
status = (inject_len > 0) ? 1 : done_status;
|
||||
}
|
||||
} else if (max_think_len == 0) {
|
||||
if (status == 0 && !splitwise_role_is_decode) {
|
||||
status = (inject_len > 0) ? 1 : done_status;
|
||||
} else if (status == 0 && splitwise_role_is_decode) {
|
||||
status = (inject_len > 0) ? 2 : done_status + 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (status == 0 && inject_len > 0) {
|
||||
for (int i = 0; i < eos_token_id_len; i++) {
|
||||
if (eos_token_ids[i] == next_token) {
|
||||
status = 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (inject_len > 0 && status >= 1 && status <= inject_len) {
|
||||
next_token = inject_token_ids[status - 1];
|
||||
status += 1;
|
||||
if (status > done_status) status = done_status;
|
||||
condition_triggered = true;
|
||||
}
|
||||
}
|
||||
|
||||
const bool became_done_this_token = (status == done_status) &&
|
||||
(prev_status != done_status) &&
|
||||
(prev_status < reply_base);
|
||||
|
||||
if (max_reply_len >= 0) {
|
||||
if (!became_done_this_token) {
|
||||
if (status == done_status) {
|
||||
status = reply_base;
|
||||
}
|
||||
|
||||
if (status >= reply_base) {
|
||||
int reply_len = status - reply_base;
|
||||
|
||||
if (reply_len >= max_reply_len) {
|
||||
if (eos_token_id_len > 0) next_token = eos_token_ids[0];
|
||||
status = reply_base + max_reply_len;
|
||||
condition_triggered = true;
|
||||
} else {
|
||||
status = reply_base + (reply_len + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
next_tokens[token_idx] = next_token;
|
||||
|
||||
if (condition_triggered) {
|
||||
new_accept_num = token_offset + 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
const int discarded_tokens = original_accept_num - new_accept_num;
|
||||
if (discarded_tokens > 0) {
|
||||
step_idx[bid] -= discarded_tokens;
|
||||
}
|
||||
|
||||
accept_num[bid] = new_accept_num;
|
||||
limit_status[bid] = status;
|
||||
max_reply_lens[bid] = max_reply_len;
|
||||
}
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
static int xpu3_wrapper(Context* ctx,
|
||||
int64_t* next_tokens,
|
||||
const int* max_think_lens,
|
||||
int* max_reply_lens,
|
||||
int64_t* step_idx,
|
||||
const int64_t* eos_token_ids,
|
||||
int* limit_status,
|
||||
int* accept_num,
|
||||
const bool* stop_flags,
|
||||
const int64_t think_end_id,
|
||||
const int64_t* inject_token_ids,
|
||||
const int tokens_per_step,
|
||||
const int bs,
|
||||
const int eos_token_id_len,
|
||||
const int inject_len,
|
||||
const bool splitwise_role_is_decode) {
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
auto kernel = xpu3::plugin::speculate_limit_thinking_content_length_kernel;
|
||||
int32_t ret_xre = kernel<<<1, 64, ctx->xpu_stream>>>(
|
||||
reinterpret_cast<XPU_INT64*>(next_tokens),
|
||||
max_think_lens,
|
||||
max_reply_lens,
|
||||
reinterpret_cast<XPU_INT64*>(step_idx),
|
||||
reinterpret_cast<const XPU_INT64*>(eos_token_ids),
|
||||
limit_status,
|
||||
accept_num,
|
||||
stop_flags,
|
||||
think_end_id,
|
||||
reinterpret_cast<const XPU_INT64*>(inject_token_ids),
|
||||
tokens_per_step,
|
||||
bs,
|
||||
eos_token_id_len,
|
||||
inject_len,
|
||||
splitwise_role_is_decode);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
int speculate_limit_thinking_content_length_kernel(
|
||||
Context* ctx,
|
||||
int64_t* next_tokens,
|
||||
const int* max_think_lens,
|
||||
int* max_reply_lens,
|
||||
int64_t* step_idx,
|
||||
const int64_t* eos_token_ids,
|
||||
int* limit_status,
|
||||
int* accept_num,
|
||||
const bool* stop_flags,
|
||||
const int64_t think_end_id,
|
||||
const int64_t* inject_token_ids,
|
||||
const int tokens_per_step,
|
||||
const int bs,
|
||||
const int eos_token_id_len,
|
||||
const int inject_len,
|
||||
const bool splitwise_role_is_decode) {
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T1(
|
||||
ctx, "speculate_limit_thinking_content_length_kernel", int);
|
||||
WRAPPER_DUMP_PARAM5(ctx,
|
||||
next_tokens,
|
||||
max_think_lens,
|
||||
max_reply_lens,
|
||||
step_idx,
|
||||
eos_token_ids);
|
||||
WRAPPER_DUMP_PARAM5(ctx,
|
||||
limit_status,
|
||||
accept_num,
|
||||
stop_flags,
|
||||
think_end_id,
|
||||
inject_token_ids);
|
||||
WRAPPER_DUMP_PARAM5(ctx,
|
||||
tokens_per_step,
|
||||
bs,
|
||||
eos_token_id_len,
|
||||
inject_len,
|
||||
splitwise_role_is_decode);
|
||||
WRAPPER_DUMP(ctx);
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
return cpu_wrapper(ctx,
|
||||
next_tokens,
|
||||
max_think_lens,
|
||||
max_reply_lens,
|
||||
step_idx,
|
||||
eos_token_ids,
|
||||
limit_status,
|
||||
accept_num,
|
||||
stop_flags,
|
||||
think_end_id,
|
||||
inject_token_ids,
|
||||
tokens_per_step,
|
||||
bs,
|
||||
eos_token_id_len,
|
||||
inject_len,
|
||||
splitwise_role_is_decode);
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper(ctx,
|
||||
next_tokens,
|
||||
max_think_lens,
|
||||
max_reply_lens,
|
||||
step_idx,
|
||||
eos_token_ids,
|
||||
limit_status,
|
||||
accept_num,
|
||||
stop_flags,
|
||||
think_end_id,
|
||||
inject_token_ids,
|
||||
tokens_per_step,
|
||||
bs,
|
||||
eos_token_id_len,
|
||||
inject_len,
|
||||
splitwise_role_is_decode);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
}
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
@@ -0,0 +1,433 @@
|
||||
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
LIMIT_THINKING_TEST_DEBUG = os.environ.get("LIMIT_THINKING_TEST_DEBUG", "0") == "1"
|
||||
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.xpu import (
|
||||
speculate_limit_thinking_content_length as xpu_op,
|
||||
)
|
||||
|
||||
HAS_XPU = True
|
||||
except ImportError:
|
||||
HAS_XPU = False
|
||||
|
||||
|
||||
def ref_impl(
|
||||
next_tokens, # [bs, tokens_per_step] int64
|
||||
max_think_lens, # [bs] int32
|
||||
max_reply_lens, # [bs] int32
|
||||
step_idx, # [bs] int64
|
||||
eos_token_ids, # [eos_len] int64
|
||||
limit_status, # [bs] int32
|
||||
accept_num, # [bs] int32
|
||||
stop_flags, # [bs] bool
|
||||
think_end_id, # int
|
||||
inject_token_ids, # [inject_len] int64, may be empty
|
||||
splitwise_role_is_decode,
|
||||
):
|
||||
next_tokens = next_tokens.copy()
|
||||
max_reply_lens = max_reply_lens.copy()
|
||||
step_idx = step_idx.copy()
|
||||
limit_status = limit_status.copy()
|
||||
accept_num = accept_num.copy()
|
||||
|
||||
bs = len(accept_num)
|
||||
inject_len = len(inject_token_ids)
|
||||
eos_token_id_len = len(eos_token_ids)
|
||||
|
||||
for bid in range(bs):
|
||||
original_accept_num = int(accept_num[bid])
|
||||
if original_accept_num <= 0:
|
||||
continue
|
||||
if stop_flags[bid]:
|
||||
continue
|
||||
|
||||
max_think_len = int(max_think_lens[bid])
|
||||
max_reply_len = int(max_reply_lens[bid])
|
||||
if max_think_len < 0 and max_reply_len < 0:
|
||||
continue
|
||||
|
||||
done_status = (inject_len + 1) if inject_len > 0 else 1
|
||||
reply_base = done_status + 1
|
||||
|
||||
status = int(limit_status[bid])
|
||||
if status < 0:
|
||||
status = 0
|
||||
|
||||
new_accept_num = original_accept_num
|
||||
current_base_step = int(step_idx[bid]) - original_accept_num + 1
|
||||
|
||||
for token_offset in range(original_accept_num):
|
||||
next_token = int(next_tokens[bid, token_offset])
|
||||
current_step = current_base_step + token_offset
|
||||
|
||||
prev_status = status
|
||||
condition_triggered = False
|
||||
|
||||
# 1) 思考阶段监听 think_end_id
|
||||
if status == 0 and next_token == think_end_id:
|
||||
status = done_status
|
||||
if max_reply_len >= 0:
|
||||
max_reply_len += 2
|
||||
|
||||
# 2) 注入触发(仅 max_think_len >= 0 时)
|
||||
if max_think_len >= 0 and status < reply_base:
|
||||
if max_think_len > 0:
|
||||
if status == 0 and (current_step - 1) == max_think_len:
|
||||
status = 1 if inject_len > 0 else done_status
|
||||
elif max_think_len == 0:
|
||||
if status == 0 and not splitwise_role_is_decode:
|
||||
status = 1 if inject_len > 0 else done_status
|
||||
elif status == 0 and splitwise_role_is_decode:
|
||||
status = 2 if inject_len > 0 else done_status + 1
|
||||
|
||||
# eos 触发注入
|
||||
if status == 0 and inject_len > 0:
|
||||
for i in range(eos_token_id_len):
|
||||
if eos_token_ids[i] == next_token:
|
||||
status = 1
|
||||
break
|
||||
|
||||
# 注入序列
|
||||
if inject_len > 0 and 1 <= status <= inject_len:
|
||||
next_token = int(inject_token_ids[status - 1])
|
||||
status += 1
|
||||
if status > done_status:
|
||||
status = done_status
|
||||
condition_triggered = True
|
||||
|
||||
became_done_this_token = status == done_status and prev_status != done_status and prev_status < reply_base
|
||||
|
||||
# 3) 回复长度限制
|
||||
if max_reply_len >= 0:
|
||||
if not became_done_this_token:
|
||||
if status == done_status:
|
||||
status = reply_base
|
||||
if status >= reply_base:
|
||||
reply_len = status - reply_base
|
||||
if reply_len >= max_reply_len:
|
||||
if eos_token_id_len > 0:
|
||||
next_token = int(eos_token_ids[0])
|
||||
status = reply_base + max_reply_len
|
||||
condition_triggered = True
|
||||
else:
|
||||
status = reply_base + (reply_len + 1)
|
||||
|
||||
next_tokens[bid, token_offset] = next_token
|
||||
|
||||
if condition_triggered:
|
||||
new_accept_num = token_offset + 1
|
||||
break
|
||||
|
||||
discarded = original_accept_num - new_accept_num
|
||||
if discarded > 0:
|
||||
step_idx[bid] -= discarded
|
||||
|
||||
accept_num[bid] = new_accept_num
|
||||
limit_status[bid] = status
|
||||
max_reply_lens[bid] = max_reply_len
|
||||
|
||||
return {
|
||||
"next_tokens": next_tokens,
|
||||
"max_reply_lens": max_reply_lens,
|
||||
"step_idx": step_idx,
|
||||
"limit_status": limit_status,
|
||||
"accept_num": accept_num,
|
||||
}
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# 工具函数
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
def run_op(np_inputs, think_end_id, splitwise_role_is_decode, device, op_fn):
|
||||
"""在指定 device 上运行算子,返回输出字段的 numpy dict。"""
|
||||
paddle.set_device(device)
|
||||
next_tokens = paddle.to_tensor(np_inputs["next_tokens"].copy())
|
||||
max_think_lens = paddle.to_tensor(np_inputs["max_think_lens"].copy())
|
||||
max_reply_lens = paddle.to_tensor(np_inputs["max_reply_lens"].copy())
|
||||
step_idx = paddle.to_tensor(np_inputs["step_idx"].copy())
|
||||
limit_status = paddle.to_tensor(np_inputs["limit_status"].copy())
|
||||
accept_num = paddle.to_tensor(np_inputs["accept_num"].copy())
|
||||
stop_flags = paddle.to_tensor(np_inputs["stop_flags"].copy())
|
||||
eos_token_ids = paddle.to_tensor(np_inputs["eos_token_ids"].copy())
|
||||
inject_token_ids = paddle.to_tensor(np_inputs["inject_token_ids"].copy())
|
||||
|
||||
op_fn(
|
||||
next_tokens,
|
||||
max_think_lens,
|
||||
max_reply_lens,
|
||||
step_idx,
|
||||
limit_status,
|
||||
accept_num,
|
||||
stop_flags,
|
||||
eos_token_ids,
|
||||
inject_token_ids,
|
||||
think_end_id,
|
||||
splitwise_role_is_decode,
|
||||
)
|
||||
return {
|
||||
"next_tokens": next_tokens.numpy(),
|
||||
"max_reply_lens": max_reply_lens.numpy(),
|
||||
"step_idx": step_idx.numpy(),
|
||||
"limit_status": limit_status.numpy(),
|
||||
"accept_num": accept_num.numpy(),
|
||||
}
|
||||
|
||||
|
||||
def run_ref(np_inputs, think_end_id, splitwise_role_is_decode):
|
||||
return ref_impl(
|
||||
np_inputs["next_tokens"].copy(),
|
||||
np_inputs["max_think_lens"].copy(),
|
||||
np_inputs["max_reply_lens"].copy(),
|
||||
np_inputs["step_idx"].copy(),
|
||||
np_inputs["eos_token_ids"].copy(),
|
||||
np_inputs["limit_status"].copy(),
|
||||
np_inputs["accept_num"].copy(),
|
||||
np_inputs["stop_flags"].copy(),
|
||||
think_end_id,
|
||||
np_inputs["inject_token_ids"].copy(),
|
||||
splitwise_role_is_decode,
|
||||
)
|
||||
|
||||
|
||||
def assert_equal(expected, actual, label):
|
||||
for key in expected:
|
||||
np.testing.assert_array_equal(
|
||||
expected[key],
|
||||
actual[key],
|
||||
err_msg=f"[{label}] field='{key}' mismatch",
|
||||
)
|
||||
|
||||
|
||||
def make_inputs(
|
||||
bs,
|
||||
tokens_per_step,
|
||||
next_tokens,
|
||||
max_think_lens,
|
||||
max_reply_lens,
|
||||
step_idx,
|
||||
limit_status,
|
||||
accept_num,
|
||||
stop_flags,
|
||||
eos_token_ids,
|
||||
inject_token_ids,
|
||||
):
|
||||
return {
|
||||
"next_tokens": np.array(next_tokens, dtype=np.int64).reshape(bs, tokens_per_step),
|
||||
"max_think_lens": np.array(max_think_lens, dtype=np.int32),
|
||||
"max_reply_lens": np.array(max_reply_lens, dtype=np.int32),
|
||||
"step_idx": np.array(step_idx, dtype=np.int64),
|
||||
"limit_status": np.array(limit_status, dtype=np.int32),
|
||||
"accept_num": np.array(accept_num, dtype=np.int32),
|
||||
"stop_flags": np.array(stop_flags, dtype=bool),
|
||||
"eos_token_ids": np.array(eos_token_ids, dtype=np.int64),
|
||||
"inject_token_ids": np.array(inject_token_ids, dtype=np.int64),
|
||||
}
|
||||
|
||||
|
||||
def run_all_and_compare(test_case, np_inputs, think_end_id, splitwise_role_is_decode=False):
|
||||
|
||||
if LIMIT_THINKING_TEST_DEBUG:
|
||||
print("\n========== [INPUT] ==========")
|
||||
print(f" think_end_id : {think_end_id}")
|
||||
print(f" splitwise_role_is_decode: {splitwise_role_is_decode}")
|
||||
for k, v in np_inputs.items():
|
||||
print(f" {k:25s}: {v}")
|
||||
|
||||
ref_out = run_ref(np_inputs, think_end_id, splitwise_role_is_decode)
|
||||
|
||||
if LIMIT_THINKING_TEST_DEBUG:
|
||||
print("---------- [REF OUTPUT] ----------")
|
||||
for k, v in ref_out.items():
|
||||
print(f" {k:25s}: {v}")
|
||||
|
||||
if HAS_XPU:
|
||||
xpu_out = run_op(np_inputs, think_end_id, splitwise_role_is_decode, "xpu:0", xpu_op)
|
||||
|
||||
if LIMIT_THINKING_TEST_DEBUG:
|
||||
print("---------- [XPU OUTPUT] ----------")
|
||||
for k, v in xpu_out.items():
|
||||
print(f" {k:25s}: {v}")
|
||||
|
||||
assert_equal(ref_out, xpu_out, "XPU vs ref")
|
||||
else:
|
||||
test_case.skipTest("XPU is not available; only ref logic verified.")
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# 测试用例
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
THINK_END_ID = 100
|
||||
EOS_ID = 2
|
||||
|
||||
|
||||
class TestSpeculateLimitThinkingContentLength(unittest.TestCase):
|
||||
def test_think_end_natural(self):
|
||||
"""模型自然输出 think_end_id:status 0 → done_status,max_reply_len += 2。
|
||||
inject_len=0 → done_status=1, reply_base=2。"""
|
||||
np_inputs = make_inputs(
|
||||
bs=1,
|
||||
tokens_per_step=1,
|
||||
next_tokens=[THINK_END_ID], # 模型正好输出 </think>
|
||||
max_think_lens=[-1], # 不强制截断
|
||||
max_reply_lens=[5],
|
||||
step_idx=[3],
|
||||
limit_status=[0],
|
||||
accept_num=[1],
|
||||
stop_flags=[False],
|
||||
eos_token_ids=[EOS_ID],
|
||||
inject_token_ids=[],
|
||||
)
|
||||
run_all_and_compare(self, np_inputs, THINK_END_ID)
|
||||
|
||||
def test_inject_truncation(self):
|
||||
"""超过 max_think_len 触发注入序列,accept_num 截断至当前 token。
|
||||
inject_len=2 → done_status=3, reply_base=4。
|
||||
step_idx=6, current_step-1=5=max_think_len → 触发注入,token 被替换为 inject[0]=200。"""
|
||||
np_inputs = make_inputs(
|
||||
bs=1,
|
||||
tokens_per_step=1,
|
||||
next_tokens=[999],
|
||||
max_think_lens=[5],
|
||||
max_reply_lens=[-1],
|
||||
step_idx=[6], # current_base_step=6, current_step-1=5 == max_think_len
|
||||
limit_status=[0],
|
||||
accept_num=[1],
|
||||
stop_flags=[False],
|
||||
eos_token_ids=[EOS_ID],
|
||||
inject_token_ids=[200, 201],
|
||||
)
|
||||
run_all_and_compare(self, np_inputs, THINK_END_ID)
|
||||
|
||||
def test_reply_len_limit(self):
|
||||
"""回复计数达到 max_reply_len 上限,强制写入 EOS,截断 accept_num。
|
||||
inject_len=0 → done_status=1, reply_base=2。
|
||||
status=4(reply_len=2=max_reply_len)→ 强制 EOS。"""
|
||||
np_inputs = make_inputs(
|
||||
bs=1,
|
||||
tokens_per_step=1,
|
||||
next_tokens=[999],
|
||||
max_think_lens=[-1],
|
||||
max_reply_lens=[2],
|
||||
step_idx=[10],
|
||||
limit_status=[4], # reply_base(2) + reply_len(2),已到上限
|
||||
accept_num=[1],
|
||||
stop_flags=[False],
|
||||
eos_token_ids=[EOS_ID],
|
||||
inject_token_ids=[],
|
||||
)
|
||||
run_all_and_compare(self, np_inputs, THINK_END_ID)
|
||||
|
||||
def test_no_limit(self):
|
||||
"""max_think_len<0 且 max_reply_len<0,整个 batch 直接跳过,输出不变。"""
|
||||
np_inputs = make_inputs(
|
||||
bs=2,
|
||||
tokens_per_step=1,
|
||||
next_tokens=[111, 222],
|
||||
max_think_lens=[-1, -1],
|
||||
max_reply_lens=[-1, -1],
|
||||
step_idx=[5, 8],
|
||||
limit_status=[0, 0],
|
||||
accept_num=[1, 1],
|
||||
stop_flags=[False, False],
|
||||
eos_token_ids=[EOS_ID],
|
||||
inject_token_ids=[],
|
||||
)
|
||||
run_all_and_compare(self, np_inputs, THINK_END_ID)
|
||||
|
||||
def test_inject_len_zero_v1_behavior(self):
|
||||
"""inject_len=0 退化为 v1 行为:超时直接进入 done_status=1,token 不替换。"""
|
||||
np_inputs = make_inputs(
|
||||
bs=1,
|
||||
tokens_per_step=1,
|
||||
next_tokens=[999],
|
||||
max_think_lens=[3],
|
||||
max_reply_lens=[-1],
|
||||
step_idx=[4], # current_step-1=3 == max_think_len=3
|
||||
limit_status=[0],
|
||||
accept_num=[1],
|
||||
stop_flags=[False],
|
||||
eos_token_ids=[EOS_ID],
|
||||
inject_token_ids=[], # inject_len=0
|
||||
)
|
||||
run_all_and_compare(self, np_inputs, THINK_END_ID)
|
||||
|
||||
def test_multi_token_per_step(self):
|
||||
"""tokens_per_step=3,第 2 个 token(offset=1)触发回复超限,
|
||||
前 1 个 token 保留,accept_num 截断为 2,step_idx 回退 1。
|
||||
inject_len=0 → done_status=1, reply_base=2。
|
||||
status=2(reply_len=0),max_reply_len=1:
|
||||
offset=0: reply_len=0 < 1 → 正常输出,status→3
|
||||
offset=1: reply_len=1 >= 1 → 强制 EOS,截断"""
|
||||
np_inputs = make_inputs(
|
||||
bs=1,
|
||||
tokens_per_step=3,
|
||||
next_tokens=[501, 502, 503],
|
||||
max_think_lens=[-1],
|
||||
max_reply_lens=[1],
|
||||
step_idx=[12],
|
||||
limit_status=[2], # = reply_base, reply_len=0
|
||||
accept_num=[3],
|
||||
stop_flags=[False],
|
||||
eos_token_ids=[EOS_ID],
|
||||
inject_token_ids=[],
|
||||
)
|
||||
run_all_and_compare(self, np_inputs, THINK_END_ID)
|
||||
|
||||
def test_already_stopped(self):
|
||||
"""stop_flags=True 的 batch,直接跳过,输出不变。"""
|
||||
np_inputs = make_inputs(
|
||||
bs=2,
|
||||
tokens_per_step=1,
|
||||
next_tokens=[111, 222],
|
||||
max_think_lens=[5, 5],
|
||||
max_reply_lens=[10, 10],
|
||||
step_idx=[6, 6],
|
||||
limit_status=[0, 0],
|
||||
accept_num=[1, 1],
|
||||
stop_flags=[True, False], # batch0 已停止
|
||||
eos_token_ids=[EOS_ID],
|
||||
inject_token_ids=[],
|
||||
)
|
||||
run_all_and_compare(self, np_inputs, THINK_END_ID)
|
||||
|
||||
def test_splitwise_decode_node(self):
|
||||
"""splitwise_role_is_decode=True 且 max_think_len=0:
|
||||
D 节点从 inject_token_ids[1] 开始注入(status 直接跳到 2)。
|
||||
inject_len=3 → done_status=4, reply_base=5。"""
|
||||
np_inputs = make_inputs(
|
||||
bs=1,
|
||||
tokens_per_step=1,
|
||||
next_tokens=[999],
|
||||
max_think_lens=[0], # 立即触发
|
||||
max_reply_lens=[-1],
|
||||
step_idx=[1],
|
||||
limit_status=[0],
|
||||
accept_num=[1],
|
||||
stop_flags=[False],
|
||||
eos_token_ids=[EOS_ID],
|
||||
inject_token_ids=[200, 201, 202],
|
||||
)
|
||||
run_all_and_compare(self, np_inputs, THINK_END_ID, splitwise_role_is_decode=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user