mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
remove source in weight_loader in moe.py (#6892)
This commit is contained in:
@@ -265,11 +265,7 @@ class FusedMoE(nn.Layer):
|
||||
loaded_weight,
|
||||
expert_id,
|
||||
shard_id: Optional[str] = None,
|
||||
source: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
source:Avoid redundant transpose of fused weights when weight_loader is called iteratively
|
||||
"""
|
||||
if expert_id is None and shard_id is None:
|
||||
# MoE experts has been fused in disk
|
||||
self._load_fused_experts_weight(param, loaded_weight)
|
||||
@@ -283,17 +279,20 @@ class FusedMoE(nn.Layer):
|
||||
|
||||
if not (expert_id - self.expert_id_offset >= 0 and expert_id - self.expert_id_offset < self.num_local_experts):
|
||||
return
|
||||
|
||||
if not param._is_initialized():
|
||||
param.initialize()
|
||||
|
||||
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||
|
||||
if self.ep_size > 1 or weight_need_transpose:
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
|
||||
if weight_need_transpose:
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
|
||||
if shard_id is None:
|
||||
# 1.gate up fused in disk
|
||||
if weight_need_transpose:
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
output_size = param[expert_id - self.expert_id_offset].shape[SHARD_ID_TO_SHARDED_DIM["gate"]]
|
||||
shard_offsets = [
|
||||
# (shard_id, shard_offset, shard_size)
|
||||
@@ -305,10 +304,14 @@ class FusedMoE(nn.Layer):
|
||||
loaded_weight_shard = slice_fn(
|
||||
loaded_weight, SHARD_ID_TO_SHARDED_DIM[shard_id], shard_offset, shard_offset + shard_size
|
||||
)
|
||||
self.weight_loader(param, loaded_weight_shard, expert_id, shard_id, "fused")
|
||||
self._load_expert_weight(
|
||||
param=param,
|
||||
expert_id=expert_id,
|
||||
loaded_weight=loaded_weight_shard,
|
||||
shard_id=shard_id,
|
||||
shard_dim=SHARD_ID_TO_SHARDED_DIM[shard_id],
|
||||
)
|
||||
else:
|
||||
if weight_need_transpose and source != "fused":
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
# 2.gate up splited in disk
|
||||
assert shard_id in ["gate", "down", "up"]
|
||||
self._load_expert_weight(
|
||||
|
||||
Reference in New Issue
Block a user