[Bugfix] Align thinking_budget behavior with ERNIE reasoning flow (#6934)

* [Bugfix] Align thinking_budget behavior with ERNIE reasoning flow

* [Docs] Fix thinking_budget markdown formatting

* [Test] Align ernie thinking budget test with process_request_dict
This commit is contained in:
jackyYang6
2026-03-23 14:15:55 +08:00
committed by GitHub
parent 7a78001be2
commit 634d23a38a
10 changed files with 663 additions and 285 deletions
@@ -42,14 +42,15 @@ class _ThinkingState:
class ThinkingBudgetLogitsProcessor(LogitsProcessor):
"""Limit the number of tokens generated in the thinking phase.
The processor tracks per-request thinking state and forces a newline token
when the budget is reached, followed by the thinking end token on the next step.
The processor tracks per-request thinking state and forces the thinking end
token when the budget is reached. If a stop sentence is configured, the
processor emits the stop sentence first and then the thinking end token.
Request-specific configuration is provided via logits_processors_args:
{"thinking_budget": <int>}
Requires model_config to provide think_start_id, think_end_id, and line_break_id.
If any of these are missing or invalid (-1), the processor will be disabled.
Requires model_config to provide think_start_id and think_end_id. If any of
these are missing or invalid (-1), the processor will be disabled.
"""
def __init__(self, fd_config: FDConfig) -> None:
@@ -61,20 +62,33 @@ class ThinkingBudgetLogitsProcessor(LogitsProcessor):
self.think_start_token_id = think_start_id if isinstance(think_start_id, int) and think_start_id >= 0 else -1
self.think_end_token_id = think_end_id if isinstance(think_end_id, int) and think_end_id >= 0 else -1
self.line_break_token_id = line_break_id if isinstance(line_break_id, int) and line_break_id >= 0 else -1
self._enabled = (
self.think_start_token_id >= 0 and self.think_end_token_id >= 0 and self.line_break_token_id >= 0
)
self._enabled = self.think_start_token_id >= 0 and self.think_end_token_id >= 0
if not self._enabled:
logger.warning(
"ThinkingBudgetLogitsProcessor disabled: missing token ids "
f"(think_start={think_start_id}, think_end={think_end_id}, line_break={line_break_id}). "
"Ensure model vocab contains <think>, </think> tokens and line_break_id is configured."
f"(think_start={think_start_id}, think_end={think_end_id}). "
"Ensure model vocab contains <think> and </think> tokens."
)
self._states: Dict[str, _ThinkingState] = {}
self._active_req_ids: list[str] = []
self._active_budgets: list[int] = []
self._active_slots: list[int] = []
def _scan_prompt_state(self, prompt_slice: list[int]) -> tuple[bool, bool, int, Optional[int]]:
started = False
ended = False
in_thinking = False
for token_id in prompt_slice:
if token_id == self.think_start_token_id:
started = True
ended = False
in_thinking = True
elif token_id == self.think_end_token_id and in_thinking:
ended = True
in_thinking = False
last_token_id = int(prompt_slice[-1]) if started and prompt_slice else None
return started, ended, 0, last_token_id
def update_state(self, share_inputs: dict) -> None:
if not self._enabled:
return
@@ -82,6 +96,7 @@ class ThinkingBudgetLogitsProcessor(LogitsProcessor):
req_ids = share_inputs["req_ids"]
logits_processors_args = share_inputs["logits_processors_args"]
prompt_ids = share_inputs.get("prompt_ids")
token_ids_all = share_inputs.get("token_ids_all")
prompt_lens = share_inputs.get("prompt_lens")
pre_ids = share_inputs.get("pre_ids")
next_tokens = share_inputs.get("next_tokens")
@@ -153,6 +168,8 @@ class ThinkingBudgetLogitsProcessor(LogitsProcessor):
for idx, slot_id in enumerate(candidate_slots):
next_token_by_slot[slot_id] = int(next_sel[idx])
prompt_source = prompt_ids if prompt_ids is not None else token_ids_all
for idx, slot_id in enumerate(candidate_slots):
req_id = candidate_req_ids[idx]
logit_proc_args = candidate_args[idx]
@@ -190,25 +207,28 @@ class ThinkingBudgetLogitsProcessor(LogitsProcessor):
state.current_step_idx = current_step_idx
if not state.started and not state.prompt_checked:
if prompt_ids is not None and prompt_lens is not None:
if prompt_source is not None and prompt_lens is not None:
if prompt_lens_np is not None:
prompt_len = int(prompt_lens_np[slot_id])
else:
prompt_len = int(prompt_lens[slot_id])
prompt_slice = prompt_ids[slot_id, :prompt_len]
prompt_slice = prompt_slice.numpy().tolist()
if self.think_start_token_id in prompt_slice:
prompt_slice = prompt_source[slot_id, :prompt_len]
if hasattr(prompt_slice, "numpy"):
prompt_slice = prompt_slice.numpy().tolist()
elif hasattr(prompt_slice, "tolist"):
prompt_slice = prompt_slice.tolist()
else:
prompt_slice = list(prompt_slice)
if prompt_ids is None:
prompt_slice = [int(token_id) for token_id in prompt_slice if int(token_id) >= 0]
prompt_started, prompt_ended, prompt_tokens_after_start, prompt_last_token_id = (
self._scan_prompt_state(prompt_slice)
)
if prompt_started:
state.started = True
start_pos = prompt_slice.index(self.think_start_token_id)
tokens_after = prompt_slice[start_pos + 1 :]
if self.think_end_token_id in tokens_after:
end_pos = tokens_after.index(self.think_end_token_id)
state.tokens_after_start = end_pos + 1
state.ended = True
else:
state.tokens_after_start = len(tokens_after)
if prompt_slice:
state.last_token_id = int(prompt_slice[-1])
state.ended = prompt_ended
state.tokens_after_start = prompt_tokens_after_start
state.last_token_id = prompt_last_token_id
if current_step_idx is not None and state.last_step_idx is None:
state.last_step_idx = current_step_idx
state.prompt_checked = True
@@ -304,13 +324,6 @@ class ThinkingBudgetLogitsProcessor(LogitsProcessor):
if state.tokens_after_start < budget:
continue
if state.last_token_id != self.line_break_token_id:
logits[slot_id, :] = -float("inf")
logits[slot_id, self.line_break_token_id] = 0.0
state.last_token_id = self.line_break_token_id
state.last_step_idx = state.current_step_idx
continue
logits[slot_id, :] = -float("inf")
logits[slot_id, self.think_end_token_id] = 0.0
state.last_token_id = self.think_end_token_id