Revert "【New Feature】W4afp8 supports per group quantization (#4272)" (#4854)

This reverts commit 93fcf7e4ec.
This commit is contained in:
YuBaoku
2025-11-06 17:48:28 +08:00
committed by GitHub
parent 3478d20262
commit 819b2dbbae
26 changed files with 1718 additions and 4378 deletions
+65 -89
View File
@@ -18,8 +18,8 @@
#include "cutlass_kernels/w4a8_moe/w4a8_moe_gemm_kernel.h"
#include "group_swiglu_with_masked.h"
#include "helper.h"
#include "moe/fast_hardmard/fast_hardamard_kernel.h"
#include "moe/fused_moe_helper.h"
#include "moe/moe_fast_hardamard_kernel.h"
#include "swigluoai.h"
#include "w4afp8_gemm/w4afp8_gemm.h"
@@ -28,7 +28,6 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_proj_in_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale,
@@ -198,20 +197,31 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
typedef PDTraits<paddle::DataType::FLOAT8_E4M3FN> traits_fp8;
typedef typename traits_fp8::DataType DataType_fp8;
typedef typename traits_fp8::data_t data_t_fp8;
paddle::Tensor weight_scale_tensor =
*const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr());
const int weight_scale_group_size = weight_scale_tensor.dims().size() == 2
? hidden_size
: weight_scale_tensor.dims()[3];
const float* input_dequant_scale =
up_proj_in_scale ? up_proj_in_scale.get().data<float>() : nullptr;
Allocator::AllocationPtr ffn1_input_row_sum;
ffn1_input_row_sum =
allocator->Allocate(sizeof(float) * expanded_active_expert_rows);
compute_row_sum(
permute_input.data<data_t_fp8>(),
expanded_active_expert_rows,
hidden_size,
reinterpret_cast<float*>(ffn1_input_row_sum->ptr()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
num_max_tokens_per_expert,
used_in_ep_low_latency,
stream);
float* row_scale = nullptr;
DisPatchW4AFp8GemmWrapper(
reinterpret_cast<const DataType_fp8*>(permute_input.data<data_t_fp8>()),
reinterpret_cast<const DataType_fp8*>(
up_gate_proj_weight.data<int8_t>()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
input_dequant_scale,
weight_scale_tensor.data<float>(),
reinterpret_cast<float*>(ffn1_input_row_sum->ptr()),
row_scale,
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr())
->data<float>(),
reinterpret_cast<NvType*>(fc1_out),
used_in_ep_low_latency ? num_max_tokens_per_expert : 0,
used_in_ep_low_latency ? num_max_tokens_per_expert
@@ -219,7 +229,6 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
num_experts,
inter_size,
hidden_size,
weight_scale_group_size,
stream);
} else {
typename cutlass::WintQuantTraits<
@@ -346,84 +355,60 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
} else if (quant_method == "w4afp8") {
data_t* ffn2_shift = nullptr;
data_t* ffn2_smooth = nullptr;
float* input_dequant_scale = nullptr;
float* row_scale = nullptr;
Allocator::AllocationPtr fp8_act_out;
fp8_act_out = allocator->Allocate(SizeOf(paddle::DataType::INT8) *
act_out_tensor.numel());
Allocator::AllocationPtr ffn2_input_row_sum;
ffn2_input_row_sum =
allocator->Allocate(sizeof(float) * expanded_active_expert_rows);
if (down_proj_in_scale) {
MoeFastHardamardWrapper<data_t, data_t_fp8>(
act_out_tensor.data<data_t>(),
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>()
: nullptr,
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
ffn2_shift,
ffn2_smooth,
down_proj_in_scale
? const_cast<paddle::Tensor*>(down_proj_in_scale.get_ptr())
->data<float>()
: nullptr,
1,
448.0f,
-448.0f,
expanded_active_expert_rows,
inter_size / 2,
num_max_tokens_per_expert,
used_in_ep_low_latency,
hadamard_block_size,
reinterpret_cast<data_t_fp8*>(fp8_act_out->ptr()),
stream);
} else {
Allocator::AllocationPtr ffn2_input_dequant_scale;
ffn2_input_dequant_scale =
allocator->Allocate(sizeof(float) * expanded_active_expert_rows);
input_dequant_scale =
reinterpret_cast<float*>(ffn2_input_dequant_scale->ptr());
MoeFastHardamardWrapper<data_t, data_t>(
act_out_tensor.data<data_t>(),
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>()
: nullptr,
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
ffn2_shift, // ffn2_shift->data<T>(),
ffn2_smooth, // ffn2_smooth->data<T>(),
nullptr,
1,
448.0f,
-448.0f,
expanded_active_expert_rows,
inter_size / 2,
num_max_tokens_per_expert,
used_in_ep_low_latency,
hadamard_block_size,
act_out_tensor.data<data_t>(),
stream);
// note(yuanxiaolan): optimize this
MoeFastHardamardWrapper<data_t, data_t>(
act_out_tensor.data<data_t>(),
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>()
: nullptr,
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
ffn2_shift, // ffn2_shift->data<T>(),
ffn2_smooth, // ffn2_smooth->data<T>(),
nullptr,
1,
448.0f,
-448.0f,
expanded_active_expert_rows,
inter_size / 2,
num_max_tokens_per_expert,
used_in_ep_low_latency,
hadamard_block_size,
act_out_tensor.data<data_t>(),
stream);
quantize_moe_input<data_t, data_t_fp8>(
act_out_tensor.data<data_t>(),
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>()
: nullptr,
expanded_active_expert_rows,
inter_size / 2,
input_dequant_scale,
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
num_max_tokens_per_expert,
used_in_ep_low_latency,
reinterpret_cast<data_t_fp8*>(fp8_act_out->ptr()),
stream);
}
paddle::Tensor weight_scale_tensor =
*const_cast<paddle::Tensor*>(down_proj_scale.get_ptr());
const int weight_scale_group_size = weight_scale_tensor.dims().size() == 2
? inter_size / 2
: weight_scale_tensor.dims()[3];
quantize_moe_input<data_t, data_t_fp8>(
act_out_tensor.data<data_t>(),
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>()
: nullptr,
down_proj_in_scale
? const_cast<paddle::Tensor*>(down_proj_in_scale.get_ptr())
->data<float>()
: nullptr,
448.0f,
-448.0f,
expanded_active_expert_rows,
inter_size / 2,
reinterpret_cast<float*>(ffn2_input_row_sum->ptr()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
num_max_tokens_per_expert,
used_in_ep_low_latency,
reinterpret_cast<data_t_fp8*>(fp8_act_out->ptr()),
stream);
DisPatchW4AFp8GemmWrapper(
reinterpret_cast<const DataType_fp8*>(fp8_act_out->ptr()),
reinterpret_cast<const DataType_fp8*>(down_proj_weight.data<int8_t>()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
input_dequant_scale,
weight_scale_tensor.data<float>(),
reinterpret_cast<float*>(ffn2_input_row_sum->ptr()),
row_scale,
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr())->data<float>(),
reinterpret_cast<NvType*>(ffn_out_data),
used_in_ep_low_latency ? num_max_tokens_per_expert : 0,
used_in_ep_low_latency ? num_max_tokens_per_expert
@@ -431,7 +416,6 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
num_experts,
hidden_size,
inter_size / 2,
weight_scale_group_size,
stream);
} else {
typename cutlass::WintQuantTraits<
@@ -458,7 +442,6 @@ paddle::Tensor MoeExpertFFNFunc(
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_proj_in_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale,
@@ -483,7 +466,6 @@ paddle::Tensor MoeExpertFFNFunc(
tokens_expert_prefix_sum,
up_gate_proj_weight,
down_proj_weight,
up_proj_in_scale,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
@@ -501,7 +483,6 @@ paddle::Tensor MoeExpertFFNFunc(
tokens_expert_prefix_sum,
up_gate_proj_weight,
down_proj_weight,
up_proj_in_scale,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
@@ -525,7 +506,6 @@ std::vector<paddle::Tensor> MoeExpertFFN(
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_proj_in_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale,
@@ -540,7 +520,6 @@ std::vector<paddle::Tensor> MoeExpertFFN(
tokens_expert_prefix_sum,
up_gate_proj_weight,
down_proj_weight,
up_proj_in_scale,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
@@ -558,7 +537,6 @@ std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
const std::vector<int64_t>& tokens_expert_prefix_sum_shape,
const std::vector<int64_t>& up_gate_proj_weight_shape,
const std::vector<int64_t>& down_proj_weight_shape,
const paddle::optional<std::vector<int64_t>>& up_proj_in_scale_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_bias_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_scale_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_scale_shape,
@@ -577,7 +555,6 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
const paddle::DataType& tokens_expert_prefix_sum_dtype,
const paddle::DataType& up_gate_proj_weight_dtype,
const paddle::DataType& down_proj_weight_dtype,
const paddle::optional<paddle::DataType>& up_proj_in_scale_dtype,
const paddle::optional<paddle::DataType>& up_gate_proj_bias_dtype,
const paddle::optional<paddle::DataType>& up_gate_proj_scale_dtype,
const paddle::optional<paddle::DataType>& down_proj_scale_dtype,
@@ -655,7 +632,6 @@ PD_BUILD_STATIC_OP(moe_expert_ffn)
"tokens_expert_prefix_sum",
"up_gate_proj_weight",
"down_proj_weight",
paddle::Optional("up_proj_in_scale"),
paddle::Optional("up_gate_proj_bias"),
paddle::Optional("up_gate_proj_scale"),
paddle::Optional("down_proj_scale"),