[Iluvatar] Support wi4a16 group_gemm (#7078)

This commit is contained in:
yzwu
2026-03-30 19:03:51 +08:00
committed by GitHub
parent 18062c55bb
commit 8789329457
13 changed files with 722 additions and 144 deletions
@@ -30,6 +30,7 @@ from fastdeploy.model_executor.ops.iluvatar import (
moe_expert_ffn,
moe_expert_reduce,
)
from fastdeploy.model_executor.ops.iluvatar.utils import wi4a16_weight_quantize
from fastdeploy.model_executor.utils import (
TensorTracker,
free_tensor,
@@ -80,8 +81,11 @@ class IluvatarCutlassMoEMethod(UnquantizedFusedMoEMethod):
getattr(layer, self.added_weight_attrs[1]),
(layer.up_gate_proj_bias if hasattr(layer, "up_gate_proj_bias") else None),
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
(layer.up_gate_proj_weight_zeros if hasattr(layer, "up_gate_proj_weight_zeros") else None),
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
(layer.down_proj_weight_zeros if hasattr(layer, "down_proj_weight_zeros") else None),
self.moe_quant_type,
self.quant_config.group_size,
layer.fd_config.model_config.moe_phase.phase,
)
@@ -214,7 +218,12 @@ class IluvatarCutlassWeightOnlyMoEMethod(IluvatarCutlassMoEMethod):
super().__init__(quant_config)
self.quant_config = quant_config
self.moe_quant_type = self.quant_config.algo
self.pack_num = 1
if self.moe_quant_type == "weight_only_int8":
self.quant_config.group_size = -1
elif self.moe_quant_type == "weight_only_int4":
self.quant_config.group_size = 128
else:
raise NotImplementedError("Iluvarar only support wint8 nand wint4 yet.")
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
"""
@@ -427,24 +436,33 @@ class IluvatarCutlassWeightOnlyMoEMethod(IluvatarCutlassMoEMethod):
# quantized_weight_name
weight_name = self.added_weight_attrs[weight_idx]
unquantized_weight_name = weight_name.replace("quant_weight", "weight")
weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape
weight_dtype = "int8"
# scale
scale_name = self.added_scale_attrs[weight_idx]
scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape
scale_dtype = self.default_dtype
if self.moe_quant_type == "weight_only_int4":
# zeros for int4
zeros = []
zeros_name = scale_name.replace("weight_scale", "weight_zeros")
# 2.crate tmp tensor
weight = paddle.empty(weight_shape, dtype=weight_dtype)
scale = paddle.empty(scale_shape, dtype=scale_dtype)
weight, scale = [], []
# 3.quantize weight
for expert_id in range(layer.num_local_experts):
weight[expert_id], scale[expert_id] = weight_quantize(
getattr(layer, unquantized_weight_name)[expert_id], algo=self.moe_quant_type
)
unquantized_weight = getattr(layer, unquantized_weight_name)[expert_id]
if self.moe_quant_type == "weight_only_int8":
w, s = weight_quantize(unquantized_weight, algo=self.moe_quant_type)
else:
w, s, z = wi4a16_weight_quantize(unquantized_weight)
zeros.append(z)
weight.append(w)
scale.append(s)
weight = paddle.stack(weight, axis=0)
scale = paddle.stack(scale, axis=0)
if self.moe_quant_type == "weight_only_int4":
zeros = paddle.stack(zeros, axis=0)
free_tensor(getattr(layer, unquantized_weight_name))
@@ -453,8 +471,8 @@ class IluvatarCutlassWeightOnlyMoEMethod(IluvatarCutlassMoEMethod):
layer,
weight_name,
layer.create_parameter(
shape=weight_shape,
dtype=weight_dtype,
shape=weight.shape,
dtype=weight.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
@@ -463,14 +481,27 @@ class IluvatarCutlassWeightOnlyMoEMethod(IluvatarCutlassMoEMethod):
layer,
scale_name,
layer.create_parameter(
shape=scale_shape,
dtype=scale_dtype,
shape=scale.shape,
dtype=scale.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).copy_(weight, False)
getattr(layer, scale_name).copy_(scale, False)
if self.moe_quant_type == "weight_only_int4":
# create zeros
setattr(
layer,
zeros_name,
layer.create_parameter(
shape=zeros.shape,
dtype=zeros.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, zeros_name).copy_(zeros, False)
if self.quant_config.is_checkpoint_bf16:
weight_id_map = {"gate_up": 0, "down": 1}
if weight_fully_copied(layer.up_gate_proj_weight):
@@ -18,76 +18,35 @@ from typing import Optional
import paddle
from paddle.nn.functional import swiglu
from paddle.nn.quant import weight_only_linear
try:
from fastdeploy.model_executor.ops.iluvatar import (
restore_tokens_per_expert,
w8a16_group_gemm,
w8a16_group_gemv,
wi4a16_group_gemm,
)
except ImportError:
except:
w8a16_group_gemm = None
w8a16_group_gemv = None
wi4a16_group_gemm = None
restore_tokens_per_expert = None
def group_gemm(
input: paddle.Tensor,
def _pre_process_expert_ffn(
moe_phase: str,
quant_method: str,
tokens_expert_prefix_sum: paddle.Tensor,
weight: paddle.Tensor,
scale: paddle.Tensor,
output: paddle.Tensor,
):
assert (
input.dim() == 2
and tokens_expert_prefix_sum.dim() == 1
and weight.dim() == 3
and scale.dim() == 2
and output.dim() == 2
)
num_tokens = input.shape[0]
dim_in = input.shape[1]
dim_out = weight.shape[1]
num_experts = weight.shape[0]
# check shape
assert tokens_expert_prefix_sum.shape == [
num_experts,
]
assert weight.shape == [num_experts, dim_out, dim_in]
assert scale.shape == [num_experts, dim_out]
assert output.shape == [num_tokens, dim_out]
# check dtype
assert input.dtype in (paddle.float16, paddle.bfloat16)
assert scale.dtype == input.dtype and output.dtype == input.dtype
assert tokens_expert_prefix_sum.dtype == paddle.int64
assert weight.dtype == paddle.int8
# check others
assert tokens_expert_prefix_sum.place.is_cpu_place()
assert tokens_expert_prefix_sum[-1] == num_tokens
for i in range(num_experts):
expert_start = 0 if i == 0 else tokens_expert_prefix_sum[i - 1]
expert_end = tokens_expert_prefix_sum[i]
if expert_start == expert_end:
continue
input_i = input[expert_start:expert_end]
weight_i = weight[i]
scale_i = scale[i]
# avoid d2d?
output[expert_start:expert_end] = weight_only_linear(
input_i, weight_i, weight_scale=scale_i, weight_dtype="int8", group_size=-1
)
def _pre_process_expert_ffn(moe_phase: str, tokens_expert_prefix_sum: paddle.Tensor):
if moe_phase == "decode":
group_gemm_func = w8a16_group_gemv
tokens_per_expert = restore_tokens_per_expert(tokens_expert_prefix_sum).to("int32")
if quant_method == "weight_only_int8":
if moe_phase == "decode":
group_gemm_func = w8a16_group_gemv
tokens_per_expert = restore_tokens_per_expert(tokens_expert_prefix_sum).to("int32")
else:
group_gemm_func = w8a16_group_gemm
tokens_per_expert = tokens_expert_prefix_sum
else:
group_gemm_func = w8a16_group_gemm
group_gemm_func = wi4a16_group_gemm
tokens_per_expert = tokens_expert_prefix_sum
return group_gemm_func, tokens_per_expert
@@ -99,16 +58,25 @@ def iluvatar_moe_expert_ffn(
down_proj_weight: paddle.Tensor,
up_gate_proj_bias: Optional[paddle.Tensor],
up_gate_proj_scale: Optional[paddle.Tensor],
up_gate_proj_zeros: Optional[paddle.Tensor],
down_proj_scale: Optional[paddle.Tensor],
down_proj_zeros: Optional[paddle.Tensor],
quant_method: str,
group_size: int,
moe_phase: str,
):
assert up_gate_proj_bias is None
assert up_gate_proj_scale is not None
assert down_proj_scale is not None
assert quant_method in ("weight_only_int8")
group_gemm_func, tokens_per_expert = _pre_process_expert_ffn(moe_phase, tokens_expert_prefix_sum)
ffn1_output = group_gemm_func(permute_input, up_gate_proj_weight, up_gate_proj_scale, tokens_per_expert, -1)
if quant_method == "weight_only_int4":
assert up_gate_proj_zeros is not None
assert down_proj_zeros is not None
group_gemm_func, tokens_per_expert = _pre_process_expert_ffn(moe_phase, quant_method, tokens_expert_prefix_sum)
ffn1_output = group_gemm_func(
permute_input, up_gate_proj_weight, up_gate_proj_scale, up_gate_proj_zeros, tokens_per_expert, group_size
)
act_out = swiglu(ffn1_output)
output = group_gemm_func(act_out, down_proj_weight, down_proj_scale, tokens_per_expert, -1)
output = group_gemm_func(
act_out, down_proj_weight, down_proj_scale, down_proj_zeros, tokens_per_expert, group_size
)
return output
@@ -0,0 +1,51 @@
import paddle
try:
from fastdeploy.model_executor.ops.iluvatar import wi4a16_weight_quantize_cuda
except:
wi4a16_weight_quantize_cuda = None
def _get_weight_by_group_size(w, group_size):
assert w.dim() == 2
assert group_size in (-1, 32, 64, 128)
if group_size == -1:
quant_weight = w
else:
assert w.shape[-1] % group_size == 0
quant_weight = w.reshape(-1, group_size)
assert paddle.isnan(quant_weight).sum() == 0
return quant_weight
def _pack_int4_to_int8(weight):
return ((weight[:, 1::2] & 0xF) << 4) | (weight[:, 0::2] & 0xF)
def wi4a16_weight_quantize(w, group_size=128):
"""Quantize [k, n] weight to packed int4, scales, zeros (MoE wi4a16)."""
k, n = w.shape
assert k % group_size == 0 and n % 2 == 0
if wi4a16_weight_quantize_cuda is not None:
return wi4a16_weight_quantize_cuda(w.contiguous(), group_size)
else:
# [k, n] -> [n, k]
w = w.T.contiguous()
quant_weight = _get_weight_by_group_size(w, group_size)
wmax = quant_weight.abs().max(axis=1, keepdim=True)
scales = wmax / 7
out = paddle.round(quant_weight.to(paddle.float32) / scales).clamp(-8, 7).to(paddle.int8)
out = _pack_int4_to_int8(
# NOTE: conver to numpy since paddle cannot support &
out.view(w.shape[0], -1)
.T.contiguous()
.cpu()
.numpy(),
)
out = paddle.from_numpy(out).T.contiguous()
scales = scales.view(w.shape[0], -1).T.contiguous()
zeros = paddle.zeros_like(scales)
return out, scales, zeros