// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef PD_BUILD_STATIC_OP #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) #endif #include "helper.h" #include "paddle/extension.h" #include "w4afp8_gemm_template.h" #include "w4afp8_gemm.h" void weight_convert(const uint8_t *weight, uint8_t *weight_new, int batch, int M, int K) { assert(K % 64 == 0); for (int b = 0; b < batch; ++b) { for (int m = 0; m < M; ++m) { for (int k = 0; k < K; k+=64) { for (int k_inner = 0; k_inner < 32; ++k_inner) { uint8_t temp = 0; uint8_t left = weight[b * M * K + m * K + k + k_inner]; uint8_t right = weight[b * M * K + m * K + k + k_inner + 32]; temp |= left << 4; temp |= right; weight_new[b * M * K / 2 + m * K / 2 + k / 2 + k_inner] = *reinterpret_cast(&temp); } } } } } template class NVTraits; template <> class NVTraits<__nv_fp8_e4m3> { public: typedef cutlass::float_e4m3_t data_t; }; template <> class NVTraits<__nv_bfloat16>{ public: typedef cutlass::bfloat16_t data_t; }; template <> class NVTraits{ public: typedef cutlass::half_t data_t; }; template void DisPatchW4AFp8Gemm( const cutlass::float_e4m3_t* input, const cutlass::float_e4m3_t* weight, const int64_t * tokens, const float * input_row_sum, const float * weight_scale, OutputType * out, const int64_t token_padding_size, const int64_t max_tokens, const int batch_size, const int64_t M, const int64_t K, cudaStream_t stream) { int kBlockN = 256; int TailN = 0; if constexpr (std::is_same_v) { GEMM_SWITCH_BF16( M, K, batch_size, token_padding_size, kBlockN, TailN, weight, input, out, weight_scale, input_row_sum, tokens, max_tokens, stream) } else { PD_THROW("Only supported dtype in ['BFLOAT16']."); } } std::vector W4AFp8Gemm( const paddle::Tensor& input, const paddle::Tensor& weight, const paddle::Tensor& tokens, // If tokenpadding=0, this tensor represents the prefix sum of tensors, otherwise it represents the number of tokens in each group const paddle::Tensor& input_row_sum, const paddle::Tensor& weight_scale, const int64_t token_padding_size, const int64_t max_tokens, const bool is_bfloat16) { const int batch_size = weight.dims()[0]; const int M = weight.dims()[1]; const int K = weight.dims()[2] * 2; if (input.dtype() != paddle::DataType::FLOAT8_E4M3FN) { PD_THROW("Only supported dtype in ['FLOAT8_E4M3FN']."); } if (token_padding_size == 0) { const int all_tokens = input.dims()[0]; if (is_bfloat16) { paddle::Tensor out = paddle::empty({all_tokens, M}, paddle::DataType::BFLOAT16, input.place()); phi::dtype::bfloat16 *out_data = out.data(); DisPatchW4AFp8Gemm( reinterpret_cast(input.data()), reinterpret_cast(weight.data()), tokens.data(), input_row_sum.data(), weight_scale.data(), reinterpret_cast(out_data), token_padding_size, max_tokens, batch_size, M, K, input.stream()); return {out}; } else { PD_THROW("Only supported dtype in ['BFLOAT16']."); } } else { if (is_bfloat16) { paddle::Tensor out = paddle::empty({batch_size, token_padding_size, M}, paddle::DataType::BFLOAT16, input.place()); phi::dtype::bfloat16 * out_data = out.data(); DisPatchW4AFp8Gemm( reinterpret_cast(input.data()), reinterpret_cast(weight.data()), tokens.data(), input_row_sum.data(), weight_scale.data(), reinterpret_cast(out_data), token_padding_size, max_tokens, batch_size, M, K, input.stream()); return {out}; } else { PD_THROW("Only supported dtype in ['BFLOAT16']."); } } } template void DisPatchW4AFp8GemmWrapper( const InputType* input, const InputType* weight, const int64_t* total_rows_before_expert, const float* input_row_sum, const float* row_scale, const float* weight_scale, OutputType * out, const int64_t token_padding_size, const int64_t max_tokens, const int num_experts, const int64_t M, const int64_t K, cudaStream_t stream) { using InType = typename NVTraits::data_t; using OutType = typename NVTraits::data_t; DisPatchW4AFp8Gemm( reinterpret_cast(input), reinterpret_cast(weight), total_rows_before_expert, input_row_sum, weight_scale, reinterpret_cast(out), token_padding_size, max_tokens, num_experts, M, K, stream); } std::vector W4AFp8GemmWeightConvert(const paddle::Tensor& weight) { const int batch_size = weight.dims()[0]; const int M = weight.dims()[1]; const int K = weight.dims()[2]; paddle::Tensor weight_new = paddle::empty({batch_size, M, K / 2}, paddle::DataType::UINT8, weight.place()); weight_convert(weight.data(), weight_new.data(), batch_size, M, K); return {weight_new}; } template __global__ void permute_scale_kernel( T* input_data, const int numel) { using LoadT = AlignedVector; LoadT input_vec; LoadT dst_vec; const int load_idx = (blockIdx.x * blockDim.x + threadIdx.x) * kPackSize; if (load_idx >= numel) { return; } Load(&input_data[load_idx], &input_vec); for (int i = 0; i < kPackSize; i+=2) { dst_vec[i] = input_vec[i / 2]; dst_vec[i + 1] = input_vec[i / 2 + 8]; } Store(dst_vec, &input_data[load_idx]); } void W4AFp8GemmScalePermute(const paddle::Tensor& scale) { const int row = scale.dims().size() == 2 ? scale.dims()[0] : 1; const int col = scale.dims().size() == 2 ? scale.dims()[1] : scale.dims()[0]; if (col % 16 != 0) { PD_THROW("Only supported when col is divisible by 16."); } const int numel = row * col; const int threads = 128; const int kPackSize = 16; const int grid_size = (numel / kPackSize + threads - 1) / threads; if (scale.dtype() == paddle::DataType::BFLOAT16) { permute_scale_kernel<<>>( const_cast(scale.data()), numel ); } else if (scale.dtype() == paddle::DataType::FLOAT16) { permute_scale_kernel<<>>( const_cast(scale.data()), numel ); } else if (scale.dtype() == paddle::DataType::FLOAT32) { permute_scale_kernel<<>>( const_cast(scale.data()), numel ); } } PD_BUILD_STATIC_OP(w4afp8_gemm_scale_permute) .Inputs({"weight_scale"}) .Outputs({"permute_scale"}) .SetInplaceMap({{"weight_scale", "permute_scale"}}) .SetKernelFn(PD_KERNEL(W4AFp8GemmScalePermute)); PD_BUILD_STATIC_OP(w4afp8_gemm) .Inputs({"input", "weight", "tokens", "input_row_sum", "weight_scale"}) .Outputs({"out"}) .Attrs({"token_padding_size: int64_t", "max_tokens: int64_t", "is_bfloat16: bool"}) .SetKernelFn(PD_KERNEL(W4AFp8Gemm)); PD_BUILD_STATIC_OP(w4afp8_gemm_weight_convert) .Inputs({"weight"}) .Outputs({"converted_weight"}) .SetKernelFn(PD_KERNEL(W4AFp8GemmWeightConvert)); template void DisPatchW4AFp8GemmWrapper<__nv_fp8_e4m3, __nv_bfloat16>( const __nv_fp8_e4m3* input, const __nv_fp8_e4m3* weight, const int64_t * tokens, const float * input_row_sum, const float * row_scale, const float * weight_scale, __nv_bfloat16 * out, const int64_t token_padding_size, const int64_t max_tokens, const int num_experts, const int64_t M, const int64_t K, cudaStream_t stream ); template void DisPatchW4AFp8GemmWrapper<__nv_fp8_e4m3, half>( const __nv_fp8_e4m3* input, const __nv_fp8_e4m3* weight, const int64_t * tokens, const float * input_row_sum, const float * row_scale, const float * weight_scale, half * out, const int64_t token_padding_size, const int64_t max_tokens, const int num_experts, const int64_t M, const int64_t K, cudaStream_t stream );