mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Models][Feature] Support new ERNIE reward model and add return_token_ids to reward API (#6638)
* reward model * Add support for pooling-based inference in the reward model * bugfix --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -1079,8 +1079,10 @@ PoolingChatRequest = EmbeddingChatRequest
|
||||
|
||||
class ChatRewardRequest(BaseModel):
|
||||
model: Optional[str] = None
|
||||
prompt_token_ids: Optional[List[int]] = None
|
||||
messages: Union[List[Any], List[int]]
|
||||
user: Optional[str] = None
|
||||
return_token_ids: Optional[bool] = None
|
||||
|
||||
dimensions: Optional[int] = None
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
|
||||
@@ -1151,6 +1153,8 @@ class ChatRewardRequest(BaseModel):
|
||||
class ChatRewardData(BaseModel):
|
||||
index: Optional[int] = None
|
||||
object: str = "reward"
|
||||
prompt_token_ids: Optional[List[int]] = None
|
||||
prompt_tokens: Optional[str] = None
|
||||
score: List[float]
|
||||
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from fastdeploy.entrypoints.openai.protocol import (
|
||||
ChatRewardData,
|
||||
ChatRewardRequest,
|
||||
ChatRewardResponse,
|
||||
ErrorResponse,
|
||||
UsageInfo,
|
||||
)
|
||||
from fastdeploy.entrypoints.openai.serving_engine import ServeContext, ZmqOpenAIServing
|
||||
@@ -54,7 +55,7 @@ class OpenAIServingReward(ZmqOpenAIServing):
|
||||
request_dict["metrics"] = {}
|
||||
return request_dict
|
||||
else:
|
||||
request_obj = None
|
||||
request_obj: Request = None
|
||||
if hasattr(request, "to_pooling_params"):
|
||||
pooling_params: PoolingParams = request.to_pooling_params()
|
||||
pooling_params.verify("reward", self.cfg.model_config)
|
||||
@@ -90,10 +91,12 @@ class OpenAIServingReward(ZmqOpenAIServing):
|
||||
response: ChatRewardResponse = None
|
||||
generators: AsyncGenerator[ChatRewardResponse, None] = self.handle(ctx)
|
||||
async for r in generators:
|
||||
api_server_logger.info(f"engine pooling result:{r}")
|
||||
r.data[0].index = idx
|
||||
idx += 1
|
||||
if response is None:
|
||||
if response is None or isinstance(r, ErrorResponse):
|
||||
response = r
|
||||
break
|
||||
else:
|
||||
response.data.append(r.data[0])
|
||||
response.usage.prompt_tokens += r.usage.prompt_tokens
|
||||
@@ -109,8 +112,16 @@ class OpenAIServingReward(ZmqOpenAIServing):
|
||||
base = PoolingRequestOutput.from_dict(request_output)
|
||||
reward_res = RewardRequestOutput.from_base(base)
|
||||
|
||||
prompt_token_ids = None
|
||||
prompt_tokens = None
|
||||
if ctx.request.return_token_ids:
|
||||
prompt_token_ids = request_output.get("prompt_token_ids", None)
|
||||
prompt_tokens = ctx.preprocess_requests[0].get("prompt_tokens", None)
|
||||
|
||||
data = ChatRewardData(
|
||||
index=0,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_tokens=prompt_tokens,
|
||||
score=reward_res.outputs.score,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user