mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Iluvatar] Support CudaGraph and optimize flash_attn_unpadded and fused_neox_rope_embedding (#6553)
This commit is contained in:
@@ -19,9 +19,10 @@ PACKAGE = "fastdeploy.model_executor.ops.iluvatar"
|
||||
|
||||
import_custom_ops(PACKAGE, ".fastdeploy_ops", globals())
|
||||
|
||||
from .moe_ops import iluvatar_moe_expert_ffn as moe_expert_ffn # noqa: F401
|
||||
from .paged_attention import ( # noqa: F401
|
||||
from .attention_ops import ( # noqa: F401
|
||||
flash_attn_unpadded,
|
||||
mixed_fused_paged_attention,
|
||||
paged_attention,
|
||||
prefill_fused_paged_attention,
|
||||
)
|
||||
from .moe_ops import iluvatar_moe_expert_ffn as moe_expert_ffn # noqa: F401
|
||||
|
||||
+19
@@ -18,6 +18,7 @@ import paddle
|
||||
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.iluvatar import (
|
||||
cuinfer_flash_attn_unpadded,
|
||||
mixed_fused_paged_attn,
|
||||
paged_attn,
|
||||
prefill_fused_paged_attn,
|
||||
@@ -188,3 +189,21 @@ def mixed_fused_paged_attention(
|
||||
rope_batch_stride,
|
||||
is_interleaved_rope_mode,
|
||||
)
|
||||
|
||||
|
||||
def flash_attn_unpadded(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
scale,
|
||||
causal=False,
|
||||
training=False,
|
||||
):
|
||||
output = cuinfer_flash_attn_unpadded(
|
||||
query, key, value, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal, scale, training
|
||||
)
|
||||
return output, None # return_softmax
|
||||
@@ -21,9 +21,13 @@ from paddle.nn.functional import swiglu
|
||||
from paddle.nn.quant import weight_only_linear
|
||||
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.iluvatar import w8a16_group_gemm
|
||||
from fastdeploy.model_executor.ops.iluvatar import (
|
||||
w8a16_group_gemm,
|
||||
w8a16_group_gemv,
|
||||
)
|
||||
except ImportError:
|
||||
w8a16_group_gemm = None
|
||||
w8a16_group_gemv = None
|
||||
|
||||
|
||||
def group_gemm(
|
||||
@@ -76,6 +80,15 @@ def group_gemm(
|
||||
)
|
||||
|
||||
|
||||
def _select_group_gemm_algo(moe_phase: str):
|
||||
# if moe_phase == "decode":
|
||||
if False:
|
||||
group_gemm_func = w8a16_group_gemv
|
||||
else:
|
||||
group_gemm_func = w8a16_group_gemm
|
||||
return group_gemm_func
|
||||
|
||||
|
||||
def iluvatar_moe_expert_ffn(
|
||||
permute_input: paddle.Tensor,
|
||||
tokens_expert_prefix_sum: paddle.Tensor,
|
||||
@@ -88,6 +101,7 @@ def iluvatar_moe_expert_ffn(
|
||||
expert_idx_per_token: Optional[paddle.Tensor],
|
||||
quant_method: str,
|
||||
used_in_ep_low_latency: bool,
|
||||
moe_phase: str,
|
||||
):
|
||||
assert up_gate_proj_bias is None
|
||||
assert up_gate_proj_scale is not None
|
||||
@@ -96,10 +110,8 @@ def iluvatar_moe_expert_ffn(
|
||||
assert expert_idx_per_token is None
|
||||
assert quant_method in ("weight_only_int8")
|
||||
assert not used_in_ep_low_latency
|
||||
tokens_expert_prefix_sum_cpu = tokens_expert_prefix_sum.to("cpu")
|
||||
ffn1_output = w8a16_group_gemm(
|
||||
permute_input, up_gate_proj_weight, up_gate_proj_scale, tokens_expert_prefix_sum_cpu, -1
|
||||
)
|
||||
group_gemm_func = _select_group_gemm_algo(moe_phase)
|
||||
ffn1_output = group_gemm_func(permute_input, up_gate_proj_weight, up_gate_proj_scale, tokens_expert_prefix_sum, -1)
|
||||
act_out = swiglu(ffn1_output)
|
||||
output = w8a16_group_gemm(act_out, down_proj_weight, down_proj_scale, tokens_expert_prefix_sum_cpu, -1)
|
||||
output = group_gemm_func(act_out, down_proj_weight, down_proj_scale, tokens_expert_prefix_sum, -1)
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user