[V1 Loader] support weight_only (#3413)
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled

* support wint4/wint8

* delete smoe case

* update ci

* print log
This commit is contained in:
bukejiyu
2025-08-23 13:13:41 +08:00
committed by GitHub
parent 93e1b63200
commit 77514e3e1e
24 changed files with 1055 additions and 524 deletions
+65 -60
View File
@@ -23,6 +23,7 @@ from paddleformers.utils.log import logger
from fastdeploy import envs
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.utils import slice_fn
from fastdeploy.platforms import current_platform
from fastdeploy.worker.experts_manager import RedundantExpertManger
@@ -78,6 +79,7 @@ class FusedMoE(nn.Layer):
routed_scaling_factor: float = 1.0,
layer_idx: int = -1,
moe_tag: str = "",
gate_correction_bias=None,
weight_key_map: dict = {},
):
"""
@@ -155,9 +157,10 @@ class FusedMoE(nn.Layer):
# It's for RL to build model
self.init_moe_weights()
else:
self.gate_correction_bias_key = self.weight_key_map.get("gate_correction_bias_key", None)
if self.gate_correction_bias_key is not None:
self.gate_correction_bias = self.create_parameter(shape=[1, self.num_experts], dtype="float32")
if gate_correction_bias is not None:
self.gate_correction_bias = gate_correction_bias
else:
self.gate_correction_bias = None
if moe_quant_config:
if (
moe_quant_config
@@ -179,54 +182,72 @@ class FusedMoE(nn.Layer):
def weight_loader(self, param, loaded_weight, expert_id, shard_id: Optional[str] = None):
from fastdeploy.platforms import current_platform
if hasattr(param, "SHARD_ID_TO_SHARDED_DIM"):
SHARD_ID_TO_SHARDED_DIM = param.SHARD_ID_TO_SHARDED_DIM
elif current_platform.is_cuda():
SHARD_ID_TO_SHARDED_DIM = {"gate": 1, "down": 0, "up": 1}
else:
SHARD_ID_TO_SHARDED_DIM = {"gate": 0, "down": 1, "up": 0}
if not param._is_initialized():
param.initialize()
if shard_id is None:
# 1.gate up fused in disk
if self.tp_size > 1:
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("gate", 0, self.moe_intermediate_size * self.tp_size),
("up", self.moe_intermediate_size * self.tp_size, self.moe_intermediate_size * self.tp_size),
]
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = loaded_weight[..., shard_offset : shard_offset + shard_size]
self.weight_loader(param, loaded_weight_shard, expert_id, shard_id)
else:
expert_param = param[expert_id - self.expert_id_offset]
loaded_weight = get_tensor(loaded_weight)
expert_param.copy_(loaded_weight, False)
output_size = param[expert_id - self.expert_id_offset].shape[SHARD_ID_TO_SHARDED_DIM["gate"]]
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("gate", 0, output_size // 2 * self.tp_size),
("up", output_size // 2 * self.tp_size, output_size // 2 * self.tp_size),
]
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = slice_fn(
loaded_weight, SHARD_ID_TO_SHARDED_DIM[shard_id], shard_offset, shard_offset + shard_size
)
self.weight_loader(param, loaded_weight_shard, expert_id, shard_id)
else:
# 2.gate up splited in disk
assert shard_id in ["gate", "down", "up"]
if current_platform.is_cuda():
SHARD_ID_TO_SHARDED_DIM = {"gate": 1, "down": 0, "up": 1}
else:
SHARD_ID_TO_SHARDED_DIM = {"gate": 0, "down": 1, "up": 0}
self._load_expert_weight(
param=param,
expert_id=expert_id,
shard_dim=SHARD_ID_TO_SHARDED_DIM[shard_id],
loaded_weight=loaded_weight,
shard_id=shard_id,
shard_dim=SHARD_ID_TO_SHARDED_DIM[shard_id],
)
def _load_gate_up_weight(self, expert_param, shard_dim, loaded_weight, shard_id):
tensor_size = expert_param.shape[shard_dim] // 2
if shard_id == "gate":
expert_param = expert_param[..., :tensor_size] if shard_dim else expert_param[:tensor_size, ...]
elif shard_id == "up":
expert_param = expert_param[..., tensor_size:] if shard_dim else expert_param[tensor_size:, ...]
def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None):
dim = -1 if shard_dim else 0
if self.tp_size > 1:
if isinstance(loaded_weight, np.ndarray):
size = loaded_weight.shape[-1]
size = loaded_weight.shape[dim]
else:
size = loaded_weight.get_shape()[-1]
size = loaded_weight.get_shape()[dim]
block_size = size // self.tp_size
shard_offset = self.tp_rank * block_size
shard_size = (self.tp_rank + 1) * block_size
loaded_weight = loaded_weight[..., shard_offset:shard_size]
loaded_weight = slice_fn(loaded_weight, shard_dim, shard_offset, shard_size)
loaded_weight = get_tensor(loaded_weight)
expert_param = param[expert_id - self.expert_id_offset]
param_shard_size = expert_param.shape[dim] // 2
if shard_id == "gate":
param_shard_offset = 0
else:
# shard_id == "up":
param_shard_offset = param_shard_size
expert_param = slice_fn(
expert_param, shard_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size
)
if hasattr(param, "tensor_track"):
# for dyn quant
param.tensor_track.mark(
start=param_shard_offset,
end=param_shard_offset + param_shard_size,
batch_id=expert_id - self.expert_id_offset,
)
# To ensure compatibility across backends, apply an extra transpose for GCU and XPU
if expert_param.shape != loaded_weight.shape:
loaded_weight = loaded_weight.transpose([1, 0])
@@ -235,17 +256,22 @@ class FusedMoE(nn.Layer):
)
expert_param.copy_(loaded_weight, False)
def _load_down_weight(self, expert_param, shard_dim, loaded_weight, shard_id):
if self.tp_size > 1:
def _load_down_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None):
if self.tp_size > 1 and shard_dim is not None:
dim = -1 if shard_dim else 0
if isinstance(loaded_weight, np.ndarray):
size = loaded_weight.shape[shard_dim]
size = loaded_weight.shape[dim]
else:
size = loaded_weight.get_shape()[shard_dim]
size = loaded_weight.get_shape()[dim]
block_size = size // self.tp_size
shard_offset = self.tp_rank * block_size
shard_size = (self.tp_rank + 1) * block_size
loaded_weight = loaded_weight[shard_offset:shard_size, ...]
loaded_weight = slice_fn(loaded_weight, shard_dim, shard_offset, shard_size)
loaded_weight = get_tensor(loaded_weight)
expert_param = param[expert_id - self.expert_id_offset]
if hasattr(param, "tensor_track"):
# for dyn quant
param.tensor_track.mark(start=0, batch_id=expert_id - self.expert_id_offset)
# To ensure compatibility across backends, apply an extra transpose for GCU and XPU
if expert_param.shape != loaded_weight.shape:
loaded_weight = loaded_weight.transpose([1, 0])
@@ -258,15 +284,14 @@ class FusedMoE(nn.Layer):
self,
param,
expert_id,
shard_dim,
loaded_weight,
shard_id,
shard_dim=None,
):
expert_param = param[expert_id - self.expert_id_offset]
if shard_id == "down":
self._load_down_weight(expert_param, shard_dim, loaded_weight, shard_id)
self._load_down_weight(param, expert_id, loaded_weight, shard_id, shard_dim)
elif shard_id in ["gate", "up"]:
self._load_gate_up_weight(expert_param, shard_dim, loaded_weight, shard_id)
self._load_gate_up_weight(param, expert_id, loaded_weight, shard_id, shard_dim)
@classmethod
def make_expert_params_mapping(
@@ -314,13 +339,6 @@ class FusedMoE(nn.Layer):
Combines weight shape initialization and parameter creation into a single function.
"""
# Initialize weight shapes
gate_correction_bias_shape = [1, self.num_experts]
if self.fd_config.model_config.moe_use_aux_free:
self.gate_correction_bias = self.create_parameter(
shape=gate_correction_bias_shape,
dtype="float32",
)
up_gate_proj_output_dim = self.moe_intermediate_size * 2
if self.moe_quant_type in ["block_wise_fp8", "wint8"]:
up_gate_proj_weight_shape = [
@@ -535,19 +553,6 @@ class FusedMoE(nn.Layer):
"""
load_state_dict function.
"""
if not is_rearrange:
if self.moe_use_gate_correction_bias:
gate_correction_bias_tensor = self.extract_gate_correction_bias(
self.gate_correction_bias_key, state_dict
)
if self.gate_correction_bias.shape != gate_correction_bias_tensor.shape:
gate_correction_bias_tensor = gate_correction_bias_tensor.reshape(self.gate_correction_bias.shape)
self.gate_correction_bias.set_value(gate_correction_bias_tensor)
else:
self.gate_correction_bias = None
else:
self.gate_correction_bias = None
if is_supported_moe_backend is not None and is_supported_moe_backend(self.quant_method):
if self.fd_config.model_config.is_quantized:
if getattr(self.fd_config.quant_config, "is_permuted", True):