mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] support w4afp8 v1_loader and v0_loader(tp>1) (#5757)
* support * fix * support w4afp8 v1_loader and v0_loader * fix * fix test * fix test * fix test * fix moe.py * add test_ernie_4_5_w4afp8 * add test * delete tensor * fix test * fix * add * fix test
This commit is contained in:
@@ -884,35 +884,93 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
"""
|
||||
Paddle cutlass create weight process.
|
||||
"""
|
||||
self.weight_dtype = "int8"
|
||||
self.model_format = extra_weight_attrs.get("model_format")
|
||||
|
||||
self.ffn1_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.hidden_size // 2,
|
||||
layer.hidden_size // 2, # 4-bit packing
|
||||
layer.moe_intermediate_size * 2,
|
||||
]
|
||||
self.ffn2_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.moe_intermediate_size // 2,
|
||||
layer.moe_intermediate_size // 2, # 4-bit packing
|
||||
layer.hidden_size,
|
||||
]
|
||||
setattr(
|
||||
layer,
|
||||
self.added_weight_attrs[0],
|
||||
layer.create_parameter(
|
||||
shape=self.ffn1_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
|
||||
if not self.quant_config.is_quantized and layer.fd_config.load_config.load_choices == "default_v1":
|
||||
if self.model_format != "torch":
|
||||
up_gate_proj_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.hidden_size,
|
||||
layer.moe_intermediate_size * 2,
|
||||
]
|
||||
down_proj_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.moe_intermediate_size,
|
||||
layer.hidden_size,
|
||||
]
|
||||
up_gate_proj_attrs = {
|
||||
**extra_weight_attrs,
|
||||
"tensor_track": TensorTracker(shape=up_gate_proj_weight_shape, output_dim=True),
|
||||
}
|
||||
down_proj_attrs = {
|
||||
**extra_weight_attrs,
|
||||
"tensor_track": TensorTracker(shape=down_proj_weight_shape, output_dim=False),
|
||||
}
|
||||
else:
|
||||
up_gate_proj_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.moe_intermediate_size * 2,
|
||||
layer.hidden_size,
|
||||
]
|
||||
down_proj_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.hidden_size,
|
||||
layer.moe_intermediate_size,
|
||||
]
|
||||
up_gate_proj_attrs = {
|
||||
**extra_weight_attrs,
|
||||
"tensor_track": TensorTracker(shape=up_gate_proj_weight_shape, output_dim=False),
|
||||
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "up": 0, "down": 1},
|
||||
}
|
||||
down_proj_attrs = {
|
||||
**extra_weight_attrs,
|
||||
"tensor_track": TensorTracker(shape=down_proj_weight_shape, output_dim=True),
|
||||
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "up": 0, "down": 1},
|
||||
}
|
||||
|
||||
layer.up_gate_proj_weight = layer.create_parameter(
|
||||
shape=up_gate_proj_weight_shape,
|
||||
dtype=layer.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
self.added_weight_attrs[1],
|
||||
layer.create_parameter(
|
||||
shape=self.ffn2_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
)
|
||||
layer.down_proj_weight = layer.create_parameter(
|
||||
shape=down_proj_weight_shape,
|
||||
dtype=layer.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
)
|
||||
set_weight_attrs(layer.up_gate_proj_weight, up_gate_proj_attrs)
|
||||
set_weight_attrs(layer.down_proj_weight, down_proj_attrs)
|
||||
else:
|
||||
self.weight_dtype = "int8"
|
||||
setattr(
|
||||
layer,
|
||||
self.added_weight_attrs[0], # "up_gate_proj_weight"
|
||||
layer.create_parameter(
|
||||
shape=self.ffn1_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
self.added_weight_attrs[1], # "down_proj_weight"
|
||||
layer.create_parameter(
|
||||
shape=self.ffn2_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
|
||||
self.create_w4afp8_scale_weights(layer, layer.weight_key_map)
|
||||
|
||||
@@ -922,22 +980,175 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
dtype=layer.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
layer.down_proj_bias = layer.create_parameter(
|
||||
shape=[layer.num_experts, layer.hidden_size],
|
||||
dtype=layer.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
set_weight_attrs(layer.up_gate_proj_bias, extra_weight_attrs)
|
||||
set_weight_attrs(layer.down_proj_bias, extra_weight_attrs)
|
||||
|
||||
set_weight_attrs(
|
||||
layer.up_gate_proj_bias,
|
||||
extra_weight_attrs,
|
||||
def process_weights_after_loading(self, layer: nn.Layer) -> None:
|
||||
from ..utils import get_orthogonal_matrix
|
||||
|
||||
def _rotate_down_proj_weight():
|
||||
"""
|
||||
Apply Hadamard rotation to down_proj weight
|
||||
"""
|
||||
Q_ffn2, moe_block_size = get_orthogonal_matrix(size=layer.moe_intermediate_size, mode="hadamard_ffn2")
|
||||
down_proj_weight = layer.down_proj_weight
|
||||
original_dtype = down_proj_weight.dtype # bfloat16
|
||||
|
||||
expert_list = [down_proj_weight[i] for i in range(layer.num_local_experts)]
|
||||
|
||||
moe_weight = paddle.concat(expert_list, axis=-1)
|
||||
|
||||
new_moe_weight = Q_ffn2.cast("float32").T @ moe_weight.cast("float32").to(Q_ffn2.place)
|
||||
rotated_list = []
|
||||
for expert_id in range(layer.num_local_experts):
|
||||
start_idx = expert_id * layer.hidden_size
|
||||
end_idx = (expert_id + 1) * layer.hidden_size
|
||||
rotated_weight = new_moe_weight[:, start_idx:end_idx]
|
||||
rotated_list.append(rotated_weight)
|
||||
|
||||
rotated_stacked = paddle.stack(rotated_list, axis=0).cast(original_dtype)
|
||||
layer.down_proj_weight.set_value(rotated_stacked)
|
||||
|
||||
del moe_weight, new_moe_weight, expert_list, rotated_list
|
||||
paddle.device.cuda.empty_cache()
|
||||
|
||||
return moe_block_size
|
||||
|
||||
def _process_quantize(weight_type: str):
|
||||
|
||||
weight_idx = 0 if weight_type == "gate_up" else 1
|
||||
weight_name = self.added_weight_attrs[weight_idx] # "up_gate_proj_weight" or "down_proj_weight"
|
||||
scale_name = self.added_scale_attrs[weight_idx] # "up_gate_proj_weight_scale" or "down_proj_weight_scale"
|
||||
|
||||
weight_dtype = "int8"
|
||||
scale_dtype = "float32"
|
||||
|
||||
block_size = getattr(layer.moe_quant_config, "hadamard_block_size", 512)
|
||||
|
||||
quant_weight_list = []
|
||||
scale_list = []
|
||||
|
||||
for expert_id in range(layer.num_local_experts):
|
||||
expert_weight = getattr(layer, weight_name)[expert_id]
|
||||
|
||||
quant_weight, weight_scale = group_wise_int4_weight_quantize(expert_weight, group_size=128)
|
||||
|
||||
quant_weight = pack(quant_weight.transpose([1, 0]), bits=4)
|
||||
|
||||
if weight_type == "down":
|
||||
weight_scale = weight_scale / (block_size**0.5)
|
||||
|
||||
quant_weight = w4afp8_gemm_weight_convert(quant_weight)
|
||||
|
||||
quant_weight_list.append(quant_weight)
|
||||
scale_list.append(weight_scale)
|
||||
|
||||
free_tensor(getattr(layer, weight_name))
|
||||
|
||||
stacked_quant_weight = paddle.stack(quant_weight_list, axis=0)
|
||||
stacked_scale = paddle.stack(scale_list, axis=0)
|
||||
|
||||
setattr(
|
||||
layer,
|
||||
weight_name,
|
||||
layer.create_parameter(
|
||||
shape=stacked_quant_weight.shape,
|
||||
dtype=weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
set_weight_attrs(
|
||||
layer.down_proj_bias,
|
||||
extra_weight_attrs,
|
||||
processed_scale = stacked_scale / (448 * 7 * 2 ** (-9))
|
||||
|
||||
if len(processed_scale.shape) == 3:
|
||||
if weight_type == "gate_up" and processed_scale.shape[-1] * 128 != layer.hidden_size:
|
||||
assert (
|
||||
layer.hidden_size // 128 % processed_scale.shape[-1] == 0
|
||||
), "weight_scale_group_size must be a multiple of 128"
|
||||
processed_scale = processed_scale.repeat_interleave(
|
||||
layer.hidden_size // 128 // processed_scale.shape[-1], axis=-1
|
||||
)
|
||||
elif weight_type == "down" and processed_scale.shape[-1] * 128 != layer.moe_intermediate_size:
|
||||
assert (
|
||||
layer.moe_intermediate_size // 128 % processed_scale.shape[-1] == 0
|
||||
), "weight_scale_group_size must be a multiple of 128"
|
||||
processed_scale = processed_scale.repeat_interleave(
|
||||
layer.moe_intermediate_size // 128 // processed_scale.shape[-1], axis=-1
|
||||
)
|
||||
|
||||
origin_shape = processed_scale.shape
|
||||
processed_scale = processed_scale.transpose([0, 2, 1])
|
||||
processed_scale = processed_scale.reshape([-1, processed_scale.shape[-1]])
|
||||
processed_scale = w4afp8_gemm_scale_permute(processed_scale)
|
||||
processed_scale = processed_scale.reshape(
|
||||
[origin_shape[0], origin_shape[2], origin_shape[1] // 128, 128]
|
||||
)
|
||||
processed_scale = processed_scale.transpose([0, 2, 1, 3])
|
||||
else:
|
||||
processed_scale = w4afp8_gemm_scale_permute(processed_scale)
|
||||
setattr(
|
||||
layer,
|
||||
scale_name,
|
||||
layer.create_parameter(
|
||||
shape=processed_scale.shape,
|
||||
dtype=scale_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
|
||||
getattr(layer, weight_name).copy_(stacked_quant_weight, False)
|
||||
getattr(layer, scale_name).copy_(processed_scale, False)
|
||||
|
||||
in_scale_name = scale_name.replace("_weight_scale", "_in_scale")
|
||||
if hasattr(layer, in_scale_name):
|
||||
getattr(layer, in_scale_name).set_value(paddle.ones([layer.num_local_experts], dtype="float32"))
|
||||
|
||||
del quant_weight_list, scale_list, stacked_quant_weight, stacked_scale, processed_scale
|
||||
paddle.device.cuda.empty_cache()
|
||||
|
||||
up_gate_ready = hasattr(layer, "up_gate_proj_weight") and weight_fully_copied(layer.up_gate_proj_weight)
|
||||
down_ready = hasattr(layer, "down_proj_weight") and weight_fully_copied(layer.down_proj_weight)
|
||||
|
||||
if not up_gate_ready and not down_ready:
|
||||
return
|
||||
|
||||
if not self.quant_config.is_quantized:
|
||||
if up_gate_ready and not getattr(self, "_up_gate_processed", False):
|
||||
weight_type = "gate_up"
|
||||
self._up_gate_processed = True
|
||||
|
||||
logger.info(f"Online quantizing layer.{layer.layer_idx}.mlp.experts.up_gate_proj.weight...")
|
||||
|
||||
if self.model_format == "torch":
|
||||
process_weight_transpose(layer, "up_gate_proj_weight")
|
||||
|
||||
_process_quantize(weight_type)
|
||||
|
||||
elif down_ready and not getattr(self, "_down_processed", False):
|
||||
weight_type = "down"
|
||||
self._down_processed = True
|
||||
|
||||
logger.info(f"Rotating and online quantizing layer.{layer.layer_idx}.mlp.experts.down_proj.weight...")
|
||||
|
||||
if self.model_format == "torch":
|
||||
process_weight_transpose(layer, "down_proj_weight")
|
||||
|
||||
_rotate_down_proj_weight()
|
||||
|
||||
_process_quantize(weight_type)
|
||||
|
||||
if getattr(self, "_up_gate_processed", False) and getattr(self, "_down_processed", False):
|
||||
logger.info(f"Layer {layer.layer_idx} MoE W4AFP8 online quantization completed.")
|
||||
del self._up_gate_processed
|
||||
del self._down_processed
|
||||
|
||||
else:
|
||||
return
|
||||
|
||||
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Paddle cutlass load weight process.
|
||||
@@ -960,6 +1171,7 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
|
||||
layer.extract_moe_ffn_weights(state_dict)
|
||||
)
|
||||
|
||||
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||
|
||||
up_gate_proj_weight_scales = []
|
||||
|
||||
Reference in New Issue
Block a user