Sync v2.0 version of code to github repo

This commit is contained in:
Jiang-Jia-Jun
2025-06-29 23:29:37 +00:00
parent d151496038
commit 92c2cfa2e7
597 changed files with 78776 additions and 22905 deletions
+56 -111
View File
@@ -30,9 +30,9 @@ namespace cg = cooperative_groups;
template<typename T>
__device__ T warpReduceSum(T val){
for(int lane_mask = 16; lane_mask > 0; lane_mask /=2){
val += __shfl_down_sync(0xffffffff, val, lane_mask);
val += __shfl_down_sync(0xffffffff, val, lane_mask);
}
return val;
return val;
}
__global__ void get_expert_token_num(
@@ -88,7 +88,7 @@ __global__ void get_expert_token_num(
sum = (threadIdx.x < KNWARPS) ? warp_sum[laneId] : 0;
sum_padded = (threadIdx.x < KNWARPS) ? warp_sum[laneId + KNWARPS] : 0;
if (warpId == 0) {
sum = warpReduceSum<int>(sum);
sum = warpReduceSum<int>(sum);
sum_padded = warpReduceSum<int>(sum_padded);
}
if (threadIdx.x == 0) {
@@ -167,7 +167,7 @@ __global__ void combine_prmt_back_kernel(
#pragma unroll
for (int vid = 0; vid < VEC_SIZE; vid++) {
res_vec[vid] += static_cast<T>(
row_scale * static_cast<float>(load_vec[vid]) +
row_scale * static_cast<float>(load_vec[vid]) +
static_cast<float>(bias_vec[vid]));
}
} else {
@@ -497,7 +497,7 @@ std::vector<paddle::Tensor> EPMoeExpertDispatch(
place);
auto num_experts_per_rank_tensor = GetEmptyTensor(
{num_experts_per_rank},
paddle::DataType::INT32,
paddle::DataType::INT32,
place);
auto expert_idx_per_token = GetEmptyTensor(
{token_nums_this_rank}, paddle::DataType::INT64, place);
@@ -619,7 +619,7 @@ PD_BUILD_STATIC_OP(ep_moe_expert_dispatch)
"cumsum_idx_gpu",
"expert_idx_per_token"})
.Attrs({
"token_nums_per_expert: std::vector<int>",
"token_nums_per_expert: std::vector<int>",
"token_nums_this_rank: int",
"moe_quant_type: std::string"
})
@@ -672,18 +672,21 @@ __global__ void permute_x_fp8_kernel(const T *src_x,
const int hidden_size_int4 = hidden_size / vec_size;
const int hidden_size_scale = hidden_size / 128;
const int hidden_size_scale_int4 = hidden_size_scale / scale_vec_size;
const int token_nums_feed_to_ffn = token_nums_per_expert_cum[NUM_EXPERTS_PER_RANK-1];
// prmt
for (int64_t s_token_idx = src_token_idx; s_token_idx < token_nums_this_rank_padded; s_token_idx += gridDim.x) {
if (tid == 0) {
for (int i = 0; i < NUM_EXPERTS_PER_RANK; i++) {
const int start_idx = i == 0 ? 0 : token_nums_per_expert_cum[i - 1];
const int end_idx = token_nums_per_expert_cum[i];
if (s_token_idx >= start_idx && s_token_idx < end_idx) {
m_indices[s_token_idx] = i;
break;
}
for (int64_t s_token_idx = src_token_idx; s_token_idx < token_nums_feed_to_ffn; s_token_idx += gridDim.x) {
// the m_indices[s_token_idx] must be a value `i` in [0, NUM_EXPERTS_PER_RANK)
// here we parallel wo find the `i` we want.
for (int i = threadIdx.x; i < NUM_EXPERTS_PER_RANK; i+= blockDim.x) {
const int start_idx = i == 0 ? 0 : token_nums_per_expert_cum[i - 1];
const int end_idx = token_nums_per_expert_cum[i];
if (s_token_idx >= start_idx && s_token_idx < end_idx) {
if ((s_token_idx - start_idx) < token_nums_per_expert[i]) m_indices[s_token_idx] = i;
break;
}
}
if (s_token_idx < num_rows) {
const int64_t *topk_idx_now = topk_idx + s_token_idx * moe_topk;
#pragma unroll
@@ -738,7 +741,8 @@ void EPMoeDispatchFP8Kernel(const paddle::Tensor& input,
paddle::Tensor* m_indices) {
auto stream = input.stream();
auto place = input.place();
const int gridx = min(132 * 8, num_rows);
// const int gridx = min(132 * 8, num_rows);
const int gridx = 132 * 8;
if (num_experts_per_rank == 8) {
permute_x_fp8_kernel<phi::dtype::float8_e4m3fn, 8><<<gridx, 512, 0, stream>>>(
input.data<phi::dtype::float8_e4m3fn>(),
@@ -831,8 +835,31 @@ void EPMoeDispatchFP8Kernel(const paddle::Tensor& input,
token_nums_per_expert_padded_cumsum->data<int64_t>(),
m_indices->data<int>()
);
} else if (num_experts_per_rank == 128) {
permute_x_fp8_kernel<phi::dtype::float8_e4m3fn, 128><<<gridx, 512, 0, stream>>>(
input.data<phi::dtype::float8_e4m3fn>(),
scale.data<float>(),
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
token_nums_per_expert_padded.data<int>(),
moe_topk,
num_rows,
token_nums_this_rank,
token_nums_this_rank_padded,
hidden_size,
permute_input->data<phi::dtype::float8_e4m3fn>(),
permute_scale->data<float>(),
permute_indices_per_token->data<int>(),
dst_weights->data<float>(),
dst_indices->data<int>(),
cumsum_idx_gpu->data<int>(),
token_nums_per_expert_cumsum->data<int64_t>(),
token_nums_per_expert_padded_cumsum->data<int64_t>(),
m_indices->data<int>()
);
} else {
PD_THROW("Not dispatching this num_experts_per_rank for EPMoeDispatchFP8Kernel");
PD_THROW("Not dispatching this num_experts_per_rank(", num_experts_per_rank, ") for EPMoeDispatchFP8Kernel");
}
}
@@ -842,10 +869,8 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
const paddle::Tensor& scale,
const paddle::Tensor& topk_ids,
const paddle::Tensor& topk_weights,
const std::vector<int>& token_nums_per_expert,
const std::vector<int>& token_nums_per_expert_padded,
const int token_nums_this_rank,
const int token_nums_this_rank_padded) {
const paddle::Tensor& num_experts_per_rank_tensor,
const paddle::Tensor& num_experts_per_rank_padded_tensor) {
const auto input_type = input.dtype();
const int moe_topk = topk_ids.dims()[1];
auto place = input.place();
@@ -859,7 +884,10 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
}
const int num_rows = token_rows;
const int hidden_size = input.dims()[input_dims.size() - 1];
const int num_experts_per_rank = token_nums_per_expert.size();
const int num_experts_per_rank = num_experts_per_rank_tensor.dims()[0];
int32_t token_nums_this_rank_padded = token_rows * moe_topk + num_experts_per_rank * (128-1);
// token_nums_this_rank_padded = token_nums_this_rank_padded_useless;
auto permute_input = GetEmptyTensor(
{token_nums_this_rank_padded, hidden_size},
@@ -869,30 +897,8 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
{token_nums_this_rank_padded, hidden_size / 128},
paddle::DataType::FLOAT32,
place);
auto num_experts_per_rank_tensor = GetEmptyTensor(
{num_experts_per_rank},
paddle::DataType::INT32,
place);
auto num_experts_per_rank_padded_tensor = GetEmptyTensor(
{num_experts_per_rank},
paddle::DataType::INT32,
place);
auto m_indices = GetEmptyTensor(
{token_nums_this_rank_padded},
paddle::DataType::INT32,
place);
cudaMemcpyAsync(
num_experts_per_rank_tensor.data<int>(),
token_nums_per_expert.data(),
num_experts_per_rank * sizeof(int),
cudaMemcpyHostToDevice,
input.stream());
cudaMemcpyAsync(
num_experts_per_rank_padded_tensor.data<int>(),
token_nums_per_expert_padded.data(),
num_experts_per_rank * sizeof(int),
cudaMemcpyHostToDevice,
input.stream());
auto m_indices = paddle::full({token_nums_this_rank_padded}, -1, paddle::DataType::INT32, place);
auto token_nums_per_expert_cumsum = GetEmptyTensor({num_experts_per_rank}, paddle::DataType::INT64, place);
auto token_nums_per_expert_padded_cumsum = GetEmptyTensor({num_experts_per_rank}, paddle::DataType::INT64, place);
auto dst_weights = GetEmptyTensor({token_nums_this_rank_padded}, paddle::DataType::FLOAT32, place);
@@ -908,8 +914,8 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
num_experts_per_rank_padded_tensor,
moe_topk,
num_rows,
token_nums_this_rank,
token_nums_this_rank_padded,
-1,
-1,
hidden_size,
num_experts_per_rank,
&permute_input,
@@ -932,61 +938,8 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
m_indices};
}
std::vector<std::vector<int64_t>> EPMoeExpertDispatchFP8InferShape(
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& scale_shape,
const std::vector<int64_t>& topk_ids_shape,
const std::vector<int64_t>& topk_weights_shape,
const std::vector<int>& token_nums_per_expert,
const std::vector<int>& token_nums_per_expert_padded,
const int token_nums_this_rank,
const int token_nums_this_rank_padded) {
int token_rows = -1; // real token row
int moe_topk = topk_ids_shape[1];
if (input_shape.size() == 3) {
token_rows = input_shape[0] * input_shape[1];
} else {
token_rows = input_shape[0];
}
const int expert_num = token_nums_per_expert.size(); // 本地专家个数
const int num_rows = token_rows;
const int hidden_size = input_shape[input_shape.size() - 1];
return {{token_nums_this_rank_padded, hidden_size}, // x
{token_nums_this_rank_padded, hidden_size / 128}, // scale
{expert_num, num_rows},
{expert_num},
{expert_num},
{token_nums_this_rank_padded},
{num_rows, expert_num},
{expert_num},
{token_nums_this_rank_padded}}; // dst_idx per expert
}
std::vector<paddle::DataType> EPMoeExpertDispatchFP8InferDtype(
const paddle::DataType& input_dtype,
const paddle::DataType& scale_dtype,
const paddle::DataType& topk_ids_dtype,
const paddle::DataType& topk_weights_dtype,
const std::vector<int>& token_nums_per_expert,
const std::vector<int>& token_nums_per_expert_padded,
const int token_nums_this_rank,
const int token_nums_this_rank_padded) {
return {input_dtype,
paddle::DataType::FLOAT32,
paddle::DataType::INT32,
paddle::DataType::INT64,
paddle::DataType::INT64,
paddle::DataType::FLOAT32,
paddle::DataType::INT32,
paddle::DataType::INT32,
paddle::DataType::INT32};
}
PD_BUILD_STATIC_OP(ep_moe_expert_dispatch_fp8)
.Inputs({"input", "scale", "topk_ids", "topk_weights"})
.Inputs({"input", "scale", "topk_ids", "topk_weights", "num_experts_per_rank_tensor", "num_experts_per_rank_padded_tensor"})
.Outputs({"permute_input",
"permute_scale",
"permute_indices_per_token",
@@ -996,12 +949,4 @@ PD_BUILD_STATIC_OP(ep_moe_expert_dispatch_fp8)
"dst_indices",
"cumsum_idx_gpu",
"m_indices"})
.Attrs({
"token_nums_per_expert: std::vector<int>",
"token_nums_per_expert_padded: std::vector<int>",
"token_nums_this_rank: int",
"token_nums_this_rank_padded: int",
})
.SetKernelFn(PD_KERNEL(EPMoeExpertDispatchFP8))
.SetInferShapeFn(PD_INFER_SHAPE(EPMoeExpertDispatchFP8InferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(EPMoeExpertDispatchFP8InferDtype));
.SetKernelFn(PD_KERNEL(EPMoeExpertDispatchFP8));