mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Iluvatar] Support wi4a16 group_gemm (#7078)
This commit is contained in:
@@ -15,11 +15,12 @@
|
||||
#include "helper.h"
|
||||
#include "iluvatar_context.h"
|
||||
|
||||
std::vector<paddle::Tensor> GroupGemm(const paddle::Tensor& x,
|
||||
const paddle::Tensor& weight,
|
||||
const paddle::Tensor& weight_scale,
|
||||
const paddle::Tensor& prefix_sum,
|
||||
const int32_t group_size) {
|
||||
std::vector<paddle::Tensor> W8A16GroupGemm(const paddle::Tensor& x,
|
||||
const paddle::Tensor& weight,
|
||||
const paddle::Tensor& weight_scale,
|
||||
const paddle::Tensor& weight_zeros,
|
||||
const paddle::Tensor& prefix_sum,
|
||||
const int32_t group_size) {
|
||||
auto dev_ctx = static_cast<const phi::CustomContext*>(
|
||||
paddle::experimental::DeviceContextPool::Instance().Get(x.place()));
|
||||
auto stream = static_cast<const cudaStream_t>(dev_ctx->stream());
|
||||
@@ -174,28 +175,22 @@ std::vector<paddle::Tensor> GroupGemm(const paddle::Tensor& x,
|
||||
return {output};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> GroupGemmInferShape(
|
||||
std::vector<std::vector<int64_t>> W8A16GroupGemmInferShape(
|
||||
const std::vector<int64_t>& x_shape,
|
||||
const std::vector<int64_t>& weight_shape,
|
||||
const std::vector<int64_t>& weight_scale_shape,
|
||||
const std::vector<int64_t>& prefix_sum_shape) {
|
||||
const std::vector<int64_t>& weight_shape) {
|
||||
return {{x_shape[0], weight_shape[1]}};
|
||||
}
|
||||
std::vector<paddle::DataType> GroupGemmInferDtype(
|
||||
const paddle::DataType& input_dtype,
|
||||
const paddle::DataType& weight_output_dtype,
|
||||
const paddle::DataType& weight_scale_dtype,
|
||||
const paddle::DataType& prefix_sum_dtype,
|
||||
const int moe_topk) {
|
||||
std::vector<paddle::DataType> W8A16GroupGemmInferDtype(
|
||||
const paddle::DataType& input_dtype) {
|
||||
return {input_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(w8a16_group_gemm)
|
||||
.Inputs({"x", "weight", "weight_scale", "prefix_sum"})
|
||||
.Inputs({"x", "weight", "weight_scale", "weight_zeros", "prefix_sum"})
|
||||
.Outputs({"output"})
|
||||
.Attrs({
|
||||
"group_size:int",
|
||||
})
|
||||
.SetKernelFn(PD_KERNEL(GroupGemm))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(GroupGemmInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(GroupGemmInferDtype));
|
||||
.SetKernelFn(PD_KERNEL(W8A16GroupGemm))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(W8A16GroupGemmInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(W8A16GroupGemmInferDtype));
|
||||
|
||||
@@ -15,18 +15,20 @@
|
||||
#include "helper.h"
|
||||
#include "iluvatar_context.h"
|
||||
|
||||
std::vector<paddle::Tensor> GroupGemv(const paddle::Tensor& x,
|
||||
const paddle::Tensor& weight,
|
||||
const paddle::Tensor& weight_scale,
|
||||
const paddle::Tensor& prefix_sum,
|
||||
const int32_t group_size) {
|
||||
std::vector<paddle::Tensor> W8A16GroupGemv(
|
||||
const paddle::Tensor& x,
|
||||
const paddle::Tensor& weight,
|
||||
const paddle::Tensor& weight_scale,
|
||||
const paddle::Tensor& weight_zeros,
|
||||
const paddle::Tensor& tokens_per_expert,
|
||||
const int32_t group_size) {
|
||||
auto dev_ctx = static_cast<const phi::CustomContext*>(
|
||||
paddle::experimental::DeviceContextPool::Instance().Get(x.place()));
|
||||
auto stream = static_cast<const cudaStream_t>(dev_ctx->stream());
|
||||
const auto& x_dims = x.dims();
|
||||
const auto& w_dims = weight.dims();
|
||||
const auto& ws_dims = weight_scale.dims();
|
||||
const auto& prefix_sum_dims = prefix_sum.dims();
|
||||
const auto& tokens_per_expert_dims = tokens_per_expert.dims();
|
||||
// [m, k]
|
||||
PD_CHECK(x_dims.size() == 2, "x should be 2D");
|
||||
// [n_experts, n, k]
|
||||
@@ -34,7 +36,8 @@ std::vector<paddle::Tensor> GroupGemv(const paddle::Tensor& x,
|
||||
// [n_experts, n]
|
||||
PD_CHECK(ws_dims.size() == 2, "weight_scale should be 2D");
|
||||
// [n_experts]
|
||||
PD_CHECK(prefix_sum_dims.size() == 1, "prefix_sum should be 1D");
|
||||
PD_CHECK(tokens_per_expert_dims.size() == 1,
|
||||
"tokens_per_expert should be 1D");
|
||||
PD_CHECK(group_size == -1);
|
||||
auto m = x_dims[0];
|
||||
auto k = x_dims[1];
|
||||
@@ -43,9 +46,9 @@ std::vector<paddle::Tensor> GroupGemv(const paddle::Tensor& x,
|
||||
PD_CHECK(w_dims[2] == k);
|
||||
PD_CHECK(ws_dims[0] == n_experts);
|
||||
PD_CHECK(ws_dims[1] == n);
|
||||
PD_CHECK(prefix_sum_dims[0] == n_experts);
|
||||
PD_CHECK(tokens_per_expert_dims[0] == n_experts);
|
||||
|
||||
PD_CHECK(prefix_sum.dtype() == paddle::DataType::INT32);
|
||||
PD_CHECK(tokens_per_expert.dtype() == paddle::DataType::INT32);
|
||||
PD_CHECK(x.dtype() == paddle::DataType::BFLOAT16 ||
|
||||
x.dtype() == paddle::DataType::FLOAT16);
|
||||
PD_CHECK(weight.dtype() == paddle::DataType::INT8);
|
||||
@@ -87,7 +90,7 @@ std::vector<paddle::Tensor> GroupGemv(const paddle::Tensor& x,
|
||||
cust_device_param.sortedId = nullptr;
|
||||
cust_device_param.bias = nullptr;
|
||||
cust_device_param.scale = weight_scale.data();
|
||||
cust_device_param.nSize = prefix_sum.data<int32_t>();
|
||||
cust_device_param.nSize = tokens_per_expert.data<int32_t>();
|
||||
|
||||
int lda = k;
|
||||
int ldb = k;
|
||||
@@ -152,28 +155,23 @@ std::vector<paddle::Tensor> GroupGemv(const paddle::Tensor& x,
|
||||
return {output};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> GroupGemvInferShape(
|
||||
std::vector<std::vector<int64_t>> W8A16GroupGemvInferShape(
|
||||
const std::vector<int64_t>& x_shape,
|
||||
const std::vector<int64_t>& weight_shape,
|
||||
const std::vector<int64_t>& weight_scale_shape,
|
||||
const std::vector<int64_t>& prefix_sum_shape) {
|
||||
const std::vector<int64_t>& weight_shape) {
|
||||
return {{x_shape[0], weight_shape[1]}};
|
||||
}
|
||||
std::vector<paddle::DataType> GroupGemvInferDtype(
|
||||
const paddle::DataType& input_dtype,
|
||||
const paddle::DataType& weight_output_dtype,
|
||||
const paddle::DataType& weight_scale_dtype,
|
||||
const paddle::DataType& prefix_sum_dtype,
|
||||
const int moe_topk) {
|
||||
std::vector<paddle::DataType> W8A16GroupGemvInferDtype(
|
||||
const paddle::DataType& input_dtype) {
|
||||
return {input_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(w8a16_group_gemv)
|
||||
.Inputs({"x", "weight", "weight_scale", "prefix_sum"})
|
||||
.Inputs(
|
||||
{"x", "weight", "weight_scale", "weight_zeros", "tokens_per_expert"})
|
||||
.Outputs({"output"})
|
||||
.Attrs({
|
||||
"group_size:int",
|
||||
})
|
||||
.SetKernelFn(PD_KERNEL(GroupGemv))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(GroupGemvInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(GroupGemvInferDtype));
|
||||
.SetKernelFn(PD_KERNEL(W8A16GroupGemv))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(W8A16GroupGemvInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(W8A16GroupGemvInferDtype));
|
||||
|
||||
@@ -0,0 +1,241 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "iluvatar_context.h"
|
||||
|
||||
std::vector<paddle::Tensor> WI4A16GroupGemm(const paddle::Tensor& x,
|
||||
const paddle::Tensor& weight,
|
||||
const paddle::Tensor& weight_scale,
|
||||
const paddle::Tensor& weight_zeros,
|
||||
const paddle::Tensor& prefix_sum,
|
||||
const int32_t group_size) {
|
||||
auto dev_ctx = static_cast<const phi::CustomContext*>(
|
||||
paddle::experimental::DeviceContextPool::Instance().Get(x.place()));
|
||||
auto stream = static_cast<const cudaStream_t>(dev_ctx->stream());
|
||||
auto prefix_sum_cpu = prefix_sum.copy_to(paddle::CPUPlace(), false);
|
||||
|
||||
const auto& x_dims = x.dims();
|
||||
const auto& w_dims = weight.dims();
|
||||
const auto& ws_dims = weight_scale.dims();
|
||||
const auto& prefix_sum_dims = prefix_sum.dims();
|
||||
const auto& zeros_dims = weight_zeros.dims();
|
||||
// [m, k]
|
||||
PD_CHECK(x_dims.size() == 2, "x should be 2D");
|
||||
// [n_experts, n // 2, k]
|
||||
PD_CHECK(w_dims.size() == 3, "weight should be 3D");
|
||||
// [n_experts, k // group_size, n]
|
||||
PD_CHECK(ws_dims.size() == 3, "weight_scale should be 3D");
|
||||
// [n_experts, k // group_size, n]
|
||||
PD_CHECK(zeros_dims.size() == 3, "weight_zeros should be 3D");
|
||||
// [n_experts]
|
||||
PD_CHECK(prefix_sum_dims.size() == 1, "prefix_sum should be 1D");
|
||||
PD_CHECK(group_size == 128);
|
||||
auto m = x_dims[0];
|
||||
auto k = x_dims[1];
|
||||
auto n_experts = w_dims[0];
|
||||
auto n = w_dims[1] * 2;
|
||||
PD_CHECK(w_dims[2] == k);
|
||||
PD_CHECK(ws_dims[0] == n_experts);
|
||||
PD_CHECK(ws_dims[1] == k / group_size);
|
||||
PD_CHECK(ws_dims[2] == n);
|
||||
PD_CHECK(zeros_dims[0] == n_experts);
|
||||
PD_CHECK(zeros_dims[1] == k / group_size);
|
||||
PD_CHECK(zeros_dims[2] == n);
|
||||
PD_CHECK(prefix_sum_dims[0] == n_experts);
|
||||
|
||||
PD_CHECK(x.dtype() == paddle::DataType::BFLOAT16 ||
|
||||
x.dtype() == paddle::DataType::FLOAT16);
|
||||
PD_CHECK(weight.dtype() == paddle::DataType::INT8);
|
||||
PD_CHECK(weight_scale.dtype() == x.dtype());
|
||||
PD_CHECK(weight_zeros.dtype() == x.dtype());
|
||||
PD_CHECK(prefix_sum.dtype() == paddle::DataType::INT64);
|
||||
|
||||
PD_CHECK(x.is_contiguous());
|
||||
PD_CHECK(weight.is_contiguous());
|
||||
PD_CHECK(weight_scale.is_contiguous());
|
||||
PD_CHECK(weight_zeros.is_contiguous());
|
||||
PD_CHECK(prefix_sum.is_contiguous());
|
||||
|
||||
const int64_t* prefix_sum_cpu_ptr = prefix_sum_cpu.data<int64_t>();
|
||||
auto output = GetEmptyTensor({m, n}, x.dtype(), x.place());
|
||||
int16_t* out_data = static_cast<int16_t*>(output.data());
|
||||
const int16_t* x_data = static_cast<const int16_t*>(x.data());
|
||||
const int8_t* weight_data = weight.data<int8_t>();
|
||||
const int16_t* weight_scale_data =
|
||||
static_cast<const int16_t*>(weight_scale.data());
|
||||
const int16_t* weight_zeros_data =
|
||||
static_cast<const int16_t*>(weight_zeros.data());
|
||||
|
||||
cuinferHandle_t handle = iluvatar::getContextInstance()->getIxInferHandle();
|
||||
cuinferPointerMode_t cuinfer_ptr_mode = CUINFER_POINTER_MODE_HOST;
|
||||
cuinferOperation_t transa = CUINFER_OP_T;
|
||||
cuinferOperation_t transb = CUINFER_OP_N;
|
||||
cudaDataType_t Atype = CUDA_R_4I;
|
||||
cudaDataType_t Btype;
|
||||
if (x.dtype() == paddle::DataType::FLOAT16) {
|
||||
Btype = CUDA_R_16F;
|
||||
} else if (x.dtype() == paddle::DataType::BFLOAT16) {
|
||||
Btype = CUDA_R_16BF;
|
||||
} else {
|
||||
PADDLE_THROW(common::errors::Unimplemented("Unsupported input dtype."));
|
||||
}
|
||||
cudaDataType_t Ctype = Btype;
|
||||
cudaDataType_t computeType = CUDA_R_32F;
|
||||
cudaDataType_t scaleType = CUDA_R_32F;
|
||||
cuinferGEMMCustomOption_t customOption = CUINFER_BLAS_GEMM_CUSTOM_NONE;
|
||||
|
||||
cuinferQuantGEMMHostParam cust_host_param;
|
||||
cuinferCustomGemmHostParamInit(&cust_host_param);
|
||||
cust_host_param.size = sizeof(cuinferQuantGEMMHostParam);
|
||||
cust_host_param.persistent = 0;
|
||||
cust_host_param.groupSize = group_size;
|
||||
|
||||
cuinferQuantGEMMDeviceParam cust_device_param;
|
||||
cust_device_param.size = sizeof(cuinferQuantGEMMDeviceParam);
|
||||
cust_device_param.bias = nullptr;
|
||||
|
||||
int lda = k;
|
||||
int ldb = k;
|
||||
int ldc = n;
|
||||
float beta = 0.f;
|
||||
float alpha = 1.f;
|
||||
int batch_count = 1;
|
||||
size_t pre = 0;
|
||||
|
||||
auto* allocator = paddle::GetAllocator(x.place());
|
||||
phi::Allocator::AllocationPtr tmp_workspace;
|
||||
for (int i = 0; i < n_experts; i++) {
|
||||
size_t expert_i_end = prefix_sum_cpu_ptr[i];
|
||||
size_t cur_len = expert_i_end - pre;
|
||||
pre = expert_i_end;
|
||||
if (cur_len != 0) {
|
||||
cust_device_param.scale = weight_scale_data;
|
||||
cust_device_param.zero = weight_zeros_data;
|
||||
|
||||
size_t workspace_size = 0;
|
||||
CUINFER_CHECK(cuinferGetCustomGemmWorkspace(transa,
|
||||
transb,
|
||||
n,
|
||||
cur_len,
|
||||
k,
|
||||
Atype,
|
||||
lda,
|
||||
lda,
|
||||
Btype,
|
||||
ldb,
|
||||
ldb,
|
||||
Ctype,
|
||||
ldc,
|
||||
ldc,
|
||||
batch_count,
|
||||
computeType,
|
||||
scaleType,
|
||||
&workspace_size));
|
||||
if (workspace_size > 0) {
|
||||
tmp_workspace = allocator->Allocate(workspace_size);
|
||||
cust_device_param.workspace = tmp_workspace->ptr();
|
||||
} else {
|
||||
cust_device_param.workspace = nullptr;
|
||||
}
|
||||
|
||||
if (cur_len <= 1) {
|
||||
CUINFER_CHECK(cuinferCustomGemmEx(handle,
|
||||
stream,
|
||||
cuinfer_ptr_mode,
|
||||
transa,
|
||||
transb,
|
||||
n,
|
||||
cur_len,
|
||||
k,
|
||||
&alpha,
|
||||
weight_data,
|
||||
Atype,
|
||||
lda,
|
||||
lda,
|
||||
x_data,
|
||||
Btype,
|
||||
ldb,
|
||||
ldb,
|
||||
&beta,
|
||||
out_data,
|
||||
Ctype,
|
||||
ldc,
|
||||
ldc,
|
||||
batch_count,
|
||||
computeType,
|
||||
scaleType,
|
||||
&cust_host_param,
|
||||
&cust_device_param,
|
||||
customOption,
|
||||
cust_device_param.workspace));
|
||||
} else {
|
||||
CUINFER_CHECK(cuinferCustomGemm(handle,
|
||||
stream,
|
||||
cuinfer_ptr_mode,
|
||||
transa,
|
||||
transb,
|
||||
n,
|
||||
cur_len,
|
||||
k,
|
||||
&alpha,
|
||||
weight_data,
|
||||
Atype,
|
||||
lda,
|
||||
lda,
|
||||
x_data,
|
||||
Btype,
|
||||
ldb,
|
||||
ldb,
|
||||
&beta,
|
||||
out_data,
|
||||
Ctype,
|
||||
ldc,
|
||||
ldc,
|
||||
batch_count,
|
||||
computeType,
|
||||
scaleType,
|
||||
&cust_host_param,
|
||||
&cust_device_param,
|
||||
customOption));
|
||||
}
|
||||
}
|
||||
x_data += cur_len * k;
|
||||
weight_data += k * n / 2;
|
||||
weight_scale_data += k * n / group_size;
|
||||
weight_zeros_data += k * n / group_size;
|
||||
out_data += cur_len * n;
|
||||
}
|
||||
return {output};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> WI4A16GroupGemmInferShape(
|
||||
const std::vector<int64_t>& x_shape,
|
||||
const std::vector<int64_t>& weight_shape) {
|
||||
return {{x_shape[0], weight_shape[1] * 2}};
|
||||
}
|
||||
std::vector<paddle::DataType> WI4A16GroupGemmInferDtype(
|
||||
const paddle::DataType& input_dtype) {
|
||||
return {input_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(wi4a16_group_gemm)
|
||||
.Inputs({"x", "weight", "weight_scale", "weight_zeros", "prefix_sum"})
|
||||
.Outputs({"output"})
|
||||
.Attrs({
|
||||
"group_size:int",
|
||||
})
|
||||
.SetKernelFn(PD_KERNEL(WI4A16GroupGemm))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(WI4A16GroupGemmInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(WI4A16GroupGemmInferDtype));
|
||||
@@ -0,0 +1,198 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Wi4A16 weight quantization: per-group symmetric int4 with scale = max|w|/7,
|
||||
// packed two int4 per int8 along the output dimension, matching Python
|
||||
// fastdeploy.model_executor.ops.iluvatar.utils.wi4a16_weight_quantize_cuda.
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
namespace {
|
||||
|
||||
__device__ __forceinline__ float ToFloat(__half v) { return __half2float(v); }
|
||||
|
||||
__device__ __forceinline__ float ToFloat(__nv_bfloat16 v) {
|
||||
return __bfloat162float(v);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void WriteScale(__half* scales, int idx, float s) {
|
||||
scales[idx] = __float2half(s);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void WriteScale(__nv_bfloat16* scales,
|
||||
int idx,
|
||||
float s) {
|
||||
scales[idx] = __float2bfloat16_rn(s);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void Wi4A16QuantizeGroupsKernel(const T* __restrict__ w,
|
||||
int8_t* __restrict__ q,
|
||||
T* __restrict__ scales,
|
||||
int k,
|
||||
int n,
|
||||
int group_size,
|
||||
int num_groups_per_row) {
|
||||
const int gid = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
if (tid >= group_size) return;
|
||||
|
||||
const int nn = gid / num_groups_per_row;
|
||||
const int gj = gid % num_groups_per_row;
|
||||
const int kk = gj * group_size + tid;
|
||||
|
||||
extern __shared__ float sdata[];
|
||||
const float v = fabsf(ToFloat(w[kk * n + nn]));
|
||||
sdata[tid] = v;
|
||||
__syncthreads();
|
||||
|
||||
for (int s = group_size >> 1; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
sdata[tid] = fmaxf(sdata[tid], sdata[tid + s]);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
__shared__ float s_scale;
|
||||
if (tid == 0) {
|
||||
const float max_abs = sdata[0];
|
||||
s_scale = max_abs / 7.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const float scale = s_scale;
|
||||
if (tid == 0) {
|
||||
WriteScale(scales, gj * n + nn, scale);
|
||||
}
|
||||
|
||||
const float wval = ToFloat(w[kk * n + nn]);
|
||||
float qf = roundf(wval / scale);
|
||||
if (qf > 7.f) qf = 7.f;
|
||||
if (qf < -8.f) qf = -8.f;
|
||||
q[kk * n + nn] = static_cast<int8_t>(static_cast<int>(qf));
|
||||
}
|
||||
|
||||
__global__ void Wi4A16PackInt4Kernel(const int8_t* __restrict__ q,
|
||||
int8_t* __restrict__ packed,
|
||||
int k,
|
||||
int n) {
|
||||
const int nn = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int kk = blockIdx.y;
|
||||
const int nhalf = n >> 1;
|
||||
if (nn >= nhalf || kk >= k) return;
|
||||
|
||||
const int8_t q0 = q[kk * n + (nn << 1)];
|
||||
const int8_t q1 = q[kk * n + (nn << 1) + 1];
|
||||
const uint32_t b0 = static_cast<uint32_t>(static_cast<uint8_t>(q0)) & 0xFU;
|
||||
const uint32_t b1 = static_cast<uint32_t>(static_cast<uint8_t>(q1)) & 0xFU;
|
||||
packed[nn * k + kk] = static_cast<int8_t>(b0 | (b1 << 4));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::vector<paddle::Tensor> Wi4A16Quantize(const paddle::Tensor& w,
|
||||
int32_t group_size) {
|
||||
PD_CHECK(w.dims().size() == 2,
|
||||
"wi4a16_weight_quantize: weight must be 2D [k, n]");
|
||||
PD_CHECK(group_size == 128,
|
||||
"wi4a16_weight_quantize CUDA: group_size must be 128");
|
||||
const int64_t k = w.dims()[0];
|
||||
const int64_t n = w.dims()[1];
|
||||
PD_CHECK(n % 2 == 0, "wi4a16_weight_quantize: n (dim 1) must be even");
|
||||
PD_CHECK(k % group_size == 0,
|
||||
"wi4a16_weight_quantize: k must be divisible by group_size");
|
||||
|
||||
PD_CHECK(w.dtype() == paddle::DataType::FLOAT16 ||
|
||||
w.dtype() == paddle::DataType::BFLOAT16,
|
||||
"wi4a16_weight_quantize: weight dtype must be float16 or bfloat16");
|
||||
PD_CHECK(w.is_contiguous(),
|
||||
"wi4a16_weight_quantize: weight must be contiguous");
|
||||
|
||||
auto dev_ctx = static_cast<const phi::CustomContext*>(
|
||||
paddle::experimental::DeviceContextPool::Instance().Get(w.place()));
|
||||
auto stream = static_cast<cudaStream_t>(dev_ctx->stream());
|
||||
|
||||
auto packed = GetEmptyTensor({n / 2, k}, paddle::DataType::INT8, w.place());
|
||||
auto scales = GetEmptyTensor({k / group_size, n}, w.dtype(), w.place());
|
||||
auto zeros = GetEmptyTensor({k / group_size, n}, w.dtype(), w.place());
|
||||
|
||||
CUDA_CHECK(cudaMemsetAsync(
|
||||
zeros.data(),
|
||||
0,
|
||||
static_cast<size_t>(zeros.numel()) * phi::SizeOf(zeros.dtype()),
|
||||
stream));
|
||||
|
||||
auto q_tmp = GetEmptyTensor({k, n}, paddle::DataType::INT8, w.place());
|
||||
int8_t* q_ptr = q_tmp.data<int8_t>();
|
||||
int8_t* packed_ptr = packed.data<int8_t>();
|
||||
|
||||
const int num_groups_per_row = static_cast<int>(k / group_size);
|
||||
const int total_groups = static_cast<int>(n * num_groups_per_row);
|
||||
const int threads = group_size;
|
||||
const size_t shmem = static_cast<size_t>(group_size) * sizeof(float);
|
||||
|
||||
if (w.dtype() == paddle::DataType::FLOAT16) {
|
||||
Wi4A16QuantizeGroupsKernel<__half>
|
||||
<<<total_groups, threads, shmem, stream>>>(
|
||||
reinterpret_cast<const __half*>(w.data()),
|
||||
q_ptr,
|
||||
reinterpret_cast<__half*>(scales.data()),
|
||||
static_cast<int>(k),
|
||||
static_cast<int>(n),
|
||||
group_size,
|
||||
num_groups_per_row);
|
||||
} else {
|
||||
Wi4A16QuantizeGroupsKernel<__nv_bfloat16>
|
||||
<<<total_groups, threads, shmem, stream>>>(
|
||||
reinterpret_cast<const __nv_bfloat16*>(w.data()),
|
||||
q_ptr,
|
||||
reinterpret_cast<__nv_bfloat16*>(scales.data()),
|
||||
static_cast<int>(k),
|
||||
static_cast<int>(n),
|
||||
group_size,
|
||||
num_groups_per_row);
|
||||
}
|
||||
|
||||
const int nhalf = static_cast<int>(n >> 1);
|
||||
dim3 block(256);
|
||||
dim3 grid((nhalf + block.x - 1) / block.x, static_cast<unsigned>(k));
|
||||
Wi4A16PackInt4Kernel<<<grid, block, 0, stream>>>(
|
||||
q_ptr, packed_ptr, static_cast<int>(k), static_cast<int>(n));
|
||||
|
||||
return {packed, scales, zeros};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> Wi4A16QuantizeInferShape(
|
||||
const std::vector<int64_t>& w_shape, int32_t group_size) {
|
||||
const int64_t k = w_shape[0];
|
||||
const int64_t n = w_shape[1];
|
||||
const int64_t k_groups = k / group_size;
|
||||
return {{n / 2, k}, {k_groups, n}, {k_groups, n}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> Wi4A16QuantizeInferDtype(
|
||||
const paddle::DataType& w_dtype, int32_t group_size) {
|
||||
return {paddle::DataType::INT8, w_dtype, w_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(wi4a16_weight_quantize_cuda)
|
||||
.Inputs({"w"})
|
||||
.Outputs({"quant_weight", "scales", "zeros"})
|
||||
.Attrs({"group_size: int"})
|
||||
.SetKernelFn(PD_KERNEL(Wi4A16Quantize))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(Wi4A16QuantizeInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(Wi4A16QuantizeInferDtype));
|
||||
@@ -584,14 +584,13 @@ elif paddle.is_compiled_with_cuda():
|
||||
elif paddle.is_compiled_with_xpu():
|
||||
assert False, "For XPU, please use setup_ops.py in the xpu_ops directory to compile custom ops."
|
||||
elif paddle.is_compiled_with_custom_device("iluvatar_gpu"):
|
||||
_iluvatar_clang_cuda_flags = ["-Wno-non-pod-varargs", "-DPADDLE_DEV", "-DPADDLE_WITH_CUSTOM_DEVICE"]
|
||||
setup(
|
||||
name="fastdeploy_ops",
|
||||
ext_modules=CUDAExtension(
|
||||
extra_compile_args={
|
||||
"nvcc": [
|
||||
"-DPADDLE_DEV",
|
||||
"-DPADDLE_WITH_CUSTOM_DEVICE",
|
||||
]
|
||||
"cxx": _iluvatar_clang_cuda_flags,
|
||||
"nvcc": _iluvatar_clang_cuda_flags,
|
||||
},
|
||||
sources=[
|
||||
"gpu_ops/save_with_output_msg.cc",
|
||||
@@ -625,6 +624,8 @@ elif paddle.is_compiled_with_custom_device("iluvatar_gpu"):
|
||||
"iluvatar_ops/mixed_fused_attn.cu",
|
||||
"iluvatar_ops/w8a16_group_gemm.cu",
|
||||
"iluvatar_ops/w8a16_group_gemv.cu",
|
||||
"iluvatar_ops/wi4a16_group_gemm.cu",
|
||||
"iluvatar_ops/wi4a16_weight_quantize.cu",
|
||||
"iluvatar_ops/restore_tokens_per_expert.cu",
|
||||
"iluvatar_ops/runtime/iluvatar_context.cc",
|
||||
"iluvatar_ops/cpp_extensions.cc",
|
||||
|
||||
Reference in New Issue
Block a user