[Models][Feature] Support new ERNIE reward model and add return_token_ids to reward API (#6638)

* reward model

* Add support for pooling-based inference in the reward model

* bugfix

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
SunLei
2026-03-06 18:51:00 +08:00
committed by GitHub
parent caf73e8131
commit 5d9524fc3c
5 changed files with 22 additions and 6 deletions
@@ -1079,8 +1079,10 @@ PoolingChatRequest = EmbeddingChatRequest
class ChatRewardRequest(BaseModel):
model: Optional[str] = None
prompt_token_ids: Optional[List[int]] = None
messages: Union[List[Any], List[int]]
user: Optional[str] = None
return_token_ids: Optional[bool] = None
dimensions: Optional[int] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
@@ -1151,6 +1153,8 @@ class ChatRewardRequest(BaseModel):
class ChatRewardData(BaseModel):
index: Optional[int] = None
object: str = "reward"
prompt_token_ids: Optional[List[int]] = None
prompt_tokens: Optional[str] = None
score: List[float]
@@ -26,6 +26,7 @@ from fastdeploy.entrypoints.openai.protocol import (
ChatRewardData,
ChatRewardRequest,
ChatRewardResponse,
ErrorResponse,
UsageInfo,
)
from fastdeploy.entrypoints.openai.serving_engine import ServeContext, ZmqOpenAIServing
@@ -54,7 +55,7 @@ class OpenAIServingReward(ZmqOpenAIServing):
request_dict["metrics"] = {}
return request_dict
else:
request_obj = None
request_obj: Request = None
if hasattr(request, "to_pooling_params"):
pooling_params: PoolingParams = request.to_pooling_params()
pooling_params.verify("reward", self.cfg.model_config)
@@ -90,10 +91,12 @@ class OpenAIServingReward(ZmqOpenAIServing):
response: ChatRewardResponse = None
generators: AsyncGenerator[ChatRewardResponse, None] = self.handle(ctx)
async for r in generators:
api_server_logger.info(f"engine pooling result:{r}")
r.data[0].index = idx
idx += 1
if response is None:
if response is None or isinstance(r, ErrorResponse):
response = r
break
else:
response.data.append(r.data[0])
response.usage.prompt_tokens += r.usage.prompt_tokens
@@ -109,8 +112,16 @@ class OpenAIServingReward(ZmqOpenAIServing):
base = PoolingRequestOutput.from_dict(request_output)
reward_res = RewardRequestOutput.from_base(base)
prompt_token_ids = None
prompt_tokens = None
if ctx.request.return_token_ids:
prompt_token_ids = request_output.get("prompt_token_ids", None)
prompt_tokens = ctx.preprocess_requests[0].get("prompt_tokens", None)
data = ChatRewardData(
index=0,
prompt_token_ids=prompt_token_ids,
prompt_tokens=prompt_tokens,
score=reward_res.outputs.score,
)