mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[v1 loader]qwen Offline fp8 (#4036)
* support offline fp8 * update ut * update ut * update ut * fix * update * update
This commit is contained in:
@@ -206,20 +206,19 @@ class FusedMoE(nn.Layer):
|
||||
|
||||
if shard_id is None:
|
||||
# 1.gate up fused in disk
|
||||
model_format = getattr(param, "model_format", "")
|
||||
is_torch_model = model_format == "torch"
|
||||
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||
output_size = param[expert_id - self.expert_id_offset].shape[SHARD_ID_TO_SHARDED_DIM["gate"]]
|
||||
per_rank = output_size // 2
|
||||
start = self.tp_rank * per_rank
|
||||
loaded_weight_shard_gate = slice_fn(
|
||||
loaded_weight, is_torch_model ^ SHARD_ID_TO_SHARDED_DIM["gate"], start, start + per_rank
|
||||
loaded_weight, weight_need_transpose ^ SHARD_ID_TO_SHARDED_DIM["gate"], start, start + per_rank
|
||||
)
|
||||
self._load_gate_up_weight(
|
||||
param, expert_id, loaded_weight_shard_gate, "gate", SHARD_ID_TO_SHARDED_DIM["gate"], is_sharded=True
|
||||
)
|
||||
start_up = output_size // 2 * self.tp_size + self.tp_rank * per_rank
|
||||
loaded_weight_shard_up = slice_fn(
|
||||
loaded_weight, is_torch_model ^ SHARD_ID_TO_SHARDED_DIM["up"], start_up, start_up + per_rank
|
||||
loaded_weight, weight_need_transpose ^ SHARD_ID_TO_SHARDED_DIM["up"], start_up, start_up + per_rank
|
||||
)
|
||||
self._load_gate_up_weight(
|
||||
param, expert_id, loaded_weight_shard_up, "up", SHARD_ID_TO_SHARDED_DIM["up"], is_sharded=True
|
||||
@@ -236,10 +235,9 @@ class FusedMoE(nn.Layer):
|
||||
)
|
||||
|
||||
def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None, is_sharded=False):
|
||||
model_format = getattr(param, "model_format", "")
|
||||
is_torch_model = model_format == "torch"
|
||||
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||
if self.tp_size > 1 and not is_sharded:
|
||||
tp_shard_dim = is_torch_model ^ shard_dim
|
||||
tp_shard_dim = weight_need_transpose ^ shard_dim
|
||||
weight_dim = -1 if tp_shard_dim else 0
|
||||
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
|
||||
size = loaded_weight.shape[weight_dim]
|
||||
@@ -275,13 +273,17 @@ class FusedMoE(nn.Layer):
|
||||
assert expert_param.shape == loaded_weight.shape, (
|
||||
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})"
|
||||
)
|
||||
if expert_param.dtype != loaded_weight.dtype:
|
||||
if loaded_weight.dtype == paddle.int8 and expert_param.dtype == paddle.float8_e4m3fn:
|
||||
loaded_weight = loaded_weight.view(expert_param.dtype)
|
||||
else:
|
||||
loaded_weight = loaded_weight.cast(expert_param.dtype)
|
||||
expert_param.copy_(loaded_weight, False)
|
||||
|
||||
def _load_down_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None):
|
||||
model_format = getattr(param, "model_format", "")
|
||||
is_torch_model = model_format == "torch"
|
||||
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||
if self.tp_size > 1 and shard_dim is not None:
|
||||
tp_shard_dim = is_torch_model ^ shard_dim
|
||||
tp_shard_dim = weight_need_transpose ^ shard_dim
|
||||
dim = -1 if tp_shard_dim else 0
|
||||
if isinstance(loaded_weight, paddle.Tensor):
|
||||
size = loaded_weight.shape[dim]
|
||||
@@ -302,6 +304,11 @@ class FusedMoE(nn.Layer):
|
||||
assert expert_param.shape == loaded_weight.shape, (
|
||||
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})"
|
||||
)
|
||||
if expert_param.dtype != loaded_weight.dtype:
|
||||
if loaded_weight.dtype == paddle.int8 and expert_param.dtype == paddle.float8_e4m3fn:
|
||||
loaded_weight = loaded_weight.view(expert_param.dtype)
|
||||
else:
|
||||
loaded_weight = loaded_weight.cast(expert_param.dtype)
|
||||
expert_param.copy_(loaded_weight, False)
|
||||
|
||||
def _load_expert_weight(
|
||||
|
||||
Reference in New Issue
Block a user