[Iluvatar GPU] Optimize attention performance and fix moe load ckpt error (#3651)

This commit is contained in:
yzwu
2025-09-22 21:13:59 +08:00
committed by GitHub
parent 5532e8a323
commit 504461b6b5
17 changed files with 1344 additions and 363 deletions
@@ -20,4 +20,8 @@ PACKAGE = "fastdeploy.model_executor.ops.iluvatar"
import_custom_ops(PACKAGE, ".fastdeploy_ops", globals())
from .moe_ops import iluvatar_moe_expert_ffn as moe_expert_ffn # noqa: F401
from .paged_attention import paged_attention # noqa: F401
from .paged_attention import ( # noqa: F401
mixed_fused_paged_attention,
paged_attention,
prefill_fused_paged_attention,
)
@@ -17,9 +17,15 @@
import paddle
try:
from fastdeploy.model_executor.ops.iluvatar import paged_attn
from fastdeploy.model_executor.ops.iluvatar import (
mixed_fused_paged_attn,
paged_attn,
prefill_fused_paged_attn,
)
except ImportError:
paged_attn = None
prefill_fused_paged_attn = None
mixed_fused_paged_attn = None
def paged_attention(
@@ -28,6 +34,8 @@ def paged_attention(
v_cache: paddle.Tensor,
block_tables: paddle.Tensor,
seq_lens: paddle.Tensor,
num_heads: int,
head_dim: int,
num_kv_heads: int,
scale: float,
block_size: int,
@@ -45,7 +53,7 @@ def paged_attention(
rope_sin: paddle.Tensor = None,
rope_cos: paddle.Tensor = None,
):
output = paged_attn(
return paged_attn(
q,
k_cache,
v_cache,
@@ -56,6 +64,8 @@ def paged_attention(
v,
rope_sin,
rope_cos,
num_heads,
head_dim,
num_kv_heads,
scale,
block_size,
@@ -68,4 +78,99 @@ def paged_attention(
use_sqrt_alibi,
merged_qkv,
)
return output[0] if isinstance(output, list) else output
def prefill_fused_paged_attention(
qkv: paddle.Tensor,
k_cache: paddle.Tensor,
v_cache: paddle.Tensor,
block_tables: paddle.Tensor,
cu_seqlens_qkv: paddle.Tensor,
num_heads: int,
head_dim: int,
num_kv_heads: int,
block_size: int,
max_seq_len: int,
scale: float,
causal: bool = True,
q_rope: bool = True,
k_rope: bool = True,
v_rope: bool = False,
rope_sin: paddle.Tensor = None,
rope_cos: paddle.Tensor = None,
):
return prefill_fused_paged_attn(
qkv,
k_cache,
v_cache,
block_tables,
cu_seqlens_qkv,
rope_sin,
rope_cos,
num_heads,
head_dim,
num_kv_heads,
block_size,
max_seq_len,
scale,
causal,
q_rope,
k_rope,
v_rope,
)
def mixed_fused_paged_attention(
qkv: paddle.Tensor,
k_cache: paddle.Tensor,
v_cache: paddle.Tensor,
prefill_block_tables: paddle.Tensor,
decode_block_tables: paddle.Tensor,
cu_seqlens_qkv: paddle.Tensor,
seq_lens: paddle.Tensor,
prefill_num_tokens: int,
num_heads: int,
head_dim: int,
num_kv_heads: int,
block_size: int,
max_seq_len: int,
scale: float,
causal: bool = True,
q_rope: bool = True,
k_rope: bool = True,
v_rope: bool = False,
window_left: int = -1,
window_right: int = -1,
softcap: float = 0.0,
use_cuda_graph: bool = False,
use_sqrt_alibi: bool = False,
rope_sin: paddle.Tensor = None,
rope_cos: paddle.Tensor = None,
):
return mixed_fused_paged_attn(
qkv,
k_cache,
v_cache,
prefill_block_tables,
decode_block_tables,
cu_seqlens_qkv,
seq_lens,
rope_sin,
rope_cos,
prefill_num_tokens,
num_heads,
head_dim,
num_kv_heads,
block_size,
max_seq_len,
scale,
causal,
q_rope,
k_rope,
v_rope,
window_left,
window_right,
softcap,
use_cuda_graph,
use_sqrt_alibi,
)