diff --git a/custom_ops/xpu_ops/src/ops/fused_noaux_tc.cc b/custom_ops/xpu_ops/src/ops/fused_noaux_tc.cc new file mode 100644 index 0000000000..66dbe7d3f9 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/fused_noaux_tc.cc @@ -0,0 +1,100 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "paddle/extension.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "utility/debug.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +std::vector FusedNoAuxTc(const paddle::Tensor& gating_logits, + const paddle::Tensor& bias, + const int n_group, + const int topk_group, + const int top_k, + const bool apply_norm_weight, + const float routed_scaling_factor) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + + PD_CHECK(apply_norm_weight, "only support apply_norm_weight==true"); + + auto gating_logits_dims = gating_logits.shape(); + int token_num = gating_logits_dims[0]; + int expert_num = gating_logits_dims[1]; + auto topk_idx = paddle::empty( + {token_num, top_k}, paddle::DataType::INT32, gating_logits.place()); + auto topk_weights = paddle::empty( + {token_num, top_k}, paddle::DataType::FLOAT32, gating_logits.place()); + int32_t* block_statistic = nullptr; + if (token_num > 0) { + int ret = infer_ops::moe_sigmoid_topk_norm_fusion( + xpu_ctx->x_context(), + gating_logits.data(), + const_cast(bias.data()), + routed_scaling_factor, + topk_weights.mutable_data(), + topk_idx.mutable_data(), + block_statistic, + token_num, + expert_num, + n_group, + topk_group, + top_k, + 0); + PD_CHECK(ret == 0); + } + + return {gating_logits, + topk_weights, + topk_idx}; // return gating_logits without change +} + +std::vector> FusedNoAuxTcInferShape( + const std::vector& gating_logits_shape, + const std::vector& bias_shape, + const int n_group, + const int topk_group, + const int top_k, + const bool apply_norm_weight, + const float routed_scaling_factor) { + std::vector topk_ids_shape = {gating_logits_shape[0], top_k}; + std::vector topk_weights_shape = {gating_logits_shape[0], top_k}; + return {gating_logits_shape, topk_ids_shape, topk_weights_shape}; +} + +std::vector FusedNoAuxTcInferDtype( + const paddle::DataType& gating_logits_dtype, + const paddle::DataType& bias_dtype) { + return { + gating_logits_dtype, paddle::DataType::INT64, paddle::DataType::FLOAT32}; +} + +PD_BUILD_STATIC_OP(fused_noaux_tc) + .Inputs({"gating_logits", "bias"}) + .Outputs({"gating_logits_out", "topk_ids", "topk_weights"}) + .Attrs({"n_group: int", + "topk_group: int", + "top_k: int", + "apply_norm_weight: bool", + "routed_scaling_factor: float"}) + .SetInplaceMap({{"gating_logits", "gating_logits_out"}}) + .SetKernelFn(PD_KERNEL(FusedNoAuxTc)) + .SetInferShapeFn(PD_INFER_SHAPE(FusedNoAuxTcInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(FusedNoAuxTcInferDtype)); diff --git a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc index 10b5e0bd11..cf7542ce2e 100644 --- a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc +++ b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc @@ -195,6 +195,14 @@ std::vector MoeTopkSelect( const int moe_topk, const bool apply_norm_weight); +std::vector FusedNoAuxTc(const paddle::Tensor& gating_logits, + const paddle::Tensor& bias, + const int n_group, + const int topk_group, + const int top_k, + const bool apply_norm_weight, + const float routed_scaling_factor); + void DraftModelUpdate(const paddle::Tensor& inter_next_tokens, const paddle::Tensor& draft_tokens, const paddle::Tensor& pre_ids, @@ -974,6 +982,17 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("moe_topk"), py::arg("apply_norm_weight")); + m.def("fused_noaux_tc", + &FusedNoAuxTc, + "noaux_tc for Deepseekv3 MoE compute with sigmoid and bias", + py::arg("gating_logits"), + py::arg("bias"), + py::arg("n_group"), + py::arg("topk_group"), + py::arg("top_k"), + py::arg("apply_norm_weight"), + py::arg("routed_scaling_factor")); + m.def("prof_start", &prof_start, "prof_start"); m.def("prof_stop", &prof_stop, "prof_stop"); diff --git a/custom_ops/xpu_ops/test/test_fused_noaux_tc.py b/custom_ops/xpu_ops/test/test_fused_noaux_tc.py new file mode 100644 index 0000000000..4dd6dda0d9 --- /dev/null +++ b/custom_ops/xpu_ops/test/test_fused_noaux_tc.py @@ -0,0 +1,119 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import unittest + +import paddle + +from fastdeploy.model_executor.ops.xpu import fused_noaux_tc + + +class TestMoeRouting(unittest.TestCase): + def setUp(self): + paddle.seed(2024) + + def native_group_topk( + self, + gating_output: paddle.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int, + topk_group: int, + routed_scaling_factor: float, + e_score_correction_bias: paddle.Tensor, + ): + original_scores = paddle.nn.functional.sigmoid(gating_output) + if len(e_score_correction_bias.shape) == 1: + e_score_correction_bias = e_score_correction_bias.unsqueeze(0) + scores = original_scores + e_score_correction_bias + + num_token, n_experts = scores.shape + group_scores = scores.reshape([num_token, num_expert_group, -1]).topk(2, axis=-1)[0].sum(axis=-1) + group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] # [n, top_k_group] + group_mask = paddle.zeros_like(group_scores) # [n, n_group] + group_mask.put_along_axis_(group_idx, 1.0, axis=-1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand([num_token, num_expert_group, n_experts // num_expert_group]) + .reshape([num_token, -1]) + ) + tmp_scores = scores.masked_fill(~score_mask.astype(paddle.bool), float("-inf")) + + topk_ids = paddle.topk(tmp_scores, topk, axis=1)[1] + topk_weights = paddle.take_along_axis(original_scores, topk_ids, axis=1) + + if renormalize: + topk_weights = topk_weights / paddle.sum(topk_weights, axis=1, keepdim=True) + + if routed_scaling_factor != 1.0: + topk_weights = topk_weights * routed_scaling_factor + + return topk_weights, topk_ids + + def test_group_topk(self): + + renormalize = True + + test_cases = [ + # (num_experts, n_group, topk_group, top_k, routed_scaling_factor) + (128, 1, 1, 8, 1.0), # glm45-air + (256, 8, 4, 8, 2.5), # deepseek + ] + + for case_tuple in test_cases: + num_experts, n_group, topk_group, top_k, routed_scaling_factor = case_tuple + for num_tokens in [1, 32, 64, 128]: + gating_output = paddle.rand([num_tokens, num_experts]) + e_score_correction_bias = paddle.rand([1, num_experts]) + + ref_topk_values, ref_topk_idx = self.native_group_topk( + gating_output=gating_output, + topk=top_k, + renormalize=renormalize, + num_expert_group=n_group, + topk_group=topk_group, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + ) + + new_score, topk_values, topk_idx = fused_noaux_tc( + gating_output, + e_score_correction_bias, + n_group, + topk_group, + top_k, + True, # apply_norm_weight = True + routed_scaling_factor, + ) + + equal_topk_value = paddle.allclose(topk_values, ref_topk_values, atol=1e-03, rtol=1e-03).item() + equal_topk_ids = paddle.allclose( + topk_idx.cast("int32"), ref_topk_idx.cast("int32"), atol=0.0, rtol=0.0 + ).item() + print( + f"Test Case[{case_tuple}], num_tokens = {num_tokens}, equal_topk_value: {equal_topk_value}, equal_topk_ids: {equal_topk_ids}" + ) + if not equal_topk_value: + print(f"ref_topk_values = {ref_topk_values}") + print(f"topk_values = {topk_values}") + if not equal_topk_ids: + print(f"ref_topk_idx = {ref_topk_idx}") + print(f"topk_idx = {topk_idx}") + assert equal_topk_value and equal_topk_ids + + +if __name__ == "__main__": + unittest.main()