remove source in weight_loader in moe.py (#6892)

This commit is contained in:
周周周
2026-03-19 13:31:43 +08:00
committed by GitHub
parent dd93f8ffb4
commit c184a7cb69
+12 -9
View File
@@ -265,11 +265,7 @@ class FusedMoE(nn.Layer):
loaded_weight,
expert_id,
shard_id: Optional[str] = None,
source: Optional[str] = None,
):
"""
source:Avoid redundant transpose of fused weights when weight_loader is called iteratively
"""
if expert_id is None and shard_id is None:
# MoE experts has been fused in disk
self._load_fused_experts_weight(param, loaded_weight)
@@ -283,17 +279,20 @@ class FusedMoE(nn.Layer):
if not (expert_id - self.expert_id_offset >= 0 and expert_id - self.expert_id_offset < self.num_local_experts):
return
if not param._is_initialized():
param.initialize()
weight_need_transpose = getattr(param, "weight_need_transpose", False)
if self.ep_size > 1 or weight_need_transpose:
loaded_weight = get_tensor(loaded_weight)
if weight_need_transpose:
loaded_weight = loaded_weight.transpose([1, 0])
if shard_id is None:
# 1.gate up fused in disk
if weight_need_transpose:
loaded_weight = loaded_weight.transpose([1, 0])
output_size = param[expert_id - self.expert_id_offset].shape[SHARD_ID_TO_SHARDED_DIM["gate"]]
shard_offsets = [
# (shard_id, shard_offset, shard_size)
@@ -305,10 +304,14 @@ class FusedMoE(nn.Layer):
loaded_weight_shard = slice_fn(
loaded_weight, SHARD_ID_TO_SHARDED_DIM[shard_id], shard_offset, shard_offset + shard_size
)
self.weight_loader(param, loaded_weight_shard, expert_id, shard_id, "fused")
self._load_expert_weight(
param=param,
expert_id=expert_id,
loaded_weight=loaded_weight_shard,
shard_id=shard_id,
shard_dim=SHARD_ID_TO_SHARDED_DIM[shard_id],
)
else:
if weight_need_transpose and source != "fused":
loaded_weight = loaded_weight.transpose([1, 0])
# 2.gate up splited in disk
assert shard_id in ["gate", "down", "up"]
self._load_expert_weight(