mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-05-06 15:40:33 +08:00
[Model] tp+ep support v1_loader (#5465)
* [Model] tp+ep support v1_loader * fix * fix mtp_linear * fix mtp_linear * fix * fix * fix v0 loader * fix * Add get_tensor for ep * fix linear weight_loader * fix typo * fix
This commit is contained in:
@@ -130,6 +130,10 @@ class RMSNorm(nn.Layer):
|
||||
dtype=self._norm_weight_dtype,
|
||||
)
|
||||
|
||||
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
||||
loaded_weight = get_tensor(loaded_weight).astype(self._norm_weight_dtype)
|
||||
param.copy_(loaded_weight, False)
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
|
||||
"""
|
||||
Load the checkpoint state dictionary into the layer.
|
||||
|
||||
Reference in New Issue
Block a user