[loader]supoort wint2 backend (#6139)

* support wint2

* update
This commit is contained in:
bukejiyu
2026-02-09 14:42:36 +08:00
committed by GitHub
parent f18f3b99ed
commit dc5917289d
20 changed files with 86 additions and 11 deletions
+3 -3
View File
@@ -316,7 +316,7 @@ class FusedMoE(nn.Layer):
)
def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None, is_sharded=False):
if self.tp_size > 1 and not is_sharded:
if self.tp_size > 1 and not is_sharded and not self.fd_config.load_config.is_pre_sharded:
tp_shard_dim = shard_dim
weight_dim = -1 if tp_shard_dim else 0
size = loaded_weight.shape[weight_dim]
@@ -371,7 +371,7 @@ class FusedMoE(nn.Layer):
h2d_copy(dst=expert_param, src=loaded_weight)
def _load_down_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None):
if self.tp_size > 1 and shard_dim is not None:
if self.tp_size > 1 and shard_dim is not None and not self.fd_config.load_config.is_pre_sharded:
tp_shard_dim = shard_dim
dim = -1 if tp_shard_dim else 0
size = loaded_weight.shape[dim]
@@ -397,7 +397,7 @@ class FusedMoE(nn.Layer):
h2d_copy(dst=expert_param, src=loaded_weight)
def _load_fused_experts_weight(self, param, loaded_weight):
if self.tp_size > 1 and self.moe_quant_type != "mxfp4":
if self.tp_size > 1 and self.moe_quant_type != "mxfp4" and not self.fd_config.load_config.is_pre_sharded:
dim = -1
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
size = loaded_weight.shape[dim]