[V1 Loader] Support Ernie text(moe and dense) (#3110)

* new loader support 0.3B

* fix weight

* support parallel load

* support parallel load

* fix slice

* support moe

* delete code

* perfect code

* perfect code
This commit is contained in:
YuanRisheng
2025-08-14 20:25:28 +08:00
committed by GitHub
parent ab60292f89
commit 09c979f3dd
6 changed files with 218 additions and 85 deletions
+88 -49
View File
@@ -16,6 +16,7 @@
from typing import Optional
import numpy as np
import paddle
from paddle import nn
@@ -392,30 +393,48 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
)
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
# 1.fused gate_up in disk
# 2.split gate up
assert loaded_shard_id in ["gate", "up"]
output_dim = getattr(param, "output_dim", None)
# Tensor parallelism splits the weight along the output_dim
if output_dim is not None:
dim = -1
size = loaded_weight.get_shape()[dim]
block_size = size // self.nranks
shard_offset = self.local_rank * block_size
shard_size = (self.local_rank + 1) * block_size
loaded_weight = loaded_weight[..., shard_offset:shard_size]
if loaded_shard_id is None:
# Loaded weight is already fused on disk.
if self.nranks != 1:
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("gate", 0, self.output_size * self.nranks // 2),
("up", self.output_size * self.nranks // 2, self.output_size * self.nranks // 2),
]
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = loaded_weight[..., shard_offset : shard_offset + shard_size]
self.weight_loader(param, loaded_weight_shard, shard_id)
else:
loaded_weight = get_tensor(loaded_weight)
param.copy_(loaded_weight, False)
else:
# 1.fused gate_up in disk
# 2.split gate up
assert loaded_shard_id in ["gate", "up"]
output_dim = getattr(param, "output_dim", None)
# Tensor parallelism splits the weight along the output_dim
if output_dim is not None:
dim = -1
if isinstance(loaded_weight, np.ndarray):
size = loaded_weight.shape[dim]
else:
size = loaded_weight.get_shape()[dim]
block_size = size // self.nranks
shard_offset = self.local_rank * block_size
shard_size = (self.local_rank + 1) * block_size
loaded_weight = loaded_weight[..., shard_offset:shard_size]
loaded_weight = get_tensor(loaded_weight)
loaded_weight = get_tensor(loaded_weight)
if loaded_shard_id == "gate":
param = param[:, : self.output_size // 2]
elif loaded_shard_id == "up":
param = param[:, self.output_size // 2 :]
if loaded_shard_id == "gate":
param = param[:, : self.output_size // 2]
elif loaded_shard_id == "up":
param = param[:, self.output_size // 2 :]
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
param.copy_(loaded_weight, False)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
param.copy_(loaded_weight, False)
def load_state_dict(self, state_dict: dict):
"""
@@ -486,37 +505,57 @@ class QKVParallelLinear(ColumnParallelLinear):
)
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
# 1.fused qkv in disk
# 2.split q k v
assert loaded_shard_id in ["q", "k", "v"]
output_dim = getattr(param, "output_dim", None)
# Tensor parallelism splits the weight along the output_dim
if output_dim is not None:
dim = -1
size = loaded_weight.get_shape()[dim]
block_size = size // self.nranks
shard_offset = self.local_rank * block_size
shard_size = (self.local_rank + 1) * block_size
loaded_weight = loaded_weight[..., shard_offset:shard_size]
if loaded_shard_id is None:
# Loaded weight is already fused on disk
if self.nranks != 1:
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("q", 0, self.num_heads * self.head_dim),
("k", self.num_heads * self.head_dim, self.kv_num_heads * self.head_dim),
("v", (self.num_heads + self.kv_num_heads) * self.head_dim, self.kv_num_heads * self.head_dim),
]
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = loaded_weight[..., shard_offset : shard_offset + shard_size]
self.weight_loader(param, loaded_weight_shard, shard_id)
else:
loaded_weight = get_tensor(loaded_weight)
split_loaded_weight = loaded_weight
param.copy_(split_loaded_weight, False)
else:
# 1.fused qkv in disk
# 2.split q k v
assert loaded_shard_id in ["q", "k", "v"]
output_dim = getattr(param, "output_dim", None)
# Tensor parallelism splits the weight along the output_dim
if output_dim is not None:
dim = -1
if isinstance(loaded_weight, np.ndarray):
size = loaded_weight.shape[dim]
else:
size = loaded_weight.get_shape()[dim]
block_size = size // self.nranks
shard_offset = self.local_rank * block_size
shard_size = (self.local_rank + 1) * block_size
loaded_weight = loaded_weight[..., shard_offset:shard_size]
loaded_weight = get_tensor(loaded_weight)
loaded_weight = get_tensor(loaded_weight)
if loaded_shard_id == "q":
param = param[:, : self.num_heads_per_rank * self.head_dim]
elif loaded_shard_id == "k":
param = param[
:,
self.num_heads_per_rank
* self.head_dim : (self.num_heads_per_rank + self.kv_num_heads_per_rank)
* self.head_dim,
]
elif loaded_shard_id == "v":
param = param[:, (self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim :]
if loaded_shard_id == "q":
param = param[:, : self.num_heads_per_rank * self.head_dim]
elif loaded_shard_id == "k":
param = param[
:,
self.num_heads_per_rank
* self.head_dim : (self.num_heads_per_rank + self.kv_num_heads_per_rank)
* self.head_dim,
]
elif loaded_shard_id == "v":
param = param[:, (self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim :]
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
param.copy_(loaded_weight, False)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
param.copy_(loaded_weight, False)
def load_weight(self, state_dict: dict):
"""