[Feature]Support tag phase token enforce generation (#6034)

* support tag phase token enforce generation

* optimize note and some feature

* fix sampler unit test

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
freeliuzc
2026-01-15 19:59:55 +08:00
committed by GitHub
parent 17866c028e
commit 49617d9832
12 changed files with 889 additions and 5 deletions
+18
View File
@@ -802,6 +802,7 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
const paddle::Tensor& actual_candidate_len,
const paddle::Tensor& actual_draft_token_nums,
const paddle::Tensor& topp,
const paddle::Tensor& reasoning_status,
int max_seq_len,
int verify_window,
bool enable_topp,
@@ -1107,6 +1108,19 @@ std::vector<paddle::Tensor> FusedNeoxRopeEmbedding(
std::vector<paddle::Tensor> GeluTanh(paddle::Tensor& input);
void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits,
const paddle::Tensor& pre_ids,
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& allowed_tokens,
const paddle::Tensor& reasoning_status,
const paddle::Tensor& output_padding_offset,
const paddle::Tensor& output_cum_offsets,
int64_t think_end_id,
int64_t line_break_id);
PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("get_expert_token_num",
&GetExpertTokenNum,
@@ -1712,4 +1726,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
"fused_neox_rope_embedding function");
m.def("gelu_tanh", &GeluTanh, "gelu_tanh function");
m.def("reasoning_phase_token_constraint",
&ReasoningPhaseTokenConstraint,
"reasoning_phase_token_constraint function");
}
@@ -0,0 +1,315 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda.h>
#include <cuda_runtime.h>
#include "helper.h"
// ================================================================
// Reasoning Phase State Machine
//
// reasoning_status meanings:
//
// x = 0 : Thinking phase
// - Model is generating hidden reasoning content
// - No token constraint is applied
//
// Transition condition (x = 0 -> x = 1):
// - Check whether <think_end> token appears
// in the last 4 generated tokens
//
// ------------------------------------------------
//
// x = 1 : Generating "\n</think>\n\n"
// - Model is emitting the explicit boundary pattern
// - In non-MTP mode, accept_num is implicitly 1
// and does not need to be manually set
// - In MTP mode, accept_num must be 1 in verify kernel
//
// Transition condition (x = 1 -> x = 2):
// - step_idx >= 3
// - pre_ids[-4:] exactly match:
// "\n</think>\n\n"
//
// ------------------------------------------------
//
// x = 2 : Generating <response> / <tool_call> phase
// - Model starts generating visible response or tool calls
// - Token constraint is enforced at the first token of this phase
// - Logits are masked to allow only a predefined token set
//
// Kernel applied:
// - apply_token_enforce_generation_scores_kernel
//
// Transition condition (x = 2 -> x = 3):
// - Automatically advance after one step
//
// ------------------------------------------------
//
// x = 3 : End state
// - Reasoning boundary handling is complete
// - No further state transitions
//
// ================================================================
__global__ void update_reasoning_status_kernel(
const bool* stop_flags, // [bs]
const int* seq_lens_encoder, // [bs]
const int64_t* step_idx, // [bs]
const int64_t* pre_ids, // [bs, max_seq_len]
int32_t* reasoning_status, // [bs]
int32_t bs,
int32_t max_seq_len,
int64_t think_end_id,
int64_t line_break_id) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid >= bs) return;
int32_t status = reasoning_status[tid];
if (stop_flags[tid] || seq_lens_encoder[tid] > 0 || status == 3) return;
int64_t cur_step = step_idx[tid];
const int64_t* pre_ids_now = pre_ids + tid * max_seq_len;
int64_t t0 = (cur_step >= 0) ? pre_ids_now[cur_step] : -1;
int64_t t1 = (cur_step >= 1) ? pre_ids_now[cur_step - 1] : -1;
int64_t t2 = (cur_step >= 2) ? pre_ids_now[cur_step - 2] : -1;
int64_t t3 = (cur_step >= 3) ? pre_ids_now[cur_step - 3] : -1;
int32_t new_status = status;
// x = 0 -> x = 1
if (status == 0) {
if (t0 == think_end_id || t1 == think_end_id || t2 == think_end_id ||
t3 == think_end_id) {
new_status = 1;
}
}
// x = 1 -> x = 2 (include think_end_id)
// or x = 1 -> x = 3 (not include think_end_id)
// Here must be serial judge
if (new_status == 1 && cur_step >= 3) {
if (t3 == line_break_id && t2 == think_end_id && t1 == line_break_id &&
t0 == line_break_id) {
new_status = 2;
} else if (t3 != think_end_id && t2 != think_end_id && t1 != think_end_id &&
t0 != think_end_id) {
new_status = 3;
}
} else if (status == 2) {
// x = 2 -> x = 3
new_status = 3;
}
reasoning_status[tid] = new_status;
}
// ================================================================
// Kernel 2: apply enforce generation scores
// ================================================================
template <typename T>
__global__ void apply_token_enforce_generation_scores_kernel(
const T* __restrict__ logits_src, // logits_tmp (backup)
T* __restrict__ logits_dst, // logits (output)
const int64_t* __restrict__ allowed_tokens, // [allowed_len]
const int32_t* __restrict__ reasoning_status,
const int* output_padding_offset,
const int* output_cum_offsets,
const int max_bsz,
const int max_seq_len,
const int vocab_size,
const int allowed_tokens_len) {
int token_idx = blockIdx.x;
int tid = threadIdx.x;
const int bs_idx =
(token_idx + output_padding_offset[token_idx]) / max_seq_len;
const int query_start_token_idx =
bs_idx * max_seq_len - output_cum_offsets[bs_idx];
bool is_batch_first_token = (token_idx == query_start_token_idx);
if (allowed_tokens_len == 0 || !is_batch_first_token) {
return;
}
if (bs_idx < max_bsz && reasoning_status[bs_idx] == 2) {
const T* src = logits_src + token_idx * vocab_size;
T* dst = logits_dst + token_idx * vocab_size;
// 1. clear all logits
for (int i = tid; i < vocab_size; i += blockDim.x) {
dst[i] = static_cast<T>(-1e10f);
}
__syncthreads();
// 2. restore allowed tokens
for (int i = tid; i < allowed_tokens_len; i += blockDim.x) {
int64_t token_id = allowed_tokens[i];
if ((unsigned)token_id < (unsigned)vocab_size) {
dst[token_id] = src[token_id];
}
}
}
}
// ================================================================
// C++ Launcher
// ================================================================
template <paddle::DataType D>
void reasoning_phase_token_constraint(
const paddle::Tensor& logits, // inplace output
const paddle::Tensor& pre_ids,
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& allowed_tokens,
const paddle::Tensor& reasoning_status,
const paddle::Tensor& output_padding_offset,
const paddle::Tensor& output_cum_offsets,
int64_t think_end_id,
int64_t line_break_id) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto stream = logits.stream();
int bs = seq_lens_this_time.shape()[0];
int token_num = logits.shape()[0];
int vocab_size = logits.shape()[1];
int max_seq_len = pre_ids.shape()[1];
int allowed_tokens_len = allowed_tokens.shape()[0];
// ------------------------------------------------
// Kernel 1: update reasoning status
// ------------------------------------------------
// int block1 = (bs + 31) / 32 * 32;
const int block_size = 512;
const int gird_size = (bs + block_size - 1) / block_size;
update_reasoning_status_kernel<<<gird_size, block_size, 0, stream>>>(
stop_flags.data<bool>(),
seq_lens_encoder.data<int>(),
step_idx.data<int64_t>(),
pre_ids.data<int64_t>(),
const_cast<int32_t*>(reasoning_status.data<int32_t>()),
bs,
max_seq_len,
think_end_id,
line_break_id);
// ------------------------------------------------
// backup logits
// ------------------------------------------------
auto logits_tmp = logits.copy_to(logits.place(), false);
// ------------------------------------------------
// Kernel 2: enforce generation
// ------------------------------------------------
int block_size_2 = (vocab_size + 31) / 32 * 32;
block_size_2 = std::min(block_size_2, 512);
apply_token_enforce_generation_scores_kernel<<<token_num,
block_size_2,
0,
stream>>>(
reinterpret_cast<DataType_*>(logits_tmp.data<data_t>()),
reinterpret_cast<DataType_*>(const_cast<data_t*>(logits.data<data_t>())),
allowed_tokens.data<int64_t>(),
reasoning_status.data<int32_t>(),
output_padding_offset.data<int32_t>(),
output_cum_offsets.data<int32_t>(),
bs,
max_seq_len,
vocab_size,
allowed_tokens_len);
}
void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits,
const paddle::Tensor& pre_ids,
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& allowed_tokens,
const paddle::Tensor& reasoning_status,
const paddle::Tensor& output_padding_offset,
const paddle::Tensor& output_cum_offsets,
int64_t think_end_id,
int64_t line_break_id) {
switch (logits.type()) {
case paddle::DataType::FLOAT16:
return reasoning_phase_token_constraint<paddle::DataType::FLOAT16>(
logits,
pre_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
step_idx,
allowed_tokens,
reasoning_status,
output_padding_offset,
output_cum_offsets,
think_end_id,
line_break_id);
case paddle::DataType::BFLOAT16:
return reasoning_phase_token_constraint<paddle::DataType::BFLOAT16>(
logits,
pre_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
step_idx,
allowed_tokens,
reasoning_status,
output_padding_offset,
output_cum_offsets,
think_end_id,
line_break_id);
case paddle::DataType::FLOAT32:
return reasoning_phase_token_constraint<paddle::DataType::FLOAT32>(
logits,
pre_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
step_idx,
allowed_tokens,
reasoning_status,
output_padding_offset,
output_cum_offsets,
think_end_id,
line_break_id);
default:
PD_THROW("Unsupported data type.");
}
}
// ================================================================
// PD_BUILD_STATIC_OP
// ================================================================
PD_BUILD_STATIC_OP(reasoning_phase_token_constraint)
.Inputs({"logits",
"pre_ids",
"stop_flags",
"seq_lens_this_time",
"seq_lens_encoder",
"step_idx",
"allowed_tokens",
"reasoning_status",
"output_padding_offset",
"output_cum_offsets"})
.Outputs({"logits_out", "reasoning_status_out"})
.Attrs({"think_end_id: int64_t", "line_break_id: int64_t"})
.SetInplaceMap({{"logits", "logits_out"},
{"reasoning_status", "reasoning_status_out"}})
.SetKernelFn(PD_KERNEL(ReasoningPhaseTokenConstraint));
@@ -84,6 +84,7 @@ __global__ void speculate_verify(const int64_t *sampled_token_ids,
const bool *is_block_step,
const int *output_cum_offsets,
const int *actual_candidate_len,
const int *reasoning_status,
const int real_bsz,
const int max_draft_tokens,
const int end_length,
@@ -116,10 +117,8 @@ __global__ void speculate_verify(const int64_t *sampled_token_ids,
// printf("seq_lens_this_time[%d]-1: %d \n",bid,
// seq_lens_this_time[bid]-1);
for (; i < seq_lens_this_time[bid] - 1; i++) {
if (benchmark_mode) {
break;
}
if (seq_lens_encoder[bid] != 0) {
if (benchmark_mode || seq_lens_encoder[bid] != 0 ||
reasoning_status[bid] == 1) {
break;
}
if (accept_all_drafts) {
@@ -317,6 +316,7 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
const paddle::Tensor &actual_candidate_len,
const paddle::Tensor &actual_draft_token_nums,
const paddle::Tensor &topp,
const paddle::Tensor &reasoning_status,
int max_seq_len,
int verify_window,
bool enable_topp,
@@ -376,6 +376,7 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
is_block_step.data<bool>(),
output_cum_offsets.data<int>(),
actual_candidate_len.data<int>(),
reasoning_status.data<int>(),
real_bsz,
max_draft_tokens,
end_length,
@@ -408,6 +409,7 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
is_block_step.data<bool>(),
output_cum_offsets.data<int>(),
actual_candidate_len.data<int>(),
reasoning_status.data<int>(),
real_bsz,
max_draft_tokens,
end_length,
@@ -442,6 +444,7 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
is_block_step.data<bool>(),
output_cum_offsets.data<int>(),
actual_candidate_len.data<int>(),
reasoning_status.data<int>(),
real_bsz,
max_draft_tokens,
end_length,
@@ -474,6 +477,7 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
is_block_step.data<bool>(),
output_cum_offsets.data<int>(),
actual_candidate_len.data<int>(),
reasoning_status.data<int>(),
real_bsz,
max_draft_tokens,
end_length,
@@ -508,7 +512,8 @@ PD_BUILD_STATIC_OP(speculate_verify)
"output_cum_offsets",
"actual_candidate_len",
"actual_draft_token_nums",
"topp"})
"topp",
"reasoning_status"})
.Outputs({"accept_tokens_out",
"accept_num_out",
"step_idx_out",
+1
View File
@@ -310,6 +310,7 @@ elif paddle.is_compiled_with_cuda():
"gpu_ops/update_attn_mask_offsets.cu",
"gpu_ops/fused_neox_rope_embedding.cu",
"gpu_ops/gelu_tanh.cu",
"gpu_ops/reasoning_phase_token_constraint.cu",
]
# pd_disaggregation
+3
View File
@@ -739,6 +739,9 @@ class SpeculativeConfig:
# This means no tokens from MTP are accepted.
# This ensures that the specified simulation acceptance rate is not affected.
self.benchmark_mode: bool = False
# Enable token constraint enforcement in generation phase
# When enabled, enforces specific tokens after the reasoning phase boundary pattern
self.enf_gen_phase_tag: bool = False
self.num_extra_cache_layer = 0
@@ -17,6 +17,7 @@
from .apply_penalty_multi_scores import (
apply_penalty_multi_scores,
apply_speculative_penalty_multi_scores,
reasoning_phase_token_constraint,
)
from .speculate_logprob_utils import (
speculate_get_target_logits,
@@ -27,6 +28,7 @@ from .top_k_top_p_sampling import min_p_sampling, top_k_top_p_sampling
__all__ = [
"apply_penalty_multi_scores",
"apply_speculative_penalty_multi_scores",
"reasoning_phase_token_constraint",
"top_k_top_p_sampling",
"min_p_sampling",
"speculate_get_target_logits",
@@ -207,3 +207,43 @@ def apply_speculative_penalty_multi_scores(
)
# inplace
return logits
def reasoning_phase_token_constraint(
logits: paddle.Tensor,
pre_token_ids: paddle.Tensor,
stop_flags: paddle.Tensor,
seq_lens_this_time: paddle.Tensor,
seq_lens_encoder: paddle.Tensor,
step_idx: paddle.Tensor,
reasoning_allowed_tokens: paddle.Tensor,
reasoning_status: paddle.Tensor,
output_padding_offset: paddle.Tensor,
output_cum_offsets: paddle.Tensor,
think_end_id: int,
line_break_id: int,
):
"""
reasoning_phase_token_constraint
"""
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import reasoning_phase_token_constraint
reasoning_phase_token_constraint(
logits,
pre_token_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
step_idx,
reasoning_allowed_tokens,
reasoning_status,
output_padding_offset,
output_cum_offsets,
think_end_id,
line_break_id,
)
else:
raise NotImplementedError
# inplace
return logits
@@ -36,6 +36,7 @@ from fastdeploy.model_executor.layers.sample.ops import (
apply_penalty_multi_scores,
apply_speculative_penalty_multi_scores,
min_p_sampling,
reasoning_phase_token_constraint,
speculate_get_target_logits,
speculate_insert_first_token,
top_k_top_p_sampling,
@@ -614,6 +615,9 @@ class SpeculativeSampler(nn.Layer):
self.speculative_verify_window = fd_config.speculative_config.verify_window
self.speculative_max_candidate_len = fd_config.speculative_config.max_candidate_len
self.speculative_benchmark_mode = fd_config.speculative_config.benchmark_mode
self.think_end_id = fd_config.model_config.think_end_id
self.line_break_id = fd_config.model_config.line_break_id
self.enf_gen_phase_tag = fd_config.speculative_config.enf_gen_phase_tag
def pre_process(self, skip_idx_list: List[int] = []):
"""pre process before running"""
@@ -757,6 +761,22 @@ class SpeculativeSampler(nn.Layer):
max_model_len,
)
if self.enf_gen_phase_tag:
reasoning_phase_token_constraint(
logits,
sampling_metadata.pre_token_ids,
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["seq_lens_encoder"],
share_inputs["step_idx"],
share_inputs["reasoning_allowed_tokens"],
share_inputs["reasoning_status"],
share_inputs["output_padding_offset"],
share_inputs["output_cum_offsets"],
self.think_end_id,
self.line_break_id,
)
probs = F.softmax(logits)
top_p, top_k, topp_seed = padding_sampling_params(
@@ -797,6 +817,7 @@ class SpeculativeSampler(nn.Layer):
actual_candidate_len,
share_inputs["actual_draft_token_num"],
sampling_metadata.top_p,
share_inputs["reasoning_status"],
max_model_len,
self.speculative_verify_window,
True, # enable_topp
+6
View File
@@ -1274,6 +1274,12 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["max_think_lens"] = paddle.full(shape=[max_num_seqs, 1], fill_value=-1, dtype="int32")
self.share_inputs["limit_think_status"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
# NOTE(liuzichang): token after \n</think>\n\n must be <tool_call> 100973 or <response> 100975
# It is a hard code to cover up model's performance
# Detailed notes can be found in FastDeploy/custom_ops/gpu_ops/reasoning_phase_token_constraint.cu
self.share_inputs["reasoning_status"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
self.share_inputs["reasoning_allowed_tokens"] = paddle.to_tensor([100973, 100975], dtype="int64")
# Initialize rotary position embedding
if not self.enable_mm:
self.share_inputs["rope_emb"] = get_rope(
+1
View File
@@ -141,6 +141,7 @@ def _create_share_inputs(max_num_seqs, max_draft_token_num, max_model_len, vocab
share_inputs["draft_logits"] = paddle.full(
[max_num_seqs * (max_draft_token_num + 1), vocab_size], -1, dtype="float32"
)
share_inputs["reasoning_status"] = paddle.zeros([max_num_seqs], dtype="int32")
return share_inputs
@@ -0,0 +1,469 @@
import unittest
import numpy as np
import paddle
from fastdeploy.model_executor.ops.gpu import (
reasoning_phase_token_constraint,
speculate_get_output_padding_offset,
)
class TestReasoningPhaseTokenConstraint(unittest.TestCase):
def setUp(self):
paddle.set_device("gpu")
# ------------------------
# Basic config
# ------------------------
self.bs = 2
self.max_seq_len = 8
self.vocab_size = 16
self.think_end_id = 9
self.line_break_id = 10
# ------------------------
# seq / step
# ------------------------
self.step_idx = paddle.to_tensor([4, 4], dtype="int64")
self.seq_lens_this_time = paddle.to_tensor([2, 2], dtype="int32")
self.seq_lens_encoder = paddle.to_tensor([0, 0], dtype="int32")
self.stop_flags = paddle.to_tensor([False, False], dtype="bool")
# ------------------------
# pre_ids
#
# batch 0:
# ... \n <think_end> \n \n → status 1 -> 2
#
# batch 1:
# contains think_end, but pattern not complete → status 0 -> 1
# ------------------------
pre_ids = np.zeros((self.bs, self.max_seq_len), dtype=np.int64)
# batch 0
pre_ids[0, 1] = self.line_break_id
pre_ids[0, 2] = self.think_end_id
pre_ids[0, 3] = self.line_break_id
pre_ids[0, 4] = self.line_break_id
# batch 1
pre_ids[1, 3] = self.think_end_id
self.pre_ids = paddle.to_tensor(pre_ids, dtype="int64")
# ------------------------
# reasoning_status (init)
# ------------------------
self.reasoning_status = paddle.to_tensor([1, 0], dtype="int32")
# ------------------------
# allowed tokens
# ------------------------
self.allowed_tokens = paddle.to_tensor([2, 5, 7], dtype="int64")
# ------------------------
# speculative layout
#
# each batch has exactly 1 token this step
# token_idx == bs_idx
# ------------------------
self.token_num = paddle.sum(self.seq_lens_this_time)
seq_lens_output = paddle.to_tensor([2, 2], dtype="int32")
output_token_num = paddle.sum(seq_lens_output)
output_cum_offsets_tmp = paddle.cumsum(self.max_seq_len - seq_lens_output, dtype="int32")
self.output_padding_offset, self.output_cum_offsets = speculate_get_output_padding_offset(
output_cum_offsets_tmp,
output_token_num,
seq_lens_output,
self.max_seq_len,
)
# self.output_padding_offset = paddle.zeros([self.token_num], dtype="int32")
# self.output_cum_offsets = paddle.zeros([self.bs], dtype="int32")
# ------------------------
# logits
# ------------------------
np.random.seed(2024)
logits = np.random.randn(self.token_num, self.vocab_size).astype("float32")
self.logits = paddle.to_tensor(logits, dtype="float32")
def test_reasoning_status_and_logits_enforce(self):
logits_before = self.logits.numpy().copy()
# ------------------------
# call custom op
# ------------------------
reasoning_phase_token_constraint(
self.logits,
self.pre_ids,
self.stop_flags,
self.seq_lens_this_time,
self.seq_lens_encoder,
self.step_idx,
self.allowed_tokens,
self.reasoning_status,
self.output_padding_offset,
self.output_cum_offsets,
self.think_end_id,
self.line_break_id,
)
logits_after = self.logits.numpy()
status_after = self.reasoning_status.numpy()
# ============================================================
# 1. reasoning_status check
# ============================================================
# batch 0: 1 -> 2
self.assertEqual(status_after[0], 2)
# batch 1: 0 -> 1
self.assertEqual(status_after[1], 1)
# ============================================================
# 2. logits enforce check
# ============================================================
# batch 0 should be enforced (status == 2)
for vid in range(self.vocab_size):
if vid in self.allowed_tokens.numpy():
self.assertAlmostEqual(
logits_after[0, vid],
logits_before[0, vid],
places=5,
)
else:
self.assertLess(logits_after[0, vid], -1e9)
# batch 1 should be untouched
np.testing.assert_allclose(
logits_after[1],
logits_before[1],
rtol=1e-5,
atol=1e-6,
)
def test_status_0_to_1_only(self):
"""
status == 0
recent tokens contain <think_end>
=> status: 0 -> 1
logits should NOT be enforced
"""
# ------------------------
# setup: only think_end appears
# ------------------------
pre_ids = np.zeros((self.bs, self.max_seq_len), dtype=np.int64)
# batch 0: think_end at cur_step - 1
pre_ids[0, 3] = self.think_end_id
# batch 1: no think_end
pre_ids[1, :] = 0
self.pre_ids = paddle.to_tensor(pre_ids, dtype="int64")
self.reasoning_status = paddle.to_tensor([0, 0], dtype="int32")
logits_before = self.logits.numpy().copy()
# ------------------------
# call op
# ------------------------
reasoning_phase_token_constraint(
self.logits,
self.pre_ids,
self.stop_flags,
self.seq_lens_this_time,
self.seq_lens_encoder,
self.step_idx,
self.allowed_tokens,
self.reasoning_status,
self.output_padding_offset,
self.output_cum_offsets,
self.think_end_id,
self.line_break_id,
)
status_after = self.reasoning_status.numpy()
logits_after = self.logits.numpy()
# ============================================================
# 1. reasoning_status
# ============================================================
# batch 0: 0 -> 1
self.assertEqual(status_after[0], 1)
# batch 1: stays 0
self.assertEqual(status_after[1], 0)
# ============================================================
# 2. logits must be untouched
# ============================================================
np.testing.assert_allclose(
logits_after,
logits_before,
rtol=1e-5,
atol=1e-6,
)
def test_status_2_to_3_only(self):
# Force initial status = 2
self.reasoning_status = paddle.to_tensor([2, 2], dtype="int32")
logits_before = self.logits.numpy().copy()
reasoning_phase_token_constraint(
self.logits,
self.pre_ids,
self.stop_flags,
self.seq_lens_this_time,
self.seq_lens_encoder,
self.step_idx,
self.allowed_tokens,
self.reasoning_status,
self.output_padding_offset,
self.output_cum_offsets,
self.think_end_id,
self.line_break_id,
)
status_after = self.reasoning_status.numpy()
logits_after = self.logits.numpy()
# status: 2 -> 3
self.assertTrue(np.all(status_after == 3))
# logits should NOT be changed
np.testing.assert_allclose(
logits_after,
logits_before,
rtol=1e-5,
atol=1e-6,
)
def test_status_1_to_2(self):
# batch 0 enforcebatch 1 not enforce
self.reasoning_status = paddle.to_tensor([1, 2], dtype="int32")
logits_before = self.logits.numpy().copy()
reasoning_phase_token_constraint(
self.logits,
self.pre_ids,
self.stop_flags,
self.seq_lens_this_time,
self.seq_lens_encoder,
self.step_idx,
self.allowed_tokens,
self.reasoning_status,
self.output_padding_offset,
self.output_cum_offsets,
self.think_end_id,
self.line_break_id,
)
logits_after = self.logits.numpy()
# Find batch 0's token_idx
token_idx_batch0 = 0 # speculate_get_output_padding_offset 下,第一个 token 一定是 batch 0
# batch 0 first token should be enforced
for vid in range(self.vocab_size):
if vid in self.allowed_tokens.numpy():
self.assertAlmostEqual(
logits_after[token_idx_batch0, vid],
logits_before[token_idx_batch0, vid],
places=5,
)
else:
self.assertLess(logits_after[token_idx_batch0, vid], -1e9)
# batch 0 second token(如果存在)必须 untouched
if self.token_num > 1:
np.testing.assert_allclose(
logits_after[token_idx_batch0 + 1],
logits_before[token_idx_batch0 + 1],
rtol=1e-5,
atol=1e-6,
)
np.testing.assert_equal(self.reasoning_status.numpy(), [2, 3])
def test_empty_allowed_tokens(self):
empty_allowed = paddle.empty([0], dtype="int64")
logits_before = self.logits.numpy().copy()
reasoning_phase_token_constraint(
self.logits,
self.pre_ids,
self.stop_flags,
self.seq_lens_this_time,
self.seq_lens_encoder,
self.step_idx,
empty_allowed,
self.reasoning_status,
self.output_padding_offset,
self.output_cum_offsets,
self.think_end_id,
self.line_break_id,
)
logits_after = self.logits.numpy()
np.testing.assert_allclose(
logits_after,
logits_before,
rtol=1e-5,
atol=1e-6,
)
def test_perf_bsz128_vocab100k_status2(self):
"""
Performance benchmark:
bsz = 128
vocab = 100k
all status == 2
all tokens are batch-first tokens
"""
paddle.set_device("gpu")
# ------------------------
# config
# ------------------------
bs = 256
vocab_size = 100000
max_seq_len = 1024
think_end_id = 9
line_break_id = 10
# ------------------------
# seq / step
# ------------------------
step_idx = paddle.full([bs], 4, dtype="int64")
seq_lens_this_time = paddle.full([bs], 1, dtype="int32")
seq_lens_encoder = paddle.zeros([bs], dtype="int32")
stop_flags = paddle.zeros([bs], dtype="bool")
# ------------------------
# pre_ids: force 1 -> 2 pattern
# ------------------------
pre_ids = np.zeros((bs, max_seq_len), dtype=np.int64)
for i in range(bs):
pre_ids[i, 1] = line_break_id
pre_ids[i, 2] = think_end_id
pre_ids[i, 3] = line_break_id
pre_ids[i, 4] = line_break_id
pre_ids = paddle.to_tensor(pre_ids, dtype="int64")
# ------------------------
# reasoning_status: start from 1
# ------------------------
reasoning_status = paddle.ones([bs], dtype="int32")
# ------------------------
# allowed tokens (small set)
# ------------------------
allowed_tokens = paddle.to_tensor([1, 5, 42, 999], dtype="int64")
# ------------------------
# speculative layout
# each batch exactly 1 token
# token_idx == bs_idx
# ------------------------
token_num = paddle.sum(seq_lens_this_time)
seq_lens_output = paddle.full(bs, 2, dtype="int32")
output_token_num = paddle.sum(seq_lens_output)
output_cum_offsets_tmp = paddle.cumsum(max_seq_len - seq_lens_output, dtype="int32")
output_padding_offset, output_cum_offsets = speculate_get_output_padding_offset(
output_cum_offsets_tmp,
output_token_num,
seq_lens_output,
max_seq_len,
)
# ------------------------
# logits
# ------------------------
logits = paddle.randn([token_num, vocab_size], dtype="float32")
# ------------------------
# warmup
# ------------------------
for _ in range(5):
reasoning_phase_token_constraint(
logits,
pre_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
step_idx,
allowed_tokens,
reasoning_status,
output_padding_offset,
output_cum_offsets,
think_end_id,
line_break_id,
)
paddle.device.cuda.synchronize()
# ------------------------
# timing
# ------------------------
iters = 20
start = paddle.device.cuda.Event(enable_timing=True)
end = paddle.device.cuda.Event(enable_timing=True)
start.record()
for _ in range(iters):
reasoning_phase_token_constraint(
logits,
pre_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
step_idx,
allowed_tokens,
reasoning_status,
output_padding_offset,
output_cum_offsets,
think_end_id,
line_break_id,
)
end.record()
paddle.device.cuda.synchronize()
elapsed_ms = paddle.device.cuda.Event.elapsed_time(start, end)
avg_ms = elapsed_ms / iters
print(f"[PERF] bsz={bs}, vocab={vocab_size}, " f"avg latency = {avg_ms:.3f} ms")
# ------------------------
# correctness spot check
# ------------------------
logits_np = logits.numpy()
print(logits)
for b in [0, 100, 200]: # sample few batches
for vid in range(vocab_size):
if vid in allowed_tokens.numpy():
continue
# print(f"b: {b}, vid: {vid}")
self.assertLess(logits_np[b, vid], -1e9)
if __name__ == "__main__":
unittest.main()
+3
View File
@@ -66,6 +66,7 @@ def speculate_verify_ref(
actual_candidate_len,
actual_draft_token_nums,
topp,
reasoning_status,
max_seq_len,
verify_window,
enable_topp,
@@ -285,6 +286,7 @@ def gen_speculate_verify_inputs(
if enable_topp
else np.zeros(real_bsz, dtype=np.float32)
)
reasoning_status = np.zeros((real_bsz), dtype=np.int32)
# Output(inplace)
accept_tokens = np.zeros((real_bsz, max_draft_tokens), dtype=np.int64)
@@ -311,6 +313,7 @@ def gen_speculate_verify_inputs(
"actual_candidate_len": actual_candidate_len,
"actual_draft_token_nums": actual_draft_token_nums,
"topp": topp,
"reasoning_status": reasoning_status,
"max_seq_len": max_seq_len,
"verify_window": verify_window,
"enable_topp": enable_topp,