mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] support pooling model dummy_run (#4345)
* support qwen3-embedding * fix ci bug * support pooling dummy_run * fix * delete print * parallel_config.max_model_len * delete is_pooling_model in dummy_run * fix * fd_model * fix embedding load * fix * fix post_process
This commit is contained in:
@@ -180,7 +180,7 @@ def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
|
||||
|
||||
def as_embedding_model(cls: _T) -> _T:
|
||||
"""
|
||||
Subclass an existing vLLM model to support embeddings.
|
||||
Subclass an existing FastDeploy model to support embeddings.
|
||||
|
||||
By default, the embeddings of the whole prompt are extracted from the
|
||||
normalized hidden state corresponding to the last token.
|
||||
|
||||
@@ -12,9 +12,18 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Type
|
||||
from typing import ClassVar, Literal, Protocol, Type
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from typing_extensions import TypeVar, runtime_checkable
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
from fastdeploy.model_executor.layers.pooler import Pooler
|
||||
|
||||
T = TypeVar("T", default=paddle.Tensor)
|
||||
T_co = TypeVar("T_co", default=paddle.Tensor, covariant=True)
|
||||
|
||||
|
||||
def is_text_generation_model(model_cls: Type[nn.Layer]) -> bool:
|
||||
@@ -24,13 +33,7 @@ def is_text_generation_model(model_cls: Type[nn.Layer]) -> bool:
|
||||
|
||||
|
||||
def is_pooling_model(model_cls: Type[nn.Layer]) -> bool:
|
||||
class_name = model_cls.__name__
|
||||
pooling_indicators = ["Embedding", "ForSequenceClassification"]
|
||||
return (
|
||||
any(indicator in class_name for indicator in pooling_indicators)
|
||||
or hasattr(model_cls, "is_embedding_model")
|
||||
and model_cls.is_embedding_model
|
||||
)
|
||||
return getattr(model_cls, "is_pooling_model", False)
|
||||
|
||||
|
||||
def is_multimodal_model(class_name: str) -> bool:
|
||||
@@ -52,3 +55,48 @@ def get_default_pooling_type(model_cls: Type[nn.Layer] = None) -> str:
|
||||
if model_cls is not None:
|
||||
return getattr(model_cls, "default_pooling_type", "LAST")
|
||||
return "LAST"
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class FdModel(Protocol[T_co]):
|
||||
"""The interface required for all models in FastDeploy."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ids_remove_padding: paddle.Tensor,
|
||||
forward_metadata: ForwardMeta,
|
||||
) -> T_co:
|
||||
pass
|
||||
|
||||
|
||||
class FdModelForPooling(FdModel[T_co], Protocol[T_co]):
|
||||
"""The interface required for all pooling models in FastDeploy."""
|
||||
|
||||
is_pooling_model: ClassVar[Literal[True]] = True
|
||||
"""
|
||||
A flag that indicates this model supports pooling.
|
||||
|
||||
Note:
|
||||
There is no need to redefine this flag if this class is in the
|
||||
MRO of your model class.
|
||||
"""
|
||||
|
||||
default_pooling_type: ClassVar[str] = "LAST"
|
||||
"""
|
||||
Indicates the
|
||||
[fastdeploy.config.PoolerConfig.pooling_type][]
|
||||
to use by default.
|
||||
|
||||
You can use the
|
||||
[fastdeploy.model_executor.models.interfaces_base.default_pooling_type][]
|
||||
decorator to conveniently set this field.
|
||||
"""
|
||||
pooler: Pooler
|
||||
"""The pooler is only called on TP rank 0."""
|
||||
|
||||
@@ -303,7 +303,9 @@ class Qwen3ForCausalLM(ModelForCasualLM):
|
||||
if model_param_name not in params_dict:
|
||||
continue
|
||||
param = params_dict[model_param_name]
|
||||
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
|
||||
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
|
||||
break
|
||||
|
||||
Reference in New Issue
Block a user