[Bugfix] Align thinking_budget behavior with ERNIE reasoning flow (#6934)

* [Bugfix] Align thinking_budget behavior with ERNIE reasoning flow

* [Docs] Fix thinking_budget markdown formatting

* [Test] Align ernie thinking budget test with process_request_dict
This commit is contained in:
jackyYang6
2026-03-23 14:15:55 +08:00
committed by GitHub
parent 7a78001be2
commit 634d23a38a
10 changed files with 663 additions and 285 deletions
+20 -1
View File
@@ -107,6 +107,11 @@ class Ernie4_5Processor(BaseDataProcessor):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids
logits_processors_args = self._prepare_think_stop_sentence(
request.get("logits_processors_args") or {}, max_model_len
)
request["logits_processors_args"] = logits_processors_args
# processing prompt_token_ids
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
if request.prompt is not None:
@@ -123,7 +128,8 @@ class Ernie4_5Processor(BaseDataProcessor):
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
request.prompt_token_ids = token_ids
data_processor_logger.debug(
f"request_ids: {request.request_id}, prompt: {prompt}, tokens: {tokens}, token_ids: {token_ids}"
f"request_ids: {request.request_id}, prompt: {prompt}, "
f"tokens: {tokens}, token_ids: {token_ids}"
)
elif request.messages is not None:
task = request.to_dict()
@@ -145,6 +151,10 @@ class Ernie4_5Processor(BaseDataProcessor):
# truncate prompts that exceed the length limit
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
logits_processors_args = self._update_thinking_prompt_state(
request.prompt_token_ids, request.get("logits_processors_args") or {}
)
request["logits_processors_args"] = logits_processors_args
max_tokens = max_model_len - len(request.prompt_token_ids)
if request.get("max_tokens") is None:
request.set("max_tokens", max(1, max_tokens))
@@ -201,6 +211,11 @@ class Ernie4_5Processor(BaseDataProcessor):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request.sampling_params.bad_words_token_ids = bad_words_token_ids
logits_processors_args = self._prepare_think_stop_sentence(
getattr(request.sampling_params, "logits_processors_args", None) or {}, max_model_len
)
request.sampling_params.logits_processors_args = logits_processors_args
# processing prompt_token_ids
if not request.prompt_token_ids:
if request.prompt:
@@ -241,6 +256,10 @@ class Ernie4_5Processor(BaseDataProcessor):
# truncate prompts that exceed the length limit
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
logits_processors_args = self._update_thinking_prompt_state(
request.prompt_token_ids, getattr(request.sampling_params, "logits_processors_args", None) or {}
)
request.sampling_params.logits_processors_args = logits_processors_args
max_tokens = max_model_len - len(request.prompt_token_ids)
if getattr(request.sampling_params, "max_tokens", None) is None:
request.sampling_params.max_tokens = max(1, max_tokens)
@@ -216,6 +216,11 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request.sampling_params.bad_words_token_ids = bad_words_token_ids
logits_processors_args = self._prepare_think_stop_sentence(
getattr(request.sampling_params, "logits_processors_args", None) or {}, max_model_len
)
request.sampling_params.logits_processors_args = logits_processors_args
if request.prompt_token_ids:
messages = request.messages
if messages:
@@ -267,6 +272,10 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
# 截断超过长度限制的prompt
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
logits_processors_args = self._update_thinking_prompt_state(
request.prompt_token_ids, getattr(request.sampling_params, "logits_processors_args", None) or {}
)
request.sampling_params.logits_processors_args = logits_processors_args
max_tokens = max_model_len - len(request.prompt_token_ids)
if getattr(request.sampling_params, "max_tokens", None) is None:
+101 -75
View File
@@ -115,6 +115,9 @@ class BaseDataProcessor(ABC):
"""
Encode text into token ids with a small LRU cache.
"""
if not hasattr(self, "_tokenize_cache"):
self._tokenize_cache = OrderedDict()
self._tokenize_cache_capacity = getattr(self, "_tokenize_cache_capacity", 128)
key = (text, bool(add_special_tokens))
cached = self._tokenize_cache.get(key)
if cached is not None:
@@ -130,6 +133,25 @@ class BaseDataProcessor(ABC):
self._tokenize_cache.popitem(last=False)
return token_ids
def _encode_literal_text_with_cache(self, text):
if not hasattr(self, "_tokenize_cache"):
self._tokenize_cache = OrderedDict()
self._tokenize_cache_capacity = getattr(self, "_tokenize_cache_capacity", 128)
key = ("literal_text", text)
cached = self._tokenize_cache.get(key)
if cached is not None:
self._tokenize_cache.move_to_end(key)
return cached
token_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text))
if hasattr(token_ids, "tolist"):
token_ids = token_ids.tolist()
elif not isinstance(token_ids, list):
token_ids = list(token_ids)
self._tokenize_cache[key] = token_ids
if len(self._tokenize_cache) > self._tokenize_cache_capacity:
self._tokenize_cache.popitem(last=False)
return token_ids
def messages2ids(self, messages):
"""
Convert multi-turn messages into ID sequences.
@@ -142,6 +164,77 @@ class BaseDataProcessor(ABC):
"""
raise NotImplementedError
def _get_think_token_ids(self):
think_token_ids = getattr(self, "_think_token_ids", None)
if think_token_ids is not None:
return think_token_ids
tokenizer = getattr(self, "tokenizer", None)
vocab = tokenizer.get_vocab() if tokenizer is not None else {}
think_start_id = vocab.get("<think>", -1)
think_end_id = vocab.get("</think>", -1)
self._think_token_ids = (think_start_id, think_end_id)
return self._think_token_ids
def _prepare_think_stop_sentence(self, logits_processors_args, max_model_len=None):
if not isinstance(logits_processors_args, dict):
return logits_processors_args
think_stop_sentence = logits_processors_args.get("think_stop_sentence")
if isinstance(think_stop_sentence, str) and think_stop_sentence:
sentence_token_ids = self._encode_literal_text_with_cache(think_stop_sentence)
logits_processors_args["think_stop_sentence_token_ids"] = sentence_token_ids
logits_processors_args.pop("think_stop_sentence", None)
return logits_processors_args
def _update_thinking_prompt_state(self, prompt_token_ids, logits_processors_args):
if not isinstance(logits_processors_args, dict):
return logits_processors_args
thinking_budget = logits_processors_args.get("thinking_budget")
if thinking_budget is None or not isinstance(thinking_budget, int) or thinking_budget < 0:
return logits_processors_args
if logits_processors_args.get("think_prompt_checked"):
return logits_processors_args
if prompt_token_ids is None:
return logits_processors_args
token_len = getattr(prompt_token_ids, "size", None) or len(prompt_token_ids)
if token_len == 0:
return logits_processors_args
think_start_id, think_end_id = self._get_think_token_ids()
if think_start_id < 0 or think_end_id < 0:
return logits_processors_args
if hasattr(prompt_token_ids, "tolist"):
token_list = prompt_token_ids.tolist()
else:
token_list = list(prompt_token_ids)
started = False
ended = False
tokens_after_start = 0
last_token_id = None
in_thinking = False
for token_id in token_list:
if token_id == think_start_id:
started = True
ended = False
in_thinking = True
elif token_id == think_end_id and in_thinking:
ended = True
in_thinking = False
if started and token_list:
# Align with operator-level reasoning_max_tokens: prompt-side tokens
# inside <think> do not consume thinking budget.
last_token_id = int(token_list[-1])
logits_processors_args["think_prompt_checked"] = True
logits_processors_args["think_prompt_started"] = started
logits_processors_args["think_prompt_ended"] = ended
logits_processors_args["think_prompt_tokens_after_start"] = tokens_after_start
if last_token_id is not None:
logits_processors_args["think_prompt_last_token_id"] = last_token_id
else:
logits_processors_args.pop("think_prompt_last_token_id", None)
return logits_processors_args
def ids2tokens(self, token_id, task_id=None):
"""
token ids to strings
@@ -221,65 +314,6 @@ class DataProcessor(BaseDataProcessor):
self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
self.tokenizer.pad_token_id = self.pad_token_id
self._think_token_ids = None
def _get_think_token_ids(self):
if self._think_token_ids is not None:
return self._think_token_ids
vocab = self.tokenizer.get_vocab()
think_start_id = vocab.get("<think>", -1)
think_end_id = vocab.get("</think>", -1)
self._think_token_ids = (think_start_id, think_end_id)
return self._think_token_ids
def _update_thinking_prompt_state(self, prompt_token_ids, logits_processors_args):
if not isinstance(logits_processors_args, dict):
return logits_processors_args
thinking_budget = logits_processors_args.get("thinking_budget")
if thinking_budget is None or not isinstance(thinking_budget, int) or thinking_budget < 0:
return logits_processors_args
if logits_processors_args.get("think_prompt_checked"):
return logits_processors_args
if prompt_token_ids is None:
return logits_processors_args
token_len = getattr(prompt_token_ids, "size", None) or len(prompt_token_ids)
if token_len == 0:
return logits_processors_args
think_start_id, think_end_id = self._get_think_token_ids()
if think_start_id < 0 or think_end_id < 0:
return logits_processors_args
if hasattr(prompt_token_ids, "tolist"):
token_list = prompt_token_ids.tolist()
else:
token_list = list(prompt_token_ids)
started = think_start_id in token_list
ended = False
tokens_after_start = 0
last_token_id = None
if started:
start_pos = token_list.index(think_start_id)
tokens_after = token_list[start_pos + 1 :]
if think_end_id in tokens_after:
end_pos = tokens_after.index(think_end_id)
tokens_after_start = end_pos + 1
ended = True
else:
tokens_after_start = len(tokens_after)
if token_list:
last_token_id = int(token_list[-1])
logits_processors_args["think_prompt_checked"] = True
logits_processors_args["think_prompt_started"] = started
logits_processors_args["think_prompt_ended"] = ended
logits_processors_args["think_prompt_tokens_after_start"] = tokens_after_start
if last_token_id is not None:
logits_processors_args["think_prompt_last_token_id"] = last_token_id
else:
logits_processors_args.pop("think_prompt_last_token_id", None)
return logits_processors_args
def process_request(self, request, max_model_len=None, **kwargs):
"""
Preprocess the request
@@ -306,14 +340,10 @@ class DataProcessor(BaseDataProcessor):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids
logits_processors_args = request.get("logits_processors_args") or {}
think_stop_sentence = logits_processors_args.get("think_stop_sentence")
if isinstance(think_stop_sentence, str) and think_stop_sentence:
newline_token_ids = self.encode_with_cache("\n", max_model_len, add_special_tokens=False)
sentence_token_ids = self.encode_with_cache(think_stop_sentence, max_model_len, add_special_tokens=False)
logits_processors_args["think_stop_sentence_token_ids"] = newline_token_ids + sentence_token_ids
logits_processors_args.pop("think_stop_sentence", None)
request["logits_processors_args"] = logits_processors_args
logits_processors_args = self._prepare_think_stop_sentence(
request.get("logits_processors_args") or {}, max_model_len
)
request["logits_processors_args"] = logits_processors_args
# processing prompt_token_ids
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
@@ -411,14 +441,10 @@ class DataProcessor(BaseDataProcessor):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request.sampling_params.bad_words_token_ids = bad_words_token_ids
logits_processors_args = getattr(request.sampling_params, "logits_processors_args", None) or {}
think_stop_sentence = logits_processors_args.get("think_stop_sentence")
if isinstance(think_stop_sentence, str) and think_stop_sentence:
newline_token_ids = self.encode_with_cache("\n", max_model_len, add_special_tokens=False)
sentence_token_ids = self.encode_with_cache(think_stop_sentence, max_model_len, add_special_tokens=False)
logits_processors_args["think_stop_sentence_token_ids"] = newline_token_ids + sentence_token_ids
logits_processors_args.pop("think_stop_sentence", None)
request.sampling_params.logits_processors_args = logits_processors_args
logits_processors_args = self._prepare_think_stop_sentence(
getattr(request.sampling_params, "logits_processors_args", None) or {}, max_model_len
)
request.sampling_params.logits_processors_args = logits_processors_args
# processing prompt_token_ids
if not request.prompt_token_ids: