[Feature] bad words support v1 scheduler and specifiy token ids (#3608)

* support bad_words_token_ids

* docs

* fix test

* fix

* bad words support kvcache v1 and token ids

* fix
This commit is contained in:
Sunny-bot1
2025-08-26 11:14:51 +08:00
committed by GitHub
parent c43a4bec00
commit c68c3c4b8b
16 changed files with 420 additions and 62 deletions
-3
View File
@@ -461,7 +461,6 @@ class LLMEngine:
request = Request.from_dict(task)
llm_logger.info(f"Receive request {request}")
if sampling_params is not None:
sampling_params.update_from_tokenizer(self.data_processor.tokenizer)
request.sampling_params = sampling_params
request.preprocess_start_time = time.time()
@@ -762,8 +761,6 @@ class LLMEngine:
for task in tasks:
start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER)
if task.sampling_params.bad_words is not None:
task.sampling_params.update_from_tokenizer(self.data_processor.tokenizer)
self.resource_manager.check_and_free_block_tables()
+3 -43
View File
@@ -20,8 +20,6 @@ import random
from dataclasses import dataclass, fields
from typing import Any, List, Optional, Union
from fastdeploy.utils import llm_logger as logger
@dataclass
class SamplingParams:
@@ -102,7 +100,7 @@ class SamplingParams:
temp_scaled_logprobs: bool = False
top_p_normalized_logprobs: bool = False
bad_words: Optional[List[str]] = None
_bad_words_token_ids: Optional[List[int]] = None
bad_words_token_ids: Optional[List[int]] = None
@classmethod
def from_dict(cls, req_dict: dict[str, Any]) -> SamplingParams:
@@ -134,6 +132,7 @@ class SamplingParams:
min_tokens=1,
logprobs=None,
bad_words=None,
bad_words_token_ids=None,
) -> SamplingParams:
"""Create instance from command line arguments"""
return cls(
@@ -154,6 +153,7 @@ class SamplingParams:
min_tokens=min_tokens,
logprobs=logprobs,
bad_words=bad_words,
bad_words_token_ids=bad_words_token_ids,
)
def __post_init__(self):
@@ -206,46 +206,6 @@ class SamplingParams:
if not 0 <= self.seed <= 922337203685477580:
raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.")
def update_from_tokenizer(self, tokenizer):
"""Support bad words"""
if self.bad_words is None:
return
self._bad_words_token_ids = []
for bad_word in self.bad_words:
# To prohibit words both at the beginning
# and in the middle of text
# (related to add_prefix_space tokenizer parameter)
for add_prefix_space in [False, True]:
prefix = " " if add_prefix_space else ""
prompt = prefix + bad_word.lstrip()
prompt_token_ids = tokenizer.encode(text=prompt, add_special_tokens=False)["input_ids"]
if len(prompt_token_ids) != 1:
if not add_prefix_space:
logger.warning(
f"Skip bad_words: <{prompt}>."
f"Bad words should be a single token."
f"Got tokens: {prompt_token_ids}."
)
continue
if prompt_token_ids[0] > tokenizer.vocab_size:
if not add_prefix_space:
logger.warning(
f"Skip bad_words: <{prompt}>."
f"All token id values should be satisfying:"
f" 0 <= token_id < {tokenizer.vocab_size}."
f"Got token: {prompt_token_ids}."
)
continue
if prompt_token_ids not in self._bad_words_token_ids:
self._bad_words_token_ids.extend(prompt_token_ids)
@property
def bad_words_token_ids(self) -> Optional[List[list[int]]]:
return self._bad_words_token_ids
@dataclass
class BeamSearchParams:
@@ -426,6 +426,7 @@ class CompletionRequest(BaseModel):
min_tokens: Optional[int] = None
include_stop_str_in_output: Optional[bool] = False
bad_words: Optional[List[str]] = None
bad_words_token_ids: Optional[List[int]] = None
# doc: end-completion-sampling-params
# doc: start-completion-extra-params
@@ -566,6 +567,7 @@ class ChatCompletionRequest(BaseModel):
min_tokens: Optional[int] = None
include_stop_str_in_output: Optional[bool] = False
bad_words: Optional[List[str]] = None
bad_words_token_ids: Optional[List[int]] = None
repetition_penalty: Optional[float] = None
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
# doc: end-chat-completion-sampling-params
+53
View File
@@ -97,6 +97,12 @@ class ErnieProcessor(BaseDataProcessor):
request.set("stop_token_ids", stop_seqs)
request.set("stop_seqs_len", stop_seqs_len)
bad_words = request.get("bad_words")
bad_words_token_ids = request.get("bad_words_token_ids")
if bad_words:
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
if request.prompt is None and request.messages is None:
raise ValueError(f"The request should have `prompt_token_ids`, `prompt` or `messages`: {request}.")
@@ -160,6 +166,13 @@ class ErnieProcessor(BaseDataProcessor):
request["stop_token_ids"] = stop_seqs
request["stop_seqs_len"] = stop_seqs_len
# processing bad_words
bad_words = request.get("bad_words")
bad_words_token_ids = request.get("bad_words_token_ids")
if bad_words:
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids
# processing prompt_token_ids
if not request.get("prompt_token_ids"):
if request.get("prompt") is None and request.get("messages") is None:
@@ -481,3 +494,43 @@ class ErnieProcessor(BaseDataProcessor):
def process_logprob_response(self, token_ids, **kwargs):
full_text = self.tokenizer.decode(token_ids, **kwargs)
return full_text
def update_bad_words(self, bad_words, bad_words_token_ids):
"""Support bad words"""
token_ids = bad_words_token_ids
if token_ids is None:
token_ids = []
for bad_word in bad_words:
# To prohibit words both at the beginning
# and in the middle of text
# (related to add_prefix_space tokenizer parameter)
for add_prefix_space in [False, True]:
prefix = " " if add_prefix_space else ""
prompt = prefix + bad_word.lstrip()
prompt_token_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(prompt))
data_processor_logger.debug(f"processed bad_words: {prompt}, {prompt_token_ids}")
if len(prompt_token_ids) != 1:
if not add_prefix_space:
data_processor_logger.warning(
f"Skip bad_words: <{prompt}>."
f"Bad words should be a single token."
f"Got tokens: {prompt_token_ids}."
)
continue
if prompt_token_ids[0] > self.tokenizer.vocab_size:
if not add_prefix_space:
data_processor_logger.warning(
f"Skip bad_words: <{prompt}>."
f"All token id values should be satisfying:"
f" 0 <= token_id < {self.tokenizer.vocab_size}."
f"Got token: {prompt_token_ids}."
)
continue
if prompt_token_ids not in token_ids:
token_ids.extend(prompt_token_ids)
return token_ids
+6
View File
@@ -208,6 +208,12 @@ class ErnieMoEVLProcessor(ErnieProcessor):
request["stop_token_ids"] = stop_seqs
request["stop_seqs_len"] = stop_seqs_len
bad_words = request.get("bad_words")
bad_words_token_ids = request.get("bad_words_token_ids")
if bad_words:
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids
if request.get("prompt"):
multimodal_data = request.get("multimodal_data")
if multimodal_data is None:
+6
View File
@@ -212,6 +212,12 @@ class QwenVLProcessor(TextProcessor):
request["stop_token_ids"] = stop_seqs
request["stop_seqs_len"] = stop_seqs_len
bad_words = request.get("bad_words")
bad_words_token_ids = request.get("bad_words_token_ids")
if bad_words:
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids
if request.get("prompt"):
multimodal_data = request.get("multimodal_data")
if multimodal_data is None:
+52
View File
@@ -214,6 +214,12 @@ class DataProcessor(BaseDataProcessor):
request.set("stop_token_ids", stop_seqs)
request.set("stop_seqs_len", stop_seqs_len)
bad_words = request.get("bad_words")
bad_words_token_ids = request.get("bad_words_token_ids")
if bad_words:
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
if request.prompt is not None:
request.prompt_token_ids = self.text2ids(request.prompt, max_model_len)
@@ -270,6 +276,13 @@ class DataProcessor(BaseDataProcessor):
request["stop_token_ids"] = stop_seqs
request["stop_seqs_len"] = stop_seqs_len
# processing bad_words
bad_words = request.get("bad_words")
bad_words_token_ids = request.get("bad_words_token_ids")
if bad_words:
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids
data_processor_logger.info(f"Processing request {request}")
# processing prompt_token_ids
if not request.get("prompt_token_ids"):
@@ -652,3 +665,42 @@ class DataProcessor(BaseDataProcessor):
stop_seqs, stop_seqs_len = self.pad_batch_data(stop_seqs, pad_id=-1, return_seq_len=True, return_array=False)
data_processor_logger.debug(f"processed stop_seqs: {stop_seqs}, {stop_seqs_len}")
return stop_seqs, stop_seqs_len
def update_bad_words(self, bad_words, bad_words_token_ids):
"""Support bad words"""
token_ids = bad_words_token_ids
if token_ids is None:
token_ids = []
for bad_word in bad_words:
# To prohibit words both at the beginning
# and in the middle of text
# (related to add_prefix_space tokenizer parameter)
for add_prefix_space in [False, True]:
prefix = " " if add_prefix_space else ""
prompt = prefix + bad_word.lstrip()
prompt_token_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(prompt))
if len(prompt_token_ids) != 1:
if not add_prefix_space:
data_processor_logger.warning(
f"Skip bad_words: <{prompt}>."
f"Bad words should be a single token."
f"Got tokens: {prompt_token_ids}."
)
continue
if prompt_token_ids[0] > self.tokenizer.vocab_size:
if not add_prefix_space:
data_processor_logger.warning(
f"Skip bad_words: <{prompt}>."
f"All token id values should be satisfying:"
f" 0 <= token_id < {self.tokenizer.vocab_size}."
f"Got token: {prompt_token_ids}."
)
continue
if prompt_token_ids not in token_ids:
token_ids.extend(prompt_token_ids)
return token_ids
+10
View File
@@ -339,6 +339,16 @@ class GPUModelRunner(ModelRunnerBase):
if request.get("seed") is not None:
self.share_inputs["infer_seed"][idx : idx + 1] = request.get("seed")
if request.get("bad_words_token_ids") is not None and len(request.get("bad_words_token_ids")) > 0:
bad_words_len = len(request.get("bad_words_token_ids"))
self.share_inputs["bad_tokens_len"][idx : idx + 1] = bad_words_len
self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array(
request.get("bad_words_token_ids"), dtype="int64"
)
else:
self.share_inputs["bad_tokens_len"][idx : idx + 1] = 1
self.share_inputs["bad_tokens"][idx : idx + 1, :] = np.array([-1], dtype="int64")
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
stop_seqs_num = len(request.get("stop_seqs_len"))
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
+10
View File
@@ -322,6 +322,16 @@ class MetaxModelRunner(ModelRunnerBase):
if request.get("seed") is not None:
self.share_inputs["infer_seed"][idx : idx + 1] = request.get("seed")
if request.get("bad_words_token_ids") is not None and len(request.get("bad_words_token_ids")) > 0:
bad_words_len = len(request.get("bad_words_token_ids"))
self.share_inputs["bad_tokens_len"][idx : idx + 1] = bad_words_len
self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array(
request.get("bad_words_token_ids"), dtype="int64"
)
else:
self.share_inputs["bad_tokens_len"][idx : idx + 1] = 1
self.share_inputs["bad_tokens"][idx : idx + 1, :] = np.array([-1], dtype="int64")
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
stop_seqs_num = len(request.get("stop_seqs_len"))
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
+10
View File
@@ -455,6 +455,16 @@ class XPUModelRunner(ModelRunnerBase):
if request.get("seed") is not None:
self.share_inputs["infer_seed"][idx : idx + 1] = request.get("seed")
if request.get("bad_words_token_ids") is not None and len(request.get("bad_words_token_ids")) > 0:
bad_words_len = len(request.get("bad_words_token_ids"))
self.share_inputs["bad_tokens_len"][idx : idx + 1] = bad_words_len
self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array(
request.get("bad_words_token_ids"), dtype="int64"
)
else:
self.share_inputs["bad_tokens_len"][idx : idx + 1] = 1
self.share_inputs["bad_tokens"][idx : idx + 1, :] = np.array([-1], dtype="int64")
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
stop_seqs_num = len(request.get("stop_seqs_len"))
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):