clean clean code in _load_per_tensor_weight_scale (#6868)

Co-authored-by: “liuruian” <liuruian@baidu.com>
This commit is contained in:
周周周
2026-03-17 14:06:57 +08:00
committed by GitHub
parent 3b7507a4c2
commit ea998dd26f
@@ -429,19 +429,8 @@ class FusedMoE(nn.Layer):
expert_param = param[expert_id - self.expert_id_offset]
if shard_id in ["gate", "up"]:
idx = 0 if shard_id == "gate" else 1
if expert_param[idx].shape != loaded_weight.shape:
if len(expert_param[idx].shape) != len(loaded_weight.shape):
loaded_weight = loaded_weight.reshape(expert_param[idx].shape)
else:
loaded_weight = loaded_weight.transpose([1, 0])
expert_param[idx].set_value(loaded_weight)
elif shard_id == "down":
if expert_param.shape != loaded_weight.shape:
if len(expert_param.shape) != len(loaded_weight.shape):
loaded_weight = loaded_weight.reshape(expert_param.shape)
else:
loaded_weight = loaded_weight.transpose([1, 0])
expert_param.set_value(loaded_weight)
def _load_expert_weight(