// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Ignore CUTLASS warnings about type punning #pragma once #include "cutlass/numeric_conversion.h" #include "cutlass_extensions/wint_type_traits.h" #include "helper.h" #include "moe/fused_moe_helper.h" namespace phi { __global__ void compute_total_rows_before_expert_kernel( int* sorted_experts, const int64_t sorted_experts_len, const int64_t num_experts, int64_t* total_rows_before_expert) { // First, compute the global tid. We only need 1 thread per expert. const int expert = blockIdx.x * blockDim.x + threadIdx.x; if (expert >= num_experts) return; // This should construct the last index where each expert occurs. total_rows_before_expert[expert] = find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert); } void compute_total_rows_before_expert(int* sorted_indices, const int64_t total_indices, const int64_t num_experts, int64_t* total_rows_before_expert, cudaStream_t stream) { const int threads = std::min(int64_t(1024), num_experts); const int blocks = (num_experts + threads - 1) / threads; compute_total_rows_before_expert_kernel<<>>( sorted_indices, total_indices, num_experts, total_rows_before_expert); } } // namespace phi template void FusedMoeKernel(const paddle::Tensor& input, const paddle::Tensor& gate_weight, const paddle::Tensor& up_gate_proj_weight, const paddle::optional& up_gate_proj_scale, const paddle::optional& up_gate_proj_bias, const paddle::Tensor& down_proj_weight, const paddle::optional& down_proj_scale, const paddle::optional& down_proj_bias, const std::string& quant_method, const int moe_topk, const bool group_moe, const bool norm_topk_prob, paddle::Tensor* output) { using namespace phi; typedef PDTraits traits_; typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; auto* output_data = output->data(); auto fp16_moe_gemm_runner = MoeGemmRunner< DataType_, cutlass::WintQuantTraits>(); auto int8_moe_gemm_runner = MoeGemmRunner< DataType_, cutlass::WintQuantTraits>(); auto int4_moe_gemm_runner = MoeGemmRunner< DataType_, cutlass::WintQuantTraits>(); using NvType = typename traits_::DataType; auto moe_compute = MoeHelper(quant_method, &fp16_moe_gemm_runner, &int8_moe_gemm_runner, &int4_moe_gemm_runner); moe_compute.ComputeFFN( &input, &gate_weight, &up_gate_proj_weight, up_gate_proj_scale ? up_gate_proj_scale.get_ptr() : nullptr, up_gate_proj_bias ? up_gate_proj_bias.get_ptr() : nullptr, &down_proj_weight, down_proj_scale ? down_proj_scale.get_ptr() : nullptr, down_proj_bias ? down_proj_bias.get_ptr() : nullptr, nullptr, moe_topk, group_moe, norm_topk_prob, 1.0, // ComputeFFN "ffn", output); } paddle::Tensor FusedExpertMoeFunc( const paddle::Tensor& input, const paddle::Tensor& gate_weight, const paddle::Tensor& up_gate_proj_weight, const paddle::Tensor& down_proj_weight, const paddle::optional& up_gate_proj_bias, const paddle::optional& up_gate_proj_scale, const paddle::optional& down_proj_bias, const paddle::optional& down_proj_scale, const std::string& quant_method, const int moe_topk, const bool norm_topk_prob, const bool group_moe) { const auto input_type = input.dtype(); auto output = paddle::empty_like(input); switch (input_type) { case paddle::DataType::BFLOAT16: FusedMoeKernel(input, gate_weight, up_gate_proj_weight, up_gate_proj_scale, up_gate_proj_bias, down_proj_weight, down_proj_scale, down_proj_bias, quant_method, moe_topk, group_moe, norm_topk_prob, &output); break; case paddle::DataType::FLOAT16: FusedMoeKernel(input, gate_weight, up_gate_proj_weight, up_gate_proj_scale, up_gate_proj_bias, down_proj_weight, down_proj_scale, down_proj_bias, quant_method, moe_topk, group_moe, norm_topk_prob, &output); break; default: PD_THROW("Unsupported data type for FusedMoeKernel"); } return output; } std::vector FusedExpertMoe( const paddle::Tensor& input, const paddle::Tensor& gate_weight, const paddle::Tensor& up_gate_proj_weight, const paddle::Tensor& down_proj_weight, const paddle::optional& up_gate_proj_bias, const paddle::optional& up_gate_proj_scale, const paddle::optional& down_proj_bias, const paddle::optional& down_proj_scale, const std::string& quant_method, const int moe_topk, const bool norm_topk_prob, const bool group_moe) { return {FusedExpertMoeFunc(input, gate_weight, up_gate_proj_weight, down_proj_weight, up_gate_proj_bias, up_gate_proj_scale, down_proj_bias, down_proj_scale, quant_method, moe_topk, norm_topk_prob, group_moe)}; } std::vector> FusedExpertMoeInferShape( const std::vector& input_shape, const std::vector& gate_weight_shape, const std::vector& up_gate_proj_weight_shape, const std::vector& down_proj_weight_shape, const paddle::optional>& up_gate_proj_bias_shape, const paddle::optional>& up_gate_proj_scale_shape, const paddle::optional>& down_proj_bias_shape, const paddle::optional>& down_proj_scale_shape) { return {input_shape}; } std::vector FusedExpertMoeInferDtype( const paddle::DataType& input_dtype, const paddle::DataType& gate_weight_dtype, const paddle::DataType& up_gate_proj_weight_dtype, const paddle::DataType& down_proj_weight_dtype, const paddle::optional& up_gate_proj_bias_dtype, const paddle::optional& up_gate_proj_scale_dtype, const paddle::optional& down_proj_bias_dtype, const paddle::optional& down_proj_scale_dtype) { return {input_dtype}; } /** * @brief Fused Mixture-of-Experts (MoE) Operator * * This operator combines three key MoE operations into a single optimized * kernel: * 1. moe_dispatch - Routes tokens to top-k experts using gating network * 2. moe_ffn - Processes tokens through parallel expert FFNs * 3. moe_reduce - Combines expert outputs with routing weights * * Key Features: * - Supports both dense and quantized expert weights * - Optimized for GPU execution with fused operations * * Mathematical Formulation: * output = ∑_i^topk(softmax(gate(x))_i * FFN_i(x) * * Reference Components: * moe_dispatch: Selects top-k experts per token and generates permutation * indices moe_ffn: Applies SwiGLU activation expert networks in parallel * moe_reduce: Combines weighted expert outputs and restores original token * order * * Performance Notes: * - Recommended hidden_size multiples of 128 for optimal memory alignment * - For best throughput, num_experts should be powers of 2 */ PD_BUILD_STATIC_OP(fused_expert_moe) .Inputs({"input", "gate_weight", "up_gate_proj_weight", "down_proj_weight", paddle::Optional("up_gate_proj_bias"), paddle::Optional("up_gate_proj_scale"), paddle::Optional("down_proj_bias"), paddle::Optional("down_proj_scale")}) .Outputs({"output"}) .Attrs({"quant_method:std::string", "moe_topk:int", "norm_topk_prob:bool", "group_moe:bool"}) .SetKernelFn(PD_KERNEL(FusedExpertMoe)) .SetInferShapeFn(PD_INFER_SHAPE(FusedExpertMoeInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(FusedExpertMoeInferDtype));