mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -20,5 +20,5 @@ PACKAGE = "fastdeploy.model_executor.ops.iluvatar"
|
||||
import_custom_ops(PACKAGE, "..base.fastdeploy_base_ops", globals())
|
||||
import_custom_ops(PACKAGE, ".fastdeploy_ops", globals())
|
||||
|
||||
from .moe_ops import iluvatar_moe_expert_ffn as moe_expert_ffn # noqa: E402, F401
|
||||
from .paged_attention import paged_attention # noqa: E402, F401
|
||||
from .moe_ops import iluvatar_moe_expert_ffn as moe_expert_ffn # noqa: F401
|
||||
from .paged_attention import paged_attention # noqa: F401
|
||||
|
||||
@@ -28,8 +28,13 @@ def group_gemm(
|
||||
scale: paddle.Tensor,
|
||||
output: paddle.Tensor,
|
||||
):
|
||||
assert (input.dim() == 2 and tokens_expert_prefix_sum.dim() == 1
|
||||
and weight.dim() == 3 and scale.dim() == 2 and output.dim() == 2)
|
||||
assert (
|
||||
input.dim() == 2
|
||||
and tokens_expert_prefix_sum.dim() == 1
|
||||
and weight.dim() == 3
|
||||
and scale.dim() == 2
|
||||
and output.dim() == 2
|
||||
)
|
||||
num_tokens = input.shape[0]
|
||||
dim_in = input.shape[1]
|
||||
dim_out = weight.shape[1]
|
||||
@@ -66,7 +71,8 @@ def group_gemm(
|
||||
weight_i,
|
||||
weight_scale=scale_i,
|
||||
weight_dtype="int8",
|
||||
group_size=-1)
|
||||
group_size=-1,
|
||||
)
|
||||
|
||||
|
||||
def iluvatar_moe_expert_ffn(
|
||||
@@ -90,13 +96,24 @@ def iluvatar_moe_expert_ffn(
|
||||
assert quant_method in ("weight_only_int8")
|
||||
assert not used_in_ep_low_latency
|
||||
tokens_expert_prefix_sum_cpu = tokens_expert_prefix_sum.to("cpu")
|
||||
up_gate_proj_output = paddle.empty([permute_input.shape[0], up_gate_proj_weight.shape[1]],
|
||||
dtype=permute_input.dtype)
|
||||
group_gemm(permute_input, tokens_expert_prefix_sum_cpu, up_gate_proj_weight,
|
||||
up_gate_proj_scale, up_gate_proj_output)
|
||||
up_gate_proj_output = paddle.empty(
|
||||
[permute_input.shape[0], up_gate_proj_weight.shape[1]],
|
||||
dtype=permute_input.dtype,
|
||||
)
|
||||
group_gemm(
|
||||
permute_input,
|
||||
tokens_expert_prefix_sum_cpu,
|
||||
up_gate_proj_weight,
|
||||
up_gate_proj_scale,
|
||||
up_gate_proj_output,
|
||||
)
|
||||
act_out = swiglu(up_gate_proj_output)
|
||||
output = paddle.empty([act_out.shape[0], down_proj_weight.shape[1]],
|
||||
dtype=act_out.dtype)
|
||||
group_gemm(act_out, tokens_expert_prefix_sum_cpu, down_proj_weight, down_proj_scale,
|
||||
output)
|
||||
output = paddle.empty([act_out.shape[0], down_proj_weight.shape[1]], dtype=act_out.dtype)
|
||||
group_gemm(
|
||||
act_out,
|
||||
tokens_expert_prefix_sum_cpu,
|
||||
down_proj_weight,
|
||||
down_proj_scale,
|
||||
output,
|
||||
)
|
||||
return output
|
||||
|
||||
@@ -15,32 +15,51 @@
|
||||
"""
|
||||
|
||||
import paddle
|
||||
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.iluvatar import paged_attn
|
||||
except ImportError:
|
||||
paged_attn = None
|
||||
|
||||
|
||||
def paged_attention(q: paddle.Tensor,
|
||||
k_cache: paddle.Tensor,
|
||||
v_cache: paddle.Tensor,
|
||||
block_tables: paddle.Tensor,
|
||||
seq_lens: paddle.Tensor,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
block_size: int,
|
||||
max_context_len: int,
|
||||
alibi_slopes: paddle.Tensor = None,
|
||||
causal: bool = True,
|
||||
window_left: int = -1,
|
||||
window_right: int = -1,
|
||||
softcap: float = 0.0,
|
||||
use_cuda_graph: bool = False,
|
||||
use_sqrt_alibi: bool = False,
|
||||
k: paddle.Tensor = None,
|
||||
v: paddle.Tensor = None):
|
||||
output = paged_attn(q, k_cache, v_cache, block_tables, seq_lens,
|
||||
alibi_slopes, k, v, num_kv_heads, scale, block_size,
|
||||
max_context_len, causal, window_left, window_right,
|
||||
softcap, use_cuda_graph, use_sqrt_alibi)
|
||||
def paged_attention(
|
||||
q: paddle.Tensor,
|
||||
k_cache: paddle.Tensor,
|
||||
v_cache: paddle.Tensor,
|
||||
block_tables: paddle.Tensor,
|
||||
seq_lens: paddle.Tensor,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
block_size: int,
|
||||
max_context_len: int,
|
||||
alibi_slopes: paddle.Tensor = None,
|
||||
causal: bool = True,
|
||||
window_left: int = -1,
|
||||
window_right: int = -1,
|
||||
softcap: float = 0.0,
|
||||
use_cuda_graph: bool = False,
|
||||
use_sqrt_alibi: bool = False,
|
||||
k: paddle.Tensor = None,
|
||||
v: paddle.Tensor = None,
|
||||
):
|
||||
output = paged_attn(
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
alibi_slopes,
|
||||
k,
|
||||
v,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_size,
|
||||
max_context_len,
|
||||
causal,
|
||||
window_left,
|
||||
window_right,
|
||||
softcap,
|
||||
use_cuda_graph,
|
||||
use_sqrt_alibi,
|
||||
)
|
||||
return output[0] if isinstance(output, list) else output
|
||||
|
||||
Reference in New Issue
Block a user