[v1 loader]qwen Offline fp8 (#4036)

* support offline fp8

* update ut

* update ut

* update ut

* fix

* update

* update
This commit is contained in:
bukejiyu
2025-09-15 13:44:11 +08:00
committed by GitHub
parent b1a5b756a3
commit 29ed617f0f
21 changed files with 440 additions and 138 deletions
+17 -10
View File
@@ -206,20 +206,19 @@ class FusedMoE(nn.Layer):
if shard_id is None:
# 1.gate up fused in disk
model_format = getattr(param, "model_format", "")
is_torch_model = model_format == "torch"
weight_need_transpose = getattr(param, "weight_need_transpose", False)
output_size = param[expert_id - self.expert_id_offset].shape[SHARD_ID_TO_SHARDED_DIM["gate"]]
per_rank = output_size // 2
start = self.tp_rank * per_rank
loaded_weight_shard_gate = slice_fn(
loaded_weight, is_torch_model ^ SHARD_ID_TO_SHARDED_DIM["gate"], start, start + per_rank
loaded_weight, weight_need_transpose ^ SHARD_ID_TO_SHARDED_DIM["gate"], start, start + per_rank
)
self._load_gate_up_weight(
param, expert_id, loaded_weight_shard_gate, "gate", SHARD_ID_TO_SHARDED_DIM["gate"], is_sharded=True
)
start_up = output_size // 2 * self.tp_size + self.tp_rank * per_rank
loaded_weight_shard_up = slice_fn(
loaded_weight, is_torch_model ^ SHARD_ID_TO_SHARDED_DIM["up"], start_up, start_up + per_rank
loaded_weight, weight_need_transpose ^ SHARD_ID_TO_SHARDED_DIM["up"], start_up, start_up + per_rank
)
self._load_gate_up_weight(
param, expert_id, loaded_weight_shard_up, "up", SHARD_ID_TO_SHARDED_DIM["up"], is_sharded=True
@@ -236,10 +235,9 @@ class FusedMoE(nn.Layer):
)
def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None, is_sharded=False):
model_format = getattr(param, "model_format", "")
is_torch_model = model_format == "torch"
weight_need_transpose = getattr(param, "weight_need_transpose", False)
if self.tp_size > 1 and not is_sharded:
tp_shard_dim = is_torch_model ^ shard_dim
tp_shard_dim = weight_need_transpose ^ shard_dim
weight_dim = -1 if tp_shard_dim else 0
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
size = loaded_weight.shape[weight_dim]
@@ -275,13 +273,17 @@ class FusedMoE(nn.Layer):
assert expert_param.shape == loaded_weight.shape, (
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})"
)
if expert_param.dtype != loaded_weight.dtype:
if loaded_weight.dtype == paddle.int8 and expert_param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.view(expert_param.dtype)
else:
loaded_weight = loaded_weight.cast(expert_param.dtype)
expert_param.copy_(loaded_weight, False)
def _load_down_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None):
model_format = getattr(param, "model_format", "")
is_torch_model = model_format == "torch"
weight_need_transpose = getattr(param, "weight_need_transpose", False)
if self.tp_size > 1 and shard_dim is not None:
tp_shard_dim = is_torch_model ^ shard_dim
tp_shard_dim = weight_need_transpose ^ shard_dim
dim = -1 if tp_shard_dim else 0
if isinstance(loaded_weight, paddle.Tensor):
size = loaded_weight.shape[dim]
@@ -302,6 +304,11 @@ class FusedMoE(nn.Layer):
assert expert_param.shape == loaded_weight.shape, (
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})"
)
if expert_param.dtype != loaded_weight.dtype:
if loaded_weight.dtype == paddle.int8 and expert_param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.view(expert_param.dtype)
else:
loaded_weight = loaded_weight.cast(expert_param.dtype)
expert_param.copy_(loaded_weight, False)
def _load_expert_weight(