mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[feature] support reward api (#4518)
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Co-authored-by: SunLei <sunlei5788@gmail.com>
This commit is contained in:
@@ -729,3 +729,44 @@ class ScoringRequestOutput(PoolingRequestOutput[ScoringOutput]):
|
||||
prompt_token_ids=request_output.prompt_token_ids,
|
||||
finished=request_output.finished,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RewardOutput:
|
||||
"""The output data of one reward output of a request.
|
||||
|
||||
Args:
|
||||
reward: The score, which is a list of floats.
|
||||
Its length depends on the hidden dimension of the model.
|
||||
"""
|
||||
|
||||
score: list[float]
|
||||
|
||||
@staticmethod
|
||||
def from_base(pooling_output: PoolingOutput):
|
||||
pooled_data = pooling_output.data
|
||||
# if pooled_data.ndim != 1:
|
||||
# raise ValueError("pooled_data should be a 1-D embedding vector")
|
||||
|
||||
if isinstance(pooled_data, list):
|
||||
return RewardOutput(pooled_data)
|
||||
|
||||
return RewardOutput(pooled_data.tolist())
|
||||
|
||||
@property
|
||||
def hidden_size(self) -> int:
|
||||
return len(self.score)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"RewardOutput(hidden_size={self.hidden_size})"
|
||||
|
||||
|
||||
class RewardRequestOutput(PoolingRequestOutput[RewardOutput]):
|
||||
@staticmethod
|
||||
def from_base(request_output: PoolingRequestOutput):
|
||||
return RewardRequestOutput(
|
||||
request_id=request_output.request_id,
|
||||
outputs=RewardOutput.from_base(request_output.outputs),
|
||||
prompt_token_ids=request_output.prompt_token_ids,
|
||||
finished=request_output.finished,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user