mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
@@ -28,13 +28,13 @@ __global__ void compute_total_rows_before_expert_kernel(
|
||||
const int64_t sorted_experts_len,
|
||||
const int64_t num_experts,
|
||||
int64_t* total_rows_before_expert) {
|
||||
// First, compute the global tid. We only need 1 thread per expert.
|
||||
const int expert = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (expert >= num_experts) return;
|
||||
// First, compute the global tid. We only need 1 thread per expert.
|
||||
const int expert = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (expert >= num_experts) return;
|
||||
|
||||
// This should construct the last index where each expert occurs.
|
||||
total_rows_before_expert[expert] =
|
||||
find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert);
|
||||
// This should construct the last index where each expert occurs.
|
||||
total_rows_before_expert[expert] =
|
||||
find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert);
|
||||
}
|
||||
|
||||
void compute_total_rows_before_expert(int* sorted_indices,
|
||||
@@ -42,11 +42,11 @@ void compute_total_rows_before_expert(int* sorted_indices,
|
||||
const int64_t num_experts,
|
||||
int64_t* total_rows_before_expert,
|
||||
cudaStream_t stream) {
|
||||
const int threads = std::min(int64_t(1024), num_experts);
|
||||
const int blocks = (num_experts + threads - 1) / threads;
|
||||
const int threads = std::min(int64_t(1024), num_experts);
|
||||
const int blocks = (num_experts + threads - 1) / threads;
|
||||
|
||||
compute_total_rows_before_expert_kernel<<<blocks, threads, 0, stream>>>(
|
||||
sorted_indices, total_indices, num_experts, total_rows_before_expert);
|
||||
compute_total_rows_before_expert_kernel<<<blocks, threads, 0, stream>>>(
|
||||
sorted_indices, total_indices, num_experts, total_rows_before_expert);
|
||||
}
|
||||
|
||||
} // namespace phi
|
||||
@@ -65,38 +65,47 @@ void FusedMoeKernel(const paddle::Tensor& input,
|
||||
const bool group_moe,
|
||||
const bool norm_topk_prob,
|
||||
paddle::Tensor* output) {
|
||||
using namespace phi;
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
using namespace phi;
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
auto* output_data = output->data<data_t>();
|
||||
auto* output_data = output->data<data_t>();
|
||||
|
||||
auto fp16_moe_gemm_runner = MoeGemmRunner<DataType_, cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kNone>>();
|
||||
auto int8_moe_gemm_runner = MoeGemmRunner<DataType_, cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt8>>();
|
||||
auto int4_moe_gemm_runner = MoeGemmRunner<DataType_, cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt4>>();
|
||||
auto fp16_moe_gemm_runner = MoeGemmRunner<
|
||||
DataType_,
|
||||
cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kNone>>();
|
||||
auto int8_moe_gemm_runner = MoeGemmRunner<
|
||||
DataType_,
|
||||
cutlass::WintQuantTraits<DataType_,
|
||||
cutlass::WintQuantMethod::kWeightOnlyInt8>>();
|
||||
auto int4_moe_gemm_runner = MoeGemmRunner<
|
||||
DataType_,
|
||||
cutlass::WintQuantTraits<DataType_,
|
||||
cutlass::WintQuantMethod::kWeightOnlyInt4>>();
|
||||
|
||||
using NvType = typename traits_::DataType;
|
||||
auto moe_compute = MoeHelper<data_t, NvType>(quant_method,
|
||||
&fp16_moe_gemm_runner,
|
||||
&int8_moe_gemm_runner,
|
||||
&int4_moe_gemm_runner);
|
||||
using NvType = typename traits_::DataType;
|
||||
auto moe_compute = MoeHelper<data_t, NvType>(quant_method,
|
||||
&fp16_moe_gemm_runner,
|
||||
&int8_moe_gemm_runner,
|
||||
&int4_moe_gemm_runner);
|
||||
|
||||
moe_compute.ComputeFFN(&input,
|
||||
&gate_weight,
|
||||
&up_gate_proj_weight,
|
||||
up_gate_proj_scale ? up_gate_proj_scale.get_ptr() : nullptr,
|
||||
up_gate_proj_bias ? up_gate_proj_bias.get_ptr() : nullptr,
|
||||
&down_proj_weight,
|
||||
down_proj_scale ? down_proj_scale.get_ptr() : nullptr,
|
||||
down_proj_bias ? down_proj_bias.get_ptr() : nullptr,
|
||||
nullptr,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
norm_topk_prob,
|
||||
1.0, // ComputeFFN
|
||||
"ffn",
|
||||
output);
|
||||
moe_compute.ComputeFFN(
|
||||
&input,
|
||||
&gate_weight,
|
||||
&up_gate_proj_weight,
|
||||
up_gate_proj_scale ? up_gate_proj_scale.get_ptr() : nullptr,
|
||||
up_gate_proj_bias ? up_gate_proj_bias.get_ptr() : nullptr,
|
||||
&down_proj_weight,
|
||||
down_proj_scale ? down_proj_scale.get_ptr() : nullptr,
|
||||
down_proj_bias ? down_proj_bias.get_ptr() : nullptr,
|
||||
nullptr,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
norm_topk_prob,
|
||||
1.0, // ComputeFFN
|
||||
"ffn",
|
||||
output);
|
||||
}
|
||||
|
||||
paddle::Tensor FusedExpertMoeFunc(
|
||||
@@ -112,44 +121,44 @@ paddle::Tensor FusedExpertMoeFunc(
|
||||
const int moe_topk,
|
||||
const bool norm_topk_prob,
|
||||
const bool group_moe) {
|
||||
const auto input_type = input.dtype();
|
||||
auto output = paddle::empty_like(input);
|
||||
const auto input_type = input.dtype();
|
||||
auto output = paddle::empty_like(input);
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
FusedMoeKernel<paddle::DataType::BFLOAT16>(input,
|
||||
gate_weight,
|
||||
up_gate_proj_weight,
|
||||
up_gate_proj_scale,
|
||||
up_gate_proj_bias,
|
||||
down_proj_weight,
|
||||
down_proj_scale,
|
||||
down_proj_bias,
|
||||
quant_method,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
norm_topk_prob,
|
||||
&output);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
FusedMoeKernel<paddle::DataType::FLOAT16>(input,
|
||||
gate_weight,
|
||||
up_gate_proj_weight,
|
||||
up_gate_proj_scale,
|
||||
up_gate_proj_bias,
|
||||
down_proj_weight,
|
||||
down_proj_scale,
|
||||
down_proj_bias,
|
||||
quant_method,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
norm_topk_prob,
|
||||
&output);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for FusedMoeKernel");
|
||||
}
|
||||
return output;
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
FusedMoeKernel<paddle::DataType::BFLOAT16>(input,
|
||||
gate_weight,
|
||||
up_gate_proj_weight,
|
||||
up_gate_proj_scale,
|
||||
up_gate_proj_bias,
|
||||
down_proj_weight,
|
||||
down_proj_scale,
|
||||
down_proj_bias,
|
||||
quant_method,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
norm_topk_prob,
|
||||
&output);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
FusedMoeKernel<paddle::DataType::FLOAT16>(input,
|
||||
gate_weight,
|
||||
up_gate_proj_weight,
|
||||
up_gate_proj_scale,
|
||||
up_gate_proj_bias,
|
||||
down_proj_weight,
|
||||
down_proj_scale,
|
||||
down_proj_bias,
|
||||
quant_method,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
norm_topk_prob,
|
||||
&output);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for FusedMoeKernel");
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> FusedExpertMoe(
|
||||
@@ -165,18 +174,18 @@ std::vector<paddle::Tensor> FusedExpertMoe(
|
||||
const int moe_topk,
|
||||
const bool norm_topk_prob,
|
||||
const bool group_moe) {
|
||||
return {FusedExpertMoeFunc(input,
|
||||
gate_weight,
|
||||
up_gate_proj_weight,
|
||||
down_proj_weight,
|
||||
up_gate_proj_bias,
|
||||
up_gate_proj_scale,
|
||||
down_proj_bias,
|
||||
down_proj_scale,
|
||||
quant_method,
|
||||
moe_topk,
|
||||
norm_topk_prob,
|
||||
group_moe)};
|
||||
return {FusedExpertMoeFunc(input,
|
||||
gate_weight,
|
||||
up_gate_proj_weight,
|
||||
down_proj_weight,
|
||||
up_gate_proj_bias,
|
||||
up_gate_proj_scale,
|
||||
down_proj_bias,
|
||||
down_proj_scale,
|
||||
quant_method,
|
||||
moe_topk,
|
||||
norm_topk_prob,
|
||||
group_moe)};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> FusedExpertMoeInferShape(
|
||||
@@ -188,7 +197,7 @@ std::vector<std::vector<int64_t>> FusedExpertMoeInferShape(
|
||||
const paddle::optional<std::vector<int64_t>>& up_gate_proj_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& down_proj_bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& down_proj_scale_shape) {
|
||||
return {input_shape};
|
||||
return {input_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> FusedExpertMoeInferDtype(
|
||||
@@ -200,13 +209,14 @@ std::vector<paddle::DataType> FusedExpertMoeInferDtype(
|
||||
const paddle::optional<paddle::DataType>& up_gate_proj_scale_dtype,
|
||||
const paddle::optional<paddle::DataType>& down_proj_bias_dtype,
|
||||
const paddle::optional<paddle::DataType>& down_proj_scale_dtype) {
|
||||
return {input_dtype};
|
||||
return {input_dtype};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Fused Mixture-of-Experts (MoE) Operator
|
||||
*
|
||||
* This operator combines three key MoE operations into a single optimized kernel:
|
||||
* This operator combines three key MoE operations into a single optimized
|
||||
* kernel:
|
||||
* 1. moe_dispatch - Routes tokens to top-k experts using gating network
|
||||
* 2. moe_ffn - Processes tokens through parallel expert FFNs
|
||||
* 3. moe_reduce - Combines expert outputs with routing weights
|
||||
@@ -219,9 +229,10 @@ std::vector<paddle::DataType> FusedExpertMoeInferDtype(
|
||||
* output = ∑_i^topk(softmax(gate(x))_i * FFN_i(x)
|
||||
*
|
||||
* Reference Components:
|
||||
* moe_dispatch: Selects top-k experts per token and generates permutation indices
|
||||
* moe_ffn: Applies SwiGLU activation expert networks in parallel
|
||||
* moe_reduce: Combines weighted expert outputs and restores original token order
|
||||
* moe_dispatch: Selects top-k experts per token and generates permutation
|
||||
* indices moe_ffn: Applies SwiGLU activation expert networks in parallel
|
||||
* moe_reduce: Combines weighted expert outputs and restores original token
|
||||
* order
|
||||
*
|
||||
* Performance Notes:
|
||||
* - Recommended hidden_size multiples of 128 for optimal memory alignment
|
||||
|
||||
@@ -7,8 +7,10 @@ namespace MARLIN_NAMESPACE_NAME {
|
||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
||||
__global__ void gptq_marlin_repack_kernel(
|
||||
uint32_t const* __restrict__ b_q_weight_ptr,
|
||||
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
|
||||
int size_k, int size_n) {
|
||||
uint32_t const* __restrict__ perm_ptr,
|
||||
uint32_t* __restrict__ out_ptr,
|
||||
int size_k,
|
||||
int size_n) {
|
||||
constexpr int pack_factor = 32 / num_bits;
|
||||
|
||||
int k_tiles = size_k / tile_k_size;
|
||||
@@ -224,7 +226,8 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
while (n_tile_id < n_tiles) {
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < repack_stages; pipe++) {
|
||||
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
|
||||
fetch_to_shared((pipe + repack_stages - 1) % repack_stages,
|
||||
k_tile_id,
|
||||
n_tile_id + pipe + repack_stages - 1);
|
||||
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
|
||||
wait_for_stage();
|
||||
@@ -234,85 +237,83 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace marlin
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
|
||||
#define CALL_IF(NUM_BITS, HAS_PERM) \
|
||||
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
|
||||
cudaFuncSetAttribute( \
|
||||
MARLIN_NAMESPACE_NAME::gptq_marlin_repack_kernel<MARLIN_NAMESPACE_NAME::repack_threads, NUM_BITS, \
|
||||
HAS_PERM>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||
MARLIN_NAMESPACE_NAME::gptq_marlin_repack_kernel<MARLIN_NAMESPACE_NAME::repack_threads, NUM_BITS, \
|
||||
HAS_PERM> \
|
||||
<<<blocks, MARLIN_NAMESPACE_NAME::repack_threads, max_shared_mem, stream>>>( \
|
||||
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
|
||||
#define CALL_IF(NUM_BITS, HAS_PERM) \
|
||||
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
|
||||
cudaFuncSetAttribute(MARLIN_NAMESPACE_NAME::gptq_marlin_repack_kernel< \
|
||||
MARLIN_NAMESPACE_NAME::repack_threads, \
|
||||
NUM_BITS, \
|
||||
HAS_PERM>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, \
|
||||
max_shared_mem); \
|
||||
MARLIN_NAMESPACE_NAME::gptq_marlin_repack_kernel< \
|
||||
MARLIN_NAMESPACE_NAME::repack_threads, \
|
||||
NUM_BITS, \
|
||||
HAS_PERM> \
|
||||
<<<blocks, \
|
||||
MARLIN_NAMESPACE_NAME::repack_threads, \
|
||||
max_shared_mem, \
|
||||
stream>>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> gptq_marlin_repack(paddle::Tensor& b_q_weight, paddle::Tensor& perm,
|
||||
int64_t size_k, int64_t size_n,
|
||||
int64_t num_bits) {
|
||||
std::vector<paddle::Tensor> gptq_marlin_repack(paddle::Tensor& b_q_weight,
|
||||
paddle::Tensor& perm,
|
||||
int64_t size_k,
|
||||
int64_t size_n,
|
||||
int64_t num_bits) {
|
||||
// Verify compatibility with marlin tile of 16x64
|
||||
PADDLE_ENFORCE(
|
||||
size_k % MARLIN_NAMESPACE_NAME::tile_k_size == 0,
|
||||
"size_k = ", size_k,
|
||||
" is not divisible by tile_k_size = ",
|
||||
MARLIN_NAMESPACE_NAME::tile_k_size);
|
||||
PADDLE_ENFORCE(size_k % MARLIN_NAMESPACE_NAME::tile_k_size == 0,
|
||||
"size_k = ",
|
||||
size_k,
|
||||
" is not divisible by tile_k_size = ",
|
||||
MARLIN_NAMESPACE_NAME::tile_k_size);
|
||||
|
||||
PADDLE_ENFORCE(
|
||||
size_n % MARLIN_NAMESPACE_NAME::tile_n_size == 0,
|
||||
"size_n = ", size_n,
|
||||
" is not divisible by tile_n_size = ",
|
||||
MARLIN_NAMESPACE_NAME::tile_n_size);
|
||||
PADDLE_ENFORCE(size_n % MARLIN_NAMESPACE_NAME::tile_n_size == 0,
|
||||
"size_n = ",
|
||||
size_n,
|
||||
" is not divisible by tile_n_size = ",
|
||||
MARLIN_NAMESPACE_NAME::tile_n_size);
|
||||
|
||||
PADDLE_ENFORCE(
|
||||
num_bits == 4 || num_bits == 8,
|
||||
"num_bits must be 4 or 8. Got = ", num_bits);
|
||||
PADDLE_ENFORCE(num_bits == 4 || num_bits == 8,
|
||||
"num_bits must be 4 or 8. Got = ",
|
||||
num_bits);
|
||||
|
||||
int const pack_factor = 32 / num_bits;
|
||||
|
||||
// Verify B
|
||||
// shape checks
|
||||
PADDLE_ENFORCE(
|
||||
(size_k / pack_factor) == b_q_weight.dims()[0],
|
||||
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.dims()[0]);
|
||||
PADDLE_ENFORCE((size_k / pack_factor) == b_q_weight.dims()[0],
|
||||
"Shape mismatch: b_q_weight.size(0) = ",
|
||||
b_q_weight.dims()[0]);
|
||||
|
||||
PADDLE_ENFORCE(
|
||||
b_q_weight.dims()[1] == size_n,
|
||||
"Shape mismatch: b_q_weight.size(1) = ", b_q_weight.dims()[1],
|
||||
", expected size_n = ", size_n);
|
||||
PADDLE_ENFORCE(b_q_weight.dims()[1] == size_n,
|
||||
"Shape mismatch: b_q_weight.size(1) = ",
|
||||
b_q_weight.dims()[1],
|
||||
", expected size_n = ",
|
||||
size_n);
|
||||
|
||||
// Verify device and strides
|
||||
PADDLE_ENFORCE(
|
||||
b_q_weight.is_gpu(),
|
||||
"b_q_weight is not on GPU");
|
||||
// Verify device and strides
|
||||
PADDLE_ENFORCE(b_q_weight.is_gpu(), "b_q_weight is not on GPU");
|
||||
|
||||
PADDLE_ENFORCE(
|
||||
b_q_weight.is_contiguous(),
|
||||
"b_q_weight is not contiguous");
|
||||
PADDLE_ENFORCE(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
||||
|
||||
PADDLE_ENFORCE(
|
||||
b_q_weight.dtype() == phi::DataType::INT32,
|
||||
"b_q_weight type is not kInt");
|
||||
PADDLE_ENFORCE(b_q_weight.dtype() == phi::DataType::INT32,
|
||||
"b_q_weight type is not kInt");
|
||||
|
||||
PADDLE_ENFORCE(
|
||||
perm.is_gpu(),
|
||||
"perm is not on GPU");
|
||||
PADDLE_ENFORCE(perm.is_gpu(), "perm is not on GPU");
|
||||
|
||||
PADDLE_ENFORCE(
|
||||
perm.is_contiguous(),
|
||||
"perm is not contiguous");
|
||||
PADDLE_ENFORCE(perm.is_contiguous(), "perm is not contiguous");
|
||||
|
||||
PADDLE_ENFORCE(
|
||||
perm.dtype() == phi::DataType::INT32,
|
||||
"perm type is not kInt");
|
||||
PADDLE_ENFORCE(perm.dtype() == phi::DataType::INT32, "perm type is not kInt");
|
||||
|
||||
// Alloc buffers
|
||||
// const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
|
||||
paddle::Tensor out = paddle::empty(
|
||||
{size_k / MARLIN_NAMESPACE_NAME::tile_size, size_n * MARLIN_NAMESPACE_NAME::tile_size / pack_factor},
|
||||
b_q_weight.dtype(),
|
||||
b_q_weight.place());
|
||||
|
||||
paddle::Tensor out =
|
||||
paddle::empty({size_k / MARLIN_NAMESPACE_NAME::tile_size,
|
||||
size_n * MARLIN_NAMESPACE_NAME::tile_size / pack_factor},
|
||||
b_q_weight.dtype(),
|
||||
b_q_weight.place());
|
||||
|
||||
// Detect if there is act_order
|
||||
bool has_perm = perm.dims()[0] != 0;
|
||||
@@ -330,13 +331,11 @@ std::vector<paddle::Tensor> gptq_marlin_repack(paddle::Tensor& b_q_weight, paddl
|
||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||
|
||||
int max_shared_mem = 0;
|
||||
cudaDeviceGetAttribute(&max_shared_mem,
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||
cudaDeviceGetAttribute(
|
||||
&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||
// TORCH_CHECK(max_shared_mem > 0);
|
||||
PADDLE_ENFORCE(
|
||||
max_shared_mem > 0,
|
||||
"max_shared_mem must be > 0. Got = ", max_shared_mem);
|
||||
|
||||
max_shared_mem > 0, "max_shared_mem must be > 0. Got = ", max_shared_mem);
|
||||
|
||||
if (false) {
|
||||
}
|
||||
@@ -347,11 +346,11 @@ std::vector<paddle::Tensor> gptq_marlin_repack(paddle::Tensor& b_q_weight, paddl
|
||||
else {
|
||||
// TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits,
|
||||
// ", has_perm = ", has_perm);
|
||||
PADDLE_ENFORCE(
|
||||
false,
|
||||
"Unsupported repack config: num_bits = ", num_bits,
|
||||
", has_perm = ", has_perm);
|
||||
|
||||
PADDLE_ENFORCE(false,
|
||||
"Unsupported repack config: num_bits = ",
|
||||
num_bits,
|
||||
", has_perm = ",
|
||||
has_perm);
|
||||
}
|
||||
|
||||
return {out};
|
||||
@@ -360,9 +359,5 @@ std::vector<paddle::Tensor> gptq_marlin_repack(paddle::Tensor& b_q_weight, paddl
|
||||
PD_BUILD_STATIC_OP(gptq_marlin_repack)
|
||||
.Inputs({"b_q_weight", "perm"})
|
||||
.Outputs({"out"})
|
||||
.Attrs({
|
||||
"size_k: int64_t",
|
||||
"size_n: int64_t",
|
||||
"num_bits: int64_t"
|
||||
})
|
||||
.Attrs({"size_k: int64_t", "size_n: int64_t", "num_bits: int64_t"})
|
||||
.SetKernelFn(PD_KERNEL(gptq_marlin_repack));
|
||||
|
||||
@@ -19,115 +19,119 @@
|
||||
#pragma once
|
||||
|
||||
template <typename index, typename T, int VecSize>
|
||||
__global__ void group_swiglu_with_masked_kernel(T* act_out,
|
||||
const T* input,
|
||||
const index *token_nums_per_expert,
|
||||
const int64_t group_num,
|
||||
const int64_t group_size,
|
||||
const int64_t hidden_dim) {
|
||||
int64_t global_idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
|
||||
const int64_t num = group_num * group_size * hidden_dim;
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
LoadT src_vec0, src_vec1;
|
||||
LoadT res_vec;
|
||||
__global__ void group_swiglu_with_masked_kernel(
|
||||
T* act_out,
|
||||
const T* input,
|
||||
const index* token_nums_per_expert,
|
||||
const int64_t group_num,
|
||||
const int64_t group_size,
|
||||
const int64_t hidden_dim) {
|
||||
int64_t global_idx =
|
||||
static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
|
||||
const int64_t num = group_num * group_size * hidden_dim;
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
LoadT src_vec0, src_vec1;
|
||||
LoadT res_vec;
|
||||
|
||||
int64_t block_id = static_cast<int64_t>(blockIdx.x);
|
||||
const int lane_idx = threadIdx.x % 32;
|
||||
int64_t block_id = static_cast<int64_t>(blockIdx.x);
|
||||
const int lane_idx = threadIdx.x % 32;
|
||||
|
||||
while(true) {
|
||||
int dealt_group_id = -1;
|
||||
int dealt_seq_id = -1;
|
||||
if (lane_idx == 0 ) {
|
||||
int cumsum1 = 0;
|
||||
int cumsum2 = 0;
|
||||
for (int i = 0; i < group_num; i++) {
|
||||
int tmp = token_nums_per_expert[i];
|
||||
cumsum2 += tmp;
|
||||
if (block_id >= cumsum1 && block_id < cumsum2) {
|
||||
dealt_group_id = i;
|
||||
dealt_seq_id = block_id - cumsum1;
|
||||
break;
|
||||
}
|
||||
cumsum1 += tmp;
|
||||
}
|
||||
while (true) {
|
||||
int dealt_group_id = -1;
|
||||
int dealt_seq_id = -1;
|
||||
if (lane_idx == 0) {
|
||||
int cumsum1 = 0;
|
||||
int cumsum2 = 0;
|
||||
for (int i = 0; i < group_num; i++) {
|
||||
int tmp = token_nums_per_expert[i];
|
||||
cumsum2 += tmp;
|
||||
if (block_id >= cumsum1 && block_id < cumsum2) {
|
||||
dealt_group_id = i;
|
||||
dealt_seq_id = block_id - cumsum1;
|
||||
break;
|
||||
}
|
||||
dealt_group_id = __shfl_sync(0xffffffff, dealt_group_id, 0);
|
||||
dealt_seq_id = __shfl_sync(0xffffffff, dealt_seq_id, 0);
|
||||
if (dealt_group_id < 0) break;
|
||||
|
||||
const int64_t r_offset = (dealt_group_id * group_size + dealt_seq_id) * hidden_dim * 2;
|
||||
const int64_t w_offset = (dealt_group_id * group_size + dealt_seq_id) * hidden_dim;
|
||||
|
||||
for (int64_t col_id = threadIdx.x * VecSize; col_id < hidden_dim; col_id += blockDim.x * VecSize) {
|
||||
|
||||
Load<T, VecSize>(&input[r_offset + col_id], &src_vec0);
|
||||
Load<T, VecSize>(&input[r_offset + col_id + hidden_dim], &src_vec1);
|
||||
|
||||
for (int j = 0; j < VecSize; ++j) {
|
||||
float a = static_cast<float>(src_vec0[j]);
|
||||
float b = static_cast<float>(src_vec1[j]);
|
||||
float res = b * a / (1.f + exp(-a));
|
||||
res_vec[j] = static_cast<T>(res);
|
||||
}
|
||||
|
||||
Store<T, VecSize>(res_vec, &act_out[w_offset + col_id]);
|
||||
}
|
||||
block_id += gridDim.x;
|
||||
cumsum1 += tmp;
|
||||
}
|
||||
}
|
||||
dealt_group_id = __shfl_sync(0xffffffff, dealt_group_id, 0);
|
||||
dealt_seq_id = __shfl_sync(0xffffffff, dealt_seq_id, 0);
|
||||
if (dealt_group_id < 0) break;
|
||||
|
||||
const int64_t r_offset =
|
||||
(dealt_group_id * group_size + dealt_seq_id) * hidden_dim * 2;
|
||||
const int64_t w_offset =
|
||||
(dealt_group_id * group_size + dealt_seq_id) * hidden_dim;
|
||||
|
||||
for (int64_t col_id = threadIdx.x * VecSize; col_id < hidden_dim;
|
||||
col_id += blockDim.x * VecSize) {
|
||||
Load<T, VecSize>(&input[r_offset + col_id], &src_vec0);
|
||||
Load<T, VecSize>(&input[r_offset + col_id + hidden_dim], &src_vec1);
|
||||
|
||||
for (int j = 0; j < VecSize; ++j) {
|
||||
float a = static_cast<float>(src_vec0[j]);
|
||||
float b = static_cast<float>(src_vec1[j]);
|
||||
float res = b * a / (1.f + exp(-a));
|
||||
res_vec[j] = static_cast<T>(res);
|
||||
}
|
||||
|
||||
Store<T, VecSize>(res_vec, &act_out[w_offset + col_id]);
|
||||
}
|
||||
block_id += gridDim.x;
|
||||
}
|
||||
}
|
||||
|
||||
paddle::Tensor GroupSwigluWithMasked(const paddle::Tensor& fc1_out_tensor,
|
||||
const paddle::Tensor& token_nums_per_expert
|
||||
)
|
||||
{
|
||||
const int64_t group_num = token_nums_per_expert.shape()[0];
|
||||
const int64_t group_size = fc1_out_tensor.shape()[1];
|
||||
const int64_t hidden_dim = fc1_out_tensor.shape()[2] / 2;
|
||||
auto act_out_tensor = GetEmptyTensor({group_num, group_size, hidden_dim}, fc1_out_tensor.dtype(), fc1_out_tensor.place());
|
||||
paddle::Tensor GroupSwigluWithMasked(
|
||||
const paddle::Tensor& fc1_out_tensor,
|
||||
const paddle::Tensor& token_nums_per_expert) {
|
||||
const int64_t group_num = token_nums_per_expert.shape()[0];
|
||||
const int64_t group_size = fc1_out_tensor.shape()[1];
|
||||
const int64_t hidden_dim = fc1_out_tensor.shape()[2] / 2;
|
||||
auto act_out_tensor = GetEmptyTensor({group_num, group_size, hidden_dim},
|
||||
fc1_out_tensor.dtype(),
|
||||
fc1_out_tensor.place());
|
||||
|
||||
constexpr int VecSize = 8;
|
||||
PD_CHECK(fc1_out_tensor.dtype() == paddle::DataType::BFLOAT16);
|
||||
PD_CHECK(hidden_dim % VecSize == 0);
|
||||
constexpr int VecSize = 8;
|
||||
PD_CHECK(fc1_out_tensor.dtype() == paddle::DataType::BFLOAT16);
|
||||
PD_CHECK(hidden_dim % VecSize == 0);
|
||||
|
||||
constexpr paddle::DataType D = paddle::DataType::BFLOAT16;
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
constexpr paddle::DataType D = paddle::DataType::BFLOAT16;
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
const int threads = 512;
|
||||
const int blocks = 256;
|
||||
const int threads = 512;
|
||||
const int blocks = 256;
|
||||
|
||||
#define dispatch_by_index(index) {\
|
||||
group_swiglu_with_masked_kernel<index, DataType_, VecSize><<<blocks, threads, 0, fc1_out_tensor.stream()>>>(\
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(act_out_tensor.data<data_t>())),\
|
||||
reinterpret_cast<const DataType_*>(fc1_out_tensor.data<data_t>()),\
|
||||
token_nums_per_expert.data<index>(),\
|
||||
group_num,\
|
||||
group_size,\
|
||||
hidden_dim\
|
||||
);} while(0)
|
||||
if (token_nums_per_expert.dtype() == paddle::DataType::INT64) {
|
||||
dispatch_by_index(int64_t);
|
||||
} else if(token_nums_per_expert.dtype() == paddle::DataType::INT32) {
|
||||
dispatch_by_index(int32_t);
|
||||
} else {
|
||||
PD_THROW("Unsupported token_nums_per_expert's data dtype.");
|
||||
}
|
||||
#define dispatch_by_index(index) \
|
||||
{ \
|
||||
group_swiglu_with_masked_kernel<index, DataType_, VecSize> \
|
||||
<<<blocks, threads, 0, fc1_out_tensor.stream()>>>( \
|
||||
reinterpret_cast<DataType_*>( \
|
||||
const_cast<data_t*>(act_out_tensor.data<data_t>())), \
|
||||
reinterpret_cast<const DataType_*>(fc1_out_tensor.data<data_t>()), \
|
||||
token_nums_per_expert.data<index>(), \
|
||||
group_num, \
|
||||
group_size, \
|
||||
hidden_dim); \
|
||||
} \
|
||||
while (0)
|
||||
if (token_nums_per_expert.dtype() == paddle::DataType::INT64) {
|
||||
dispatch_by_index(int64_t);
|
||||
} else if (token_nums_per_expert.dtype() == paddle::DataType::INT32) {
|
||||
dispatch_by_index(int32_t);
|
||||
} else {
|
||||
PD_THROW("Unsupported token_nums_per_expert's data dtype.");
|
||||
}
|
||||
|
||||
return act_out_tensor;
|
||||
return act_out_tensor;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> GroupSwigluWithMaskedWrapper(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& token_nums_per_expert) {
|
||||
return {GroupSwigluWithMasked(input, token_nums_per_expert)};
|
||||
const paddle::Tensor& input, const paddle::Tensor& token_nums_per_expert) {
|
||||
return {GroupSwigluWithMasked(input, token_nums_per_expert)};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(group_swiglu_with_masked)
|
||||
.Inputs({"input",
|
||||
"token_nums_per_expert"})
|
||||
.Inputs({"input", "token_nums_per_expert"})
|
||||
.Outputs({"output_tensor"})
|
||||
.SetKernelFn(PD_KERNEL(GroupSwigluWithMaskedWrapper));
|
||||
|
||||
@@ -15,6 +15,6 @@
|
||||
#pragma once
|
||||
#include "helper.h"
|
||||
|
||||
paddle::Tensor
|
||||
GroupSwigluWithMasked(const paddle::Tensor &fc1_out_tensor,
|
||||
const paddle::Tensor &token_nums_per_expert);
|
||||
paddle::Tensor GroupSwigluWithMasked(
|
||||
const paddle::Tensor &fc1_out_tensor,
|
||||
const paddle::Tensor &token_nums_per_expert);
|
||||
|
||||
@@ -14,146 +14,164 @@
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
template<typename T, int VecSize, int TopK>
|
||||
__global__ void MoEDeepGEMMDePermuteKernel(T* out, const T* ffn_out, const int* permute_indices_per_token, const int64_t* topk_idx, const float* topk_weights, const int token_num, const int num_vecs, const int hidden, const int max_tokens_per_expert) {
|
||||
AlignedVector<T, VecSize> in_vec;
|
||||
template <typename T, int VecSize, int TopK>
|
||||
__global__ void MoEDeepGEMMDePermuteKernel(T* out,
|
||||
const T* ffn_out,
|
||||
const int* permute_indices_per_token,
|
||||
const int64_t* topk_idx,
|
||||
const float* topk_weights,
|
||||
const int token_num,
|
||||
const int num_vecs,
|
||||
const int hidden,
|
||||
const int max_tokens_per_expert) {
|
||||
AlignedVector<T, VecSize> in_vec;
|
||||
|
||||
AlignedVector<T, VecSize> acc_vec[TopK];
|
||||
AlignedVector<T, VecSize> acc_vec[TopK];
|
||||
|
||||
const int bid = blockIdx.x;
|
||||
const int wid = threadIdx.x / 32;
|
||||
const int tid = threadIdx.x % 32;
|
||||
extern __shared__ char shm[]; // TopK * hidden
|
||||
T* shm_hidden = reinterpret_cast<T*>(shm);
|
||||
const int bid = blockIdx.x;
|
||||
const int wid = threadIdx.x / 32;
|
||||
const int tid = threadIdx.x % 32;
|
||||
extern __shared__ char shm[]; // TopK * hidden
|
||||
T* shm_hidden = reinterpret_cast<T*>(shm);
|
||||
|
||||
for (int token_idx = bid; token_idx < token_num; token_idx += gridDim.x) {
|
||||
int src_expert_id = topk_idx[token_idx * TopK + wid];
|
||||
int src_expert_token = permute_indices_per_token[token_idx * TopK + wid];
|
||||
float weight = topk_weights[token_idx * TopK + wid];
|
||||
for (int token_idx = bid; token_idx < token_num; token_idx += gridDim.x) {
|
||||
int src_expert_id = topk_idx[token_idx * TopK + wid];
|
||||
int src_expert_token = permute_indices_per_token[token_idx * TopK + wid];
|
||||
float weight = topk_weights[token_idx * TopK + wid];
|
||||
|
||||
for (int hidden_vec_id = tid; hidden_vec_id < num_vecs; hidden_vec_id += 32) {
|
||||
Load<T, VecSize>(ffn_out + src_expert_id * max_tokens_per_expert * hidden + src_expert_token * hidden + hidden_vec_id * VecSize, &in_vec);
|
||||
for (int hidden_vec_id = tid; hidden_vec_id < num_vecs;
|
||||
hidden_vec_id += 32) {
|
||||
Load<T, VecSize>(ffn_out +
|
||||
src_expert_id * max_tokens_per_expert * hidden +
|
||||
src_expert_token * hidden + hidden_vec_id * VecSize,
|
||||
&in_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
in_vec[i] *= weight;
|
||||
}
|
||||
Store<T, VecSize>(in_vec, shm_hidden + wid * hidden + hidden_vec_id * VecSize);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int hidden_vec_id = threadIdx.x; hidden_vec_id < num_vecs; hidden_vec_id += blockDim.x) {
|
||||
#pragma unroll
|
||||
for (int topk_id = 0; topk_id < TopK; topk_id++) {
|
||||
Load<T, VecSize>(shm_hidden + topk_id * hidden + hidden_vec_id * VecSize, &acc_vec[topk_id]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
#pragma unroll
|
||||
for (int topk_id = 1; topk_id < TopK; topk_id++) {
|
||||
acc_vec[0][i] += acc_vec[topk_id][i];
|
||||
}
|
||||
}
|
||||
Store<T, VecSize>(acc_vec[0], out + token_idx * hidden + hidden_vec_id * VecSize);
|
||||
}
|
||||
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
in_vec[i] *= weight;
|
||||
}
|
||||
Store<T, VecSize>(in_vec,
|
||||
shm_hidden + wid * hidden + hidden_vec_id * VecSize);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int hidden_vec_id = threadIdx.x; hidden_vec_id < num_vecs;
|
||||
hidden_vec_id += blockDim.x) {
|
||||
#pragma unroll
|
||||
for (int topk_id = 0; topk_id < TopK; topk_id++) {
|
||||
Load<T, VecSize>(
|
||||
shm_hidden + topk_id * hidden + hidden_vec_id * VecSize,
|
||||
&acc_vec[topk_id]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
#pragma unroll
|
||||
for (int topk_id = 1; topk_id < TopK; topk_id++) {
|
||||
acc_vec[0][i] += acc_vec[topk_id][i];
|
||||
}
|
||||
}
|
||||
Store<T, VecSize>(acc_vec[0],
|
||||
out + token_idx * hidden + hidden_vec_id * VecSize);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <paddle::DataType D>
|
||||
std::vector<paddle::Tensor> MoEDeepGEMMDePermuteDispatch(
|
||||
const paddle::Tensor& ffn_out, // [num_experts, max_tokens_per_expert, hidden]
|
||||
const paddle::Tensor& permute_indices_per_token, // [token_num, topk}]
|
||||
const paddle::Tensor&
|
||||
ffn_out, // [num_experts, max_tokens_per_expert, hidden]
|
||||
const paddle::Tensor& permute_indices_per_token, // [token_num, topk}]
|
||||
const paddle::Tensor& topk_idx,
|
||||
const paddle::Tensor& topk_weights
|
||||
) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
const paddle::Tensor& topk_weights) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
const int token_num = permute_indices_per_token.shape()[0];
|
||||
const int max_tokens_per_expert = ffn_out.shape()[1];
|
||||
const int hidden = ffn_out.shape()[2];
|
||||
const int topk = permute_indices_per_token.shape()[1];
|
||||
const int token_num = permute_indices_per_token.shape()[0];
|
||||
const int max_tokens_per_expert = ffn_out.shape()[1];
|
||||
const int hidden = ffn_out.shape()[2];
|
||||
const int topk = permute_indices_per_token.shape()[1];
|
||||
|
||||
auto place = ffn_out.place();
|
||||
auto stream = ffn_out.stream();
|
||||
auto place = ffn_out.place();
|
||||
auto stream = ffn_out.stream();
|
||||
|
||||
auto out = GetEmptyTensor({token_num, hidden}, ffn_out.dtype(), place);
|
||||
auto out = GetEmptyTensor({token_num, hidden}, ffn_out.dtype(), place);
|
||||
|
||||
constexpr int VecSize = 16 / sizeof(data_t);
|
||||
int blocks = 32 * topk;
|
||||
int grids = min(132 * 4, token_num);
|
||||
int num_vecs = hidden / VecSize;
|
||||
constexpr int VecSize = 16 / sizeof(data_t);
|
||||
int blocks = 32 * topk;
|
||||
int grids = min(132 * 4, token_num);
|
||||
int num_vecs = hidden / VecSize;
|
||||
|
||||
assert(blocks <= 1024);
|
||||
int dyn_smem_size = 0;
|
||||
assert(blocks <= 1024);
|
||||
int dyn_smem_size = 0;
|
||||
|
||||
switch (topk) {
|
||||
case 4:
|
||||
dyn_smem_size = topk * hidden * sizeof(DataType_);
|
||||
if (dyn_smem_size >= (48 << 10)) {
|
||||
cudaFuncSetAttribute(
|
||||
MoEDeepGEMMDePermuteKernel<DataType_, VecSize, 4>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
dyn_smem_size);
|
||||
}
|
||||
MoEDeepGEMMDePermuteKernel<DataType_, VecSize, 4><<<grids, blocks, dyn_smem_size, stream>>>(
|
||||
reinterpret_cast<DataType_*>(out.data<data_t>()),
|
||||
reinterpret_cast<const DataType_*>(ffn_out.data<data_t>()),
|
||||
permute_indices_per_token.data<int32_t>(),
|
||||
topk_idx.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_num, num_vecs, hidden, max_tokens_per_expert
|
||||
);
|
||||
break;
|
||||
switch (topk) {
|
||||
case 4:
|
||||
dyn_smem_size = topk * hidden * sizeof(DataType_);
|
||||
if (dyn_smem_size >= (48 << 10)) {
|
||||
cudaFuncSetAttribute(MoEDeepGEMMDePermuteKernel<DataType_, VecSize, 4>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
dyn_smem_size);
|
||||
}
|
||||
MoEDeepGEMMDePermuteKernel<DataType_, VecSize, 4>
|
||||
<<<grids, blocks, dyn_smem_size, stream>>>(
|
||||
reinterpret_cast<DataType_*>(out.data<data_t>()),
|
||||
reinterpret_cast<const DataType_*>(ffn_out.data<data_t>()),
|
||||
permute_indices_per_token.data<int32_t>(),
|
||||
topk_idx.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_num,
|
||||
num_vecs,
|
||||
hidden,
|
||||
max_tokens_per_expert);
|
||||
break;
|
||||
|
||||
case 8:
|
||||
dyn_smem_size = topk * hidden * sizeof(DataType_);
|
||||
if (dyn_smem_size >= (48 << 10)) {
|
||||
cudaFuncSetAttribute(
|
||||
MoEDeepGEMMDePermuteKernel<DataType_, VecSize, 8>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
dyn_smem_size);
|
||||
}
|
||||
MoEDeepGEMMDePermuteKernel<DataType_, VecSize, 8><<<grids, blocks, topk * hidden * sizeof(DataType_), stream>>>(
|
||||
reinterpret_cast<DataType_*>(out.data<data_t>()),
|
||||
reinterpret_cast<const DataType_*>(ffn_out.data<data_t>()),
|
||||
permute_indices_per_token.data<int32_t>(),
|
||||
topk_idx.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_num, num_vecs, hidden, max_tokens_per_expert
|
||||
);
|
||||
break;
|
||||
case 8:
|
||||
dyn_smem_size = topk * hidden * sizeof(DataType_);
|
||||
if (dyn_smem_size >= (48 << 10)) {
|
||||
cudaFuncSetAttribute(MoEDeepGEMMDePermuteKernel<DataType_, VecSize, 8>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
dyn_smem_size);
|
||||
}
|
||||
MoEDeepGEMMDePermuteKernel<DataType_, VecSize, 8>
|
||||
<<<grids, blocks, topk * hidden * sizeof(DataType_), stream>>>(
|
||||
reinterpret_cast<DataType_*>(out.data<data_t>()),
|
||||
reinterpret_cast<const DataType_*>(ffn_out.data<data_t>()),
|
||||
permute_indices_per_token.data<int32_t>(),
|
||||
topk_idx.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_num,
|
||||
num_vecs,
|
||||
hidden,
|
||||
max_tokens_per_expert);
|
||||
break;
|
||||
|
||||
default:
|
||||
PD_THROW("Unsupported topk");
|
||||
}
|
||||
return {out};
|
||||
default:
|
||||
PD_THROW("Unsupported topk");
|
||||
}
|
||||
return {out};
|
||||
}
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> MoEDeepGEMMDePermute(
|
||||
const paddle::Tensor& ffn_out, // [num_experts, max_tokens_per_expert, hidden]
|
||||
const paddle::Tensor& permute_indices_per_token, // [token_num, topk}]
|
||||
const paddle::Tensor&
|
||||
ffn_out, // [num_experts, max_tokens_per_expert, hidden]
|
||||
const paddle::Tensor& permute_indices_per_token, // [token_num, topk}]
|
||||
const paddle::Tensor& topk_idx,
|
||||
const paddle::Tensor& topk_weights
|
||||
) {
|
||||
switch (ffn_out.dtype()) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
return MoEDeepGEMMDePermuteDispatch<paddle::DataType::BFLOAT16>(
|
||||
ffn_out, permute_indices_per_token, topk_idx, topk_weights
|
||||
);
|
||||
case paddle::DataType::FLOAT16:
|
||||
return MoEDeepGEMMDePermuteDispatch<paddle::DataType::FLOAT16>(
|
||||
ffn_out, permute_indices_per_token, topk_idx, topk_weights
|
||||
);
|
||||
default:
|
||||
PD_THROW("Unsupported data type");
|
||||
}
|
||||
const paddle::Tensor& topk_weights) {
|
||||
switch (ffn_out.dtype()) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
return MoEDeepGEMMDePermuteDispatch<paddle::DataType::BFLOAT16>(
|
||||
ffn_out, permute_indices_per_token, topk_idx, topk_weights);
|
||||
case paddle::DataType::FLOAT16:
|
||||
return MoEDeepGEMMDePermuteDispatch<paddle::DataType::FLOAT16>(
|
||||
ffn_out, permute_indices_per_token, topk_idx, topk_weights);
|
||||
default:
|
||||
PD_THROW("Unsupported data type");
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(moe_deepgemm_depermute)
|
||||
.Inputs({"ffn_out", "permute_indices_per_token", "topk_idx", "topk_weights"})
|
||||
.Inputs(
|
||||
{"ffn_out", "permute_indices_per_token", "topk_idx", "topk_weights"})
|
||||
.Outputs({"out"})
|
||||
.SetKernelFn(PD_KERNEL(MoEDeepGEMMDePermute));
|
||||
|
||||
@@ -15,29 +15,41 @@
|
||||
#include "helper.h"
|
||||
|
||||
// topk warps
|
||||
template<typename T, int VecSize>
|
||||
__global__ void MoEDeepGEMMPermuteKernel(T* out, int* token_nums_per_expert, int* permute_indices_per_token, const T* x, const int64_t* topk_idx, const int token_num, const int topk, const int num_vecs, const int hidden, const int max_tokens_per_expert) {
|
||||
template <typename T, int VecSize>
|
||||
__global__ void MoEDeepGEMMPermuteKernel(T* out,
|
||||
int* token_nums_per_expert,
|
||||
int* permute_indices_per_token,
|
||||
const T* x,
|
||||
const int64_t* topk_idx,
|
||||
const int token_num,
|
||||
const int topk,
|
||||
const int num_vecs,
|
||||
const int hidden,
|
||||
const int max_tokens_per_expert) {
|
||||
AlignedVector<T, VecSize> in_vec;
|
||||
|
||||
AlignedVector<T, VecSize> in_vec;
|
||||
|
||||
const int bid = blockIdx.x;
|
||||
const int wid = threadIdx.x / 32;
|
||||
const int tid = threadIdx.x % 32;
|
||||
for (int token_idx = bid; token_idx < token_num; token_idx += gridDim.x) {
|
||||
const int tgt_expert_id = topk_idx[token_idx * topk + wid];
|
||||
int tgt_expert_token;
|
||||
if (tid == 0) {
|
||||
tgt_expert_token = atomicAdd(token_nums_per_expert + tgt_expert_id, 1);
|
||||
permute_indices_per_token[token_idx * topk + wid] = tgt_expert_token;
|
||||
}
|
||||
tgt_expert_token = __shfl_sync(0xFFFFFFFF, tgt_expert_token, 0);
|
||||
|
||||
|
||||
for (int hidden_vec_id = tid; hidden_vec_id < num_vecs; hidden_vec_id += 32) {
|
||||
Load<T, VecSize>(x + token_idx * hidden + hidden_vec_id * VecSize, &in_vec);
|
||||
Store<T, VecSize>(in_vec, out + tgt_expert_id * max_tokens_per_expert * hidden + tgt_expert_token * hidden + hidden_vec_id * VecSize);
|
||||
}
|
||||
const int bid = blockIdx.x;
|
||||
const int wid = threadIdx.x / 32;
|
||||
const int tid = threadIdx.x % 32;
|
||||
for (int token_idx = bid; token_idx < token_num; token_idx += gridDim.x) {
|
||||
const int tgt_expert_id = topk_idx[token_idx * topk + wid];
|
||||
int tgt_expert_token;
|
||||
if (tid == 0) {
|
||||
tgt_expert_token = atomicAdd(token_nums_per_expert + tgt_expert_id, 1);
|
||||
permute_indices_per_token[token_idx * topk + wid] = tgt_expert_token;
|
||||
}
|
||||
tgt_expert_token = __shfl_sync(0xFFFFFFFF, tgt_expert_token, 0);
|
||||
|
||||
for (int hidden_vec_id = tid; hidden_vec_id < num_vecs;
|
||||
hidden_vec_id += 32) {
|
||||
Load<T, VecSize>(x + token_idx * hidden + hidden_vec_id * VecSize,
|
||||
&in_vec);
|
||||
Store<T, VecSize>(in_vec,
|
||||
out + tgt_expert_id * max_tokens_per_expert * hidden +
|
||||
tgt_expert_token * hidden +
|
||||
hidden_vec_id * VecSize);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <paddle::DataType D>
|
||||
@@ -45,71 +57,78 @@ std::vector<paddle::Tensor> MoEDeepGEMMPermuteDispatch(
|
||||
const paddle::Tensor& x,
|
||||
const paddle::Tensor& topk_idx,
|
||||
const int num_experts,
|
||||
const int max_tokens_per_expert
|
||||
) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
const int max_tokens_per_expert) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
const int token_num = x.shape()[0];
|
||||
const int hidden = x.shape()[1];
|
||||
const int topk = topk_idx.shape()[1];
|
||||
const int token_num = x.shape()[0];
|
||||
const int hidden = x.shape()[1];
|
||||
const int topk = topk_idx.shape()[1];
|
||||
|
||||
auto place = x.place();
|
||||
auto stream = x.stream();
|
||||
auto place = x.place();
|
||||
auto stream = x.stream();
|
||||
|
||||
auto token_nums_per_expert = GetEmptyTensor({num_experts}, paddle::DataType::INT32, place);
|
||||
auto permute_indices_per_token = GetEmptyTensor({token_num, topk}, paddle::DataType::INT32, place);
|
||||
auto token_nums_per_expert =
|
||||
GetEmptyTensor({num_experts}, paddle::DataType::INT32, place);
|
||||
auto permute_indices_per_token =
|
||||
GetEmptyTensor({token_num, topk}, paddle::DataType::INT32, place);
|
||||
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(token_nums_per_expert.data<int32_t>(), 0, num_experts * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
||||
cudaMemsetAsync(token_nums_per_expert.data<int32_t>(),
|
||||
0,
|
||||
num_experts * sizeof(int32_t),
|
||||
stream));
|
||||
|
||||
auto permute_output = GetEmptyTensor({num_experts, max_tokens_per_expert, hidden}, x.dtype(), place);
|
||||
auto permute_output = GetEmptyTensor(
|
||||
{num_experts, max_tokens_per_expert, hidden}, x.dtype(), place);
|
||||
|
||||
auto permute_output_data = permute_output.data<data_t>();
|
||||
auto permute_output_data = permute_output.data<data_t>();
|
||||
|
||||
constexpr int VecSize = 16 / sizeof(data_t);
|
||||
constexpr int VecSize = 16 / sizeof(data_t);
|
||||
|
||||
int blocks = 32 * topk;
|
||||
int grids = min(132 * 4, token_num);
|
||||
int num_vecs = hidden / VecSize;
|
||||
int blocks = 32 * topk;
|
||||
int grids = min(132 * 4, token_num);
|
||||
int num_vecs = hidden / VecSize;
|
||||
|
||||
assert(blocks <= 1024);
|
||||
assert(blocks <= 1024);
|
||||
|
||||
MoEDeepGEMMPermuteKernel<DataType_, VecSize><<<grids, blocks, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(permute_output_data),
|
||||
token_nums_per_expert.data<int32_t>(),
|
||||
permute_indices_per_token.data<int32_t>(),
|
||||
reinterpret_cast<const DataType_ *>(x.data<data_t>()),
|
||||
topk_idx.data<int64_t>(),
|
||||
token_num, topk, num_vecs,
|
||||
hidden, max_tokens_per_expert
|
||||
);
|
||||
MoEDeepGEMMPermuteKernel<DataType_, VecSize><<<grids, blocks, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(permute_output_data),
|
||||
token_nums_per_expert.data<int32_t>(),
|
||||
permute_indices_per_token.data<int32_t>(),
|
||||
reinterpret_cast<const DataType_*>(x.data<data_t>()),
|
||||
topk_idx.data<int64_t>(),
|
||||
token_num,
|
||||
topk,
|
||||
num_vecs,
|
||||
hidden,
|
||||
max_tokens_per_expert);
|
||||
|
||||
return {permute_output, token_nums_per_expert, permute_indices_per_token};
|
||||
return {permute_output, token_nums_per_expert, permute_indices_per_token};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> MoEDeepGEMMPermute(
|
||||
const paddle::Tensor& x,
|
||||
const paddle::Tensor& topk_idx,
|
||||
const int num_experts,
|
||||
const int max_tokens_per_expert
|
||||
) {
|
||||
switch (x.dtype()) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
return MoEDeepGEMMPermuteDispatch<paddle::DataType::BFLOAT16>(
|
||||
x, topk_idx, num_experts, max_tokens_per_expert
|
||||
);
|
||||
case paddle::DataType::FLOAT16:
|
||||
return MoEDeepGEMMPermuteDispatch<paddle::DataType::FLOAT16>(
|
||||
x, topk_idx, num_experts, max_tokens_per_expert
|
||||
);
|
||||
default:
|
||||
PD_THROW("Unsupported data type");
|
||||
}
|
||||
const int max_tokens_per_expert) {
|
||||
switch (x.dtype()) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
return MoEDeepGEMMPermuteDispatch<paddle::DataType::BFLOAT16>(
|
||||
x, topk_idx, num_experts, max_tokens_per_expert);
|
||||
case paddle::DataType::FLOAT16:
|
||||
return MoEDeepGEMMPermuteDispatch<paddle::DataType::FLOAT16>(
|
||||
x, topk_idx, num_experts, max_tokens_per_expert);
|
||||
default:
|
||||
PD_THROW("Unsupported data type");
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(moe_deepgemm_permute)
|
||||
.Inputs({"x", "topk_idx"})
|
||||
.Outputs({"permute_output", "token_nums_per_expert", "permute_indices_per_token"})
|
||||
.Outputs({"permute_output",
|
||||
"token_nums_per_expert",
|
||||
"permute_indices_per_token"})
|
||||
.Attrs({"num_experts: int", "max_tokens_per_expert: int"})
|
||||
.SetKernelFn(PD_KERNEL(MoEDeepGEMMPermute));
|
||||
|
||||
@@ -27,8 +27,10 @@ void MoeReduceKernel(const paddle::Tensor &ffn_out,
|
||||
const paddle::Tensor &top_k_indices,
|
||||
const paddle::optional<paddle::Tensor> &down_proj_bias,
|
||||
const bool norm_topk_prob,
|
||||
const float routed_scaling_factor, const int num_rows,
|
||||
const int hidden_size, const int topk,
|
||||
const float routed_scaling_factor,
|
||||
const int num_rows,
|
||||
const int hidden_size,
|
||||
const int topk,
|
||||
paddle::Tensor *output) {
|
||||
using namespace phi;
|
||||
typedef PDTraits<T> traits_;
|
||||
@@ -37,19 +39,29 @@ void MoeReduceKernel(const paddle::Tensor &ffn_out,
|
||||
auto stream = ffn_out.stream();
|
||||
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
ffn_out.data<data_t>(), output->data<data_t>(),
|
||||
ffn_out.data<data_t>(),
|
||||
output->data<data_t>(),
|
||||
down_proj_bias ? down_proj_bias->data<data_t>() : nullptr,
|
||||
top_k_weight.data<float>(), permute_indices_per_token.data<int32_t>(),
|
||||
top_k_indices.data<int>(), num_rows, hidden_size, topk,
|
||||
static_cast<int>(1), norm_topk_prob, routed_scaling_factor, stream);
|
||||
top_k_weight.data<float>(),
|
||||
permute_indices_per_token.data<int32_t>(),
|
||||
top_k_indices.data<int>(),
|
||||
num_rows,
|
||||
hidden_size,
|
||||
topk,
|
||||
static_cast<int>(1),
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
stream);
|
||||
}
|
||||
|
||||
paddle::Tensor MoeExpertReduceFunc(
|
||||
const paddle::Tensor &ffn_out, const paddle::Tensor &top_k_weight,
|
||||
const paddle::Tensor &ffn_out,
|
||||
const paddle::Tensor &top_k_weight,
|
||||
const paddle::Tensor &permute_indices_per_token,
|
||||
const paddle::Tensor &top_k_indices,
|
||||
const paddle::optional<paddle::Tensor> &down_proj_bias,
|
||||
const bool norm_topk_prob, const float routed_scaling_factor) {
|
||||
const bool norm_topk_prob,
|
||||
const float routed_scaling_factor) {
|
||||
const auto input_type = ffn_out.dtype();
|
||||
auto place = ffn_out.place();
|
||||
|
||||
@@ -59,38 +71,57 @@ paddle::Tensor MoeExpertReduceFunc(
|
||||
|
||||
auto output = GetEmptyTensor({num_rows, hidden_size}, input_type, place);
|
||||
|
||||
if(num_rows == 0){
|
||||
if (num_rows == 0) {
|
||||
return output;
|
||||
}
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MoeReduceKernel<paddle::DataType::BFLOAT16>(
|
||||
ffn_out, top_k_weight, permute_indices_per_token, top_k_indices,
|
||||
down_proj_bias, norm_topk_prob, routed_scaling_factor, num_rows, hidden_size,
|
||||
topk, &output);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
MoeReduceKernel<paddle::DataType::BFLOAT16>(
|
||||
ffn_out, top_k_weight, permute_indices_per_token, top_k_indices,
|
||||
down_proj_bias, norm_topk_prob, routed_scaling_factor, num_rows, hidden_size,
|
||||
topk, &output);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for MoeDispatchKernel");
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MoeReduceKernel<paddle::DataType::BFLOAT16>(ffn_out,
|
||||
top_k_weight,
|
||||
permute_indices_per_token,
|
||||
top_k_indices,
|
||||
down_proj_bias,
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
topk,
|
||||
&output);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
MoeReduceKernel<paddle::DataType::BFLOAT16>(ffn_out,
|
||||
top_k_weight,
|
||||
permute_indices_per_token,
|
||||
top_k_indices,
|
||||
down_proj_bias,
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
topk,
|
||||
&output);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for MoeDispatchKernel");
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor>
|
||||
MoeExpertReduce(const paddle::Tensor &ffn_out,
|
||||
const paddle::Tensor &top_k_weight,
|
||||
const paddle::Tensor &permute_indices_per_token,
|
||||
const paddle::Tensor &top_k_indices,
|
||||
const paddle::optional<paddle::Tensor> &down_proj_bias,
|
||||
const bool norm_topk_prob, const float routed_scaling_factor) {
|
||||
return {MoeExpertReduceFunc(ffn_out, top_k_weight, permute_indices_per_token,
|
||||
top_k_indices, down_proj_bias, norm_topk_prob,
|
||||
std::vector<paddle::Tensor> MoeExpertReduce(
|
||||
const paddle::Tensor &ffn_out,
|
||||
const paddle::Tensor &top_k_weight,
|
||||
const paddle::Tensor &permute_indices_per_token,
|
||||
const paddle::Tensor &top_k_indices,
|
||||
const paddle::optional<paddle::Tensor> &down_proj_bias,
|
||||
const bool norm_topk_prob,
|
||||
const float routed_scaling_factor) {
|
||||
return {MoeExpertReduceFunc(ffn_out,
|
||||
top_k_weight,
|
||||
permute_indices_per_token,
|
||||
top_k_indices,
|
||||
down_proj_bias,
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor)};
|
||||
}
|
||||
|
||||
@@ -115,7 +146,6 @@ std::vector<paddle::DataType> MoeExpertReduceInferDtype(
|
||||
return {ffn_out_dtype};
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* @brief Mixture of Experts (MoE) Expert Reduce Operator
|
||||
*
|
||||
@@ -131,9 +161,8 @@ std::vector<paddle::DataType> MoeExpertReduceInferDtype(
|
||||
* - top_k_weight: Routing weights for top-k experts per token
|
||||
* Shape: [total_tokens, moe_topk]
|
||||
* dtype: float32
|
||||
* - permute_indices_per_token: Indices mapping for reconstructing original order
|
||||
* Shape: [moe_topk, total_tokens]
|
||||
* dtype: int32
|
||||
* - permute_indices_per_token: Indices mapping for reconstructing original
|
||||
* order Shape: [moe_topk, total_tokens] dtype: int32
|
||||
* - top_k_indices: Indices of selected top-k experts for each token
|
||||
* Shape: [total_tokens, moe_topk]
|
||||
* dtype: int32
|
||||
@@ -157,8 +186,11 @@ std::vector<paddle::DataType> MoeExpertReduceInferDtype(
|
||||
* - For optimal performance, hidden_size should be a multiple of 128
|
||||
*/
|
||||
PD_BUILD_STATIC_OP(moe_expert_reduce)
|
||||
.Inputs({"ffn_out", "top_k_weight", "permute_indices_per_token",
|
||||
"top_k_indices", paddle::Optional("down_proj_bias")})
|
||||
.Inputs({"ffn_out",
|
||||
"top_k_weight",
|
||||
"permute_indices_per_token",
|
||||
"top_k_indices",
|
||||
paddle::Optional("down_proj_bias")})
|
||||
.Outputs({"output"})
|
||||
.Attrs({"norm_topk_prob:bool", "routed_scaling_factor:float"})
|
||||
.SetKernelFn(PD_KERNEL(MoeExpertReduce))
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
// Ignore CUTLASS warnings about type punning
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
@@ -27,35 +26,42 @@
|
||||
using namespace phi;
|
||||
|
||||
template <typename T>
|
||||
void moe_redundant_topk_select_kernel(const T* input,
|
||||
const T* bias,
|
||||
T* output,
|
||||
T* softmax,
|
||||
const int* expert_id_to_ep_rank_array,
|
||||
const int* expert_in_rank_num_list,
|
||||
int* tokens_per_expert_stats_list,
|
||||
int64_t* indices,
|
||||
int64_t* indices_tmp,
|
||||
int* source_row,
|
||||
T* softmax_max_prob,
|
||||
const int64_t num_rows,
|
||||
const int64_t num_experts,
|
||||
const int64_t k,
|
||||
const int redundant_ep_rank_num_plus_one,
|
||||
cudaStream_t stream,
|
||||
const bool apply_norm_weight = false,
|
||||
const bool enable_softmax_top_k_fused = false
|
||||
) {
|
||||
void moe_redundant_topk_select_kernel(
|
||||
const T* input,
|
||||
const T* bias,
|
||||
T* output,
|
||||
T* softmax,
|
||||
const int* expert_id_to_ep_rank_array,
|
||||
const int* expert_in_rank_num_list,
|
||||
int* tokens_per_expert_stats_list,
|
||||
int64_t* indices,
|
||||
int64_t* indices_tmp,
|
||||
int* source_row,
|
||||
T* softmax_max_prob,
|
||||
const int64_t num_rows,
|
||||
const int64_t num_experts,
|
||||
const int64_t k,
|
||||
const int redundant_ep_rank_num_plus_one,
|
||||
cudaStream_t stream,
|
||||
const bool apply_norm_weight = false,
|
||||
const bool enable_softmax_top_k_fused = false) {
|
||||
static constexpr int WARPS_PER_TB = 4;
|
||||
|
||||
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
|
||||
case N: { \
|
||||
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB>( \
|
||||
input, bias, output, indices, source_row, num_rows, num_experts, k, stream); \
|
||||
break; \
|
||||
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
|
||||
case N: { \
|
||||
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB>(input, \
|
||||
bias, \
|
||||
output, \
|
||||
indices, \
|
||||
source_row, \
|
||||
num_rows, \
|
||||
num_experts, \
|
||||
k, \
|
||||
stream); \
|
||||
break; \
|
||||
}
|
||||
int64_t tem_num_experts = num_experts;
|
||||
if(bias != nullptr || apply_norm_weight) tem_num_experts = 0;
|
||||
if (bias != nullptr || apply_norm_weight) tem_num_experts = 0;
|
||||
switch (tem_num_experts) {
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(2)
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(4)
|
||||
@@ -70,60 +76,60 @@ void moe_redundant_topk_select_kernel(const T* input,
|
||||
static constexpr int TPB = 256;
|
||||
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
|
||||
if (!enable_softmax_top_k_fused) {
|
||||
moe_softmax<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
|
||||
input, softmax, num_experts, num_rows);
|
||||
if (apply_norm_weight) {
|
||||
moe_redundant_top_k_normed<T, TPB>
|
||||
<<<config_topk.block_per_grid, TPB, k * sizeof(T), stream>>>(softmax,
|
||||
bias,
|
||||
expert_id_to_ep_rank_array,
|
||||
expert_in_rank_num_list,
|
||||
tokens_per_expert_stats_list,
|
||||
output,
|
||||
indices,
|
||||
indices_tmp,
|
||||
source_row,
|
||||
num_experts,
|
||||
k,
|
||||
num_rows,
|
||||
redundant_ep_rank_num_plus_one);
|
||||
} else {
|
||||
moe_top_k<T, TPB>
|
||||
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
|
||||
bias,
|
||||
output,
|
||||
indices,
|
||||
source_row,
|
||||
num_experts,
|
||||
k,
|
||||
num_rows);
|
||||
}
|
||||
moe_softmax<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
|
||||
input, softmax, num_experts, num_rows);
|
||||
if (apply_norm_weight) {
|
||||
moe_redundant_top_k_normed<T, TPB>
|
||||
<<<config_topk.block_per_grid, TPB, k * sizeof(T), stream>>>(
|
||||
softmax,
|
||||
bias,
|
||||
expert_id_to_ep_rank_array,
|
||||
expert_in_rank_num_list,
|
||||
tokens_per_expert_stats_list,
|
||||
output,
|
||||
indices,
|
||||
indices_tmp,
|
||||
source_row,
|
||||
num_experts,
|
||||
k,
|
||||
num_rows,
|
||||
redundant_ep_rank_num_plus_one);
|
||||
} else {
|
||||
moe_top_k<T, TPB>
|
||||
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
|
||||
bias,
|
||||
output,
|
||||
indices,
|
||||
source_row,
|
||||
num_experts,
|
||||
k,
|
||||
num_rows);
|
||||
}
|
||||
} else {
|
||||
assert(k <= TPB);
|
||||
if (apply_norm_weight) {
|
||||
moe_softmax_top_k_fused<T, TPB, true>
|
||||
<<<config_topk.block_per_grid, TPB, k * sizeof(T), stream>>>(
|
||||
input,
|
||||
bias,
|
||||
output,
|
||||
indices,
|
||||
source_row,
|
||||
num_experts,
|
||||
k,
|
||||
num_rows);
|
||||
} else {
|
||||
moe_softmax_top_k_fused<T, TPB, false>
|
||||
<<<config_topk.block_per_grid, TPB, 0, stream>>>(input,
|
||||
bias,
|
||||
output,
|
||||
indices,
|
||||
source_row,
|
||||
num_experts,
|
||||
k,
|
||||
num_rows);
|
||||
}
|
||||
}
|
||||
else {
|
||||
assert(k<=TPB);
|
||||
if (apply_norm_weight) {
|
||||
moe_softmax_top_k_fused<T, TPB, true>
|
||||
<<<config_topk.block_per_grid, TPB, k * sizeof(T), stream>>>(input,
|
||||
bias,
|
||||
output,
|
||||
indices,
|
||||
source_row,
|
||||
num_experts,
|
||||
k,
|
||||
num_rows);
|
||||
} else {
|
||||
moe_softmax_top_k_fused<T, TPB, false>
|
||||
<<<config_topk.block_per_grid, TPB, 0, stream>>>(input,
|
||||
bias,
|
||||
output,
|
||||
indices,
|
||||
source_row,
|
||||
num_experts,
|
||||
k,
|
||||
num_rows);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -195,24 +201,25 @@ std::vector<paddle::Tensor> MoERedundantTopKSelectKernel(
|
||||
softmax_out_ = nullptr;
|
||||
}
|
||||
|
||||
moe_redundant_topk_select_kernel<float>(gating_logits.data<float>(),
|
||||
bias ? bias.get().data<float>() : nullptr,
|
||||
topk_weights.data<float>(),
|
||||
softmax_out_,
|
||||
expert_id_to_ep_rank_array.data<int>(),
|
||||
expert_in_rank_num_list.data<int>(),
|
||||
tokens_per_expert_stats_list.data<int>(),
|
||||
topk_ids_data,
|
||||
topk_ids_tmp_data,
|
||||
source_rows_,
|
||||
softmax_max_prob,
|
||||
num_rows,
|
||||
expert_num,
|
||||
moe_topk,
|
||||
redundant_ep_rank_num_plus_one,
|
||||
stream,
|
||||
apply_norm_weight,
|
||||
enable_softmax_top_k_fused);
|
||||
moe_redundant_topk_select_kernel<float>(
|
||||
gating_logits.data<float>(),
|
||||
bias ? bias.get().data<float>() : nullptr,
|
||||
topk_weights.data<float>(),
|
||||
softmax_out_,
|
||||
expert_id_to_ep_rank_array.data<int>(),
|
||||
expert_in_rank_num_list.data<int>(),
|
||||
tokens_per_expert_stats_list.data<int>(),
|
||||
topk_ids_data,
|
||||
topk_ids_tmp_data,
|
||||
source_rows_,
|
||||
softmax_max_prob,
|
||||
num_rows,
|
||||
expert_num,
|
||||
moe_topk,
|
||||
redundant_ep_rank_num_plus_one,
|
||||
stream,
|
||||
apply_norm_weight,
|
||||
enable_softmax_top_k_fused);
|
||||
return {topk_ids, topk_weights};
|
||||
}
|
||||
|
||||
@@ -235,8 +242,7 @@ std::vector<std::vector<int64_t>> MoERedundantTopKSelectKernelInferShape(
|
||||
}
|
||||
const int num_rows = token_rows;
|
||||
|
||||
return {{num_rows, moe_topk},
|
||||
{num_rows, moe_topk}};
|
||||
return {{num_rows, moe_topk}, {num_rows, moe_topk}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> MoERedundantTopKSelectKernelInferDtype(
|
||||
@@ -249,18 +255,22 @@ std::vector<paddle::DataType> MoERedundantTopKSelectKernelInferDtype(
|
||||
const bool apply_norm_weight,
|
||||
const bool enable_softmax_top_k_fused,
|
||||
const int redundant_ep_rank_num_plus_one) {
|
||||
return {paddle::DataType::INT64,
|
||||
paddle::DataType::FLOAT32};
|
||||
return {paddle::DataType::INT64, paddle::DataType::FLOAT32};
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_STATIC_OP(moe_redundant_topk_select)
|
||||
.Inputs({"gating_logits", "expert_id_to_ep_rank_array", "expert_in_rank_num_list", "tokens_per_expert_stats_list", paddle::Optional("bias")})
|
||||
.Outputs({"topk_ids",
|
||||
"topk_weights",
|
||||
"tokens_per_expert_stats_list_out"})
|
||||
.Attrs({"moe_topk:int", "apply_norm_weight:bool", "enable_softmax_top_k_fused:bool", "redundant_ep_rank_num_plus_one:int"})
|
||||
.SetInplaceMap({{"tokens_per_expert_stats_list", "tokens_per_expert_stats_list_out"}})
|
||||
.Inputs({"gating_logits",
|
||||
"expert_id_to_ep_rank_array",
|
||||
"expert_in_rank_num_list",
|
||||
"tokens_per_expert_stats_list",
|
||||
paddle::Optional("bias")})
|
||||
.Outputs({"topk_ids", "topk_weights", "tokens_per_expert_stats_list_out"})
|
||||
.Attrs({"moe_topk:int",
|
||||
"apply_norm_weight:bool",
|
||||
"enable_softmax_top_k_fused:bool",
|
||||
"redundant_ep_rank_num_plus_one:int"})
|
||||
.SetInplaceMap({{"tokens_per_expert_stats_list",
|
||||
"tokens_per_expert_stats_list_out"}})
|
||||
.SetKernelFn(PD_KERNEL(MoERedundantTopKSelectKernel))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(MoERedundantTopKSelectKernelInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(MoERedundantTopKSelectKernelInferDtype));
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
#include "paddle/phi/api/include/api.h"
|
||||
|
||||
@@ -41,15 +41,18 @@ class ScalarType {
|
||||
NAN_REPR_ID_MAX
|
||||
};
|
||||
|
||||
constexpr ScalarType(uint8_t exponent, uint8_t mantissa, bool signed_,
|
||||
int32_t bias, bool finite_values_only = false,
|
||||
constexpr ScalarType(uint8_t exponent,
|
||||
uint8_t mantissa,
|
||||
bool signed_,
|
||||
int32_t bias,
|
||||
bool finite_values_only = false,
|
||||
NanRepr nan_repr = NAN_IEEE_754)
|
||||
: exponent(exponent),
|
||||
mantissa(mantissa),
|
||||
signed_(signed_),
|
||||
bias(bias),
|
||||
finite_values_only(finite_values_only),
|
||||
nan_repr(nan_repr) {};
|
||||
nan_repr(nan_repr){};
|
||||
|
||||
static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
|
||||
return ScalarType(0, size_bits - 1, true, bias);
|
||||
@@ -67,16 +70,17 @@ class ScalarType {
|
||||
}
|
||||
|
||||
// IEEE 754 non-compliant floating point type
|
||||
static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa,
|
||||
static constexpr ScalarType float_(uint8_t exponent,
|
||||
uint8_t mantissa,
|
||||
bool finite_values_only,
|
||||
NanRepr nan_repr) {
|
||||
// PADDLE_ENFORCE(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr");
|
||||
// PADDLE_ENFORCE(mantissa > 0 && exponent > 0);
|
||||
// PADDLE_ENFORCE(nan_repr != NAN_IEEE_754,
|
||||
// "use `float_IEEE754` constructor for floating point types that "
|
||||
// "follow IEEE 754 conventions");
|
||||
return ScalarType(exponent, mantissa, true, 0, finite_values_only,
|
||||
nan_repr);
|
||||
// "use `float_IEEE754` constructor for floating point types
|
||||
// that " "follow IEEE 754 conventions");
|
||||
return ScalarType(
|
||||
exponent, mantissa, true, 0, finite_values_only, nan_repr);
|
||||
}
|
||||
|
||||
uint8_t const exponent; // size of the exponent field (0 for integer types)
|
||||
@@ -103,7 +107,9 @@ class ScalarType {
|
||||
}
|
||||
|
||||
template <typename Fn, typename Init, typename Member, typename... Rest>
|
||||
static constexpr auto reduce_members_helper(Fn f, Init val, Member member,
|
||||
static constexpr auto reduce_members_helper(Fn f,
|
||||
Init val,
|
||||
Member member,
|
||||
Rest... rest) {
|
||||
auto new_val = f(val, member);
|
||||
if constexpr (sizeof...(rest) > 0) {
|
||||
@@ -116,8 +122,14 @@ class ScalarType {
|
||||
template <typename Fn, typename Init>
|
||||
constexpr auto reduce_members(Fn f, Init init) const {
|
||||
// Should be in constructor order for `from_id`
|
||||
return reduce_members_helper(f, init, exponent, mantissa, signed_, bias,
|
||||
finite_values_only, nan_repr);
|
||||
return reduce_members_helper(f,
|
||||
init,
|
||||
exponent,
|
||||
mantissa,
|
||||
signed_,
|
||||
bias,
|
||||
finite_values_only,
|
||||
nan_repr);
|
||||
};
|
||||
|
||||
template <typename Fn, typename Init>
|
||||
@@ -194,7 +206,8 @@ class ScalarType {
|
||||
private:
|
||||
double _floating_point_max() const {
|
||||
PADDLE_ENFORCE(mantissa <= 52 && exponent <= 11,
|
||||
"Cannot represent max/min as a double for type ", str());
|
||||
"Cannot represent max/min as a double for type ",
|
||||
str());
|
||||
|
||||
uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1;
|
||||
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) {
|
||||
@@ -204,7 +217,8 @@ class ScalarType {
|
||||
uint64_t max_exponent = (uint64_t(1) << exponent) - 2;
|
||||
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) {
|
||||
PADDLE_ENFORCE(exponent < 11,
|
||||
"Cannot represent max/min as a double for type ", str());
|
||||
"Cannot represent max/min as a double for type ",
|
||||
str());
|
||||
max_exponent += 1;
|
||||
}
|
||||
|
||||
@@ -327,23 +341,32 @@ using ScalarTypeId = MARLIN_NAMESPACE_NAME::ScalarType::Id;
|
||||
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70
|
||||
static inline constexpr auto kS4 = MARLIN_NAMESPACE_NAME::ScalarType::int_(4);
|
||||
static inline constexpr auto kU4 = MARLIN_NAMESPACE_NAME::ScalarType::uint(4);
|
||||
static inline constexpr auto kU4B8 = MARLIN_NAMESPACE_NAME::ScalarType::uint(4, 8);
|
||||
static inline constexpr auto kU4B8 =
|
||||
MARLIN_NAMESPACE_NAME::ScalarType::uint(4, 8);
|
||||
static inline constexpr auto kS8 = MARLIN_NAMESPACE_NAME::ScalarType::int_(8);
|
||||
static inline constexpr auto kU8 = MARLIN_NAMESPACE_NAME::ScalarType::uint(8);
|
||||
static inline constexpr auto kU8B128 = MARLIN_NAMESPACE_NAME::ScalarType::uint(8, 128);
|
||||
static inline constexpr auto kU8B128 =
|
||||
MARLIN_NAMESPACE_NAME::ScalarType::uint(8, 128);
|
||||
|
||||
static inline constexpr auto kFE2M1f =
|
||||
MARLIN_NAMESPACE_NAME::ScalarType::float_(2, 1, true, MARLIN_NAMESPACE_NAME::ScalarType::NAN_NONE);
|
||||
MARLIN_NAMESPACE_NAME::ScalarType::float_(
|
||||
2, 1, true, MARLIN_NAMESPACE_NAME::ScalarType::NAN_NONE);
|
||||
static inline constexpr auto kFE3M2f =
|
||||
MARLIN_NAMESPACE_NAME::ScalarType::float_(3, 2, true, MARLIN_NAMESPACE_NAME::ScalarType::NAN_NONE);
|
||||
MARLIN_NAMESPACE_NAME::ScalarType::float_(
|
||||
3, 2, true, MARLIN_NAMESPACE_NAME::ScalarType::NAN_NONE);
|
||||
static inline constexpr auto kFE4M3fn =
|
||||
MARLIN_NAMESPACE_NAME::ScalarType::float_(4, 3, true, MARLIN_NAMESPACE_NAME::ScalarType::NAN_EXTD_RANGE_MAX_MIN);
|
||||
static inline constexpr auto kFE5M2 = MARLIN_NAMESPACE_NAME::ScalarType::float_IEEE754(5, 2);
|
||||
static inline constexpr auto kFE8M7 = MARLIN_NAMESPACE_NAME::ScalarType::float_IEEE754(8, 7);
|
||||
static inline constexpr auto kFE5M10 = MARLIN_NAMESPACE_NAME::ScalarType::float_IEEE754(5, 10);
|
||||
MARLIN_NAMESPACE_NAME::ScalarType::float_(
|
||||
4, 3, true, MARLIN_NAMESPACE_NAME::ScalarType::NAN_EXTD_RANGE_MAX_MIN);
|
||||
static inline constexpr auto kFE5M2 =
|
||||
MARLIN_NAMESPACE_NAME::ScalarType::float_IEEE754(5, 2);
|
||||
static inline constexpr auto kFE8M7 =
|
||||
MARLIN_NAMESPACE_NAME::ScalarType::float_IEEE754(8, 7);
|
||||
static inline constexpr auto kFE5M10 =
|
||||
MARLIN_NAMESPACE_NAME::ScalarType::float_IEEE754(5, 10);
|
||||
|
||||
// // Fixed width style names, generally following:
|
||||
// // https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57
|
||||
// //
|
||||
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57
|
||||
constexpr auto kInt4 = kS4;
|
||||
constexpr auto kUint4 = kU4;
|
||||
constexpr auto kUint4b8 = kU4B8;
|
||||
@@ -369,4 +392,4 @@ constexpr auto kBFloat16 = phi::DataType::BFLOAT16;
|
||||
constexpr auto kFloat32 = phi::DataType::FLOAT32;
|
||||
constexpr auto kByte = phi::DataType::INT8;
|
||||
|
||||
}; // namespace Marlin
|
||||
}; // namespace MARLIN_NAMESPACE_NAME
|
||||
|
||||
@@ -92,7 +92,8 @@ __device__ inline uint32_t prmt(uint32_t a) {
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename scalar_t2, MARLIN_NAMESPACE_NAME::ScalarTypeId w_type_id,
|
||||
template <typename scalar_t2,
|
||||
MARLIN_NAMESPACE_NAME::ScalarTypeId w_type_id,
|
||||
bool skip_flop = false>
|
||||
__device__ inline void dequant(int q, scalar_t2* frag_b);
|
||||
|
||||
@@ -106,8 +107,8 @@ __device__ inline void dequant(int q, scalar_t2* frag_b);
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
|
||||
//
|
||||
template <>
|
||||
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU4B8.id(), true>(int q,
|
||||
half2* frag_b) {
|
||||
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU4B8.id(), true>(
|
||||
int q, half2* frag_b) {
|
||||
const int MASK = 0x000f000f;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
@@ -120,8 +121,8 @@ __device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU4B8.id(), true>(i
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU4B8.id(), false>(int q,
|
||||
half2* frag_b) {
|
||||
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU4B8.id(), false>(
|
||||
int q, half2* frag_b) {
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
@@ -143,14 +144,14 @@ __device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU4B8.id(), false>(
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU4.id(), true>(int q,
|
||||
half2* frag_b) {
|
||||
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU4.id(), true>(
|
||||
int q, half2* frag_b) {
|
||||
dequant<half2, MARLIN_NAMESPACE_NAME::kU4B8.id(), true>(q, frag_b);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU4.id(), false>(int q,
|
||||
half2* frag_b) {
|
||||
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU4.id(), false>(
|
||||
int q, half2* frag_b) {
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
@@ -172,7 +173,8 @@ __device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU4.id(), false>(in
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU4B8.id(), true>(
|
||||
__device__ inline void
|
||||
dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU4B8.id(), true>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
static constexpr uint32_t MASK = 0x000f000f;
|
||||
static constexpr uint32_t EX = 0x43004300;
|
||||
@@ -189,7 +191,8 @@ __device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU4B8.id(),
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU4B8.id(), false>(
|
||||
__device__ inline void
|
||||
dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU4B8.id(), false>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU4B8.id(), true>(q, frag_b);
|
||||
|
||||
@@ -200,13 +203,15 @@ __device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU4B8.id(),
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU4.id(), true>(
|
||||
__device__ inline void
|
||||
dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU4.id(), true>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU4B8.id(), true>(q, frag_b);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU4.id(), false>(
|
||||
__device__ inline void
|
||||
dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU4.id(), false>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU4.id(), true>(q, frag_b);
|
||||
|
||||
@@ -225,8 +230,9 @@ __device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU4.id(), fa
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
|
||||
//
|
||||
template <>
|
||||
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU8B128.id(), true>(int q,
|
||||
half2* frag_b) {
|
||||
__device__ inline void
|
||||
dequant<half2, MARLIN_NAMESPACE_NAME::kU8B128.id(), true>(int q,
|
||||
half2* frag_b) {
|
||||
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||
@@ -239,8 +245,9 @@ __device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU8B128.id(), true>
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU8B128.id(), false>(
|
||||
int q, half2* frag_b) {
|
||||
__device__ inline void
|
||||
dequant<half2, MARLIN_NAMESPACE_NAME::kU8B128.id(), false>(int q,
|
||||
half2* frag_b) {
|
||||
dequant<half2, MARLIN_NAMESPACE_NAME::kU8B128.id(), true>(q, frag_b);
|
||||
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
||||
@@ -251,14 +258,14 @@ __device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU8B128.id(), false
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU8.id(), true>(int q,
|
||||
half2* frag_b) {
|
||||
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU8.id(), true>(
|
||||
int q, half2* frag_b) {
|
||||
dequant<half2, MARLIN_NAMESPACE_NAME::kU8B128.id(), true>(q, frag_b);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU8.id(), false>(int q,
|
||||
half2* frag_b) {
|
||||
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU8.id(), false>(
|
||||
int q, half2* frag_b) {
|
||||
dequant<half2, MARLIN_NAMESPACE_NAME::kU8.id(), true>(q, frag_b);
|
||||
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||
@@ -269,7 +276,8 @@ __device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kU8.id(), false>(in
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU8B128.id(), false>(
|
||||
__device__ inline void
|
||||
dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU8B128.id(), false>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
float fp32_intermediates[4];
|
||||
uint32_t* fp32_intermediates_casted =
|
||||
@@ -287,14 +295,15 @@ __device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU8B128.id()
|
||||
fp32_intermediates[3] -= 8388736.f;
|
||||
|
||||
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);
|
||||
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
|
||||
fp32_intermediates_casted[1], 0x7632);
|
||||
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
|
||||
fp32_intermediates_casted[3], 0x7632);
|
||||
bf16_result_ptr[0] = __byte_perm(
|
||||
fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632);
|
||||
bf16_result_ptr[1] = __byte_perm(
|
||||
fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU8.id(), false>(
|
||||
__device__ inline void
|
||||
dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU8.id(), false>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
float fp32_intermediates[4];
|
||||
uint32_t* fp32_intermediates_casted =
|
||||
@@ -312,15 +321,16 @@ __device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kU8.id(), fa
|
||||
fp32_intermediates[3] -= 8388608.f;
|
||||
|
||||
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);
|
||||
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
|
||||
fp32_intermediates_casted[1], 0x7632);
|
||||
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
|
||||
fp32_intermediates_casted[3], 0x7632);
|
||||
bf16_result_ptr[0] = __byte_perm(
|
||||
fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632);
|
||||
bf16_result_ptr[1] = __byte_perm(
|
||||
fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kFE4M3fn.id(), true>(
|
||||
int q, half2* frag_b) {
|
||||
__device__ inline void
|
||||
dequant<half2, MARLIN_NAMESPACE_NAME::kFE4M3fn.id(), true>(int q,
|
||||
half2* frag_b) {
|
||||
// Constants for FP8 (E4M3) and FP16 formats
|
||||
constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5;
|
||||
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT;
|
||||
@@ -337,8 +347,9 @@ __device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kFE4M3fn.id(), true
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kFE4M3fn.id(), false>(
|
||||
int q, half2* frag_b) {
|
||||
__device__ inline void
|
||||
dequant<half2, MARLIN_NAMESPACE_NAME::kFE4M3fn.id(), false>(int q,
|
||||
half2* frag_b) {
|
||||
dequant<half2, MARLIN_NAMESPACE_NAME::kFE4M3fn.id(), true>(q, frag_b);
|
||||
|
||||
// Constants for FP8 (E4M3) and FP16 formats
|
||||
@@ -355,7 +366,8 @@ __device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kFE4M3fn.id(), fals
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kFE4M3fn.id(), true>(
|
||||
__device__ inline void
|
||||
dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kFE4M3fn.id(), true>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
// Constants for FP8 (E4M3) and BF16 formats
|
||||
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
|
||||
@@ -374,7 +386,8 @@ __device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kFE4M3fn.id(
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kFE4M3fn.id(), false>(
|
||||
__device__ inline void
|
||||
dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kFE4M3fn.id(), false>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kFE4M3fn.id(), true>(q, frag_b);
|
||||
|
||||
@@ -396,8 +409,9 @@ __device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kFE4M3fn.id(
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kFE2M1f.id(), true>(int q,
|
||||
half2* frag_b) {
|
||||
__device__ inline void
|
||||
dequant<half2, MARLIN_NAMESPACE_NAME::kFE2M1f.id(), true>(int q,
|
||||
half2* frag_b) {
|
||||
// Constants for FP4 (E2M1) and FP16 formats
|
||||
constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5;
|
||||
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT;
|
||||
@@ -414,8 +428,9 @@ __device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kFE2M1f.id(), true>
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kFE2M1f.id(), false>(
|
||||
int q, half2* frag_b) {
|
||||
__device__ inline void
|
||||
dequant<half2, MARLIN_NAMESPACE_NAME::kFE2M1f.id(), false>(int q,
|
||||
half2* frag_b) {
|
||||
dequant<half2, MARLIN_NAMESPACE_NAME::kFE2M1f.id(), true>(q, frag_b);
|
||||
|
||||
// Constants for FP4 (E2M1) and FP16 formats
|
||||
@@ -432,7 +447,8 @@ __device__ inline void dequant<half2, MARLIN_NAMESPACE_NAME::kFE2M1f.id(), false
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kFE2M1f.id(), true>(
|
||||
__device__ inline void
|
||||
dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kFE2M1f.id(), true>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
// Constants for FP4 (E2M1) and FP16 formats
|
||||
constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8;
|
||||
@@ -450,7 +466,8 @@ __device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kFE2M1f.id()
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kFE2M1f.id(), false>(
|
||||
__device__ inline void
|
||||
dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kFE2M1f.id(), false>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
dequant<nv_bfloat162, MARLIN_NAMESPACE_NAME::kFE2M1f.id(), true>(q, frag_b);
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
#include "moe/moe_wna16_marlin_utils/marlin.cuh"
|
||||
#include "moe/moe_wna16_marlin_utils/marlin_dtypes.cuh"
|
||||
@@ -22,7 +22,8 @@
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const MARLIN_NAMESPACE_NAME::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const MARLIN_NAMESPACE_NAME::ScalarTypeId w_type_id, // weight
|
||||
// ScalarType id
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
#include <iostream>
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
@@ -51,10 +51,10 @@ using I4 = Vec<int, 4>;
|
||||
|
||||
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
|
||||
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
||||
__device__ inline void cp_async4_pred(void* smem_ptr,
|
||||
const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
@@ -64,7 +64,9 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||
"}\n" ::"r"((int)pred),
|
||||
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
||||
"r"(smem),
|
||||
"l"(glob_ptr),
|
||||
"n"(BYTES));
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||||
@@ -74,7 +76,8 @@ __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||||
"{\n"
|
||||
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
||||
"}\n" ::"r"(smem),
|
||||
"l"(glob_ptr), "n"(BYTES));
|
||||
"l"(glob_ptr),
|
||||
"n"(BYTES));
|
||||
}
|
||||
|
||||
__device__ inline void cp_async_fence() {
|
||||
@@ -93,7 +96,6 @@ inline void cp_async_fence() {}
|
||||
template <int n>
|
||||
inline void cp_async_wait() {}
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME::kernel_types {
|
||||
@@ -78,6 +78,6 @@ class ScalarType<nv_bfloat16> {
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
} // namespace MARLIN_NAMESPACE_NAME::kernel_types
|
||||
|
||||
#endif
|
||||
|
||||
@@ -20,7 +20,7 @@
|
||||
*/
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
#include "moe/moe_wna16_marlin_utils/marlin.cuh"
|
||||
@@ -38,7 +38,8 @@ namespace MARLIN_NAMESPACE_NAME {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const MARLIN_NAMESPACE_NAME::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const MARLIN_NAMESPACE_NAME::ScalarTypeId w_type_id, // weight
|
||||
// ScalarType id
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the
|
||||
@@ -86,9 +87,10 @@ __global__ void Marlin(
|
||||
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
|
||||
// output/accumulation.
|
||||
template <typename scalar_t>
|
||||
__device__ inline void mma(const typename kernel_types::ScalarType<scalar_t>::FragA& a_frag,
|
||||
const typename kernel_types::ScalarType<scalar_t>::FragB& frag_b,
|
||||
typename kernel_types::ScalarType<scalar_t>::FragC& frag_c) {
|
||||
__device__ inline void mma(
|
||||
const typename kernel_types::ScalarType<scalar_t>::FragA& a_frag,
|
||||
const typename kernel_types::ScalarType<scalar_t>::FragB& frag_b,
|
||||
typename kernel_types::ScalarType<scalar_t>::FragC& frag_c) {
|
||||
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
@@ -97,15 +99,31 @@ __device__ inline void mma(const typename kernel_types::ScalarType<scalar_t>::Fr
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
: "r"(a[0]),
|
||||
"r"(a[1]),
|
||||
"r"(a[2]),
|
||||
"r"(a[3]),
|
||||
"r"(b[0]),
|
||||
"r"(b[1]),
|
||||
"f"(c[0]),
|
||||
"f"(c[1]),
|
||||
"f"(c[2]),
|
||||
"f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
: "r"(a[0]),
|
||||
"r"(a[1]),
|
||||
"r"(a[2]),
|
||||
"r"(a[3]),
|
||||
"r"(b[0]),
|
||||
"r"(b[1]),
|
||||
"f"(c[0]),
|
||||
"f"(c[1]),
|
||||
"f"(c[2]),
|
||||
"f"(c[3]));
|
||||
} else {
|
||||
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
||||
}
|
||||
@@ -126,15 +144,31 @@ __device__ inline void mma_trans(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
: "r"(b[0]),
|
||||
"r"(b2[0]),
|
||||
"r"(b[1]),
|
||||
"r"(b2[1]),
|
||||
"r"(a[0]),
|
||||
"r"(a[1]),
|
||||
"f"(c[0]),
|
||||
"f"(c[1]),
|
||||
"f"(c[2]),
|
||||
"f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
: "r"(b[0]),
|
||||
"r"(b2[0]),
|
||||
"r"(b[1]),
|
||||
"r"(b2[1]),
|
||||
"r"(a[0]),
|
||||
"r"(a[1]),
|
||||
"f"(c[0]),
|
||||
"f"(c[1]),
|
||||
"f"(c[2]),
|
||||
"f"(c[3]));
|
||||
} else {
|
||||
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
||||
}
|
||||
@@ -143,8 +177,9 @@ __device__ inline void mma_trans(
|
||||
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
||||
// memory, directly in tensor core layout.
|
||||
template <int count, typename scalar_t>
|
||||
__device__ inline void ldsm(typename kernel_types::ScalarType<scalar_t>::FragA& frag_a,
|
||||
const void* smem_ptr) {
|
||||
__device__ inline void ldsm(
|
||||
typename kernel_types::ScalarType<scalar_t>::FragA& frag_a,
|
||||
const void* smem_ptr) {
|
||||
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
if constexpr (count == 4) {
|
||||
@@ -168,19 +203,22 @@ __device__ inline void ldsm(typename kernel_types::ScalarType<scalar_t>::FragA&
|
||||
// Multiply dequantized values by the corresponding quantization scale; used
|
||||
// only for grouped quantization.
|
||||
template <typename scalar_t>
|
||||
__device__ inline void scale(typename kernel_types::ScalarType<scalar_t>::FragB& frag_b,
|
||||
typename kernel_types::ScalarType<scalar_t>::FragS& frag_s,
|
||||
int i) {
|
||||
__device__ inline void scale(
|
||||
typename kernel_types::ScalarType<scalar_t>::FragB& frag_b,
|
||||
typename kernel_types::ScalarType<scalar_t>::FragS& frag_s,
|
||||
int i) {
|
||||
using scalar_t2 = typename kernel_types::ScalarType<scalar_t>::scalar_t2;
|
||||
scalar_t2 s =
|
||||
kernel_types::ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_s)[i]);
|
||||
scalar_t2 s = kernel_types::ScalarType<scalar_t>::num2num2(
|
||||
reinterpret_cast<scalar_t*>(&frag_s)[i]);
|
||||
frag_b[0] = __hmul2(frag_b[0], s);
|
||||
frag_b[1] = __hmul2(frag_b[1], s);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ inline void scale_and_sub(
|
||||
typename kernel_types::ScalarType<scalar_t>::FragB& frag_b, scalar_t s, scalar_t zp) {
|
||||
typename kernel_types::ScalarType<scalar_t>::FragB& frag_b,
|
||||
scalar_t s,
|
||||
scalar_t zp) {
|
||||
using scalar_t2 = typename kernel_types::ScalarType<scalar_t>::scalar_t2;
|
||||
scalar_t2 s2 = kernel_types::ScalarType<scalar_t>::num2num2(s);
|
||||
scalar_t2 zp2 = kernel_types::ScalarType<scalar_t>::num2num2(zp);
|
||||
@@ -189,24 +227,26 @@ __device__ inline void scale_and_sub(
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ inline void sub_zp(typename kernel_types::ScalarType<scalar_t>::FragB& frag_b,
|
||||
typename kernel_types::ScalarType<scalar_t>::scalar_t2& frag_zp,
|
||||
int i) {
|
||||
__device__ inline void sub_zp(
|
||||
typename kernel_types::ScalarType<scalar_t>::FragB& frag_b,
|
||||
typename kernel_types::ScalarType<scalar_t>::scalar_t2& frag_zp,
|
||||
int i) {
|
||||
using scalar_t2 = typename kernel_types::ScalarType<scalar_t>::scalar_t2;
|
||||
scalar_t2 zp =
|
||||
kernel_types::ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_zp)[i]);
|
||||
scalar_t2 zp = kernel_types::ScalarType<scalar_t>::num2num2(
|
||||
reinterpret_cast<scalar_t*>(&frag_zp)[i]);
|
||||
frag_b[0] = __hsub2(frag_b[0], zp);
|
||||
frag_b[1] = __hsub2(frag_b[1], zp);
|
||||
}
|
||||
|
||||
// Same as above, but for act_order (each K is multiplied individually)
|
||||
template <typename scalar_t>
|
||||
__device__ inline void scale4(typename kernel_types::ScalarType<scalar_t>::FragB& frag_b,
|
||||
typename kernel_types::ScalarType<scalar_t>::FragS& frag_s_1,
|
||||
typename kernel_types::ScalarType<scalar_t>::FragS& frag_s_2,
|
||||
typename kernel_types::ScalarType<scalar_t>::FragS& frag_s_3,
|
||||
typename kernel_types::ScalarType<scalar_t>::FragS& frag_s_4,
|
||||
int i) {
|
||||
__device__ inline void scale4(
|
||||
typename kernel_types::ScalarType<scalar_t>::FragB& frag_b,
|
||||
typename kernel_types::ScalarType<scalar_t>::FragS& frag_s_1,
|
||||
typename kernel_types::ScalarType<scalar_t>::FragS& frag_s_2,
|
||||
typename kernel_types::ScalarType<scalar_t>::FragS& frag_s_3,
|
||||
typename kernel_types::ScalarType<scalar_t>::FragS& frag_s_4,
|
||||
int i) {
|
||||
using scalar_t2 = typename kernel_types::ScalarType<scalar_t>::scalar_t2;
|
||||
scalar_t2 s_val_1_2;
|
||||
s_val_1_2.x = reinterpret_cast<scalar_t*>(&frag_s_1)[i];
|
||||
@@ -222,11 +262,13 @@ __device__ inline void scale4(typename kernel_types::ScalarType<scalar_t>::FragB
|
||||
|
||||
// Given 2 floats multiply by 2 scales (halves)
|
||||
template <typename scalar_t>
|
||||
__device__ inline void scale_float(float* c,
|
||||
typename kernel_types::ScalarType<scalar_t>::FragS& s) {
|
||||
__device__ inline void scale_float(
|
||||
float* c, typename kernel_types::ScalarType<scalar_t>::FragS& s) {
|
||||
scalar_t* s_ptr = reinterpret_cast<scalar_t*>(&s);
|
||||
c[0] = __fmul_rn(c[0], kernel_types::ScalarType<scalar_t>::num2float(s_ptr[0]));
|
||||
c[1] = __fmul_rn(c[1], kernel_types::ScalarType<scalar_t>::num2float(s_ptr[1]));
|
||||
c[0] =
|
||||
__fmul_rn(c[0], kernel_types::ScalarType<scalar_t>::num2float(s_ptr[0]));
|
||||
c[1] =
|
||||
__fmul_rn(c[1], kernel_types::ScalarType<scalar_t>::num2float(s_ptr[1]));
|
||||
}
|
||||
|
||||
// Wait until barrier reaches `count`, then lock for current threadblock.
|
||||
@@ -279,7 +321,8 @@ __device__ inline void wait_negative_and_add(int* lock) {
|
||||
}
|
||||
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const MARLIN_NAMESPACE_NAME::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const MARLIN_NAMESPACE_NAME::ScalarTypeId w_type_id, // weight
|
||||
// ScalarType id
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the
|
||||
@@ -341,10 +384,14 @@ __global__ void Marlin(
|
||||
using FragZP = typename kernel_types::ScalarType<scalar_t>::FragZP;
|
||||
|
||||
extern __shared__ int4 sh[];
|
||||
static constexpr auto w_type = MARLIN_NAMESPACE_NAME::ScalarType::from_id(w_type_id);
|
||||
constexpr bool has_zp = w_type == MARLIN_NAMESPACE_NAME::kU4 || w_type == MARLIN_NAMESPACE_NAME::kU8;
|
||||
constexpr bool is_int_type = w_type == MARLIN_NAMESPACE_NAME::kU4 || w_type == MARLIN_NAMESPACE_NAME::kU8 ||
|
||||
w_type == MARLIN_NAMESPACE_NAME::kU4B8 || w_type == MARLIN_NAMESPACE_NAME::kU8B128;
|
||||
static constexpr auto w_type =
|
||||
MARLIN_NAMESPACE_NAME::ScalarType::from_id(w_type_id);
|
||||
constexpr bool has_zp = w_type == MARLIN_NAMESPACE_NAME::kU4 ||
|
||||
w_type == MARLIN_NAMESPACE_NAME::kU8;
|
||||
constexpr bool is_int_type = w_type == MARLIN_NAMESPACE_NAME::kU4 ||
|
||||
w_type == MARLIN_NAMESPACE_NAME::kU8 ||
|
||||
w_type == MARLIN_NAMESPACE_NAME::kU4B8 ||
|
||||
w_type == MARLIN_NAMESPACE_NAME::kU8B128;
|
||||
// see comments of dequant.h for more details
|
||||
constexpr bool dequant_skip_flop =
|
||||
!is_int_type ||
|
||||
@@ -361,7 +408,8 @@ __global__ void Marlin(
|
||||
const int group_size =
|
||||
(!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups;
|
||||
const int scales_expert_stride =
|
||||
prob_n * prob_k / group_size / (w_type == MARLIN_NAMESPACE_NAME::kFE2M1f ? 16 : 8);
|
||||
prob_n * prob_k / group_size /
|
||||
(w_type == MARLIN_NAMESPACE_NAME::kFE2M1f ? 16 : 8);
|
||||
const int zp_expert_stride =
|
||||
is_zp_float ? prob_n * prob_k / group_size / 8
|
||||
: prob_n * prob_k / group_size / (pack_factor * 4);
|
||||
@@ -444,12 +492,12 @@ __global__ void Marlin(
|
||||
// block_sorted_ids / block_num_valid_tokens / block_topk_weights
|
||||
auto read_moe_block_data = [&](int block_id) {
|
||||
block_num_valid_tokens = moe_block_size;
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < moe_block_size / 4; i++) {
|
||||
int4 sorted_token_ids_int4 = reinterpret_cast<const int4*>(
|
||||
sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i];
|
||||
int* sorted_token_ids = reinterpret_cast<int*>(&sorted_token_ids_int4);
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++) {
|
||||
if (sorted_token_ids[j] >= prob_m * top_k) {
|
||||
block_num_valid_tokens = i * 4 + j;
|
||||
@@ -465,20 +513,21 @@ __global__ void Marlin(
|
||||
sh_block_sorted_ids_int4[tid4] = reinterpret_cast<const int4*>(
|
||||
sorted_token_ids_ptr)[block_id * moe_block_size / 4 + tid4];
|
||||
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++)
|
||||
sh_rd_block_sorted_ids[tid4 * 4 + i] =
|
||||
sh_block_sorted_ids[tid4 * 4 + i] / top_k;
|
||||
|
||||
if (mul_topk_weights) {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int idx = tid4 * 4 + i;
|
||||
idx = idx < block_num_valid_tokens ? idx : 0;
|
||||
if constexpr (w_type == MARLIN_NAMESPACE_NAME::kFE2M1f) {
|
||||
sh_block_topk_weights[idx] = __hmul2(
|
||||
global_scale, Dtype::num2num2(Dtype::float2num(
|
||||
topk_weights_ptr[sh_block_sorted_ids[idx]])));
|
||||
sh_block_topk_weights[idx] =
|
||||
__hmul2(global_scale,
|
||||
Dtype::num2num2(Dtype::float2num(
|
||||
topk_weights_ptr[sh_block_sorted_ids[idx]])));
|
||||
} else {
|
||||
sh_block_topk_weights[idx] = Dtype::num2num2(
|
||||
Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]]));
|
||||
@@ -631,7 +680,8 @@ __global__ void Marlin(
|
||||
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
|
||||
constexpr int s_tb_groups =
|
||||
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
|
||||
? thread_k_blocks / group_blocks / (w_type == MARLIN_NAMESPACE_NAME::kFE2M1f ? 2 : 1)
|
||||
? thread_k_blocks / group_blocks /
|
||||
(w_type == MARLIN_NAMESPACE_NAME::kFE2M1f ? 2 : 1)
|
||||
: 1;
|
||||
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
|
||||
int s_gl_rd_delta = s_gl_stride;
|
||||
@@ -714,7 +764,8 @@ __global__ void Marlin(
|
||||
// we scale a `half2` tile in column-major layout in the former and in
|
||||
// row-major in the latter case.
|
||||
int s_sh_rd;
|
||||
if constexpr (group_blocks != -1 && w_type == MARLIN_NAMESPACE_NAME::kFE2M1f) {
|
||||
if constexpr (group_blocks != -1 &&
|
||||
w_type == MARLIN_NAMESPACE_NAME::kFE2M1f) {
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
int n_warps = thread_n_blocks / 4;
|
||||
int warp_row = warp_id / n_warps;
|
||||
@@ -767,13 +818,13 @@ __global__ void Marlin(
|
||||
// loop unrolls, all shared memory accesses are static, we simply precompute
|
||||
// both transformed reads and writes.
|
||||
int a_sh_wr_trans[a_sh_wr_iters];
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < a_sh_wr_iters; i++)
|
||||
a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
|
||||
int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < b_sh_wr_iters; i++) {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int j = 0; j < thread_m_blocks; j++)
|
||||
a_sh_rd_trans[i][j] =
|
||||
transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
|
||||
@@ -784,7 +835,7 @@ __global__ void Marlin(
|
||||
// maintining multiple pointers (we have enough registers), a tiny
|
||||
// optimization.
|
||||
const int4* B_ptr[b_sh_wr_iters];
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < b_sh_wr_iters; i++)
|
||||
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
|
||||
|
||||
@@ -824,7 +875,7 @@ __global__ void Marlin(
|
||||
|
||||
// Zero accumulators.
|
||||
auto zero_accums = [&]() {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
|
||||
reinterpret_cast<float*>(frag_c)[i] = 0;
|
||||
};
|
||||
@@ -832,39 +883,39 @@ __global__ void Marlin(
|
||||
int sh_first_group_id = -1;
|
||||
int sh_num_groups = -1;
|
||||
|
||||
auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id,
|
||||
int last_group_id) {
|
||||
sh_first_group_id = first_group_id;
|
||||
sh_num_groups = last_group_id - first_group_id + 1;
|
||||
auto fetch_act_order_scales_to_shared =
|
||||
[&](bool is_async, int first_group_id, int last_group_id) {
|
||||
sh_first_group_id = first_group_id;
|
||||
sh_num_groups = last_group_id - first_group_id + 1;
|
||||
|
||||
if (sh_num_groups > act_s_max_num_groups) {
|
||||
sh_num_groups = act_s_max_num_groups;
|
||||
}
|
||||
|
||||
if (sh_first_group_id + sh_num_groups > num_groups) {
|
||||
sh_num_groups = num_groups - sh_first_group_id;
|
||||
}
|
||||
|
||||
int row_offset = first_group_id * s_gl_stride;
|
||||
|
||||
if (is_async) {
|
||||
for (int i = 0; i < sh_num_groups; i++) {
|
||||
if (threadIdx.x < s_sh_stride) {
|
||||
cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x],
|
||||
&scales_ptr[row_offset + (i * s_gl_stride) +
|
||||
slice_n_offset + threadIdx.x]);
|
||||
if (sh_num_groups > act_s_max_num_groups) {
|
||||
sh_num_groups = act_s_max_num_groups;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < sh_num_groups; i++) {
|
||||
if (threadIdx.x < s_sh_stride) {
|
||||
sh_s[(i * s_sh_stride) + threadIdx.x] =
|
||||
scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset +
|
||||
threadIdx.x];
|
||||
|
||||
if (sh_first_group_id + sh_num_groups > num_groups) {
|
||||
sh_num_groups = num_groups - sh_first_group_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
int row_offset = first_group_id * s_gl_stride;
|
||||
|
||||
if (is_async) {
|
||||
for (int i = 0; i < sh_num_groups; i++) {
|
||||
if (threadIdx.x < s_sh_stride) {
|
||||
cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x],
|
||||
&scales_ptr[row_offset + (i * s_gl_stride) +
|
||||
slice_n_offset + threadIdx.x]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < sh_num_groups; i++) {
|
||||
if (threadIdx.x < s_sh_stride) {
|
||||
sh_s[(i * s_sh_stride) + threadIdx.x] =
|
||||
scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset +
|
||||
threadIdx.x];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Asynchronously fetch the next A, B and s tile from global to the next
|
||||
// shared memory pipeline location.
|
||||
@@ -872,101 +923,102 @@ __global__ void Marlin(
|
||||
int max_num_stage_groups =
|
||||
((sh_a_max_row - moe_block_size) / moe_block_size + 1) / stages;
|
||||
max_num_stage_groups = max(max_num_stage_groups, 1);
|
||||
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true,
|
||||
int pipe_a = 0) {
|
||||
if (pred) {
|
||||
if (should_load_a) {
|
||||
int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < a_sh_wr_iters; i++) {
|
||||
int row = a_gl_rd_delta_i / a_gl_stride * i + a_gl_rd_row;
|
||||
int64_t sorted_row = 0;
|
||||
if (!m_block_size_8 || row < 8)
|
||||
sorted_row = sh_rd_block_sorted_ids[row];
|
||||
int64_t true_idx =
|
||||
sorted_row * a_gl_stride + a_gl_rd_col + a_gl_rd_delta_o * a_off;
|
||||
cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[true_idx],
|
||||
row < block_num_valid_tokens);
|
||||
}
|
||||
}
|
||||
|
||||
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < b_sh_wr_iters; i++) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < b_thread_vecs; j++) {
|
||||
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j],
|
||||
B_ptr[i] + j + B_expert_off);
|
||||
}
|
||||
|
||||
B_ptr[i] += b_gl_rd_delta_o;
|
||||
}
|
||||
|
||||
if constexpr (has_act_order) {
|
||||
// Fetch g_idx thread-block portion
|
||||
int full_pipe = a_off;
|
||||
int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe;
|
||||
if (cur_k < prob_k && cur_k < slice_k_finish) {
|
||||
int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
|
||||
|
||||
int4 const* cur_g_idx_stage_ptr =
|
||||
reinterpret_cast<int4 const*>(&g_idx[cur_k]);
|
||||
|
||||
if (threadIdx.x < g_idx_stage) {
|
||||
cp_async4_pred(&sh_g_idx_stage[threadIdx.x],
|
||||
&cur_g_idx_stage_ptr[threadIdx.x]);
|
||||
auto fetch_to_shared =
|
||||
[&](int pipe, int a_off, bool pred = true, int pipe_a = 0) {
|
||||
if (pred) {
|
||||
if (should_load_a) {
|
||||
int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < a_sh_wr_iters; i++) {
|
||||
int row = a_gl_rd_delta_i / a_gl_stride * i + a_gl_rd_row;
|
||||
int64_t sorted_row = 0;
|
||||
if (!m_block_size_8 || row < 8)
|
||||
sorted_row = sh_rd_block_sorted_ids[row];
|
||||
int64_t true_idx = sorted_row * a_gl_stride + a_gl_rd_col +
|
||||
a_gl_rd_delta_o * a_off;
|
||||
cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]],
|
||||
&A[true_idx],
|
||||
row < block_num_valid_tokens);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if constexpr (group_blocks != -1) {
|
||||
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
||||
|
||||
if constexpr (group_blocks >= thread_k_blocks) {
|
||||
// Only fetch scales if this tile starts a new group
|
||||
if (pipe % (group_blocks / thread_k_blocks) == 0) {
|
||||
if (s_sh_wr_pred) {
|
||||
cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
|
||||
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < b_sh_wr_iters; i++) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < b_thread_vecs; j++) {
|
||||
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j],
|
||||
B_ptr[i] + j + B_expert_off);
|
||||
}
|
||||
|
||||
B_ptr[i] += b_gl_rd_delta_o;
|
||||
}
|
||||
|
||||
if constexpr (has_act_order) {
|
||||
// Fetch g_idx thread-block portion
|
||||
int full_pipe = a_off;
|
||||
int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe;
|
||||
if (cur_k < prob_k && cur_k < slice_k_finish) {
|
||||
int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
|
||||
|
||||
int4 const* cur_g_idx_stage_ptr =
|
||||
reinterpret_cast<int4 const*>(&g_idx[cur_k]);
|
||||
|
||||
if (threadIdx.x < g_idx_stage) {
|
||||
cp_async4_pred(&sh_g_idx_stage[threadIdx.x],
|
||||
&cur_g_idx_stage_ptr[threadIdx.x]);
|
||||
}
|
||||
s_gl_rd += s_gl_rd_delta;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < s_tb_groups; i++) {
|
||||
if (s_sh_wr_pred) {
|
||||
cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr],
|
||||
&scales_ptr[s_gl_rd]);
|
||||
if constexpr (group_blocks != -1) {
|
||||
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
||||
|
||||
if constexpr (group_blocks >= thread_k_blocks) {
|
||||
// Only fetch scales if this tile starts a new group
|
||||
if (pipe % (group_blocks / thread_k_blocks) == 0) {
|
||||
if (s_sh_wr_pred) {
|
||||
cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
|
||||
}
|
||||
s_gl_rd += s_gl_rd_delta;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < s_tb_groups; i++) {
|
||||
if (s_sh_wr_pred) {
|
||||
cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr],
|
||||
&scales_ptr[s_gl_rd]);
|
||||
}
|
||||
s_gl_rd += s_gl_rd_delta;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (has_zp && group_blocks != -1) {
|
||||
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
|
||||
|
||||
if constexpr (group_blocks >= thread_k_blocks) {
|
||||
// Only fetch zero-points if this tile starts a new group
|
||||
if (pipe % (group_blocks / thread_k_blocks) == 0) {
|
||||
if (zp_sh_wr_pred) {
|
||||
cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);
|
||||
}
|
||||
zp_gl_rd += zp_gl_rd_delta;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < zp_tb_groups; i++) {
|
||||
if (zp_sh_wr_pred) {
|
||||
cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr],
|
||||
&zp_ptr[zp_gl_rd]);
|
||||
}
|
||||
zp_gl_rd += zp_gl_rd_delta;
|
||||
}
|
||||
}
|
||||
s_gl_rd += s_gl_rd_delta;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (has_zp && group_blocks != -1) {
|
||||
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
|
||||
|
||||
if constexpr (group_blocks >= thread_k_blocks) {
|
||||
// Only fetch zero-points if this tile starts a new group
|
||||
if (pipe % (group_blocks / thread_k_blocks) == 0) {
|
||||
if (zp_sh_wr_pred) {
|
||||
cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);
|
||||
}
|
||||
zp_gl_rd += zp_gl_rd_delta;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < zp_tb_groups; i++) {
|
||||
if (zp_sh_wr_pred) {
|
||||
cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr],
|
||||
&zp_ptr[zp_gl_rd]);
|
||||
}
|
||||
zp_gl_rd += zp_gl_rd_delta;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Insert a fence even when we are winding down the pipeline to ensure that
|
||||
// waiting is also correct at this point.
|
||||
cp_async_fence();
|
||||
};
|
||||
// Insert a fence even when we are winding down the pipeline to ensure
|
||||
// that waiting is also correct at this point.
|
||||
cp_async_fence();
|
||||
};
|
||||
|
||||
auto fetch_col_zp_to_shared = [&]() {
|
||||
if (zp_sh_wr_pred) {
|
||||
@@ -994,13 +1046,13 @@ __global__ void Marlin(
|
||||
// into the current register buffer.
|
||||
auto fetch_to_registers = [&](int k, int pipe, int pipe_a = 0) {
|
||||
int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a;
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++)
|
||||
ldsm<m_block_size_8 ? 2 : 4, scalar_t>(
|
||||
frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
|
||||
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
|
||||
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < b_thread_vecs; i++) {
|
||||
frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(
|
||||
&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);
|
||||
@@ -1058,7 +1110,8 @@ __global__ void Marlin(
|
||||
|
||||
int k_blocks = cur_k / 16;
|
||||
int cur_group_id =
|
||||
k_blocks / (group_blocks * (w_type == MARLIN_NAMESPACE_NAME::kFE2M1f ? 2 : 1));
|
||||
k_blocks / (group_blocks *
|
||||
(w_type == MARLIN_NAMESPACE_NAME::kFE2M1f ? 2 : 1));
|
||||
|
||||
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
||||
|
||||
@@ -1129,10 +1182,10 @@ __global__ void Marlin(
|
||||
int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
|
||||
int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);
|
||||
|
||||
constexpr int k_frag_offsets[4] = {0, 1, 8,
|
||||
9}; // Tensor core offsets per thread
|
||||
constexpr int k_frag_offsets[4] = {
|
||||
0, 1, 8, 9}; // Tensor core offsets per thread
|
||||
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int actual_k = cur_k + k_frag_offsets[i];
|
||||
|
||||
@@ -1156,7 +1209,7 @@ __global__ void Marlin(
|
||||
if constexpr (group_blocks == -1) {
|
||||
// load only when starting a new slice
|
||||
if (k == 0 && full_pipe == 0) {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_ints_per_thread; i++) {
|
||||
frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i];
|
||||
}
|
||||
@@ -1167,7 +1220,7 @@ __global__ void Marlin(
|
||||
int4* sh_zp_stage =
|
||||
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
|
||||
(pipe / (group_blocks / thread_k_blocks)));
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_ints_per_thread; i++) {
|
||||
frag_qzp[k % 2][i] =
|
||||
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
||||
@@ -1186,16 +1239,16 @@ __global__ void Marlin(
|
||||
int cur_group_id = 0;
|
||||
|
||||
// Suppress bogus and persistent divide-by-zero warning
|
||||
#pragma nv_diagnostic push
|
||||
#pragma nv_diag_suppress divide_by_zero
|
||||
#pragma nv_diagnostic push
|
||||
#pragma nv_diag_suppress divide_by_zero
|
||||
cur_group_id = k_blocks / group_blocks;
|
||||
#pragma nv_diagnostic pop
|
||||
#pragma nv_diagnostic pop
|
||||
|
||||
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
|
||||
|
||||
sh_zp_stage += cur_group_id * zp_sh_stride;
|
||||
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_ints_per_thread; i++) {
|
||||
frag_qzp[k % 2][i] =
|
||||
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
||||
@@ -1227,10 +1280,10 @@ __global__ void Marlin(
|
||||
|
||||
int k_blocks = cur_k / 16;
|
||||
// Suppress bogus and persistent divide-by-zero warning
|
||||
#pragma nv_diagnostic push
|
||||
#pragma nv_diag_suppress divide_by_zero
|
||||
#pragma nv_diagnostic push
|
||||
#pragma nv_diag_suppress divide_by_zero
|
||||
int cur_group_id = k_blocks / group_blocks;
|
||||
#pragma nv_diagnostic pop
|
||||
#pragma nv_diagnostic pop
|
||||
|
||||
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
|
||||
|
||||
@@ -1287,9 +1340,9 @@ __global__ void Marlin(
|
||||
s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
|
||||
}
|
||||
|
||||
// We have the m dimension as the inner loop in order to encourage overlapping
|
||||
// dequantization and matmul operations.
|
||||
#pragma unroll
|
||||
// We have the m dimension as the inner loop in order to encourage overlapping
|
||||
// dequantization and matmul operations.
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++) {
|
||||
FragB frag_b0;
|
||||
FragB frag_b1;
|
||||
@@ -1319,10 +1372,18 @@ __global__ void Marlin(
|
||||
// Apply scale to frag_b0
|
||||
if constexpr (has_act_order) {
|
||||
static_assert(group_blocks != -1);
|
||||
scale4<scalar_t>(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
|
||||
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0);
|
||||
scale4<scalar_t>(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
|
||||
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1);
|
||||
scale4<scalar_t>(frag_b0,
|
||||
act_frag_s[k2][0][j],
|
||||
act_frag_s[k2][1][j],
|
||||
act_frag_s[k2][2][j],
|
||||
act_frag_s[k2][3][j],
|
||||
0);
|
||||
scale4<scalar_t>(frag_b1,
|
||||
act_frag_s[k2][0][j],
|
||||
act_frag_s[k2][1][j],
|
||||
act_frag_s[k2][2][j],
|
||||
act_frag_s[k2][3][j],
|
||||
1);
|
||||
} else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float &&
|
||||
group_blocks == -1) {
|
||||
int idx = (threadIdx.x / 4) % 2;
|
||||
@@ -1343,7 +1404,7 @@ __global__ void Marlin(
|
||||
scale<scalar_t>(frag_b1, frag_s[k2][j], 1);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
if constexpr (m_block_size_8) {
|
||||
mma_trans<scalar_t>(frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]);
|
||||
@@ -1372,12 +1433,12 @@ __global__ void Marlin(
|
||||
// unnecessary read or write iterations, e.g., for two warps we write only
|
||||
// once by warp 1 and read only once by warp 0.
|
||||
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = red_off; i > 0; i /= 2) {
|
||||
if (i <= red_idx && red_idx < 2 * i) {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4 * 2; j += (m_block_size_8 ? 2 : 1)) {
|
||||
int red_sh_wr =
|
||||
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
|
||||
@@ -1385,7 +1446,7 @@ __global__ void Marlin(
|
||||
float* c_rd = reinterpret_cast<float*>(
|
||||
&sh_red[red_sh_delta * j + red_sh_rd]);
|
||||
float* c_wr = reinterpret_cast<float*>(&sh_red[red_sh_wr]);
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int k = 0; k < 4; k++)
|
||||
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
|
||||
c_rd[k] + c_wr[k];
|
||||
@@ -1397,11 +1458,11 @@ __global__ void Marlin(
|
||||
__syncthreads();
|
||||
}
|
||||
if (red_idx == 0) {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4 * 2; i += (m_block_size_8 ? 2 : 1)) {
|
||||
float* c_rd =
|
||||
reinterpret_cast<float*>(&sh_red[red_sh_delta * i + red_sh_rd]);
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++)
|
||||
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
|
||||
c_rd[j];
|
||||
@@ -1444,7 +1505,7 @@ __global__ void Marlin(
|
||||
|
||||
if (!first) {
|
||||
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) {
|
||||
int c_idx;
|
||||
if constexpr (m_block_size_8)
|
||||
@@ -1461,11 +1522,11 @@ __global__ void Marlin(
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) {
|
||||
if (!first) {
|
||||
int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta];
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 2 * 4; j++) {
|
||||
int delta = 0;
|
||||
if constexpr (m_block_size_8) {
|
||||
@@ -1478,7 +1539,7 @@ __global__ void Marlin(
|
||||
}
|
||||
if (!last) {
|
||||
int4 c;
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 2 * 4; j++) {
|
||||
int delta = 0;
|
||||
if constexpr (m_block_size_8) {
|
||||
@@ -1527,7 +1588,7 @@ __global__ void Marlin(
|
||||
|
||||
if (!first) {
|
||||
float* frag_c_ptr = reinterpret_cast<float*>(&frag_c);
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int k = 0; k < th_size; k++) {
|
||||
if constexpr (m_block_size_8) {
|
||||
if (k % 2) continue;
|
||||
@@ -1540,7 +1601,7 @@ __global__ void Marlin(
|
||||
C_tmp[c_cur_offset + active_threads * k + threadIdx.x];
|
||||
|
||||
float* sh_c_ptr = reinterpret_cast<float*>(&sh_red[threadIdx.x]);
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int f = 0; f < 4; f++) {
|
||||
frag_c_ptr[k * 4 + f] += sh_c_ptr[f];
|
||||
}
|
||||
@@ -1549,7 +1610,7 @@ __global__ void Marlin(
|
||||
|
||||
if (!last) {
|
||||
int4* frag_c_ptr = reinterpret_cast<int4*>(&frag_c);
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int k = 0; k < th_size; k++) {
|
||||
if constexpr (m_block_size_8) {
|
||||
if (k % 2) continue;
|
||||
@@ -1619,26 +1680,38 @@ __global__ void Marlin(
|
||||
};
|
||||
|
||||
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++) {
|
||||
if constexpr (m_block_size_8) {
|
||||
int wr = c_sh_wr + 16 * j;
|
||||
write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1],
|
||||
write(wr,
|
||||
frag_c[i][j][0][0],
|
||||
frag_c[i][j][0][1],
|
||||
frag_s[j / 2][2 * (j % 2) + 0]);
|
||||
write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3],
|
||||
write(wr + 8,
|
||||
frag_c[i][j][0][2],
|
||||
frag_c[i][j][0][3],
|
||||
frag_s[j / 2][2 * (j % 2) + 1]);
|
||||
} else {
|
||||
int wr = c_sh_wr + 8 * j;
|
||||
write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
|
||||
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
|
||||
write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
|
||||
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
|
||||
write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
|
||||
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
|
||||
write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
|
||||
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
|
||||
write(wr + (4 * c_sh_stride) * 0 + 0,
|
||||
frag_c[i][j][0][0],
|
||||
frag_c[i][j][0][1],
|
||||
frag_s[j / 2][2 * (j % 2) + 0]);
|
||||
write(wr + (4 * c_sh_stride) * 8 + 0,
|
||||
frag_c[i][j][0][2],
|
||||
frag_c[i][j][0][3],
|
||||
frag_s[j / 2][2 * (j % 2) + 0]);
|
||||
write(wr + (4 * c_sh_stride) * 0 + 4,
|
||||
frag_c[i][j][1][0],
|
||||
frag_c[i][j][1][1],
|
||||
frag_s[j / 2][2 * (j % 2) + 1]);
|
||||
write(wr + (4 * c_sh_stride) * 8 + 4,
|
||||
frag_c[i][j][1][2],
|
||||
frag_c[i][j][1][3],
|
||||
frag_s[j / 2][2 * (j % 2) + 1]);
|
||||
}
|
||||
}
|
||||
c_sh_wr += 16 * (4 * c_sh_stride);
|
||||
@@ -1646,7 +1719,7 @@ __global__ void Marlin(
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0;
|
||||
i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
|
||||
i++) {
|
||||
@@ -1660,7 +1733,7 @@ __global__ void Marlin(
|
||||
scalar_t2* C_half2 = reinterpret_cast<scalar_t2*>(&C[true_idx]);
|
||||
scalar_t2* sh_red_half2 =
|
||||
reinterpret_cast<scalar_t2*>(&sh_red[c_sh_rd]);
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int a = 0; a < 4; a++) {
|
||||
scalar_t2 res = sh_red_half2[a];
|
||||
if (mul_topk_weights) {
|
||||
@@ -1686,15 +1759,15 @@ __global__ void Marlin(
|
||||
// Start global fetch and register load pipelines.
|
||||
auto start_pipes = [&]() {
|
||||
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < stages - 1; i++) {
|
||||
if (has_act_order && i == 0) {
|
||||
int last_g_idx = slice_k_start + stages * tb_k * 2;
|
||||
if (last_g_idx >= prob_k) {
|
||||
last_g_idx = prob_k - 1;
|
||||
}
|
||||
fetch_act_order_scales_to_shared(true, g_idx[slice_k_start],
|
||||
g_idx[last_g_idx]);
|
||||
fetch_act_order_scales_to_shared(
|
||||
true, g_idx[slice_k_start], g_idx[last_g_idx]);
|
||||
}
|
||||
|
||||
if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
|
||||
@@ -1732,9 +1805,9 @@ __global__ void Marlin(
|
||||
|
||||
for (int stage_group_id = 0; stage_group_id < max_num_stage_groups;
|
||||
stage_group_id++) {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < stages;) {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int k = 0; k < b_sh_wr_iters; k++) {
|
||||
int idx =
|
||||
(pipe >= stages && stage_group_id == max_num_stage_groups - 1)
|
||||
@@ -1747,8 +1820,8 @@ __global__ void Marlin(
|
||||
int idx = (pipe >= 1 && stage_group_id == max_num_stage_groups - 1)
|
||||
? (pipe - 1)
|
||||
: (pipe + (stage_group_id + 1) * stages - 1);
|
||||
fetch_to_shared((pipe + stages - 1) % stages, pipe,
|
||||
slice_iters >= stages, idx);
|
||||
fetch_to_shared(
|
||||
(pipe + stages - 1) % stages, pipe, slice_iters >= stages, idx);
|
||||
pipe++;
|
||||
wait_for_stage();
|
||||
init_same_group(pipe % stages);
|
||||
@@ -1775,8 +1848,8 @@ __global__ void Marlin(
|
||||
}
|
||||
int last_group_id = g_idx[last_g_idx];
|
||||
if (last_group_id >= sh_first_group_id + sh_num_groups) {
|
||||
fetch_act_order_scales_to_shared(false, first_group_id,
|
||||
last_group_id);
|
||||
fetch_act_order_scales_to_shared(
|
||||
false, first_group_id, last_group_id);
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
@@ -1816,7 +1889,7 @@ __global__ void Marlin(
|
||||
if constexpr (m_block_size_8) {
|
||||
int idx = (threadIdx.x / 4) % 2;
|
||||
scalar_t2* frag_s_half2 = reinterpret_cast<scalar_t2*>(frag_s);
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
frag_s_half2[i] = Dtype::num2num2(
|
||||
reinterpret_cast<scalar_t*>(&frag_s_half2[i])[idx]);
|
||||
@@ -1833,9 +1906,9 @@ __global__ void Marlin(
|
||||
w_type.size_bits() == 8 &&
|
||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
||||
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++) {
|
||||
scale_float<scalar_t>(
|
||||
reinterpret_cast<float*>(&frag_c[i][j][0][0]),
|
||||
@@ -1896,11 +1969,11 @@ __global__ void Marlin(
|
||||
|
||||
if (slice_iters) {
|
||||
a_gl_rd_col = (threadIdx.x % a_gl_rd_delta_o);
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < b_sh_wr_iters; i++)
|
||||
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
|
||||
if (slice_col == 0) {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
|
||||
}
|
||||
|
||||
|
||||
@@ -49,7 +49,6 @@ struct Tensor {
|
||||
|
||||
int64_t size(int64_t d) const { return raw_tensor_.dims().at(d); }
|
||||
|
||||
|
||||
template <typename T>
|
||||
T *data_ptr() const {
|
||||
return const_cast<T *>(raw_tensor_.data<T>());
|
||||
@@ -74,7 +73,6 @@ struct Tensor {
|
||||
return raw_tensor_.place().GetType() == phi::AllocationType::GPU;
|
||||
}
|
||||
|
||||
|
||||
// MARLIN_NAMESPACE_NAME::ScalarType scalar_type() const {
|
||||
// return raw_tensor_.dtype();
|
||||
// }
|
||||
|
||||
+127
-124
@@ -18,168 +18,171 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
// dim3 grid(256)
|
||||
// dim3 block(512)
|
||||
template <typename T, int VecSize>
|
||||
__global__ void swigluoai_interleave_kernel(T* act_out,
|
||||
const T* input,
|
||||
const float alpha,
|
||||
const float limit,
|
||||
const int64_t seq_len,
|
||||
const int64_t hidden_dim) {
|
||||
int64_t tid = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
|
||||
const int64_t num = seq_len * hidden_dim;
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
LoadT src_vec0, src_vec1;
|
||||
LoadT res_vec;
|
||||
const T* input,
|
||||
const float alpha,
|
||||
const float limit,
|
||||
const int64_t seq_len,
|
||||
const int64_t hidden_dim) {
|
||||
int64_t tid = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
|
||||
const int64_t num = seq_len * hidden_dim;
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
LoadT src_vec0, src_vec1;
|
||||
LoadT res_vec;
|
||||
|
||||
int64_t vec_num = hidden_dim / VecSize * seq_len;
|
||||
int64_t col_size = hidden_dim / VecSize;
|
||||
int64_t times = (vec_num - 1) / (gridDim.x * blockDim.x) + 1;
|
||||
int64_t vec_num = hidden_dim / VecSize * seq_len;
|
||||
int64_t col_size = hidden_dim / VecSize;
|
||||
int64_t times = (vec_num - 1) / (gridDim.x * blockDim.x) + 1;
|
||||
|
||||
for(int i = 0; i < times; i++)
|
||||
{
|
||||
int64_t index = tid + i * gridDim.x * blockDim.x ;
|
||||
int64_t row = index / col_size;
|
||||
int64_t col = index % col_size;
|
||||
for (int i = 0; i < times; i++) {
|
||||
int64_t index = tid + i * gridDim.x * blockDim.x;
|
||||
int64_t row = index / col_size;
|
||||
int64_t col = index % col_size;
|
||||
|
||||
if(row < seq_len && col < col_size)
|
||||
{
|
||||
Load<T, VecSize>(&input[row*hidden_dim*2 + col*VecSize*2], &src_vec0);
|
||||
Load<T, VecSize>(&input[row*hidden_dim*2 + col*VecSize*2 + VecSize], &src_vec1);
|
||||
if (row < seq_len && col < col_size) {
|
||||
Load<T, VecSize>(&input[row * hidden_dim * 2 + col * VecSize * 2],
|
||||
&src_vec0);
|
||||
Load<T, VecSize>(
|
||||
&input[row * hidden_dim * 2 + col * VecSize * 2 + VecSize],
|
||||
&src_vec1);
|
||||
|
||||
for (int j = 0; j < VecSize/2; ++j) {
|
||||
float a = static_cast<float>(src_vec0[2*j]);
|
||||
float b = static_cast<float>(src_vec0[2*j + 1]);
|
||||
a = fminf(a, limit);
|
||||
b = fminf(fmaxf(b,-limit), limit);
|
||||
float res = (b + 1) * a / (1.f + expf(-a * alpha));
|
||||
res_vec[j] = static_cast<T>(res);
|
||||
}
|
||||
for (int j = 0; j < VecSize/2; ++j) {
|
||||
float a = static_cast<float>(src_vec1[2*j]);
|
||||
float b = static_cast<float>(src_vec1[2*j + 1]);
|
||||
a = fminf(a, limit);
|
||||
b = fminf(fmaxf(b,-limit), limit);
|
||||
float res = (b + 1) * a / (1.f + expf(-a * alpha));
|
||||
res_vec[j + VecSize/2] = static_cast<T>(res);
|
||||
}
|
||||
for (int j = 0; j < VecSize / 2; ++j) {
|
||||
float a = static_cast<float>(src_vec0[2 * j]);
|
||||
float b = static_cast<float>(src_vec0[2 * j + 1]);
|
||||
a = fminf(a, limit);
|
||||
b = fminf(fmaxf(b, -limit), limit);
|
||||
float res = (b + 1) * a / (1.f + expf(-a * alpha));
|
||||
res_vec[j] = static_cast<T>(res);
|
||||
}
|
||||
for (int j = 0; j < VecSize / 2; ++j) {
|
||||
float a = static_cast<float>(src_vec1[2 * j]);
|
||||
float b = static_cast<float>(src_vec1[2 * j + 1]);
|
||||
a = fminf(a, limit);
|
||||
b = fminf(fmaxf(b, -limit), limit);
|
||||
float res = (b + 1) * a / (1.f + expf(-a * alpha));
|
||||
res_vec[j + VecSize / 2] = static_cast<T>(res);
|
||||
}
|
||||
|
||||
Store<T, VecSize>(res_vec, &act_out[row*hidden_dim + col*VecSize]);
|
||||
}
|
||||
Store<T, VecSize>(res_vec, &act_out[row * hidden_dim + col * VecSize]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// dim3 grid(256)
|
||||
// dim3 block(512)
|
||||
template <typename T, int VecSize>
|
||||
__global__ void swigluoai_norm_kernel(T* act_out,
|
||||
const T* input,
|
||||
const float alpha,
|
||||
const float limit,
|
||||
const int64_t seq_len,
|
||||
const int64_t hidden_dim) {
|
||||
int64_t tid = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
|
||||
const int64_t num = seq_len * hidden_dim;
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
LoadT src_vec0, src_vec1;
|
||||
LoadT res_vec;
|
||||
const T* input,
|
||||
const float alpha,
|
||||
const float limit,
|
||||
const int64_t seq_len,
|
||||
const int64_t hidden_dim) {
|
||||
int64_t tid = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
|
||||
const int64_t num = seq_len * hidden_dim;
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
LoadT src_vec0, src_vec1;
|
||||
LoadT res_vec;
|
||||
|
||||
int64_t vec_num = hidden_dim / VecSize * seq_len;
|
||||
int64_t col_size = hidden_dim / VecSize;
|
||||
int64_t times = (vec_num - 1) / (gridDim.x * blockDim.x) + 1;
|
||||
int64_t vec_num = hidden_dim / VecSize * seq_len;
|
||||
int64_t col_size = hidden_dim / VecSize;
|
||||
int64_t times = (vec_num - 1) / (gridDim.x * blockDim.x) + 1;
|
||||
|
||||
for(int i = 0; i < times; i++)
|
||||
{
|
||||
int64_t index = tid + i * gridDim.x * blockDim.x ;
|
||||
int64_t row = index / col_size;
|
||||
int64_t col = index % col_size;
|
||||
for (int i = 0; i < times; i++) {
|
||||
int64_t index = tid + i * gridDim.x * blockDim.x;
|
||||
int64_t row = index / col_size;
|
||||
int64_t col = index % col_size;
|
||||
|
||||
if(row < seq_len && col < col_size)
|
||||
{
|
||||
Load<T, VecSize>(&input[row*hidden_dim*2 + col*VecSize], &src_vec0);
|
||||
Load<T, VecSize>(&input[row*hidden_dim*2 + hidden_dim + col*VecSize], &src_vec1);
|
||||
if (row < seq_len && col < col_size) {
|
||||
Load<T, VecSize>(&input[row * hidden_dim * 2 + col * VecSize], &src_vec0);
|
||||
Load<T, VecSize>(
|
||||
&input[row * hidden_dim * 2 + hidden_dim + col * VecSize], &src_vec1);
|
||||
|
||||
for (int j = 0; j < VecSize; ++j) {
|
||||
float a = static_cast<float>(src_vec0[j]);
|
||||
float b = static_cast<float>(src_vec1[j]);
|
||||
float z = fminf(fmaxf(a * alpha, -limit), limit);
|
||||
float res = b * a / (1.f + expf(-z));
|
||||
res_vec[j] = static_cast<T>(res);
|
||||
}
|
||||
for (int j = 0; j < VecSize; ++j) {
|
||||
float a = static_cast<float>(src_vec0[j]);
|
||||
float b = static_cast<float>(src_vec1[j]);
|
||||
float z = fminf(fmaxf(a * alpha, -limit), limit);
|
||||
float res = b * a / (1.f + expf(-z));
|
||||
res_vec[j] = static_cast<T>(res);
|
||||
}
|
||||
|
||||
Store<T, VecSize>(res_vec, &act_out[row*hidden_dim + col*VecSize]);
|
||||
}
|
||||
Store<T, VecSize>(res_vec, &act_out[row * hidden_dim + col * VecSize]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
paddle::Tensor SwigluOAI(const paddle::Tensor &fc1_out_tensor, const float alpha, const float limit, const std::string& type)
|
||||
{
|
||||
// const int64_t group_size = fc1_out_tensor.shape()[1];
|
||||
const int64_t seq_len = fc1_out_tensor.shape()[0];
|
||||
const int64_t hidden_dim = fc1_out_tensor.shape()[1] / 2;
|
||||
auto act_out_tensor = GetEmptyTensor({seq_len, hidden_dim}, fc1_out_tensor.dtype(), fc1_out_tensor.place());
|
||||
paddle::Tensor SwigluOAI(const paddle::Tensor& fc1_out_tensor,
|
||||
const float alpha,
|
||||
const float limit,
|
||||
const std::string& type) {
|
||||
// const int64_t group_size = fc1_out_tensor.shape()[1];
|
||||
const int64_t seq_len = fc1_out_tensor.shape()[0];
|
||||
const int64_t hidden_dim = fc1_out_tensor.shape()[1] / 2;
|
||||
auto act_out_tensor = GetEmptyTensor(
|
||||
{seq_len, hidden_dim}, fc1_out_tensor.dtype(), fc1_out_tensor.place());
|
||||
|
||||
constexpr int VecSize = 8;
|
||||
PD_CHECK(fc1_out_tensor.dtype() == paddle::DataType::BFLOAT16);
|
||||
PD_CHECK(hidden_dim % VecSize == 0);
|
||||
constexpr int VecSize = 8;
|
||||
PD_CHECK(fc1_out_tensor.dtype() == paddle::DataType::BFLOAT16);
|
||||
PD_CHECK(hidden_dim % VecSize == 0);
|
||||
|
||||
constexpr paddle::DataType D = paddle::DataType::BFLOAT16;
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
constexpr paddle::DataType D = paddle::DataType::BFLOAT16;
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
const int block_size = 512;
|
||||
const int grid_size = 256;
|
||||
const int block_size = 512;
|
||||
const int grid_size = 256;
|
||||
|
||||
#define dispatch_norm() do {\
|
||||
swigluoai_norm_kernel<DataType_, VecSize><<<grid_size, block_size, 0, fc1_out_tensor.stream()>>>(\
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(act_out_tensor.data<data_t>())),\
|
||||
reinterpret_cast<const DataType_*>(fc1_out_tensor.data<data_t>()),\
|
||||
alpha,\
|
||||
limit,\
|
||||
seq_len,\
|
||||
hidden_dim\
|
||||
);} while(0)
|
||||
#define dispatch_norm() \
|
||||
do { \
|
||||
swigluoai_norm_kernel<DataType_, VecSize> \
|
||||
<<<grid_size, block_size, 0, fc1_out_tensor.stream()>>>( \
|
||||
reinterpret_cast<DataType_*>( \
|
||||
const_cast<data_t*>(act_out_tensor.data<data_t>())), \
|
||||
reinterpret_cast<const DataType_*>(fc1_out_tensor.data<data_t>()), \
|
||||
alpha, \
|
||||
limit, \
|
||||
seq_len, \
|
||||
hidden_dim); \
|
||||
} while (0)
|
||||
|
||||
#define dispatch_interleave() do {\
|
||||
swigluoai_interleave_kernel<DataType_, VecSize><<<grid_size, block_size, 0, fc1_out_tensor.stream()>>>(\
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(act_out_tensor.data<data_t>())),\
|
||||
reinterpret_cast<const DataType_*>(fc1_out_tensor.data<data_t>()),\
|
||||
alpha,\
|
||||
limit,\
|
||||
seq_len,\
|
||||
hidden_dim\
|
||||
);} while(0)
|
||||
#define dispatch_interleave() \
|
||||
do { \
|
||||
swigluoai_interleave_kernel<DataType_, VecSize> \
|
||||
<<<grid_size, block_size, 0, fc1_out_tensor.stream()>>>( \
|
||||
reinterpret_cast<DataType_*>( \
|
||||
const_cast<data_t*>(act_out_tensor.data<data_t>())), \
|
||||
reinterpret_cast<const DataType_*>(fc1_out_tensor.data<data_t>()), \
|
||||
alpha, \
|
||||
limit, \
|
||||
seq_len, \
|
||||
hidden_dim); \
|
||||
} while (0)
|
||||
|
||||
if(type == "interleave")
|
||||
{
|
||||
dispatch_interleave();
|
||||
}
|
||||
else
|
||||
{
|
||||
dispatch_norm();
|
||||
}
|
||||
// if (token_nums_per_expert.dtype() == paddle::DataType::INT64) {
|
||||
// dispatch_by_index(int64_t);
|
||||
// } else if(token_nums_per_expert.dtype() == paddle::DataType::INT32) {
|
||||
// dispatch_by_index(int32_t);
|
||||
// } else {
|
||||
// PD_THROW("Unsupported token_nums_per_expert's data dtype.");
|
||||
// }
|
||||
if (type == "interleave") {
|
||||
dispatch_interleave();
|
||||
} else {
|
||||
dispatch_norm();
|
||||
}
|
||||
// if (token_nums_per_expert.dtype() == paddle::DataType::INT64) {
|
||||
// dispatch_by_index(int64_t);
|
||||
// } else if(token_nums_per_expert.dtype() == paddle::DataType::INT32) {
|
||||
// dispatch_by_index(int32_t);
|
||||
// } else {
|
||||
// PD_THROW("Unsupported token_nums_per_expert's data dtype.");
|
||||
// }
|
||||
|
||||
return act_out_tensor;
|
||||
return act_out_tensor;
|
||||
}
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> SwigluOAIWrapper(
|
||||
const paddle::Tensor& fc1_out_tensor,
|
||||
const float alpha,
|
||||
const float limit,
|
||||
const std::string& type) {
|
||||
return {SwigluOAI(fc1_out_tensor, alpha, limit, type)};
|
||||
return {SwigluOAI(fc1_out_tensor, alpha, limit, type)};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(swigluoai)
|
||||
|
||||
@@ -15,5 +15,7 @@
|
||||
#pragma once
|
||||
#include "helper.h"
|
||||
|
||||
paddle::Tensor
|
||||
SwigluOAI(const paddle::Tensor &fc1_out_tensor, const float alpha, const float limit, const std::string& type);
|
||||
paddle::Tensor SwigluOAI(const paddle::Tensor& fc1_out_tensor,
|
||||
const float alpha,
|
||||
const float limit,
|
||||
const std::string& type);
|
||||
|
||||
@@ -15,13 +15,14 @@
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#define CEILDIV(a,b) (((a+b-1)/b))
|
||||
#define CEILDIV(a, b) (((a + b - 1) / b))
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__ topk_ids,
|
||||
int32_t* __restrict__ sorted_token_ids,
|
||||
int32_t* __restrict__ cumsum_buffer,
|
||||
size_t numel) {
|
||||
__global__ void count_and_sort_expert_tokens_kernel(
|
||||
const scalar_t* __restrict__ topk_ids,
|
||||
int32_t* __restrict__ sorted_token_ids,
|
||||
int32_t* __restrict__ cumsum_buffer,
|
||||
size_t numel) {
|
||||
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const size_t stride = blockDim.x * gridDim.x;
|
||||
|
||||
@@ -33,16 +34,17 @@ __global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__
|
||||
}
|
||||
|
||||
template <typename scalar_t, int num_experts>
|
||||
__global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_ids,
|
||||
int32_t* __restrict__ expert_ids,
|
||||
int32_t* __restrict__ total_tokens_post_pad,
|
||||
int32_t GEMM_BLOCK_SIZE_M,
|
||||
size_t numel,
|
||||
int32_t* __restrict__ cumsum_buffer) {
|
||||
__global__ void moe_align_block_size_kernel(
|
||||
const scalar_t* __restrict__ topk_ids,
|
||||
int32_t* __restrict__ expert_ids,
|
||||
int32_t* __restrict__ total_tokens_post_pad,
|
||||
int32_t GEMM_BLOCK_SIZE_M,
|
||||
size_t numel,
|
||||
int32_t* __restrict__ cumsum_buffer) {
|
||||
__shared__ int32_t tokens_per_ep[num_experts];
|
||||
|
||||
for (int i = threadIdx.x; i < num_experts; i += blockDim.x) {
|
||||
tokens_per_ep[i] = 0;
|
||||
tokens_per_ep[i] = 0;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
@@ -57,8 +59,10 @@ __global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_id
|
||||
if (threadIdx.x == 0) {
|
||||
cumsum_buffer[0] = 0;
|
||||
for (int i = 1; i <= num_experts; ++i) {
|
||||
int expert_count = tokens_per_ep[i-1];
|
||||
cumsum_buffer[i] = cumsum_buffer[i - 1] + CEILDIV(expert_count, GEMM_BLOCK_SIZE_M) * GEMM_BLOCK_SIZE_M;
|
||||
int expert_count = tokens_per_ep[i - 1];
|
||||
cumsum_buffer[i] =
|
||||
cumsum_buffer[i - 1] +
|
||||
CEILDIV(expert_count, GEMM_BLOCK_SIZE_M) * GEMM_BLOCK_SIZE_M;
|
||||
}
|
||||
*total_tokens_post_pad = cumsum_buffer[num_experts];
|
||||
}
|
||||
@@ -66,33 +70,39 @@ __global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_id
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < num_experts) {
|
||||
for (int i = cumsum_buffer[threadIdx.x]; i < cumsum_buffer[threadIdx.x + 1]; i += GEMM_BLOCK_SIZE_M) {
|
||||
for (int i = cumsum_buffer[threadIdx.x]; i < cumsum_buffer[threadIdx.x + 1];
|
||||
i += GEMM_BLOCK_SIZE_M) {
|
||||
expert_ids[i / GEMM_BLOCK_SIZE_M] = threadIdx.x;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> tritonmoe_preprocessInferShape(
|
||||
const std::vector<int64_t>& topk_ids,
|
||||
int64_t num_experts,
|
||||
int64_t GEMM_BLOCK_SIZE_M) {
|
||||
int topk_ids_numel = topk_ids[0] * topk_ids[1];
|
||||
int max_num_tokens_padded =
|
||||
topk_ids_numel + num_experts * (GEMM_BLOCK_SIZE_M - 1);
|
||||
|
||||
std::vector<std::vector<int64_t>> tritonmoe_preprocessInferShape(const std::vector<int64_t>& topk_ids, int64_t num_experts, int64_t GEMM_BLOCK_SIZE_M) {
|
||||
std::vector<int64_t> sorted_ids = {max_num_tokens_padded};
|
||||
|
||||
int max_num_m_blocks = max_num_tokens_padded / GEMM_BLOCK_SIZE_M;
|
||||
std::vector<int64_t> expert_ids = {max_num_m_blocks};
|
||||
std::vector<int64_t> num_tokens_post_pad = {1};
|
||||
|
||||
int topk_ids_numel = topk_ids[0] * topk_ids[1];
|
||||
int max_num_tokens_padded = topk_ids_numel + num_experts * (GEMM_BLOCK_SIZE_M - 1);
|
||||
|
||||
std::vector<int64_t> sorted_ids = {max_num_tokens_padded};
|
||||
|
||||
int max_num_m_blocks = max_num_tokens_padded / GEMM_BLOCK_SIZE_M;
|
||||
std::vector<int64_t> expert_ids = {max_num_m_blocks};
|
||||
std::vector<int64_t> num_tokens_post_pad = {1};
|
||||
|
||||
return {sorted_ids, expert_ids, num_tokens_post_pad};
|
||||
return {sorted_ids, expert_ids, num_tokens_post_pad};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> tritonmoe_preprocessIferDtype(const paddle::DataType& topk_ids, int64_t num_experts, int64_t GEMM_BLOCK_SIZE_M) {
|
||||
return {paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32};
|
||||
std::vector<paddle::DataType> tritonmoe_preprocessIferDtype(
|
||||
const paddle::DataType& topk_ids,
|
||||
int64_t num_experts,
|
||||
int64_t GEMM_BLOCK_SIZE_M) {
|
||||
return {paddle::DataType::INT32,
|
||||
paddle::DataType::INT32,
|
||||
paddle::DataType::INT32};
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
supporse num_experts = 8, GEMM_BLOCK_SIZE_M = 4,
|
||||
topk_ids.shape = [4,4], means=topk=4
|
||||
@@ -113,85 +123,74 @@ Then return value `sorted_ids` is
|
||||
0,16,16,16
|
||||
*/
|
||||
|
||||
std::vector<paddle::Tensor> tritonmoe_preprocess_kernel(
|
||||
const paddle::Tensor& topk_ids,
|
||||
int64_t num_experts,
|
||||
int64_t GEMM_BLOCK_SIZE_M) {
|
||||
int topk_ids_numel = topk_ids.shape()[0] * topk_ids.shape()[1];
|
||||
int max_num_tokens_padded =
|
||||
topk_ids_numel + num_experts * (GEMM_BLOCK_SIZE_M - 1);
|
||||
|
||||
std::vector<paddle::Tensor> tritonmoe_preprocess_kernel(const paddle::Tensor& topk_ids, int64_t num_experts, int64_t GEMM_BLOCK_SIZE_M) {
|
||||
auto sorted_ids = paddle::full({max_num_tokens_padded},
|
||||
topk_ids_numel,
|
||||
paddle::DataType::INT32,
|
||||
topk_ids.place());
|
||||
|
||||
int topk_ids_numel = topk_ids.shape()[0] * topk_ids.shape()[1];
|
||||
int max_num_tokens_padded = topk_ids_numel + num_experts * (GEMM_BLOCK_SIZE_M - 1);
|
||||
int max_num_m_blocks = max_num_tokens_padded / GEMM_BLOCK_SIZE_M;
|
||||
|
||||
auto sorted_ids = paddle::full(
|
||||
{max_num_tokens_padded},
|
||||
topk_ids_numel,
|
||||
paddle::DataType::INT32,
|
||||
topk_ids.place()
|
||||
);
|
||||
auto expert_ids = paddle::empty(
|
||||
{max_num_m_blocks}, paddle::DataType::INT32, topk_ids.place());
|
||||
|
||||
int max_num_m_blocks = max_num_tokens_padded / GEMM_BLOCK_SIZE_M;
|
||||
auto num_tokens_post_pad =
|
||||
paddle::empty({1}, paddle::DataType::INT32, topk_ids.place());
|
||||
|
||||
auto expert_ids = paddle::empty(
|
||||
{max_num_m_blocks}, paddle::DataType::INT32,
|
||||
topk_ids.place()
|
||||
);
|
||||
auto cumsum_buffer = paddle::empty(
|
||||
{num_experts + 1}, paddle::DataType::INT32, topk_ids.place());
|
||||
|
||||
auto num_tokens_post_pad = paddle::empty(
|
||||
{1},
|
||||
paddle::DataType::INT32,
|
||||
topk_ids.place()
|
||||
);
|
||||
auto stream = topk_ids.stream();
|
||||
using scalar_t = int64_t;
|
||||
|
||||
auto cumsum_buffer = paddle::empty(
|
||||
{num_experts + 1},
|
||||
paddle::DataType::INT32,
|
||||
topk_ids.place()
|
||||
);
|
||||
#define run_align_kernel(num_experts) \
|
||||
auto align_kernel = moe_align_block_size_kernel<scalar_t, num_experts>; \
|
||||
align_kernel<<<1, 1024, 0, stream>>>(topk_ids.data<scalar_t>(), \
|
||||
expert_ids.data<int32_t>(), \
|
||||
num_tokens_post_pad.data<int32_t>(), \
|
||||
GEMM_BLOCK_SIZE_M, \
|
||||
topk_ids_numel, \
|
||||
cumsum_buffer.data<int32_t>());
|
||||
|
||||
auto stream = topk_ids.stream();
|
||||
using scalar_t = int64_t;
|
||||
if (num_experts == 8) {
|
||||
run_align_kernel(8);
|
||||
} else if (num_experts == 256) {
|
||||
run_align_kernel(256);
|
||||
} else if (num_experts == 2) {
|
||||
run_align_kernel(2);
|
||||
} else if (num_experts == 64) {
|
||||
run_align_kernel(64);
|
||||
} else if (num_experts == 128) {
|
||||
run_align_kernel(128);
|
||||
} else if (num_experts == 160) {
|
||||
run_align_kernel(160);
|
||||
} else if (num_experts == 32) {
|
||||
run_align_kernel(32);
|
||||
} else {
|
||||
PD_THROW("Not support num_experts: %d", num_experts);
|
||||
}
|
||||
|
||||
# define run_align_kernel(num_experts) \
|
||||
auto align_kernel = moe_align_block_size_kernel<scalar_t, num_experts>; \
|
||||
align_kernel<<<1, 1024, 0, stream>>>( \
|
||||
topk_ids.data<scalar_t>(), \
|
||||
expert_ids.data<int32_t>(), \
|
||||
num_tokens_post_pad.data<int32_t>(), \
|
||||
GEMM_BLOCK_SIZE_M, \
|
||||
topk_ids_numel, \
|
||||
cumsum_buffer.data<int32_t>());
|
||||
const int block_threads = 256;
|
||||
const int num_blocks = CEILDIV(topk_ids_numel, block_threads);
|
||||
const int max_blocks = 65535;
|
||||
const int actual_blocks = std::min(num_blocks, max_blocks);
|
||||
|
||||
if (num_experts == 8) {
|
||||
run_align_kernel(8);
|
||||
} else if (num_experts == 256) {
|
||||
run_align_kernel(256);
|
||||
} else if (num_experts == 2) {
|
||||
run_align_kernel(2);
|
||||
} else if (num_experts == 64) {
|
||||
run_align_kernel(64);
|
||||
} else if (num_experts == 128) {
|
||||
run_align_kernel(128);
|
||||
} else if (num_experts == 160) {
|
||||
run_align_kernel(160);
|
||||
} else if (num_experts == 32) {
|
||||
run_align_kernel(32);
|
||||
}
|
||||
else {
|
||||
PD_THROW("Not support num_experts: %d", num_experts);
|
||||
}
|
||||
auto sort_kernel = count_and_sort_expert_tokens_kernel<scalar_t>;
|
||||
|
||||
const int block_threads = 256;
|
||||
const int num_blocks = CEILDIV(topk_ids_numel, block_threads);
|
||||
const int max_blocks = 65535;
|
||||
const int actual_blocks = std::min(num_blocks, max_blocks);
|
||||
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(
|
||||
topk_ids.data<scalar_t>(),
|
||||
sorted_ids.data<int32_t>(),
|
||||
cumsum_buffer.data<int32_t>(),
|
||||
topk_ids_numel);
|
||||
|
||||
auto sort_kernel = count_and_sort_expert_tokens_kernel<scalar_t>;
|
||||
|
||||
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(topk_ids.data<scalar_t>(),
|
||||
sorted_ids.data<int32_t>(),
|
||||
cumsum_buffer.data<int32_t>(),
|
||||
topk_ids_numel);
|
||||
|
||||
|
||||
|
||||
return {sorted_ids, expert_ids, num_tokens_post_pad};
|
||||
return {sorted_ids, expert_ids, num_tokens_post_pad};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(tritonmoe_preprocess)
|
||||
|
||||
@@ -17,12 +17,17 @@
|
||||
|
||||
template <typename T, int TileRows, int TileColumns, int NumThreads>
|
||||
__global__ void Wint25UnzipKernel(const uint16_t *zipped_weight_ptr,
|
||||
const T *super_scale_ptr, T *weight_ptr,
|
||||
const int64_t batch, const int64_t num_rows,
|
||||
const T *super_scale_ptr,
|
||||
T *weight_ptr,
|
||||
const int64_t batch,
|
||||
const int64_t num_rows,
|
||||
const int64_t num_columns) {
|
||||
using UnzipFunctor =
|
||||
cutlass::gemm::threadblock::UnzipAndDequantFunctor<T, cutlass::WintQuantMethod::kWeightOnlyInt25, TileRows,
|
||||
TileColumns, NumThreads>;
|
||||
using UnzipFunctor = cutlass::gemm::threadblock::UnzipAndDequantFunctor<
|
||||
T,
|
||||
cutlass::WintQuantMethod::kWeightOnlyInt25,
|
||||
TileRows,
|
||||
TileColumns,
|
||||
NumThreads>;
|
||||
|
||||
__shared__ T smem[TileRows * TileColumns];
|
||||
|
||||
@@ -41,7 +46,8 @@ __global__ void Wint25UnzipKernel(const uint16_t *zipped_weight_ptr,
|
||||
|
||||
// unzip to shared memory
|
||||
UnzipFunctor unzip_functor;
|
||||
unzip_functor(block_zipped_weight_ptr, block_super_scale_ptr, smem, num_columns);
|
||||
unzip_functor(
|
||||
block_zipped_weight_ptr, block_super_scale_ptr, smem, num_columns);
|
||||
|
||||
// write back to global memory
|
||||
for (int row = 0; row < TileRows; ++row) {
|
||||
@@ -55,19 +61,26 @@ __global__ void Wint25UnzipKernel(const uint16_t *zipped_weight_ptr,
|
||||
}
|
||||
|
||||
template <typename T, int64_t TileRows, int64_t TileColumns, int NumThreads>
|
||||
__global__ void
|
||||
Wint2UnzipKernel(const uint8_t *zipped_weight_ptr,
|
||||
const uint8_t *local_scale_ptr, const float *code_scale_ptr,
|
||||
const float *code_zp_ptr, const T *super_scale_ptr,
|
||||
T *weight_ptr, const int64_t batch, const int64_t num_rows,
|
||||
const int64_t num_columns) {
|
||||
using UnzipFunctor =
|
||||
cutlass::gemm::threadblock::UnzipAndDequantFunctor<T, cutlass::WintQuantMethod::kWeightOnlyInt2, TileRows,
|
||||
TileColumns, NumThreads>;
|
||||
__global__ void Wint2UnzipKernel(const uint8_t *zipped_weight_ptr,
|
||||
const uint8_t *local_scale_ptr,
|
||||
const float *code_scale_ptr,
|
||||
const float *code_zp_ptr,
|
||||
const T *super_scale_ptr,
|
||||
T *weight_ptr,
|
||||
const int64_t batch,
|
||||
const int64_t num_rows,
|
||||
const int64_t num_columns) {
|
||||
using UnzipFunctor = cutlass::gemm::threadblock::UnzipAndDequantFunctor<
|
||||
T,
|
||||
cutlass::WintQuantMethod::kWeightOnlyInt2,
|
||||
TileRows,
|
||||
TileColumns,
|
||||
NumThreads>;
|
||||
|
||||
constexpr bool kUseAsyncLoad = true;
|
||||
|
||||
__shared__ uint8_t zipped_smem[UnzipFunctor::kZippedSmemBytes + UnzipFunctor::kColumnWiseSmemBytes];
|
||||
__shared__ uint8_t zipped_smem[UnzipFunctor::kZippedSmemBytes +
|
||||
UnzipFunctor::kColumnWiseSmemBytes];
|
||||
__shared__ T smem[TileRows * TileColumns];
|
||||
|
||||
int64_t block_start_column = blockIdx.x * TileColumns;
|
||||
@@ -95,15 +108,21 @@ Wint2UnzipKernel(const uint8_t *zipped_weight_ptr,
|
||||
? super_scale_ptr + blockIdx.z * num_columns + block_start_column
|
||||
: nullptr;
|
||||
|
||||
typename UnzipFunctor::Arguments args(zipped_smem, zipped_smem + UnzipFunctor::kZippedSmemBytes);
|
||||
typename UnzipFunctor::Arguments args(
|
||||
zipped_smem, zipped_smem + UnzipFunctor::kZippedSmemBytes);
|
||||
|
||||
// unzip to shared memory
|
||||
UnzipFunctor functor;
|
||||
|
||||
if (kUseAsyncLoad) {
|
||||
functor.LoadAsync(block_zipped_weight_ptr, block_local_scale_ptr,
|
||||
block_code_scale_ptr, block_code_zp_ptr, block_super_scale_ptr,
|
||||
&args, num_columns, true);
|
||||
functor.LoadAsync(block_zipped_weight_ptr,
|
||||
block_local_scale_ptr,
|
||||
block_code_scale_ptr,
|
||||
block_code_zp_ptr,
|
||||
block_super_scale_ptr,
|
||||
&args,
|
||||
num_columns,
|
||||
true);
|
||||
|
||||
// 发起 cp.async 的收束
|
||||
cutlass::arch::cp_async_fence();
|
||||
@@ -112,9 +131,14 @@ Wint2UnzipKernel(const uint8_t *zipped_weight_ptr,
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
} else {
|
||||
functor.Load(block_zipped_weight_ptr, block_local_scale_ptr,
|
||||
block_code_scale_ptr, block_code_zp_ptr, block_super_scale_ptr,
|
||||
&args, num_columns, true);
|
||||
functor.Load(block_zipped_weight_ptr,
|
||||
block_local_scale_ptr,
|
||||
block_code_scale_ptr,
|
||||
block_code_zp_ptr,
|
||||
block_super_scale_ptr,
|
||||
&args,
|
||||
num_columns,
|
||||
true);
|
||||
}
|
||||
|
||||
functor.Compute(args, smem, block_start_row);
|
||||
@@ -132,8 +156,10 @@ Wint2UnzipKernel(const uint8_t *zipped_weight_ptr,
|
||||
|
||||
template <typename T>
|
||||
void Wint25UnzipKernelLauncher(const uint16_t *zipped_weight,
|
||||
const T *supper_scale, T *weight,
|
||||
const int64_t batch, const int64_t num_rows,
|
||||
const T *supper_scale,
|
||||
T *weight,
|
||||
const int64_t batch,
|
||||
const int64_t num_rows,
|
||||
const int64_t num_columns) {
|
||||
constexpr int kTileRows = 64;
|
||||
constexpr int kTileColumns = 128;
|
||||
@@ -146,16 +172,19 @@ void Wint25UnzipKernelLauncher(const uint16_t *zipped_weight,
|
||||
dim3 grid_dim(block_dim_x, block_dim_y, batch);
|
||||
|
||||
Wint25UnzipKernel<T, kTileRows, kTileColumns, kNumThreads>
|
||||
<<<grid_dim, block_dim>>>(zipped_weight, supper_scale, weight, batch,
|
||||
num_rows, num_columns);
|
||||
<<<grid_dim, block_dim>>>(
|
||||
zipped_weight, supper_scale, weight, batch, num_rows, num_columns);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Wint2UnzipKernelLauncher(const uint8_t *zipped_weight,
|
||||
const uint8_t *local_scale,
|
||||
const float *code_scale, const float *code_zp,
|
||||
const T *supper_scale, T *weight,
|
||||
const int64_t batch, const int64_t num_rows,
|
||||
const float *code_scale,
|
||||
const float *code_zp,
|
||||
const T *supper_scale,
|
||||
T *weight,
|
||||
const int64_t batch,
|
||||
const int64_t num_rows,
|
||||
const int64_t num_columns) {
|
||||
constexpr int kTileRows = 64;
|
||||
constexpr int kTileColumns = 256;
|
||||
@@ -168,8 +197,14 @@ void Wint2UnzipKernelLauncher(const uint8_t *zipped_weight,
|
||||
dim3 grid_dim(block_dim_x, block_dim_y, batch);
|
||||
|
||||
Wint2UnzipKernel<T, kTileRows, kTileColumns, kNumThreads>
|
||||
<<<grid_dim, block_dim>>>(zipped_weight, local_scale, code_scale, code_zp,
|
||||
supper_scale, weight, batch, num_rows,
|
||||
<<<grid_dim, block_dim>>>(zipped_weight,
|
||||
local_scale,
|
||||
code_scale,
|
||||
code_zp,
|
||||
supper_scale,
|
||||
weight,
|
||||
batch,
|
||||
num_rows,
|
||||
num_columns);
|
||||
}
|
||||
|
||||
@@ -179,7 +214,8 @@ void WintxUnzipKernel(const paddle::Tensor &zipped_weight,
|
||||
const paddle::optional<paddle::Tensor> &code_scale,
|
||||
const paddle::optional<paddle::Tensor> &code_zp,
|
||||
const paddle::optional<paddle::Tensor> &super_scale,
|
||||
paddle::Tensor &weight, const std::string &quant_method) {
|
||||
paddle::Tensor &weight,
|
||||
const std::string &quant_method) {
|
||||
using data_t = typename PDTraits<T>::data_t;
|
||||
using NvType = typename PDTraits<T>::DataType;
|
||||
|
||||
@@ -199,7 +235,10 @@ void WintxUnzipKernel(const paddle::Tensor &zipped_weight,
|
||||
Wint25UnzipKernelLauncher<NvType>(
|
||||
reinterpret_cast<const uint16_t *>(zipped_weight_ptr),
|
||||
reinterpret_cast<const NvType *>(super_scale_ptr),
|
||||
reinterpret_cast<NvType *>(weight_ptr), batch, num_rows, num_columns);
|
||||
reinterpret_cast<NvType *>(weight_ptr),
|
||||
batch,
|
||||
num_rows,
|
||||
num_columns);
|
||||
} else if (quant_method == "weight_only_int2") {
|
||||
paddle::Tensor *local_scale_tensor =
|
||||
const_cast<paddle::Tensor *>(local_scale.get_ptr());
|
||||
@@ -209,22 +248,27 @@ void WintxUnzipKernel(const paddle::Tensor &zipped_weight,
|
||||
const_cast<paddle::Tensor *>(code_zp.get_ptr());
|
||||
|
||||
Wint2UnzipKernelLauncher<NvType>(
|
||||
zipped_weight.data<uint8_t>(), local_scale_tensor->data<uint8_t>(),
|
||||
code_scale_tensor->data<float>(), code_zp_tensor->data<float>(),
|
||||
zipped_weight.data<uint8_t>(),
|
||||
local_scale_tensor->data<uint8_t>(),
|
||||
code_scale_tensor->data<float>(),
|
||||
code_zp_tensor->data<float>(),
|
||||
reinterpret_cast<const NvType *>(super_scale_ptr),
|
||||
reinterpret_cast<NvType *>(weight_ptr), batch, num_rows, num_columns);
|
||||
reinterpret_cast<NvType *>(weight_ptr),
|
||||
batch,
|
||||
num_rows,
|
||||
num_columns);
|
||||
} else {
|
||||
PD_THROW("Unsupported quant_method for WintxUnzip.");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor>
|
||||
WintXUnzip(const paddle::Tensor &zipped_weight,
|
||||
const paddle::optional<paddle::Tensor> &local_scale,
|
||||
const paddle::optional<paddle::Tensor> &code_scale,
|
||||
const paddle::optional<paddle::Tensor> &code_zp,
|
||||
const paddle::optional<paddle::Tensor> &super_scale,
|
||||
const std::string &quant_method) {
|
||||
std::vector<paddle::Tensor> WintXUnzip(
|
||||
const paddle::Tensor &zipped_weight,
|
||||
const paddle::optional<paddle::Tensor> &local_scale,
|
||||
const paddle::optional<paddle::Tensor> &code_scale,
|
||||
const paddle::optional<paddle::Tensor> &code_zp,
|
||||
const paddle::optional<paddle::Tensor> &super_scale,
|
||||
const std::string &quant_method) {
|
||||
paddle::Tensor *local_scale_tensor =
|
||||
const_cast<paddle::Tensor *>(local_scale.get_ptr());
|
||||
paddle::Tensor *super_scale_tensor =
|
||||
@@ -251,18 +295,26 @@ WintXUnzip(const paddle::Tensor &zipped_weight,
|
||||
auto output_tensor = GetEmptyTensor(output_dims, dtype, place);
|
||||
|
||||
switch (dtype) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
WintxUnzipKernel<paddle::DataType::BFLOAT16>(
|
||||
zipped_weight, local_scale, code_scale, code_zp, super_scale,
|
||||
output_tensor, quant_method);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
WintxUnzipKernel<paddle::DataType::FLOAT16>(
|
||||
zipped_weight, local_scale, code_scale, code_zp, super_scale,
|
||||
output_tensor, quant_method);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for WintxUnzip");
|
||||
case paddle::DataType::BFLOAT16:
|
||||
WintxUnzipKernel<paddle::DataType::BFLOAT16>(zipped_weight,
|
||||
local_scale,
|
||||
code_scale,
|
||||
code_zp,
|
||||
super_scale,
|
||||
output_tensor,
|
||||
quant_method);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
WintxUnzipKernel<paddle::DataType::FLOAT16>(zipped_weight,
|
||||
local_scale,
|
||||
code_scale,
|
||||
code_zp,
|
||||
super_scale,
|
||||
output_tensor,
|
||||
quant_method);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for WintxUnzip");
|
||||
}
|
||||
return {output_tensor};
|
||||
}
|
||||
@@ -306,8 +358,10 @@ std::vector<paddle::DataType> WintXUnzipInferDtype(
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(winx_unzip)
|
||||
.Inputs({"zipped_weight", paddle::Optional("local_scale"),
|
||||
paddle::Optional("code_scale"), paddle::Optional("code_zp"),
|
||||
.Inputs({"zipped_weight",
|
||||
paddle::Optional("local_scale"),
|
||||
paddle::Optional("code_scale"),
|
||||
paddle::Optional("code_zp"),
|
||||
paddle::Optional("super_scale")})
|
||||
.Outputs({"weight"})
|
||||
.Attrs({"quant_method:std::string"})
|
||||
|
||||
Reference in New Issue
Block a user