mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
w4afp8 fix quant (#5830)
This commit is contained in:
@@ -41,6 +41,8 @@ struct GpuLaunchConfig {
|
||||
dim3 thread_per_block;
|
||||
};
|
||||
|
||||
constexpr float epsilon = 1e-4;
|
||||
|
||||
inline GpuLaunchConfig Get1DBlocksAnd2DGridsMoe(const int64_t cols) {
|
||||
int blocks_x = cols;
|
||||
int blocks_y = 1;
|
||||
@@ -186,7 +188,7 @@ __global__ void masked_quantize_moe_input_kernel(
|
||||
continue;
|
||||
}
|
||||
int64_t expert_idx = expert_idx_per_token[token_idx];
|
||||
float abs_max = 0.0f;
|
||||
float abs_max = epsilon;
|
||||
for (int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) {
|
||||
int64_t offset = token_idx * dim + idx * VecSize;
|
||||
Load<T, VecSize>(&permuted_inputs[offset], &input_vec);
|
||||
@@ -234,7 +236,7 @@ __global__ void quantize_moe_input_kernel(const T* permuted_inputs,
|
||||
for (int token_idx = blockIdx.x; token_idx < token_num;
|
||||
token_idx += gridDim.x) {
|
||||
int64_t expert_idx = expert_idx_per_token[token_idx];
|
||||
float abs_max = 0.0f;
|
||||
float abs_max = epsilon;
|
||||
for (int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) {
|
||||
int64_t offset = token_idx * dim + idx * VecSize;
|
||||
Load<T, VecSize>(&permuted_inputs[offset], &input_vec);
|
||||
@@ -1280,7 +1282,7 @@ __global__ void initialize_moe_routing_kernel(
|
||||
|
||||
if constexpr (std::is_same<OutT, phi::dtype::float8_e4m3fn>::value) {
|
||||
if (dequant_scale != nullptr) {
|
||||
float abs_max = 0.f;
|
||||
float abs_max = epsilon;
|
||||
for (int tid = threadIdx.x * VecSize; tid < cols;
|
||||
tid += blockDim.x * VecSize) {
|
||||
Load<T, VecSize>(&source_row_ptr[tid], &src_vec);
|
||||
|
||||
Reference in New Issue
Block a user