mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Optimization]update prompt & prompt_token_ids (#6334)
* fix prompt * add unit test * add unit test * fix
This commit is contained in:
@@ -107,21 +107,13 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
# processing prompt_token_ids
|
||||
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
|
||||
if request.prompt is not None:
|
||||
# prompt = request.prompt if request.prompt is not None else request.messages[0]
|
||||
prompt = request.prompt
|
||||
assert isinstance(prompt, str) or (
|
||||
isinstance(prompt, list) and all([isinstance(t, int) for t in prompt])
|
||||
), f"prompt must be a string or a list of integers, but got {type(prompt)}"
|
||||
|
||||
if isinstance(prompt, list): # if prompt is a token id list
|
||||
request.prompt_token_ids = prompt
|
||||
else:
|
||||
tokens = self.tokenizer.tokenize(prompt)
|
||||
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
||||
request.prompt_token_ids = token_ids
|
||||
data_processor_logger.debug(
|
||||
f"request_ids: {request.request_id}, prompt: {prompt}, tokens: {tokens}, token_ids: {token_ids}"
|
||||
)
|
||||
tokens = self.tokenizer.tokenize(prompt)
|
||||
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
||||
request.prompt_token_ids = token_ids
|
||||
data_processor_logger.debug(
|
||||
f"request_ids: {request.request_id}, prompt: {prompt}, tokens: {tokens}, token_ids: {token_ids}"
|
||||
)
|
||||
elif request.messages is not None:
|
||||
task = request.to_dict()
|
||||
chat_template_kwargs = kwargs.get("chat_template_kwargs", {})
|
||||
|
||||
@@ -313,17 +313,10 @@ class DataProcessor(BaseDataProcessor):
|
||||
# processing prompt_token_ids
|
||||
if not request.get("prompt_token_ids"):
|
||||
if request.get("prompt"):
|
||||
prompt = request.get("prompt")
|
||||
add_special_tokens = request.get("add_special_tokens", False)
|
||||
assert isinstance(prompt, str) or (
|
||||
isinstance(prompt, list) and all([isinstance(t, int) for t in prompt])
|
||||
), f"prompt must be a string or a list of integers, but got {type(prompt)}"
|
||||
if isinstance(prompt, list): # if prompt is a token id list
|
||||
request["prompt_token_ids"] = prompt
|
||||
else:
|
||||
request["prompt_token_ids"] = self.text2ids(
|
||||
request["prompt"], max_model_len, add_special_tokens=add_special_tokens
|
||||
).tolist()
|
||||
request["prompt_token_ids"] = self.text2ids(
|
||||
request["prompt"], max_model_len, add_special_tokens=add_special_tokens
|
||||
).tolist()
|
||||
elif request.get("messages"):
|
||||
if self.tokenizer.chat_template is None:
|
||||
raise ValueError("This model does not support chat_template.")
|
||||
|
||||
Reference in New Issue
Block a user