[Feature] support w4afp8 v1_loader and v0_loader(tp>1) (#5757)

* support

* fix

* support w4afp8 v1_loader and v0_loader

* fix

* fix test

* fix test

* fix test

* fix moe.py

* add test_ernie_4_5_w4afp8

* add test

* delete tensor

* fix test

* fix

* add

* fix test
This commit is contained in:
lizexu123
2025-12-30 14:11:52 +08:00
committed by GitHub
parent e78e22ebd5
commit 44a13e4557
7 changed files with 615 additions and 31 deletions
@@ -884,35 +884,93 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
"""
Paddle cutlass create weight process.
"""
self.weight_dtype = "int8"
self.model_format = extra_weight_attrs.get("model_format")
self.ffn1_weight_shape = [
layer.num_local_experts,
layer.hidden_size // 2,
layer.hidden_size // 2, # 4-bit packing
layer.moe_intermediate_size * 2,
]
self.ffn2_weight_shape = [
layer.num_local_experts,
layer.moe_intermediate_size // 2,
layer.moe_intermediate_size // 2, # 4-bit packing
layer.hidden_size,
]
setattr(
layer,
self.added_weight_attrs[0],
layer.create_parameter(
shape=self.ffn1_weight_shape,
dtype=self.weight_dtype,
if not self.quant_config.is_quantized and layer.fd_config.load_config.load_choices == "default_v1":
if self.model_format != "torch":
up_gate_proj_weight_shape = [
layer.num_local_experts,
layer.hidden_size,
layer.moe_intermediate_size * 2,
]
down_proj_weight_shape = [
layer.num_local_experts,
layer.moe_intermediate_size,
layer.hidden_size,
]
up_gate_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=up_gate_proj_weight_shape, output_dim=True),
}
down_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=down_proj_weight_shape, output_dim=False),
}
else:
up_gate_proj_weight_shape = [
layer.num_local_experts,
layer.moe_intermediate_size * 2,
layer.hidden_size,
]
down_proj_weight_shape = [
layer.num_local_experts,
layer.hidden_size,
layer.moe_intermediate_size,
]
up_gate_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=up_gate_proj_weight_shape, output_dim=False),
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "up": 0, "down": 1},
}
down_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=down_proj_weight_shape, output_dim=True),
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "up": 0, "down": 1},
}
layer.up_gate_proj_weight = layer.create_parameter(
shape=up_gate_proj_weight_shape,
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
self.added_weight_attrs[1],
layer.create_parameter(
shape=self.ffn2_weight_shape,
dtype=self.weight_dtype,
)
layer.down_proj_weight = layer.create_parameter(
shape=down_proj_weight_shape,
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
)
set_weight_attrs(layer.up_gate_proj_weight, up_gate_proj_attrs)
set_weight_attrs(layer.down_proj_weight, down_proj_attrs)
else:
self.weight_dtype = "int8"
setattr(
layer,
self.added_weight_attrs[0], # "up_gate_proj_weight"
layer.create_parameter(
shape=self.ffn1_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
self.added_weight_attrs[1], # "down_proj_weight"
layer.create_parameter(
shape=self.ffn2_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
self.create_w4afp8_scale_weights(layer, layer.weight_key_map)
@@ -922,22 +980,175 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.down_proj_bias = layer.create_parameter(
shape=[layer.num_experts, layer.hidden_size],
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
set_weight_attrs(layer.up_gate_proj_bias, extra_weight_attrs)
set_weight_attrs(layer.down_proj_bias, extra_weight_attrs)
set_weight_attrs(
layer.up_gate_proj_bias,
extra_weight_attrs,
def process_weights_after_loading(self, layer: nn.Layer) -> None:
from ..utils import get_orthogonal_matrix
def _rotate_down_proj_weight():
"""
Apply Hadamard rotation to down_proj weight
"""
Q_ffn2, moe_block_size = get_orthogonal_matrix(size=layer.moe_intermediate_size, mode="hadamard_ffn2")
down_proj_weight = layer.down_proj_weight
original_dtype = down_proj_weight.dtype # bfloat16
expert_list = [down_proj_weight[i] for i in range(layer.num_local_experts)]
moe_weight = paddle.concat(expert_list, axis=-1)
new_moe_weight = Q_ffn2.cast("float32").T @ moe_weight.cast("float32").to(Q_ffn2.place)
rotated_list = []
for expert_id in range(layer.num_local_experts):
start_idx = expert_id * layer.hidden_size
end_idx = (expert_id + 1) * layer.hidden_size
rotated_weight = new_moe_weight[:, start_idx:end_idx]
rotated_list.append(rotated_weight)
rotated_stacked = paddle.stack(rotated_list, axis=0).cast(original_dtype)
layer.down_proj_weight.set_value(rotated_stacked)
del moe_weight, new_moe_weight, expert_list, rotated_list
paddle.device.cuda.empty_cache()
return moe_block_size
def _process_quantize(weight_type: str):
weight_idx = 0 if weight_type == "gate_up" else 1
weight_name = self.added_weight_attrs[weight_idx] # "up_gate_proj_weight" or "down_proj_weight"
scale_name = self.added_scale_attrs[weight_idx] # "up_gate_proj_weight_scale" or "down_proj_weight_scale"
weight_dtype = "int8"
scale_dtype = "float32"
block_size = getattr(layer.moe_quant_config, "hadamard_block_size", 512)
quant_weight_list = []
scale_list = []
for expert_id in range(layer.num_local_experts):
expert_weight = getattr(layer, weight_name)[expert_id]
quant_weight, weight_scale = group_wise_int4_weight_quantize(expert_weight, group_size=128)
quant_weight = pack(quant_weight.transpose([1, 0]), bits=4)
if weight_type == "down":
weight_scale = weight_scale / (block_size**0.5)
quant_weight = w4afp8_gemm_weight_convert(quant_weight)
quant_weight_list.append(quant_weight)
scale_list.append(weight_scale)
free_tensor(getattr(layer, weight_name))
stacked_quant_weight = paddle.stack(quant_weight_list, axis=0)
stacked_scale = paddle.stack(scale_list, axis=0)
setattr(
layer,
weight_name,
layer.create_parameter(
shape=stacked_quant_weight.shape,
dtype=weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
set_weight_attrs(
layer.down_proj_bias,
extra_weight_attrs,
processed_scale = stacked_scale / (448 * 7 * 2 ** (-9))
if len(processed_scale.shape) == 3:
if weight_type == "gate_up" and processed_scale.shape[-1] * 128 != layer.hidden_size:
assert (
layer.hidden_size // 128 % processed_scale.shape[-1] == 0
), "weight_scale_group_size must be a multiple of 128"
processed_scale = processed_scale.repeat_interleave(
layer.hidden_size // 128 // processed_scale.shape[-1], axis=-1
)
elif weight_type == "down" and processed_scale.shape[-1] * 128 != layer.moe_intermediate_size:
assert (
layer.moe_intermediate_size // 128 % processed_scale.shape[-1] == 0
), "weight_scale_group_size must be a multiple of 128"
processed_scale = processed_scale.repeat_interleave(
layer.moe_intermediate_size // 128 // processed_scale.shape[-1], axis=-1
)
origin_shape = processed_scale.shape
processed_scale = processed_scale.transpose([0, 2, 1])
processed_scale = processed_scale.reshape([-1, processed_scale.shape[-1]])
processed_scale = w4afp8_gemm_scale_permute(processed_scale)
processed_scale = processed_scale.reshape(
[origin_shape[0], origin_shape[2], origin_shape[1] // 128, 128]
)
processed_scale = processed_scale.transpose([0, 2, 1, 3])
else:
processed_scale = w4afp8_gemm_scale_permute(processed_scale)
setattr(
layer,
scale_name,
layer.create_parameter(
shape=processed_scale.shape,
dtype=scale_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).copy_(stacked_quant_weight, False)
getattr(layer, scale_name).copy_(processed_scale, False)
in_scale_name = scale_name.replace("_weight_scale", "_in_scale")
if hasattr(layer, in_scale_name):
getattr(layer, in_scale_name).set_value(paddle.ones([layer.num_local_experts], dtype="float32"))
del quant_weight_list, scale_list, stacked_quant_weight, stacked_scale, processed_scale
paddle.device.cuda.empty_cache()
up_gate_ready = hasattr(layer, "up_gate_proj_weight") and weight_fully_copied(layer.up_gate_proj_weight)
down_ready = hasattr(layer, "down_proj_weight") and weight_fully_copied(layer.down_proj_weight)
if not up_gate_ready and not down_ready:
return
if not self.quant_config.is_quantized:
if up_gate_ready and not getattr(self, "_up_gate_processed", False):
weight_type = "gate_up"
self._up_gate_processed = True
logger.info(f"Online quantizing layer.{layer.layer_idx}.mlp.experts.up_gate_proj.weight...")
if self.model_format == "torch":
process_weight_transpose(layer, "up_gate_proj_weight")
_process_quantize(weight_type)
elif down_ready and not getattr(self, "_down_processed", False):
weight_type = "down"
self._down_processed = True
logger.info(f"Rotating and online quantizing layer.{layer.layer_idx}.mlp.experts.down_proj.weight...")
if self.model_format == "torch":
process_weight_transpose(layer, "down_proj_weight")
_rotate_down_proj_weight()
_process_quantize(weight_type)
if getattr(self, "_up_gate_processed", False) and getattr(self, "_down_processed", False):
logger.info(f"Layer {layer.layer_idx} MoE W4AFP8 online quantization completed.")
del self._up_gate_processed
del self._down_processed
else:
return
def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""
Paddle cutlass load weight process.
@@ -960,6 +1171,7 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
layer.extract_moe_ffn_weights(state_dict)
)
self.check(layer, up_gate_proj_weights, down_proj_weights)
up_gate_proj_weight_scales = []