mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Bugfix] Align thinking_budget behavior with ERNIE reasoning flow (#6934)
* [Bugfix] Align thinking_budget behavior with ERNIE reasoning flow * [Docs] Fix thinking_budget markdown formatting * [Test] Align ernie thinking budget test with process_request_dict
This commit is contained in:
@@ -107,6 +107,11 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
||||
request["bad_words_token_ids"] = bad_words_token_ids
|
||||
|
||||
logits_processors_args = self._prepare_think_stop_sentence(
|
||||
request.get("logits_processors_args") or {}, max_model_len
|
||||
)
|
||||
request["logits_processors_args"] = logits_processors_args
|
||||
|
||||
# processing prompt_token_ids
|
||||
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
|
||||
if request.prompt is not None:
|
||||
@@ -123,7 +128,8 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
||||
request.prompt_token_ids = token_ids
|
||||
data_processor_logger.debug(
|
||||
f"request_ids: {request.request_id}, prompt: {prompt}, tokens: {tokens}, token_ids: {token_ids}"
|
||||
f"request_ids: {request.request_id}, prompt: {prompt}, "
|
||||
f"tokens: {tokens}, token_ids: {token_ids}"
|
||||
)
|
||||
elif request.messages is not None:
|
||||
task = request.to_dict()
|
||||
@@ -145,6 +151,10 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
# truncate prompts that exceed the length limit
|
||||
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
|
||||
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
|
||||
logits_processors_args = self._update_thinking_prompt_state(
|
||||
request.prompt_token_ids, request.get("logits_processors_args") or {}
|
||||
)
|
||||
request["logits_processors_args"] = logits_processors_args
|
||||
max_tokens = max_model_len - len(request.prompt_token_ids)
|
||||
if request.get("max_tokens") is None:
|
||||
request.set("max_tokens", max(1, max_tokens))
|
||||
@@ -201,6 +211,11 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
||||
request.sampling_params.bad_words_token_ids = bad_words_token_ids
|
||||
|
||||
logits_processors_args = self._prepare_think_stop_sentence(
|
||||
getattr(request.sampling_params, "logits_processors_args", None) or {}, max_model_len
|
||||
)
|
||||
request.sampling_params.logits_processors_args = logits_processors_args
|
||||
|
||||
# processing prompt_token_ids
|
||||
if not request.prompt_token_ids:
|
||||
if request.prompt:
|
||||
@@ -241,6 +256,10 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
# truncate prompts that exceed the length limit
|
||||
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
|
||||
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
|
||||
logits_processors_args = self._update_thinking_prompt_state(
|
||||
request.prompt_token_ids, getattr(request.sampling_params, "logits_processors_args", None) or {}
|
||||
)
|
||||
request.sampling_params.logits_processors_args = logits_processors_args
|
||||
max_tokens = max_model_len - len(request.prompt_token_ids)
|
||||
if getattr(request.sampling_params, "max_tokens", None) is None:
|
||||
request.sampling_params.max_tokens = max(1, max_tokens)
|
||||
|
||||
@@ -216,6 +216,11 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
|
||||
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
||||
request.sampling_params.bad_words_token_ids = bad_words_token_ids
|
||||
|
||||
logits_processors_args = self._prepare_think_stop_sentence(
|
||||
getattr(request.sampling_params, "logits_processors_args", None) or {}, max_model_len
|
||||
)
|
||||
request.sampling_params.logits_processors_args = logits_processors_args
|
||||
|
||||
if request.prompt_token_ids:
|
||||
messages = request.messages
|
||||
if messages:
|
||||
@@ -267,6 +272,10 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
|
||||
# 截断超过长度限制的prompt
|
||||
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
|
||||
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
|
||||
logits_processors_args = self._update_thinking_prompt_state(
|
||||
request.prompt_token_ids, getattr(request.sampling_params, "logits_processors_args", None) or {}
|
||||
)
|
||||
request.sampling_params.logits_processors_args = logits_processors_args
|
||||
|
||||
max_tokens = max_model_len - len(request.prompt_token_ids)
|
||||
if getattr(request.sampling_params, "max_tokens", None) is None:
|
||||
|
||||
@@ -115,6 +115,9 @@ class BaseDataProcessor(ABC):
|
||||
"""
|
||||
Encode text into token ids with a small LRU cache.
|
||||
"""
|
||||
if not hasattr(self, "_tokenize_cache"):
|
||||
self._tokenize_cache = OrderedDict()
|
||||
self._tokenize_cache_capacity = getattr(self, "_tokenize_cache_capacity", 128)
|
||||
key = (text, bool(add_special_tokens))
|
||||
cached = self._tokenize_cache.get(key)
|
||||
if cached is not None:
|
||||
@@ -130,6 +133,25 @@ class BaseDataProcessor(ABC):
|
||||
self._tokenize_cache.popitem(last=False)
|
||||
return token_ids
|
||||
|
||||
def _encode_literal_text_with_cache(self, text):
|
||||
if not hasattr(self, "_tokenize_cache"):
|
||||
self._tokenize_cache = OrderedDict()
|
||||
self._tokenize_cache_capacity = getattr(self, "_tokenize_cache_capacity", 128)
|
||||
key = ("literal_text", text)
|
||||
cached = self._tokenize_cache.get(key)
|
||||
if cached is not None:
|
||||
self._tokenize_cache.move_to_end(key)
|
||||
return cached
|
||||
token_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text))
|
||||
if hasattr(token_ids, "tolist"):
|
||||
token_ids = token_ids.tolist()
|
||||
elif not isinstance(token_ids, list):
|
||||
token_ids = list(token_ids)
|
||||
self._tokenize_cache[key] = token_ids
|
||||
if len(self._tokenize_cache) > self._tokenize_cache_capacity:
|
||||
self._tokenize_cache.popitem(last=False)
|
||||
return token_ids
|
||||
|
||||
def messages2ids(self, messages):
|
||||
"""
|
||||
Convert multi-turn messages into ID sequences.
|
||||
@@ -142,6 +164,77 @@ class BaseDataProcessor(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_think_token_ids(self):
|
||||
think_token_ids = getattr(self, "_think_token_ids", None)
|
||||
if think_token_ids is not None:
|
||||
return think_token_ids
|
||||
tokenizer = getattr(self, "tokenizer", None)
|
||||
vocab = tokenizer.get_vocab() if tokenizer is not None else {}
|
||||
think_start_id = vocab.get("<think>", -1)
|
||||
think_end_id = vocab.get("</think>", -1)
|
||||
self._think_token_ids = (think_start_id, think_end_id)
|
||||
return self._think_token_ids
|
||||
|
||||
def _prepare_think_stop_sentence(self, logits_processors_args, max_model_len=None):
|
||||
if not isinstance(logits_processors_args, dict):
|
||||
return logits_processors_args
|
||||
think_stop_sentence = logits_processors_args.get("think_stop_sentence")
|
||||
if isinstance(think_stop_sentence, str) and think_stop_sentence:
|
||||
sentence_token_ids = self._encode_literal_text_with_cache(think_stop_sentence)
|
||||
logits_processors_args["think_stop_sentence_token_ids"] = sentence_token_ids
|
||||
logits_processors_args.pop("think_stop_sentence", None)
|
||||
return logits_processors_args
|
||||
|
||||
def _update_thinking_prompt_state(self, prompt_token_ids, logits_processors_args):
|
||||
if not isinstance(logits_processors_args, dict):
|
||||
return logits_processors_args
|
||||
thinking_budget = logits_processors_args.get("thinking_budget")
|
||||
if thinking_budget is None or not isinstance(thinking_budget, int) or thinking_budget < 0:
|
||||
return logits_processors_args
|
||||
if logits_processors_args.get("think_prompt_checked"):
|
||||
return logits_processors_args
|
||||
if prompt_token_ids is None:
|
||||
return logits_processors_args
|
||||
token_len = getattr(prompt_token_ids, "size", None) or len(prompt_token_ids)
|
||||
if token_len == 0:
|
||||
return logits_processors_args
|
||||
think_start_id, think_end_id = self._get_think_token_ids()
|
||||
if think_start_id < 0 or think_end_id < 0:
|
||||
return logits_processors_args
|
||||
|
||||
if hasattr(prompt_token_ids, "tolist"):
|
||||
token_list = prompt_token_ids.tolist()
|
||||
else:
|
||||
token_list = list(prompt_token_ids)
|
||||
|
||||
started = False
|
||||
ended = False
|
||||
tokens_after_start = 0
|
||||
last_token_id = None
|
||||
in_thinking = False
|
||||
for token_id in token_list:
|
||||
if token_id == think_start_id:
|
||||
started = True
|
||||
ended = False
|
||||
in_thinking = True
|
||||
elif token_id == think_end_id and in_thinking:
|
||||
ended = True
|
||||
in_thinking = False
|
||||
if started and token_list:
|
||||
# Align with operator-level reasoning_max_tokens: prompt-side tokens
|
||||
# inside <think> do not consume thinking budget.
|
||||
last_token_id = int(token_list[-1])
|
||||
|
||||
logits_processors_args["think_prompt_checked"] = True
|
||||
logits_processors_args["think_prompt_started"] = started
|
||||
logits_processors_args["think_prompt_ended"] = ended
|
||||
logits_processors_args["think_prompt_tokens_after_start"] = tokens_after_start
|
||||
if last_token_id is not None:
|
||||
logits_processors_args["think_prompt_last_token_id"] = last_token_id
|
||||
else:
|
||||
logits_processors_args.pop("think_prompt_last_token_id", None)
|
||||
return logits_processors_args
|
||||
|
||||
def ids2tokens(self, token_id, task_id=None):
|
||||
"""
|
||||
token ids to strings
|
||||
@@ -221,65 +314,6 @@ class DataProcessor(BaseDataProcessor):
|
||||
self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
|
||||
self.tokenizer.pad_token_id = self.pad_token_id
|
||||
|
||||
self._think_token_ids = None
|
||||
|
||||
def _get_think_token_ids(self):
|
||||
if self._think_token_ids is not None:
|
||||
return self._think_token_ids
|
||||
vocab = self.tokenizer.get_vocab()
|
||||
think_start_id = vocab.get("<think>", -1)
|
||||
think_end_id = vocab.get("</think>", -1)
|
||||
self._think_token_ids = (think_start_id, think_end_id)
|
||||
return self._think_token_ids
|
||||
|
||||
def _update_thinking_prompt_state(self, prompt_token_ids, logits_processors_args):
|
||||
if not isinstance(logits_processors_args, dict):
|
||||
return logits_processors_args
|
||||
thinking_budget = logits_processors_args.get("thinking_budget")
|
||||
if thinking_budget is None or not isinstance(thinking_budget, int) or thinking_budget < 0:
|
||||
return logits_processors_args
|
||||
if logits_processors_args.get("think_prompt_checked"):
|
||||
return logits_processors_args
|
||||
if prompt_token_ids is None:
|
||||
return logits_processors_args
|
||||
token_len = getattr(prompt_token_ids, "size", None) or len(prompt_token_ids)
|
||||
if token_len == 0:
|
||||
return logits_processors_args
|
||||
think_start_id, think_end_id = self._get_think_token_ids()
|
||||
if think_start_id < 0 or think_end_id < 0:
|
||||
return logits_processors_args
|
||||
|
||||
if hasattr(prompt_token_ids, "tolist"):
|
||||
token_list = prompt_token_ids.tolist()
|
||||
else:
|
||||
token_list = list(prompt_token_ids)
|
||||
|
||||
started = think_start_id in token_list
|
||||
ended = False
|
||||
tokens_after_start = 0
|
||||
last_token_id = None
|
||||
if started:
|
||||
start_pos = token_list.index(think_start_id)
|
||||
tokens_after = token_list[start_pos + 1 :]
|
||||
if think_end_id in tokens_after:
|
||||
end_pos = tokens_after.index(think_end_id)
|
||||
tokens_after_start = end_pos + 1
|
||||
ended = True
|
||||
else:
|
||||
tokens_after_start = len(tokens_after)
|
||||
if token_list:
|
||||
last_token_id = int(token_list[-1])
|
||||
|
||||
logits_processors_args["think_prompt_checked"] = True
|
||||
logits_processors_args["think_prompt_started"] = started
|
||||
logits_processors_args["think_prompt_ended"] = ended
|
||||
logits_processors_args["think_prompt_tokens_after_start"] = tokens_after_start
|
||||
if last_token_id is not None:
|
||||
logits_processors_args["think_prompt_last_token_id"] = last_token_id
|
||||
else:
|
||||
logits_processors_args.pop("think_prompt_last_token_id", None)
|
||||
return logits_processors_args
|
||||
|
||||
def process_request(self, request, max_model_len=None, **kwargs):
|
||||
"""
|
||||
Preprocess the request
|
||||
@@ -306,14 +340,10 @@ class DataProcessor(BaseDataProcessor):
|
||||
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
||||
request["bad_words_token_ids"] = bad_words_token_ids
|
||||
|
||||
logits_processors_args = request.get("logits_processors_args") or {}
|
||||
think_stop_sentence = logits_processors_args.get("think_stop_sentence")
|
||||
if isinstance(think_stop_sentence, str) and think_stop_sentence:
|
||||
newline_token_ids = self.encode_with_cache("\n", max_model_len, add_special_tokens=False)
|
||||
sentence_token_ids = self.encode_with_cache(think_stop_sentence, max_model_len, add_special_tokens=False)
|
||||
logits_processors_args["think_stop_sentence_token_ids"] = newline_token_ids + sentence_token_ids
|
||||
logits_processors_args.pop("think_stop_sentence", None)
|
||||
request["logits_processors_args"] = logits_processors_args
|
||||
logits_processors_args = self._prepare_think_stop_sentence(
|
||||
request.get("logits_processors_args") or {}, max_model_len
|
||||
)
|
||||
request["logits_processors_args"] = logits_processors_args
|
||||
|
||||
# processing prompt_token_ids
|
||||
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
|
||||
@@ -411,14 +441,10 @@ class DataProcessor(BaseDataProcessor):
|
||||
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
||||
request.sampling_params.bad_words_token_ids = bad_words_token_ids
|
||||
|
||||
logits_processors_args = getattr(request.sampling_params, "logits_processors_args", None) or {}
|
||||
think_stop_sentence = logits_processors_args.get("think_stop_sentence")
|
||||
if isinstance(think_stop_sentence, str) and think_stop_sentence:
|
||||
newline_token_ids = self.encode_with_cache("\n", max_model_len, add_special_tokens=False)
|
||||
sentence_token_ids = self.encode_with_cache(think_stop_sentence, max_model_len, add_special_tokens=False)
|
||||
logits_processors_args["think_stop_sentence_token_ids"] = newline_token_ids + sentence_token_ids
|
||||
logits_processors_args.pop("think_stop_sentence", None)
|
||||
request.sampling_params.logits_processors_args = logits_processors_args
|
||||
logits_processors_args = self._prepare_think_stop_sentence(
|
||||
getattr(request.sampling_params, "logits_processors_args", None) or {}, max_model_len
|
||||
)
|
||||
request.sampling_params.logits_processors_args = logits_processors_args
|
||||
|
||||
# processing prompt_token_ids
|
||||
if not request.prompt_token_ids:
|
||||
|
||||
Reference in New Issue
Block a user