mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Bugfix] Align thinking_budget behavior with ERNIE reasoning flow (#6934)
* [Bugfix] Align thinking_budget behavior with ERNIE reasoning flow * [Docs] Fix thinking_budget markdown formatting * [Test] Align ernie thinking budget test with process_request_dict
This commit is contained in:
@@ -107,6 +107,11 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
||||
request["bad_words_token_ids"] = bad_words_token_ids
|
||||
|
||||
logits_processors_args = self._prepare_think_stop_sentence(
|
||||
request.get("logits_processors_args") or {}, max_model_len
|
||||
)
|
||||
request["logits_processors_args"] = logits_processors_args
|
||||
|
||||
# processing prompt_token_ids
|
||||
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
|
||||
if request.prompt is not None:
|
||||
@@ -123,7 +128,8 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
||||
request.prompt_token_ids = token_ids
|
||||
data_processor_logger.debug(
|
||||
f"request_ids: {request.request_id}, prompt: {prompt}, tokens: {tokens}, token_ids: {token_ids}"
|
||||
f"request_ids: {request.request_id}, prompt: {prompt}, "
|
||||
f"tokens: {tokens}, token_ids: {token_ids}"
|
||||
)
|
||||
elif request.messages is not None:
|
||||
task = request.to_dict()
|
||||
@@ -145,6 +151,10 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
# truncate prompts that exceed the length limit
|
||||
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
|
||||
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
|
||||
logits_processors_args = self._update_thinking_prompt_state(
|
||||
request.prompt_token_ids, request.get("logits_processors_args") or {}
|
||||
)
|
||||
request["logits_processors_args"] = logits_processors_args
|
||||
max_tokens = max_model_len - len(request.prompt_token_ids)
|
||||
if request.get("max_tokens") is None:
|
||||
request.set("max_tokens", max(1, max_tokens))
|
||||
@@ -201,6 +211,11 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
||||
request.sampling_params.bad_words_token_ids = bad_words_token_ids
|
||||
|
||||
logits_processors_args = self._prepare_think_stop_sentence(
|
||||
getattr(request.sampling_params, "logits_processors_args", None) or {}, max_model_len
|
||||
)
|
||||
request.sampling_params.logits_processors_args = logits_processors_args
|
||||
|
||||
# processing prompt_token_ids
|
||||
if not request.prompt_token_ids:
|
||||
if request.prompt:
|
||||
@@ -241,6 +256,10 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
# truncate prompts that exceed the length limit
|
||||
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
|
||||
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
|
||||
logits_processors_args = self._update_thinking_prompt_state(
|
||||
request.prompt_token_ids, getattr(request.sampling_params, "logits_processors_args", None) or {}
|
||||
)
|
||||
request.sampling_params.logits_processors_args = logits_processors_args
|
||||
max_tokens = max_model_len - len(request.prompt_token_ids)
|
||||
if getattr(request.sampling_params, "max_tokens", None) is None:
|
||||
request.sampling_params.max_tokens = max(1, max_tokens)
|
||||
|
||||
Reference in New Issue
Block a user