[Feature] support w4afp8 v1_loader and v0_loader(tp>1) (#5757)

* support

* fix

* support w4afp8 v1_loader and v0_loader

* fix

* fix test

* fix test

* fix test

* fix moe.py

* add test_ernie_4_5_w4afp8

* add test

* delete tensor

* fix test

* fix

* add

* fix test
This commit is contained in:
lizexu123
2025-12-30 14:11:52 +08:00
committed by GitHub
parent e78e22ebd5
commit 44a13e4557
7 changed files with 615 additions and 31 deletions
@@ -110,7 +110,11 @@ class Ernie4_5_MoE(nn.Layer):
if hasattr(fd_config.quant_config, "moe_quant_type"):
moe_quant_type = fd_config.quant_config.moe_quant_type
if moe_quant_type == "w4a8" or moe_quant_type == "w4afp8":
if moe_quant_type == "w4a8" or (
moe_quant_type == "w4afp8"
and fd_config.model_config.is_quantized
and not fd_config.quant_config.moe_dynamic_quant
):
weight_key_map = {
"gate_weight_key": f"{prefix}.gate.weight",
"gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias",
@@ -121,6 +125,19 @@ class Ernie4_5_MoE(nn.Layer):
"up_gate_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.activation_scale",
"down_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.down_proj.activation_scale",
}
elif (
moe_quant_type == "w4afp8"
and fd_config.model_config.is_quantized
and fd_config.quant_config.moe_dynamic_quant
):
weight_key_map = {
"gate_weight_key": f"{prefix}.gate.weight",
"gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias",
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.quant_weight",
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.quant_weight",
"up_gate_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.weight_scale",
"down_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.down_proj.weight_scale",
}
elif moe_quant_type == "w4w2":
weight_key_map = {
"gate_weight_key": f"{prefix}.gate.weight",
@@ -223,6 +240,7 @@ class Ernie4_5_MoE(nn.Layer):
gate=self.gate,
forward_meta=forward_meta,
)
if self.num_shared_experts > 0:
s_x = self.shared_experts(hidden_states)
out = out + s_x