[Metax] fix release2.4 and support cudagraph (#5547)
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

Co-authored-by: xiaozude <xiaozude@outlook.com>
This commit is contained in:
zhang-chenyi
2025-12-15 14:23:33 +08:00
committed by GitHub
parent 4bd991aa17
commit 77f8ba06e7
5 changed files with 85 additions and 126 deletions
@@ -64,6 +64,8 @@ class FlashAttentionMetadata(AttentionMetadata):
encoder_block_shape_q: int = -1
decoder_block_shape_q: int = -1
_fuse_kernel_compute_dtype: str = "bf16"
seq_lens_dec: paddle.Tensor = None
block_table_dec: paddle.Tensor = None
# pd_disaggregation
kv_signal_metadata: Optional[paddle.Tensor] = None
@@ -135,6 +137,12 @@ class FlashAttentionBackend(AttentionBackend):
shape=[max_num_seqs, 1, 1, self.head_dim],
dtype=self.dtype,
)
self.attention_metadata.seq_lens_dec = paddle.empty(
shape=[fd_config.scheduler_config.max_num_seqs, 1], dtype="int32"
)
self.attention_metadata.block_table_dec = paddle.empty(
shape=[fd_config.scheduler_config.max_num_seqs, self.head_dim], dtype="int32"
)
def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
@@ -229,8 +237,9 @@ class FlashAttentionBackend(AttentionBackend):
self.batch_ids_prefill = paddle.to_tensor(self.prefill_info_dict["batch_ids"])
self.batch_ids_decode = paddle.to_tensor(self.decode_info_dict["batch_ids"])
self.seq_lens_dec = forward_meta.seq_lens_decoder[self.batch_ids_decode, 0]
self.block_table_dec = forward_meta.block_tables[self.batch_ids_decode, :]
self.attention_metadata.seq_lens_dec.copy_(forward_meta.seq_lens_decoder[self.batch_ids_decode, 0])
self.attention_metadata.block_table_dec.copy_(forward_meta.block_tables[self.batch_ids_decode, :])
# update prefilling rope
self.update_rotary_embs_prefill(forward_meta)
# update decoding rope
@@ -296,13 +305,18 @@ class FlashAttentionBackend(AttentionBackend):
bs = self.batch_ids_decode.shape[0]
if self.enable_mm:
index = paddle.concat(
[self.batch_ids_decode.view([-1, 1]), self.seq_lens_dec.to("int64").view([-1, 1])], axis=1
[self.batch_ids_decode.view([-1, 1]), self.attention_metadata.seq_lens_dec.to("int64").view([-1, 1])],
axis=1,
)
rot_cos = paddle.gather_nd(forward_meta.rotary_embs[:, 0, 0, :, 0, :], index).view([bs, 1, 1, -1])
rot_sin = paddle.gather_nd(forward_meta.rotary_embs[:, 1, 0, :, 0, :], index).view([bs, 1, 1, -1])
else:
rot_cos = paddle.gather(forward_meta.rotary_embs[0, 0, :, 0, :], self.seq_lens_dec).view([bs, 1, 1, -1])
rot_sin = paddle.gather(forward_meta.rotary_embs[1, 0, :, 0, :], self.seq_lens_dec).view([bs, 1, 1, -1])
rot_cos = paddle.gather(
forward_meta.rotary_embs[0, 0, :, 0, :], self.attention_metadata.seq_lens_dec
).view([bs, 1, 1, -1])
rot_sin = paddle.gather(
forward_meta.rotary_embs[1, 0, :, 0, :], self.attention_metadata.seq_lens_dec
).view([bs, 1, 1, -1])
self.attention_metadata.rotary_cos_decode[:bs].copy_(
paddle.repeat_interleave(rot_cos, repeats=2, axis=-1).astype(self.dtype)
)
@@ -476,8 +490,8 @@ class FlashAttentionBackend(AttentionBackend):
q,
forward_meta.caches[k_cache_id],
forward_meta.caches[v_cache_id],
self.seq_lens_dec,
self.block_table_dec,
self.attention_metadata.seq_lens_dec,
self.attention_metadata.block_table_dec,
k,
v,
rotary_cos=None,
@@ -221,12 +221,9 @@ class MetaxMLAAttentionBackend(AttentionBackend):
"""
Calculate kv cache shape for MLA
"""
return (
max_num_blocks,
1,
self.block_size,
self.kv_lora_rank + self.qk_rope_head_dim,
)
key_cache_shape = [max_num_blocks, 1, self.block_size, self.kv_lora_rank + self.qk_rope_head_dim]
value_cache_shape = []
return key_cache_shape, value_cache_shape
def compute_flash_mla(
self,
@@ -15,6 +15,7 @@
"""
import os
from typing import Callable
import paddle
from paddle import nn
@@ -66,25 +67,12 @@ class MetaxCutlassUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
layer.up_gate_proj_bias.set_value(stacked_up_gate_proj_bias)
layer.down_proj_bias.set_value(stacked_down_proj_bias)
def compute_ffn(
self,
layer: nn.Layer,
permute_input: paddle.Tensor,
token_nums_per_expert: paddle.Tensor,
expert_idx_per_token: paddle.Tensor,
used_in_ep_low_latency: bool = False,
estimate_total_token_nums: int = -1,
):
"""
Paddle Cutlass compute Fused MoE.
"""
raise NotImplementedError
def apply_ep_prefill(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the EP prefill method.
@@ -96,6 +84,7 @@ class MetaxCutlassUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the EP decoder method.
@@ -107,70 +96,12 @@ class MetaxCutlassUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Paddle Cutlass compute Fused MoE.
"""
"""
Paddle Cutlass compute Fused MoE.
"""
if layer.topk_method == "noaux_tc":
gate_out = gate(x.cast("float32"))
gate_out, topk_weights, topk_idx = get_moe_scores(
gate_out,
layer.n_group,
layer.topk_group,
layer.top_k,
layer.routed_scaling_factor,
layer.gate_correction_bias,
getattr(layer, "renormalize", True),
)
(
permute_input,
token_nums_per_expert,
permute_indices_per_token,
topk_weights,
topk_idx,
) = moe_expert_dispatch(
x,
gate_out,
layer.top_k,
False,
True,
)
ffn_out = self.compute_ffn(layer, permute_input, token_nums_per_expert, None)
fused_moe_out = moe_expert_reduce(
ffn_out,
topk_weights,
permute_indices_per_token,
topk_idx,
None,
False,
1.0,
)
else:
raise NotImplementedError
fused_moe_out = fused_expert_moe(
x,
gate.weight,
getattr(layer, self.added_weight_attrs[0]),
getattr(layer, self.added_weight_attrs[1]),
None,
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
None,
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
"weight_only_int8",
layer.top_k,
True,
False,
)
return fused_moe_out
raise NotImplementedError
class MetaxCutlassMoEMethod(MoEMethodBase):
@@ -189,35 +120,12 @@ class MetaxCutlassMoEMethod(MoEMethodBase):
layer.up_gate_proj_weight.set_value(stacked_up_gate_proj_weights)
layer.down_proj_weight.set_value(stacked_down_proj_weights)
def compute_ffn(
self,
layer: nn.Layer,
permute_input: paddle.Tensor,
token_nums_per_expert: paddle.Tensor,
expert_idx_per_token: paddle.Tensor,
used_in_ep_low_latency: bool = False,
estimate_total_token_nums: int = -1,
):
"""
Paddle Cutlass compute Fused MoE.
"""
return moe_expert_ffn(
permute_input,
token_nums_per_expert,
getattr(layer, self.added_weight_attrs[0]),
getattr(layer, self.added_weight_attrs[1]),
None,
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
expert_idx_per_token, # expert_idx_per_token: only for w4a8
self.moe_quant_type,
)
def apply_ep_prefill(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the EP prefill method.
@@ -229,6 +137,7 @@ class MetaxCutlassMoEMethod(MoEMethodBase):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the EP decoder method.
@@ -240,6 +149,7 @@ class MetaxCutlassMoEMethod(MoEMethodBase):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Paddle Cutlass compute Fused MoE.
@@ -282,7 +192,17 @@ class MetaxCutlassMoEMethod(MoEMethodBase):
else:
expert_idx_per_token = expert_idx_per_token.cast("int64")
ffn_out = self.compute_ffn(layer, permute_input, token_nums_per_expert, expert_idx_per_token)
ffn_out = moe_expert_ffn(
permute_input,
token_nums_per_expert,
getattr(layer, self.added_weight_attrs[0]),
getattr(layer, self.added_weight_attrs[1]),
None,
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
expert_idx_per_token, # expert_idx_per_token: only for w4a8
self.moe_quant_type,
)
fused_moe_out = moe_expert_reduce(
ffn_out,