mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] Support ThinkingBudget Logits processor to control thinking content length (#6367)
* feat: add thinking budget logits processor * add unittest * fix pre-commit * add unittest * docs: clarify operator-level vs logits processor usage and conflict guidance --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
from paddleformers.generation import GenerationConfig
|
||||
@@ -50,6 +51,8 @@ class BaseDataProcessor(ABC):
|
||||
f"mask_token is {self.tokenizer.mask_token}, {self.tokenizer.mask_token_id}",
|
||||
)
|
||||
)
|
||||
self._tokenize_cache = OrderedDict()
|
||||
self._tokenize_cache_capacity = 128
|
||||
|
||||
def _apply_default_parameters(self, request):
|
||||
"""
|
||||
@@ -108,6 +111,25 @@ class BaseDataProcessor(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def encode_with_cache(self, text, max_model_len=None, add_special_tokens=False):
|
||||
"""
|
||||
Encode text into token ids with a small LRU cache.
|
||||
"""
|
||||
key = (text, bool(add_special_tokens))
|
||||
cached = self._tokenize_cache.get(key)
|
||||
if cached is not None:
|
||||
self._tokenize_cache.move_to_end(key)
|
||||
return cached
|
||||
token_ids = self.text2ids(text, max_model_len, add_special_tokens=add_special_tokens)
|
||||
if hasattr(token_ids, "tolist"):
|
||||
token_ids = token_ids.tolist()
|
||||
elif not isinstance(token_ids, list):
|
||||
token_ids = list(token_ids)
|
||||
self._tokenize_cache[key] = token_ids
|
||||
if len(self._tokenize_cache) > self._tokenize_cache_capacity:
|
||||
self._tokenize_cache.popitem(last=False)
|
||||
return token_ids
|
||||
|
||||
def messages2ids(self, messages):
|
||||
"""
|
||||
Convert multi-turn messages into ID sequences.
|
||||
@@ -175,6 +197,8 @@ class DataProcessor(BaseDataProcessor):
|
||||
self.model_status_dict = dict()
|
||||
self.tool_parser_dict = dict()
|
||||
self.tokenizer = self._load_tokenizer()
|
||||
self._tokenize_cache = OrderedDict()
|
||||
self._tokenize_cache_capacity = 128
|
||||
data_processor_logger.info(
|
||||
f"tokenizer information: bos_token is {self.tokenizer.bos_token}, {self.tokenizer.bos_token_id}, \
|
||||
eos_token is {self.tokenizer.eos_token}, {self.tokenizer.eos_token_id} "
|
||||
@@ -194,6 +218,65 @@ class DataProcessor(BaseDataProcessor):
|
||||
self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
|
||||
self.tokenizer.pad_token_id = self.pad_token_id
|
||||
|
||||
self._think_token_ids = None
|
||||
|
||||
def _get_think_token_ids(self):
|
||||
if self._think_token_ids is not None:
|
||||
return self._think_token_ids
|
||||
vocab = self.tokenizer.get_vocab()
|
||||
think_start_id = vocab.get("<think>", -1)
|
||||
think_end_id = vocab.get("</think>", -1)
|
||||
self._think_token_ids = (think_start_id, think_end_id)
|
||||
return self._think_token_ids
|
||||
|
||||
def _update_thinking_prompt_state(self, prompt_token_ids, logits_processors_args):
|
||||
if not isinstance(logits_processors_args, dict):
|
||||
return logits_processors_args
|
||||
thinking_budget = logits_processors_args.get("thinking_budget")
|
||||
if thinking_budget is None or not isinstance(thinking_budget, int) or thinking_budget < 0:
|
||||
return logits_processors_args
|
||||
if logits_processors_args.get("think_prompt_checked"):
|
||||
return logits_processors_args
|
||||
if prompt_token_ids is None:
|
||||
return logits_processors_args
|
||||
token_len = getattr(prompt_token_ids, "size", None) or len(prompt_token_ids)
|
||||
if token_len == 0:
|
||||
return logits_processors_args
|
||||
think_start_id, think_end_id = self._get_think_token_ids()
|
||||
if think_start_id < 0 or think_end_id < 0:
|
||||
return logits_processors_args
|
||||
|
||||
if hasattr(prompt_token_ids, "tolist"):
|
||||
token_list = prompt_token_ids.tolist()
|
||||
else:
|
||||
token_list = list(prompt_token_ids)
|
||||
|
||||
started = think_start_id in token_list
|
||||
ended = False
|
||||
tokens_after_start = 0
|
||||
last_token_id = None
|
||||
if started:
|
||||
start_pos = token_list.index(think_start_id)
|
||||
tokens_after = token_list[start_pos + 1 :]
|
||||
if think_end_id in tokens_after:
|
||||
end_pos = tokens_after.index(think_end_id)
|
||||
tokens_after_start = end_pos + 1
|
||||
ended = True
|
||||
else:
|
||||
tokens_after_start = len(tokens_after)
|
||||
if token_list:
|
||||
last_token_id = int(token_list[-1])
|
||||
|
||||
logits_processors_args["think_prompt_checked"] = True
|
||||
logits_processors_args["think_prompt_started"] = started
|
||||
logits_processors_args["think_prompt_ended"] = ended
|
||||
logits_processors_args["think_prompt_tokens_after_start"] = tokens_after_start
|
||||
if last_token_id is not None:
|
||||
logits_processors_args["think_prompt_last_token_id"] = last_token_id
|
||||
else:
|
||||
logits_processors_args.pop("think_prompt_last_token_id", None)
|
||||
return logits_processors_args
|
||||
|
||||
def process_request(self, request, max_model_len=None, **kwargs):
|
||||
"""
|
||||
Preprocess the request
|
||||
@@ -220,6 +303,15 @@ class DataProcessor(BaseDataProcessor):
|
||||
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
||||
request["bad_words_token_ids"] = bad_words_token_ids
|
||||
|
||||
logits_processors_args = request.get("logits_processors_args") or {}
|
||||
think_stop_sentence = logits_processors_args.get("think_stop_sentence")
|
||||
if isinstance(think_stop_sentence, str) and think_stop_sentence:
|
||||
newline_token_ids = self.encode_with_cache("\n", max_model_len, add_special_tokens=False)
|
||||
sentence_token_ids = self.encode_with_cache(think_stop_sentence, max_model_len, add_special_tokens=False)
|
||||
logits_processors_args["think_stop_sentence_token_ids"] = newline_token_ids + sentence_token_ids
|
||||
logits_processors_args.pop("think_stop_sentence", None)
|
||||
request["logits_processors_args"] = logits_processors_args
|
||||
|
||||
# processing prompt_token_ids
|
||||
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
|
||||
if request.prompt is not None:
|
||||
@@ -257,6 +349,9 @@ class DataProcessor(BaseDataProcessor):
|
||||
# truncate prompts that exceed the length limit
|
||||
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
|
||||
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
|
||||
logits_processors_args = request.get("logits_processors_args") or {}
|
||||
logits_processors_args = self._update_thinking_prompt_state(request.prompt_token_ids, logits_processors_args)
|
||||
request["logits_processors_args"] = logits_processors_args
|
||||
if request.get("max_tokens") is None:
|
||||
request.set("max_tokens", max(1, max_model_len - len(request.prompt_token_ids)))
|
||||
if request.get("temperature") < _SAMPLING_EPS:
|
||||
@@ -306,6 +401,15 @@ class DataProcessor(BaseDataProcessor):
|
||||
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
||||
request.sampling_params.bad_words_token_ids = bad_words_token_ids
|
||||
|
||||
logits_processors_args = getattr(request.sampling_params, "logits_processors_args", None) or {}
|
||||
think_stop_sentence = logits_processors_args.get("think_stop_sentence")
|
||||
if isinstance(think_stop_sentence, str) and think_stop_sentence:
|
||||
newline_token_ids = self.encode_with_cache("\n", max_model_len, add_special_tokens=False)
|
||||
sentence_token_ids = self.encode_with_cache(think_stop_sentence, max_model_len, add_special_tokens=False)
|
||||
logits_processors_args["think_stop_sentence_token_ids"] = newline_token_ids + sentence_token_ids
|
||||
logits_processors_args.pop("think_stop_sentence", None)
|
||||
request.sampling_params.logits_processors_args = logits_processors_args
|
||||
|
||||
# processing prompt_token_ids
|
||||
if not request.prompt_token_ids:
|
||||
if request.prompt:
|
||||
@@ -346,6 +450,9 @@ class DataProcessor(BaseDataProcessor):
|
||||
# truncate prompts that exceed the length limit
|
||||
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
|
||||
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
|
||||
logits_processors_args = getattr(request.sampling_params, "logits_processors_args", None) or {}
|
||||
logits_processors_args = self._update_thinking_prompt_state(request.prompt_token_ids, logits_processors_args)
|
||||
request.sampling_params.logits_processors_args = logits_processors_args
|
||||
if request.sampling_params.max_tokens is None:
|
||||
request.sampling_params.max_tokens = max(1, max_model_len - len(request.prompt_token_ids))
|
||||
if request.sampling_params.temperature < _SAMPLING_EPS:
|
||||
|
||||
Reference in New Issue
Block a user