[Iluvatar GPU] Optimze attention and moe performance (#3234)

This commit is contained in:
yzwu
2025-08-08 10:51:24 +08:00
committed by GitHub
parent 37569cca86
commit fbdd6b0663
24 changed files with 1130 additions and 1653 deletions
@@ -20,6 +20,11 @@ import paddle
from paddle.incubate.nn.functional import swiglu
from paddle.nn.quant import weight_only_linear
try:
from fastdeploy.model_executor.ops.iluvatar import w8a16_group_gemm
except ImportError:
w8a16_group_gemm = None
def group_gemm(
input: paddle.Tensor,
@@ -67,53 +72,32 @@ def group_gemm(
scale_i = scale[i]
# avoid d2d?
output[expert_start:expert_end] = weight_only_linear(
input_i,
weight_i,
weight_scale=scale_i,
weight_dtype="int8",
group_size=-1,
input_i, weight_i, weight_scale=scale_i, weight_dtype="int8", group_size=-1
)
def iluvatar_moe_expert_ffn(
permute_input: paddle.Tensor,
tokens_expert_prefix_sum: paddle.Tensor,
up_gate_proj_weight: paddle.Tensor,
down_proj_weight: paddle.Tensor,
up_gate_proj_bias: Optional[paddle.Tensor],
up_gate_proj_scale: Optional[paddle.Tensor],
down_proj_scale: Optional[paddle.Tensor],
down_proj_in_scale: Optional[paddle.Tensor],
ffn1_weight: paddle.Tensor,
ffn2_weight: paddle.Tensor,
ffn1_bias: Optional[paddle.Tensor],
ffn1_scale: Optional[paddle.Tensor],
ffn2_scale: Optional[paddle.Tensor],
ffn2_in_scale: Optional[paddle.Tensor],
expert_idx_per_token: Optional[paddle.Tensor],
quant_method: str,
used_in_ep_low_latency: bool,
):
assert up_gate_proj_bias is None
assert up_gate_proj_scale is not None
assert down_proj_scale is not None
assert down_proj_in_scale is None
assert ffn1_bias is None
assert ffn1_scale is not None
assert ffn2_scale is not None
assert ffn2_in_scale is None
assert expert_idx_per_token is None
assert quant_method in ("weight_only_int8")
assert not used_in_ep_low_latency
tokens_expert_prefix_sum_cpu = tokens_expert_prefix_sum.to("cpu")
up_gate_proj_output = paddle.empty(
[permute_input.shape[0], up_gate_proj_weight.shape[1]],
dtype=permute_input.dtype,
)
group_gemm(
permute_input,
tokens_expert_prefix_sum_cpu,
up_gate_proj_weight,
up_gate_proj_scale,
up_gate_proj_output,
)
act_out = swiglu(up_gate_proj_output)
output = paddle.empty([act_out.shape[0], down_proj_weight.shape[1]], dtype=act_out.dtype)
group_gemm(
act_out,
tokens_expert_prefix_sum_cpu,
down_proj_weight,
down_proj_scale,
output,
)
ffn1_output = w8a16_group_gemm(permute_input, ffn1_weight, ffn1_scale, tokens_expert_prefix_sum_cpu, -1)
act_out = swiglu(ffn1_output)
output = w8a16_group_gemm(act_out, ffn2_weight, ffn2_scale, tokens_expert_prefix_sum_cpu, -1)
return output
@@ -39,8 +39,11 @@ def paged_attention(
softcap: float = 0.0,
use_cuda_graph: bool = False,
use_sqrt_alibi: bool = False,
merged_qkv: bool = False,
k: paddle.Tensor = None,
v: paddle.Tensor = None,
rope_sin: paddle.Tensor = None,
rope_cos: paddle.Tensor = None,
):
output = paged_attn(
q,
@@ -51,6 +54,8 @@ def paged_attention(
alibi_slopes,
k,
v,
rope_sin,
rope_cos,
num_kv_heads,
scale,
block_size,
@@ -61,5 +66,6 @@ def paged_attention(
softcap,
use_cuda_graph,
use_sqrt_alibi,
merged_qkv,
)
return output[0] if isinstance(output, list) else output