Co-authored-by: gongweibao <gognweibao@baidu.com>
This commit is contained in:
gongweibao
2026-03-04 21:55:31 +08:00
committed by GitHub
parent 5c8f5184d9
commit ddb06ff83f
306 changed files with 40627 additions and 34418 deletions
+104 -93
View File
@@ -28,13 +28,13 @@ __global__ void compute_total_rows_before_expert_kernel(
const int64_t sorted_experts_len,
const int64_t num_experts,
int64_t* total_rows_before_expert) {
// First, compute the global tid. We only need 1 thread per expert.
const int expert = blockIdx.x * blockDim.x + threadIdx.x;
if (expert >= num_experts) return;
// First, compute the global tid. We only need 1 thread per expert.
const int expert = blockIdx.x * blockDim.x + threadIdx.x;
if (expert >= num_experts) return;
// This should construct the last index where each expert occurs.
total_rows_before_expert[expert] =
find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert);
// This should construct the last index where each expert occurs.
total_rows_before_expert[expert] =
find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert);
}
void compute_total_rows_before_expert(int* sorted_indices,
@@ -42,11 +42,11 @@ void compute_total_rows_before_expert(int* sorted_indices,
const int64_t num_experts,
int64_t* total_rows_before_expert,
cudaStream_t stream) {
const int threads = std::min(int64_t(1024), num_experts);
const int blocks = (num_experts + threads - 1) / threads;
const int threads = std::min(int64_t(1024), num_experts);
const int blocks = (num_experts + threads - 1) / threads;
compute_total_rows_before_expert_kernel<<<blocks, threads, 0, stream>>>(
sorted_indices, total_indices, num_experts, total_rows_before_expert);
compute_total_rows_before_expert_kernel<<<blocks, threads, 0, stream>>>(
sorted_indices, total_indices, num_experts, total_rows_before_expert);
}
} // namespace phi
@@ -65,38 +65,47 @@ void FusedMoeKernel(const paddle::Tensor& input,
const bool group_moe,
const bool norm_topk_prob,
paddle::Tensor* output) {
using namespace phi;
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
using namespace phi;
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto* output_data = output->data<data_t>();
auto* output_data = output->data<data_t>();
auto fp16_moe_gemm_runner = MoeGemmRunner<DataType_, cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kNone>>();
auto int8_moe_gemm_runner = MoeGemmRunner<DataType_, cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt8>>();
auto int4_moe_gemm_runner = MoeGemmRunner<DataType_, cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt4>>();
auto fp16_moe_gemm_runner = MoeGemmRunner<
DataType_,
cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kNone>>();
auto int8_moe_gemm_runner = MoeGemmRunner<
DataType_,
cutlass::WintQuantTraits<DataType_,
cutlass::WintQuantMethod::kWeightOnlyInt8>>();
auto int4_moe_gemm_runner = MoeGemmRunner<
DataType_,
cutlass::WintQuantTraits<DataType_,
cutlass::WintQuantMethod::kWeightOnlyInt4>>();
using NvType = typename traits_::DataType;
auto moe_compute = MoeHelper<data_t, NvType>(quant_method,
&fp16_moe_gemm_runner,
&int8_moe_gemm_runner,
&int4_moe_gemm_runner);
using NvType = typename traits_::DataType;
auto moe_compute = MoeHelper<data_t, NvType>(quant_method,
&fp16_moe_gemm_runner,
&int8_moe_gemm_runner,
&int4_moe_gemm_runner);
moe_compute.ComputeFFN(&input,
&gate_weight,
&up_gate_proj_weight,
up_gate_proj_scale ? up_gate_proj_scale.get_ptr() : nullptr,
up_gate_proj_bias ? up_gate_proj_bias.get_ptr() : nullptr,
&down_proj_weight,
down_proj_scale ? down_proj_scale.get_ptr() : nullptr,
down_proj_bias ? down_proj_bias.get_ptr() : nullptr,
nullptr,
moe_topk,
group_moe,
norm_topk_prob,
1.0, // ComputeFFN
"ffn",
output);
moe_compute.ComputeFFN(
&input,
&gate_weight,
&up_gate_proj_weight,
up_gate_proj_scale ? up_gate_proj_scale.get_ptr() : nullptr,
up_gate_proj_bias ? up_gate_proj_bias.get_ptr() : nullptr,
&down_proj_weight,
down_proj_scale ? down_proj_scale.get_ptr() : nullptr,
down_proj_bias ? down_proj_bias.get_ptr() : nullptr,
nullptr,
moe_topk,
group_moe,
norm_topk_prob,
1.0, // ComputeFFN
"ffn",
output);
}
paddle::Tensor FusedExpertMoeFunc(
@@ -112,44 +121,44 @@ paddle::Tensor FusedExpertMoeFunc(
const int moe_topk,
const bool norm_topk_prob,
const bool group_moe) {
const auto input_type = input.dtype();
auto output = paddle::empty_like(input);
const auto input_type = input.dtype();
auto output = paddle::empty_like(input);
switch (input_type) {
case paddle::DataType::BFLOAT16:
FusedMoeKernel<paddle::DataType::BFLOAT16>(input,
gate_weight,
up_gate_proj_weight,
up_gate_proj_scale,
up_gate_proj_bias,
down_proj_weight,
down_proj_scale,
down_proj_bias,
quant_method,
moe_topk,
group_moe,
norm_topk_prob,
&output);
break;
case paddle::DataType::FLOAT16:
FusedMoeKernel<paddle::DataType::FLOAT16>(input,
gate_weight,
up_gate_proj_weight,
up_gate_proj_scale,
up_gate_proj_bias,
down_proj_weight,
down_proj_scale,
down_proj_bias,
quant_method,
moe_topk,
group_moe,
norm_topk_prob,
&output);
break;
default:
PD_THROW("Unsupported data type for FusedMoeKernel");
}
return output;
switch (input_type) {
case paddle::DataType::BFLOAT16:
FusedMoeKernel<paddle::DataType::BFLOAT16>(input,
gate_weight,
up_gate_proj_weight,
up_gate_proj_scale,
up_gate_proj_bias,
down_proj_weight,
down_proj_scale,
down_proj_bias,
quant_method,
moe_topk,
group_moe,
norm_topk_prob,
&output);
break;
case paddle::DataType::FLOAT16:
FusedMoeKernel<paddle::DataType::FLOAT16>(input,
gate_weight,
up_gate_proj_weight,
up_gate_proj_scale,
up_gate_proj_bias,
down_proj_weight,
down_proj_scale,
down_proj_bias,
quant_method,
moe_topk,
group_moe,
norm_topk_prob,
&output);
break;
default:
PD_THROW("Unsupported data type for FusedMoeKernel");
}
return output;
}
std::vector<paddle::Tensor> FusedExpertMoe(
@@ -165,18 +174,18 @@ std::vector<paddle::Tensor> FusedExpertMoe(
const int moe_topk,
const bool norm_topk_prob,
const bool group_moe) {
return {FusedExpertMoeFunc(input,
gate_weight,
up_gate_proj_weight,
down_proj_weight,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_bias,
down_proj_scale,
quant_method,
moe_topk,
norm_topk_prob,
group_moe)};
return {FusedExpertMoeFunc(input,
gate_weight,
up_gate_proj_weight,
down_proj_weight,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_bias,
down_proj_scale,
quant_method,
moe_topk,
norm_topk_prob,
group_moe)};
}
std::vector<std::vector<int64_t>> FusedExpertMoeInferShape(
@@ -188,7 +197,7 @@ std::vector<std::vector<int64_t>> FusedExpertMoeInferShape(
const paddle::optional<std::vector<int64_t>>& up_gate_proj_scale_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_bias_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_scale_shape) {
return {input_shape};
return {input_shape};
}
std::vector<paddle::DataType> FusedExpertMoeInferDtype(
@@ -200,13 +209,14 @@ std::vector<paddle::DataType> FusedExpertMoeInferDtype(
const paddle::optional<paddle::DataType>& up_gate_proj_scale_dtype,
const paddle::optional<paddle::DataType>& down_proj_bias_dtype,
const paddle::optional<paddle::DataType>& down_proj_scale_dtype) {
return {input_dtype};
return {input_dtype};
}
/**
* @brief Fused Mixture-of-Experts (MoE) Operator
*
* This operator combines three key MoE operations into a single optimized kernel:
* This operator combines three key MoE operations into a single optimized
* kernel:
* 1. moe_dispatch - Routes tokens to top-k experts using gating network
* 2. moe_ffn - Processes tokens through parallel expert FFNs
* 3. moe_reduce - Combines expert outputs with routing weights
@@ -219,9 +229,10 @@ std::vector<paddle::DataType> FusedExpertMoeInferDtype(
* output = ∑_i^topk(softmax(gate(x))_i * FFN_i(x)
*
* Reference Components:
* moe_dispatch: Selects top-k experts per token and generates permutation indices
* moe_ffn: Applies SwiGLU activation expert networks in parallel
* moe_reduce: Combines weighted expert outputs and restores original token order
* moe_dispatch: Selects top-k experts per token and generates permutation
* indices moe_ffn: Applies SwiGLU activation expert networks in parallel
* moe_reduce: Combines weighted expert outputs and restores original token
* order
*
* Performance Notes:
* - Recommended hidden_size multiples of 128 for optimal memory alignment
+71 -76
View File
@@ -7,8 +7,10 @@ namespace MARLIN_NAMESPACE_NAME {
template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void gptq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr,
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {
uint32_t const* __restrict__ perm_ptr,
uint32_t* __restrict__ out_ptr,
int size_k,
int size_n) {
constexpr int pack_factor = 32 / num_bits;
int k_tiles = size_k / tile_k_size;
@@ -224,7 +226,8 @@ __global__ void gptq_marlin_repack_kernel(
while (n_tile_id < n_tiles) {
#pragma unroll
for (int pipe = 0; pipe < repack_stages; pipe++) {
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
fetch_to_shared((pipe + repack_stages - 1) % repack_stages,
k_tile_id,
n_tile_id + pipe + repack_stages - 1);
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
wait_for_stage();
@@ -234,85 +237,83 @@ __global__ void gptq_marlin_repack_kernel(
}
}
} // namespace marlin
} // namespace MARLIN_NAMESPACE_NAME
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \
MARLIN_NAMESPACE_NAME::gptq_marlin_repack_kernel<MARLIN_NAMESPACE_NAME::repack_threads, NUM_BITS, \
HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
MARLIN_NAMESPACE_NAME::gptq_marlin_repack_kernel<MARLIN_NAMESPACE_NAME::repack_threads, NUM_BITS, \
HAS_PERM> \
<<<blocks, MARLIN_NAMESPACE_NAME::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute(MARLIN_NAMESPACE_NAME::gptq_marlin_repack_kernel< \
MARLIN_NAMESPACE_NAME::repack_threads, \
NUM_BITS, \
HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, \
max_shared_mem); \
MARLIN_NAMESPACE_NAME::gptq_marlin_repack_kernel< \
MARLIN_NAMESPACE_NAME::repack_threads, \
NUM_BITS, \
HAS_PERM> \
<<<blocks, \
MARLIN_NAMESPACE_NAME::repack_threads, \
max_shared_mem, \
stream>>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
}
std::vector<paddle::Tensor> gptq_marlin_repack(paddle::Tensor& b_q_weight, paddle::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits) {
std::vector<paddle::Tensor> gptq_marlin_repack(paddle::Tensor& b_q_weight,
paddle::Tensor& perm,
int64_t size_k,
int64_t size_n,
int64_t num_bits) {
// Verify compatibility with marlin tile of 16x64
PADDLE_ENFORCE(
size_k % MARLIN_NAMESPACE_NAME::tile_k_size == 0,
"size_k = ", size_k,
" is not divisible by tile_k_size = ",
MARLIN_NAMESPACE_NAME::tile_k_size);
PADDLE_ENFORCE(size_k % MARLIN_NAMESPACE_NAME::tile_k_size == 0,
"size_k = ",
size_k,
" is not divisible by tile_k_size = ",
MARLIN_NAMESPACE_NAME::tile_k_size);
PADDLE_ENFORCE(
size_n % MARLIN_NAMESPACE_NAME::tile_n_size == 0,
"size_n = ", size_n,
" is not divisible by tile_n_size = ",
MARLIN_NAMESPACE_NAME::tile_n_size);
PADDLE_ENFORCE(size_n % MARLIN_NAMESPACE_NAME::tile_n_size == 0,
"size_n = ",
size_n,
" is not divisible by tile_n_size = ",
MARLIN_NAMESPACE_NAME::tile_n_size);
PADDLE_ENFORCE(
num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
PADDLE_ENFORCE(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ",
num_bits);
int const pack_factor = 32 / num_bits;
// Verify B
// shape checks
PADDLE_ENFORCE(
(size_k / pack_factor) == b_q_weight.dims()[0],
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.dims()[0]);
PADDLE_ENFORCE((size_k / pack_factor) == b_q_weight.dims()[0],
"Shape mismatch: b_q_weight.size(0) = ",
b_q_weight.dims()[0]);
PADDLE_ENFORCE(
b_q_weight.dims()[1] == size_n,
"Shape mismatch: b_q_weight.size(1) = ", b_q_weight.dims()[1],
", expected size_n = ", size_n);
PADDLE_ENFORCE(b_q_weight.dims()[1] == size_n,
"Shape mismatch: b_q_weight.size(1) = ",
b_q_weight.dims()[1],
", expected size_n = ",
size_n);
// Verify device and strides
PADDLE_ENFORCE(
b_q_weight.is_gpu(),
"b_q_weight is not on GPU");
// Verify device and strides
PADDLE_ENFORCE(b_q_weight.is_gpu(), "b_q_weight is not on GPU");
PADDLE_ENFORCE(
b_q_weight.is_contiguous(),
"b_q_weight is not contiguous");
PADDLE_ENFORCE(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
PADDLE_ENFORCE(
b_q_weight.dtype() == phi::DataType::INT32,
"b_q_weight type is not kInt");
PADDLE_ENFORCE(b_q_weight.dtype() == phi::DataType::INT32,
"b_q_weight type is not kInt");
PADDLE_ENFORCE(
perm.is_gpu(),
"perm is not on GPU");
PADDLE_ENFORCE(perm.is_gpu(), "perm is not on GPU");
PADDLE_ENFORCE(
perm.is_contiguous(),
"perm is not contiguous");
PADDLE_ENFORCE(perm.is_contiguous(), "perm is not contiguous");
PADDLE_ENFORCE(
perm.dtype() == phi::DataType::INT32,
"perm type is not kInt");
PADDLE_ENFORCE(perm.dtype() == phi::DataType::INT32, "perm type is not kInt");
// Alloc buffers
// const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
paddle::Tensor out = paddle::empty(
{size_k / MARLIN_NAMESPACE_NAME::tile_size, size_n * MARLIN_NAMESPACE_NAME::tile_size / pack_factor},
b_q_weight.dtype(),
b_q_weight.place());
paddle::Tensor out =
paddle::empty({size_k / MARLIN_NAMESPACE_NAME::tile_size,
size_n * MARLIN_NAMESPACE_NAME::tile_size / pack_factor},
b_q_weight.dtype(),
b_q_weight.place());
// Detect if there is act_order
bool has_perm = perm.dims()[0] != 0;
@@ -330,13 +331,11 @@ std::vector<paddle::Tensor> gptq_marlin_repack(paddle::Tensor& b_q_weight, paddl
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
int max_shared_mem = 0;
cudaDeviceGetAttribute(&max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
cudaDeviceGetAttribute(
&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
// TORCH_CHECK(max_shared_mem > 0);
PADDLE_ENFORCE(
max_shared_mem > 0,
"max_shared_mem must be > 0. Got = ", max_shared_mem);
max_shared_mem > 0, "max_shared_mem must be > 0. Got = ", max_shared_mem);
if (false) {
}
@@ -347,11 +346,11 @@ std::vector<paddle::Tensor> gptq_marlin_repack(paddle::Tensor& b_q_weight, paddl
else {
// TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits,
// ", has_perm = ", has_perm);
PADDLE_ENFORCE(
false,
"Unsupported repack config: num_bits = ", num_bits,
", has_perm = ", has_perm);
PADDLE_ENFORCE(false,
"Unsupported repack config: num_bits = ",
num_bits,
", has_perm = ",
has_perm);
}
return {out};
@@ -360,9 +359,5 @@ std::vector<paddle::Tensor> gptq_marlin_repack(paddle::Tensor& b_q_weight, paddl
PD_BUILD_STATIC_OP(gptq_marlin_repack)
.Inputs({"b_q_weight", "perm"})
.Outputs({"out"})
.Attrs({
"size_k: int64_t",
"size_n: int64_t",
"num_bits: int64_t"
})
.Attrs({"size_k: int64_t", "size_n: int64_t", "num_bits: int64_t"})
.SetKernelFn(PD_KERNEL(gptq_marlin_repack));
@@ -19,115 +19,119 @@
#pragma once
template <typename index, typename T, int VecSize>
__global__ void group_swiglu_with_masked_kernel(T* act_out,
const T* input,
const index *token_nums_per_expert,
const int64_t group_num,
const int64_t group_size,
const int64_t hidden_dim) {
int64_t global_idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
const int64_t num = group_num * group_size * hidden_dim;
using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec0, src_vec1;
LoadT res_vec;
__global__ void group_swiglu_with_masked_kernel(
T* act_out,
const T* input,
const index* token_nums_per_expert,
const int64_t group_num,
const int64_t group_size,
const int64_t hidden_dim) {
int64_t global_idx =
static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
const int64_t num = group_num * group_size * hidden_dim;
using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec0, src_vec1;
LoadT res_vec;
int64_t block_id = static_cast<int64_t>(blockIdx.x);
const int lane_idx = threadIdx.x % 32;
int64_t block_id = static_cast<int64_t>(blockIdx.x);
const int lane_idx = threadIdx.x % 32;
while(true) {
int dealt_group_id = -1;
int dealt_seq_id = -1;
if (lane_idx == 0 ) {
int cumsum1 = 0;
int cumsum2 = 0;
for (int i = 0; i < group_num; i++) {
int tmp = token_nums_per_expert[i];
cumsum2 += tmp;
if (block_id >= cumsum1 && block_id < cumsum2) {
dealt_group_id = i;
dealt_seq_id = block_id - cumsum1;
break;
}
cumsum1 += tmp;
}
while (true) {
int dealt_group_id = -1;
int dealt_seq_id = -1;
if (lane_idx == 0) {
int cumsum1 = 0;
int cumsum2 = 0;
for (int i = 0; i < group_num; i++) {
int tmp = token_nums_per_expert[i];
cumsum2 += tmp;
if (block_id >= cumsum1 && block_id < cumsum2) {
dealt_group_id = i;
dealt_seq_id = block_id - cumsum1;
break;
}
dealt_group_id = __shfl_sync(0xffffffff, dealt_group_id, 0);
dealt_seq_id = __shfl_sync(0xffffffff, dealt_seq_id, 0);
if (dealt_group_id < 0) break;
const int64_t r_offset = (dealt_group_id * group_size + dealt_seq_id) * hidden_dim * 2;
const int64_t w_offset = (dealt_group_id * group_size + dealt_seq_id) * hidden_dim;
for (int64_t col_id = threadIdx.x * VecSize; col_id < hidden_dim; col_id += blockDim.x * VecSize) {
Load<T, VecSize>(&input[r_offset + col_id], &src_vec0);
Load<T, VecSize>(&input[r_offset + col_id + hidden_dim], &src_vec1);
for (int j = 0; j < VecSize; ++j) {
float a = static_cast<float>(src_vec0[j]);
float b = static_cast<float>(src_vec1[j]);
float res = b * a / (1.f + exp(-a));
res_vec[j] = static_cast<T>(res);
}
Store<T, VecSize>(res_vec, &act_out[w_offset + col_id]);
}
block_id += gridDim.x;
cumsum1 += tmp;
}
}
dealt_group_id = __shfl_sync(0xffffffff, dealt_group_id, 0);
dealt_seq_id = __shfl_sync(0xffffffff, dealt_seq_id, 0);
if (dealt_group_id < 0) break;
const int64_t r_offset =
(dealt_group_id * group_size + dealt_seq_id) * hidden_dim * 2;
const int64_t w_offset =
(dealt_group_id * group_size + dealt_seq_id) * hidden_dim;
for (int64_t col_id = threadIdx.x * VecSize; col_id < hidden_dim;
col_id += blockDim.x * VecSize) {
Load<T, VecSize>(&input[r_offset + col_id], &src_vec0);
Load<T, VecSize>(&input[r_offset + col_id + hidden_dim], &src_vec1);
for (int j = 0; j < VecSize; ++j) {
float a = static_cast<float>(src_vec0[j]);
float b = static_cast<float>(src_vec1[j]);
float res = b * a / (1.f + exp(-a));
res_vec[j] = static_cast<T>(res);
}
Store<T, VecSize>(res_vec, &act_out[w_offset + col_id]);
}
block_id += gridDim.x;
}
}
paddle::Tensor GroupSwigluWithMasked(const paddle::Tensor& fc1_out_tensor,
const paddle::Tensor& token_nums_per_expert
)
{
const int64_t group_num = token_nums_per_expert.shape()[0];
const int64_t group_size = fc1_out_tensor.shape()[1];
const int64_t hidden_dim = fc1_out_tensor.shape()[2] / 2;
auto act_out_tensor = GetEmptyTensor({group_num, group_size, hidden_dim}, fc1_out_tensor.dtype(), fc1_out_tensor.place());
paddle::Tensor GroupSwigluWithMasked(
const paddle::Tensor& fc1_out_tensor,
const paddle::Tensor& token_nums_per_expert) {
const int64_t group_num = token_nums_per_expert.shape()[0];
const int64_t group_size = fc1_out_tensor.shape()[1];
const int64_t hidden_dim = fc1_out_tensor.shape()[2] / 2;
auto act_out_tensor = GetEmptyTensor({group_num, group_size, hidden_dim},
fc1_out_tensor.dtype(),
fc1_out_tensor.place());
constexpr int VecSize = 8;
PD_CHECK(fc1_out_tensor.dtype() == paddle::DataType::BFLOAT16);
PD_CHECK(hidden_dim % VecSize == 0);
constexpr int VecSize = 8;
PD_CHECK(fc1_out_tensor.dtype() == paddle::DataType::BFLOAT16);
PD_CHECK(hidden_dim % VecSize == 0);
constexpr paddle::DataType D = paddle::DataType::BFLOAT16;
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
constexpr paddle::DataType D = paddle::DataType::BFLOAT16;
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
const int threads = 512;
const int blocks = 256;
const int threads = 512;
const int blocks = 256;
#define dispatch_by_index(index) {\
group_swiglu_with_masked_kernel<index, DataType_, VecSize><<<blocks, threads, 0, fc1_out_tensor.stream()>>>(\
reinterpret_cast<DataType_*>(const_cast<data_t*>(act_out_tensor.data<data_t>())),\
reinterpret_cast<const DataType_*>(fc1_out_tensor.data<data_t>()),\
token_nums_per_expert.data<index>(),\
group_num,\
group_size,\
hidden_dim\
);} while(0)
if (token_nums_per_expert.dtype() == paddle::DataType::INT64) {
dispatch_by_index(int64_t);
} else if(token_nums_per_expert.dtype() == paddle::DataType::INT32) {
dispatch_by_index(int32_t);
} else {
PD_THROW("Unsupported token_nums_per_expert's data dtype.");
}
#define dispatch_by_index(index) \
{ \
group_swiglu_with_masked_kernel<index, DataType_, VecSize> \
<<<blocks, threads, 0, fc1_out_tensor.stream()>>>( \
reinterpret_cast<DataType_*>( \
const_cast<data_t*>(act_out_tensor.data<data_t>())), \
reinterpret_cast<const DataType_*>(fc1_out_tensor.data<data_t>()), \
token_nums_per_expert.data<index>(), \
group_num, \
group_size, \
hidden_dim); \
} \
while (0)
if (token_nums_per_expert.dtype() == paddle::DataType::INT64) {
dispatch_by_index(int64_t);
} else if (token_nums_per_expert.dtype() == paddle::DataType::INT32) {
dispatch_by_index(int32_t);
} else {
PD_THROW("Unsupported token_nums_per_expert's data dtype.");
}
return act_out_tensor;
return act_out_tensor;
}
std::vector<paddle::Tensor> GroupSwigluWithMaskedWrapper(
const paddle::Tensor& input,
const paddle::Tensor& token_nums_per_expert) {
return {GroupSwigluWithMasked(input, token_nums_per_expert)};
const paddle::Tensor& input, const paddle::Tensor& token_nums_per_expert) {
return {GroupSwigluWithMasked(input, token_nums_per_expert)};
}
PD_BUILD_STATIC_OP(group_swiglu_with_masked)
.Inputs({"input",
"token_nums_per_expert"})
.Inputs({"input", "token_nums_per_expert"})
.Outputs({"output_tensor"})
.SetKernelFn(PD_KERNEL(GroupSwigluWithMaskedWrapper));
@@ -15,6 +15,6 @@
#pragma once
#include "helper.h"
paddle::Tensor
GroupSwigluWithMasked(const paddle::Tensor &fc1_out_tensor,
const paddle::Tensor &token_nums_per_expert);
paddle::Tensor GroupSwigluWithMasked(
const paddle::Tensor &fc1_out_tensor,
const paddle::Tensor &token_nums_per_expert);
+133 -115
View File
@@ -14,146 +14,164 @@
#include "helper.h"
template<typename T, int VecSize, int TopK>
__global__ void MoEDeepGEMMDePermuteKernel(T* out, const T* ffn_out, const int* permute_indices_per_token, const int64_t* topk_idx, const float* topk_weights, const int token_num, const int num_vecs, const int hidden, const int max_tokens_per_expert) {
AlignedVector<T, VecSize> in_vec;
template <typename T, int VecSize, int TopK>
__global__ void MoEDeepGEMMDePermuteKernel(T* out,
const T* ffn_out,
const int* permute_indices_per_token,
const int64_t* topk_idx,
const float* topk_weights,
const int token_num,
const int num_vecs,
const int hidden,
const int max_tokens_per_expert) {
AlignedVector<T, VecSize> in_vec;
AlignedVector<T, VecSize> acc_vec[TopK];
AlignedVector<T, VecSize> acc_vec[TopK];
const int bid = blockIdx.x;
const int wid = threadIdx.x / 32;
const int tid = threadIdx.x % 32;
extern __shared__ char shm[]; // TopK * hidden
T* shm_hidden = reinterpret_cast<T*>(shm);
const int bid = blockIdx.x;
const int wid = threadIdx.x / 32;
const int tid = threadIdx.x % 32;
extern __shared__ char shm[]; // TopK * hidden
T* shm_hidden = reinterpret_cast<T*>(shm);
for (int token_idx = bid; token_idx < token_num; token_idx += gridDim.x) {
int src_expert_id = topk_idx[token_idx * TopK + wid];
int src_expert_token = permute_indices_per_token[token_idx * TopK + wid];
float weight = topk_weights[token_idx * TopK + wid];
for (int token_idx = bid; token_idx < token_num; token_idx += gridDim.x) {
int src_expert_id = topk_idx[token_idx * TopK + wid];
int src_expert_token = permute_indices_per_token[token_idx * TopK + wid];
float weight = topk_weights[token_idx * TopK + wid];
for (int hidden_vec_id = tid; hidden_vec_id < num_vecs; hidden_vec_id += 32) {
Load<T, VecSize>(ffn_out + src_expert_id * max_tokens_per_expert * hidden + src_expert_token * hidden + hidden_vec_id * VecSize, &in_vec);
for (int hidden_vec_id = tid; hidden_vec_id < num_vecs;
hidden_vec_id += 32) {
Load<T, VecSize>(ffn_out +
src_expert_id * max_tokens_per_expert * hidden +
src_expert_token * hidden + hidden_vec_id * VecSize,
&in_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
in_vec[i] *= weight;
}
Store<T, VecSize>(in_vec, shm_hidden + wid * hidden + hidden_vec_id * VecSize);
}
__syncthreads();
for (int hidden_vec_id = threadIdx.x; hidden_vec_id < num_vecs; hidden_vec_id += blockDim.x) {
#pragma unroll
for (int topk_id = 0; topk_id < TopK; topk_id++) {
Load<T, VecSize>(shm_hidden + topk_id * hidden + hidden_vec_id * VecSize, &acc_vec[topk_id]);
}
#pragma unroll
for (int i = 0; i < VecSize; i++) {
#pragma unroll
for (int topk_id = 1; topk_id < TopK; topk_id++) {
acc_vec[0][i] += acc_vec[topk_id][i];
}
}
Store<T, VecSize>(acc_vec[0], out + token_idx * hidden + hidden_vec_id * VecSize);
}
for (int i = 0; i < VecSize; i++) {
in_vec[i] *= weight;
}
Store<T, VecSize>(in_vec,
shm_hidden + wid * hidden + hidden_vec_id * VecSize);
}
__syncthreads();
for (int hidden_vec_id = threadIdx.x; hidden_vec_id < num_vecs;
hidden_vec_id += blockDim.x) {
#pragma unroll
for (int topk_id = 0; topk_id < TopK; topk_id++) {
Load<T, VecSize>(
shm_hidden + topk_id * hidden + hidden_vec_id * VecSize,
&acc_vec[topk_id]);
}
#pragma unroll
for (int i = 0; i < VecSize; i++) {
#pragma unroll
for (int topk_id = 1; topk_id < TopK; topk_id++) {
acc_vec[0][i] += acc_vec[topk_id][i];
}
}
Store<T, VecSize>(acc_vec[0],
out + token_idx * hidden + hidden_vec_id * VecSize);
}
}
}
template <paddle::DataType D>
std::vector<paddle::Tensor> MoEDeepGEMMDePermuteDispatch(
const paddle::Tensor& ffn_out, // [num_experts, max_tokens_per_expert, hidden]
const paddle::Tensor& permute_indices_per_token, // [token_num, topk}]
const paddle::Tensor&
ffn_out, // [num_experts, max_tokens_per_expert, hidden]
const paddle::Tensor& permute_indices_per_token, // [token_num, topk}]
const paddle::Tensor& topk_idx,
const paddle::Tensor& topk_weights
) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
const paddle::Tensor& topk_weights) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
const int token_num = permute_indices_per_token.shape()[0];
const int max_tokens_per_expert = ffn_out.shape()[1];
const int hidden = ffn_out.shape()[2];
const int topk = permute_indices_per_token.shape()[1];
const int token_num = permute_indices_per_token.shape()[0];
const int max_tokens_per_expert = ffn_out.shape()[1];
const int hidden = ffn_out.shape()[2];
const int topk = permute_indices_per_token.shape()[1];
auto place = ffn_out.place();
auto stream = ffn_out.stream();
auto place = ffn_out.place();
auto stream = ffn_out.stream();
auto out = GetEmptyTensor({token_num, hidden}, ffn_out.dtype(), place);
auto out = GetEmptyTensor({token_num, hidden}, ffn_out.dtype(), place);
constexpr int VecSize = 16 / sizeof(data_t);
int blocks = 32 * topk;
int grids = min(132 * 4, token_num);
int num_vecs = hidden / VecSize;
constexpr int VecSize = 16 / sizeof(data_t);
int blocks = 32 * topk;
int grids = min(132 * 4, token_num);
int num_vecs = hidden / VecSize;
assert(blocks <= 1024);
int dyn_smem_size = 0;
assert(blocks <= 1024);
int dyn_smem_size = 0;
switch (topk) {
case 4:
dyn_smem_size = topk * hidden * sizeof(DataType_);
if (dyn_smem_size >= (48 << 10)) {
cudaFuncSetAttribute(
MoEDeepGEMMDePermuteKernel<DataType_, VecSize, 4>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
dyn_smem_size);
}
MoEDeepGEMMDePermuteKernel<DataType_, VecSize, 4><<<grids, blocks, dyn_smem_size, stream>>>(
reinterpret_cast<DataType_*>(out.data<data_t>()),
reinterpret_cast<const DataType_*>(ffn_out.data<data_t>()),
permute_indices_per_token.data<int32_t>(),
topk_idx.data<int64_t>(),
topk_weights.data<float>(),
token_num, num_vecs, hidden, max_tokens_per_expert
);
break;
switch (topk) {
case 4:
dyn_smem_size = topk * hidden * sizeof(DataType_);
if (dyn_smem_size >= (48 << 10)) {
cudaFuncSetAttribute(MoEDeepGEMMDePermuteKernel<DataType_, VecSize, 4>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
dyn_smem_size);
}
MoEDeepGEMMDePermuteKernel<DataType_, VecSize, 4>
<<<grids, blocks, dyn_smem_size, stream>>>(
reinterpret_cast<DataType_*>(out.data<data_t>()),
reinterpret_cast<const DataType_*>(ffn_out.data<data_t>()),
permute_indices_per_token.data<int32_t>(),
topk_idx.data<int64_t>(),
topk_weights.data<float>(),
token_num,
num_vecs,
hidden,
max_tokens_per_expert);
break;
case 8:
dyn_smem_size = topk * hidden * sizeof(DataType_);
if (dyn_smem_size >= (48 << 10)) {
cudaFuncSetAttribute(
MoEDeepGEMMDePermuteKernel<DataType_, VecSize, 8>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
dyn_smem_size);
}
MoEDeepGEMMDePermuteKernel<DataType_, VecSize, 8><<<grids, blocks, topk * hidden * sizeof(DataType_), stream>>>(
reinterpret_cast<DataType_*>(out.data<data_t>()),
reinterpret_cast<const DataType_*>(ffn_out.data<data_t>()),
permute_indices_per_token.data<int32_t>(),
topk_idx.data<int64_t>(),
topk_weights.data<float>(),
token_num, num_vecs, hidden, max_tokens_per_expert
);
break;
case 8:
dyn_smem_size = topk * hidden * sizeof(DataType_);
if (dyn_smem_size >= (48 << 10)) {
cudaFuncSetAttribute(MoEDeepGEMMDePermuteKernel<DataType_, VecSize, 8>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
dyn_smem_size);
}
MoEDeepGEMMDePermuteKernel<DataType_, VecSize, 8>
<<<grids, blocks, topk * hidden * sizeof(DataType_), stream>>>(
reinterpret_cast<DataType_*>(out.data<data_t>()),
reinterpret_cast<const DataType_*>(ffn_out.data<data_t>()),
permute_indices_per_token.data<int32_t>(),
topk_idx.data<int64_t>(),
topk_weights.data<float>(),
token_num,
num_vecs,
hidden,
max_tokens_per_expert);
break;
default:
PD_THROW("Unsupported topk");
}
return {out};
default:
PD_THROW("Unsupported topk");
}
return {out};
}
std::vector<paddle::Tensor> MoEDeepGEMMDePermute(
const paddle::Tensor& ffn_out, // [num_experts, max_tokens_per_expert, hidden]
const paddle::Tensor& permute_indices_per_token, // [token_num, topk}]
const paddle::Tensor&
ffn_out, // [num_experts, max_tokens_per_expert, hidden]
const paddle::Tensor& permute_indices_per_token, // [token_num, topk}]
const paddle::Tensor& topk_idx,
const paddle::Tensor& topk_weights
) {
switch (ffn_out.dtype()) {
case paddle::DataType::BFLOAT16:
return MoEDeepGEMMDePermuteDispatch<paddle::DataType::BFLOAT16>(
ffn_out, permute_indices_per_token, topk_idx, topk_weights
);
case paddle::DataType::FLOAT16:
return MoEDeepGEMMDePermuteDispatch<paddle::DataType::FLOAT16>(
ffn_out, permute_indices_per_token, topk_idx, topk_weights
);
default:
PD_THROW("Unsupported data type");
}
const paddle::Tensor& topk_weights) {
switch (ffn_out.dtype()) {
case paddle::DataType::BFLOAT16:
return MoEDeepGEMMDePermuteDispatch<paddle::DataType::BFLOAT16>(
ffn_out, permute_indices_per_token, topk_idx, topk_weights);
case paddle::DataType::FLOAT16:
return MoEDeepGEMMDePermuteDispatch<paddle::DataType::FLOAT16>(
ffn_out, permute_indices_per_token, topk_idx, topk_weights);
default:
PD_THROW("Unsupported data type");
}
}
PD_BUILD_STATIC_OP(moe_deepgemm_depermute)
.Inputs({"ffn_out", "permute_indices_per_token", "topk_idx", "topk_weights"})
.Inputs(
{"ffn_out", "permute_indices_per_token", "topk_idx", "topk_weights"})
.Outputs({"out"})
.SetKernelFn(PD_KERNEL(MoEDeepGEMMDePermute));
+85 -66
View File
@@ -15,29 +15,41 @@
#include "helper.h"
// topk warps
template<typename T, int VecSize>
__global__ void MoEDeepGEMMPermuteKernel(T* out, int* token_nums_per_expert, int* permute_indices_per_token, const T* x, const int64_t* topk_idx, const int token_num, const int topk, const int num_vecs, const int hidden, const int max_tokens_per_expert) {
template <typename T, int VecSize>
__global__ void MoEDeepGEMMPermuteKernel(T* out,
int* token_nums_per_expert,
int* permute_indices_per_token,
const T* x,
const int64_t* topk_idx,
const int token_num,
const int topk,
const int num_vecs,
const int hidden,
const int max_tokens_per_expert) {
AlignedVector<T, VecSize> in_vec;
AlignedVector<T, VecSize> in_vec;
const int bid = blockIdx.x;
const int wid = threadIdx.x / 32;
const int tid = threadIdx.x % 32;
for (int token_idx = bid; token_idx < token_num; token_idx += gridDim.x) {
const int tgt_expert_id = topk_idx[token_idx * topk + wid];
int tgt_expert_token;
if (tid == 0) {
tgt_expert_token = atomicAdd(token_nums_per_expert + tgt_expert_id, 1);
permute_indices_per_token[token_idx * topk + wid] = tgt_expert_token;
}
tgt_expert_token = __shfl_sync(0xFFFFFFFF, tgt_expert_token, 0);
for (int hidden_vec_id = tid; hidden_vec_id < num_vecs; hidden_vec_id += 32) {
Load<T, VecSize>(x + token_idx * hidden + hidden_vec_id * VecSize, &in_vec);
Store<T, VecSize>(in_vec, out + tgt_expert_id * max_tokens_per_expert * hidden + tgt_expert_token * hidden + hidden_vec_id * VecSize);
}
const int bid = blockIdx.x;
const int wid = threadIdx.x / 32;
const int tid = threadIdx.x % 32;
for (int token_idx = bid; token_idx < token_num; token_idx += gridDim.x) {
const int tgt_expert_id = topk_idx[token_idx * topk + wid];
int tgt_expert_token;
if (tid == 0) {
tgt_expert_token = atomicAdd(token_nums_per_expert + tgt_expert_id, 1);
permute_indices_per_token[token_idx * topk + wid] = tgt_expert_token;
}
tgt_expert_token = __shfl_sync(0xFFFFFFFF, tgt_expert_token, 0);
for (int hidden_vec_id = tid; hidden_vec_id < num_vecs;
hidden_vec_id += 32) {
Load<T, VecSize>(x + token_idx * hidden + hidden_vec_id * VecSize,
&in_vec);
Store<T, VecSize>(in_vec,
out + tgt_expert_id * max_tokens_per_expert * hidden +
tgt_expert_token * hidden +
hidden_vec_id * VecSize);
}
}
}
template <paddle::DataType D>
@@ -45,71 +57,78 @@ std::vector<paddle::Tensor> MoEDeepGEMMPermuteDispatch(
const paddle::Tensor& x,
const paddle::Tensor& topk_idx,
const int num_experts,
const int max_tokens_per_expert
) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
const int max_tokens_per_expert) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
const int token_num = x.shape()[0];
const int hidden = x.shape()[1];
const int topk = topk_idx.shape()[1];
const int token_num = x.shape()[0];
const int hidden = x.shape()[1];
const int topk = topk_idx.shape()[1];
auto place = x.place();
auto stream = x.stream();
auto place = x.place();
auto stream = x.stream();
auto token_nums_per_expert = GetEmptyTensor({num_experts}, paddle::DataType::INT32, place);
auto permute_indices_per_token = GetEmptyTensor({token_num, topk}, paddle::DataType::INT32, place);
auto token_nums_per_expert =
GetEmptyTensor({num_experts}, paddle::DataType::INT32, place);
auto permute_indices_per_token =
GetEmptyTensor({token_num, topk}, paddle::DataType::INT32, place);
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(token_nums_per_expert.data<int32_t>(), 0, num_experts * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(token_nums_per_expert.data<int32_t>(),
0,
num_experts * sizeof(int32_t),
stream));
auto permute_output = GetEmptyTensor({num_experts, max_tokens_per_expert, hidden}, x.dtype(), place);
auto permute_output = GetEmptyTensor(
{num_experts, max_tokens_per_expert, hidden}, x.dtype(), place);
auto permute_output_data = permute_output.data<data_t>();
auto permute_output_data = permute_output.data<data_t>();
constexpr int VecSize = 16 / sizeof(data_t);
constexpr int VecSize = 16 / sizeof(data_t);
int blocks = 32 * topk;
int grids = min(132 * 4, token_num);
int num_vecs = hidden / VecSize;
int blocks = 32 * topk;
int grids = min(132 * 4, token_num);
int num_vecs = hidden / VecSize;
assert(blocks <= 1024);
assert(blocks <= 1024);
MoEDeepGEMMPermuteKernel<DataType_, VecSize><<<grids, blocks, 0, stream>>>(
reinterpret_cast<DataType_*>(permute_output_data),
token_nums_per_expert.data<int32_t>(),
permute_indices_per_token.data<int32_t>(),
reinterpret_cast<const DataType_ *>(x.data<data_t>()),
topk_idx.data<int64_t>(),
token_num, topk, num_vecs,
hidden, max_tokens_per_expert
);
MoEDeepGEMMPermuteKernel<DataType_, VecSize><<<grids, blocks, 0, stream>>>(
reinterpret_cast<DataType_*>(permute_output_data),
token_nums_per_expert.data<int32_t>(),
permute_indices_per_token.data<int32_t>(),
reinterpret_cast<const DataType_*>(x.data<data_t>()),
topk_idx.data<int64_t>(),
token_num,
topk,
num_vecs,
hidden,
max_tokens_per_expert);
return {permute_output, token_nums_per_expert, permute_indices_per_token};
return {permute_output, token_nums_per_expert, permute_indices_per_token};
}
std::vector<paddle::Tensor> MoEDeepGEMMPermute(
const paddle::Tensor& x,
const paddle::Tensor& topk_idx,
const int num_experts,
const int max_tokens_per_expert
) {
switch (x.dtype()) {
case paddle::DataType::BFLOAT16:
return MoEDeepGEMMPermuteDispatch<paddle::DataType::BFLOAT16>(
x, topk_idx, num_experts, max_tokens_per_expert
);
case paddle::DataType::FLOAT16:
return MoEDeepGEMMPermuteDispatch<paddle::DataType::FLOAT16>(
x, topk_idx, num_experts, max_tokens_per_expert
);
default:
PD_THROW("Unsupported data type");
}
const int max_tokens_per_expert) {
switch (x.dtype()) {
case paddle::DataType::BFLOAT16:
return MoEDeepGEMMPermuteDispatch<paddle::DataType::BFLOAT16>(
x, topk_idx, num_experts, max_tokens_per_expert);
case paddle::DataType::FLOAT16:
return MoEDeepGEMMPermuteDispatch<paddle::DataType::FLOAT16>(
x, topk_idx, num_experts, max_tokens_per_expert);
default:
PD_THROW("Unsupported data type");
}
}
PD_BUILD_STATIC_OP(moe_deepgemm_permute)
.Inputs({"x", "topk_idx"})
.Outputs({"permute_output", "token_nums_per_expert", "permute_indices_per_token"})
.Outputs({"permute_output",
"token_nums_per_expert",
"permute_indices_per_token"})
.Attrs({"num_experts: int", "max_tokens_per_expert: int"})
.SetKernelFn(PD_KERNEL(MoEDeepGEMMPermute));
+70 -38
View File
@@ -27,8 +27,10 @@ void MoeReduceKernel(const paddle::Tensor &ffn_out,
const paddle::Tensor &top_k_indices,
const paddle::optional<paddle::Tensor> &down_proj_bias,
const bool norm_topk_prob,
const float routed_scaling_factor, const int num_rows,
const int hidden_size, const int topk,
const float routed_scaling_factor,
const int num_rows,
const int hidden_size,
const int topk,
paddle::Tensor *output) {
using namespace phi;
typedef PDTraits<T> traits_;
@@ -37,19 +39,29 @@ void MoeReduceKernel(const paddle::Tensor &ffn_out,
auto stream = ffn_out.stream();
finalize_moe_routing_kernelLauncher(
ffn_out.data<data_t>(), output->data<data_t>(),
ffn_out.data<data_t>(),
output->data<data_t>(),
down_proj_bias ? down_proj_bias->data<data_t>() : nullptr,
top_k_weight.data<float>(), permute_indices_per_token.data<int32_t>(),
top_k_indices.data<int>(), num_rows, hidden_size, topk,
static_cast<int>(1), norm_topk_prob, routed_scaling_factor, stream);
top_k_weight.data<float>(),
permute_indices_per_token.data<int32_t>(),
top_k_indices.data<int>(),
num_rows,
hidden_size,
topk,
static_cast<int>(1),
norm_topk_prob,
routed_scaling_factor,
stream);
}
paddle::Tensor MoeExpertReduceFunc(
const paddle::Tensor &ffn_out, const paddle::Tensor &top_k_weight,
const paddle::Tensor &ffn_out,
const paddle::Tensor &top_k_weight,
const paddle::Tensor &permute_indices_per_token,
const paddle::Tensor &top_k_indices,
const paddle::optional<paddle::Tensor> &down_proj_bias,
const bool norm_topk_prob, const float routed_scaling_factor) {
const bool norm_topk_prob,
const float routed_scaling_factor) {
const auto input_type = ffn_out.dtype();
auto place = ffn_out.place();
@@ -59,38 +71,57 @@ paddle::Tensor MoeExpertReduceFunc(
auto output = GetEmptyTensor({num_rows, hidden_size}, input_type, place);
if(num_rows == 0){
if (num_rows == 0) {
return output;
}
switch (input_type) {
case paddle::DataType::BFLOAT16:
MoeReduceKernel<paddle::DataType::BFLOAT16>(
ffn_out, top_k_weight, permute_indices_per_token, top_k_indices,
down_proj_bias, norm_topk_prob, routed_scaling_factor, num_rows, hidden_size,
topk, &output);
break;
case paddle::DataType::FLOAT16:
MoeReduceKernel<paddle::DataType::BFLOAT16>(
ffn_out, top_k_weight, permute_indices_per_token, top_k_indices,
down_proj_bias, norm_topk_prob, routed_scaling_factor, num_rows, hidden_size,
topk, &output);
break;
default:
PD_THROW("Unsupported data type for MoeDispatchKernel");
case paddle::DataType::BFLOAT16:
MoeReduceKernel<paddle::DataType::BFLOAT16>(ffn_out,
top_k_weight,
permute_indices_per_token,
top_k_indices,
down_proj_bias,
norm_topk_prob,
routed_scaling_factor,
num_rows,
hidden_size,
topk,
&output);
break;
case paddle::DataType::FLOAT16:
MoeReduceKernel<paddle::DataType::BFLOAT16>(ffn_out,
top_k_weight,
permute_indices_per_token,
top_k_indices,
down_proj_bias,
norm_topk_prob,
routed_scaling_factor,
num_rows,
hidden_size,
topk,
&output);
break;
default:
PD_THROW("Unsupported data type for MoeDispatchKernel");
}
return output;
}
std::vector<paddle::Tensor>
MoeExpertReduce(const paddle::Tensor &ffn_out,
const paddle::Tensor &top_k_weight,
const paddle::Tensor &permute_indices_per_token,
const paddle::Tensor &top_k_indices,
const paddle::optional<paddle::Tensor> &down_proj_bias,
const bool norm_topk_prob, const float routed_scaling_factor) {
return {MoeExpertReduceFunc(ffn_out, top_k_weight, permute_indices_per_token,
top_k_indices, down_proj_bias, norm_topk_prob,
std::vector<paddle::Tensor> MoeExpertReduce(
const paddle::Tensor &ffn_out,
const paddle::Tensor &top_k_weight,
const paddle::Tensor &permute_indices_per_token,
const paddle::Tensor &top_k_indices,
const paddle::optional<paddle::Tensor> &down_proj_bias,
const bool norm_topk_prob,
const float routed_scaling_factor) {
return {MoeExpertReduceFunc(ffn_out,
top_k_weight,
permute_indices_per_token,
top_k_indices,
down_proj_bias,
norm_topk_prob,
routed_scaling_factor)};
}
@@ -115,7 +146,6 @@ std::vector<paddle::DataType> MoeExpertReduceInferDtype(
return {ffn_out_dtype};
}
/**
* @brief Mixture of Experts (MoE) Expert Reduce Operator
*
@@ -131,9 +161,8 @@ std::vector<paddle::DataType> MoeExpertReduceInferDtype(
* - top_k_weight: Routing weights for top-k experts per token
* Shape: [total_tokens, moe_topk]
* dtype: float32
* - permute_indices_per_token: Indices mapping for reconstructing original order
* Shape: [moe_topk, total_tokens]
* dtype: int32
* - permute_indices_per_token: Indices mapping for reconstructing original
* order Shape: [moe_topk, total_tokens] dtype: int32
* - top_k_indices: Indices of selected top-k experts for each token
* Shape: [total_tokens, moe_topk]
* dtype: int32
@@ -157,8 +186,11 @@ std::vector<paddle::DataType> MoeExpertReduceInferDtype(
* - For optimal performance, hidden_size should be a multiple of 128
*/
PD_BUILD_STATIC_OP(moe_expert_reduce)
.Inputs({"ffn_out", "top_k_weight", "permute_indices_per_token",
"top_k_indices", paddle::Optional("down_proj_bias")})
.Inputs({"ffn_out",
"top_k_weight",
"permute_indices_per_token",
"top_k_indices",
paddle::Optional("down_proj_bias")})
.Outputs({"output"})
.Attrs({"norm_topk_prob:bool", "routed_scaling_factor:float"})
.SetKernelFn(PD_KERNEL(MoeExpertReduce))
@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// Ignore CUTLASS warnings about type punning
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
@@ -27,35 +26,42 @@
using namespace phi;
template <typename T>
void moe_redundant_topk_select_kernel(const T* input,
const T* bias,
T* output,
T* softmax,
const int* expert_id_to_ep_rank_array,
const int* expert_in_rank_num_list,
int* tokens_per_expert_stats_list,
int64_t* indices,
int64_t* indices_tmp,
int* source_row,
T* softmax_max_prob,
const int64_t num_rows,
const int64_t num_experts,
const int64_t k,
const int redundant_ep_rank_num_plus_one,
cudaStream_t stream,
const bool apply_norm_weight = false,
const bool enable_softmax_top_k_fused = false
) {
void moe_redundant_topk_select_kernel(
const T* input,
const T* bias,
T* output,
T* softmax,
const int* expert_id_to_ep_rank_array,
const int* expert_in_rank_num_list,
int* tokens_per_expert_stats_list,
int64_t* indices,
int64_t* indices_tmp,
int* source_row,
T* softmax_max_prob,
const int64_t num_rows,
const int64_t num_experts,
const int64_t k,
const int redundant_ep_rank_num_plus_one,
cudaStream_t stream,
const bool apply_norm_weight = false,
const bool enable_softmax_top_k_fused = false) {
static constexpr int WARPS_PER_TB = 4;
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
case N: { \
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB>( \
input, bias, output, indices, source_row, num_rows, num_experts, k, stream); \
break; \
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
case N: { \
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB>(input, \
bias, \
output, \
indices, \
source_row, \
num_rows, \
num_experts, \
k, \
stream); \
break; \
}
int64_t tem_num_experts = num_experts;
if(bias != nullptr || apply_norm_weight) tem_num_experts = 0;
if (bias != nullptr || apply_norm_weight) tem_num_experts = 0;
switch (tem_num_experts) {
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(2)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(4)
@@ -70,60 +76,60 @@ void moe_redundant_topk_select_kernel(const T* input,
static constexpr int TPB = 256;
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
if (!enable_softmax_top_k_fused) {
moe_softmax<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
input, softmax, num_experts, num_rows);
if (apply_norm_weight) {
moe_redundant_top_k_normed<T, TPB>
<<<config_topk.block_per_grid, TPB, k * sizeof(T), stream>>>(softmax,
bias,
expert_id_to_ep_rank_array,
expert_in_rank_num_list,
tokens_per_expert_stats_list,
output,
indices,
indices_tmp,
source_row,
num_experts,
k,
num_rows,
redundant_ep_rank_num_plus_one);
} else {
moe_top_k<T, TPB>
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
bias,
output,
indices,
source_row,
num_experts,
k,
num_rows);
}
moe_softmax<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
input, softmax, num_experts, num_rows);
if (apply_norm_weight) {
moe_redundant_top_k_normed<T, TPB>
<<<config_topk.block_per_grid, TPB, k * sizeof(T), stream>>>(
softmax,
bias,
expert_id_to_ep_rank_array,
expert_in_rank_num_list,
tokens_per_expert_stats_list,
output,
indices,
indices_tmp,
source_row,
num_experts,
k,
num_rows,
redundant_ep_rank_num_plus_one);
} else {
moe_top_k<T, TPB>
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
bias,
output,
indices,
source_row,
num_experts,
k,
num_rows);
}
} else {
assert(k <= TPB);
if (apply_norm_weight) {
moe_softmax_top_k_fused<T, TPB, true>
<<<config_topk.block_per_grid, TPB, k * sizeof(T), stream>>>(
input,
bias,
output,
indices,
source_row,
num_experts,
k,
num_rows);
} else {
moe_softmax_top_k_fused<T, TPB, false>
<<<config_topk.block_per_grid, TPB, 0, stream>>>(input,
bias,
output,
indices,
source_row,
num_experts,
k,
num_rows);
}
}
else {
assert(k<=TPB);
if (apply_norm_weight) {
moe_softmax_top_k_fused<T, TPB, true>
<<<config_topk.block_per_grid, TPB, k * sizeof(T), stream>>>(input,
bias,
output,
indices,
source_row,
num_experts,
k,
num_rows);
} else {
moe_softmax_top_k_fused<T, TPB, false>
<<<config_topk.block_per_grid, TPB, 0, stream>>>(input,
bias,
output,
indices,
source_row,
num_experts,
k,
num_rows);
}
}
}
}
}
@@ -195,24 +201,25 @@ std::vector<paddle::Tensor> MoERedundantTopKSelectKernel(
softmax_out_ = nullptr;
}
moe_redundant_topk_select_kernel<float>(gating_logits.data<float>(),
bias ? bias.get().data<float>() : nullptr,
topk_weights.data<float>(),
softmax_out_,
expert_id_to_ep_rank_array.data<int>(),
expert_in_rank_num_list.data<int>(),
tokens_per_expert_stats_list.data<int>(),
topk_ids_data,
topk_ids_tmp_data,
source_rows_,
softmax_max_prob,
num_rows,
expert_num,
moe_topk,
redundant_ep_rank_num_plus_one,
stream,
apply_norm_weight,
enable_softmax_top_k_fused);
moe_redundant_topk_select_kernel<float>(
gating_logits.data<float>(),
bias ? bias.get().data<float>() : nullptr,
topk_weights.data<float>(),
softmax_out_,
expert_id_to_ep_rank_array.data<int>(),
expert_in_rank_num_list.data<int>(),
tokens_per_expert_stats_list.data<int>(),
topk_ids_data,
topk_ids_tmp_data,
source_rows_,
softmax_max_prob,
num_rows,
expert_num,
moe_topk,
redundant_ep_rank_num_plus_one,
stream,
apply_norm_weight,
enable_softmax_top_k_fused);
return {topk_ids, topk_weights};
}
@@ -235,8 +242,7 @@ std::vector<std::vector<int64_t>> MoERedundantTopKSelectKernelInferShape(
}
const int num_rows = token_rows;
return {{num_rows, moe_topk},
{num_rows, moe_topk}};
return {{num_rows, moe_topk}, {num_rows, moe_topk}};
}
std::vector<paddle::DataType> MoERedundantTopKSelectKernelInferDtype(
@@ -249,18 +255,22 @@ std::vector<paddle::DataType> MoERedundantTopKSelectKernelInferDtype(
const bool apply_norm_weight,
const bool enable_softmax_top_k_fused,
const int redundant_ep_rank_num_plus_one) {
return {paddle::DataType::INT64,
paddle::DataType::FLOAT32};
return {paddle::DataType::INT64, paddle::DataType::FLOAT32};
}
PD_BUILD_STATIC_OP(moe_redundant_topk_select)
.Inputs({"gating_logits", "expert_id_to_ep_rank_array", "expert_in_rank_num_list", "tokens_per_expert_stats_list", paddle::Optional("bias")})
.Outputs({"topk_ids",
"topk_weights",
"tokens_per_expert_stats_list_out"})
.Attrs({"moe_topk:int", "apply_norm_weight:bool", "enable_softmax_top_k_fused:bool", "redundant_ep_rank_num_plus_one:int"})
.SetInplaceMap({{"tokens_per_expert_stats_list", "tokens_per_expert_stats_list_out"}})
.Inputs({"gating_logits",
"expert_id_to_ep_rank_array",
"expert_in_rank_num_list",
"tokens_per_expert_stats_list",
paddle::Optional("bias")})
.Outputs({"topk_ids", "topk_weights", "tokens_per_expert_stats_list_out"})
.Attrs({"moe_topk:int",
"apply_norm_weight:bool",
"enable_softmax_top_k_fused:bool",
"redundant_ep_rank_num_plus_one:int"})
.SetInplaceMap({{"tokens_per_expert_stats_list",
"tokens_per_expert_stats_list_out"}})
.SetKernelFn(PD_KERNEL(MoERedundantTopKSelectKernel))
.SetInferShapeFn(PD_INFER_SHAPE(MoERedundantTopKSelectKernelInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MoERedundantTopKSelectKernelInferDtype));
File diff suppressed because it is too large Load Diff
@@ -1,6 +1,6 @@
#pragma once
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
#include "paddle/phi/api/include/api.h"
@@ -41,15 +41,18 @@ class ScalarType {
NAN_REPR_ID_MAX
};
constexpr ScalarType(uint8_t exponent, uint8_t mantissa, bool signed_,
int32_t bias, bool finite_values_only = false,
constexpr ScalarType(uint8_t exponent,
uint8_t mantissa,
bool signed_,
int32_t bias,
bool finite_values_only = false,
NanRepr nan_repr = NAN_IEEE_754)
: exponent(exponent),
mantissa(mantissa),
signed_(signed_),
bias(bias),
finite_values_only(finite_values_only),
nan_repr(nan_repr) {};
nan_repr(nan_repr){};
static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
return ScalarType(0, size_bits - 1, true, bias);
@@ -67,16 +70,17 @@ class ScalarType {
}
// IEEE 754 non-compliant floating point type
static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa,
static constexpr ScalarType float_(uint8_t exponent,
uint8_t mantissa,
bool finite_values_only,
NanRepr nan_repr) {
// PADDLE_ENFORCE(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr");
// PADDLE_ENFORCE(mantissa > 0 && exponent > 0);
// PADDLE_ENFORCE(nan_repr != NAN_IEEE_754,
// "use `float_IEEE754` constructor for floating point types that "
// "follow IEEE 754 conventions");
return ScalarType(exponent, mantissa, true, 0, finite_values_only,
nan_repr);
// "use `float_IEEE754` constructor for floating point types
// that " "follow IEEE 754 conventions");
return ScalarType(
exponent, mantissa, true, 0, finite_values_only, nan_repr);
}
uint8_t const exponent; // size of the exponent field (0 for integer types)
@@ -103,7 +107,9 @@ class ScalarType {
}
template <typename Fn, typename Init, typename Member, typename... Rest>
static constexpr auto reduce_members_helper(Fn f, Init val, Member member,
static constexpr auto reduce_members_helper(Fn f,
Init val,
Member member,
Rest... rest) {
auto new_val = f(val, member);
if constexpr (sizeof...(rest) > 0) {
@@ -116,8 +122,14 @@ class ScalarType {
template <typename Fn, typename Init>
constexpr auto reduce_members(Fn f, Init init) const {
// Should be in constructor order for `from_id`
return reduce_members_helper(f, init, exponent, mantissa, signed_, bias,
finite_values_only, nan_repr);
return reduce_members_helper(f,
init,
exponent,
mantissa,
signed_,
bias,
finite_values_only,
nan_repr);
};
template <typename Fn, typename Init>
@@ -194,7 +206,8 @@ class ScalarType {
private:
double _floating_point_max() const {
PADDLE_ENFORCE(mantissa <= 52 && exponent <= 11,
"Cannot represent max/min as a double for type ", str());
"Cannot represent max/min as a double for type ",
str());
uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1;
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) {
@@ -204,7 +217,8 @@ class ScalarType {
uint64_t max_exponent = (uint64_t(1) << exponent) - 2;
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) {
PADDLE_ENFORCE(exponent < 11,
"Cannot represent max/min as a double for type ", str());
"Cannot represent max/min as a double for type ",
str());
max_exponent += 1;
}
@@ -327,23 +341,32 @@ using ScalarTypeId = MARLIN_NAMESPACE_NAME::ScalarType::Id;
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70
static inline constexpr auto kS4 = MARLIN_NAMESPACE_NAME::ScalarType::int_(4);
static inline constexpr auto kU4 = MARLIN_NAMESPACE_NAME::ScalarType::uint(4);
static inline constexpr auto kU4B8 = MARLIN_NAMESPACE_NAME::ScalarType::uint(4, 8);
static inline constexpr auto kU4B8 =
MARLIN_NAMESPACE_NAME::ScalarType::uint(4, 8);
static inline constexpr auto kS8 = MARLIN_NAMESPACE_NAME::ScalarType::int_(8);
static inline constexpr auto kU8 = MARLIN_NAMESPACE_NAME::ScalarType::uint(8);
static inline constexpr auto kU8B128 = MARLIN_NAMESPACE_NAME::ScalarType::uint(8, 128);
static inline constexpr auto kU8B128 =
MARLIN_NAMESPACE_NAME::ScalarType::uint(8, 128);
static inline constexpr auto kFE2M1f =
MARLIN_NAMESPACE_NAME::ScalarType::float_(2, 1, true, MARLIN_NAMESPACE_NAME::ScalarType::NAN_NONE);
MARLIN_NAMESPACE_NAME::ScalarType::float_(
2, 1, true, MARLIN_NAMESPACE_NAME::ScalarType::NAN_NONE);
static inline constexpr auto kFE3M2f =
MARLIN_NAMESPACE_NAME::ScalarType::float_(3, 2, true, MARLIN_NAMESPACE_NAME::ScalarType::NAN_NONE);
MARLIN_NAMESPACE_NAME::ScalarType::float_(
3, 2, true, MARLIN_NAMESPACE_NAME::ScalarType::NAN_NONE);
static inline constexpr auto kFE4M3fn =
MARLIN_NAMESPACE_NAME::ScalarType::float_(4, 3, true, MARLIN_NAMESPACE_NAME::ScalarType::NAN_EXTD_RANGE_MAX_MIN);
static inline constexpr auto kFE5M2 = MARLIN_NAMESPACE_NAME::ScalarType::float_IEEE754(5, 2);
static inline constexpr auto kFE8M7 = MARLIN_NAMESPACE_NAME::ScalarType::float_IEEE754(8, 7);
static inline constexpr auto kFE5M10 = MARLIN_NAMESPACE_NAME::ScalarType::float_IEEE754(5, 10);
MARLIN_NAMESPACE_NAME::ScalarType::float_(
4, 3, true, MARLIN_NAMESPACE_NAME::ScalarType::NAN_EXTD_RANGE_MAX_MIN);
static inline constexpr auto kFE5M2 =
MARLIN_NAMESPACE_NAME::ScalarType::float_IEEE754(5, 2);
static inline constexpr auto kFE8M7 =
MARLIN_NAMESPACE_NAME::ScalarType::float_IEEE754(8, 7);
static inline constexpr auto kFE5M10 =
MARLIN_NAMESPACE_NAME::ScalarType::float_IEEE754(5, 10);
// // Fixed width style names, generally following:
// // https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57
// //
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57
constexpr auto kInt4 = kS4;
constexpr auto kUint4 = kU4;
constexpr auto kUint4b8 = kU4B8;
@@ -369,4 +392,4 @@ constexpr auto kBFloat16 = phi::DataType::BFLOAT16;
constexpr auto kFloat32 = phi::DataType::FLOAT32;
constexpr auto kByte = phi::DataType::INT8;
}; // namespace Marlin
}; // namespace MARLIN_NAMESPACE_NAME
@@ -92,7 +92,8 @@ __device__ inline uint32_t prmt(uint32_t a) {
return res;
}
template <typename scalar_t2, MARLIN_NAMESPACE_NAME::ScalarTypeId w_type_id,
template <typename scalar_t2,
MARLIN_NAMESPACE_NAME::ScalarTypeId w_type_id,
bool skip_flop = false>
__device__ inline void dequant(int q, scalar_t2* frag_b);
@@ -106,8 +107,8 @@ __device__ inline void dequant(int q, scalar_t2* frag_b);
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
//
template <>
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU4B8.id(), true>(int q,
half2* frag_b) {
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU4B8.id(), true>(
int q, half2* frag_b) {
const int MASK = 0x000f000f;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
@@ -120,8 +121,8 @@ __device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU4B8.id(), true>(i
}
template <>
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU4B8.id(), false>(int q,
half2* frag_b) {
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU4B8.id(), false>(
int q, half2* frag_b) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
@@ -143,14 +144,14 @@ __device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU4B8.id(), false>(
}
template <>
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU4.id(), true>(int q,
half2* frag_b) {
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU4.id(), true>(
int q, half2* frag_b) {
dequant<half2, MARLIN_NAMESPACE_NAME::kU4B8.id(), true>(q, frag_b);
}
template <>
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU4.id(), false>(int q,
half2* frag_b) {
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU4.id(), false>(
int q, half2* frag_b) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
@@ -172,7 +173,8 @@ __device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU4.id(), false>(in
}
template <>
__device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU4B8.id(), true>(
__device__ inline void
dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU4B8.id(), true>(
int q, nv_bfloat162* frag_b) {
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
@@ -189,7 +191,8 @@ __device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU4B8.id(),
}
template <>
__device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU4B8.id(), false>(
__device__ inline void
dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU4B8.id(), false>(
int q, nv_bfloat162* frag_b) {
dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU4B8.id(), true>(q, frag_b);
@@ -200,13 +203,15 @@ __device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU4B8.id(),
}
template <>
__device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU4.id(), true>(
__device__ inline void
dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU4.id(), true>(
int q, nv_bfloat162* frag_b) {
dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU4B8.id(), true>(q, frag_b);
}
template <>
__device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU4.id(), false>(
__device__ inline void
dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU4.id(), false>(
int q, nv_bfloat162* frag_b) {
dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU4.id(), true>(q, frag_b);
@@ -225,8 +230,9 @@ __device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU4.id(), fa
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
//
template <>
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU8B128.id(), true>(int q,
half2* frag_b) {
__device__ inline void
dequant<half2, MARLIN_NAMESPACE_NAME::kU8B128.id(), true>(int q,
half2* frag_b) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
@@ -239,8 +245,9 @@ __device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU8B128.id(), true>
}
template <>
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU8B128.id(), false>(
int q, half2* frag_b) {
__device__ inline void
dequant<half2, MARLIN_NAMESPACE_NAME::kU8B128.id(), false>(int q,
half2* frag_b) {
dequant<half2, MARLIN_NAMESPACE_NAME::kU8B128.id(), true>(q, frag_b);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
@@ -251,14 +258,14 @@ __device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU8B128.id(), false
}
template <>
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU8.id(), true>(int q,
half2* frag_b) {
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU8.id(), true>(
int q, half2* frag_b) {
dequant<half2, MARLIN_NAMESPACE_NAME::kU8B128.id(), true>(q, frag_b);
}
template <>
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU8.id(), false>(int q,
half2* frag_b) {
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU8.id(), false>(
int q, half2* frag_b) {
dequant<half2, MARLIN_NAMESPACE_NAME::kU8.id(), true>(q, frag_b);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
@@ -269,7 +276,8 @@ __device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU8.id(), false>(in
}
template <>
__device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU8B128.id(), false>(
__device__ inline void
dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU8B128.id(), false>(
int q, nv_bfloat162* frag_b) {
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted =
@@ -287,14 +295,15 @@ __device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU8B128.id()
fp32_intermediates[3] -= 8388736.f;
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
fp32_intermediates_casted[1], 0x7632);
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
fp32_intermediates_casted[3], 0x7632);
bf16_result_ptr[0] = __byte_perm(
fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632);
bf16_result_ptr[1] = __byte_perm(
fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632);
}
template <>
__device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU8.id(), false>(
__device__ inline void
dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU8.id(), false>(
int q, nv_bfloat162* frag_b) {
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted =
@@ -312,15 +321,16 @@ __device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU8.id(), fa
fp32_intermediates[3] -= 8388608.f;
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
fp32_intermediates_casted[1], 0x7632);
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
fp32_intermediates_casted[3], 0x7632);
bf16_result_ptr[0] = __byte_perm(
fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632);
bf16_result_ptr[1] = __byte_perm(
fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632);
}
template <>
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kFE4M3fn.id(), true>(
int q, half2* frag_b) {
__device__ inline void
dequant<half2, MARLIN_NAMESPACE_NAME::kFE4M3fn.id(), true>(int q,
half2* frag_b) {
// Constants for FP8 (E4M3) and FP16 formats
constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5;
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT;
@@ -337,8 +347,9 @@ __device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kFE4M3fn.id(), true
}
template <>
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kFE4M3fn.id(), false>(
int q, half2* frag_b) {
__device__ inline void
dequant<half2, MARLIN_NAMESPACE_NAME::kFE4M3fn.id(), false>(int q,
half2* frag_b) {
dequant<half2, MARLIN_NAMESPACE_NAME::kFE4M3fn.id(), true>(q, frag_b);
// Constants for FP8 (E4M3) and FP16 formats
@@ -355,7 +366,8 @@ __device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kFE4M3fn.id(), fals
}
template <>
__device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kFE4M3fn.id(), true>(
__device__ inline void
dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kFE4M3fn.id(), true>(
int q, nv_bfloat162* frag_b) {
// Constants for FP8 (E4M3) and BF16 formats
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
@@ -374,7 +386,8 @@ __device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kFE4M3fn.id(
}
template <>
__device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kFE4M3fn.id(), false>(
__device__ inline void
dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kFE4M3fn.id(), false>(
int q, nv_bfloat162* frag_b) {
dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kFE4M3fn.id(), true>(q, frag_b);
@@ -396,8 +409,9 @@ __device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kFE4M3fn.id(
}
template <>
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kFE2M1f.id(), true>(int q,
half2* frag_b) {
__device__ inline void
dequant<half2, MARLIN_NAMESPACE_NAME::kFE2M1f.id(), true>(int q,
half2* frag_b) {
// Constants for FP4 (E2M1) and FP16 formats
constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5;
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT;
@@ -414,8 +428,9 @@ __device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kFE2M1f.id(), true>
}
template <>
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kFE2M1f.id(), false>(
int q, half2* frag_b) {
__device__ inline void
dequant<half2, MARLIN_NAMESPACE_NAME::kFE2M1f.id(), false>(int q,
half2* frag_b) {
dequant<half2, MARLIN_NAMESPACE_NAME::kFE2M1f.id(), true>(q, frag_b);
// Constants for FP4 (E2M1) and FP16 formats
@@ -432,7 +447,8 @@ __device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kFE2M1f.id(), false
}
template <>
__device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kFE2M1f.id(), true>(
__device__ inline void
dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kFE2M1f.id(), true>(
int q, nv_bfloat162* frag_b) {
// Constants for FP4 (E2M1) and FP16 formats
constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8;
@@ -450,7 +466,8 @@ __device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kFE2M1f.id()
}
template <>
__device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kFE2M1f.id(), false>(
__device__ inline void
dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kFE2M1f.id(), false>(
int q, nv_bfloat162* frag_b) {
dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kFE2M1f.id(), true>(q, frag_b);
@@ -1,6 +1,6 @@
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
#include "moe/moe_wna16_marlin_utils/marlin.cuh"
#include "moe/moe_wna16_marlin_utils/marlin_dtypes.cuh"
@@ -22,7 +22,8 @@
namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t, // compute dtype, half or nv_float16
const MARLIN_NAMESPACE_NAME::ScalarTypeId w_type_id, // weight ScalarType id
const MARLIN_NAMESPACE_NAME::ScalarTypeId w_type_id, // weight
// ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
@@ -10,7 +10,7 @@
#include <iostream>
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
namespace MARLIN_NAMESPACE_NAME {
@@ -51,10 +51,10 @@ using I4 = Vec<int, 4>;
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
__device__ inline void cp_async4_pred(void* smem_ptr,
const void* glob_ptr,
bool pred = true) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
@@ -64,7 +64,9 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
"r"(smem),
"l"(glob_ptr),
"n"(BYTES));
}
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
@@ -74,7 +76,8 @@ __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
"{\n"
" cp.async.cg.shared.global [%0], [%1], %2;\n"
"}\n" ::"r"(smem),
"l"(glob_ptr), "n"(BYTES));
"l"(glob_ptr),
"n"(BYTES));
}
__device__ inline void cp_async_fence() {
@@ -93,7 +96,6 @@ inline void cp_async_fence() {}
template <int n>
inline void cp_async_wait() {}
#endif
} // namespace MARLIN_NAMESPACE_NAME
@@ -6,7 +6,7 @@
#include <cuda_bf16.h>
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
namespace MARLIN_NAMESPACE_NAME::kernel_types {
@@ -78,6 +78,6 @@ class ScalarType<nv_bfloat16> {
#endif
};
} // namespace MARLIN_NAMESPACE_NAME
} // namespace MARLIN_NAMESPACE_NAME::kernel_types
#endif
@@ -20,7 +20,7 @@
*/
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
#include "moe/moe_wna16_marlin_utils/marlin.cuh"
@@ -38,7 +38,8 @@ namespace MARLIN_NAMESPACE_NAME {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template <typename scalar_t, // compute dtype, half or nv_float16
const MARLIN_NAMESPACE_NAME::ScalarTypeId w_type_id, // weight ScalarType id
const MARLIN_NAMESPACE_NAME::ScalarTypeId w_type_id, // weight
// ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
@@ -86,9 +87,10 @@ __global__ void Marlin(
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
template <typename scalar_t>
__device__ inline void mma(const typename kernel_types::ScalarType<scalar_t>::FragA& a_frag,
const typename kernel_types::ScalarType<scalar_t>::FragB& frag_b,
typename kernel_types::ScalarType<scalar_t>::FragC& frag_c) {
__device__ inline void mma(
const typename kernel_types::ScalarType<scalar_t>::FragA& a_frag,
const typename kernel_types::ScalarType<scalar_t>::FragB& frag_b,
typename kernel_types::ScalarType<scalar_t>::FragC& frag_c) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
float* c = reinterpret_cast<float*>(&frag_c);
@@ -97,15 +99,31 @@ __device__ inline void mma(const typename kernel_types::ScalarType<scalar_t>::Fr
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
: "r"(a[0]),
"r"(a[1]),
"r"(a[2]),
"r"(a[3]),
"r"(b[0]),
"r"(b[1]),
"f"(c[0]),
"f"(c[1]),
"f"(c[2]),
"f"(c[3]));
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
: "r"(a[0]),
"r"(a[1]),
"r"(a[2]),
"r"(a[3]),
"r"(b[0]),
"r"(b[1]),
"f"(c[0]),
"f"(c[1]),
"f"(c[2]),
"f"(c[3]));
} else {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
}
@@ -126,15 +144,31 @@ __device__ inline void mma_trans(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
: "r"(b[0]),
"r"(b2[0]),
"r"(b[1]),
"r"(b2[1]),
"r"(a[0]),
"r"(a[1]),
"f"(c[0]),
"f"(c[1]),
"f"(c[2]),
"f"(c[3]));
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
: "r"(b[0]),
"r"(b2[0]),
"r"(b[1]),
"r"(b2[1]),
"r"(a[0]),
"r"(a[1]),
"f"(c[0]),
"f"(c[1]),
"f"(c[2]),
"f"(c[3]));
} else {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
}
@@ -143,8 +177,9 @@ __device__ inline void mma_trans(
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
template <int count, typename scalar_t>
__device__ inline void ldsm(typename kernel_types::ScalarType<scalar_t>::FragA& frag_a,
const void* smem_ptr) {
__device__ inline void ldsm(
typename kernel_types::ScalarType<scalar_t>::FragA& frag_a,
const void* smem_ptr) {
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
if constexpr (count == 4) {
@@ -168,19 +203,22 @@ __device__ inline void ldsm(typename kernel_types::ScalarType<scalar_t>::FragA&
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
template <typename scalar_t>
__device__ inline void scale(typename kernel_types::ScalarType<scalar_t>::FragB& frag_b,
typename kernel_types::ScalarType<scalar_t>::FragS& frag_s,
int i) {
__device__ inline void scale(
typename kernel_types::ScalarType<scalar_t>::FragB& frag_b,
typename kernel_types::ScalarType<scalar_t>::FragS& frag_s,
int i) {
using scalar_t2 = typename kernel_types::ScalarType<scalar_t>::scalar_t2;
scalar_t2 s =
kernel_types::ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_s)[i]);
scalar_t2 s = kernel_types::ScalarType<scalar_t>::num2num2(
reinterpret_cast<scalar_t*>(&frag_s)[i]);
frag_b[0] = __hmul2(frag_b[0], s);
frag_b[1] = __hmul2(frag_b[1], s);
}
template <typename scalar_t>
__device__ inline void scale_and_sub(
typename kernel_types::ScalarType<scalar_t>::FragB& frag_b, scalar_t s, scalar_t zp) {
typename kernel_types::ScalarType<scalar_t>::FragB& frag_b,
scalar_t s,
scalar_t zp) {
using scalar_t2 = typename kernel_types::ScalarType<scalar_t>::scalar_t2;
scalar_t2 s2 = kernel_types::ScalarType<scalar_t>::num2num2(s);
scalar_t2 zp2 = kernel_types::ScalarType<scalar_t>::num2num2(zp);
@@ -189,24 +227,26 @@ __device__ inline void scale_and_sub(
}
template <typename scalar_t>
__device__ inline void sub_zp(typename kernel_types::ScalarType<scalar_t>::FragB& frag_b,
typename kernel_types::ScalarType<scalar_t>::scalar_t2& frag_zp,
int i) {
__device__ inline void sub_zp(
typename kernel_types::ScalarType<scalar_t>::FragB& frag_b,
typename kernel_types::ScalarType<scalar_t>::scalar_t2& frag_zp,
int i) {
using scalar_t2 = typename kernel_types::ScalarType<scalar_t>::scalar_t2;
scalar_t2 zp =
kernel_types::ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_zp)[i]);
scalar_t2 zp = kernel_types::ScalarType<scalar_t>::num2num2(
reinterpret_cast<scalar_t*>(&frag_zp)[i]);
frag_b[0] = __hsub2(frag_b[0], zp);
frag_b[1] = __hsub2(frag_b[1], zp);
}
// Same as above, but for act_order (each K is multiplied individually)
template <typename scalar_t>
__device__ inline void scale4(typename kernel_types::ScalarType<scalar_t>::FragB& frag_b,
typename kernel_types::ScalarType<scalar_t>::FragS& frag_s_1,
typename kernel_types::ScalarType<scalar_t>::FragS& frag_s_2,
typename kernel_types::ScalarType<scalar_t>::FragS& frag_s_3,
typename kernel_types::ScalarType<scalar_t>::FragS& frag_s_4,
int i) {
__device__ inline void scale4(
typename kernel_types::ScalarType<scalar_t>::FragB& frag_b,
typename kernel_types::ScalarType<scalar_t>::FragS& frag_s_1,
typename kernel_types::ScalarType<scalar_t>::FragS& frag_s_2,
typename kernel_types::ScalarType<scalar_t>::FragS& frag_s_3,
typename kernel_types::ScalarType<scalar_t>::FragS& frag_s_4,
int i) {
using scalar_t2 = typename kernel_types::ScalarType<scalar_t>::scalar_t2;
scalar_t2 s_val_1_2;
s_val_1_2.x = reinterpret_cast<scalar_t*>(&frag_s_1)[i];
@@ -222,11 +262,13 @@ __device__ inline void scale4(typename kernel_types::ScalarType<scalar_t>::FragB
// Given 2 floats multiply by 2 scales (halves)
template <typename scalar_t>
__device__ inline void scale_float(float* c,
typename kernel_types::ScalarType<scalar_t>::FragS& s) {
__device__ inline void scale_float(
float* c, typename kernel_types::ScalarType<scalar_t>::FragS& s) {
scalar_t* s_ptr = reinterpret_cast<scalar_t*>(&s);
c[0] = __fmul_rn(c[0], kernel_types::ScalarType<scalar_t>::num2float(s_ptr[0]));
c[1] = __fmul_rn(c[1], kernel_types::ScalarType<scalar_t>::num2float(s_ptr[1]));
c[0] =
__fmul_rn(c[0], kernel_types::ScalarType<scalar_t>::num2float(s_ptr[0]));
c[1] =
__fmul_rn(c[1], kernel_types::ScalarType<scalar_t>::num2float(s_ptr[1]));
}
// Wait until barrier reaches `count`, then lock for current threadblock.
@@ -279,7 +321,8 @@ __device__ inline void wait_negative_and_add(int* lock) {
}
template <typename scalar_t, // compute dtype, half or nv_float16
const MARLIN_NAMESPACE_NAME::ScalarTypeId w_type_id, // weight ScalarType id
const MARLIN_NAMESPACE_NAME::ScalarTypeId w_type_id, // weight
// ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
@@ -341,10 +384,14 @@ __global__ void Marlin(
using FragZP = typename kernel_types::ScalarType<scalar_t>::FragZP;
extern __shared__ int4 sh[];
static constexpr auto w_type = MARLIN_NAMESPACE_NAME::ScalarType::from_id(w_type_id);
constexpr bool has_zp = w_type == MARLIN_NAMESPACE_NAME::kU4 || w_type == MARLIN_NAMESPACE_NAME::kU8;
constexpr bool is_int_type = w_type == MARLIN_NAMESPACE_NAME::kU4 || w_type == MARLIN_NAMESPACE_NAME::kU8 ||
w_type == MARLIN_NAMESPACE_NAME::kU4B8 || w_type == MARLIN_NAMESPACE_NAME::kU8B128;
static constexpr auto w_type =
MARLIN_NAMESPACE_NAME::ScalarType::from_id(w_type_id);
constexpr bool has_zp = w_type == MARLIN_NAMESPACE_NAME::kU4 ||
w_type == MARLIN_NAMESPACE_NAME::kU8;
constexpr bool is_int_type = w_type == MARLIN_NAMESPACE_NAME::kU4 ||
w_type == MARLIN_NAMESPACE_NAME::kU8 ||
w_type == MARLIN_NAMESPACE_NAME::kU4B8 ||
w_type == MARLIN_NAMESPACE_NAME::kU8B128;
// see comments of dequant.h for more details
constexpr bool dequant_skip_flop =
!is_int_type ||
@@ -361,7 +408,8 @@ __global__ void Marlin(
const int group_size =
(!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups;
const int scales_expert_stride =
prob_n * prob_k / group_size / (w_type == MARLIN_NAMESPACE_NAME::kFE2M1f ? 16 : 8);
prob_n * prob_k / group_size /
(w_type == MARLIN_NAMESPACE_NAME::kFE2M1f ? 16 : 8);
const int zp_expert_stride =
is_zp_float ? prob_n * prob_k / group_size / 8
: prob_n * prob_k / group_size / (pack_factor * 4);
@@ -444,12 +492,12 @@ __global__ void Marlin(
// block_sorted_ids / block_num_valid_tokens / block_topk_weights
auto read_moe_block_data = [&](int block_id) {
block_num_valid_tokens = moe_block_size;
#pragma unroll
#pragma unroll
for (int i = 0; i < moe_block_size / 4; i++) {
int4 sorted_token_ids_int4 = reinterpret_cast<const int4*>(
sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i];
int* sorted_token_ids = reinterpret_cast<int*>(&sorted_token_ids_int4);
#pragma unroll
#pragma unroll
for (int j = 0; j < 4; j++) {
if (sorted_token_ids[j] >= prob_m * top_k) {
block_num_valid_tokens = i * 4 + j;
@@ -465,20 +513,21 @@ __global__ void Marlin(
sh_block_sorted_ids_int4[tid4] = reinterpret_cast<const int4*>(
sorted_token_ids_ptr)[block_id * moe_block_size / 4 + tid4];
#pragma unroll
#pragma unroll
for (int i = 0; i < 4; i++)
sh_rd_block_sorted_ids[tid4 * 4 + i] =
sh_block_sorted_ids[tid4 * 4 + i] / top_k;
if (mul_topk_weights) {
#pragma unroll
#pragma unroll
for (int i = 0; i < 4; i++) {
int idx = tid4 * 4 + i;
idx = idx < block_num_valid_tokens ? idx : 0;
if constexpr (w_type == MARLIN_NAMESPACE_NAME::kFE2M1f) {
sh_block_topk_weights[idx] = __hmul2(
global_scale, Dtype::num2num2(Dtype::float2num(
topk_weights_ptr[sh_block_sorted_ids[idx]])));
sh_block_topk_weights[idx] =
__hmul2(global_scale,
Dtype::num2num2(Dtype::float2num(
topk_weights_ptr[sh_block_sorted_ids[idx]])));
} else {
sh_block_topk_weights[idx] = Dtype::num2num2(
Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]]));
@@ -631,7 +680,8 @@ __global__ void Marlin(
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
constexpr int s_tb_groups =
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
? thread_k_blocks / group_blocks / (w_type == MARLIN_NAMESPACE_NAME::kFE2M1f ? 2 : 1)
? thread_k_blocks / group_blocks /
(w_type == MARLIN_NAMESPACE_NAME::kFE2M1f ? 2 : 1)
: 1;
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
int s_gl_rd_delta = s_gl_stride;
@@ -714,7 +764,8 @@ __global__ void Marlin(
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
int s_sh_rd;
if constexpr (group_blocks != -1 && w_type == MARLIN_NAMESPACE_NAME::kFE2M1f) {
if constexpr (group_blocks != -1 &&
w_type == MARLIN_NAMESPACE_NAME::kFE2M1f) {
auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
@@ -767,13 +818,13 @@ __global__ void Marlin(
// loop unrolls, all shared memory accesses are static, we simply precompute
// both transformed reads and writes.
int a_sh_wr_trans[a_sh_wr_iters];
#pragma unroll
#pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++)
a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
#pragma unroll
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) {
#pragma unroll
#pragma unroll
for (int j = 0; j < thread_m_blocks; j++)
a_sh_rd_trans[i][j] =
transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
@@ -784,7 +835,7 @@ __global__ void Marlin(
// maintining multiple pointers (we have enough registers), a tiny
// optimization.
const int4* B_ptr[b_sh_wr_iters];
#pragma unroll
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
@@ -824,7 +875,7 @@ __global__ void Marlin(
// Zero accumulators.
auto zero_accums = [&]() {
#pragma unroll
#pragma unroll
for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
reinterpret_cast<float*>(frag_c)[i] = 0;
};
@@ -832,39 +883,39 @@ __global__ void Marlin(
int sh_first_group_id = -1;
int sh_num_groups = -1;
auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id,
int last_group_id) {
sh_first_group_id = first_group_id;
sh_num_groups = last_group_id - first_group_id + 1;
auto fetch_act_order_scales_to_shared =
[&](bool is_async, int first_group_id, int last_group_id) {
sh_first_group_id = first_group_id;
sh_num_groups = last_group_id - first_group_id + 1;
if (sh_num_groups > act_s_max_num_groups) {
sh_num_groups = act_s_max_num_groups;
}
if (sh_first_group_id + sh_num_groups > num_groups) {
sh_num_groups = num_groups - sh_first_group_id;
}
int row_offset = first_group_id * s_gl_stride;
if (is_async) {
for (int i = 0; i < sh_num_groups; i++) {
if (threadIdx.x < s_sh_stride) {
cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x],
&scales_ptr[row_offset + (i * s_gl_stride) +
slice_n_offset + threadIdx.x]);
if (sh_num_groups > act_s_max_num_groups) {
sh_num_groups = act_s_max_num_groups;
}
}
} else {
for (int i = 0; i < sh_num_groups; i++) {
if (threadIdx.x < s_sh_stride) {
sh_s[(i * s_sh_stride) + threadIdx.x] =
scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset +
threadIdx.x];
if (sh_first_group_id + sh_num_groups > num_groups) {
sh_num_groups = num_groups - sh_first_group_id;
}
}
}
};
int row_offset = first_group_id * s_gl_stride;
if (is_async) {
for (int i = 0; i < sh_num_groups; i++) {
if (threadIdx.x < s_sh_stride) {
cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x],
&scales_ptr[row_offset + (i * s_gl_stride) +
slice_n_offset + threadIdx.x]);
}
}
} else {
for (int i = 0; i < sh_num_groups; i++) {
if (threadIdx.x < s_sh_stride) {
sh_s[(i * s_sh_stride) + threadIdx.x] =
scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset +
threadIdx.x];
}
}
}
};
// Asynchronously fetch the next A, B and s tile from global to the next
// shared memory pipeline location.
@@ -872,101 +923,102 @@ __global__ void Marlin(
int max_num_stage_groups =
((sh_a_max_row - moe_block_size) / moe_block_size + 1) / stages;
max_num_stage_groups = max(max_num_stage_groups, 1);
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true,
int pipe_a = 0) {
if (pred) {
if (should_load_a) {
int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a;
#pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++) {
int row = a_gl_rd_delta_i / a_gl_stride * i + a_gl_rd_row;
int64_t sorted_row = 0;
if (!m_block_size_8 || row < 8)
sorted_row = sh_rd_block_sorted_ids[row];
int64_t true_idx =
sorted_row * a_gl_stride + a_gl_rd_col + a_gl_rd_delta_o * a_off;
cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[true_idx],
row < block_num_valid_tokens);
}
}
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) {
#pragma unroll
for (int j = 0; j < b_thread_vecs; j++) {
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j],
B_ptr[i] + j + B_expert_off);
}
B_ptr[i] += b_gl_rd_delta_o;
}
if constexpr (has_act_order) {
// Fetch g_idx thread-block portion
int full_pipe = a_off;
int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe;
if (cur_k < prob_k && cur_k < slice_k_finish) {
int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
int4 const* cur_g_idx_stage_ptr =
reinterpret_cast<int4 const*>(&g_idx[cur_k]);
if (threadIdx.x < g_idx_stage) {
cp_async4_pred(&sh_g_idx_stage[threadIdx.x],
&cur_g_idx_stage_ptr[threadIdx.x]);
auto fetch_to_shared =
[&](int pipe, int a_off, bool pred = true, int pipe_a = 0) {
if (pred) {
if (should_load_a) {
int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a;
#pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++) {
int row = a_gl_rd_delta_i / a_gl_stride * i + a_gl_rd_row;
int64_t sorted_row = 0;
if (!m_block_size_8 || row < 8)
sorted_row = sh_rd_block_sorted_ids[row];
int64_t true_idx = sorted_row * a_gl_stride + a_gl_rd_col +
a_gl_rd_delta_o * a_off;
cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]],
&A[true_idx],
row < block_num_valid_tokens);
}
}
}
} else {
if constexpr (group_blocks != -1) {
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if constexpr (group_blocks >= thread_k_blocks) {
// Only fetch scales if this tile starts a new group
if (pipe % (group_blocks / thread_k_blocks) == 0) {
if (s_sh_wr_pred) {
cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) {
#pragma unroll
for (int j = 0; j < b_thread_vecs; j++) {
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j],
B_ptr[i] + j + B_expert_off);
}
B_ptr[i] += b_gl_rd_delta_o;
}
if constexpr (has_act_order) {
// Fetch g_idx thread-block portion
int full_pipe = a_off;
int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe;
if (cur_k < prob_k && cur_k < slice_k_finish) {
int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
int4 const* cur_g_idx_stage_ptr =
reinterpret_cast<int4 const*>(&g_idx[cur_k]);
if (threadIdx.x < g_idx_stage) {
cp_async4_pred(&sh_g_idx_stage[threadIdx.x],
&cur_g_idx_stage_ptr[threadIdx.x]);
}
s_gl_rd += s_gl_rd_delta;
}
} else {
for (int i = 0; i < s_tb_groups; i++) {
if (s_sh_wr_pred) {
cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr],
&scales_ptr[s_gl_rd]);
if constexpr (group_blocks != -1) {
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if constexpr (group_blocks >= thread_k_blocks) {
// Only fetch scales if this tile starts a new group
if (pipe % (group_blocks / thread_k_blocks) == 0) {
if (s_sh_wr_pred) {
cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
}
s_gl_rd += s_gl_rd_delta;
}
} else {
for (int i = 0; i < s_tb_groups; i++) {
if (s_sh_wr_pred) {
cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr],
&scales_ptr[s_gl_rd]);
}
s_gl_rd += s_gl_rd_delta;
}
}
}
if constexpr (has_zp && group_blocks != -1) {
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
if constexpr (group_blocks >= thread_k_blocks) {
// Only fetch zero-points if this tile starts a new group
if (pipe % (group_blocks / thread_k_blocks) == 0) {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);
}
zp_gl_rd += zp_gl_rd_delta;
}
} else {
for (int i = 0; i < zp_tb_groups; i++) {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr],
&zp_ptr[zp_gl_rd]);
}
zp_gl_rd += zp_gl_rd_delta;
}
}
s_gl_rd += s_gl_rd_delta;
}
}
}
if constexpr (has_zp && group_blocks != -1) {
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
if constexpr (group_blocks >= thread_k_blocks) {
// Only fetch zero-points if this tile starts a new group
if (pipe % (group_blocks / thread_k_blocks) == 0) {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);
}
zp_gl_rd += zp_gl_rd_delta;
}
} else {
for (int i = 0; i < zp_tb_groups; i++) {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr],
&zp_ptr[zp_gl_rd]);
}
zp_gl_rd += zp_gl_rd_delta;
}
}
}
}
}
// Insert a fence even when we are winding down the pipeline to ensure that
// waiting is also correct at this point.
cp_async_fence();
};
// Insert a fence even when we are winding down the pipeline to ensure
// that waiting is also correct at this point.
cp_async_fence();
};
auto fetch_col_zp_to_shared = [&]() {
if (zp_sh_wr_pred) {
@@ -994,13 +1046,13 @@ __global__ void Marlin(
// into the current register buffer.
auto fetch_to_registers = [&](int k, int pipe, int pipe_a = 0) {
int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a;
#pragma unroll
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++)
ldsm<m_block_size_8 ? 2 : 4, scalar_t>(
frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll
#pragma unroll
for (int i = 0; i < b_thread_vecs; i++) {
frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(
&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);
@@ -1058,7 +1110,8 @@ __global__ void Marlin(
int k_blocks = cur_k / 16;
int cur_group_id =
k_blocks / (group_blocks * (w_type == MARLIN_NAMESPACE_NAME::kFE2M1f ? 2 : 1));
k_blocks / (group_blocks *
(w_type == MARLIN_NAMESPACE_NAME::kFE2M1f ? 2 : 1));
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
@@ -1129,10 +1182,10 @@ __global__ void Marlin(
int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);
constexpr int k_frag_offsets[4] = {0, 1, 8,
9}; // Tensor core offsets per thread
constexpr int k_frag_offsets[4] = {
0, 1, 8, 9}; // Tensor core offsets per thread
#pragma unroll
#pragma unroll
for (int i = 0; i < 4; i++) {
int actual_k = cur_k + k_frag_offsets[i];
@@ -1156,7 +1209,7 @@ __global__ void Marlin(
if constexpr (group_blocks == -1) {
// load only when starting a new slice
if (k == 0 && full_pipe == 0) {
#pragma unroll
#pragma unroll
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i];
}
@@ -1167,7 +1220,7 @@ __global__ void Marlin(
int4* sh_zp_stage =
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
#pragma unroll
#pragma unroll
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
@@ -1186,16 +1239,16 @@ __global__ void Marlin(
int cur_group_id = 0;
// Suppress bogus and persistent divide-by-zero warning
#pragma nv_diagnostic push
#pragma nv_diag_suppress divide_by_zero
#pragma nv_diagnostic push
#pragma nv_diag_suppress divide_by_zero
cur_group_id = k_blocks / group_blocks;
#pragma nv_diagnostic pop
#pragma nv_diagnostic pop
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
sh_zp_stage += cur_group_id * zp_sh_stride;
#pragma unroll
#pragma unroll
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
@@ -1227,10 +1280,10 @@ __global__ void Marlin(
int k_blocks = cur_k / 16;
// Suppress bogus and persistent divide-by-zero warning
#pragma nv_diagnostic push
#pragma nv_diag_suppress divide_by_zero
#pragma nv_diagnostic push
#pragma nv_diag_suppress divide_by_zero
int cur_group_id = k_blocks / group_blocks;
#pragma nv_diagnostic pop
#pragma nv_diagnostic pop
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
@@ -1287,9 +1340,9 @@ __global__ void Marlin(
s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
}
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
for (int j = 0; j < 4; j++) {
FragB frag_b0;
FragB frag_b1;
@@ -1319,10 +1372,18 @@ __global__ void Marlin(
// Apply scale to frag_b0
if constexpr (has_act_order) {
static_assert(group_blocks != -1);
scale4<scalar_t>(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0);
scale4<scalar_t>(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1);
scale4<scalar_t>(frag_b0,
act_frag_s[k2][0][j],
act_frag_s[k2][1][j],
act_frag_s[k2][2][j],
act_frag_s[k2][3][j],
0);
scale4<scalar_t>(frag_b1,
act_frag_s[k2][0][j],
act_frag_s[k2][1][j],
act_frag_s[k2][2][j],
act_frag_s[k2][3][j],
1);
} else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float &&
group_blocks == -1) {
int idx = (threadIdx.x / 4) % 2;
@@ -1343,7 +1404,7 @@ __global__ void Marlin(
scale<scalar_t>(frag_b1, frag_s[k2][j], 1);
}
#pragma unroll
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
if constexpr (m_block_size_8) {
mma_trans<scalar_t>(frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]);
@@ -1372,12 +1433,12 @@ __global__ void Marlin(
// unnecessary read or write iterations, e.g., for two warps we write only
// once by warp 1 and read only once by warp 0.
#pragma unroll
#pragma unroll
for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
#pragma unroll
#pragma unroll
for (int i = red_off; i > 0; i /= 2) {
if (i <= red_idx && red_idx < 2 * i) {
#pragma unroll
#pragma unroll
for (int j = 0; j < 4 * 2; j += (m_block_size_8 ? 2 : 1)) {
int red_sh_wr =
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
@@ -1385,7 +1446,7 @@ __global__ void Marlin(
float* c_rd = reinterpret_cast<float*>(
&sh_red[red_sh_delta * j + red_sh_rd]);
float* c_wr = reinterpret_cast<float*>(&sh_red[red_sh_wr]);
#pragma unroll
#pragma unroll
for (int k = 0; k < 4; k++)
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
c_rd[k] + c_wr[k];
@@ -1397,11 +1458,11 @@ __global__ void Marlin(
__syncthreads();
}
if (red_idx == 0) {
#pragma unroll
#pragma unroll
for (int i = 0; i < 4 * 2; i += (m_block_size_8 ? 2 : 1)) {
float* c_rd =
reinterpret_cast<float*>(&sh_red[red_sh_delta * i + red_sh_rd]);
#pragma unroll
#pragma unroll
for (int j = 0; j < 4; j++)
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
c_rd[j];
@@ -1444,7 +1505,7 @@ __global__ void Marlin(
if (!first) {
#pragma unroll
#pragma unroll
for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) {
int c_idx;
if constexpr (m_block_size_8)
@@ -1461,11 +1522,11 @@ __global__ void Marlin(
}
}
#pragma unroll
#pragma unroll
for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) {
if (!first) {
int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta];
#pragma unroll
#pragma unroll
for (int j = 0; j < 2 * 4; j++) {
int delta = 0;
if constexpr (m_block_size_8) {
@@ -1478,7 +1539,7 @@ __global__ void Marlin(
}
if (!last) {
int4 c;
#pragma unroll
#pragma unroll
for (int j = 0; j < 2 * 4; j++) {
int delta = 0;
if constexpr (m_block_size_8) {
@@ -1527,7 +1588,7 @@ __global__ void Marlin(
if (!first) {
float* frag_c_ptr = reinterpret_cast<float*>(&frag_c);
#pragma unroll
#pragma unroll
for (int k = 0; k < th_size; k++) {
if constexpr (m_block_size_8) {
if (k % 2) continue;
@@ -1540,7 +1601,7 @@ __global__ void Marlin(
C_tmp[c_cur_offset + active_threads * k + threadIdx.x];
float* sh_c_ptr = reinterpret_cast<float*>(&sh_red[threadIdx.x]);
#pragma unroll
#pragma unroll
for (int f = 0; f < 4; f++) {
frag_c_ptr[k * 4 + f] += sh_c_ptr[f];
}
@@ -1549,7 +1610,7 @@ __global__ void Marlin(
if (!last) {
int4* frag_c_ptr = reinterpret_cast<int4*>(&frag_c);
#pragma unroll
#pragma unroll
for (int k = 0; k < th_size; k++) {
if constexpr (m_block_size_8) {
if (k % 2) continue;
@@ -1619,26 +1680,38 @@ __global__ void Marlin(
};
if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
#pragma unroll
for (int j = 0; j < 4; j++) {
if constexpr (m_block_size_8) {
int wr = c_sh_wr + 16 * j;
write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1],
write(wr,
frag_c[i][j][0][0],
frag_c[i][j][0][1],
frag_s[j / 2][2 * (j % 2) + 0]);
write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3],
write(wr + 8,
frag_c[i][j][0][2],
frag_c[i][j][0][3],
frag_s[j / 2][2 * (j % 2) + 1]);
} else {
int wr = c_sh_wr + 8 * j;
write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
write(wr + (4 * c_sh_stride) * 0 + 0,
frag_c[i][j][0][0],
frag_c[i][j][0][1],
frag_s[j / 2][2 * (j % 2) + 0]);
write(wr + (4 * c_sh_stride) * 8 + 0,
frag_c[i][j][0][2],
frag_c[i][j][0][3],
frag_s[j / 2][2 * (j % 2) + 0]);
write(wr + (4 * c_sh_stride) * 0 + 4,
frag_c[i][j][1][0],
frag_c[i][j][1][1],
frag_s[j / 2][2 * (j % 2) + 1]);
write(wr + (4 * c_sh_stride) * 8 + 4,
frag_c[i][j][1][2],
frag_c[i][j][1][3],
frag_s[j / 2][2 * (j % 2) + 1]);
}
}
c_sh_wr += 16 * (4 * c_sh_stride);
@@ -1646,7 +1719,7 @@ __global__ void Marlin(
}
__syncthreads();
#pragma unroll
#pragma unroll
for (int i = 0;
i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
i++) {
@@ -1660,7 +1733,7 @@ __global__ void Marlin(
scalar_t2* C_half2 = reinterpret_cast<scalar_t2*>(&C[true_idx]);
scalar_t2* sh_red_half2 =
reinterpret_cast<scalar_t2*>(&sh_red[c_sh_rd]);
#pragma unroll
#pragma unroll
for (int a = 0; a < 4; a++) {
scalar_t2 res = sh_red_half2[a];
if (mul_topk_weights) {
@@ -1686,15 +1759,15 @@ __global__ void Marlin(
// Start global fetch and register load pipelines.
auto start_pipes = [&]() {
#pragma unroll
#pragma unroll
for (int i = 0; i < stages - 1; i++) {
if (has_act_order && i == 0) {
int last_g_idx = slice_k_start + stages * tb_k * 2;
if (last_g_idx >= prob_k) {
last_g_idx = prob_k - 1;
}
fetch_act_order_scales_to_shared(true, g_idx[slice_k_start],
g_idx[last_g_idx]);
fetch_act_order_scales_to_shared(
true, g_idx[slice_k_start], g_idx[last_g_idx]);
}
if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
@@ -1732,9 +1805,9 @@ __global__ void Marlin(
for (int stage_group_id = 0; stage_group_id < max_num_stage_groups;
stage_group_id++) {
#pragma unroll
#pragma unroll
for (int pipe = 0; pipe < stages;) {
#pragma unroll
#pragma unroll
for (int k = 0; k < b_sh_wr_iters; k++) {
int idx =
(pipe >= stages && stage_group_id == max_num_stage_groups - 1)
@@ -1747,8 +1820,8 @@ __global__ void Marlin(
int idx = (pipe >= 1 && stage_group_id == max_num_stage_groups - 1)
? (pipe - 1)
: (pipe + (stage_group_id + 1) * stages - 1);
fetch_to_shared((pipe + stages - 1) % stages, pipe,
slice_iters >= stages, idx);
fetch_to_shared(
(pipe + stages - 1) % stages, pipe, slice_iters >= stages, idx);
pipe++;
wait_for_stage();
init_same_group(pipe % stages);
@@ -1775,8 +1848,8 @@ __global__ void Marlin(
}
int last_group_id = g_idx[last_g_idx];
if (last_group_id >= sh_first_group_id + sh_num_groups) {
fetch_act_order_scales_to_shared(false, first_group_id,
last_group_id);
fetch_act_order_scales_to_shared(
false, first_group_id, last_group_id);
__syncthreads();
}
}
@@ -1816,7 +1889,7 @@ __global__ void Marlin(
if constexpr (m_block_size_8) {
int idx = (threadIdx.x / 4) % 2;
scalar_t2* frag_s_half2 = reinterpret_cast<scalar_t2*>(frag_s);
#pragma unroll
#pragma unroll
for (int i = 0; i < 8; i++) {
frag_s_half2[i] = Dtype::num2num2(
reinterpret_cast<scalar_t*>(&frag_s_half2[i])[idx]);
@@ -1833,9 +1906,9 @@ __global__ void Marlin(
w_type.size_bits() == 8 &&
(has_zp && dequant_skip_flop || !has_zp)) {
if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
#pragma unroll
for (int j = 0; j < 4; j++) {
scale_float<scalar_t>(
reinterpret_cast<float*>(&frag_c[i][j][0][0]),
@@ -1896,11 +1969,11 @@ __global__ void Marlin(
if (slice_iters) {
a_gl_rd_col = (threadIdx.x % a_gl_rd_delta_o);
#pragma unroll
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
if (slice_col == 0) {
#pragma unroll
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
}
@@ -49,7 +49,6 @@ struct Tensor {
int64_t size(int64_t d) const { return raw_tensor_.dims().at(d); }
template <typename T>
T *data_ptr() const {
return const_cast<T *>(raw_tensor_.data<T>());
@@ -74,7 +73,6 @@ struct Tensor {
return raw_tensor_.place().GetType() == phi::AllocationType::GPU;
}
// MARLIN_NAMESPACE_NAME::ScalarType scalar_type() const {
// return raw_tensor_.dtype();
// }
+127 -124
View File
@@ -18,168 +18,171 @@
#pragma once
// dim3 grid(256)
// dim3 block(512)
template <typename T, int VecSize>
__global__ void swigluoai_interleave_kernel(T* act_out,
const T* input,
const float alpha,
const float limit,
const int64_t seq_len,
const int64_t hidden_dim) {
int64_t tid = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
const int64_t num = seq_len * hidden_dim;
using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec0, src_vec1;
LoadT res_vec;
const T* input,
const float alpha,
const float limit,
const int64_t seq_len,
const int64_t hidden_dim) {
int64_t tid = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
const int64_t num = seq_len * hidden_dim;
using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec0, src_vec1;
LoadT res_vec;
int64_t vec_num = hidden_dim / VecSize * seq_len;
int64_t col_size = hidden_dim / VecSize;
int64_t times = (vec_num - 1) / (gridDim.x * blockDim.x) + 1;
int64_t vec_num = hidden_dim / VecSize * seq_len;
int64_t col_size = hidden_dim / VecSize;
int64_t times = (vec_num - 1) / (gridDim.x * blockDim.x) + 1;
for(int i = 0; i < times; i++)
{
int64_t index = tid + i * gridDim.x * blockDim.x ;
int64_t row = index / col_size;
int64_t col = index % col_size;
for (int i = 0; i < times; i++) {
int64_t index = tid + i * gridDim.x * blockDim.x;
int64_t row = index / col_size;
int64_t col = index % col_size;
if(row < seq_len && col < col_size)
{
Load<T, VecSize>(&input[row*hidden_dim*2 + col*VecSize*2], &src_vec0);
Load<T, VecSize>(&input[row*hidden_dim*2 + col*VecSize*2 + VecSize], &src_vec1);
if (row < seq_len && col < col_size) {
Load<T, VecSize>(&input[row * hidden_dim * 2 + col * VecSize * 2],
&src_vec0);
Load<T, VecSize>(
&input[row * hidden_dim * 2 + col * VecSize * 2 + VecSize],
&src_vec1);
for (int j = 0; j < VecSize/2; ++j) {
float a = static_cast<float>(src_vec0[2*j]);
float b = static_cast<float>(src_vec0[2*j + 1]);
a = fminf(a, limit);
b = fminf(fmaxf(b,-limit), limit);
float res = (b + 1) * a / (1.f + expf(-a * alpha));
res_vec[j] = static_cast<T>(res);
}
for (int j = 0; j < VecSize/2; ++j) {
float a = static_cast<float>(src_vec1[2*j]);
float b = static_cast<float>(src_vec1[2*j + 1]);
a = fminf(a, limit);
b = fminf(fmaxf(b,-limit), limit);
float res = (b + 1) * a / (1.f + expf(-a * alpha));
res_vec[j + VecSize/2] = static_cast<T>(res);
}
for (int j = 0; j < VecSize / 2; ++j) {
float a = static_cast<float>(src_vec0[2 * j]);
float b = static_cast<float>(src_vec0[2 * j + 1]);
a = fminf(a, limit);
b = fminf(fmaxf(b, -limit), limit);
float res = (b + 1) * a / (1.f + expf(-a * alpha));
res_vec[j] = static_cast<T>(res);
}
for (int j = 0; j < VecSize / 2; ++j) {
float a = static_cast<float>(src_vec1[2 * j]);
float b = static_cast<float>(src_vec1[2 * j + 1]);
a = fminf(a, limit);
b = fminf(fmaxf(b, -limit), limit);
float res = (b + 1) * a / (1.f + expf(-a * alpha));
res_vec[j + VecSize / 2] = static_cast<T>(res);
}
Store<T, VecSize>(res_vec, &act_out[row*hidden_dim + col*VecSize]);
}
Store<T, VecSize>(res_vec, &act_out[row * hidden_dim + col * VecSize]);
}
}
}
// dim3 grid(256)
// dim3 block(512)
template <typename T, int VecSize>
__global__ void swigluoai_norm_kernel(T* act_out,
const T* input,
const float alpha,
const float limit,
const int64_t seq_len,
const int64_t hidden_dim) {
int64_t tid = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
const int64_t num = seq_len * hidden_dim;
using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec0, src_vec1;
LoadT res_vec;
const T* input,
const float alpha,
const float limit,
const int64_t seq_len,
const int64_t hidden_dim) {
int64_t tid = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
const int64_t num = seq_len * hidden_dim;
using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec0, src_vec1;
LoadT res_vec;
int64_t vec_num = hidden_dim / VecSize * seq_len;
int64_t col_size = hidden_dim / VecSize;
int64_t times = (vec_num - 1) / (gridDim.x * blockDim.x) + 1;
int64_t vec_num = hidden_dim / VecSize * seq_len;
int64_t col_size = hidden_dim / VecSize;
int64_t times = (vec_num - 1) / (gridDim.x * blockDim.x) + 1;
for(int i = 0; i < times; i++)
{
int64_t index = tid + i * gridDim.x * blockDim.x ;
int64_t row = index / col_size;
int64_t col = index % col_size;
for (int i = 0; i < times; i++) {
int64_t index = tid + i * gridDim.x * blockDim.x;
int64_t row = index / col_size;
int64_t col = index % col_size;
if(row < seq_len && col < col_size)
{
Load<T, VecSize>(&input[row*hidden_dim*2 + col*VecSize], &src_vec0);
Load<T, VecSize>(&input[row*hidden_dim*2 + hidden_dim + col*VecSize], &src_vec1);
if (row < seq_len && col < col_size) {
Load<T, VecSize>(&input[row * hidden_dim * 2 + col * VecSize], &src_vec0);
Load<T, VecSize>(
&input[row * hidden_dim * 2 + hidden_dim + col * VecSize], &src_vec1);
for (int j = 0; j < VecSize; ++j) {
float a = static_cast<float>(src_vec0[j]);
float b = static_cast<float>(src_vec1[j]);
float z = fminf(fmaxf(a * alpha, -limit), limit);
float res = b * a / (1.f + expf(-z));
res_vec[j] = static_cast<T>(res);
}
for (int j = 0; j < VecSize; ++j) {
float a = static_cast<float>(src_vec0[j]);
float b = static_cast<float>(src_vec1[j]);
float z = fminf(fmaxf(a * alpha, -limit), limit);
float res = b * a / (1.f + expf(-z));
res_vec[j] = static_cast<T>(res);
}
Store<T, VecSize>(res_vec, &act_out[row*hidden_dim + col*VecSize]);
}
Store<T, VecSize>(res_vec, &act_out[row * hidden_dim + col * VecSize]);
}
}
}
paddle::Tensor SwigluOAI(const paddle::Tensor &fc1_out_tensor, const float alpha, const float limit, const std::string& type)
{
// const int64_t group_size = fc1_out_tensor.shape()[1];
const int64_t seq_len = fc1_out_tensor.shape()[0];
const int64_t hidden_dim = fc1_out_tensor.shape()[1] / 2;
auto act_out_tensor = GetEmptyTensor({seq_len, hidden_dim}, fc1_out_tensor.dtype(), fc1_out_tensor.place());
paddle::Tensor SwigluOAI(const paddle::Tensor& fc1_out_tensor,
const float alpha,
const float limit,
const std::string& type) {
// const int64_t group_size = fc1_out_tensor.shape()[1];
const int64_t seq_len = fc1_out_tensor.shape()[0];
const int64_t hidden_dim = fc1_out_tensor.shape()[1] / 2;
auto act_out_tensor = GetEmptyTensor(
{seq_len, hidden_dim}, fc1_out_tensor.dtype(), fc1_out_tensor.place());
constexpr int VecSize = 8;
PD_CHECK(fc1_out_tensor.dtype() == paddle::DataType::BFLOAT16);
PD_CHECK(hidden_dim % VecSize == 0);
constexpr int VecSize = 8;
PD_CHECK(fc1_out_tensor.dtype() == paddle::DataType::BFLOAT16);
PD_CHECK(hidden_dim % VecSize == 0);
constexpr paddle::DataType D = paddle::DataType::BFLOAT16;
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
constexpr paddle::DataType D = paddle::DataType::BFLOAT16;
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
const int block_size = 512;
const int grid_size = 256;
const int block_size = 512;
const int grid_size = 256;
#define dispatch_norm() do {\
swigluoai_norm_kernel<DataType_, VecSize><<<grid_size, block_size, 0, fc1_out_tensor.stream()>>>(\
reinterpret_cast<DataType_*>(const_cast<data_t*>(act_out_tensor.data<data_t>())),\
reinterpret_cast<const DataType_*>(fc1_out_tensor.data<data_t>()),\
alpha,\
limit,\
seq_len,\
hidden_dim\
);} while(0)
#define dispatch_norm() \
do { \
swigluoai_norm_kernel<DataType_, VecSize> \
<<<grid_size, block_size, 0, fc1_out_tensor.stream()>>>( \
reinterpret_cast<DataType_*>( \
const_cast<data_t*>(act_out_tensor.data<data_t>())), \
reinterpret_cast<const DataType_*>(fc1_out_tensor.data<data_t>()), \
alpha, \
limit, \
seq_len, \
hidden_dim); \
} while (0)
#define dispatch_interleave() do {\
swigluoai_interleave_kernel<DataType_, VecSize><<<grid_size, block_size, 0, fc1_out_tensor.stream()>>>(\
reinterpret_cast<DataType_*>(const_cast<data_t*>(act_out_tensor.data<data_t>())),\
reinterpret_cast<const DataType_*>(fc1_out_tensor.data<data_t>()),\
alpha,\
limit,\
seq_len,\
hidden_dim\
);} while(0)
#define dispatch_interleave() \
do { \
swigluoai_interleave_kernel<DataType_, VecSize> \
<<<grid_size, block_size, 0, fc1_out_tensor.stream()>>>( \
reinterpret_cast<DataType_*>( \
const_cast<data_t*>(act_out_tensor.data<data_t>())), \
reinterpret_cast<const DataType_*>(fc1_out_tensor.data<data_t>()), \
alpha, \
limit, \
seq_len, \
hidden_dim); \
} while (0)
if(type == "interleave")
{
dispatch_interleave();
}
else
{
dispatch_norm();
}
// if (token_nums_per_expert.dtype() == paddle::DataType::INT64) {
// dispatch_by_index(int64_t);
// } else if(token_nums_per_expert.dtype() == paddle::DataType::INT32) {
// dispatch_by_index(int32_t);
// } else {
// PD_THROW("Unsupported token_nums_per_expert's data dtype.");
// }
if (type == "interleave") {
dispatch_interleave();
} else {
dispatch_norm();
}
// if (token_nums_per_expert.dtype() == paddle::DataType::INT64) {
// dispatch_by_index(int64_t);
// } else if(token_nums_per_expert.dtype() == paddle::DataType::INT32) {
// dispatch_by_index(int32_t);
// } else {
// PD_THROW("Unsupported token_nums_per_expert's data dtype.");
// }
return act_out_tensor;
return act_out_tensor;
}
std::vector<paddle::Tensor> SwigluOAIWrapper(
const paddle::Tensor& fc1_out_tensor,
const float alpha,
const float limit,
const std::string& type) {
return {SwigluOAI(fc1_out_tensor, alpha, limit, type)};
return {SwigluOAI(fc1_out_tensor, alpha, limit, type)};
}
PD_BUILD_STATIC_OP(swigluoai)
+4 -2
View File
@@ -15,5 +15,7 @@
#pragma once
#include "helper.h"
paddle::Tensor
SwigluOAI(const paddle::Tensor &fc1_out_tensor, const float alpha, const float limit, const std::string& type);
paddle::Tensor SwigluOAI(const paddle::Tensor& fc1_out_tensor,
const float alpha,
const float limit,
const std::string& type);
+95 -96
View File
@@ -15,13 +15,14 @@
#include "helper.h"
#include "paddle/extension.h"
#define CEILDIV(a,b) (((a+b-1)/b))
#define CEILDIV(a, b) (((a + b - 1) / b))
template <typename scalar_t>
__global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids,
int32_t* __restrict__ cumsum_buffer,
size_t numel) {
__global__ void count_and_sort_expert_tokens_kernel(
const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids,
int32_t* __restrict__ cumsum_buffer,
size_t numel) {
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
@@ -33,16 +34,17 @@ __global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__
}
template <typename scalar_t, int num_experts>
__global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ expert_ids,
int32_t* __restrict__ total_tokens_post_pad,
int32_t GEMM_BLOCK_SIZE_M,
size_t numel,
int32_t* __restrict__ cumsum_buffer) {
__global__ void moe_align_block_size_kernel(
const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ expert_ids,
int32_t* __restrict__ total_tokens_post_pad,
int32_t GEMM_BLOCK_SIZE_M,
size_t numel,
int32_t* __restrict__ cumsum_buffer) {
__shared__ int32_t tokens_per_ep[num_experts];
for (int i = threadIdx.x; i < num_experts; i += blockDim.x) {
tokens_per_ep[i] = 0;
tokens_per_ep[i] = 0;
}
__syncthreads();
@@ -57,8 +59,10 @@ __global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_id
if (threadIdx.x == 0) {
cumsum_buffer[0] = 0;
for (int i = 1; i <= num_experts; ++i) {
int expert_count = tokens_per_ep[i-1];
cumsum_buffer[i] = cumsum_buffer[i - 1] + CEILDIV(expert_count, GEMM_BLOCK_SIZE_M) * GEMM_BLOCK_SIZE_M;
int expert_count = tokens_per_ep[i - 1];
cumsum_buffer[i] =
cumsum_buffer[i - 1] +
CEILDIV(expert_count, GEMM_BLOCK_SIZE_M) * GEMM_BLOCK_SIZE_M;
}
*total_tokens_post_pad = cumsum_buffer[num_experts];
}
@@ -66,33 +70,39 @@ __global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_id
__syncthreads();
if (threadIdx.x < num_experts) {
for (int i = cumsum_buffer[threadIdx.x]; i < cumsum_buffer[threadIdx.x + 1]; i += GEMM_BLOCK_SIZE_M) {
for (int i = cumsum_buffer[threadIdx.x]; i < cumsum_buffer[threadIdx.x + 1];
i += GEMM_BLOCK_SIZE_M) {
expert_ids[i / GEMM_BLOCK_SIZE_M] = threadIdx.x;
}
}
}
std::vector<std::vector<int64_t>> tritonmoe_preprocessInferShape(
const std::vector<int64_t>& topk_ids,
int64_t num_experts,
int64_t GEMM_BLOCK_SIZE_M) {
int topk_ids_numel = topk_ids[0] * topk_ids[1];
int max_num_tokens_padded =
topk_ids_numel + num_experts * (GEMM_BLOCK_SIZE_M - 1);
std::vector<std::vector<int64_t>> tritonmoe_preprocessInferShape(const std::vector<int64_t>& topk_ids, int64_t num_experts, int64_t GEMM_BLOCK_SIZE_M) {
std::vector<int64_t> sorted_ids = {max_num_tokens_padded};
int max_num_m_blocks = max_num_tokens_padded / GEMM_BLOCK_SIZE_M;
std::vector<int64_t> expert_ids = {max_num_m_blocks};
std::vector<int64_t> num_tokens_post_pad = {1};
int topk_ids_numel = topk_ids[0] * topk_ids[1];
int max_num_tokens_padded = topk_ids_numel + num_experts * (GEMM_BLOCK_SIZE_M - 1);
std::vector<int64_t> sorted_ids = {max_num_tokens_padded};
int max_num_m_blocks = max_num_tokens_padded / GEMM_BLOCK_SIZE_M;
std::vector<int64_t> expert_ids = {max_num_m_blocks};
std::vector<int64_t> num_tokens_post_pad = {1};
return {sorted_ids, expert_ids, num_tokens_post_pad};
return {sorted_ids, expert_ids, num_tokens_post_pad};
}
std::vector<paddle::DataType> tritonmoe_preprocessIferDtype(const paddle::DataType& topk_ids, int64_t num_experts, int64_t GEMM_BLOCK_SIZE_M) {
return {paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32};
std::vector<paddle::DataType> tritonmoe_preprocessIferDtype(
const paddle::DataType& topk_ids,
int64_t num_experts,
int64_t GEMM_BLOCK_SIZE_M) {
return {paddle::DataType::INT32,
paddle::DataType::INT32,
paddle::DataType::INT32};
}
/*
supporse num_experts = 8, GEMM_BLOCK_SIZE_M = 4,
topk_ids.shape = [4,4], means=topk=4
@@ -113,85 +123,74 @@ Then return value `sorted_ids` is
0,16,16,16
*/
std::vector<paddle::Tensor> tritonmoe_preprocess_kernel(
const paddle::Tensor& topk_ids,
int64_t num_experts,
int64_t GEMM_BLOCK_SIZE_M) {
int topk_ids_numel = topk_ids.shape()[0] * topk_ids.shape()[1];
int max_num_tokens_padded =
topk_ids_numel + num_experts * (GEMM_BLOCK_SIZE_M - 1);
std::vector<paddle::Tensor> tritonmoe_preprocess_kernel(const paddle::Tensor& topk_ids, int64_t num_experts, int64_t GEMM_BLOCK_SIZE_M) {
auto sorted_ids = paddle::full({max_num_tokens_padded},
topk_ids_numel,
paddle::DataType::INT32,
topk_ids.place());
int topk_ids_numel = topk_ids.shape()[0] * topk_ids.shape()[1];
int max_num_tokens_padded = topk_ids_numel + num_experts * (GEMM_BLOCK_SIZE_M - 1);
int max_num_m_blocks = max_num_tokens_padded / GEMM_BLOCK_SIZE_M;
auto sorted_ids = paddle::full(
{max_num_tokens_padded},
topk_ids_numel,
paddle::DataType::INT32,
topk_ids.place()
);
auto expert_ids = paddle::empty(
{max_num_m_blocks}, paddle::DataType::INT32, topk_ids.place());
int max_num_m_blocks = max_num_tokens_padded / GEMM_BLOCK_SIZE_M;
auto num_tokens_post_pad =
paddle::empty({1}, paddle::DataType::INT32, topk_ids.place());
auto expert_ids = paddle::empty(
{max_num_m_blocks}, paddle::DataType::INT32,
topk_ids.place()
);
auto cumsum_buffer = paddle::empty(
{num_experts + 1}, paddle::DataType::INT32, topk_ids.place());
auto num_tokens_post_pad = paddle::empty(
{1},
paddle::DataType::INT32,
topk_ids.place()
);
auto stream = topk_ids.stream();
using scalar_t = int64_t;
auto cumsum_buffer = paddle::empty(
{num_experts + 1},
paddle::DataType::INT32,
topk_ids.place()
);
#define run_align_kernel(num_experts) \
auto align_kernel = moe_align_block_size_kernel<scalar_t, num_experts>; \
align_kernel<<<1, 1024, 0, stream>>>(topk_ids.data<scalar_t>(), \
expert_ids.data<int32_t>(), \
num_tokens_post_pad.data<int32_t>(), \
GEMM_BLOCK_SIZE_M, \
topk_ids_numel, \
cumsum_buffer.data<int32_t>());
auto stream = topk_ids.stream();
using scalar_t = int64_t;
if (num_experts == 8) {
run_align_kernel(8);
} else if (num_experts == 256) {
run_align_kernel(256);
} else if (num_experts == 2) {
run_align_kernel(2);
} else if (num_experts == 64) {
run_align_kernel(64);
} else if (num_experts == 128) {
run_align_kernel(128);
} else if (num_experts == 160) {
run_align_kernel(160);
} else if (num_experts == 32) {
run_align_kernel(32);
} else {
PD_THROW("Not support num_experts: %d", num_experts);
}
# define run_align_kernel(num_experts) \
auto align_kernel = moe_align_block_size_kernel<scalar_t, num_experts>; \
align_kernel<<<1, 1024, 0, stream>>>( \
topk_ids.data<scalar_t>(), \
expert_ids.data<int32_t>(), \
num_tokens_post_pad.data<int32_t>(), \
GEMM_BLOCK_SIZE_M, \
topk_ids_numel, \
cumsum_buffer.data<int32_t>());
const int block_threads = 256;
const int num_blocks = CEILDIV(topk_ids_numel, block_threads);
const int max_blocks = 65535;
const int actual_blocks = std::min(num_blocks, max_blocks);
if (num_experts == 8) {
run_align_kernel(8);
} else if (num_experts == 256) {
run_align_kernel(256);
} else if (num_experts == 2) {
run_align_kernel(2);
} else if (num_experts == 64) {
run_align_kernel(64);
} else if (num_experts == 128) {
run_align_kernel(128);
} else if (num_experts == 160) {
run_align_kernel(160);
} else if (num_experts == 32) {
run_align_kernel(32);
}
else {
PD_THROW("Not support num_experts: %d", num_experts);
}
auto sort_kernel = count_and_sort_expert_tokens_kernel<scalar_t>;
const int block_threads = 256;
const int num_blocks = CEILDIV(topk_ids_numel, block_threads);
const int max_blocks = 65535;
const int actual_blocks = std::min(num_blocks, max_blocks);
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(
topk_ids.data<scalar_t>(),
sorted_ids.data<int32_t>(),
cumsum_buffer.data<int32_t>(),
topk_ids_numel);
auto sort_kernel = count_and_sort_expert_tokens_kernel<scalar_t>;
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(topk_ids.data<scalar_t>(),
sorted_ids.data<int32_t>(),
cumsum_buffer.data<int32_t>(),
topk_ids_numel);
return {sorted_ids, expert_ids, num_tokens_post_pad};
return {sorted_ids, expert_ids, num_tokens_post_pad};
}
PD_BUILD_STATIC_OP(tritonmoe_preprocess)
+112 -58
View File
@@ -17,12 +17,17 @@
template <typename T, int TileRows, int TileColumns, int NumThreads>
__global__ void Wint25UnzipKernel(const uint16_t *zipped_weight_ptr,
const T *super_scale_ptr, T *weight_ptr,
const int64_t batch, const int64_t num_rows,
const T *super_scale_ptr,
T *weight_ptr,
const int64_t batch,
const int64_t num_rows,
const int64_t num_columns) {
using UnzipFunctor =
cutlass::gemm::threadblock::UnzipAndDequantFunctor<T, cutlass::WintQuantMethod::kWeightOnlyInt25, TileRows,
TileColumns, NumThreads>;
using UnzipFunctor = cutlass::gemm::threadblock::UnzipAndDequantFunctor<
T,
cutlass::WintQuantMethod::kWeightOnlyInt25,
TileRows,
TileColumns,
NumThreads>;
__shared__ T smem[TileRows * TileColumns];
@@ -41,7 +46,8 @@ __global__ void Wint25UnzipKernel(const uint16_t *zipped_weight_ptr,
// unzip to shared memory
UnzipFunctor unzip_functor;
unzip_functor(block_zipped_weight_ptr, block_super_scale_ptr, smem, num_columns);
unzip_functor(
block_zipped_weight_ptr, block_super_scale_ptr, smem, num_columns);
// write back to global memory
for (int row = 0; row < TileRows; ++row) {
@@ -55,19 +61,26 @@ __global__ void Wint25UnzipKernel(const uint16_t *zipped_weight_ptr,
}
template <typename T, int64_t TileRows, int64_t TileColumns, int NumThreads>
__global__ void
Wint2UnzipKernel(const uint8_t *zipped_weight_ptr,
const uint8_t *local_scale_ptr, const float *code_scale_ptr,
const float *code_zp_ptr, const T *super_scale_ptr,
T *weight_ptr, const int64_t batch, const int64_t num_rows,
const int64_t num_columns) {
using UnzipFunctor =
cutlass::gemm::threadblock::UnzipAndDequantFunctor<T, cutlass::WintQuantMethod::kWeightOnlyInt2, TileRows,
TileColumns, NumThreads>;
__global__ void Wint2UnzipKernel(const uint8_t *zipped_weight_ptr,
const uint8_t *local_scale_ptr,
const float *code_scale_ptr,
const float *code_zp_ptr,
const T *super_scale_ptr,
T *weight_ptr,
const int64_t batch,
const int64_t num_rows,
const int64_t num_columns) {
using UnzipFunctor = cutlass::gemm::threadblock::UnzipAndDequantFunctor<
T,
cutlass::WintQuantMethod::kWeightOnlyInt2,
TileRows,
TileColumns,
NumThreads>;
constexpr bool kUseAsyncLoad = true;
__shared__ uint8_t zipped_smem[UnzipFunctor::kZippedSmemBytes + UnzipFunctor::kColumnWiseSmemBytes];
__shared__ uint8_t zipped_smem[UnzipFunctor::kZippedSmemBytes +
UnzipFunctor::kColumnWiseSmemBytes];
__shared__ T smem[TileRows * TileColumns];
int64_t block_start_column = blockIdx.x * TileColumns;
@@ -95,15 +108,21 @@ Wint2UnzipKernel(const uint8_t *zipped_weight_ptr,
? super_scale_ptr + blockIdx.z * num_columns + block_start_column
: nullptr;
typename UnzipFunctor::Arguments args(zipped_smem, zipped_smem + UnzipFunctor::kZippedSmemBytes);
typename UnzipFunctor::Arguments args(
zipped_smem, zipped_smem + UnzipFunctor::kZippedSmemBytes);
// unzip to shared memory
UnzipFunctor functor;
if (kUseAsyncLoad) {
functor.LoadAsync(block_zipped_weight_ptr, block_local_scale_ptr,
block_code_scale_ptr, block_code_zp_ptr, block_super_scale_ptr,
&args, num_columns, true);
functor.LoadAsync(block_zipped_weight_ptr,
block_local_scale_ptr,
block_code_scale_ptr,
block_code_zp_ptr,
block_super_scale_ptr,
&args,
num_columns,
true);
// 发起 cp.async 的收束
cutlass::arch::cp_async_fence();
@@ -112,9 +131,14 @@ Wint2UnzipKernel(const uint8_t *zipped_weight_ptr,
cutlass::arch::cp_async_wait<0>();
__syncthreads();
} else {
functor.Load(block_zipped_weight_ptr, block_local_scale_ptr,
block_code_scale_ptr, block_code_zp_ptr, block_super_scale_ptr,
&args, num_columns, true);
functor.Load(block_zipped_weight_ptr,
block_local_scale_ptr,
block_code_scale_ptr,
block_code_zp_ptr,
block_super_scale_ptr,
&args,
num_columns,
true);
}
functor.Compute(args, smem, block_start_row);
@@ -132,8 +156,10 @@ Wint2UnzipKernel(const uint8_t *zipped_weight_ptr,
template <typename T>
void Wint25UnzipKernelLauncher(const uint16_t *zipped_weight,
const T *supper_scale, T *weight,
const int64_t batch, const int64_t num_rows,
const T *supper_scale,
T *weight,
const int64_t batch,
const int64_t num_rows,
const int64_t num_columns) {
constexpr int kTileRows = 64;
constexpr int kTileColumns = 128;
@@ -146,16 +172,19 @@ void Wint25UnzipKernelLauncher(const uint16_t *zipped_weight,
dim3 grid_dim(block_dim_x, block_dim_y, batch);
Wint25UnzipKernel<T, kTileRows, kTileColumns, kNumThreads>
<<<grid_dim, block_dim>>>(zipped_weight, supper_scale, weight, batch,
num_rows, num_columns);
<<<grid_dim, block_dim>>>(
zipped_weight, supper_scale, weight, batch, num_rows, num_columns);
}
template <typename T>
void Wint2UnzipKernelLauncher(const uint8_t *zipped_weight,
const uint8_t *local_scale,
const float *code_scale, const float *code_zp,
const T *supper_scale, T *weight,
const int64_t batch, const int64_t num_rows,
const float *code_scale,
const float *code_zp,
const T *supper_scale,
T *weight,
const int64_t batch,
const int64_t num_rows,
const int64_t num_columns) {
constexpr int kTileRows = 64;
constexpr int kTileColumns = 256;
@@ -168,8 +197,14 @@ void Wint2UnzipKernelLauncher(const uint8_t *zipped_weight,
dim3 grid_dim(block_dim_x, block_dim_y, batch);
Wint2UnzipKernel<T, kTileRows, kTileColumns, kNumThreads>
<<<grid_dim, block_dim>>>(zipped_weight, local_scale, code_scale, code_zp,
supper_scale, weight, batch, num_rows,
<<<grid_dim, block_dim>>>(zipped_weight,
local_scale,
code_scale,
code_zp,
supper_scale,
weight,
batch,
num_rows,
num_columns);
}
@@ -179,7 +214,8 @@ void WintxUnzipKernel(const paddle::Tensor &zipped_weight,
const paddle::optional<paddle::Tensor> &code_scale,
const paddle::optional<paddle::Tensor> &code_zp,
const paddle::optional<paddle::Tensor> &super_scale,
paddle::Tensor &weight, const std::string &quant_method) {
paddle::Tensor &weight,
const std::string &quant_method) {
using data_t = typename PDTraits<T>::data_t;
using NvType = typename PDTraits<T>::DataType;
@@ -199,7 +235,10 @@ void WintxUnzipKernel(const paddle::Tensor &zipped_weight,
Wint25UnzipKernelLauncher<NvType>(
reinterpret_cast<const uint16_t *>(zipped_weight_ptr),
reinterpret_cast<const NvType *>(super_scale_ptr),
reinterpret_cast<NvType *>(weight_ptr), batch, num_rows, num_columns);
reinterpret_cast<NvType *>(weight_ptr),
batch,
num_rows,
num_columns);
} else if (quant_method == "weight_only_int2") {
paddle::Tensor *local_scale_tensor =
const_cast<paddle::Tensor *>(local_scale.get_ptr());
@@ -209,22 +248,27 @@ void WintxUnzipKernel(const paddle::Tensor &zipped_weight,
const_cast<paddle::Tensor *>(code_zp.get_ptr());
Wint2UnzipKernelLauncher<NvType>(
zipped_weight.data<uint8_t>(), local_scale_tensor->data<uint8_t>(),
code_scale_tensor->data<float>(), code_zp_tensor->data<float>(),
zipped_weight.data<uint8_t>(),
local_scale_tensor->data<uint8_t>(),
code_scale_tensor->data<float>(),
code_zp_tensor->data<float>(),
reinterpret_cast<const NvType *>(super_scale_ptr),
reinterpret_cast<NvType *>(weight_ptr), batch, num_rows, num_columns);
reinterpret_cast<NvType *>(weight_ptr),
batch,
num_rows,
num_columns);
} else {
PD_THROW("Unsupported quant_method for WintxUnzip.");
}
}
std::vector<paddle::Tensor>
WintXUnzip(const paddle::Tensor &zipped_weight,
const paddle::optional<paddle::Tensor> &local_scale,
const paddle::optional<paddle::Tensor> &code_scale,
const paddle::optional<paddle::Tensor> &code_zp,
const paddle::optional<paddle::Tensor> &super_scale,
const std::string &quant_method) {
std::vector<paddle::Tensor> WintXUnzip(
const paddle::Tensor &zipped_weight,
const paddle::optional<paddle::Tensor> &local_scale,
const paddle::optional<paddle::Tensor> &code_scale,
const paddle::optional<paddle::Tensor> &code_zp,
const paddle::optional<paddle::Tensor> &super_scale,
const std::string &quant_method) {
paddle::Tensor *local_scale_tensor =
const_cast<paddle::Tensor *>(local_scale.get_ptr());
paddle::Tensor *super_scale_tensor =
@@ -251,18 +295,26 @@ WintXUnzip(const paddle::Tensor &zipped_weight,
auto output_tensor = GetEmptyTensor(output_dims, dtype, place);
switch (dtype) {
case paddle::DataType::BFLOAT16:
WintxUnzipKernel<paddle::DataType::BFLOAT16>(
zipped_weight, local_scale, code_scale, code_zp, super_scale,
output_tensor, quant_method);
break;
case paddle::DataType::FLOAT16:
WintxUnzipKernel<paddle::DataType::FLOAT16>(
zipped_weight, local_scale, code_scale, code_zp, super_scale,
output_tensor, quant_method);
break;
default:
PD_THROW("Unsupported data type for WintxUnzip");
case paddle::DataType::BFLOAT16:
WintxUnzipKernel<paddle::DataType::BFLOAT16>(zipped_weight,
local_scale,
code_scale,
code_zp,
super_scale,
output_tensor,
quant_method);
break;
case paddle::DataType::FLOAT16:
WintxUnzipKernel<paddle::DataType::FLOAT16>(zipped_weight,
local_scale,
code_scale,
code_zp,
super_scale,
output_tensor,
quant_method);
break;
default:
PD_THROW("Unsupported data type for WintxUnzip");
}
return {output_tensor};
}
@@ -306,8 +358,10 @@ std::vector<paddle::DataType> WintXUnzipInferDtype(
}
PD_BUILD_STATIC_OP(winx_unzip)
.Inputs({"zipped_weight", paddle::Optional("local_scale"),
paddle::Optional("code_scale"), paddle::Optional("code_zp"),
.Inputs({"zipped_weight",
paddle::Optional("local_scale"),
paddle::Optional("code_scale"),
paddle::Optional("code_zp"),
paddle::Optional("super_scale")})
.Outputs({"weight"})
.Attrs({"quant_method:std::string"})