mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] FD_USE_PHI_FP8_QUANT (#6320)
* add ut * add use_fd_quant env * rm mask_per_token_quant * add make ops list * USE_FD_FP8_QUANT -> FD_USE_PHI_FP8_QUANT 默认是true * modify comments * use bool type * Add function declaration
This commit is contained in:
@@ -114,12 +114,20 @@ def m_grouped_fp8_gemm_nt_contiguous_custom_python_op(
|
||||
ffn_out = paddle.incubate.nn.functional.swiglu(ffn_out)
|
||||
|
||||
# down_proj
|
||||
ffn_in_x, ffn_in_x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
||||
ffn_out,
|
||||
using_pow2_scale=not disable_ue8m0_cast,
|
||||
using_ue8m0_scale=not disable_ue8m0_cast,
|
||||
)
|
||||
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.T[: ffn_in_x.shape[0]]
|
||||
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
|
||||
ffn_in_x, ffn_in_x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
|
||||
ffn_out, quant_config_weight_block_size_0
|
||||
)
|
||||
|
||||
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0]).contiguous()
|
||||
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0])
|
||||
else:
|
||||
ffn_in_x, ffn_in_x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
||||
ffn_out,
|
||||
using_pow2_scale=not disable_ue8m0_cast,
|
||||
using_ue8m0_scale=not disable_ue8m0_cast,
|
||||
)
|
||||
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.T[: ffn_in_x.shape[0]]
|
||||
|
||||
ffn_out = paddle.empty(
|
||||
(permute_input.shape[0], layer_added_weight_attrs_1.shape[1]),
|
||||
@@ -262,17 +270,22 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
topk_ids_hookfunc(topk_ids=topk_idx)
|
||||
|
||||
# 2. Dynamic compute blockwise quantization scales
|
||||
x, x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
||||
x,
|
||||
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0,
|
||||
output_scale_transpose=self.quant_config.deepgemm_scale_ue8m0,
|
||||
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
|
||||
)
|
||||
x_scale_tensor = (
|
||||
x_scale_tensor[: x.shape[0]]
|
||||
if not self.quant_config.deepgemm_scale_ue8m0
|
||||
else x_scale_tensor.T[: x.shape[0]]
|
||||
)
|
||||
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
|
||||
x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
|
||||
x, self.quant_config.weight_block_size[0]
|
||||
)
|
||||
else:
|
||||
x, x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
||||
x,
|
||||
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0,
|
||||
output_scale_transpose=self.quant_config.deepgemm_scale_ue8m0,
|
||||
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
|
||||
)
|
||||
x_scale_tensor = (
|
||||
x_scale_tensor[: x.shape[0]]
|
||||
if not self.quant_config.deepgemm_scale_ue8m0
|
||||
else x_scale_tensor.T[: x.shape[0]]
|
||||
)
|
||||
|
||||
event = deep_ep.Buffer.capture()
|
||||
let_another_thread_run()
|
||||
@@ -348,12 +361,18 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
ffn_out = paddle.incubate.nn.functional.swiglu(ffn_out, None)
|
||||
|
||||
# down_proj
|
||||
ffn_in_x, ffn_in_x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
||||
ffn_out,
|
||||
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0,
|
||||
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
|
||||
)
|
||||
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.T[: ffn_in_x.shape[0]]
|
||||
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
|
||||
ffn_in_x, ffn_in_x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
|
||||
ffn_out, self.quant_config.weight_block_size[0]
|
||||
)
|
||||
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0]).contiguous().transpose([1, 0])
|
||||
else:
|
||||
ffn_in_x, ffn_in_x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
||||
ffn_out,
|
||||
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0,
|
||||
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
|
||||
)
|
||||
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.T[: ffn_in_x.shape[0]]
|
||||
|
||||
del ffn_out
|
||||
ffn_out = paddle.empty(
|
||||
@@ -505,17 +524,22 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
|
||||
tmp = count_tokens_per_expert_func(topk_ids, layer.num_experts)
|
||||
|
||||
recv_x, recv_x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
||||
x,
|
||||
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0,
|
||||
output_scale_transpose=self.quant_config.deepgemm_scale_ue8m0,
|
||||
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
|
||||
)
|
||||
recv_x_scale = (
|
||||
recv_x_scale[: recv_x.shape[0]]
|
||||
if not self.quant_config.deepgemm_scale_ue8m0
|
||||
else recv_x_scale.T[: recv_x.shape[0]]
|
||||
)
|
||||
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
|
||||
recv_x, recv_x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, 128)
|
||||
else:
|
||||
|
||||
recv_x, recv_x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
||||
x,
|
||||
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0,
|
||||
output_scale_transpose=self.quant_config.deepgemm_scale_ue8m0,
|
||||
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
|
||||
)
|
||||
recv_x_scale = (
|
||||
recv_x_scale[: recv_x.shape[0]]
|
||||
if not self.quant_config.deepgemm_scale_ue8m0
|
||||
else recv_x_scale.T[: recv_x.shape[0]]
|
||||
)
|
||||
|
||||
(
|
||||
permute_input,
|
||||
permute_scale,
|
||||
|
||||
@@ -1232,10 +1232,13 @@ def python_op_fused_moe_kernel_paddle(
|
||||
|
||||
from .triton_moe_kernels import fused_moe_kernel_paddle
|
||||
|
||||
x_q, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
||||
x, using_pow2_scale=False, output_scale_transpose=False
|
||||
)
|
||||
x_scale = x_scale[: x.shape[0]]
|
||||
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
|
||||
x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, quant_config.weight_block_size[0])
|
||||
else:
|
||||
x_q, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
||||
x, using_pow2_scale=False, output_scale_transpose=False
|
||||
)
|
||||
x_scale = x_scale[: x.shape[0]]
|
||||
|
||||
fused_moe_kernel_paddle[grid](
|
||||
x_q,
|
||||
@@ -1285,11 +1288,15 @@ def python_op_fused_moe_kernel_paddle(
|
||||
intermediate_cache3 = cache13[: token_num * top_k * N2].view([token_num * top_k, N2])
|
||||
|
||||
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) * ceil_div(hidden_size, config["BLOCK_SIZE_N"]),)
|
||||
|
||||
x_q, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
||||
intermediate_cache2, using_pow2_scale=False, output_scale_transpose=False
|
||||
)
|
||||
x_scale = x_scale[: x_q.shape[0]]
|
||||
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
|
||||
x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(
|
||||
intermediate_cache2, quant_config.weight_block_size[0]
|
||||
)
|
||||
else:
|
||||
x_q, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
||||
intermediate_cache2, using_pow2_scale=False, output_scale_transpose=False
|
||||
)
|
||||
x_scale = x_scale[: x_q.shape[0]]
|
||||
|
||||
fused_moe_kernel_paddle[grid](
|
||||
x_q,
|
||||
|
||||
Reference in New Issue
Block a user