mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[V1 Loader] Support Ernie text(moe and dense) (#3110)
* new loader support 0.3B * fix weight * support parallel load * support parallel load * fix slice * support moe * delete code * perfect code * perfect code
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddleformers.utils.log import logger
|
||||
@@ -110,13 +111,18 @@ class FusedMoE(nn.Layer):
|
||||
self.weight_key_map = weight_key_map
|
||||
|
||||
self.use_method = envs.FD_MOE_BACKEND.lower()
|
||||
self.gate_correction_bias = None
|
||||
self.moe_tag = moe_tag
|
||||
if self.ep_size > 1:
|
||||
expert_id_offset = expert_id_offset + self.ep_rank * self.num_local_experts
|
||||
|
||||
self.expert_id_offset = expert_id_offset
|
||||
|
||||
self.gate_correction_bias_key = self.weight_key_map.get("gate_correction_bias_key", None)
|
||||
if self.gate_correction_bias_key is not None:
|
||||
self.moe_use_gate_correction_bias = True
|
||||
else:
|
||||
self.moe_use_gate_correction_bias = False
|
||||
|
||||
# used for deepseek_v3
|
||||
self.topk_method = topk_method
|
||||
self.topk_group = topk_group
|
||||
@@ -175,20 +181,33 @@ class FusedMoE(nn.Layer):
|
||||
|
||||
if shard_id is None:
|
||||
# 1.gate up fused in disk
|
||||
return
|
||||
# 2.gate up splited in disk
|
||||
assert shard_id in ["gate", "down", "up"]
|
||||
expert_param = param[expert_id]
|
||||
if current_platform.is_cuda():
|
||||
SHARD_ID_TO_SHARDED_DIM = {"gate": 1, "down": 0, "up": 1}
|
||||
if self.tp_size > 1:
|
||||
shard_offsets = [
|
||||
# (shard_id, shard_offset, shard_size)
|
||||
("gate", 0, self.moe_intermediate_size * self.tp_size),
|
||||
("up", self.moe_intermediate_size * self.tp_size, self.moe_intermediate_size * self.tp_size),
|
||||
]
|
||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||
loaded_weight_shard = loaded_weight[..., shard_offset : shard_offset + shard_size]
|
||||
self.weight_loader(param, loaded_weight_shard, expert_id, shard_id)
|
||||
else:
|
||||
expert_param = param[expert_id]
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
expert_param.copy_(loaded_weight, False)
|
||||
else:
|
||||
SHARD_ID_TO_SHARDED_DIM = {"gate": 0, "down": 1, "up": 0}
|
||||
self._load_expert_weight(
|
||||
expert_param=expert_param,
|
||||
shard_dim=SHARD_ID_TO_SHARDED_DIM[shard_id],
|
||||
loaded_weight=loaded_weight,
|
||||
shard_id=shard_id,
|
||||
)
|
||||
# 2.gate up splited in disk
|
||||
assert shard_id in ["gate", "down", "up"]
|
||||
if current_platform.is_cuda():
|
||||
SHARD_ID_TO_SHARDED_DIM = {"gate": 1, "down": 0, "up": 1}
|
||||
else:
|
||||
SHARD_ID_TO_SHARDED_DIM = {"gate": 0, "down": 1, "up": 0}
|
||||
self._load_expert_weight(
|
||||
param=param,
|
||||
expert_id=expert_id,
|
||||
shard_dim=SHARD_ID_TO_SHARDED_DIM[shard_id],
|
||||
loaded_weight=loaded_weight,
|
||||
shard_id=shard_id,
|
||||
)
|
||||
|
||||
def _load_gate_up_weight(self, expert_param, shard_dim, loaded_weight, shard_id):
|
||||
tensor_size = expert_param.shape[shard_dim] // 2
|
||||
@@ -198,7 +217,10 @@ class FusedMoE(nn.Layer):
|
||||
expert_param = expert_param[..., tensor_size:] if shard_dim else expert_param[tensor_size:, ...]
|
||||
|
||||
if self.tp_size > 1:
|
||||
size = loaded_weight.get_shape()[-1]
|
||||
if isinstance(loaded_weight, np.ndarray):
|
||||
size = loaded_weight.shape[-1]
|
||||
else:
|
||||
size = loaded_weight.get_shape()[-1]
|
||||
block_size = size // self.tp_size
|
||||
shard_offset = self.tp_rank * block_size
|
||||
shard_size = (self.tp_rank + 1) * block_size
|
||||
@@ -215,7 +237,10 @@ class FusedMoE(nn.Layer):
|
||||
|
||||
def _load_down_weight(self, expert_param, shard_dim, loaded_weight, shard_id):
|
||||
if self.tp_size > 1:
|
||||
size = loaded_weight.get_shape()[shard_dim]
|
||||
if isinstance(loaded_weight, np.ndarray):
|
||||
size = loaded_weight.shape[shard_dim]
|
||||
else:
|
||||
size = loaded_weight.get_shape()[shard_dim]
|
||||
block_size = size // self.tp_size
|
||||
shard_offset = self.tp_rank * block_size
|
||||
shard_size = (self.tp_rank + 1) * block_size
|
||||
@@ -231,11 +256,13 @@ class FusedMoE(nn.Layer):
|
||||
|
||||
def _load_expert_weight(
|
||||
self,
|
||||
expert_param,
|
||||
param,
|
||||
expert_id,
|
||||
shard_dim,
|
||||
loaded_weight,
|
||||
shard_id,
|
||||
):
|
||||
expert_param = param[expert_id]
|
||||
if shard_id == "down":
|
||||
self._load_down_weight(expert_param, shard_dim, loaded_weight, shard_id)
|
||||
elif shard_id in ["gate", "up"]:
|
||||
@@ -244,29 +271,32 @@ class FusedMoE(nn.Layer):
|
||||
@classmethod
|
||||
def make_expert_params_mapping(
|
||||
cls,
|
||||
ckpt_gate_proj_name: str,
|
||||
ckpt_down_proj_name: str,
|
||||
ckpt_up_proj_name: str,
|
||||
param_gate_up_proj_name: str,
|
||||
param_down_proj_name: str,
|
||||
num_experts: int,
|
||||
ckpt_expert_key_name: str = "experts",
|
||||
ckpt_gate_proj_name: Optional[str] = None,
|
||||
ckpt_up_proj_name: Optional[str] = None,
|
||||
ckpt_down_proj_name: Optional[str] = None,
|
||||
ckpt_gate_up_proj_name: Optional[str] = None,
|
||||
param_gate_up_proj_name: Optional[str] = None,
|
||||
param_down_proj_name: Optional[str] = None,
|
||||
ckpt_expert_key_name: str = "experts",
|
||||
) -> list[tuple[str, str, int, str]]:
|
||||
param_name_maping = [
|
||||
("gate", ckpt_gate_proj_name),
|
||||
("down", ckpt_down_proj_name),
|
||||
("up", ckpt_up_proj_name),
|
||||
]
|
||||
param_name_maping = []
|
||||
|
||||
if ckpt_gate_up_proj_name:
|
||||
param_name_maping.append((None, ckpt_gate_up_proj_name))
|
||||
if ckpt_gate_proj_name:
|
||||
param_name_maping.append(("gate", ckpt_gate_proj_name))
|
||||
if ckpt_down_proj_name:
|
||||
param_name_maping.append(("down", ckpt_down_proj_name))
|
||||
if ckpt_up_proj_name:
|
||||
param_name_maping.append(("up", ckpt_up_proj_name))
|
||||
|
||||
return [
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
(
|
||||
(
|
||||
param_gate_up_proj_name
|
||||
if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
|
||||
if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name, ckpt_gate_up_proj_name]
|
||||
else param_down_proj_name
|
||||
),
|
||||
f"{ckpt_expert_key_name}.{expert_id}.{weight_name}.",
|
||||
@@ -505,11 +535,6 @@ class FusedMoE(nn.Layer):
|
||||
load_state_dict function.
|
||||
"""
|
||||
if not is_rearrange:
|
||||
self.gate_correction_bias_key = self.weight_key_map.get("gate_correction_bias_key", None)
|
||||
if self.gate_correction_bias_key is not None and self.gate_correction_bias_key in state_dict:
|
||||
self.moe_use_gate_correction_bias = True
|
||||
else:
|
||||
self.moe_use_gate_correction_bias = False
|
||||
if self.moe_use_gate_correction_bias:
|
||||
gate_correction_bias_tensor = self.extract_gate_correction_bias(
|
||||
self.gate_correction_bias_key, state_dict
|
||||
|
||||
Reference in New Issue
Block a user