[Quantization] Support to load static quant ue8m0 scale of DeepGEMM via v0_loader (#6433)

* support to load static quant ue8m0 scale of deepgemm via v0_loader

* [Fix] Fix ue8m0 scale pack dimension calculation and block size validation

1. Fix pack dimension calculation in fused_moe_triton_backend.py:
   - Changed from `ceil_div(...) // 4` to `(num_scales + 3) // 4` for correct ceiling division
   - This ensures sufficient pack allocation when num_scales is not a multiple of 4

2. Fix block size hardcoding in block_wise_fp8.py:
   - Use `self.quant_config.weight_block_size` instead of hardcoded `[128, 128]`
   - Add assertion to ensure weight_block_size is `[128, 128]` for ue8m0

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
RichardWooSJTU
2026-03-03 11:32:35 +08:00
committed by GitHub
parent 375b5b7b21
commit 61789febb9
3 changed files with 135 additions and 50 deletions
@@ -176,12 +176,22 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
)
else:
layer.weight_shape.reverse()
weight_scale_inv_shape = [
(layer.weight_shape[0] + self.quant_config.weight_block_size[0] - 1)
// self.quant_config.weight_block_size[0],
(layer.weight_shape[1] + self.quant_config.weight_block_size[1] - 1)
// self.quant_config.weight_block_size[1],
]
if not self.quant_config.deepgemm_scale_ue8m0:
weight_scale_inv_shape = [
(layer.weight_shape[0] + self.quant_config.weight_block_size[0] - 1)
// self.quant_config.weight_block_size[0],
(layer.weight_shape[1] + self.quant_config.weight_block_size[1] - 1)
// self.quant_config.weight_block_size[1],
]
else:
num_scales = (
layer.weight_shape[1] + self.quant_config.weight_block_size[1] - 1
) // self.quant_config.weight_block_size[1]
num_scale_packs = (num_scales + 3) // 4
weight_scale_inv_shape = [
layer.weight_shape[0],
num_scale_packs,
]
if self.model_format != "torch" and layer.fd_config.load_config.load_choices == "default_v1":
weight_shape = layer.weight_shape[::-1]
@@ -204,11 +214,19 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight_scale_inv = layer.create_parameter(
shape=weight_scale_inv_shape,
dtype="float32",
is_bias=False,
)
if not self.quant_config.deepgemm_scale_ue8m0:
layer.weight_scale_inv = layer.create_parameter(
shape=weight_scale_inv_shape,
dtype="float32",
is_bias=False,
)
else:
layer.weight_scale_inv = layer.create_parameter(
shape=weight_scale_inv_shape,
dtype="int32",
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
set_weight_attrs(
layer.weight,
@@ -268,9 +286,22 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
def process_loaded_weights(self, layer, weights) -> None:
weight_tensor = weights.transpose([1, 0])
quanted_weight_tensor, weight_block_scale_tensor = per_block_cast_to_fp8(weight_tensor)
if not self.quant_config.deepgemm_scale_ue8m0:
quanted_weight_tensor, weight_block_scale_tensor = per_block_cast_to_fp8(weight_tensor)
else:
weight_block_size = self.quant_config.weight_block_size
assert weight_block_size == [
128,
128,
], f"weight_block_size must be [128, 128] for ue8m0, but got {weight_block_size}"
quanted_weight_tensor, weight_block_scale_tensor = quant_weight_ue8m0(weight_tensor, weight_block_size)
weight_block_scale_tensor = transform_scale_ue8m0(
weight_block_scale_tensor,
mn=quanted_weight_tensor.shape[-2],
weight_block_size=weight_block_size,
)
layer.weight.copy_(quanted_weight_tensor, False)
layer.weight_scale_inv.set_value(weight_block_scale_tensor)
layer.weight_scale_inv.data = weight_block_scale_tensor
def process_prequanted_weights(self, layer, state_dict, is_rearrange: bool = False):
"""