// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Wi4A16 weight quantization: per-group symmetric int4 with scale = max|w|/7, // packed two int4 per int8 along the output dimension, matching Python // fastdeploy.model_executor.ops.iluvatar.utils.wi4a16_weight_quantize_cuda. #include #include #include "helper.h" namespace { __device__ __forceinline__ float ToFloat(__half v) { return __half2float(v); } __device__ __forceinline__ float ToFloat(__nv_bfloat16 v) { return __bfloat162float(v); } __device__ __forceinline__ void WriteScale(__half* scales, int idx, float s) { scales[idx] = __float2half(s); } __device__ __forceinline__ void WriteScale(__nv_bfloat16* scales, int idx, float s) { scales[idx] = __float2bfloat16_rn(s); } template __global__ void Wi4A16QuantizeGroupsKernel(const T* __restrict__ w, int8_t* __restrict__ q, T* __restrict__ scales, int k, int n, int group_size, int num_groups_per_row) { const int gid = blockIdx.x; const int tid = threadIdx.x; if (tid >= group_size) return; const int nn = gid / num_groups_per_row; const int gj = gid % num_groups_per_row; const int kk = gj * group_size + tid; extern __shared__ float sdata[]; const float v = fabsf(ToFloat(w[kk * n + nn])); sdata[tid] = v; __syncthreads(); for (int s = group_size >> 1; s > 0; s >>= 1) { if (tid < s) { sdata[tid] = fmaxf(sdata[tid], sdata[tid + s]); } __syncthreads(); } __shared__ float s_scale; if (tid == 0) { const float max_abs = sdata[0]; s_scale = max_abs / 7.0f; } __syncthreads(); const float scale = s_scale; if (tid == 0) { WriteScale(scales, gj * n + nn, scale); } const float wval = ToFloat(w[kk * n + nn]); float qf = roundf(wval / scale); if (qf > 7.f) qf = 7.f; if (qf < -8.f) qf = -8.f; q[kk * n + nn] = static_cast(static_cast(qf)); } __global__ void Wi4A16PackInt4Kernel(const int8_t* __restrict__ q, int8_t* __restrict__ packed, int k, int n) { const int nn = blockIdx.x * blockDim.x + threadIdx.x; const int kk = blockIdx.y; const int nhalf = n >> 1; if (nn >= nhalf || kk >= k) return; const int8_t q0 = q[kk * n + (nn << 1)]; const int8_t q1 = q[kk * n + (nn << 1) + 1]; const uint32_t b0 = static_cast(static_cast(q0)) & 0xFU; const uint32_t b1 = static_cast(static_cast(q1)) & 0xFU; packed[nn * k + kk] = static_cast(b0 | (b1 << 4)); } } // namespace std::vector Wi4A16Quantize(const paddle::Tensor& w, int32_t group_size) { PD_CHECK(w.dims().size() == 2, "wi4a16_weight_quantize: weight must be 2D [k, n]"); PD_CHECK(group_size == 128, "wi4a16_weight_quantize CUDA: group_size must be 128"); const int64_t k = w.dims()[0]; const int64_t n = w.dims()[1]; PD_CHECK(n % 2 == 0, "wi4a16_weight_quantize: n (dim 1) must be even"); PD_CHECK(k % group_size == 0, "wi4a16_weight_quantize: k must be divisible by group_size"); PD_CHECK(w.dtype() == paddle::DataType::FLOAT16 || w.dtype() == paddle::DataType::BFLOAT16, "wi4a16_weight_quantize: weight dtype must be float16 or bfloat16"); PD_CHECK(w.is_contiguous(), "wi4a16_weight_quantize: weight must be contiguous"); auto dev_ctx = static_cast( paddle::experimental::DeviceContextPool::Instance().Get(w.place())); auto stream = static_cast(dev_ctx->stream()); auto packed = GetEmptyTensor({n / 2, k}, paddle::DataType::INT8, w.place()); auto scales = GetEmptyTensor({k / group_size, n}, w.dtype(), w.place()); auto zeros = GetEmptyTensor({k / group_size, n}, w.dtype(), w.place()); CUDA_CHECK(cudaMemsetAsync( zeros.data(), 0, static_cast(zeros.numel()) * phi::SizeOf(zeros.dtype()), stream)); auto q_tmp = GetEmptyTensor({k, n}, paddle::DataType::INT8, w.place()); int8_t* q_ptr = q_tmp.data(); int8_t* packed_ptr = packed.data(); const int num_groups_per_row = static_cast(k / group_size); const int total_groups = static_cast(n * num_groups_per_row); const int threads = group_size; const size_t shmem = static_cast(group_size) * sizeof(float); if (w.dtype() == paddle::DataType::FLOAT16) { Wi4A16QuantizeGroupsKernel<__half> <<>>( reinterpret_cast(w.data()), q_ptr, reinterpret_cast<__half*>(scales.data()), static_cast(k), static_cast(n), group_size, num_groups_per_row); } else { Wi4A16QuantizeGroupsKernel<__nv_bfloat16> <<>>( reinterpret_cast(w.data()), q_ptr, reinterpret_cast<__nv_bfloat16*>(scales.data()), static_cast(k), static_cast(n), group_size, num_groups_per_row); } const int nhalf = static_cast(n >> 1); dim3 block(256); dim3 grid((nhalf + block.x - 1) / block.x, static_cast(k)); Wi4A16PackInt4Kernel<<>>( q_ptr, packed_ptr, static_cast(k), static_cast(n)); return {packed, scales, zeros}; } std::vector> Wi4A16QuantizeInferShape( const std::vector& w_shape, int32_t group_size) { const int64_t k = w_shape[0]; const int64_t n = w_shape[1]; const int64_t k_groups = k / group_size; return {{n / 2, k}, {k_groups, n}, {k_groups, n}}; } std::vector Wi4A16QuantizeInferDtype( const paddle::DataType& w_dtype, int32_t group_size) { return {paddle::DataType::INT8, w_dtype, w_dtype}; } PD_BUILD_STATIC_OP(wi4a16_weight_quantize_cuda) .Inputs({"w"}) .Outputs({"quant_weight", "scales", "zeros"}) .Attrs({"group_size: int"}) .SetKernelFn(PD_KERNEL(Wi4A16Quantize)) .SetInferShapeFn(PD_INFER_SHAPE(Wi4A16QuantizeInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(Wi4A16QuantizeInferDtype));