mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-24 09:44:10 +08:00
[Feature]Support tag phase token enforce generation (#6034)
* support tag phase token enforce generation * optimize note and some feature * fix sampler unit test --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
@@ -36,6 +36,7 @@ from fastdeploy.model_executor.layers.sample.ops import (
|
||||
apply_penalty_multi_scores,
|
||||
apply_speculative_penalty_multi_scores,
|
||||
min_p_sampling,
|
||||
reasoning_phase_token_constraint,
|
||||
speculate_get_target_logits,
|
||||
speculate_insert_first_token,
|
||||
top_k_top_p_sampling,
|
||||
@@ -614,6 +615,9 @@ class SpeculativeSampler(nn.Layer):
|
||||
self.speculative_verify_window = fd_config.speculative_config.verify_window
|
||||
self.speculative_max_candidate_len = fd_config.speculative_config.max_candidate_len
|
||||
self.speculative_benchmark_mode = fd_config.speculative_config.benchmark_mode
|
||||
self.think_end_id = fd_config.model_config.think_end_id
|
||||
self.line_break_id = fd_config.model_config.line_break_id
|
||||
self.enf_gen_phase_tag = fd_config.speculative_config.enf_gen_phase_tag
|
||||
|
||||
def pre_process(self, skip_idx_list: List[int] = []):
|
||||
"""pre process before running"""
|
||||
@@ -757,6 +761,22 @@ class SpeculativeSampler(nn.Layer):
|
||||
max_model_len,
|
||||
)
|
||||
|
||||
if self.enf_gen_phase_tag:
|
||||
reasoning_phase_token_constraint(
|
||||
logits,
|
||||
sampling_metadata.pre_token_ids,
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["reasoning_allowed_tokens"],
|
||||
share_inputs["reasoning_status"],
|
||||
share_inputs["output_padding_offset"],
|
||||
share_inputs["output_cum_offsets"],
|
||||
self.think_end_id,
|
||||
self.line_break_id,
|
||||
)
|
||||
|
||||
probs = F.softmax(logits)
|
||||
|
||||
top_p, top_k, topp_seed = padding_sampling_params(
|
||||
@@ -797,6 +817,7 @@ class SpeculativeSampler(nn.Layer):
|
||||
actual_candidate_len,
|
||||
share_inputs["actual_draft_token_num"],
|
||||
sampling_metadata.top_p,
|
||||
share_inputs["reasoning_status"],
|
||||
max_model_len,
|
||||
self.speculative_verify_window,
|
||||
True, # enable_topp
|
||||
|
||||
Reference in New Issue
Block a user