mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Iluvatar] Support wi4a16 group_gemm (#7078)
This commit is contained in:
@@ -18,76 +18,35 @@ from typing import Optional
|
||||
|
||||
import paddle
|
||||
from paddle.nn.functional import swiglu
|
||||
from paddle.nn.quant import weight_only_linear
|
||||
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.iluvatar import (
|
||||
restore_tokens_per_expert,
|
||||
w8a16_group_gemm,
|
||||
w8a16_group_gemv,
|
||||
wi4a16_group_gemm,
|
||||
)
|
||||
except ImportError:
|
||||
except:
|
||||
w8a16_group_gemm = None
|
||||
w8a16_group_gemv = None
|
||||
wi4a16_group_gemm = None
|
||||
restore_tokens_per_expert = None
|
||||
|
||||
|
||||
def group_gemm(
|
||||
input: paddle.Tensor,
|
||||
def _pre_process_expert_ffn(
|
||||
moe_phase: str,
|
||||
quant_method: str,
|
||||
tokens_expert_prefix_sum: paddle.Tensor,
|
||||
weight: paddle.Tensor,
|
||||
scale: paddle.Tensor,
|
||||
output: paddle.Tensor,
|
||||
):
|
||||
assert (
|
||||
input.dim() == 2
|
||||
and tokens_expert_prefix_sum.dim() == 1
|
||||
and weight.dim() == 3
|
||||
and scale.dim() == 2
|
||||
and output.dim() == 2
|
||||
)
|
||||
num_tokens = input.shape[0]
|
||||
dim_in = input.shape[1]
|
||||
dim_out = weight.shape[1]
|
||||
num_experts = weight.shape[0]
|
||||
|
||||
# check shape
|
||||
assert tokens_expert_prefix_sum.shape == [
|
||||
num_experts,
|
||||
]
|
||||
assert weight.shape == [num_experts, dim_out, dim_in]
|
||||
assert scale.shape == [num_experts, dim_out]
|
||||
assert output.shape == [num_tokens, dim_out]
|
||||
|
||||
# check dtype
|
||||
assert input.dtype in (paddle.float16, paddle.bfloat16)
|
||||
assert scale.dtype == input.dtype and output.dtype == input.dtype
|
||||
assert tokens_expert_prefix_sum.dtype == paddle.int64
|
||||
assert weight.dtype == paddle.int8
|
||||
|
||||
# check others
|
||||
assert tokens_expert_prefix_sum.place.is_cpu_place()
|
||||
assert tokens_expert_prefix_sum[-1] == num_tokens
|
||||
for i in range(num_experts):
|
||||
expert_start = 0 if i == 0 else tokens_expert_prefix_sum[i - 1]
|
||||
expert_end = tokens_expert_prefix_sum[i]
|
||||
if expert_start == expert_end:
|
||||
continue
|
||||
input_i = input[expert_start:expert_end]
|
||||
weight_i = weight[i]
|
||||
scale_i = scale[i]
|
||||
# avoid d2d?
|
||||
output[expert_start:expert_end] = weight_only_linear(
|
||||
input_i, weight_i, weight_scale=scale_i, weight_dtype="int8", group_size=-1
|
||||
)
|
||||
|
||||
|
||||
def _pre_process_expert_ffn(moe_phase: str, tokens_expert_prefix_sum: paddle.Tensor):
|
||||
if moe_phase == "decode":
|
||||
group_gemm_func = w8a16_group_gemv
|
||||
tokens_per_expert = restore_tokens_per_expert(tokens_expert_prefix_sum).to("int32")
|
||||
if quant_method == "weight_only_int8":
|
||||
if moe_phase == "decode":
|
||||
group_gemm_func = w8a16_group_gemv
|
||||
tokens_per_expert = restore_tokens_per_expert(tokens_expert_prefix_sum).to("int32")
|
||||
else:
|
||||
group_gemm_func = w8a16_group_gemm
|
||||
tokens_per_expert = tokens_expert_prefix_sum
|
||||
else:
|
||||
group_gemm_func = w8a16_group_gemm
|
||||
group_gemm_func = wi4a16_group_gemm
|
||||
tokens_per_expert = tokens_expert_prefix_sum
|
||||
return group_gemm_func, tokens_per_expert
|
||||
|
||||
@@ -99,16 +58,25 @@ def iluvatar_moe_expert_ffn(
|
||||
down_proj_weight: paddle.Tensor,
|
||||
up_gate_proj_bias: Optional[paddle.Tensor],
|
||||
up_gate_proj_scale: Optional[paddle.Tensor],
|
||||
up_gate_proj_zeros: Optional[paddle.Tensor],
|
||||
down_proj_scale: Optional[paddle.Tensor],
|
||||
down_proj_zeros: Optional[paddle.Tensor],
|
||||
quant_method: str,
|
||||
group_size: int,
|
||||
moe_phase: str,
|
||||
):
|
||||
assert up_gate_proj_bias is None
|
||||
assert up_gate_proj_scale is not None
|
||||
assert down_proj_scale is not None
|
||||
assert quant_method in ("weight_only_int8")
|
||||
group_gemm_func, tokens_per_expert = _pre_process_expert_ffn(moe_phase, tokens_expert_prefix_sum)
|
||||
ffn1_output = group_gemm_func(permute_input, up_gate_proj_weight, up_gate_proj_scale, tokens_per_expert, -1)
|
||||
if quant_method == "weight_only_int4":
|
||||
assert up_gate_proj_zeros is not None
|
||||
assert down_proj_zeros is not None
|
||||
group_gemm_func, tokens_per_expert = _pre_process_expert_ffn(moe_phase, quant_method, tokens_expert_prefix_sum)
|
||||
ffn1_output = group_gemm_func(
|
||||
permute_input, up_gate_proj_weight, up_gate_proj_scale, up_gate_proj_zeros, tokens_per_expert, group_size
|
||||
)
|
||||
act_out = swiglu(ffn1_output)
|
||||
output = group_gemm_func(act_out, down_proj_weight, down_proj_scale, tokens_per_expert, -1)
|
||||
output = group_gemm_func(
|
||||
act_out, down_proj_weight, down_proj_scale, down_proj_zeros, tokens_per_expert, group_size
|
||||
)
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user