[Model] tp+ep support v1_loader (#5465)

* [Model] tp+ep support v1_loader

* fix

* fix mtp_linear

* fix mtp_linear

* fix

* fix

* fix v0 loader

* fix

* Add get_tensor for ep

* fix linear weight_loader

* fix typo

* fix
This commit is contained in:
Longzhi Wang
2025-12-18 14:31:54 +08:00
committed by GitHub
parent c89a62e550
commit d8587e987e
8 changed files with 48 additions and 20 deletions
@@ -130,6 +130,10 @@ class RMSNorm(nn.Layer):
dtype=self._norm_weight_dtype,
)
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
loaded_weight = get_tensor(loaded_weight).astype(self._norm_weight_dtype)
param.copy_(loaded_weight, False)
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
"""
Load the checkpoint state dictionary into the layer.