[Feature]Support tag phase token enforce generation (#6034)

* support tag phase token enforce generation

* optimize note and some feature

* fix sampler unit test

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
freeliuzc
2026-01-15 19:59:55 +08:00
committed by GitHub
parent 17866c028e
commit 49617d9832
12 changed files with 889 additions and 5 deletions
@@ -36,6 +36,7 @@ from fastdeploy.model_executor.layers.sample.ops import (
apply_penalty_multi_scores,
apply_speculative_penalty_multi_scores,
min_p_sampling,
reasoning_phase_token_constraint,
speculate_get_target_logits,
speculate_insert_first_token,
top_k_top_p_sampling,
@@ -614,6 +615,9 @@ class SpeculativeSampler(nn.Layer):
self.speculative_verify_window = fd_config.speculative_config.verify_window
self.speculative_max_candidate_len = fd_config.speculative_config.max_candidate_len
self.speculative_benchmark_mode = fd_config.speculative_config.benchmark_mode
self.think_end_id = fd_config.model_config.think_end_id
self.line_break_id = fd_config.model_config.line_break_id
self.enf_gen_phase_tag = fd_config.speculative_config.enf_gen_phase_tag
def pre_process(self, skip_idx_list: List[int] = []):
"""pre process before running"""
@@ -757,6 +761,22 @@ class SpeculativeSampler(nn.Layer):
max_model_len,
)
if self.enf_gen_phase_tag:
reasoning_phase_token_constraint(
logits,
sampling_metadata.pre_token_ids,
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["seq_lens_encoder"],
share_inputs["step_idx"],
share_inputs["reasoning_allowed_tokens"],
share_inputs["reasoning_status"],
share_inputs["output_padding_offset"],
share_inputs["output_cum_offsets"],
self.think_end_id,
self.line_break_id,
)
probs = F.softmax(logits)
top_p, top_k, topp_seed = padding_sampling_params(
@@ -797,6 +817,7 @@ class SpeculativeSampler(nn.Layer):
actual_candidate_len,
share_inputs["actual_draft_token_num"],
sampling_metadata.top_p,
share_inputs["reasoning_status"],
max_model_len,
self.speculative_verify_window,
True, # enable_topp