mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-05-10 01:21:55 +08:00
[Feature] bad words support v1 scheduler and specifiy token ids (#3608)
* support bad_words_token_ids * docs * fix test * fix * bad words support kvcache v1 and token ids * fix
This commit is contained in:
@@ -461,7 +461,6 @@ class LLMEngine:
|
||||
request = Request.from_dict(task)
|
||||
llm_logger.info(f"Receive request {request}")
|
||||
if sampling_params is not None:
|
||||
sampling_params.update_from_tokenizer(self.data_processor.tokenizer)
|
||||
request.sampling_params = sampling_params
|
||||
request.preprocess_start_time = time.time()
|
||||
|
||||
@@ -762,8 +761,6 @@ class LLMEngine:
|
||||
|
||||
for task in tasks:
|
||||
start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER)
|
||||
if task.sampling_params.bad_words is not None:
|
||||
task.sampling_params.update_from_tokenizer(self.data_processor.tokenizer)
|
||||
|
||||
self.resource_manager.check_and_free_block_tables()
|
||||
|
||||
|
||||
@@ -20,8 +20,6 @@ import random
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from fastdeploy.utils import llm_logger as logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingParams:
|
||||
@@ -102,7 +100,7 @@ class SamplingParams:
|
||||
temp_scaled_logprobs: bool = False
|
||||
top_p_normalized_logprobs: bool = False
|
||||
bad_words: Optional[List[str]] = None
|
||||
_bad_words_token_ids: Optional[List[int]] = None
|
||||
bad_words_token_ids: Optional[List[int]] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, req_dict: dict[str, Any]) -> SamplingParams:
|
||||
@@ -134,6 +132,7 @@ class SamplingParams:
|
||||
min_tokens=1,
|
||||
logprobs=None,
|
||||
bad_words=None,
|
||||
bad_words_token_ids=None,
|
||||
) -> SamplingParams:
|
||||
"""Create instance from command line arguments"""
|
||||
return cls(
|
||||
@@ -154,6 +153,7 @@ class SamplingParams:
|
||||
min_tokens=min_tokens,
|
||||
logprobs=logprobs,
|
||||
bad_words=bad_words,
|
||||
bad_words_token_ids=bad_words_token_ids,
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -206,46 +206,6 @@ class SamplingParams:
|
||||
if not 0 <= self.seed <= 922337203685477580:
|
||||
raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.")
|
||||
|
||||
def update_from_tokenizer(self, tokenizer):
|
||||
"""Support bad words"""
|
||||
if self.bad_words is None:
|
||||
return
|
||||
self._bad_words_token_ids = []
|
||||
for bad_word in self.bad_words:
|
||||
# To prohibit words both at the beginning
|
||||
# and in the middle of text
|
||||
# (related to add_prefix_space tokenizer parameter)
|
||||
for add_prefix_space in [False, True]:
|
||||
prefix = " " if add_prefix_space else ""
|
||||
prompt = prefix + bad_word.lstrip()
|
||||
prompt_token_ids = tokenizer.encode(text=prompt, add_special_tokens=False)["input_ids"]
|
||||
|
||||
if len(prompt_token_ids) != 1:
|
||||
if not add_prefix_space:
|
||||
logger.warning(
|
||||
f"Skip bad_words: <{prompt}>."
|
||||
f"Bad words should be a single token."
|
||||
f"Got tokens: {prompt_token_ids}."
|
||||
)
|
||||
continue
|
||||
|
||||
if prompt_token_ids[0] > tokenizer.vocab_size:
|
||||
if not add_prefix_space:
|
||||
logger.warning(
|
||||
f"Skip bad_words: <{prompt}>."
|
||||
f"All token id values should be satisfying:"
|
||||
f" 0 <= token_id < {tokenizer.vocab_size}."
|
||||
f"Got token: {prompt_token_ids}."
|
||||
)
|
||||
continue
|
||||
|
||||
if prompt_token_ids not in self._bad_words_token_ids:
|
||||
self._bad_words_token_ids.extend(prompt_token_ids)
|
||||
|
||||
@property
|
||||
def bad_words_token_ids(self) -> Optional[List[list[int]]]:
|
||||
return self._bad_words_token_ids
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeamSearchParams:
|
||||
|
||||
@@ -426,6 +426,7 @@ class CompletionRequest(BaseModel):
|
||||
min_tokens: Optional[int] = None
|
||||
include_stop_str_in_output: Optional[bool] = False
|
||||
bad_words: Optional[List[str]] = None
|
||||
bad_words_token_ids: Optional[List[int]] = None
|
||||
# doc: end-completion-sampling-params
|
||||
|
||||
# doc: start-completion-extra-params
|
||||
@@ -566,6 +567,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
min_tokens: Optional[int] = None
|
||||
include_stop_str_in_output: Optional[bool] = False
|
||||
bad_words: Optional[List[str]] = None
|
||||
bad_words_token_ids: Optional[List[int]] = None
|
||||
repetition_penalty: Optional[float] = None
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
# doc: end-chat-completion-sampling-params
|
||||
|
||||
@@ -97,6 +97,12 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
request.set("stop_token_ids", stop_seqs)
|
||||
request.set("stop_seqs_len", stop_seqs_len)
|
||||
|
||||
bad_words = request.get("bad_words")
|
||||
bad_words_token_ids = request.get("bad_words_token_ids")
|
||||
if bad_words:
|
||||
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
||||
request["bad_words_token_ids"] = bad_words_token_ids
|
||||
|
||||
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
|
||||
if request.prompt is None and request.messages is None:
|
||||
raise ValueError(f"The request should have `prompt_token_ids`, `prompt` or `messages`: {request}.")
|
||||
@@ -160,6 +166,13 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
request["stop_token_ids"] = stop_seqs
|
||||
request["stop_seqs_len"] = stop_seqs_len
|
||||
|
||||
# processing bad_words
|
||||
bad_words = request.get("bad_words")
|
||||
bad_words_token_ids = request.get("bad_words_token_ids")
|
||||
if bad_words:
|
||||
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
||||
request["bad_words_token_ids"] = bad_words_token_ids
|
||||
|
||||
# processing prompt_token_ids
|
||||
if not request.get("prompt_token_ids"):
|
||||
if request.get("prompt") is None and request.get("messages") is None:
|
||||
@@ -481,3 +494,43 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
def process_logprob_response(self, token_ids, **kwargs):
|
||||
full_text = self.tokenizer.decode(token_ids, **kwargs)
|
||||
return full_text
|
||||
|
||||
def update_bad_words(self, bad_words, bad_words_token_ids):
|
||||
"""Support bad words"""
|
||||
|
||||
token_ids = bad_words_token_ids
|
||||
|
||||
if token_ids is None:
|
||||
token_ids = []
|
||||
for bad_word in bad_words:
|
||||
# To prohibit words both at the beginning
|
||||
# and in the middle of text
|
||||
# (related to add_prefix_space tokenizer parameter)
|
||||
for add_prefix_space in [False, True]:
|
||||
prefix = " " if add_prefix_space else ""
|
||||
prompt = prefix + bad_word.lstrip()
|
||||
prompt_token_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(prompt))
|
||||
data_processor_logger.debug(f"processed bad_words: {prompt}, {prompt_token_ids}")
|
||||
|
||||
if len(prompt_token_ids) != 1:
|
||||
if not add_prefix_space:
|
||||
data_processor_logger.warning(
|
||||
f"Skip bad_words: <{prompt}>."
|
||||
f"Bad words should be a single token."
|
||||
f"Got tokens: {prompt_token_ids}."
|
||||
)
|
||||
continue
|
||||
|
||||
if prompt_token_ids[0] > self.tokenizer.vocab_size:
|
||||
if not add_prefix_space:
|
||||
data_processor_logger.warning(
|
||||
f"Skip bad_words: <{prompt}>."
|
||||
f"All token id values should be satisfying:"
|
||||
f" 0 <= token_id < {self.tokenizer.vocab_size}."
|
||||
f"Got token: {prompt_token_ids}."
|
||||
)
|
||||
continue
|
||||
|
||||
if prompt_token_ids not in token_ids:
|
||||
token_ids.extend(prompt_token_ids)
|
||||
return token_ids
|
||||
|
||||
@@ -208,6 +208,12 @@ class ErnieMoEVLProcessor(ErnieProcessor):
|
||||
request["stop_token_ids"] = stop_seqs
|
||||
request["stop_seqs_len"] = stop_seqs_len
|
||||
|
||||
bad_words = request.get("bad_words")
|
||||
bad_words_token_ids = request.get("bad_words_token_ids")
|
||||
if bad_words:
|
||||
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
||||
request["bad_words_token_ids"] = bad_words_token_ids
|
||||
|
||||
if request.get("prompt"):
|
||||
multimodal_data = request.get("multimodal_data")
|
||||
if multimodal_data is None:
|
||||
|
||||
@@ -212,6 +212,12 @@ class QwenVLProcessor(TextProcessor):
|
||||
request["stop_token_ids"] = stop_seqs
|
||||
request["stop_seqs_len"] = stop_seqs_len
|
||||
|
||||
bad_words = request.get("bad_words")
|
||||
bad_words_token_ids = request.get("bad_words_token_ids")
|
||||
if bad_words:
|
||||
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
||||
request["bad_words_token_ids"] = bad_words_token_ids
|
||||
|
||||
if request.get("prompt"):
|
||||
multimodal_data = request.get("multimodal_data")
|
||||
if multimodal_data is None:
|
||||
|
||||
@@ -214,6 +214,12 @@ class DataProcessor(BaseDataProcessor):
|
||||
request.set("stop_token_ids", stop_seqs)
|
||||
request.set("stop_seqs_len", stop_seqs_len)
|
||||
|
||||
bad_words = request.get("bad_words")
|
||||
bad_words_token_ids = request.get("bad_words_token_ids")
|
||||
if bad_words:
|
||||
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
||||
request["bad_words_token_ids"] = bad_words_token_ids
|
||||
|
||||
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
|
||||
if request.prompt is not None:
|
||||
request.prompt_token_ids = self.text2ids(request.prompt, max_model_len)
|
||||
@@ -270,6 +276,13 @@ class DataProcessor(BaseDataProcessor):
|
||||
request["stop_token_ids"] = stop_seqs
|
||||
request["stop_seqs_len"] = stop_seqs_len
|
||||
|
||||
# processing bad_words
|
||||
bad_words = request.get("bad_words")
|
||||
bad_words_token_ids = request.get("bad_words_token_ids")
|
||||
if bad_words:
|
||||
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
||||
request["bad_words_token_ids"] = bad_words_token_ids
|
||||
|
||||
data_processor_logger.info(f"Processing request {request}")
|
||||
# processing prompt_token_ids
|
||||
if not request.get("prompt_token_ids"):
|
||||
@@ -652,3 +665,42 @@ class DataProcessor(BaseDataProcessor):
|
||||
stop_seqs, stop_seqs_len = self.pad_batch_data(stop_seqs, pad_id=-1, return_seq_len=True, return_array=False)
|
||||
data_processor_logger.debug(f"processed stop_seqs: {stop_seqs}, {stop_seqs_len}")
|
||||
return stop_seqs, stop_seqs_len
|
||||
|
||||
def update_bad_words(self, bad_words, bad_words_token_ids):
|
||||
"""Support bad words"""
|
||||
|
||||
token_ids = bad_words_token_ids
|
||||
|
||||
if token_ids is None:
|
||||
token_ids = []
|
||||
for bad_word in bad_words:
|
||||
# To prohibit words both at the beginning
|
||||
# and in the middle of text
|
||||
# (related to add_prefix_space tokenizer parameter)
|
||||
for add_prefix_space in [False, True]:
|
||||
prefix = " " if add_prefix_space else ""
|
||||
prompt = prefix + bad_word.lstrip()
|
||||
prompt_token_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(prompt))
|
||||
|
||||
if len(prompt_token_ids) != 1:
|
||||
if not add_prefix_space:
|
||||
data_processor_logger.warning(
|
||||
f"Skip bad_words: <{prompt}>."
|
||||
f"Bad words should be a single token."
|
||||
f"Got tokens: {prompt_token_ids}."
|
||||
)
|
||||
continue
|
||||
|
||||
if prompt_token_ids[0] > self.tokenizer.vocab_size:
|
||||
if not add_prefix_space:
|
||||
data_processor_logger.warning(
|
||||
f"Skip bad_words: <{prompt}>."
|
||||
f"All token id values should be satisfying:"
|
||||
f" 0 <= token_id < {self.tokenizer.vocab_size}."
|
||||
f"Got token: {prompt_token_ids}."
|
||||
)
|
||||
continue
|
||||
|
||||
if prompt_token_ids not in token_ids:
|
||||
token_ids.extend(prompt_token_ids)
|
||||
return token_ids
|
||||
|
||||
@@ -339,6 +339,16 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
if request.get("seed") is not None:
|
||||
self.share_inputs["infer_seed"][idx : idx + 1] = request.get("seed")
|
||||
|
||||
if request.get("bad_words_token_ids") is not None and len(request.get("bad_words_token_ids")) > 0:
|
||||
bad_words_len = len(request.get("bad_words_token_ids"))
|
||||
self.share_inputs["bad_tokens_len"][idx : idx + 1] = bad_words_len
|
||||
self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array(
|
||||
request.get("bad_words_token_ids"), dtype="int64"
|
||||
)
|
||||
else:
|
||||
self.share_inputs["bad_tokens_len"][idx : idx + 1] = 1
|
||||
self.share_inputs["bad_tokens"][idx : idx + 1, :] = np.array([-1], dtype="int64")
|
||||
|
||||
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
|
||||
stop_seqs_num = len(request.get("stop_seqs_len"))
|
||||
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
|
||||
|
||||
@@ -322,6 +322,16 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
if request.get("seed") is not None:
|
||||
self.share_inputs["infer_seed"][idx : idx + 1] = request.get("seed")
|
||||
|
||||
if request.get("bad_words_token_ids") is not None and len(request.get("bad_words_token_ids")) > 0:
|
||||
bad_words_len = len(request.get("bad_words_token_ids"))
|
||||
self.share_inputs["bad_tokens_len"][idx : idx + 1] = bad_words_len
|
||||
self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array(
|
||||
request.get("bad_words_token_ids"), dtype="int64"
|
||||
)
|
||||
else:
|
||||
self.share_inputs["bad_tokens_len"][idx : idx + 1] = 1
|
||||
self.share_inputs["bad_tokens"][idx : idx + 1, :] = np.array([-1], dtype="int64")
|
||||
|
||||
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
|
||||
stop_seqs_num = len(request.get("stop_seqs_len"))
|
||||
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
|
||||
|
||||
@@ -455,6 +455,16 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
if request.get("seed") is not None:
|
||||
self.share_inputs["infer_seed"][idx : idx + 1] = request.get("seed")
|
||||
|
||||
if request.get("bad_words_token_ids") is not None and len(request.get("bad_words_token_ids")) > 0:
|
||||
bad_words_len = len(request.get("bad_words_token_ids"))
|
||||
self.share_inputs["bad_tokens_len"][idx : idx + 1] = bad_words_len
|
||||
self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array(
|
||||
request.get("bad_words_token_ids"), dtype="int64"
|
||||
)
|
||||
else:
|
||||
self.share_inputs["bad_tokens_len"][idx : idx + 1] = 1
|
||||
self.share_inputs["bad_tokens"][idx : idx + 1, :] = np.array([-1], dtype="int64")
|
||||
|
||||
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
|
||||
stop_seqs_num = len(request.get("stop_seqs_len"))
|
||||
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
|
||||
|
||||
Reference in New Issue
Block a user