[Quantization] Support to load static quant ue8m0 scale of DeepGEMM via v0_loader (#6433)

* support to load static quant ue8m0 scale of deepgemm via v0_loader

* [Fix] Fix ue8m0 scale pack dimension calculation and block size validation

1. Fix pack dimension calculation in fused_moe_triton_backend.py:
   - Changed from `ceil_div(...) // 4` to `(num_scales + 3) // 4` for correct ceiling division
   - This ensures sufficient pack allocation when num_scales is not a multiple of 4

2. Fix block size hardcoding in block_wise_fp8.py:
   - Use `self.quant_config.weight_block_size` instead of hardcoded `[128, 128]`
   - Add assertion to ensure weight_block_size is `[128, 128]` for ue8m0

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
RichardWooSJTU
2026-03-03 11:32:35 +08:00
committed by GitHub
parent 375b5b7b21
commit 61789febb9
3 changed files with 135 additions and 50 deletions
@@ -223,14 +223,26 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
)
)
up_gate_proj_weight = (
paddle.stack(up_gate_proj_weights, axis=0).transpose([0, 2, 1]).contiguous().view("float8_e4m3fn")
)
down_proj_weight = (
paddle.stack(down_proj_weights, axis=0).transpose([0, 2, 1]).contiguous().view("float8_e4m3fn")
)
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0).transpose([0, 2, 1]).contiguous()
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0).transpose([0, 2, 1]).contiguous()
if not self.quant_config.deepgemm_scale_ue8m0:
up_gate_proj_weight = (
paddle.stack(up_gate_proj_weights, axis=0).transpose([0, 2, 1]).contiguous().view("float8_e4m3fn")
)
down_proj_weight = (
paddle.stack(down_proj_weights, axis=0).transpose([0, 2, 1]).contiguous().view("float8_e4m3fn")
)
up_gate_proj_weight_scale = (
paddle.stack(up_gate_proj_weight_scale, axis=0).transpose([0, 2, 1]).contiguous()
)
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0).transpose([0, 2, 1]).contiguous()
else:
up_gate_proj_weight = (
paddle.stack(up_gate_proj_weights, axis=0).transpose([0, 2, 1]).contiguous().view("float8_e4m3fn")
)
down_proj_weight = (
paddle.stack(down_proj_weights, axis=0).transpose([0, 2, 1]).contiguous().view("float8_e4m3fn")
)
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0).transpose([0, 2, 1])
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0).transpose([0, 2, 1])
name_tensor_map = {
"up_gate_proj_weight": up_gate_proj_weight,
@@ -239,7 +251,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
"down_proj_weight_scale_inv": down_proj_weight_scale,
}
for name, tensor in name_tensor_map.items():
getattr(layer, name).set_value(tensor)
getattr(layer, name).data = tensor
def apply_ep_prefill(
self,
@@ -1384,16 +1384,38 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
layer.hidden_size,
layer.moe_intermediate_size,
]
self.up_gate_proj_scale_shape = [
layer.num_local_experts,
ceil_div(layer.moe_intermediate_size * 2, self.quant_config.weight_block_size[0]),
ceil_div(layer.hidden_size, self.quant_config.weight_block_size[1]),
]
self.down_proj_scale_shape = [
layer.num_local_experts,
ceil_div(layer.hidden_size, self.quant_config.weight_block_size[0]),
ceil_div(layer.moe_intermediate_size, self.quant_config.weight_block_size[1]),
]
if not self.quant_config.deepgemm_scale_ue8m0:
self.up_gate_proj_scale_shape = [
layer.num_local_experts,
ceil_div(layer.moe_intermediate_size * 2, self.quant_config.weight_block_size[0]),
ceil_div(layer.hidden_size, self.quant_config.weight_block_size[1]),
]
self.down_proj_scale_shape = [
layer.num_local_experts,
ceil_div(layer.hidden_size, self.quant_config.weight_block_size[0]),
ceil_div(layer.moe_intermediate_size, self.quant_config.weight_block_size[1]),
]
else:
up_num_scales = ceil_div(
layer.hidden_size,
self.quant_config.weight_block_size[1],
)
up_num_scale_packs = (up_num_scales + 3) // 4
self.up_gate_proj_scale_shape = [
layer.num_local_experts,
layer.moe_intermediate_size * 2,
up_num_scale_packs,
]
down_num_scales = ceil_div(
layer.moe_intermediate_size,
self.quant_config.weight_block_size[1],
)
down_num_scale_packs = (down_num_scales + 3) // 4
self.down_proj_scale_shape = [
layer.num_local_experts,
layer.hidden_size,
down_num_scale_packs,
]
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
self.model_format = extra_weight_attrs.get("model_format")
@@ -1519,24 +1541,44 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
),
)
# weight_scale
setattr(
layer,
up_gate_proj_scale_name,
layer.create_parameter(
shape=up_gate_proj_scale_shape,
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
down_proj_scale_name,
layer.create_parameter(
shape=down_proj_scale_shape,
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),
)
if not self.quant_config.deepgemm_scale_ue8m0:
setattr(
layer,
up_gate_proj_scale_name,
layer.create_parameter(
shape=up_gate_proj_scale_shape,
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
down_proj_scale_name,
layer.create_parameter(
shape=down_proj_scale_shape,
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),
)
else:
setattr(
layer,
up_gate_proj_scale_name,
layer.create_parameter(
shape=up_gate_proj_scale_shape,
dtype="int32",
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
down_proj_scale_name,
layer.create_parameter(
shape=down_proj_scale_shape,
dtype="int32",
default_initializer=paddle.nn.initializer.Constant(0),
),
)
set_weight_attrs(
getattr(layer, up_gate_proj_weight_name),
up_gate_proj_attrs,