mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-24 09:44:10 +08:00
add qwen-2.5-7B-PRM/ernie-rm (#4319)
This commit is contained in:
@@ -48,6 +48,8 @@ def determine_model_category(class_name: str):
|
||||
return ModelCategory.MULTIMODAL
|
||||
elif any(pattern in class_name for pattern in ["Embedding", "ForSequenceClassification"]):
|
||||
return ModelCategory.EMBEDDING
|
||||
elif any(pattern in class_name for pattern in ["Reward"]):
|
||||
return ModelCategory.REWARD
|
||||
return ModelCategory.TEXT_GENERATION
|
||||
|
||||
|
||||
@@ -100,3 +102,11 @@ class FdModelForPooling(FdModel[T_co], Protocol[T_co]):
|
||||
"""
|
||||
pooler: Pooler
|
||||
"""The pooler is only called on TP rank 0."""
|
||||
|
||||
|
||||
def default_pooling_type(pooling_type: str):
|
||||
def func(model):
|
||||
model.default_pooling_type = pooling_type # type: ignore
|
||||
return model
|
||||
|
||||
return func
|
||||
|
||||
Reference in New Issue
Block a user