[Model] tp+ep support v1_loader (#5465)

* [Model] tp+ep support v1_loader

* fix

* fix mtp_linear

* fix mtp_linear

* fix

* fix

* fix v0 loader

* fix

* Add get_tensor for ep

* fix linear weight_loader

* fix typo

* fix
This commit is contained in:
Longzhi Wang
2025-12-18 14:31:54 +08:00
committed by GitHub
parent c89a62e550
commit d8587e987e
8 changed files with 48 additions and 20 deletions
@@ -229,6 +229,11 @@ class Attention(nn.Layer):
self.sinks.set_value(sinks_tensor)
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
if self.use_qk_norm and ("q_norm" in param.name or "k_norm" in param.name):
loaded_weight = get_tensor(loaded_weight).astype("float32")
param.copy_(loaded_weight, False)
return
loaded_weight = get_tensor(loaded_weight).cast(paddle.get_default_dtype())
if self.quant_method.cache_quant_config.has_zero_point: # cache_int4_zp
loaded_weight = 1.0 / loaded_weight