diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py index 3b38f026fd..7633ca79b1 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py @@ -179,6 +179,8 @@ class MoEMethodBase(QuantMethodBase): x: paddle.Tensor, gate: nn.Layer, topk_ids_hookfunc: Callable = None, + fc1_latent_proj: nn.Layer = None, + fc2_latent_proj: nn.Layer = None, ) -> paddle.Tensor: """ Apply the EP prefill method. @@ -192,6 +194,8 @@ class MoEMethodBase(QuantMethodBase): x: paddle.Tensor, gate: nn.Layer, topk_ids_hookfunc: Callable = None, + fc1_latent_proj: nn.Layer = None, + fc2_latent_proj: nn.Layer = None, ) -> paddle.Tensor: """ Apply the EP decoder method. @@ -232,13 +236,19 @@ class MoEMethodBase(QuantMethodBase): if layer.fd_config.scheduler_config.splitwise_role == "mixed" and is_moe_start_layer: self.ep_prefill_runner.clean_low_latency_buffer() return self.apply_ep_prefill( - layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts + layer, + x, + gate, + topk_ids_hookfunc, + shared_experts, + fc1_latent_proj, + fc2_latent_proj, ) else: if layer.fd_config.scheduler_config.splitwise_role == "mixed" and is_moe_start_layer: self.ep_decoder_runner.clean_low_latency_buffer() return self.apply_ep_decode( - layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts + layer, x, gate, topk_ids_hookfunc, shared_experts, fc1_latent_proj, fc2_latent_proj ) else: return self.apply_tp(layer, x, gate, topk_ids_hookfunc, fc1_latent_proj, fc2_latent_proj) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 58755f52e2..760a23734a 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -128,6 +128,8 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod): gate: nn.Layer, topk_ids_hookfunc: Callable = None, shared_experts: nn.Layer = None, + fc1_latent_proj: nn.Layer = None, + fc2_latent_proj: nn.Layer = None, ) -> paddle.Tensor: """ Apply the EP prefill method. @@ -275,6 +277,8 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod): gate: nn.Layer, topk_ids_hookfunc: Callable = None, shared_experts: nn.Layer = None, + fc1_latent_proj: nn.Layer = None, + fc2_latent_proj: nn.Layer = None, ) -> paddle.Tensor: """ Apply the EP decoder method. diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index 9c6d174daa..cdaa66678f 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -332,6 +332,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): gate: nn.Layer, topk_ids_hookfunc: Callable = None, shared_experts: nn.Layer = None, + fc1_latent_proj: nn.Layer = None, + fc2_latent_proj: nn.Layer = None, ) -> paddle.Tensor: """ Apply the EP prefill method. @@ -339,7 +341,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): gate_out = gate(x) gate_out = gate_out.cast("float32") - hidden_size = x.shape[1] + hidden_size = layer.hidden_size # 1. Select topk experts and weights topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out) @@ -347,6 +349,9 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): if topk_ids_hookfunc is not None: topk_ids_hookfunc(topk_ids=topk_idx) + if fc1_latent_proj: + x = fc1_latent_proj(x) + # 2. Dynamic compute blockwise quantization scales if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT: x_fp8, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant( @@ -643,6 +648,9 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): if shared_experts is not None: tmp_ffn_out += s_x + if fc2_latent_proj: + tmp_ffn_out = fc2_latent_proj(tmp_ffn_out) + return tmp_ffn_out def apply_ep_decode( @@ -652,6 +660,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): gate: nn.Layer, topk_ids_hookfunc: Callable = None, shared_experts: nn.Layer = None, + fc1_latent_proj: nn.Layer = None, + fc2_latent_proj: nn.Layer = None, ) -> paddle.Tensor: """ Apply the EP decoder method. @@ -665,6 +675,9 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): topk_ids_hookfunc(topk_ids=topk_idx) # 2. EP Dispatch + if fc1_latent_proj: + x = fc1_latent_proj(x) + permute_input, token_nums_per_expert, handle = self.ep_decoder_runner.dispatch( x, topk_idx, topk_weights, use_fp8=True, use_ue8m0=self.quant_config.deepgemm_scale_ue8m0 ) @@ -728,6 +741,10 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): if shared_experts is not None: out += s_x + + if fc2_latent_proj: + out = fc2_latent_proj(out) + return out def apply_tp( diff --git a/tests/operators/test_deepgemm_sm90_prefill_precision.py b/tests/operators/test_deepgemm_sm90_prefill_precision.py new file mode 100644 index 0000000000..fc6eb9eff3 --- /dev/null +++ b/tests/operators/test_deepgemm_sm90_prefill_precision.py @@ -0,0 +1,108 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import unittest + +import numpy as np +import paddle + +from fastdeploy.model_executor.layers.quantization.fp8_utils import deep_gemm + +paddle.set_default_dtype("bfloat16") + + +class TestDeepGemmPrefill(unittest.TestCase): + def setUp(self): + pass + + def one_invoke(self, num_experts, M, N, K): + token_num_in_eatch_batch = (paddle.zeros([num_experts], dtype="int32") + M).numpy().tolist() + total_m = sum(token_num_in_eatch_batch) + block_size = 128 + + raw_x = paddle.randn([total_m, K], dtype="bfloat16").cast(paddle.float8_e4m3fn) + raw_x_scale = paddle.randn([total_m, K // block_size], dtype="float32") + + raw_w = paddle.randn([num_experts, N, K], dtype="bfloat16").cast(paddle.float8_e4m3fn) + raw_w_scale = paddle.randn([num_experts, N // block_size, K // block_size], dtype="float32") + + m_indices = np.zeros([total_m], dtype="int32") + + baseline_out = paddle.empty([total_m, N], dtype="bfloat16") + for i in range(num_experts): + start = sum(token_num_in_eatch_batch[:i]) + end = start + token_num_in_eatch_batch[i] + + this_expert_token = raw_x[start:end].contiguous().cast("float32") + this_expert_token_scale = ( + raw_x_scale[start:end] + .contiguous() + .reshape([0, 0, 1]) + .tile([1, 1, block_size]) + .reshape([0, -1]) + .cast("float32") + ) + tmp0 = this_expert_token * this_expert_token_scale + + this_expert_weight = raw_w[i].contiguous().cast("float32") + this_expert_weight_scale = ( + raw_w_scale[i] + .contiguous() + .reshape([0, 1, -1, 1]) + .tile([1, block_size, 1, block_size]) + .reshape([N, K]) + .cast("float32") + ) + tmp1 = this_expert_weight * this_expert_weight_scale + + out = paddle.matmul(tmp0, tmp1, False, True) + baseline_out[start:end] = out + + m_indices[start:end] = i + + deepgemm_output = paddle.zeros_like(baseline_out) + + m_indices = paddle.to_tensor(m_indices, dtype="int32") + + for i in range(10): + a = paddle.zeros([1024, 1024, 1024]) + 1 + del a + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (raw_x, raw_x_scale.transpose([1, 0]).contiguous().transpose([1, 0])), + (raw_w, raw_w_scale), + deepgemm_output, + m_indices, + ) + + print(baseline_out - deepgemm_output) + + def test_main(self): + # import paddle.profiler as profiler + # p = profiler.Profiler( + # targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.GPU], + # on_trace_ready=profiler.export_chrome_tracing("./profile_log"), + # ) + # p.start() + # p.step() + + self.one_invoke(48, 128 * 20, 2048, 4096) + self.one_invoke(96, 128 * 20, 2048, 2048) + + # p.stop() + + +if __name__ == "__main__": + unittest.main()