[XPU] support noaux_tc (#6326)

This commit is contained in:
lizan1999
2026-02-05 12:04:16 +08:00
committed by GitHub
parent cae2709eff
commit 72edd394d9
3 changed files with 238 additions and 0 deletions
@@ -0,0 +1,100 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <infer_ops.h>
#include <xft_api.h>
#include "paddle/extension.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "utility/debug.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
std::vector<paddle::Tensor> FusedNoAuxTc(const paddle::Tensor& gating_logits,
const paddle::Tensor& bias,
const int n_group,
const int topk_group,
const int top_k,
const bool apply_norm_weight,
const float routed_scaling_factor) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
PD_CHECK(apply_norm_weight, "only support apply_norm_weight==true");
auto gating_logits_dims = gating_logits.shape();
int token_num = gating_logits_dims[0];
int expert_num = gating_logits_dims[1];
auto topk_idx = paddle::empty(
{token_num, top_k}, paddle::DataType::INT32, gating_logits.place());
auto topk_weights = paddle::empty(
{token_num, top_k}, paddle::DataType::FLOAT32, gating_logits.place());
int32_t* block_statistic = nullptr;
if (token_num > 0) {
int ret = infer_ops::moe_sigmoid_topk_norm_fusion(
xpu_ctx->x_context(),
gating_logits.data<float>(),
const_cast<float*>(bias.data<float>()),
routed_scaling_factor,
topk_weights.mutable_data<float>(),
topk_idx.mutable_data<int>(),
block_statistic,
token_num,
expert_num,
n_group,
topk_group,
top_k,
0);
PD_CHECK(ret == 0);
}
return {gating_logits,
topk_weights,
topk_idx}; // return gating_logits without change
}
std::vector<std::vector<int64_t>> FusedNoAuxTcInferShape(
const std::vector<int64_t>& gating_logits_shape,
const std::vector<int64_t>& bias_shape,
const int n_group,
const int topk_group,
const int top_k,
const bool apply_norm_weight,
const float routed_scaling_factor) {
std::vector<int64_t> topk_ids_shape = {gating_logits_shape[0], top_k};
std::vector<int64_t> topk_weights_shape = {gating_logits_shape[0], top_k};
return {gating_logits_shape, topk_ids_shape, topk_weights_shape};
}
std::vector<paddle::DataType> FusedNoAuxTcInferDtype(
const paddle::DataType& gating_logits_dtype,
const paddle::DataType& bias_dtype) {
return {
gating_logits_dtype, paddle::DataType::INT64, paddle::DataType::FLOAT32};
}
PD_BUILD_STATIC_OP(fused_noaux_tc)
.Inputs({"gating_logits", "bias"})
.Outputs({"gating_logits_out", "topk_ids", "topk_weights"})
.Attrs({"n_group: int",
"topk_group: int",
"top_k: int",
"apply_norm_weight: bool",
"routed_scaling_factor: float"})
.SetInplaceMap({{"gating_logits", "gating_logits_out"}})
.SetKernelFn(PD_KERNEL(FusedNoAuxTc))
.SetInferShapeFn(PD_INFER_SHAPE(FusedNoAuxTcInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(FusedNoAuxTcInferDtype));
@@ -195,6 +195,14 @@ std::vector<paddle::Tensor> MoeTopkSelect(
const int moe_topk,
const bool apply_norm_weight);
std::vector<paddle::Tensor> FusedNoAuxTc(const paddle::Tensor& gating_logits,
const paddle::Tensor& bias,
const int n_group,
const int topk_group,
const int top_k,
const bool apply_norm_weight,
const float routed_scaling_factor);
void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
const paddle::Tensor& draft_tokens,
const paddle::Tensor& pre_ids,
@@ -974,6 +982,17 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("moe_topk"),
py::arg("apply_norm_weight"));
m.def("fused_noaux_tc",
&FusedNoAuxTc,
"noaux_tc for Deepseekv3 MoE compute with sigmoid and bias",
py::arg("gating_logits"),
py::arg("bias"),
py::arg("n_group"),
py::arg("topk_group"),
py::arg("top_k"),
py::arg("apply_norm_weight"),
py::arg("routed_scaling_factor"));
m.def("prof_start", &prof_start, "prof_start");
m.def("prof_stop", &prof_stop, "prof_stop");
@@ -0,0 +1,119 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import unittest
import paddle
from fastdeploy.model_executor.ops.xpu import fused_noaux_tc
class TestMoeRouting(unittest.TestCase):
def setUp(self):
paddle.seed(2024)
def native_group_topk(
self,
gating_output: paddle.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int,
topk_group: int,
routed_scaling_factor: float,
e_score_correction_bias: paddle.Tensor,
):
original_scores = paddle.nn.functional.sigmoid(gating_output)
if len(e_score_correction_bias.shape) == 1:
e_score_correction_bias = e_score_correction_bias.unsqueeze(0)
scores = original_scores + e_score_correction_bias
num_token, n_experts = scores.shape
group_scores = scores.reshape([num_token, num_expert_group, -1]).topk(2, axis=-1)[0].sum(axis=-1)
group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] # [n, top_k_group]
group_mask = paddle.zeros_like(group_scores) # [n, n_group]
group_mask.put_along_axis_(group_idx, 1.0, axis=-1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand([num_token, num_expert_group, n_experts // num_expert_group])
.reshape([num_token, -1])
)
tmp_scores = scores.masked_fill(~score_mask.astype(paddle.bool), float("-inf"))
topk_ids = paddle.topk(tmp_scores, topk, axis=1)[1]
topk_weights = paddle.take_along_axis(original_scores, topk_ids, axis=1)
if renormalize:
topk_weights = topk_weights / paddle.sum(topk_weights, axis=1, keepdim=True)
if routed_scaling_factor != 1.0:
topk_weights = topk_weights * routed_scaling_factor
return topk_weights, topk_ids
def test_group_topk(self):
renormalize = True
test_cases = [
# (num_experts, n_group, topk_group, top_k, routed_scaling_factor)
(128, 1, 1, 8, 1.0), # glm45-air
(256, 8, 4, 8, 2.5), # deepseek
]
for case_tuple in test_cases:
num_experts, n_group, topk_group, top_k, routed_scaling_factor = case_tuple
for num_tokens in [1, 32, 64, 128]:
gating_output = paddle.rand([num_tokens, num_experts])
e_score_correction_bias = paddle.rand([1, num_experts])
ref_topk_values, ref_topk_idx = self.native_group_topk(
gating_output=gating_output,
topk=top_k,
renormalize=renormalize,
num_expert_group=n_group,
topk_group=topk_group,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
)
new_score, topk_values, topk_idx = fused_noaux_tc(
gating_output,
e_score_correction_bias,
n_group,
topk_group,
top_k,
True, # apply_norm_weight = True
routed_scaling_factor,
)
equal_topk_value = paddle.allclose(topk_values, ref_topk_values, atol=1e-03, rtol=1e-03).item()
equal_topk_ids = paddle.allclose(
topk_idx.cast("int32"), ref_topk_idx.cast("int32"), atol=0.0, rtol=0.0
).item()
print(
f"Test Case[{case_tuple}], num_tokens = {num_tokens}, equal_topk_value: {equal_topk_value}, equal_topk_ids: {equal_topk_ids}"
)
if not equal_topk_value:
print(f"ref_topk_values = {ref_topk_values}")
print(f"topk_values = {topk_values}")
if not equal_topk_ids:
print(f"ref_topk_idx = {ref_topk_idx}")
print(f"topk_idx = {topk_idx}")
assert equal_topk_value and equal_topk_ids
if __name__ == "__main__":
unittest.main()