mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Iluvatar GPU] Optimize attention performance and fix moe load ckpt error (#3651)
This commit is contained in:
@@ -20,4 +20,8 @@ PACKAGE = "fastdeploy.model_executor.ops.iluvatar"
|
||||
import_custom_ops(PACKAGE, ".fastdeploy_ops", globals())
|
||||
|
||||
from .moe_ops import iluvatar_moe_expert_ffn as moe_expert_ffn # noqa: F401
|
||||
from .paged_attention import paged_attention # noqa: F401
|
||||
from .paged_attention import ( # noqa: F401
|
||||
mixed_fused_paged_attention,
|
||||
paged_attention,
|
||||
prefill_fused_paged_attention,
|
||||
)
|
||||
|
||||
@@ -17,9 +17,15 @@
|
||||
import paddle
|
||||
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.iluvatar import paged_attn
|
||||
from fastdeploy.model_executor.ops.iluvatar import (
|
||||
mixed_fused_paged_attn,
|
||||
paged_attn,
|
||||
prefill_fused_paged_attn,
|
||||
)
|
||||
except ImportError:
|
||||
paged_attn = None
|
||||
prefill_fused_paged_attn = None
|
||||
mixed_fused_paged_attn = None
|
||||
|
||||
|
||||
def paged_attention(
|
||||
@@ -28,6 +34,8 @@ def paged_attention(
|
||||
v_cache: paddle.Tensor,
|
||||
block_tables: paddle.Tensor,
|
||||
seq_lens: paddle.Tensor,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
block_size: int,
|
||||
@@ -45,7 +53,7 @@ def paged_attention(
|
||||
rope_sin: paddle.Tensor = None,
|
||||
rope_cos: paddle.Tensor = None,
|
||||
):
|
||||
output = paged_attn(
|
||||
return paged_attn(
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
@@ -56,6 +64,8 @@ def paged_attention(
|
||||
v,
|
||||
rope_sin,
|
||||
rope_cos,
|
||||
num_heads,
|
||||
head_dim,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_size,
|
||||
@@ -68,4 +78,99 @@ def paged_attention(
|
||||
use_sqrt_alibi,
|
||||
merged_qkv,
|
||||
)
|
||||
return output[0] if isinstance(output, list) else output
|
||||
|
||||
|
||||
def prefill_fused_paged_attention(
|
||||
qkv: paddle.Tensor,
|
||||
k_cache: paddle.Tensor,
|
||||
v_cache: paddle.Tensor,
|
||||
block_tables: paddle.Tensor,
|
||||
cu_seqlens_qkv: paddle.Tensor,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
num_kv_heads: int,
|
||||
block_size: int,
|
||||
max_seq_len: int,
|
||||
scale: float,
|
||||
causal: bool = True,
|
||||
q_rope: bool = True,
|
||||
k_rope: bool = True,
|
||||
v_rope: bool = False,
|
||||
rope_sin: paddle.Tensor = None,
|
||||
rope_cos: paddle.Tensor = None,
|
||||
):
|
||||
return prefill_fused_paged_attn(
|
||||
qkv,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
cu_seqlens_qkv,
|
||||
rope_sin,
|
||||
rope_cos,
|
||||
num_heads,
|
||||
head_dim,
|
||||
num_kv_heads,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
scale,
|
||||
causal,
|
||||
q_rope,
|
||||
k_rope,
|
||||
v_rope,
|
||||
)
|
||||
|
||||
|
||||
def mixed_fused_paged_attention(
|
||||
qkv: paddle.Tensor,
|
||||
k_cache: paddle.Tensor,
|
||||
v_cache: paddle.Tensor,
|
||||
prefill_block_tables: paddle.Tensor,
|
||||
decode_block_tables: paddle.Tensor,
|
||||
cu_seqlens_qkv: paddle.Tensor,
|
||||
seq_lens: paddle.Tensor,
|
||||
prefill_num_tokens: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
num_kv_heads: int,
|
||||
block_size: int,
|
||||
max_seq_len: int,
|
||||
scale: float,
|
||||
causal: bool = True,
|
||||
q_rope: bool = True,
|
||||
k_rope: bool = True,
|
||||
v_rope: bool = False,
|
||||
window_left: int = -1,
|
||||
window_right: int = -1,
|
||||
softcap: float = 0.0,
|
||||
use_cuda_graph: bool = False,
|
||||
use_sqrt_alibi: bool = False,
|
||||
rope_sin: paddle.Tensor = None,
|
||||
rope_cos: paddle.Tensor = None,
|
||||
):
|
||||
return mixed_fused_paged_attn(
|
||||
qkv,
|
||||
k_cache,
|
||||
v_cache,
|
||||
prefill_block_tables,
|
||||
decode_block_tables,
|
||||
cu_seqlens_qkv,
|
||||
seq_lens,
|
||||
rope_sin,
|
||||
rope_cos,
|
||||
prefill_num_tokens,
|
||||
num_heads,
|
||||
head_dim,
|
||||
num_kv_heads,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
scale,
|
||||
causal,
|
||||
q_rope,
|
||||
k_rope,
|
||||
v_rope,
|
||||
window_left,
|
||||
window_right,
|
||||
softcap,
|
||||
use_cuda_graph,
|
||||
use_sqrt_alibi,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user