mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Metax] fix release2.4 and support cudagraph (#5547)
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Co-authored-by: xiaozude <xiaozude@outlook.com>
This commit is contained in:
@@ -64,6 +64,8 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
encoder_block_shape_q: int = -1
|
||||
decoder_block_shape_q: int = -1
|
||||
_fuse_kernel_compute_dtype: str = "bf16"
|
||||
seq_lens_dec: paddle.Tensor = None
|
||||
block_table_dec: paddle.Tensor = None
|
||||
|
||||
# pd_disaggregation
|
||||
kv_signal_metadata: Optional[paddle.Tensor] = None
|
||||
@@ -135,6 +137,12 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
shape=[max_num_seqs, 1, 1, self.head_dim],
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.attention_metadata.seq_lens_dec = paddle.empty(
|
||||
shape=[fd_config.scheduler_config.max_num_seqs, 1], dtype="int32"
|
||||
)
|
||||
self.attention_metadata.block_table_dec = paddle.empty(
|
||||
shape=[fd_config.scheduler_config.max_num_seqs, self.head_dim], dtype="int32"
|
||||
)
|
||||
|
||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
||||
@@ -229,8 +237,9 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
self.batch_ids_prefill = paddle.to_tensor(self.prefill_info_dict["batch_ids"])
|
||||
self.batch_ids_decode = paddle.to_tensor(self.decode_info_dict["batch_ids"])
|
||||
self.seq_lens_dec = forward_meta.seq_lens_decoder[self.batch_ids_decode, 0]
|
||||
self.block_table_dec = forward_meta.block_tables[self.batch_ids_decode, :]
|
||||
self.attention_metadata.seq_lens_dec.copy_(forward_meta.seq_lens_decoder[self.batch_ids_decode, 0])
|
||||
self.attention_metadata.block_table_dec.copy_(forward_meta.block_tables[self.batch_ids_decode, :])
|
||||
|
||||
# update prefilling rope
|
||||
self.update_rotary_embs_prefill(forward_meta)
|
||||
# update decoding rope
|
||||
@@ -296,13 +305,18 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
bs = self.batch_ids_decode.shape[0]
|
||||
if self.enable_mm:
|
||||
index = paddle.concat(
|
||||
[self.batch_ids_decode.view([-1, 1]), self.seq_lens_dec.to("int64").view([-1, 1])], axis=1
|
||||
[self.batch_ids_decode.view([-1, 1]), self.attention_metadata.seq_lens_dec.to("int64").view([-1, 1])],
|
||||
axis=1,
|
||||
)
|
||||
rot_cos = paddle.gather_nd(forward_meta.rotary_embs[:, 0, 0, :, 0, :], index).view([bs, 1, 1, -1])
|
||||
rot_sin = paddle.gather_nd(forward_meta.rotary_embs[:, 1, 0, :, 0, :], index).view([bs, 1, 1, -1])
|
||||
else:
|
||||
rot_cos = paddle.gather(forward_meta.rotary_embs[0, 0, :, 0, :], self.seq_lens_dec).view([bs, 1, 1, -1])
|
||||
rot_sin = paddle.gather(forward_meta.rotary_embs[1, 0, :, 0, :], self.seq_lens_dec).view([bs, 1, 1, -1])
|
||||
rot_cos = paddle.gather(
|
||||
forward_meta.rotary_embs[0, 0, :, 0, :], self.attention_metadata.seq_lens_dec
|
||||
).view([bs, 1, 1, -1])
|
||||
rot_sin = paddle.gather(
|
||||
forward_meta.rotary_embs[1, 0, :, 0, :], self.attention_metadata.seq_lens_dec
|
||||
).view([bs, 1, 1, -1])
|
||||
self.attention_metadata.rotary_cos_decode[:bs].copy_(
|
||||
paddle.repeat_interleave(rot_cos, repeats=2, axis=-1).astype(self.dtype)
|
||||
)
|
||||
@@ -476,8 +490,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
q,
|
||||
forward_meta.caches[k_cache_id],
|
||||
forward_meta.caches[v_cache_id],
|
||||
self.seq_lens_dec,
|
||||
self.block_table_dec,
|
||||
self.attention_metadata.seq_lens_dec,
|
||||
self.attention_metadata.block_table_dec,
|
||||
k,
|
||||
v,
|
||||
rotary_cos=None,
|
||||
|
||||
@@ -221,12 +221,9 @@ class MetaxMLAAttentionBackend(AttentionBackend):
|
||||
"""
|
||||
Calculate kv cache shape for MLA
|
||||
"""
|
||||
return (
|
||||
max_num_blocks,
|
||||
1,
|
||||
self.block_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
)
|
||||
key_cache_shape = [max_num_blocks, 1, self.block_size, self.kv_lora_rank + self.qk_rope_head_dim]
|
||||
value_cache_shape = []
|
||||
return key_cache_shape, value_cache_shape
|
||||
|
||||
def compute_flash_mla(
|
||||
self,
|
||||
|
||||
+19
-99
@@ -15,6 +15,7 @@
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Callable
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
@@ -66,25 +67,12 @@ class MetaxCutlassUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
layer.up_gate_proj_bias.set_value(stacked_up_gate_proj_bias)
|
||||
layer.down_proj_bias.set_value(stacked_down_proj_bias)
|
||||
|
||||
def compute_ffn(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
permute_input: paddle.Tensor,
|
||||
token_nums_per_expert: paddle.Tensor,
|
||||
expert_idx_per_token: paddle.Tensor,
|
||||
used_in_ep_low_latency: bool = False,
|
||||
estimate_total_token_nums: int = -1,
|
||||
):
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def apply_ep_prefill(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP prefill method.
|
||||
@@ -96,6 +84,7 @@ class MetaxCutlassUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP decoder method.
|
||||
@@ -107,70 +96,12 @@ class MetaxCutlassUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
"""
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
"""
|
||||
if layer.topk_method == "noaux_tc":
|
||||
gate_out = gate(x.cast("float32"))
|
||||
|
||||
gate_out, topk_weights, topk_idx = get_moe_scores(
|
||||
gate_out,
|
||||
layer.n_group,
|
||||
layer.topk_group,
|
||||
layer.top_k,
|
||||
layer.routed_scaling_factor,
|
||||
layer.gate_correction_bias,
|
||||
getattr(layer, "renormalize", True),
|
||||
)
|
||||
|
||||
(
|
||||
permute_input,
|
||||
token_nums_per_expert,
|
||||
permute_indices_per_token,
|
||||
topk_weights,
|
||||
topk_idx,
|
||||
) = moe_expert_dispatch(
|
||||
x,
|
||||
gate_out,
|
||||
layer.top_k,
|
||||
False,
|
||||
True,
|
||||
)
|
||||
|
||||
ffn_out = self.compute_ffn(layer, permute_input, token_nums_per_expert, None)
|
||||
|
||||
fused_moe_out = moe_expert_reduce(
|
||||
ffn_out,
|
||||
topk_weights,
|
||||
permute_indices_per_token,
|
||||
topk_idx,
|
||||
None,
|
||||
False,
|
||||
1.0,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
fused_moe_out = fused_expert_moe(
|
||||
x,
|
||||
gate.weight,
|
||||
getattr(layer, self.added_weight_attrs[0]),
|
||||
getattr(layer, self.added_weight_attrs[1]),
|
||||
None,
|
||||
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
|
||||
None,
|
||||
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
|
||||
"weight_only_int8",
|
||||
layer.top_k,
|
||||
True,
|
||||
False,
|
||||
)
|
||||
|
||||
return fused_moe_out
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MetaxCutlassMoEMethod(MoEMethodBase):
|
||||
@@ -189,35 +120,12 @@ class MetaxCutlassMoEMethod(MoEMethodBase):
|
||||
layer.up_gate_proj_weight.set_value(stacked_up_gate_proj_weights)
|
||||
layer.down_proj_weight.set_value(stacked_down_proj_weights)
|
||||
|
||||
def compute_ffn(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
permute_input: paddle.Tensor,
|
||||
token_nums_per_expert: paddle.Tensor,
|
||||
expert_idx_per_token: paddle.Tensor,
|
||||
used_in_ep_low_latency: bool = False,
|
||||
estimate_total_token_nums: int = -1,
|
||||
):
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
"""
|
||||
return moe_expert_ffn(
|
||||
permute_input,
|
||||
token_nums_per_expert,
|
||||
getattr(layer, self.added_weight_attrs[0]),
|
||||
getattr(layer, self.added_weight_attrs[1]),
|
||||
None,
|
||||
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
|
||||
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
|
||||
expert_idx_per_token, # expert_idx_per_token: only for w4a8
|
||||
self.moe_quant_type,
|
||||
)
|
||||
|
||||
def apply_ep_prefill(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP prefill method.
|
||||
@@ -229,6 +137,7 @@ class MetaxCutlassMoEMethod(MoEMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP decoder method.
|
||||
@@ -240,6 +149,7 @@ class MetaxCutlassMoEMethod(MoEMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
@@ -282,7 +192,17 @@ class MetaxCutlassMoEMethod(MoEMethodBase):
|
||||
else:
|
||||
expert_idx_per_token = expert_idx_per_token.cast("int64")
|
||||
|
||||
ffn_out = self.compute_ffn(layer, permute_input, token_nums_per_expert, expert_idx_per_token)
|
||||
ffn_out = moe_expert_ffn(
|
||||
permute_input,
|
||||
token_nums_per_expert,
|
||||
getattr(layer, self.added_weight_attrs[0]),
|
||||
getattr(layer, self.added_weight_attrs[1]),
|
||||
None,
|
||||
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
|
||||
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
|
||||
expert_idx_per_token, # expert_idx_per_token: only for w4a8
|
||||
self.moe_quant_type,
|
||||
)
|
||||
|
||||
fused_moe_out = moe_expert_reduce(
|
||||
ffn_out,
|
||||
|
||||
Reference in New Issue
Block a user