mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 08:21:53 +08:00
[OP]Unify MoE op with moe_permute path for bf16 GLM (#7164)
This commit is contained in:
@@ -15,10 +15,11 @@
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
template <typename scalar_t>
|
||||
template <typename scalar_t, bool kComputeCumsum>
|
||||
__global__ void cuda_kernel(const scalar_t *__restrict__ topk_ids,
|
||||
int32_t *__restrict__ res,
|
||||
int32_t *__restrict__ res_padded,
|
||||
int32_t *__restrict__ res_padded_cumsum,
|
||||
size_t numel,
|
||||
int num_experts) {
|
||||
extern __shared__ int32_t tokens_per_ep[];
|
||||
@@ -35,48 +36,81 @@ __global__ void cuda_kernel(const scalar_t *__restrict__ topk_ids,
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (size_t i = threadIdx.x; i < num_experts; i += blockDim.x) {
|
||||
res[i] = tokens_per_ep[i];
|
||||
res_padded[i] = (res[i] + 127) / 128 * 128;
|
||||
if constexpr (kComputeCumsum) {
|
||||
if (threadIdx.x == 0) {
|
||||
int32_t running_sum = 0;
|
||||
for (int i = 0; i < num_experts; i++) {
|
||||
int32_t count = tokens_per_ep[i];
|
||||
int32_t padded = (count + 127) / 128 * 128;
|
||||
res[i] = count;
|
||||
res_padded[i] = padded;
|
||||
running_sum += padded;
|
||||
res_padded_cumsum[i] = running_sum;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t i = threadIdx.x; i < num_experts; i += blockDim.x) {
|
||||
res[i] = tokens_per_ep[i];
|
||||
res_padded[i] = (tokens_per_ep[i] + 127) / 128 * 128;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> count_tokens_per_expert_func(
|
||||
const paddle::Tensor &topk_ids, int64_t num_experts) {
|
||||
const paddle::Tensor &topk_ids,
|
||||
int64_t num_experts,
|
||||
bool compute_padded_cumsum) {
|
||||
int topk_ids_numel = topk_ids.shape()[0] * topk_ids.shape()[1];
|
||||
|
||||
int64_t num_rows = compute_padded_cumsum ? 3 : 2;
|
||||
auto token_nums_per_expert = paddle::empty(
|
||||
{2, num_experts}, paddle::DataType::INT32, topk_ids.place());
|
||||
{num_rows, num_experts}, paddle::DataType::INT32, topk_ids.place());
|
||||
|
||||
auto stream = topk_ids.stream();
|
||||
using scalar_t = int64_t;
|
||||
|
||||
// CUDA_CHECK(cudaGetLastError());
|
||||
cuda_kernel<<<1, 1024, num_experts * sizeof(int32_t), stream>>>(
|
||||
topk_ids.data<scalar_t>(),
|
||||
token_nums_per_expert.data<int32_t>(),
|
||||
token_nums_per_expert.data<int32_t>() + num_experts,
|
||||
topk_ids_numel,
|
||||
num_experts);
|
||||
if (compute_padded_cumsum) {
|
||||
cuda_kernel<scalar_t, true>
|
||||
<<<1, 1024, num_experts * sizeof(int32_t), stream>>>(
|
||||
topk_ids.data<scalar_t>(),
|
||||
token_nums_per_expert.data<int32_t>(),
|
||||
token_nums_per_expert.data<int32_t>() + num_experts,
|
||||
token_nums_per_expert.data<int32_t>() + 2 * num_experts,
|
||||
topk_ids_numel,
|
||||
num_experts);
|
||||
} else {
|
||||
cuda_kernel<scalar_t, false>
|
||||
<<<1, 1024, num_experts * sizeof(int32_t), stream>>>(
|
||||
topk_ids.data<scalar_t>(),
|
||||
token_nums_per_expert.data<int32_t>(),
|
||||
token_nums_per_expert.data<int32_t>() + num_experts,
|
||||
nullptr,
|
||||
topk_ids_numel,
|
||||
num_experts);
|
||||
}
|
||||
|
||||
// CUDA_CHECK(cudaGetLastError());
|
||||
return {token_nums_per_expert};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> count_tokens_per_expert_func_infer_dtype(
|
||||
const paddle::DataType &topk_ids_dtype, int64_t num_experts) {
|
||||
const paddle::DataType &topk_ids_dtype,
|
||||
int64_t num_experts,
|
||||
bool compute_padded_cumsum) {
|
||||
return {paddle::DataType::INT32};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> count_tokens_per_expert_func_infer_shape(
|
||||
const std::vector<int64_t> &topk_ids_shape, int64_t num_experts) {
|
||||
return {{2, num_experts}};
|
||||
const std::vector<int64_t> &topk_ids_shape,
|
||||
int64_t num_experts,
|
||||
bool compute_padded_cumsum) {
|
||||
int64_t num_rows = compute_padded_cumsum ? 3 : 2;
|
||||
return {{num_rows, num_experts}};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(count_tokens_per_expert_func)
|
||||
.Inputs({"topk_ids"})
|
||||
.Outputs({"token_nums_per_expert"})
|
||||
.Attrs({"num_experts:int64_t"})
|
||||
.Attrs({"num_experts:int64_t", "compute_padded_cumsum:bool"})
|
||||
.SetKernelFn(PD_KERNEL(count_tokens_per_expert_func))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(count_tokens_per_expert_func_infer_shape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(count_tokens_per_expert_func_infer_dtype));
|
||||
|
||||
Reference in New Issue
Block a user