[XPU] Add speculate_limit_thinking_content_length Op. (#6627)

* [XPU] Add speculate_limit_thinking_content_length OP for xpu.

* add unittest.

* format codes.

* format codes.

* format codes.

* Fix unused kernel launch return value.

---------

Co-authored-by: cmcamdy <1027740945@qq.com>
This commit is contained in:
Jiajun Ji
2026-03-11 17:30:17 +08:00
committed by GitHub
parent 9f0778f991
commit 88c4fbf8e1
5 changed files with 1038 additions and 11 deletions
@@ -0,0 +1,80 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "paddle/phi/core/enforce.h"
#include "xpu/plugin.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
void SpeculateLimitThinkingContentLength(const paddle::Tensor& next_tokens,
const paddle::Tensor& max_think_lens,
const paddle::Tensor& max_reply_lens,
const paddle::Tensor& step_idx,
const paddle::Tensor& limit_status,
const paddle::Tensor& accept_num,
const paddle::Tensor& stop_flags,
const paddle::Tensor& eos_token_ids,
const paddle::Tensor& inject_token_ids,
const int64_t think_end_id,
const bool splitwise_role_is_decode) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
const int batch_size = next_tokens.shape()[0];
const int tokens_per_step = next_tokens.shape()[1];
const int eos_token_id_len = eos_token_ids.shape()[0];
const int inject_len = inject_token_ids.shape()[0];
int r =
baidu::xpu::api::plugin::speculate_limit_thinking_content_length_kernel(
xpu_ctx->x_context(),
const_cast<int64_t*>(next_tokens.data<int64_t>()),
max_think_lens.data<int>(),
const_cast<int*>(max_reply_lens.data<int>()),
const_cast<int64_t*>(step_idx.data<int64_t>()),
eos_token_ids.data<int64_t>(),
const_cast<int*>(limit_status.data<int>()),
const_cast<int*>(accept_num.data<int>()),
stop_flags.data<bool>(),
think_end_id,
(inject_len > 0) ? inject_token_ids.data<int64_t>() : nullptr,
tokens_per_step,
batch_size,
eos_token_id_len,
inject_len,
splitwise_role_is_decode);
PD_CHECK(r == 0,
"baidu::xpu::api::plugin::"
"speculate_limit_thinking_content_length_kernel failed.");
}
PD_BUILD_STATIC_OP(speculate_limit_thinking_content_length)
.Inputs({"next_tokens",
"max_think_lens",
"max_reply_lens",
"step_idx",
"limit_status",
"accept_num",
"stop_flags",
"eos_token_ids",
"inject_token_ids"})
.Attrs({"think_end_id: int64_t", "splitwise_role_is_decode: bool"})
.Outputs({"next_tokens_out"})
.SetInplaceMap({{"next_tokens", "next_tokens_out"}})
.SetKernelFn(PD_KERNEL(SpeculateLimitThinkingContentLength));
@@ -645,17 +645,17 @@ DLL_EXPORT int speculate_update_v3(Context* ctx,
const int max_draft_tokens);
DLL_EXPORT int speculate_update(Context* ctx,
int* seq_lens_encoder, // 输入 [B_max, ]
int* seq_lens_decoder, // 输出 [B_max, ]
bool* not_need_stop, // [1,]
int64_t* draft_tokens, // [B_max, T_max]
int* actual_draft_token_nums, // [B_max, ]
const int64_t* accept_tokens, // [B_max, T_max]
const int* accept_num, // [B_max, ]
const bool* stop_flags, // [B_max, ]
const int* seq_lens_this_time, // [B_real,]
const bool* is_block_step, // [B_max, ]
int* mask_rollback, // [1,]
int* seq_lens_encoder,
int* seq_lens_decoder,
bool* not_need_stop,
int64_t* draft_tokens,
int* actual_draft_token_nums,
const int64_t* accept_tokens,
const int* accept_num,
const bool* stop_flags,
const int* seq_lens_this_time,
const bool* is_block_step,
int* mask_rollback,
const int real_bsz,
const int max_bsz,
const int max_draft_tokens);
@@ -704,6 +704,24 @@ DLL_EXPORT int update_attn_mask_offsets(Context* ctx,
int real_bsz,
int max_model_len,
int decode_states_len);
DLL_EXPORT int speculate_limit_thinking_content_length_kernel(
api::Context* ctx,
int64_t* next_tokens,
const int* max_think_lens,
int* max_reply_lens,
int64_t* step_idx,
const int64_t* eos_token_ids,
int* limit_status,
int* accept_num,
const bool* stop_flags,
const int64_t think_end_id,
const int64_t* inject_token_ids,
const int tokens_per_step,
const int bs,
const int eos_token_id_len,
const int inject_len,
const bool splitwise_role_is_decode);
/*--------------------------------------- MTP end
* --------------------------------------------*/
@@ -0,0 +1,204 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/kernel/cluster_debug.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/xtdk.h"
#include "xpu/kernel/xtdk_math.h"
#include "xpu/kernel/xtdk_simd.h"
namespace xpu3 {
namespace plugin {
// 1) 支持 inject_token_ids(任意长度)
// 2) 支持限制回复长度 max_reply_lens
// 3) 语义对齐非MTP limit_thinking_content_length
// - max_think_len < 0:不强制截断思考,但仍会监听 think_end_id
// 来进入回复阶段
// - max_reply_len >= 0:仅在"思考结束后的下一个 token"开始计回复长度
//
// 状态机(复用 limit_status
// - 0:思考生成阶段
// - 1..inject_len:注入 inject_token_ids[status - 1]
// - done_status = inject_len + 1inject_len==0 时
// done_status=1):思考结束点(本 token 不计入回复)
// - reply_base = done_status + 1
// - status >= reply_base:回复阶段计数,reply_len = status -
// reply_base(已输出回复 token 数)
__global__ void speculate_limit_thinking_content_length_kernel(
int64_t* next_tokens, // [bs, tokens_per_step]
const int* max_think_lens, // [bs]
int* max_reply_lens, // [bs]
int64_t* step_idx, // [bs]
const int64_t* eos_token_ids, // [eos_len]
int* limit_status, // [bs]
int* accept_num, // [bs]
const bool* stop_flags, // [bs]
const int64_t think_end_id,
const int64_t* inject_token_ids, // [inject_len]
const int tokens_per_step,
const int bs,
const int eos_token_id_len,
const int inject_len,
const bool splitwise_role_is_decode) {
const int total_threads = cluster_num() * core_num();
const int tid = cluster_id() * core_num() + core_id();
for (int bid = tid; bid < bs; bid += total_threads) {
const int original_accept_num = accept_num[bid];
if (original_accept_num <= 0) continue;
if (stop_flags[bid]) continue;
const int max_think_len =
max_think_lens[bid]; // <0: 不强制截断思考(但仍监听 think_end_id
int max_reply_len = max_reply_lens[bid]; // <0: 不限制回复
if (max_think_len < 0 && max_reply_len < 0) continue;
// 状态常量(允许 inject_len==0
const int done_status = (inject_len > 0) ? (inject_len + 1) : 1;
const int reply_base = done_status + 1;
int status = limit_status[bid];
if (status < 0) status = 0;
int new_accept_num = original_accept_num;
// 本 step 的 token offset 对应的绝对 step
const int64_t current_base_step = step_idx[bid] - original_accept_num + 1;
for (int token_offset = 0; token_offset < original_accept_num;
token_offset++) {
const int token_idx = bid * tokens_per_step + token_offset;
int64_t next_token = next_tokens[token_idx];
const int64_t current_step = current_base_step + token_offset;
const int prev_status = status;
bool condition_triggered = false;
// ======================= 1) 思考阶段监听 think_end_id(语义对齐非MTP
// =======================
// 注意:这里必须放在"注入触发逻辑"之前,因为如果模型自然输出 </think>
// 这一 token 应该把 status 置为 done_status,但"本 token 不计入回复"。
if (status == 0 && next_token == think_end_id) {
status = done_status;
// 不截断 accept_num:后续 token 可以继续(但回复计数从下一个 token
// 才开始)
if (max_reply_len >= 0) {
max_reply_len += 2;
}
}
// ======================= 2) 仅当启用"思考截断"(max_think_len>=0)
// 时触发注入 =======================
if (max_think_len >= 0 && status < reply_base) {
if (max_think_len > 0) {
// A) 超长触发:到达 max_think_len 时开始注入(从本 token 起输出
// inject_token_ids[0]
if (status == 0 &&
(current_step - 1) ==
max_think_len) { // current_step - 1 是因为 speculate_verify
// 里 step_idx + 1 了
status = (inject_len > 0) ? 1 : done_status;
}
} else if (max_think_len == 0) {
// A) 超长触发:到达 max_think_len 时开始注入
if (status == 0 && !splitwise_role_is_decode) {
// 如果是集中式或 P 节点:从本 token 起输出 inject_token_ids[0]
status = (inject_len > 0) ? 1 : done_status;
} else if (status == 0 && splitwise_role_is_decode) {
// 如果是 D 节点下:从本 token 起输出 inject_token_ids[1]
status = (inject_len > 0) ? 2 : done_status + 1;
}
}
// B) 思考阶段提前输出 eos:开始注入(从本 token 起覆盖 eos 为
// inject_token_ids[0]
if (status == 0 && inject_len > 0) {
for (int i = 0; i < eos_token_id_len; i++) {
if (eos_token_ids[i] == next_token) {
status = 1;
break;
}
}
}
// 注入序列:如果进入注入区间,则覆盖 next_token,并推进状态。
// 由于覆盖了 token,必须截断 accept_num,避免一口气接受后续 token。
if (inject_len > 0 && status >= 1 && status <= inject_len) {
next_token = inject_token_ids[status - 1];
status += 1;
if (status > done_status) status = done_status; // 防御
condition_triggered = true;
}
}
// 这一 token 是否"刚刚进入 done_status"
// 典型场景:自然输出 </think>0 -> done_status
// 或 注入最后一步(inject_len -> done_status
const bool became_done_this_token = (status == done_status) &&
(prev_status != done_status) &&
(prev_status < reply_base);
// ======================= 3) 回复长度限制(对齐非MTP)
// ======================= 关键:刚进入 done_status 的这一 token</think>
// 或注入 token)不计入回复,也不在这一 token 开始回复计数
if (max_reply_len >= 0) {
if (!became_done_this_token) {
// 只有在"前一个 token 已经是 done_status",当前 token 才允许进入
// reply_base 开始计数
if (status == done_status) {
status = reply_base; // reply_len = 0
}
if (status >= reply_base) {
int reply_len =
status - reply_base; // 已输出回复 token 数(不含当前 token
if (reply_len >= max_reply_len) {
// 达到上限:强制 EOS,并截断 accept_num
if (eos_token_id_len > 0) next_token = eos_token_ids[0];
status = reply_base + max_reply_len;
condition_triggered = true;
} else {
// 正常输出当前 token,并回复计数 +1
status = reply_base + (reply_len + 1);
}
}
}
}
// 写回当前 token
next_tokens[token_idx] = next_token;
// 若本 token 触发了"覆盖 token"(注入 or 强制 EOS),则截断 accept_num
if (condition_triggered) {
new_accept_num = token_offset + 1;
break;
}
}
// 更新 step_idx / accept_num(被截断的 token 需要回退 step_idx
const int discarded_tokens = original_accept_num - new_accept_num;
if (discarded_tokens > 0) {
step_idx[bid] -= discarded_tokens;
}
accept_num[bid] = new_accept_num;
limit_status[bid] = status;
max_reply_lens[bid] = max_reply_len;
}
}
} // namespace plugin
} // namespace xpu3
@@ -0,0 +1,292 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <numeric>
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
__attribute__((global)) void speculate_limit_thinking_content_length_kernel(
int64_t* next_tokens,
const int* max_think_lens,
int* max_reply_lens,
int64_t* step_idx,
const int64_t* eos_token_ids,
int* limit_status,
int* accept_num,
const bool* stop_flags,
const int64_t think_end_id,
const int64_t* inject_token_ids,
const int tokens_per_step,
const int bs,
const int eos_token_id_len,
const int inject_len,
const bool splitwise_role_is_decode);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
static int cpu_wrapper(Context* ctx,
int64_t* next_tokens,
const int* max_think_lens,
int* max_reply_lens,
int64_t* step_idx,
const int64_t* eos_token_ids,
int* limit_status,
int* accept_num,
const bool* stop_flags,
const int64_t think_end_id,
const int64_t* inject_token_ids,
const int tokens_per_step,
const int bs,
const int eos_token_id_len,
const int inject_len,
const bool splitwise_role_is_decode) {
for (int bid = 0; bid < bs; bid++) {
const int original_accept_num = accept_num[bid];
if (original_accept_num <= 0) continue;
if (stop_flags[bid]) continue;
const int max_think_len = max_think_lens[bid];
int max_reply_len = max_reply_lens[bid];
if (max_think_len < 0 && max_reply_len < 0) continue;
const int done_status = (inject_len > 0) ? (inject_len + 1) : 1;
const int reply_base = done_status + 1;
int status = limit_status[bid];
if (status < 0) status = 0;
int new_accept_num = original_accept_num;
const int64_t current_base_step = step_idx[bid] - original_accept_num + 1;
for (int token_offset = 0; token_offset < original_accept_num;
token_offset++) {
const int token_idx = bid * tokens_per_step + token_offset;
int64_t next_token = next_tokens[token_idx];
const int64_t current_step = current_base_step + token_offset;
const int prev_status = status;
bool condition_triggered = false;
if (status == 0 && next_token == think_end_id) {
status = done_status;
if (max_reply_len >= 0) {
max_reply_len += 2;
}
}
if (max_think_len >= 0 && status < reply_base) {
if (max_think_len > 0) {
if (status == 0 && (current_step - 1) == max_think_len) {
status = (inject_len > 0) ? 1 : done_status;
}
} else if (max_think_len == 0) {
if (status == 0 && !splitwise_role_is_decode) {
status = (inject_len > 0) ? 1 : done_status;
} else if (status == 0 && splitwise_role_is_decode) {
status = (inject_len > 0) ? 2 : done_status + 1;
}
}
if (status == 0 && inject_len > 0) {
for (int i = 0; i < eos_token_id_len; i++) {
if (eos_token_ids[i] == next_token) {
status = 1;
break;
}
}
}
if (inject_len > 0 && status >= 1 && status <= inject_len) {
next_token = inject_token_ids[status - 1];
status += 1;
if (status > done_status) status = done_status;
condition_triggered = true;
}
}
const bool became_done_this_token = (status == done_status) &&
(prev_status != done_status) &&
(prev_status < reply_base);
if (max_reply_len >= 0) {
if (!became_done_this_token) {
if (status == done_status) {
status = reply_base;
}
if (status >= reply_base) {
int reply_len = status - reply_base;
if (reply_len >= max_reply_len) {
if (eos_token_id_len > 0) next_token = eos_token_ids[0];
status = reply_base + max_reply_len;
condition_triggered = true;
} else {
status = reply_base + (reply_len + 1);
}
}
}
}
next_tokens[token_idx] = next_token;
if (condition_triggered) {
new_accept_num = token_offset + 1;
break;
}
}
const int discarded_tokens = original_accept_num - new_accept_num;
if (discarded_tokens > 0) {
step_idx[bid] -= discarded_tokens;
}
accept_num[bid] = new_accept_num;
limit_status[bid] = status;
max_reply_lens[bid] = max_reply_len;
}
return api::SUCCESS;
}
static int xpu3_wrapper(Context* ctx,
int64_t* next_tokens,
const int* max_think_lens,
int* max_reply_lens,
int64_t* step_idx,
const int64_t* eos_token_ids,
int* limit_status,
int* accept_num,
const bool* stop_flags,
const int64_t think_end_id,
const int64_t* inject_token_ids,
const int tokens_per_step,
const int bs,
const int eos_token_id_len,
const int inject_len,
const bool splitwise_role_is_decode) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto kernel = xpu3::plugin::speculate_limit_thinking_content_length_kernel;
int32_t ret_xre = kernel<<<1, 64, ctx->xpu_stream>>>(
reinterpret_cast<XPU_INT64*>(next_tokens),
max_think_lens,
max_reply_lens,
reinterpret_cast<XPU_INT64*>(step_idx),
reinterpret_cast<const XPU_INT64*>(eos_token_ids),
limit_status,
accept_num,
stop_flags,
think_end_id,
reinterpret_cast<const XPU_INT64*>(inject_token_ids),
tokens_per_step,
bs,
eos_token_id_len,
inject_len,
splitwise_role_is_decode);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
int speculate_limit_thinking_content_length_kernel(
Context* ctx,
int64_t* next_tokens,
const int* max_think_lens,
int* max_reply_lens,
int64_t* step_idx,
const int64_t* eos_token_ids,
int* limit_status,
int* accept_num,
const bool* stop_flags,
const int64_t think_end_id,
const int64_t* inject_token_ids,
const int tokens_per_step,
const int bs,
const int eos_token_id_len,
const int inject_len,
const bool splitwise_role_is_decode) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(
ctx, "speculate_limit_thinking_content_length_kernel", int);
WRAPPER_DUMP_PARAM5(ctx,
next_tokens,
max_think_lens,
max_reply_lens,
step_idx,
eos_token_ids);
WRAPPER_DUMP_PARAM5(ctx,
limit_status,
accept_num,
stop_flags,
think_end_id,
inject_token_ids);
WRAPPER_DUMP_PARAM5(ctx,
tokens_per_step,
bs,
eos_token_id_len,
inject_len,
splitwise_role_is_decode);
WRAPPER_DUMP(ctx);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(ctx,
next_tokens,
max_think_lens,
max_reply_lens,
step_idx,
eos_token_ids,
limit_status,
accept_num,
stop_flags,
think_end_id,
inject_token_ids,
tokens_per_step,
bs,
eos_token_id_len,
inject_len,
splitwise_role_is_decode);
}
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx,
next_tokens,
max_think_lens,
max_reply_lens,
step_idx,
eos_token_ids,
limit_status,
accept_num,
stop_flags,
think_end_id,
inject_token_ids,
tokens_per_step,
bs,
eos_token_id_len,
inject_len,
splitwise_role_is_decode);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
@@ -0,0 +1,433 @@
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import numpy as np
import paddle
LIMIT_THINKING_TEST_DEBUG = os.environ.get("LIMIT_THINKING_TEST_DEBUG", "0") == "1"
try:
from fastdeploy.model_executor.ops.xpu import (
speculate_limit_thinking_content_length as xpu_op,
)
HAS_XPU = True
except ImportError:
HAS_XPU = False
def ref_impl(
next_tokens, # [bs, tokens_per_step] int64
max_think_lens, # [bs] int32
max_reply_lens, # [bs] int32
step_idx, # [bs] int64
eos_token_ids, # [eos_len] int64
limit_status, # [bs] int32
accept_num, # [bs] int32
stop_flags, # [bs] bool
think_end_id, # int
inject_token_ids, # [inject_len] int64, may be empty
splitwise_role_is_decode,
):
next_tokens = next_tokens.copy()
max_reply_lens = max_reply_lens.copy()
step_idx = step_idx.copy()
limit_status = limit_status.copy()
accept_num = accept_num.copy()
bs = len(accept_num)
inject_len = len(inject_token_ids)
eos_token_id_len = len(eos_token_ids)
for bid in range(bs):
original_accept_num = int(accept_num[bid])
if original_accept_num <= 0:
continue
if stop_flags[bid]:
continue
max_think_len = int(max_think_lens[bid])
max_reply_len = int(max_reply_lens[bid])
if max_think_len < 0 and max_reply_len < 0:
continue
done_status = (inject_len + 1) if inject_len > 0 else 1
reply_base = done_status + 1
status = int(limit_status[bid])
if status < 0:
status = 0
new_accept_num = original_accept_num
current_base_step = int(step_idx[bid]) - original_accept_num + 1
for token_offset in range(original_accept_num):
next_token = int(next_tokens[bid, token_offset])
current_step = current_base_step + token_offset
prev_status = status
condition_triggered = False
# 1) 思考阶段监听 think_end_id
if status == 0 and next_token == think_end_id:
status = done_status
if max_reply_len >= 0:
max_reply_len += 2
# 2) 注入触发(仅 max_think_len >= 0 时)
if max_think_len >= 0 and status < reply_base:
if max_think_len > 0:
if status == 0 and (current_step - 1) == max_think_len:
status = 1 if inject_len > 0 else done_status
elif max_think_len == 0:
if status == 0 and not splitwise_role_is_decode:
status = 1 if inject_len > 0 else done_status
elif status == 0 and splitwise_role_is_decode:
status = 2 if inject_len > 0 else done_status + 1
# eos 触发注入
if status == 0 and inject_len > 0:
for i in range(eos_token_id_len):
if eos_token_ids[i] == next_token:
status = 1
break
# 注入序列
if inject_len > 0 and 1 <= status <= inject_len:
next_token = int(inject_token_ids[status - 1])
status += 1
if status > done_status:
status = done_status
condition_triggered = True
became_done_this_token = status == done_status and prev_status != done_status and prev_status < reply_base
# 3) 回复长度限制
if max_reply_len >= 0:
if not became_done_this_token:
if status == done_status:
status = reply_base
if status >= reply_base:
reply_len = status - reply_base
if reply_len >= max_reply_len:
if eos_token_id_len > 0:
next_token = int(eos_token_ids[0])
status = reply_base + max_reply_len
condition_triggered = True
else:
status = reply_base + (reply_len + 1)
next_tokens[bid, token_offset] = next_token
if condition_triggered:
new_accept_num = token_offset + 1
break
discarded = original_accept_num - new_accept_num
if discarded > 0:
step_idx[bid] -= discarded
accept_num[bid] = new_accept_num
limit_status[bid] = status
max_reply_lens[bid] = max_reply_len
return {
"next_tokens": next_tokens,
"max_reply_lens": max_reply_lens,
"step_idx": step_idx,
"limit_status": limit_status,
"accept_num": accept_num,
}
# ─────────────────────────────────────────────────────────────────────────────
# 工具函数
# ─────────────────────────────────────────────────────────────────────────────
def run_op(np_inputs, think_end_id, splitwise_role_is_decode, device, op_fn):
"""在指定 device 上运行算子,返回输出字段的 numpy dict。"""
paddle.set_device(device)
next_tokens = paddle.to_tensor(np_inputs["next_tokens"].copy())
max_think_lens = paddle.to_tensor(np_inputs["max_think_lens"].copy())
max_reply_lens = paddle.to_tensor(np_inputs["max_reply_lens"].copy())
step_idx = paddle.to_tensor(np_inputs["step_idx"].copy())
limit_status = paddle.to_tensor(np_inputs["limit_status"].copy())
accept_num = paddle.to_tensor(np_inputs["accept_num"].copy())
stop_flags = paddle.to_tensor(np_inputs["stop_flags"].copy())
eos_token_ids = paddle.to_tensor(np_inputs["eos_token_ids"].copy())
inject_token_ids = paddle.to_tensor(np_inputs["inject_token_ids"].copy())
op_fn(
next_tokens,
max_think_lens,
max_reply_lens,
step_idx,
limit_status,
accept_num,
stop_flags,
eos_token_ids,
inject_token_ids,
think_end_id,
splitwise_role_is_decode,
)
return {
"next_tokens": next_tokens.numpy(),
"max_reply_lens": max_reply_lens.numpy(),
"step_idx": step_idx.numpy(),
"limit_status": limit_status.numpy(),
"accept_num": accept_num.numpy(),
}
def run_ref(np_inputs, think_end_id, splitwise_role_is_decode):
return ref_impl(
np_inputs["next_tokens"].copy(),
np_inputs["max_think_lens"].copy(),
np_inputs["max_reply_lens"].copy(),
np_inputs["step_idx"].copy(),
np_inputs["eos_token_ids"].copy(),
np_inputs["limit_status"].copy(),
np_inputs["accept_num"].copy(),
np_inputs["stop_flags"].copy(),
think_end_id,
np_inputs["inject_token_ids"].copy(),
splitwise_role_is_decode,
)
def assert_equal(expected, actual, label):
for key in expected:
np.testing.assert_array_equal(
expected[key],
actual[key],
err_msg=f"[{label}] field='{key}' mismatch",
)
def make_inputs(
bs,
tokens_per_step,
next_tokens,
max_think_lens,
max_reply_lens,
step_idx,
limit_status,
accept_num,
stop_flags,
eos_token_ids,
inject_token_ids,
):
return {
"next_tokens": np.array(next_tokens, dtype=np.int64).reshape(bs, tokens_per_step),
"max_think_lens": np.array(max_think_lens, dtype=np.int32),
"max_reply_lens": np.array(max_reply_lens, dtype=np.int32),
"step_idx": np.array(step_idx, dtype=np.int64),
"limit_status": np.array(limit_status, dtype=np.int32),
"accept_num": np.array(accept_num, dtype=np.int32),
"stop_flags": np.array(stop_flags, dtype=bool),
"eos_token_ids": np.array(eos_token_ids, dtype=np.int64),
"inject_token_ids": np.array(inject_token_ids, dtype=np.int64),
}
def run_all_and_compare(test_case, np_inputs, think_end_id, splitwise_role_is_decode=False):
if LIMIT_THINKING_TEST_DEBUG:
print("\n========== [INPUT] ==========")
print(f" think_end_id : {think_end_id}")
print(f" splitwise_role_is_decode: {splitwise_role_is_decode}")
for k, v in np_inputs.items():
print(f" {k:25s}: {v}")
ref_out = run_ref(np_inputs, think_end_id, splitwise_role_is_decode)
if LIMIT_THINKING_TEST_DEBUG:
print("---------- [REF OUTPUT] ----------")
for k, v in ref_out.items():
print(f" {k:25s}: {v}")
if HAS_XPU:
xpu_out = run_op(np_inputs, think_end_id, splitwise_role_is_decode, "xpu:0", xpu_op)
if LIMIT_THINKING_TEST_DEBUG:
print("---------- [XPU OUTPUT] ----------")
for k, v in xpu_out.items():
print(f" {k:25s}: {v}")
assert_equal(ref_out, xpu_out, "XPU vs ref")
else:
test_case.skipTest("XPU is not available; only ref logic verified.")
# ─────────────────────────────────────────────────────────────────────────────
# 测试用例
# ─────────────────────────────────────────────────────────────────────────────
THINK_END_ID = 100
EOS_ID = 2
class TestSpeculateLimitThinkingContentLength(unittest.TestCase):
def test_think_end_natural(self):
"""模型自然输出 think_end_idstatus 0 → done_statusmax_reply_len += 2。
inject_len=0 → done_status=1, reply_base=2。"""
np_inputs = make_inputs(
bs=1,
tokens_per_step=1,
next_tokens=[THINK_END_ID], # 模型正好输出 </think>
max_think_lens=[-1], # 不强制截断
max_reply_lens=[5],
step_idx=[3],
limit_status=[0],
accept_num=[1],
stop_flags=[False],
eos_token_ids=[EOS_ID],
inject_token_ids=[],
)
run_all_and_compare(self, np_inputs, THINK_END_ID)
def test_inject_truncation(self):
"""超过 max_think_len 触发注入序列,accept_num 截断至当前 token。
inject_len=2 → done_status=3, reply_base=4。
step_idx=6, current_step-1=5=max_think_len → 触发注入,token 被替换为 inject[0]=200。"""
np_inputs = make_inputs(
bs=1,
tokens_per_step=1,
next_tokens=[999],
max_think_lens=[5],
max_reply_lens=[-1],
step_idx=[6], # current_base_step=6, current_step-1=5 == max_think_len
limit_status=[0],
accept_num=[1],
stop_flags=[False],
eos_token_ids=[EOS_ID],
inject_token_ids=[200, 201],
)
run_all_and_compare(self, np_inputs, THINK_END_ID)
def test_reply_len_limit(self):
"""回复计数达到 max_reply_len 上限,强制写入 EOS,截断 accept_num。
inject_len=0 → done_status=1, reply_base=2。
status=4reply_len=2=max_reply_len)→ 强制 EOS。"""
np_inputs = make_inputs(
bs=1,
tokens_per_step=1,
next_tokens=[999],
max_think_lens=[-1],
max_reply_lens=[2],
step_idx=[10],
limit_status=[4], # reply_base(2) + reply_len(2),已到上限
accept_num=[1],
stop_flags=[False],
eos_token_ids=[EOS_ID],
inject_token_ids=[],
)
run_all_and_compare(self, np_inputs, THINK_END_ID)
def test_no_limit(self):
"""max_think_len<0 且 max_reply_len<0,整个 batch 直接跳过,输出不变。"""
np_inputs = make_inputs(
bs=2,
tokens_per_step=1,
next_tokens=[111, 222],
max_think_lens=[-1, -1],
max_reply_lens=[-1, -1],
step_idx=[5, 8],
limit_status=[0, 0],
accept_num=[1, 1],
stop_flags=[False, False],
eos_token_ids=[EOS_ID],
inject_token_ids=[],
)
run_all_and_compare(self, np_inputs, THINK_END_ID)
def test_inject_len_zero_v1_behavior(self):
"""inject_len=0 退化为 v1 行为:超时直接进入 done_status=1token 不替换。"""
np_inputs = make_inputs(
bs=1,
tokens_per_step=1,
next_tokens=[999],
max_think_lens=[3],
max_reply_lens=[-1],
step_idx=[4], # current_step-1=3 == max_think_len=3
limit_status=[0],
accept_num=[1],
stop_flags=[False],
eos_token_ids=[EOS_ID],
inject_token_ids=[], # inject_len=0
)
run_all_and_compare(self, np_inputs, THINK_END_ID)
def test_multi_token_per_step(self):
"""tokens_per_step=3,第 2 个 tokenoffset=1)触发回复超限,
前 1 个 token 保留,accept_num 截断为 2step_idx 回退 1。
inject_len=0 → done_status=1, reply_base=2。
status=2reply_len=0),max_reply_len=1
offset=0: reply_len=0 < 1 → 正常输出,status→3
offset=1: reply_len=1 >= 1 → 强制 EOS,截断"""
np_inputs = make_inputs(
bs=1,
tokens_per_step=3,
next_tokens=[501, 502, 503],
max_think_lens=[-1],
max_reply_lens=[1],
step_idx=[12],
limit_status=[2], # = reply_base, reply_len=0
accept_num=[3],
stop_flags=[False],
eos_token_ids=[EOS_ID],
inject_token_ids=[],
)
run_all_and_compare(self, np_inputs, THINK_END_ID)
def test_already_stopped(self):
"""stop_flags=True 的 batch,直接跳过,输出不变。"""
np_inputs = make_inputs(
bs=2,
tokens_per_step=1,
next_tokens=[111, 222],
max_think_lens=[5, 5],
max_reply_lens=[10, 10],
step_idx=[6, 6],
limit_status=[0, 0],
accept_num=[1, 1],
stop_flags=[True, False], # batch0 已停止
eos_token_ids=[EOS_ID],
inject_token_ids=[],
)
run_all_and_compare(self, np_inputs, THINK_END_ID)
def test_splitwise_decode_node(self):
"""splitwise_role_is_decode=True 且 max_think_len=0
D 节点从 inject_token_ids[1] 开始注入(status 直接跳到 2)。
inject_len=3 → done_status=4, reply_base=5。"""
np_inputs = make_inputs(
bs=1,
tokens_per_step=1,
next_tokens=[999],
max_think_lens=[0], # 立即触发
max_reply_lens=[-1],
step_idx=[1],
limit_status=[0],
accept_num=[1],
stop_flags=[False],
eos_token_ids=[EOS_ID],
inject_token_ids=[200, 201, 202],
)
run_all_and_compare(self, np_inputs, THINK_END_ID, splitwise_role_is_decode=True)
if __name__ == "__main__":
unittest.main()