clean nvfp4 related code (#6644)

This commit is contained in:
周周周
2026-03-05 15:48:33 +08:00
committed by GitHub
parent 63414ccc13
commit cebe6f7dae
4 changed files with 20 additions and 9 deletions
@@ -167,12 +167,17 @@ class ModelOptNvFp4LinearMethod(QuantMethodBase):
layer,
**extra_weight_attrs,
):
# 因为模型存储是列存储的,所以这里需要not一下!
extra_weight_attrs["output_dim"] = not extra_weight_attrs["output_dim"]
weight_shape = layer.weight_shape[::-1]
weight_shape[1] = weight_shape[1] // 2
K = layer.weight_shape[0]
N = layer.weight_shape[1]
# 因为模型的存储时候权重是[N,K//2]
# 所以这里创建的权重是为了契合模型存储的权重!
weight_shape = [N, K // 2]
layer.weight_dtype = "uint8"
input_scale_shape = [1]
weight_scale_shape = [layer.weight_shape[::-1][0], layer.weight_shape[::-1][1] // self.quant_config.group_size]
weight_scale_shape = [N, K // self.quant_config.group_size]
weight_scale_2_shape = [1]
self._create_main_weight(layer, weight_shape, extra_weight_attrs)
@@ -332,7 +337,9 @@ class ModelOptNvFp4LinearMethod(QuantMethodBase):
else:
raise ValueError(f"Unsupported backend: {self.backend}.")
# shape 恢复到[K//2,N]
w = layer.weight.T
# shape 恢复到[K//group_size, N]
w_scale_interleaved = layer.weight_scale_interleaved.T
if backend == "cutlass":
@@ -343,7 +350,8 @@ class ModelOptNvFp4LinearMethod(QuantMethodBase):
out = fp4_gemm(x_fp4, w, x_scale_interleaved, w_scale_interleaved, layer.alpha, output_dtype, backend=backend)
if layer.with_bias:
out = paddle.add(out, layer.bias)
return out.view(*output_shape)
assert out.shape == output_shape
return out
class ModelOptNvFp4FusedMoE(QuantMethodBase):
@@ -385,7 +393,7 @@ class ModelOptNvFp4FusedMoE(QuantMethodBase):
def create_weights(self, layer, **extra_weight_attrs):
"""
Triton MoE create weight process.
NVFP4 MoE create weight.
"""
self.up_gate_proj_weight_shape = [
layer.num_local_experts,