mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
clean clean code in _load_per_tensor_weight_scale (#6868)
Co-authored-by: “liuruian” <liuruian@baidu.com>
This commit is contained in:
@@ -429,19 +429,8 @@ class FusedMoE(nn.Layer):
|
||||
expert_param = param[expert_id - self.expert_id_offset]
|
||||
if shard_id in ["gate", "up"]:
|
||||
idx = 0 if shard_id == "gate" else 1
|
||||
if expert_param[idx].shape != loaded_weight.shape:
|
||||
if len(expert_param[idx].shape) != len(loaded_weight.shape):
|
||||
loaded_weight = loaded_weight.reshape(expert_param[idx].shape)
|
||||
else:
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
|
||||
expert_param[idx].set_value(loaded_weight)
|
||||
elif shard_id == "down":
|
||||
if expert_param.shape != loaded_weight.shape:
|
||||
if len(expert_param.shape) != len(loaded_weight.shape):
|
||||
loaded_weight = loaded_weight.reshape(expert_param.shape)
|
||||
else:
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
expert_param.set_value(loaded_weight)
|
||||
|
||||
def _load_expert_weight(
|
||||
|
||||
Reference in New Issue
Block a user