mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Optimization]Merge Text processor (#7030)
* merge text processor * update * fix unit test * merge messages2ids * fix unit test * 删除重复代码 * remove redundant code * delete code * fix unit test
This commit is contained in:
@@ -18,16 +18,10 @@ from abc import ABC, abstractmethod
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Mapping
|
||||
|
||||
import numpy as np
|
||||
from paddleformers.generation import GenerationConfig
|
||||
from paddleformers.transformers import Llama3Tokenizer, LlamaTokenizer
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.input.utils import process_stop_token_ids
|
||||
from fastdeploy.input.base_processor import BaseTextProcessor
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
class BaseDataProcessor(ABC):
|
||||
"""base class for data processor"""
|
||||
@@ -245,421 +239,21 @@ class BaseDataProcessor(ABC):
|
||||
return None
|
||||
|
||||
|
||||
class DataProcessor(BaseDataProcessor):
|
||||
class DataProcessor(BaseTextProcessor):
|
||||
"""Legacy text processor, kept for backward compatibility.
|
||||
|
||||
New code should use ``TextProcessor`` instead.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name_or_path, reasoning_parser_obj=None, tool_parser_obj=None):
|
||||
"""
|
||||
Initializes the DecodeStatus object.
|
||||
|
||||
Args:
|
||||
model_name_or_path (str): The name or path of the pre-trained model to be loaded.
|
||||
Can also be a path to a directory containing the pre-trained model file.
|
||||
|
||||
Returns:
|
||||
None.
|
||||
|
||||
Raises:
|
||||
None.
|
||||
"""
|
||||
|
||||
self.model_name_or_path = model_name_or_path
|
||||
|
||||
# Generation config
|
||||
try:
|
||||
self.generation_config = GenerationConfig.from_pretrained(self.model_name_or_path)
|
||||
except Exception as e:
|
||||
data_processor_logger.warning(
|
||||
f"Can't find generation config: {e}, so it will not use generation_config field in the model config"
|
||||
)
|
||||
self.generation_config = None
|
||||
|
||||
self.decode_status = dict()
|
||||
self.model_status_dict = dict()
|
||||
self.tool_parser_dict = dict()
|
||||
self.tokenizer = self._load_tokenizer()
|
||||
self._tokenize_cache = OrderedDict()
|
||||
self._tokenize_cache_capacity = 128
|
||||
data_processor_logger.info(
|
||||
f"tokenizer information: bos_token is {self.tokenizer.bos_token}, {self.tokenizer.bos_token_id}, \
|
||||
eos_token is {self.tokenizer.eos_token}, {self.tokenizer.eos_token_id} "
|
||||
super().__init__(
|
||||
model_name_or_path, reasoning_parser_obj=reasoning_parser_obj, tool_parser_obj=tool_parser_obj
|
||||
)
|
||||
|
||||
try:
|
||||
from paddleformers.trl.llm_utils import get_eos_token_id
|
||||
except Exception:
|
||||
from paddleformers.cli.utils.llm_utils import get_eos_token_id
|
||||
|
||||
self.eos_token_ids = get_eos_token_id(self.tokenizer, self.generation_config)
|
||||
data_processor_logger.info(
|
||||
f"The eos_token_ids obtained by merging tokenizer and generation_config is {self.eos_token_ids}"
|
||||
)
|
||||
self.eos_token_id_len = len(self.eos_token_ids)
|
||||
self.pad_token_id = self.get_pad_id()
|
||||
self.reasoning_parser = None
|
||||
self.tool_parser_obj = tool_parser_obj
|
||||
if reasoning_parser_obj:
|
||||
self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
|
||||
self.tokenizer.pad_token_id = self.pad_token_id
|
||||
|
||||
def process_request_dict(self, request, max_model_len=None, **kwargs):
|
||||
"""
|
||||
Preprocess the request
|
||||
|
||||
Args:
|
||||
request (Dict): may contain text and messages fields
|
||||
|
||||
Returns:
|
||||
bool: Whether preprocessing is successful
|
||||
str: error message
|
||||
"""
|
||||
data_processor_logger.info(f"Start processing request dict: {request}")
|
||||
request = self._apply_default_parameters(request)
|
||||
if not request.get("eos_token_ids"):
|
||||
request["eos_token_ids"] = self.eos_token_ids
|
||||
|
||||
# processing stop_sequences and stop_token_ids
|
||||
process_stop_token_ids(request, self.update_stop_seq)
|
||||
|
||||
# processing bad_words
|
||||
bad_words = request.get("bad_words")
|
||||
bad_words_token_ids = request.get("bad_words_token_ids")
|
||||
if bad_words:
|
||||
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
||||
request["bad_words_token_ids"] = bad_words_token_ids
|
||||
|
||||
logits_processors_args = self._prepare_think_stop_sentence(
|
||||
request.get("logits_processors_args") or {}, max_model_len
|
||||
)
|
||||
request["logits_processors_args"] = logits_processors_args
|
||||
|
||||
# processing prompt_token_ids
|
||||
if not request.get("prompt_token_ids"):
|
||||
if request.get("prompt"):
|
||||
prompt = request["prompt"]
|
||||
assert isinstance(prompt, str) or (
|
||||
isinstance(prompt, list) and all(isinstance(t, int) for t in prompt)
|
||||
), f"prompt must be a string or a list of integers, but got {type(prompt)}"
|
||||
if isinstance(prompt, list):
|
||||
request["prompt_token_ids"] = prompt
|
||||
else:
|
||||
add_special_tokens = request.get("add_special_tokens", False)
|
||||
request["prompt_token_ids"] = self.text2ids(
|
||||
prompt, max_model_len, add_special_tokens=add_special_tokens
|
||||
).tolist()
|
||||
elif request.get("messages"):
|
||||
if self.tokenizer.chat_template is None:
|
||||
raise ValueError("This model does not support chat_template.")
|
||||
chat_template_kwargs = request.get("chat_template_kwargs", {})
|
||||
if chat_template_kwargs:
|
||||
if isinstance(chat_template_kwargs, dict):
|
||||
for k, v in chat_template_kwargs.items():
|
||||
if k not in request:
|
||||
request[k] = v
|
||||
else:
|
||||
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
|
||||
request.setdefault("enable_thinking", True)
|
||||
request["prompt_token_ids"] = self.messages2ids(request, **chat_template_kwargs)
|
||||
else:
|
||||
raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")
|
||||
|
||||
if len(request["prompt_token_ids"]) == 0:
|
||||
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
|
||||
|
||||
# truncate prompts that exceed the length limit
|
||||
if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len:
|
||||
request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]
|
||||
|
||||
logits_processors_args = request.get("logits_processors_args") or {}
|
||||
logits_processors_args = self._update_thinking_prompt_state(
|
||||
request["prompt_token_ids"], logits_processors_args
|
||||
)
|
||||
request["logits_processors_args"] = logits_processors_args
|
||||
|
||||
max_tokens = max_model_len - len(request["prompt_token_ids"])
|
||||
if request.get("max_tokens") is None:
|
||||
request["max_tokens"] = max(1, max_tokens)
|
||||
else:
|
||||
request["max_tokens"] = min(max_tokens, request["max_tokens"])
|
||||
if request.get("temperature") < _SAMPLING_EPS:
|
||||
# zero temperature means greedy decoding: set top_k=1 to force argmax
|
||||
request["temperature"] = 1
|
||||
request["top_k"] = 1
|
||||
if request.get("top_p") < _SAMPLING_EPS:
|
||||
request["top_p"] = _SAMPLING_EPS
|
||||
request["top_k"] = 1
|
||||
if self.reasoning_parser:
|
||||
model_status = self.reasoning_parser.get_model_status(request["prompt_token_ids"])
|
||||
parts = request["request_id"].split("_")
|
||||
if len(parts) > 1:
|
||||
real_req_id = parts[0]
|
||||
index = int(parts[1])
|
||||
n = request.get("n", 1)
|
||||
for idx in range(index * n, (index + 1) * n):
|
||||
self.model_status_dict[f"{real_req_id}_{idx}"] = model_status
|
||||
else:
|
||||
self.model_status_dict[request["request_id"]] = model_status
|
||||
request["enable_thinking"] = model_status == "think_start"
|
||||
|
||||
data_processor_logger.info(f"Processed request dict: {request}")
|
||||
return request
|
||||
|
||||
def process_logprob_response(self, token_ids, **kwargs):
|
||||
full_text = self.tokenizer.decode(token_ids, **kwargs)
|
||||
return full_text
|
||||
|
||||
def process_response_dict_normal(self, response_dict, **kwargs):
|
||||
"""
|
||||
Preprocess the response
|
||||
|
||||
Args:
|
||||
response_dict (Dict): response for engine, contain ids fields
|
||||
|
||||
Returns:
|
||||
Dict: response contain text fields
|
||||
"""
|
||||
token_ids = response_dict["outputs"]["token_ids"]
|
||||
is_end = response_dict["finished"]
|
||||
req_id = response_dict["request_id"]
|
||||
request = kwargs.get("request", None)
|
||||
direct_decode = kwargs.get("direct_decode", False)
|
||||
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
||||
if token_ids[-1] in self.eos_token_ids:
|
||||
token_ids = token_ids[:-1]
|
||||
if direct_decode:
|
||||
delta_text = self.tokenizer.decode(token_ids)
|
||||
previous_texts = ""
|
||||
else:
|
||||
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||
if is_end:
|
||||
full_text = previous_texts + delta_text
|
||||
response_dict["outputs"]["completion_tokens"] = full_text
|
||||
response_dict["outputs"]["text"] = full_text
|
||||
if self.reasoning_parser:
|
||||
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(
|
||||
full_text,
|
||||
request,
|
||||
self.model_status_dict[req_id],
|
||||
)
|
||||
response_dict["outputs"]["text"] = text
|
||||
response_dict["outputs"]["reasoning_content"] = reasoning_content
|
||||
reasoning_tokens = self.tokenizer.tokenize(reasoning_content)
|
||||
response_dict["outputs"]["reasoning_token_num"] = len(reasoning_tokens)
|
||||
if self.tool_parser_obj:
|
||||
tool_parser = self.tool_parser_obj(self.tokenizer)
|
||||
tool_call_info = tool_parser.extract_tool_calls(full_text, request)
|
||||
if tool_call_info.tools_called:
|
||||
response_dict["outputs"]["tool_calls"] = tool_call_info.tool_calls
|
||||
if req_id in self.decode_status:
|
||||
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
|
||||
del self.decode_status[req_id]
|
||||
if req_id in self.model_status_dict:
|
||||
del self.model_status_dict[req_id]
|
||||
return response_dict
|
||||
|
||||
def process_response_dict_streaming(self, response_dict, **kwargs):
|
||||
"""
|
||||
Preprocess the response
|
||||
|
||||
Args:
|
||||
response_dict (Dict): response for engine, contain ids fields
|
||||
|
||||
Returns:
|
||||
Dict: response contain text fields
|
||||
"""
|
||||
is_end = response_dict["finished"]
|
||||
req_id = response_dict["request_id"]
|
||||
token_ids = response_dict["outputs"]["token_ids"]
|
||||
request = kwargs.get("request", None)
|
||||
|
||||
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
||||
if token_ids[-1] in self.eos_token_ids:
|
||||
token_ids = token_ids[:-1]
|
||||
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||
response_dict["outputs"]["text"] = delta_text
|
||||
response_dict["outputs"]["completion_tokens"] = delta_text
|
||||
response_dict["outputs"]["skipped"] = False
|
||||
response_dict["outputs"]["tool_calls"] = None
|
||||
response_dict["outputs"]["reasoning_content"] = ""
|
||||
if self.reasoning_parser:
|
||||
reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming(
|
||||
previous_texts,
|
||||
previous_texts + delta_text,
|
||||
delta_text,
|
||||
previous_token_ids,
|
||||
previous_token_ids + token_ids,
|
||||
token_ids,
|
||||
self.model_status_dict[req_id],
|
||||
)
|
||||
if reasoning_delta_message:
|
||||
reasoning_content = reasoning_delta_message.reasoning_content
|
||||
reasoning_tokens = self.tokenizer.tokenize(reasoning_content) if reasoning_content else []
|
||||
response_dict["outputs"]["reasoning_token_num"] = len(reasoning_tokens)
|
||||
response_dict["outputs"]["reasoning_content"] = reasoning_content or ""
|
||||
response_dict["outputs"]["text"] = reasoning_delta_message.content or ""
|
||||
else:
|
||||
if not is_end:
|
||||
response_dict["outputs"]["skipped"] = True
|
||||
if self.tool_parser_obj:
|
||||
if req_id not in self.tool_parser_dict:
|
||||
self.tool_parser_dict[req_id] = self.tool_parser_obj(self.tokenizer)
|
||||
tool_parser = self.tool_parser_dict[req_id]
|
||||
tool_call_delta_message = tool_parser.extract_tool_calls_streaming(
|
||||
previous_texts,
|
||||
previous_texts + delta_text,
|
||||
delta_text,
|
||||
previous_token_ids,
|
||||
previous_token_ids + token_ids,
|
||||
token_ids,
|
||||
request,
|
||||
)
|
||||
if tool_call_delta_message:
|
||||
if tool_call_delta_message.tool_calls:
|
||||
response_dict["outputs"]["text"] = tool_call_delta_message.content
|
||||
response_dict["outputs"]["tool_calls"] = tool_call_delta_message.tool_calls
|
||||
response_dict["outputs"]["skipped"] = False
|
||||
else:
|
||||
if not is_end:
|
||||
response_dict["outputs"]["skipped"] = True
|
||||
|
||||
if is_end:
|
||||
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
|
||||
del self.decode_status[req_id]
|
||||
if req_id in self.tool_parser_dict:
|
||||
del self.tool_parser_dict[req_id]
|
||||
if req_id in self.model_status_dict:
|
||||
del self.model_status_dict[req_id]
|
||||
return response_dict
|
||||
|
||||
def process_response_dict(self, response_dict, **kwargs):
|
||||
"""
|
||||
Preprocess the response
|
||||
|
||||
Args:
|
||||
response_dict (Dict): response for engine, contain ids fields
|
||||
|
||||
Returns:
|
||||
Dict: response contain text fields
|
||||
"""
|
||||
stream = kwargs.get("stream", True)
|
||||
if stream:
|
||||
return self.process_response_dict_streaming(response_dict, **kwargs)
|
||||
else:
|
||||
return self.process_response_dict_normal(
|
||||
response_dict=response_dict,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def text2ids(self, text, max_model_len, **kwargs):
|
||||
"""
|
||||
text to token ids
|
||||
|
||||
Args:
|
||||
text (str): text
|
||||
|
||||
Returns:
|
||||
List[int]: token ids list
|
||||
"""
|
||||
|
||||
add_special_tokens = kwargs.get("add_special_tokens", False)
|
||||
if envs.FD_USE_HF_TOKENIZER:
|
||||
tokens = self.tokenizer(
|
||||
text,
|
||||
return_tensors="np",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
)
|
||||
else:
|
||||
text = [text] if isinstance(text, str) else text
|
||||
|
||||
tokens = self.tokenizer(
|
||||
text,
|
||||
return_tensors="np",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=max_model_len,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
|
||||
return tokens["input_ids"][0]
|
||||
|
||||
def messages2ids(self, request, **kwargs):
|
||||
"""
|
||||
Convert multi-turn messages into ID sequences.
|
||||
|
||||
Args:
|
||||
messages (List[List[Dict[str, Any]]]): multi-turn messages.
|
||||
|
||||
Returns:
|
||||
List[int]: ID sequences
|
||||
"""
|
||||
|
||||
if "add_generation_prompt" not in kwargs:
|
||||
kwargs["add_generation_prompt"] = request.get("add_generation_prompt", True)
|
||||
|
||||
spliced_message = self.tokenizer.apply_chat_template(
|
||||
request,
|
||||
tokenize=False,
|
||||
split_special_tokens=False,
|
||||
add_special_tokens=False,
|
||||
**kwargs,
|
||||
)
|
||||
request["prompt_tokens"] = spliced_message
|
||||
req_id = None
|
||||
tokens = self.tokenizer.tokenize(spliced_message)
|
||||
if isinstance(request, dict):
|
||||
req_id = request.get("request_id", None)
|
||||
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
||||
data_processor_logger.info(f"req_id:{req_id}, tokens:{tokens}, token_ids: {token_ids}")
|
||||
return token_ids
|
||||
|
||||
def ids2tokens(self, token_id, task_id):
|
||||
"""
|
||||
token ids to strings
|
||||
|
||||
Args:
|
||||
token_ids (List[int]): token ids
|
||||
task_id (str): task id
|
||||
|
||||
Returns:
|
||||
List[str]: strings
|
||||
"""
|
||||
if envs.FD_USE_HF_TOKENIZER:
|
||||
if task_id not in self.decode_status:
|
||||
# history token ids & history token strings & befer decode str
|
||||
self.decode_status[task_id] = [[], [], ""]
|
||||
|
||||
status = self.decode_status[task_id]
|
||||
status[0].extend(token_id)
|
||||
decode_str = self.tokenizer.batch_decode(
|
||||
[status[0]],
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
)
|
||||
if isinstance(decode_str, list) and len(decode_str):
|
||||
new_str = decode_str[0].replace(status[2], "", 1)
|
||||
status[1].append(new_str)
|
||||
status[2] = decode_str[0]
|
||||
else:
|
||||
new_str = ""
|
||||
return new_str
|
||||
else:
|
||||
if task_id not in self.decode_status:
|
||||
# prefix offset & read offset & history token ids & history token strings
|
||||
self.decode_status[task_id] = [0, 0, [], ""]
|
||||
|
||||
status = self.decode_status[task_id]
|
||||
previous_texts = status[3]
|
||||
|
||||
# Extend in-place first, then pass the full list to decode_token
|
||||
# Avoids creating an O(n) temporary list every token
|
||||
status[2].extend(token_id)
|
||||
|
||||
decode_str, prefix_offset, read_offset = self.tokenizer.decode_token(status[2], status[0], status[1])
|
||||
status[0] = prefix_offset
|
||||
status[1] = read_offset
|
||||
status[3] += decode_str
|
||||
|
||||
return decode_str, status[2], previous_texts
|
||||
|
||||
def _load_tokenizer(self):
|
||||
"""
|
||||
load tokenizer
|
||||
@@ -676,114 +270,64 @@ class DataProcessor(BaseDataProcessor):
|
||||
|
||||
return AutoTokenizer.from_pretrained(self.model_name_or_path, padding_side="left", use_fast=True)
|
||||
|
||||
def clear_request_status(self, task_id):
|
||||
"""
|
||||
clear request status
|
||||
|
||||
Args:
|
||||
task_id (str): task id
|
||||
class TextProcessor(BaseTextProcessor):
|
||||
"""Unified text processor for both auto and ernie4_5 tokenizer types.
|
||||
|
||||
Returns:
|
||||
results_all (str): all token strings
|
||||
"""
|
||||
results_all = ""
|
||||
if task_id in self.decode_status:
|
||||
if envs.FD_USE_HF_TOKENIZER:
|
||||
results_all = self.decode_status[task_id][2]
|
||||
else:
|
||||
results_all = "".join(self.decode_status[task_id][3])
|
||||
del self.decode_status[task_id]
|
||||
return results_all
|
||||
Replaces ``DataProcessor`` (tokenizer_type="auto") and
|
||||
``Ernie4_5Processor`` (tokenizer_type="ernie4_5") with a single class.
|
||||
|
||||
def get_pad_id(self):
|
||||
"""
|
||||
get pad_token_id, if not pad_token_id, use eos_token
|
||||
Args:
|
||||
model_name_or_path: Path or name of the pretrained model.
|
||||
tokenizer_type: ``"auto"`` (default) or ``"ernie4_5"``.
|
||||
reasoning_parser_obj: Optional reasoning-parser class.
|
||||
tool_parser_obj: Optional tool-parser class.
|
||||
"""
|
||||
|
||||
Returns:
|
||||
int: pad_token_id
|
||||
"""
|
||||
if isinstance(self.tokenizer, (LlamaTokenizer, Llama3Tokenizer)) and not self.tokenizer.pad_token_id:
|
||||
return self.tokenizer.eos_token
|
||||
return self.tokenizer.pad_token_id
|
||||
|
||||
def pad_batch_data(
|
||||
def __init__(
|
||||
self,
|
||||
insts,
|
||||
pad_id=0,
|
||||
return_seq_len=False,
|
||||
return_array=True,
|
||||
pad_style="right",
|
||||
model_name_or_path: str,
|
||||
tokenizer_type: str = "auto",
|
||||
reasoning_parser_obj=None,
|
||||
tool_parser_obj=None,
|
||||
):
|
||||
"""Pad the instances to the max sequence length in batch."""
|
||||
if len(insts) == 0:
|
||||
padded_insts = np.array([[]], dtype=np.int64) if return_array else [[]]
|
||||
if return_seq_len:
|
||||
seq_len = np.array([], dtype=np.int64) if return_array else []
|
||||
return padded_insts, seq_len
|
||||
return padded_insts
|
||||
super().__init__(model_name_or_path, tokenizer_type, reasoning_parser_obj, tool_parser_obj)
|
||||
|
||||
max_len = max(map(len, insts))
|
||||
if pad_style == "left":
|
||||
padded_insts = [[pad_id] * (max_len - len(inst)) + list(inst) for inst in insts]
|
||||
# ------------------------------------------------------------------
|
||||
# Abstract method implementations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _load_tokenizer(self):
|
||||
if self.tokenizer_type == "ernie4_5":
|
||||
return self._load_ernie4_5_tokenizer()
|
||||
return self._load_auto_tokenizer()
|
||||
|
||||
def _load_auto_tokenizer(self):
|
||||
if envs.FD_USE_HF_TOKENIZER:
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
return AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=False)
|
||||
else:
|
||||
padded_insts = [list(inst) + [pad_id] * (max_len - len(inst)) for inst in insts]
|
||||
if return_array:
|
||||
padded_insts = np.array(padded_insts, dtype=np.int64).reshape([-1, max_len])
|
||||
from paddleformers.transformers import AutoTokenizer
|
||||
|
||||
if return_seq_len:
|
||||
seq_len = [len(inst) for inst in insts]
|
||||
if return_array:
|
||||
seq_len = np.array(seq_len, dtype=np.int64).reshape(-1, 1)
|
||||
return padded_insts, seq_len
|
||||
return padded_insts
|
||||
return AutoTokenizer.from_pretrained(self.model_name_or_path, padding_side="left", use_fast=True)
|
||||
|
||||
def update_stop_seq(self, stop_sequences):
|
||||
"""
|
||||
Update stop sequences from request.
|
||||
"""
|
||||
stop_seqs = []
|
||||
for seq in stop_sequences:
|
||||
if seq != self.tokenizer.eos_token_id:
|
||||
stop_seqs.append(self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(seq)))
|
||||
stop_seqs, stop_seqs_len = self.pad_batch_data(stop_seqs, pad_id=-1, return_seq_len=True, return_array=False)
|
||||
data_processor_logger.debug(f"processed stop_seqs: {stop_seqs}, {stop_seqs_len}")
|
||||
return stop_seqs, stop_seqs_len
|
||||
def _load_ernie4_5_tokenizer(self):
|
||||
import os
|
||||
|
||||
def update_bad_words(self, bad_words, bad_words_token_ids):
|
||||
"""Support bad words"""
|
||||
from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
|
||||
|
||||
token_ids = bad_words_token_ids
|
||||
vocab_file_names = ["tokenizer.model", "spm.model", "ernie_token_100k.model"]
|
||||
for name in vocab_file_names:
|
||||
if os.path.exists(os.path.join(self.model_name_or_path, name)):
|
||||
Ernie4_5Tokenizer.resource_files_names["vocab_file"] = name
|
||||
break
|
||||
return Ernie4_5Tokenizer.from_pretrained(self.model_name_or_path)
|
||||
|
||||
if token_ids is None:
|
||||
token_ids = []
|
||||
for bad_word in bad_words:
|
||||
# To prohibit words both at the beginning
|
||||
# and in the middle of text
|
||||
# (related to add_prefix_space tokenizer parameter)
|
||||
for add_prefix_space in [False, True]:
|
||||
prefix = " " if add_prefix_space else ""
|
||||
prompt = prefix + bad_word.lstrip()
|
||||
prompt_token_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(prompt))
|
||||
def text2ids(self, text, max_model_len=None, **kwargs):
|
||||
if self.tokenizer_type == "ernie4_5":
|
||||
return self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text))
|
||||
return super().text2ids(text, max_model_len, **kwargs)
|
||||
|
||||
if len(prompt_token_ids) != 1:
|
||||
if not add_prefix_space:
|
||||
data_processor_logger.warning(
|
||||
f"Skip bad_words: <{prompt}>."
|
||||
f"Bad words should be a single token."
|
||||
f"Got tokens: {prompt_token_ids}."
|
||||
)
|
||||
continue
|
||||
|
||||
if prompt_token_ids[0] > self.tokenizer.vocab_size:
|
||||
if not add_prefix_space:
|
||||
data_processor_logger.warning(
|
||||
f"Skip bad_words: <{prompt}>."
|
||||
f"All token id values should be satisfying:"
|
||||
f" 0 <= token_id < {self.tokenizer.vocab_size}."
|
||||
f"Got token: {prompt_token_ids}."
|
||||
)
|
||||
continue
|
||||
|
||||
if prompt_token_ids not in token_ids:
|
||||
token_ids.extend(prompt_token_ids)
|
||||
return token_ids
|
||||
def process_logprob_response(self, token_ids, **kwargs):
|
||||
return self.tokenizer.decode(token_ids, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user