[Feature] FD_USE_PHI_FP8_QUANT (#6320)

* add ut

* add use_fd_quant env

* rm mask_per_token_quant

* add make ops list

* USE_FD_FP8_QUANT -> FD_USE_PHI_FP8_QUANT 默认是true

* modify comments

* use bool type

* Add function declaration
This commit is contained in:
fxyfxy777
2026-02-04 14:33:03 +08:00
committed by GitHub
parent 2ffcb3d9ed
commit 36547cfdb3
8 changed files with 634 additions and 51 deletions
@@ -114,12 +114,20 @@ def m_grouped_fp8_gemm_nt_contiguous_custom_python_op(
ffn_out = paddle.incubate.nn.functional.swiglu(ffn_out)
# down_proj
ffn_in_x, ffn_in_x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
ffn_out,
using_pow2_scale=not disable_ue8m0_cast,
using_ue8m0_scale=not disable_ue8m0_cast,
)
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.T[: ffn_in_x.shape[0]]
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
ffn_in_x, ffn_in_x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
ffn_out, quant_config_weight_block_size_0
)
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0]).contiguous()
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0])
else:
ffn_in_x, ffn_in_x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
ffn_out,
using_pow2_scale=not disable_ue8m0_cast,
using_ue8m0_scale=not disable_ue8m0_cast,
)
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.T[: ffn_in_x.shape[0]]
ffn_out = paddle.empty(
(permute_input.shape[0], layer_added_weight_attrs_1.shape[1]),
@@ -262,17 +270,22 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
topk_ids_hookfunc(topk_ids=topk_idx)
# 2. Dynamic compute blockwise quantization scales
x, x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
x,
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0,
output_scale_transpose=self.quant_config.deepgemm_scale_ue8m0,
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
)
x_scale_tensor = (
x_scale_tensor[: x.shape[0]]
if not self.quant_config.deepgemm_scale_ue8m0
else x_scale_tensor.T[: x.shape[0]]
)
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
x, self.quant_config.weight_block_size[0]
)
else:
x, x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
x,
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0,
output_scale_transpose=self.quant_config.deepgemm_scale_ue8m0,
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
)
x_scale_tensor = (
x_scale_tensor[: x.shape[0]]
if not self.quant_config.deepgemm_scale_ue8m0
else x_scale_tensor.T[: x.shape[0]]
)
event = deep_ep.Buffer.capture()
let_another_thread_run()
@@ -348,12 +361,18 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
ffn_out = paddle.incubate.nn.functional.swiglu(ffn_out, None)
# down_proj
ffn_in_x, ffn_in_x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
ffn_out,
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0,
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
)
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.T[: ffn_in_x.shape[0]]
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
ffn_in_x, ffn_in_x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
ffn_out, self.quant_config.weight_block_size[0]
)
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0]).contiguous().transpose([1, 0])
else:
ffn_in_x, ffn_in_x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
ffn_out,
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0,
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
)
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.T[: ffn_in_x.shape[0]]
del ffn_out
ffn_out = paddle.empty(
@@ -505,17 +524,22 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
tmp = count_tokens_per_expert_func(topk_ids, layer.num_experts)
recv_x, recv_x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
x,
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0,
output_scale_transpose=self.quant_config.deepgemm_scale_ue8m0,
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
)
recv_x_scale = (
recv_x_scale[: recv_x.shape[0]]
if not self.quant_config.deepgemm_scale_ue8m0
else recv_x_scale.T[: recv_x.shape[0]]
)
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
recv_x, recv_x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, 128)
else:
recv_x, recv_x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
x,
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0,
output_scale_transpose=self.quant_config.deepgemm_scale_ue8m0,
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
)
recv_x_scale = (
recv_x_scale[: recv_x.shape[0]]
if not self.quant_config.deepgemm_scale_ue8m0
else recv_x_scale.T[: recv_x.shape[0]]
)
(
permute_input,
permute_scale,
@@ -1232,10 +1232,13 @@ def python_op_fused_moe_kernel_paddle(
from .triton_moe_kernels import fused_moe_kernel_paddle
x_q, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
x, using_pow2_scale=False, output_scale_transpose=False
)
x_scale = x_scale[: x.shape[0]]
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, quant_config.weight_block_size[0])
else:
x_q, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
x, using_pow2_scale=False, output_scale_transpose=False
)
x_scale = x_scale[: x.shape[0]]
fused_moe_kernel_paddle[grid](
x_q,
@@ -1285,11 +1288,15 @@ def python_op_fused_moe_kernel_paddle(
intermediate_cache3 = cache13[: token_num * top_k * N2].view([token_num * top_k, N2])
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) * ceil_div(hidden_size, config["BLOCK_SIZE_N"]),)
x_q, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
intermediate_cache2, using_pow2_scale=False, output_scale_transpose=False
)
x_scale = x_scale[: x_q.shape[0]]
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(
intermediate_cache2, quant_config.weight_block_size[0]
)
else:
x_q, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
intermediate_cache2, using_pow2_scale=False, output_scale_transpose=False
)
x_scale = x_scale[: x_q.shape[0]]
fused_moe_kernel_paddle[grid](
x_q,