mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 17:11:21 +08:00
[Iluvatar] Optimize decode group_gemm and Support cuda graph for ernie (#6803)
This commit is contained in:
@@ -22,12 +22,14 @@ from paddle.nn.quant import weight_only_linear
|
||||
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.iluvatar import (
|
||||
restore_tokens_per_expert,
|
||||
w8a16_group_gemm,
|
||||
w8a16_group_gemv,
|
||||
)
|
||||
except ImportError:
|
||||
w8a16_group_gemm = None
|
||||
w8a16_group_gemv = None
|
||||
restore_tokens_per_expert = None
|
||||
|
||||
|
||||
def group_gemm(
|
||||
@@ -80,13 +82,14 @@ def group_gemm(
|
||||
)
|
||||
|
||||
|
||||
def _select_group_gemm_algo(moe_phase: str):
|
||||
# if moe_phase == "decode":
|
||||
if False:
|
||||
def _pre_process_expert_ffn(moe_phase: str, tokens_expert_prefix_sum: paddle.Tensor):
|
||||
if moe_phase == "decode":
|
||||
group_gemm_func = w8a16_group_gemv
|
||||
tokens_per_expert = restore_tokens_per_expert(tokens_expert_prefix_sum).to("int32")
|
||||
else:
|
||||
group_gemm_func = w8a16_group_gemm
|
||||
return group_gemm_func
|
||||
tokens_per_expert = tokens_expert_prefix_sum
|
||||
return group_gemm_func, tokens_per_expert
|
||||
|
||||
|
||||
def iluvatar_moe_expert_ffn(
|
||||
@@ -110,8 +113,8 @@ def iluvatar_moe_expert_ffn(
|
||||
assert expert_idx_per_token is None
|
||||
assert quant_method in ("weight_only_int8")
|
||||
assert not used_in_ep_low_latency
|
||||
group_gemm_func = _select_group_gemm_algo(moe_phase)
|
||||
ffn1_output = group_gemm_func(permute_input, up_gate_proj_weight, up_gate_proj_scale, tokens_expert_prefix_sum, -1)
|
||||
group_gemm_func, tokens_per_expert = _pre_process_expert_ffn(moe_phase, tokens_expert_prefix_sum)
|
||||
ffn1_output = group_gemm_func(permute_input, up_gate_proj_weight, up_gate_proj_scale, tokens_per_expert, -1)
|
||||
act_out = swiglu(ffn1_output)
|
||||
output = group_gemm_func(act_out, down_proj_weight, down_proj_scale, tokens_expert_prefix_sum, -1)
|
||||
output = group_gemm_func(act_out, down_proj_weight, down_proj_scale, tokens_per_expert, -1)
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user