mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] support stop_token_ids (#5399)
* support stop_token_ids * fix * delete chinese * support both * delete print
This commit is contained in:
@@ -24,6 +24,7 @@ from fastdeploy.input.text_processor import BaseDataProcessor
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
from fastdeploy.input.utils import process_stop_token_ids
|
||||
|
||||
|
||||
class Ernie4_5Processor(BaseDataProcessor):
|
||||
@@ -92,12 +93,8 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
|
||||
request.eos_token_ids = self.eos_token_ids
|
||||
|
||||
# processing stop_sequences
|
||||
stop_sequences = request.get("stop", [])
|
||||
if stop_sequences is not None and len(stop_sequences) != 0:
|
||||
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
|
||||
request.set("stop_token_ids", stop_seqs)
|
||||
request.set("stop_seqs_len", stop_seqs_len)
|
||||
# processing stop_sequences and stop_token_ids
|
||||
process_stop_token_ids(request, self.update_stop_seq)
|
||||
|
||||
# processing bad_words
|
||||
bad_words = request.get("bad_words")
|
||||
@@ -173,12 +170,8 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
if not request.get("eos_token_ids"):
|
||||
request["eos_token_ids"] = self.eos_token_ids
|
||||
|
||||
# processing stop_sequences
|
||||
stop_sequences = request.get("stop", [])
|
||||
if stop_sequences:
|
||||
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
|
||||
request["stop_token_ids"] = stop_seqs
|
||||
request["stop_seqs_len"] = stop_seqs_len
|
||||
# processing stop_sequences and stop_token_ids
|
||||
process_stop_token_ids(request, self.update_stop_seq)
|
||||
|
||||
# processing bad_words
|
||||
bad_words = request.get("bad_words")
|
||||
|
||||
Reference in New Issue
Block a user