[Feature] Support ThinkingBudget Logits processor to control thinking content length (#6367)

* feat: add thinking budget logits processor

* add unittest

* fix pre-commit

* add unittest

* docs: clarify operator-level vs logits processor usage and conflict guidance

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
jackyYang6
2026-02-25 14:17:09 +08:00
committed by GitHub
parent 1405d7d5d7
commit a29ee57e15
12 changed files with 1861 additions and 2 deletions
+107
View File
@@ -15,6 +15,7 @@
"""
from abc import ABC, abstractmethod
from collections import OrderedDict
import numpy as np
from paddleformers.generation import GenerationConfig
@@ -50,6 +51,8 @@ class BaseDataProcessor(ABC):
f"mask_token is {self.tokenizer.mask_token}, {self.tokenizer.mask_token_id}",
)
)
self._tokenize_cache = OrderedDict()
self._tokenize_cache_capacity = 128
def _apply_default_parameters(self, request):
"""
@@ -108,6 +111,25 @@ class BaseDataProcessor(ABC):
"""
raise NotImplementedError
def encode_with_cache(self, text, max_model_len=None, add_special_tokens=False):
"""
Encode text into token ids with a small LRU cache.
"""
key = (text, bool(add_special_tokens))
cached = self._tokenize_cache.get(key)
if cached is not None:
self._tokenize_cache.move_to_end(key)
return cached
token_ids = self.text2ids(text, max_model_len, add_special_tokens=add_special_tokens)
if hasattr(token_ids, "tolist"):
token_ids = token_ids.tolist()
elif not isinstance(token_ids, list):
token_ids = list(token_ids)
self._tokenize_cache[key] = token_ids
if len(self._tokenize_cache) > self._tokenize_cache_capacity:
self._tokenize_cache.popitem(last=False)
return token_ids
def messages2ids(self, messages):
"""
Convert multi-turn messages into ID sequences.
@@ -175,6 +197,8 @@ class DataProcessor(BaseDataProcessor):
self.model_status_dict = dict()
self.tool_parser_dict = dict()
self.tokenizer = self._load_tokenizer()
self._tokenize_cache = OrderedDict()
self._tokenize_cache_capacity = 128
data_processor_logger.info(
f"tokenizer information: bos_token is {self.tokenizer.bos_token}, {self.tokenizer.bos_token_id}, \
eos_token is {self.tokenizer.eos_token}, {self.tokenizer.eos_token_id} "
@@ -194,6 +218,65 @@ class DataProcessor(BaseDataProcessor):
self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
self.tokenizer.pad_token_id = self.pad_token_id
self._think_token_ids = None
def _get_think_token_ids(self):
if self._think_token_ids is not None:
return self._think_token_ids
vocab = self.tokenizer.get_vocab()
think_start_id = vocab.get("<think>", -1)
think_end_id = vocab.get("</think>", -1)
self._think_token_ids = (think_start_id, think_end_id)
return self._think_token_ids
def _update_thinking_prompt_state(self, prompt_token_ids, logits_processors_args):
if not isinstance(logits_processors_args, dict):
return logits_processors_args
thinking_budget = logits_processors_args.get("thinking_budget")
if thinking_budget is None or not isinstance(thinking_budget, int) or thinking_budget < 0:
return logits_processors_args
if logits_processors_args.get("think_prompt_checked"):
return logits_processors_args
if prompt_token_ids is None:
return logits_processors_args
token_len = getattr(prompt_token_ids, "size", None) or len(prompt_token_ids)
if token_len == 0:
return logits_processors_args
think_start_id, think_end_id = self._get_think_token_ids()
if think_start_id < 0 or think_end_id < 0:
return logits_processors_args
if hasattr(prompt_token_ids, "tolist"):
token_list = prompt_token_ids.tolist()
else:
token_list = list(prompt_token_ids)
started = think_start_id in token_list
ended = False
tokens_after_start = 0
last_token_id = None
if started:
start_pos = token_list.index(think_start_id)
tokens_after = token_list[start_pos + 1 :]
if think_end_id in tokens_after:
end_pos = tokens_after.index(think_end_id)
tokens_after_start = end_pos + 1
ended = True
else:
tokens_after_start = len(tokens_after)
if token_list:
last_token_id = int(token_list[-1])
logits_processors_args["think_prompt_checked"] = True
logits_processors_args["think_prompt_started"] = started
logits_processors_args["think_prompt_ended"] = ended
logits_processors_args["think_prompt_tokens_after_start"] = tokens_after_start
if last_token_id is not None:
logits_processors_args["think_prompt_last_token_id"] = last_token_id
else:
logits_processors_args.pop("think_prompt_last_token_id", None)
return logits_processors_args
def process_request(self, request, max_model_len=None, **kwargs):
"""
Preprocess the request
@@ -220,6 +303,15 @@ class DataProcessor(BaseDataProcessor):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids
logits_processors_args = request.get("logits_processors_args") or {}
think_stop_sentence = logits_processors_args.get("think_stop_sentence")
if isinstance(think_stop_sentence, str) and think_stop_sentence:
newline_token_ids = self.encode_with_cache("\n", max_model_len, add_special_tokens=False)
sentence_token_ids = self.encode_with_cache(think_stop_sentence, max_model_len, add_special_tokens=False)
logits_processors_args["think_stop_sentence_token_ids"] = newline_token_ids + sentence_token_ids
logits_processors_args.pop("think_stop_sentence", None)
request["logits_processors_args"] = logits_processors_args
# processing prompt_token_ids
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
if request.prompt is not None:
@@ -257,6 +349,9 @@ class DataProcessor(BaseDataProcessor):
# truncate prompts that exceed the length limit
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
logits_processors_args = request.get("logits_processors_args") or {}
logits_processors_args = self._update_thinking_prompt_state(request.prompt_token_ids, logits_processors_args)
request["logits_processors_args"] = logits_processors_args
if request.get("max_tokens") is None:
request.set("max_tokens", max(1, max_model_len - len(request.prompt_token_ids)))
if request.get("temperature") < _SAMPLING_EPS:
@@ -306,6 +401,15 @@ class DataProcessor(BaseDataProcessor):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request.sampling_params.bad_words_token_ids = bad_words_token_ids
logits_processors_args = getattr(request.sampling_params, "logits_processors_args", None) or {}
think_stop_sentence = logits_processors_args.get("think_stop_sentence")
if isinstance(think_stop_sentence, str) and think_stop_sentence:
newline_token_ids = self.encode_with_cache("\n", max_model_len, add_special_tokens=False)
sentence_token_ids = self.encode_with_cache(think_stop_sentence, max_model_len, add_special_tokens=False)
logits_processors_args["think_stop_sentence_token_ids"] = newline_token_ids + sentence_token_ids
logits_processors_args.pop("think_stop_sentence", None)
request.sampling_params.logits_processors_args = logits_processors_args
# processing prompt_token_ids
if not request.prompt_token_ids:
if request.prompt:
@@ -346,6 +450,9 @@ class DataProcessor(BaseDataProcessor):
# truncate prompts that exceed the length limit
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
logits_processors_args = getattr(request.sampling_params, "logits_processors_args", None) or {}
logits_processors_args = self._update_thinking_prompt_state(request.prompt_token_ids, logits_processors_args)
request.sampling_params.logits_processors_args = logits_processors_args
if request.sampling_params.max_tokens is None:
request.sampling_params.max_tokens = max(1, max_model_len - len(request.prompt_token_ids))
if request.sampling_params.temperature < _SAMPLING_EPS: