mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 17:11:21 +08:00
[Bugfix] Align thinking_budget behavior with ERNIE reasoning flow (#6934)
* [Bugfix] Align thinking_budget behavior with ERNIE reasoning flow * [Docs] Fix thinking_budget markdown formatting * [Test] Align ernie thinking budget test with process_request_dict
This commit is contained in:
@@ -42,14 +42,15 @@ class _ThinkingState:
|
||||
class ThinkingBudgetLogitsProcessor(LogitsProcessor):
|
||||
"""Limit the number of tokens generated in the thinking phase.
|
||||
|
||||
The processor tracks per-request thinking state and forces a newline token
|
||||
when the budget is reached, followed by the thinking end token on the next step.
|
||||
The processor tracks per-request thinking state and forces the thinking end
|
||||
token when the budget is reached. If a stop sentence is configured, the
|
||||
processor emits the stop sentence first and then the thinking end token.
|
||||
Request-specific configuration is provided via logits_processors_args:
|
||||
|
||||
{"thinking_budget": <int>}
|
||||
|
||||
Requires model_config to provide think_start_id, think_end_id, and line_break_id.
|
||||
If any of these are missing or invalid (-1), the processor will be disabled.
|
||||
Requires model_config to provide think_start_id and think_end_id. If any of
|
||||
these are missing or invalid (-1), the processor will be disabled.
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig) -> None:
|
||||
@@ -61,20 +62,33 @@ class ThinkingBudgetLogitsProcessor(LogitsProcessor):
|
||||
self.think_start_token_id = think_start_id if isinstance(think_start_id, int) and think_start_id >= 0 else -1
|
||||
self.think_end_token_id = think_end_id if isinstance(think_end_id, int) and think_end_id >= 0 else -1
|
||||
self.line_break_token_id = line_break_id if isinstance(line_break_id, int) and line_break_id >= 0 else -1
|
||||
self._enabled = (
|
||||
self.think_start_token_id >= 0 and self.think_end_token_id >= 0 and self.line_break_token_id >= 0
|
||||
)
|
||||
self._enabled = self.think_start_token_id >= 0 and self.think_end_token_id >= 0
|
||||
if not self._enabled:
|
||||
logger.warning(
|
||||
"ThinkingBudgetLogitsProcessor disabled: missing token ids "
|
||||
f"(think_start={think_start_id}, think_end={think_end_id}, line_break={line_break_id}). "
|
||||
"Ensure model vocab contains <think>, </think> tokens and line_break_id is configured."
|
||||
f"(think_start={think_start_id}, think_end={think_end_id}). "
|
||||
"Ensure model vocab contains <think> and </think> tokens."
|
||||
)
|
||||
self._states: Dict[str, _ThinkingState] = {}
|
||||
self._active_req_ids: list[str] = []
|
||||
self._active_budgets: list[int] = []
|
||||
self._active_slots: list[int] = []
|
||||
|
||||
def _scan_prompt_state(self, prompt_slice: list[int]) -> tuple[bool, bool, int, Optional[int]]:
|
||||
started = False
|
||||
ended = False
|
||||
in_thinking = False
|
||||
for token_id in prompt_slice:
|
||||
if token_id == self.think_start_token_id:
|
||||
started = True
|
||||
ended = False
|
||||
in_thinking = True
|
||||
elif token_id == self.think_end_token_id and in_thinking:
|
||||
ended = True
|
||||
in_thinking = False
|
||||
last_token_id = int(prompt_slice[-1]) if started and prompt_slice else None
|
||||
return started, ended, 0, last_token_id
|
||||
|
||||
def update_state(self, share_inputs: dict) -> None:
|
||||
if not self._enabled:
|
||||
return
|
||||
@@ -82,6 +96,7 @@ class ThinkingBudgetLogitsProcessor(LogitsProcessor):
|
||||
req_ids = share_inputs["req_ids"]
|
||||
logits_processors_args = share_inputs["logits_processors_args"]
|
||||
prompt_ids = share_inputs.get("prompt_ids")
|
||||
token_ids_all = share_inputs.get("token_ids_all")
|
||||
prompt_lens = share_inputs.get("prompt_lens")
|
||||
pre_ids = share_inputs.get("pre_ids")
|
||||
next_tokens = share_inputs.get("next_tokens")
|
||||
@@ -153,6 +168,8 @@ class ThinkingBudgetLogitsProcessor(LogitsProcessor):
|
||||
for idx, slot_id in enumerate(candidate_slots):
|
||||
next_token_by_slot[slot_id] = int(next_sel[idx])
|
||||
|
||||
prompt_source = prompt_ids if prompt_ids is not None else token_ids_all
|
||||
|
||||
for idx, slot_id in enumerate(candidate_slots):
|
||||
req_id = candidate_req_ids[idx]
|
||||
logit_proc_args = candidate_args[idx]
|
||||
@@ -190,25 +207,28 @@ class ThinkingBudgetLogitsProcessor(LogitsProcessor):
|
||||
state.current_step_idx = current_step_idx
|
||||
|
||||
if not state.started and not state.prompt_checked:
|
||||
if prompt_ids is not None and prompt_lens is not None:
|
||||
if prompt_source is not None and prompt_lens is not None:
|
||||
if prompt_lens_np is not None:
|
||||
prompt_len = int(prompt_lens_np[slot_id])
|
||||
else:
|
||||
prompt_len = int(prompt_lens[slot_id])
|
||||
prompt_slice = prompt_ids[slot_id, :prompt_len]
|
||||
prompt_slice = prompt_slice.numpy().tolist()
|
||||
if self.think_start_token_id in prompt_slice:
|
||||
prompt_slice = prompt_source[slot_id, :prompt_len]
|
||||
if hasattr(prompt_slice, "numpy"):
|
||||
prompt_slice = prompt_slice.numpy().tolist()
|
||||
elif hasattr(prompt_slice, "tolist"):
|
||||
prompt_slice = prompt_slice.tolist()
|
||||
else:
|
||||
prompt_slice = list(prompt_slice)
|
||||
if prompt_ids is None:
|
||||
prompt_slice = [int(token_id) for token_id in prompt_slice if int(token_id) >= 0]
|
||||
prompt_started, prompt_ended, prompt_tokens_after_start, prompt_last_token_id = (
|
||||
self._scan_prompt_state(prompt_slice)
|
||||
)
|
||||
if prompt_started:
|
||||
state.started = True
|
||||
start_pos = prompt_slice.index(self.think_start_token_id)
|
||||
tokens_after = prompt_slice[start_pos + 1 :]
|
||||
if self.think_end_token_id in tokens_after:
|
||||
end_pos = tokens_after.index(self.think_end_token_id)
|
||||
state.tokens_after_start = end_pos + 1
|
||||
state.ended = True
|
||||
else:
|
||||
state.tokens_after_start = len(tokens_after)
|
||||
if prompt_slice:
|
||||
state.last_token_id = int(prompt_slice[-1])
|
||||
state.ended = prompt_ended
|
||||
state.tokens_after_start = prompt_tokens_after_start
|
||||
state.last_token_id = prompt_last_token_id
|
||||
if current_step_idx is not None and state.last_step_idx is None:
|
||||
state.last_step_idx = current_step_idx
|
||||
state.prompt_checked = True
|
||||
@@ -304,13 +324,6 @@ class ThinkingBudgetLogitsProcessor(LogitsProcessor):
|
||||
if state.tokens_after_start < budget:
|
||||
continue
|
||||
|
||||
if state.last_token_id != self.line_break_token_id:
|
||||
logits[slot_id, :] = -float("inf")
|
||||
logits[slot_id, self.line_break_token_id] = 0.0
|
||||
state.last_token_id = self.line_break_token_id
|
||||
state.last_step_idx = state.current_step_idx
|
||||
continue
|
||||
|
||||
logits[slot_id, :] = -float("inf")
|
||||
logits[slot_id, self.think_end_token_id] = 0.0
|
||||
state.last_token_id = self.think_end_token_id
|
||||
|
||||
Reference in New Issue
Block a user