From 36547cfdb347389eadd7206fc05d493c8502135c Mon Sep 17 00:00:00 2001 From: fxyfxy777 <137464345+fxyfxy777@users.noreply.github.com> Date: Wed, 4 Feb 2026 14:33:03 +0800 Subject: [PATCH] [Feature] FD_USE_PHI_FP8_QUANT (#6320) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add ut * add use_fd_quant env * rm mask_per_token_quant * add make ops list * USE_FD_FP8_QUANT -> FD_USE_PHI_FP8_QUANT 默认是true * modify comments * use bool type * Add function declaration --- custom_ops/gpu_ops/cpp_extensions.cc | 17 + custom_ops/gpu_ops/per_token_quant_fp8.cu | 340 ++++++++++++++++++ custom_ops/setup_ops.py | 1 + fastdeploy/envs.py | 2 + .../layers/moe/fused_moe_deepgemm_backend.py | 92 +++-- .../layers/moe/fused_moe_triton_backend.py | 25 +- .../layers/quantization/block_wise_fp8.py | 21 +- tests/operators/test_per_token_quant.py | 187 ++++++++++ 8 files changed, 634 insertions(+), 51 deletions(-) create mode 100644 custom_ops/gpu_ops/per_token_quant_fp8.cu create mode 100644 tests/operators/test_per_token_quant.py diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 17c452aab9..1836c70c8f 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -299,6 +299,11 @@ std::vector EPMoeExpertDispatchFP8( const bool use_in_ep, const int token_nums_this_rank_padded); +std::vector PerTokenQuant(paddle::Tensor& input, + const int block_size); +std::vector PerTokenQuantPadding(paddle::Tensor& input, + const int block_size); + std::vector FusedMaskSwigluFP8Quant( paddle::Tensor& input, paddle::Tensor& token_nums_per_expert, @@ -1258,6 +1263,18 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("routed_scaling_factor"), "ep moe export combine function"); + m.def("per_token_quant", + &PerTokenQuant, + py::arg("input"), + py::arg("block_size"), + "per token per block quant"); + + m.def("per_token_quant_padding", + &PerTokenQuantPadding, + py::arg("input"), + py::arg("block_size"), + "per token per block quant and padding transpose scale"); + m.def("fused_mask_swiglu_fp8_quant", &FusedMaskSwigluFP8Quant, py::arg("input"), diff --git a/custom_ops/gpu_ops/per_token_quant_fp8.cu b/custom_ops/gpu_ops/per_token_quant_fp8.cu new file mode 100644 index 0000000000..a386096ec3 --- /dev/null +++ b/custom_ops/gpu_ops/per_token_quant_fp8.cu @@ -0,0 +1,340 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" + +constexpr float epsilon = 1e-10; + +template +__global__ void quant_per_token_per_block( + const T *input, + phi::dtype::float8_e4m3fn *quanted_res, + float *quanted_scale, + const int token_num, + const int hidden_size, + const int hidden_size_scale, + const bool use_finegrained_range) { + const int bid = blockIdx.x; + const int tid = threadIdx.x; + const int warp_id = tid / 32; + const int lane_id = tid % 32; + const int num_warp = blockDim.x / 32; + static constexpr int NUM_PER_THREADS = 128 / 32; // 4 + static constexpr float MAX_VALUE = 448.f; + // Note(ZKK) use ceil_div!! + const int end_iter = (hidden_size + 127) / 128; // warp_iter_num + AlignedVector load_vec; + AlignedVector load_vec_float; + AlignedVector res_vec; + for (int token_idx = bid; token_idx < token_num; token_idx += gridDim.x) { + const T *input_now = input + token_idx * hidden_size; + phi::dtype::float8_e4m3fn *quanted_res_now = + quanted_res + token_idx * hidden_size; + float *quanted_scale_now = quanted_scale + token_idx * hidden_size_scale; + // deal a block per warp + for (int iter = warp_id; iter < end_iter; iter += num_warp) { + const int start_offset = iter * 128; + + const bool is_valid_data = + start_offset + lane_id * NUM_PER_THREADS < hidden_size; + + if (is_valid_data) { + Load( + input_now + start_offset + lane_id * NUM_PER_THREADS, &load_vec); + } else { +#pragma unroll + for (int vid = 0; vid < NUM_PER_THREADS; vid++) load_vec[vid] = T(0.f); + } + // get max value per thread + float max_value_thread = -5e4; +#pragma unroll + for (int vid = 0; vid < NUM_PER_THREADS; vid++) { + load_vec_float[vid] = static_cast(load_vec[vid]); + max_value_thread = max(abs(load_vec_float[vid]), max_value_thread); + } + // get max value per warp + max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 16), + max_value_thread); + max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 8), + max_value_thread); + max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 4), + max_value_thread); + max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 2), + max_value_thread); + max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 1), + max_value_thread); + // broadcast max_value + max_value_thread = __shfl_sync(0xFFFFFFFF, max_value_thread, 0); + max_value_thread = max(max_value_thread, epsilon); + + if (use_finegrained_range) { + max_value_thread *= 7.0f; + } + + float scale_to_store = max_value_thread / MAX_VALUE; + // quant +#pragma unroll + for (int vid = 0; vid < NUM_PER_THREADS; vid++) { + res_vec[vid] = static_cast( + load_vec_float[vid] * MAX_VALUE / max_value_thread); + } + // store + if (is_valid_data) + Store( + res_vec, + quanted_res_now + start_offset + lane_id * NUM_PER_THREADS); + if (lane_id == 0) { + quanted_scale_now[iter] = scale_to_store; + } + } + } +} + +std::vector PerTokenQuant(paddle::Tensor &input, + const int block_size) { + auto input_dim = input.dims(); + const int token_num = input_dim[0]; + const int hidden_size = input_dim[1]; + // Note(ZKK) here we use ceil_dive to support 4.5T runing on 8 GPUS + // where moe_intermediate_size is 448, can not be divided by 128. + const int hidden_size_scale = (hidden_size + block_size - 1) / block_size; + + auto quanted_x = GetEmptyTensor( + {token_num, hidden_size}, paddle::DataType::FLOAT8_E4M3FN, input.place()); + auto quanted_scale = GetEmptyTensor( + {token_num, hidden_size_scale}, paddle::DataType::FLOAT32, input.place()); + const int gridx = min(132 * 8, token_num); + const int blockx = min(1024, hidden_size / 128 * 32); + + bool use_finegrained_range = false; + char *env_var = getenv("PER_TOKEN_QUANT_FP8_USE_FINEGRAINED_RANGE"); + if (env_var) { + use_finegrained_range = static_cast(std::stoi(env_var)); + } + + switch (input.dtype()) { + case paddle::DataType::BFLOAT16: + quant_per_token_per_block<<>>( + input.data(), + quanted_x.data(), + quanted_scale.data(), + token_num, + hidden_size, + hidden_size_scale, + use_finegrained_range); + break; + case paddle::DataType::FLOAT16: + quant_per_token_per_block<<>>( + input.data(), + quanted_x.data(), + quanted_scale.data(), + token_num, + hidden_size, + hidden_size_scale, + use_finegrained_range); + break; + default: + PD_THROW("Unsupported data type for PerTokenQuant"); + } + return {quanted_x, quanted_scale}; +} + +std::vector> PerTokenQuantInferShape( + std::vector input_shape, const int block_size) { + const int token_num = input_shape[0]; + const int hidden_size = input_shape[1]; + const int hidden_size_scale = (hidden_size + block_size - 1) / block_size; + return {{token_num, hidden_size}, {token_num, hidden_size_scale}}; +} + +std::vector PerTokenQuantInferDtype( + paddle::DataType input_dtype, const int block_size) { + return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::FLOAT32}; +} + +template +__global__ void quant_per_token_per_block_padding( + const T *input, + phi::dtype::float8_e4m3fn *quanted_res, + float *quanted_scale, + const int token_num, + const int padded_token_num, + const int hidden_size, + const int hidden_size_scale, + const bool use_finegrained_range) { + const int bid = blockIdx.x; + const int tid = threadIdx.x; + const int warp_id = tid / 32; + const int lane_id = tid % 32; + const int num_warp = blockDim.x / 32; + static constexpr int NUM_PER_THREADS = 128 / 32; // 4 + static constexpr float MAX_VALUE = 448.f; + const int end_iter = hidden_size / 128; // warp_iter_num + AlignedVector load_vec; + AlignedVector load_vec_float; + AlignedVector res_vec; + for (int token_idx = bid; token_idx < token_num; token_idx += gridDim.x) { + const T *input_now = input + token_idx * hidden_size; + phi::dtype::float8_e4m3fn *quanted_res_now = + quanted_res + token_idx * hidden_size; + // deal a block per warp + for (int iter = warp_id; iter < end_iter; iter += num_warp) { + float *quanted_scale_now = + quanted_scale + iter * padded_token_num + token_idx; + const int start_offset = iter * 128; + Load( + input_now + start_offset + lane_id * NUM_PER_THREADS, &load_vec); + // get max value per thread + float max_value_thread = -5e4; +#pragma unroll + for (int vid = 0; vid < NUM_PER_THREADS; vid++) { + load_vec_float[vid] = static_cast(load_vec[vid]); + max_value_thread = max(abs(load_vec_float[vid]), max_value_thread); + } + // get max value per warp + max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 16), + max_value_thread); + max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 8), + max_value_thread); + max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 4), + max_value_thread); + max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 2), + max_value_thread); + max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 1), + max_value_thread); + // broadcast max_value + max_value_thread = __shfl_sync(0xFFFFFFFF, max_value_thread, 0); + max_value_thread = max(max_value_thread, epsilon); + + if (use_finegrained_range) { + max_value_thread *= 7.0f; + } + + float scale_to_store = max_value_thread / MAX_VALUE; + // quant +#pragma unroll + for (int vid = 0; vid < NUM_PER_THREADS; vid++) { + res_vec[vid] = static_cast( + load_vec_float[vid] * MAX_VALUE / max_value_thread); + } + // store + Store( + res_vec, quanted_res_now + start_offset + lane_id * NUM_PER_THREADS); + if (lane_id == 0) { + *quanted_scale_now = scale_to_store; + } + } + } +} + +std::vector PerTokenQuantPadding(paddle::Tensor &input, + const int block_size) { + using ScaleDtype = float; + + auto input_dim = input.dims(); + const int token_num = input_dim[0]; + const int hidden_size = input_dim[1]; + + PADDLE_ENFORCE(block_size == 128, "now only support block_size = 128"); + PADDLE_ENFORCE(hidden_size % 128 == 0, + "hidden_size must be divisible by 128"); + + const int hidden_size_scale = hidden_size / block_size; + auto quanted_x = GetEmptyTensor( + {token_num, hidden_size}, paddle::DataType::FLOAT8_E4M3FN, input.place()); + + const int tma_alignment_bytes = 16; + const int tma_alignment_elements = tma_alignment_bytes / sizeof(ScaleDtype); + const int padded_token_num = + ((token_num + tma_alignment_elements - 1) / tma_alignment_elements) * + tma_alignment_elements; + auto quanted_scale = GetEmptyTensor({padded_token_num, hidden_size_scale}, + {1, padded_token_num}, + paddle::DataType::FLOAT32, + input.place()); + const int gridx = min(132 * 8, token_num); + const int blockx = min(1024, hidden_size / 128 * 32); + + bool use_finegrained_range = false; + char *env_var = getenv("PER_TOKEN_QUANT_FP8_USE_FINEGRAINED_RANGE"); + if (env_var) { + use_finegrained_range = static_cast(std::stoi(env_var)); + } + + switch (input.dtype()) { + case paddle::DataType::BFLOAT16: + quant_per_token_per_block_padding<<>>( + input.data(), + quanted_x.data(), + quanted_scale.data(), + token_num, + padded_token_num, + hidden_size, + hidden_size_scale, + use_finegrained_range); + break; + case paddle::DataType::FLOAT16: + quant_per_token_per_block_padding<<>>( + input.data(), + quanted_x.data(), + quanted_scale.data(), + token_num, + padded_token_num, + hidden_size, + hidden_size_scale, + use_finegrained_range); + break; + default: + PD_THROW("Unsupported data type for PerTokenQuant"); + } + return {quanted_x, quanted_scale}; +} + +std::vector> PerTokenQuantPaddingInferShape( + std::vector input_shape, const int block_size) { + using ScaleDtype = float; + + const int token_num = input_shape[0]; + const int hidden_size = input_shape[1]; + const int hidden_size_scale = hidden_size / block_size; + + const int tma_alignment_bytes = 16; + const int tma_alignment_elements = tma_alignment_bytes / sizeof(ScaleDtype); + const int padded_token_num = + ((token_num + tma_alignment_elements - 1) / tma_alignment_elements) * + tma_alignment_elements; + + return {{token_num, hidden_size}, {padded_token_num, hidden_size_scale}}; +} + +std::vector PerTokenQuantPaddingInferDtype( + paddle::DataType input_dtype) { + return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::FLOAT32}; +} + +PD_BUILD_STATIC_OP(per_token_quant) + .Inputs({"input"}) + .Outputs({"output", "output_scale"}) + .Attrs({"block_size: int"}) + .SetKernelFn(PD_KERNEL(PerTokenQuant)) + .SetInferShapeFn(PD_INFER_SHAPE(PerTokenQuantInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(PerTokenQuantInferDtype)); + +PD_BUILD_STATIC_OP(per_token_quant_padding) + .Inputs({"input"}) + .Outputs({"output", "output_scale"}) + .Attrs({"block_size: int"}) + .SetKernelFn(PD_KERNEL(PerTokenQuantPadding)) + .SetInferShapeFn(PD_INFER_SHAPE(PerTokenQuantPaddingInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(PerTokenQuantPaddingInferDtype)); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index ab435498f0..efb984702a 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -294,6 +294,7 @@ elif paddle.is_compiled_with_cuda(): "gpu_ops/cpp_extensions.cc", "gpu_ops/share_external_data.cu", "gpu_ops/fused_mask_swiglu_fp8_quant_kernel.cu", + "gpu_ops/per_token_quant_fp8.cu", "gpu_ops/update_split_fuse_input.cu", "gpu_ops/text_image_index_out.cu", "gpu_ops/text_image_gather_scatter.cu", diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 3c87296ca4..54aa45b371 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -186,6 +186,8 @@ environment_variables: dict[str, Callable[[], Any]] = { "FD_XPU_MOE_FFN_QUANT_TYPE_MAP": lambda: os.getenv("FD_XPU_MOE_FFN_QUANT_TYPE_MAP", ""), # Whether to enable low latency in mixed scenario "FD_XPU_ENABLE_MIXED_EP_MODE": lambda: bool(int(os.getenv("FD_XPU_ENABLE_MIXED_EP_MODE", "0"))), + # Whether to use phi FP8 quantization,if 1,use paddle default. + "FD_USE_PHI_FP8_QUANT": lambda: bool(int(os.getenv("FD_USE_PHI_FP8_QUANT", "1"))), # Reserve output blocks for decoding requests when schedule new prefill requests "FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL": lambda: int( os.getenv("FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL", "16") diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index 3d6d27e3af..754957faa7 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -114,12 +114,20 @@ def m_grouped_fp8_gemm_nt_contiguous_custom_python_op( ffn_out = paddle.incubate.nn.functional.swiglu(ffn_out) # down_proj - ffn_in_x, ffn_in_x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise( - ffn_out, - using_pow2_scale=not disable_ue8m0_cast, - using_ue8m0_scale=not disable_ue8m0_cast, - ) - ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.T[: ffn_in_x.shape[0]] + if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT: + ffn_in_x, ffn_in_x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant( + ffn_out, quant_config_weight_block_size_0 + ) + + ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0]).contiguous() + ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0]) + else: + ffn_in_x, ffn_in_x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise( + ffn_out, + using_pow2_scale=not disable_ue8m0_cast, + using_ue8m0_scale=not disable_ue8m0_cast, + ) + ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.T[: ffn_in_x.shape[0]] ffn_out = paddle.empty( (permute_input.shape[0], layer_added_weight_attrs_1.shape[1]), @@ -262,17 +270,22 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): topk_ids_hookfunc(topk_ids=topk_idx) # 2. Dynamic compute blockwise quantization scales - x, x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise( - x, - using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0, - output_scale_transpose=self.quant_config.deepgemm_scale_ue8m0, - using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0, - ) - x_scale_tensor = ( - x_scale_tensor[: x.shape[0]] - if not self.quant_config.deepgemm_scale_ue8m0 - else x_scale_tensor.T[: x.shape[0]] - ) + if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT: + x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant( + x, self.quant_config.weight_block_size[0] + ) + else: + x, x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise( + x, + using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0, + output_scale_transpose=self.quant_config.deepgemm_scale_ue8m0, + using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0, + ) + x_scale_tensor = ( + x_scale_tensor[: x.shape[0]] + if not self.quant_config.deepgemm_scale_ue8m0 + else x_scale_tensor.T[: x.shape[0]] + ) event = deep_ep.Buffer.capture() let_another_thread_run() @@ -348,12 +361,18 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): ffn_out = paddle.incubate.nn.functional.swiglu(ffn_out, None) # down_proj - ffn_in_x, ffn_in_x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise( - ffn_out, - using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0, - using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0, - ) - ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.T[: ffn_in_x.shape[0]] + if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT: + ffn_in_x, ffn_in_x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant( + ffn_out, self.quant_config.weight_block_size[0] + ) + ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0]).contiguous().transpose([1, 0]) + else: + ffn_in_x, ffn_in_x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise( + ffn_out, + using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0, + using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0, + ) + ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.T[: ffn_in_x.shape[0]] del ffn_out ffn_out = paddle.empty( @@ -505,17 +524,22 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): tmp = count_tokens_per_expert_func(topk_ids, layer.num_experts) - recv_x, recv_x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - x, - using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0, - output_scale_transpose=self.quant_config.deepgemm_scale_ue8m0, - using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0, - ) - recv_x_scale = ( - recv_x_scale[: recv_x.shape[0]] - if not self.quant_config.deepgemm_scale_ue8m0 - else recv_x_scale.T[: recv_x.shape[0]] - ) + if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT: + recv_x, recv_x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, 128) + else: + + recv_x, recv_x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + x, + using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0, + output_scale_transpose=self.quant_config.deepgemm_scale_ue8m0, + using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0, + ) + recv_x_scale = ( + recv_x_scale[: recv_x.shape[0]] + if not self.quant_config.deepgemm_scale_ue8m0 + else recv_x_scale.T[: recv_x.shape[0]] + ) + ( permute_input, permute_scale, diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index e56af13009..3f773b5528 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -1232,10 +1232,13 @@ def python_op_fused_moe_kernel_paddle( from .triton_moe_kernels import fused_moe_kernel_paddle - x_q, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - x, using_pow2_scale=False, output_scale_transpose=False - ) - x_scale = x_scale[: x.shape[0]] + if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT: + x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, quant_config.weight_block_size[0]) + else: + x_q, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + x, using_pow2_scale=False, output_scale_transpose=False + ) + x_scale = x_scale[: x.shape[0]] fused_moe_kernel_paddle[grid]( x_q, @@ -1285,11 +1288,15 @@ def python_op_fused_moe_kernel_paddle( intermediate_cache3 = cache13[: token_num * top_k * N2].view([token_num * top_k, N2]) grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) * ceil_div(hidden_size, config["BLOCK_SIZE_N"]),) - - x_q, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - intermediate_cache2, using_pow2_scale=False, output_scale_transpose=False - ) - x_scale = x_scale[: x_q.shape[0]] + if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT: + x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant( + intermediate_cache2, quant_config.weight_block_size[0] + ) + else: + x_q, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + intermediate_cache2, using_pow2_scale=False, output_scale_transpose=False + ) + x_scale = x_scale[: x_q.shape[0]] fused_moe_kernel_paddle[grid]( x_q, diff --git a/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py b/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py index 60f242682a..5d5c134619 100644 --- a/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py +++ b/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py @@ -18,6 +18,7 @@ from typing import Optional import paddle +import fastdeploy from fastdeploy import envs from fastdeploy.model_executor.layers.linear import ( MergedColumnParallelLinear, @@ -289,14 +290,18 @@ class BlockWiseFP8LinearMethod(QuantMethodBase): linear_out = paddle.empty((x.shape[0], layer.output_size), dtype=paddle.bfloat16) if x.shape[0] == 0: return linear_out - - x, x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise( - x, - using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0, - output_scale_transpose=True, - using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0, - ) - x_scale_tensor = x_scale_tensor.T[: x.shape[0], ...] + if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT: + x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant_padding( + x, self.quant_config.weight_block_size[0] + ) + else: + x, x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise( + x, + using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0, + output_scale_transpose=True, + using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0, + ) + x_scale_tensor = x_scale_tensor.T[: x.shape[0], ...] deep_gemm_fp8_gemm_nt( x, x_scale_tensor, diff --git a/tests/operators/test_per_token_quant.py b/tests/operators/test_per_token_quant.py new file mode 100644 index 0000000000..23972ce53c --- /dev/null +++ b/tests/operators/test_per_token_quant.py @@ -0,0 +1,187 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import unittest + +import numpy as np +import paddle +import paddle.nn.functional as F + +from fastdeploy.model_executor.ops.gpu import per_token_quant, per_token_quant_padding + +paddle.seed(2024) + + +def per_token_quant_paddle(input_tensor, block_size): + MAX_VALUE = 448.0 + epsilon = 1e-10 + + input_shape = input_tensor.shape + token_num = input_shape[0] + hidden_size = input_shape[1] + + # According to https://github.com/PaddlePaddle/FastDeploy/pull/3659 + padding_size = (block_size - hidden_size % block_size) % block_size + + padded_input = input_tensor + if padding_size > 0: + padded_input = F.pad(input_tensor, pad=[0, padding_size], mode="constant", value=0.0) + + padded_hidden_size = hidden_size + padding_size + hidden_size_scale = padded_hidden_size // block_size + + reshaped_input = paddle.reshape(padded_input, [token_num, hidden_size_scale, block_size]).astype("float32") + + max_abs_val = paddle.max(paddle.abs(reshaped_input), axis=-1, keepdim=True) + max_abs_val = paddle.clip(max_abs_val, min=epsilon) + scale = max_abs_val / MAX_VALUE + + quanted_value = reshaped_input / scale + + quanted_x_padded_reshaped = quanted_value.to(paddle.float8_e4m3fn) + quanted_x_padded = paddle.reshape(quanted_x_padded_reshaped, [token_num, padded_hidden_size]) + + quanted_x = quanted_x_padded[:, :hidden_size] + + quanted_scale = paddle.squeeze(scale, axis=-1) + + return quanted_x, quanted_scale + + +def per_token_quant_padding_paddle(input_tensor, block_size, dtype): + quanted_x, intermediate_scale = per_token_quant_paddle(input_tensor, block_size) + token_num = input_tensor.shape[0] + + tma_alignment_elements = 4 + padded_token_num = ((token_num + tma_alignment_elements - 1) // tma_alignment_elements) * tma_alignment_elements + + hidden_size_scale = intermediate_scale.shape[1] + padded_scale = paddle.zeros([padded_token_num, hidden_size_scale], dtype="float32") + + padded_scale[:token_num, :] = intermediate_scale + + return quanted_x, padded_scale + + +class TestPerTokenQuant(unittest.TestCase): + def get_input(self, shape, dtype): + return paddle.randn(shape=shape, dtype=dtype) + + def setUp(self) -> None: + self.dtype = paddle.float16 + self.token_num = 4 + self.hidden_size = 500 + self.block_size = 128 + self.input_tensor = self.get_input(shape=[self.token_num, self.hidden_size], dtype=self.dtype) + + def test_per_token_quant(self): + paddle_output, paddle_output_scale = per_token_quant_paddle(self.input_tensor, self.block_size) + output, output_scale = per_token_quant(self.input_tensor, self.block_size) + + np.testing.assert_allclose(paddle_output_scale.numpy(), output_scale.numpy(), rtol=1e-6) + + output_rel_diff = paddle.mean( + paddle.abs(output.to(paddle.float32) - paddle_output.to(paddle.float32)) + ) / paddle.mean(paddle.abs(paddle_output.to(paddle.float32))) + + assert output_rel_diff < 0.001 + + +class TestPerTokenQuantCase1(TestPerTokenQuant): + def setUp(self) -> None: + self.dtype = paddle.float16 + self.token_num = 4 + self.hidden_size = 128 * 6 + self.block_size = 128 + self.input_tensor = self.get_input(shape=[self.token_num, self.hidden_size], dtype=self.dtype) + + +class TestPerTokenQuantCase2(TestPerTokenQuant): + def setUp(self) -> None: + self.dtype = paddle.bfloat16 + self.token_num = 4 + self.hidden_size = 500 + self.block_size = 128 + self.input_tensor = self.get_input(shape=[self.token_num, self.hidden_size], dtype=self.dtype) + + +class TestPerTokenQuantCase3(TestPerTokenQuant): + def setUp(self) -> None: + self.dtype = paddle.bfloat16 + self.token_num = 4 + self.hidden_size = 128 * 6 + self.block_size = 128 + self.input_tensor = self.get_input(shape=[self.token_num, self.hidden_size], dtype=self.dtype) + + +class TestPerTokenQuantPadding(TestPerTokenQuant): + def setUp(self) -> None: + self.dtype = paddle.float16 + self.token_num = 6 + self.hidden_size = 128 * 4 + self.block_size = 128 + self.input_tensor = self.get_input(shape=[self.token_num, self.hidden_size], dtype=self.dtype) + + def test_per_token_quant_padding(self): + paddle_output, paddle_output_scale = per_token_quant_padding_paddle( + self.input_tensor, self.block_size, self.dtype + ) + output, output_scale = per_token_quant_padding(self.input_tensor, self.block_size) + + self.assertEqual(paddle_output_scale.shape, output_scale.shape) + np.testing.assert_allclose( + paddle_output_scale[0 : self.token_num].numpy(), + output_scale[0 : self.token_num].numpy(), + rtol=1e-5, + atol=1e-5, + ) + + output_rel_diff = paddle.mean( + paddle.abs(output.to(paddle.float32) - paddle_output.to(paddle.float32)) + ) / paddle.mean(paddle.abs(paddle_output.to(paddle.float32)) + 1e-9) + + assert output_rel_diff < 0.001 + + +class TestPerTokenQuantPaddingCase1(TestPerTokenQuantPadding): + def setUp(self) -> None: + self.dtype = paddle.float16 + self.token_num = 8 + self.hidden_size = 128 * 4 + self.block_size = 128 + self.input_tensor = self.get_input(shape=[self.token_num, self.hidden_size], dtype=self.dtype) + + +class TestPerTokenQuantPaddingCase2(TestPerTokenQuantPadding): + def setUp(self) -> None: + self.dtype = paddle.bfloat16 + self.token_num = 6 + self.hidden_size = 128 * 4 + self.block_size = 128 + self.input_tensor = self.get_input(shape=[self.token_num, self.hidden_size], dtype=self.dtype) + + +class TestPerTokenQuantPaddingCase3(TestPerTokenQuantPadding): + def setUp(self) -> None: + self.dtype = paddle.bfloat16 + self.token_num = 8 + self.hidden_size = 128 * 4 + self.block_size = 128 + self.input_tensor = self.get_input(shape=[self.token_num, self.hidden_size], dtype=self.dtype) + + +if __name__ == "__main__": + unittest.main()