mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] support w4afp8 v1_loader and v0_loader(tp>1) (#5757)
* support * fix * support w4afp8 v1_loader and v0_loader * fix * fix test * fix test * fix test * fix moe.py * add test_ernie_4_5_w4afp8 * add test * delete tensor * fix test * fix * add * fix test
This commit is contained in:
@@ -14,7 +14,8 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
file_dir = "./gpu_ops/w4afp8_gemm/"
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
file_dir = os.path.join(script_dir, "..", "gpu_ops", "w4afp8_gemm") + os.sep
|
||||||
|
|
||||||
gemm_template_head = """
|
gemm_template_head = """
|
||||||
#pragma once
|
#pragma once
|
||||||
@@ -85,7 +86,15 @@ void w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# [M, K, Number of experts, token Padding Size, weight K group size]
|
# [M, K, Number of experts, token Padding Size, weight K group size]
|
||||||
gemm_case = [[256, 256, 2, 0, 128], [512, 256, 2, 0, 128], [256, 5120, 128, 0, 128]]
|
gemm_case = [
|
||||||
|
[256, 256, 2, 0, 128],
|
||||||
|
[512, 256, 2, 0, 128],
|
||||||
|
[256, 5120, 128, 0, 128],
|
||||||
|
[3072, 2560, 64, 0, 128],
|
||||||
|
[2560, 1536, 64, 0, 128],
|
||||||
|
[1536, 2560, 64, 0, 128],
|
||||||
|
[2560, 768, 64, 0, 128],
|
||||||
|
]
|
||||||
|
|
||||||
dtype = ["BF16"]
|
dtype = ["BF16"]
|
||||||
|
|
||||||
|
|||||||
@@ -884,35 +884,93 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
|||||||
"""
|
"""
|
||||||
Paddle cutlass create weight process.
|
Paddle cutlass create weight process.
|
||||||
"""
|
"""
|
||||||
self.weight_dtype = "int8"
|
self.model_format = extra_weight_attrs.get("model_format")
|
||||||
|
|
||||||
self.ffn1_weight_shape = [
|
self.ffn1_weight_shape = [
|
||||||
layer.num_local_experts,
|
layer.num_local_experts,
|
||||||
layer.hidden_size // 2,
|
layer.hidden_size // 2, # 4-bit packing
|
||||||
layer.moe_intermediate_size * 2,
|
layer.moe_intermediate_size * 2,
|
||||||
]
|
]
|
||||||
self.ffn2_weight_shape = [
|
self.ffn2_weight_shape = [
|
||||||
layer.num_local_experts,
|
layer.num_local_experts,
|
||||||
layer.moe_intermediate_size // 2,
|
layer.moe_intermediate_size // 2, # 4-bit packing
|
||||||
layer.hidden_size,
|
layer.hidden_size,
|
||||||
]
|
]
|
||||||
setattr(
|
|
||||||
layer,
|
if not self.quant_config.is_quantized and layer.fd_config.load_config.load_choices == "default_v1":
|
||||||
self.added_weight_attrs[0],
|
if self.model_format != "torch":
|
||||||
layer.create_parameter(
|
up_gate_proj_weight_shape = [
|
||||||
shape=self.ffn1_weight_shape,
|
layer.num_local_experts,
|
||||||
dtype=self.weight_dtype,
|
layer.hidden_size,
|
||||||
|
layer.moe_intermediate_size * 2,
|
||||||
|
]
|
||||||
|
down_proj_weight_shape = [
|
||||||
|
layer.num_local_experts,
|
||||||
|
layer.moe_intermediate_size,
|
||||||
|
layer.hidden_size,
|
||||||
|
]
|
||||||
|
up_gate_proj_attrs = {
|
||||||
|
**extra_weight_attrs,
|
||||||
|
"tensor_track": TensorTracker(shape=up_gate_proj_weight_shape, output_dim=True),
|
||||||
|
}
|
||||||
|
down_proj_attrs = {
|
||||||
|
**extra_weight_attrs,
|
||||||
|
"tensor_track": TensorTracker(shape=down_proj_weight_shape, output_dim=False),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
up_gate_proj_weight_shape = [
|
||||||
|
layer.num_local_experts,
|
||||||
|
layer.moe_intermediate_size * 2,
|
||||||
|
layer.hidden_size,
|
||||||
|
]
|
||||||
|
down_proj_weight_shape = [
|
||||||
|
layer.num_local_experts,
|
||||||
|
layer.hidden_size,
|
||||||
|
layer.moe_intermediate_size,
|
||||||
|
]
|
||||||
|
up_gate_proj_attrs = {
|
||||||
|
**extra_weight_attrs,
|
||||||
|
"tensor_track": TensorTracker(shape=up_gate_proj_weight_shape, output_dim=False),
|
||||||
|
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "up": 0, "down": 1},
|
||||||
|
}
|
||||||
|
down_proj_attrs = {
|
||||||
|
**extra_weight_attrs,
|
||||||
|
"tensor_track": TensorTracker(shape=down_proj_weight_shape, output_dim=True),
|
||||||
|
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "up": 0, "down": 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
layer.up_gate_proj_weight = layer.create_parameter(
|
||||||
|
shape=up_gate_proj_weight_shape,
|
||||||
|
dtype=layer.weight_dtype,
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
),
|
)
|
||||||
)
|
layer.down_proj_weight = layer.create_parameter(
|
||||||
setattr(
|
shape=down_proj_weight_shape,
|
||||||
layer,
|
dtype=layer.weight_dtype,
|
||||||
self.added_weight_attrs[1],
|
|
||||||
layer.create_parameter(
|
|
||||||
shape=self.ffn2_weight_shape,
|
|
||||||
dtype=self.weight_dtype,
|
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
),
|
)
|
||||||
)
|
set_weight_attrs(layer.up_gate_proj_weight, up_gate_proj_attrs)
|
||||||
|
set_weight_attrs(layer.down_proj_weight, down_proj_attrs)
|
||||||
|
else:
|
||||||
|
self.weight_dtype = "int8"
|
||||||
|
setattr(
|
||||||
|
layer,
|
||||||
|
self.added_weight_attrs[0], # "up_gate_proj_weight"
|
||||||
|
layer.create_parameter(
|
||||||
|
shape=self.ffn1_weight_shape,
|
||||||
|
dtype=self.weight_dtype,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
setattr(
|
||||||
|
layer,
|
||||||
|
self.added_weight_attrs[1], # "down_proj_weight"
|
||||||
|
layer.create_parameter(
|
||||||
|
shape=self.ffn2_weight_shape,
|
||||||
|
dtype=self.weight_dtype,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
self.create_w4afp8_scale_weights(layer, layer.weight_key_map)
|
self.create_w4afp8_scale_weights(layer, layer.weight_key_map)
|
||||||
|
|
||||||
@@ -922,22 +980,175 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
|||||||
dtype=layer.weight_dtype,
|
dtype=layer.weight_dtype,
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
)
|
)
|
||||||
|
|
||||||
layer.down_proj_bias = layer.create_parameter(
|
layer.down_proj_bias = layer.create_parameter(
|
||||||
shape=[layer.num_experts, layer.hidden_size],
|
shape=[layer.num_experts, layer.hidden_size],
|
||||||
dtype=layer.weight_dtype,
|
dtype=layer.weight_dtype,
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
)
|
)
|
||||||
|
set_weight_attrs(layer.up_gate_proj_bias, extra_weight_attrs)
|
||||||
|
set_weight_attrs(layer.down_proj_bias, extra_weight_attrs)
|
||||||
|
|
||||||
set_weight_attrs(
|
def process_weights_after_loading(self, layer: nn.Layer) -> None:
|
||||||
layer.up_gate_proj_bias,
|
from ..utils import get_orthogonal_matrix
|
||||||
extra_weight_attrs,
|
|
||||||
|
def _rotate_down_proj_weight():
|
||||||
|
"""
|
||||||
|
Apply Hadamard rotation to down_proj weight
|
||||||
|
"""
|
||||||
|
Q_ffn2, moe_block_size = get_orthogonal_matrix(size=layer.moe_intermediate_size, mode="hadamard_ffn2")
|
||||||
|
down_proj_weight = layer.down_proj_weight
|
||||||
|
original_dtype = down_proj_weight.dtype # bfloat16
|
||||||
|
|
||||||
|
expert_list = [down_proj_weight[i] for i in range(layer.num_local_experts)]
|
||||||
|
|
||||||
|
moe_weight = paddle.concat(expert_list, axis=-1)
|
||||||
|
|
||||||
|
new_moe_weight = Q_ffn2.cast("float32").T @ moe_weight.cast("float32").to(Q_ffn2.place)
|
||||||
|
rotated_list = []
|
||||||
|
for expert_id in range(layer.num_local_experts):
|
||||||
|
start_idx = expert_id * layer.hidden_size
|
||||||
|
end_idx = (expert_id + 1) * layer.hidden_size
|
||||||
|
rotated_weight = new_moe_weight[:, start_idx:end_idx]
|
||||||
|
rotated_list.append(rotated_weight)
|
||||||
|
|
||||||
|
rotated_stacked = paddle.stack(rotated_list, axis=0).cast(original_dtype)
|
||||||
|
layer.down_proj_weight.set_value(rotated_stacked)
|
||||||
|
|
||||||
|
del moe_weight, new_moe_weight, expert_list, rotated_list
|
||||||
|
paddle.device.cuda.empty_cache()
|
||||||
|
|
||||||
|
return moe_block_size
|
||||||
|
|
||||||
|
def _process_quantize(weight_type: str):
|
||||||
|
|
||||||
|
weight_idx = 0 if weight_type == "gate_up" else 1
|
||||||
|
weight_name = self.added_weight_attrs[weight_idx] # "up_gate_proj_weight" or "down_proj_weight"
|
||||||
|
scale_name = self.added_scale_attrs[weight_idx] # "up_gate_proj_weight_scale" or "down_proj_weight_scale"
|
||||||
|
|
||||||
|
weight_dtype = "int8"
|
||||||
|
scale_dtype = "float32"
|
||||||
|
|
||||||
|
block_size = getattr(layer.moe_quant_config, "hadamard_block_size", 512)
|
||||||
|
|
||||||
|
quant_weight_list = []
|
||||||
|
scale_list = []
|
||||||
|
|
||||||
|
for expert_id in range(layer.num_local_experts):
|
||||||
|
expert_weight = getattr(layer, weight_name)[expert_id]
|
||||||
|
|
||||||
|
quant_weight, weight_scale = group_wise_int4_weight_quantize(expert_weight, group_size=128)
|
||||||
|
|
||||||
|
quant_weight = pack(quant_weight.transpose([1, 0]), bits=4)
|
||||||
|
|
||||||
|
if weight_type == "down":
|
||||||
|
weight_scale = weight_scale / (block_size**0.5)
|
||||||
|
|
||||||
|
quant_weight = w4afp8_gemm_weight_convert(quant_weight)
|
||||||
|
|
||||||
|
quant_weight_list.append(quant_weight)
|
||||||
|
scale_list.append(weight_scale)
|
||||||
|
|
||||||
|
free_tensor(getattr(layer, weight_name))
|
||||||
|
|
||||||
|
stacked_quant_weight = paddle.stack(quant_weight_list, axis=0)
|
||||||
|
stacked_scale = paddle.stack(scale_list, axis=0)
|
||||||
|
|
||||||
|
setattr(
|
||||||
|
layer,
|
||||||
|
weight_name,
|
||||||
|
layer.create_parameter(
|
||||||
|
shape=stacked_quant_weight.shape,
|
||||||
|
dtype=weight_dtype,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
set_weight_attrs(
|
processed_scale = stacked_scale / (448 * 7 * 2 ** (-9))
|
||||||
layer.down_proj_bias,
|
|
||||||
extra_weight_attrs,
|
if len(processed_scale.shape) == 3:
|
||||||
|
if weight_type == "gate_up" and processed_scale.shape[-1] * 128 != layer.hidden_size:
|
||||||
|
assert (
|
||||||
|
layer.hidden_size // 128 % processed_scale.shape[-1] == 0
|
||||||
|
), "weight_scale_group_size must be a multiple of 128"
|
||||||
|
processed_scale = processed_scale.repeat_interleave(
|
||||||
|
layer.hidden_size // 128 // processed_scale.shape[-1], axis=-1
|
||||||
|
)
|
||||||
|
elif weight_type == "down" and processed_scale.shape[-1] * 128 != layer.moe_intermediate_size:
|
||||||
|
assert (
|
||||||
|
layer.moe_intermediate_size // 128 % processed_scale.shape[-1] == 0
|
||||||
|
), "weight_scale_group_size must be a multiple of 128"
|
||||||
|
processed_scale = processed_scale.repeat_interleave(
|
||||||
|
layer.moe_intermediate_size // 128 // processed_scale.shape[-1], axis=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
origin_shape = processed_scale.shape
|
||||||
|
processed_scale = processed_scale.transpose([0, 2, 1])
|
||||||
|
processed_scale = processed_scale.reshape([-1, processed_scale.shape[-1]])
|
||||||
|
processed_scale = w4afp8_gemm_scale_permute(processed_scale)
|
||||||
|
processed_scale = processed_scale.reshape(
|
||||||
|
[origin_shape[0], origin_shape[2], origin_shape[1] // 128, 128]
|
||||||
|
)
|
||||||
|
processed_scale = processed_scale.transpose([0, 2, 1, 3])
|
||||||
|
else:
|
||||||
|
processed_scale = w4afp8_gemm_scale_permute(processed_scale)
|
||||||
|
setattr(
|
||||||
|
layer,
|
||||||
|
scale_name,
|
||||||
|
layer.create_parameter(
|
||||||
|
shape=processed_scale.shape,
|
||||||
|
dtype=scale_dtype,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
getattr(layer, weight_name).copy_(stacked_quant_weight, False)
|
||||||
|
getattr(layer, scale_name).copy_(processed_scale, False)
|
||||||
|
|
||||||
|
in_scale_name = scale_name.replace("_weight_scale", "_in_scale")
|
||||||
|
if hasattr(layer, in_scale_name):
|
||||||
|
getattr(layer, in_scale_name).set_value(paddle.ones([layer.num_local_experts], dtype="float32"))
|
||||||
|
|
||||||
|
del quant_weight_list, scale_list, stacked_quant_weight, stacked_scale, processed_scale
|
||||||
|
paddle.device.cuda.empty_cache()
|
||||||
|
|
||||||
|
up_gate_ready = hasattr(layer, "up_gate_proj_weight") and weight_fully_copied(layer.up_gate_proj_weight)
|
||||||
|
down_ready = hasattr(layer, "down_proj_weight") and weight_fully_copied(layer.down_proj_weight)
|
||||||
|
|
||||||
|
if not up_gate_ready and not down_ready:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self.quant_config.is_quantized:
|
||||||
|
if up_gate_ready and not getattr(self, "_up_gate_processed", False):
|
||||||
|
weight_type = "gate_up"
|
||||||
|
self._up_gate_processed = True
|
||||||
|
|
||||||
|
logger.info(f"Online quantizing layer.{layer.layer_idx}.mlp.experts.up_gate_proj.weight...")
|
||||||
|
|
||||||
|
if self.model_format == "torch":
|
||||||
|
process_weight_transpose(layer, "up_gate_proj_weight")
|
||||||
|
|
||||||
|
_process_quantize(weight_type)
|
||||||
|
|
||||||
|
elif down_ready and not getattr(self, "_down_processed", False):
|
||||||
|
weight_type = "down"
|
||||||
|
self._down_processed = True
|
||||||
|
|
||||||
|
logger.info(f"Rotating and online quantizing layer.{layer.layer_idx}.mlp.experts.down_proj.weight...")
|
||||||
|
|
||||||
|
if self.model_format == "torch":
|
||||||
|
process_weight_transpose(layer, "down_proj_weight")
|
||||||
|
|
||||||
|
_rotate_down_proj_weight()
|
||||||
|
|
||||||
|
_process_quantize(weight_type)
|
||||||
|
|
||||||
|
if getattr(self, "_up_gate_processed", False) and getattr(self, "_down_processed", False):
|
||||||
|
logger.info(f"Layer {layer.layer_idx} MoE W4AFP8 online quantization completed.")
|
||||||
|
del self._up_gate_processed
|
||||||
|
del self._down_processed
|
||||||
|
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||||
"""
|
"""
|
||||||
Paddle cutlass load weight process.
|
Paddle cutlass load weight process.
|
||||||
@@ -960,6 +1171,7 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
|||||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
|
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
|
||||||
layer.extract_moe_ffn_weights(state_dict)
|
layer.extract_moe_ffn_weights(state_dict)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||||
|
|
||||||
up_gate_proj_weight_scales = []
|
up_gate_proj_weight_scales = []
|
||||||
|
|||||||
@@ -88,7 +88,8 @@ def parse_quant_config(args, model_config, is_ernie, is_v1_loader):
|
|||||||
elif quant_config_name == "w4afp8":
|
elif quant_config_name == "w4afp8":
|
||||||
quantization_config["dense_quant_type"] = "block_wise_fp8"
|
quantization_config["dense_quant_type"] = "block_wise_fp8"
|
||||||
quantization_config["moe_quant_type"] = "w4afp8"
|
quantization_config["moe_quant_type"] = "w4afp8"
|
||||||
quantization_config["hadamard_block_size"] = 512
|
tp_size = getattr(args, "tensor_parallel_size", 1)
|
||||||
|
quantization_config["hadamard_block_size"] = 512 // tp_size
|
||||||
quantization_config["quantization"] = "mix_quant"
|
quantization_config["quantization"] = "mix_quant"
|
||||||
quant_config_name = "mix_quant"
|
quant_config_name = "mix_quant"
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ class W4AFP8Config(QuantConfigBase):
|
|||||||
self.is_permuted = is_permuted
|
self.is_permuted = is_permuted
|
||||||
self.hadamard_block_size = hadamard_block_size
|
self.hadamard_block_size = hadamard_block_size
|
||||||
self.is_quantized = is_quantized
|
self.is_quantized = is_quantized
|
||||||
|
self.is_checkpoint_bf16 = not is_quantized
|
||||||
|
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return "w4afp8"
|
return "w4afp8"
|
||||||
|
|||||||
@@ -110,7 +110,11 @@ class Ernie4_5_MoE(nn.Layer):
|
|||||||
if hasattr(fd_config.quant_config, "moe_quant_type"):
|
if hasattr(fd_config.quant_config, "moe_quant_type"):
|
||||||
moe_quant_type = fd_config.quant_config.moe_quant_type
|
moe_quant_type = fd_config.quant_config.moe_quant_type
|
||||||
|
|
||||||
if moe_quant_type == "w4a8" or moe_quant_type == "w4afp8":
|
if moe_quant_type == "w4a8" or (
|
||||||
|
moe_quant_type == "w4afp8"
|
||||||
|
and fd_config.model_config.is_quantized
|
||||||
|
and not fd_config.quant_config.moe_dynamic_quant
|
||||||
|
):
|
||||||
weight_key_map = {
|
weight_key_map = {
|
||||||
"gate_weight_key": f"{prefix}.gate.weight",
|
"gate_weight_key": f"{prefix}.gate.weight",
|
||||||
"gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias",
|
"gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias",
|
||||||
@@ -121,6 +125,19 @@ class Ernie4_5_MoE(nn.Layer):
|
|||||||
"up_gate_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.activation_scale",
|
"up_gate_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.activation_scale",
|
||||||
"down_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.down_proj.activation_scale",
|
"down_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.down_proj.activation_scale",
|
||||||
}
|
}
|
||||||
|
elif (
|
||||||
|
moe_quant_type == "w4afp8"
|
||||||
|
and fd_config.model_config.is_quantized
|
||||||
|
and fd_config.quant_config.moe_dynamic_quant
|
||||||
|
):
|
||||||
|
weight_key_map = {
|
||||||
|
"gate_weight_key": f"{prefix}.gate.weight",
|
||||||
|
"gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias",
|
||||||
|
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.quant_weight",
|
||||||
|
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.quant_weight",
|
||||||
|
"up_gate_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.weight_scale",
|
||||||
|
"down_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.down_proj.weight_scale",
|
||||||
|
}
|
||||||
elif moe_quant_type == "w4w2":
|
elif moe_quant_type == "w4w2":
|
||||||
weight_key_map = {
|
weight_key_map = {
|
||||||
"gate_weight_key": f"{prefix}.gate.weight",
|
"gate_weight_key": f"{prefix}.gate.weight",
|
||||||
@@ -223,6 +240,7 @@ class Ernie4_5_MoE(nn.Layer):
|
|||||||
gate=self.gate,
|
gate=self.gate,
|
||||||
forward_meta=forward_meta,
|
forward_meta=forward_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.num_shared_experts > 0:
|
if self.num_shared_experts > 0:
|
||||||
s_x = self.shared_experts(hidden_states)
|
s_x = self.shared_experts(hidden_states)
|
||||||
out = out + s_x
|
out = out + s_x
|
||||||
|
|||||||
@@ -390,7 +390,7 @@ def v1_loader_support(fd_config):
|
|||||||
|
|
||||||
def _get_unsupported_quant():
|
def _get_unsupported_quant():
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
return {"w4a8", "w4afp8", "wint2"}
|
return {"w4a8", "wint2"}
|
||||||
elif current_platform.is_xpu():
|
elif current_platform.is_xpu():
|
||||||
return {"w4a8", "w8a8"}
|
return {"w4a8", "w8a8"}
|
||||||
return set()
|
return set()
|
||||||
|
|||||||
@@ -0,0 +1,343 @@
|
|||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import signal
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
import openai
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
tests_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||||
|
sys.path.insert(0, tests_dir)
|
||||||
|
|
||||||
|
from e2e.utils.serving_utils import (
|
||||||
|
FD_API_PORT,
|
||||||
|
FD_CACHE_QUEUE_PORT,
|
||||||
|
FD_ENGINE_QUEUE_PORT,
|
||||||
|
FD_METRICS_PORT,
|
||||||
|
clean_ports,
|
||||||
|
is_port_open,
|
||||||
|
)
|
||||||
|
|
||||||
|
os.environ.setdefault("DG_NVCC_OVERRIDE_CPP_STANDARD", "17")
|
||||||
|
|
||||||
|
W4AFP8_CONFIGS = [
|
||||||
|
{
|
||||||
|
"id": "w4afp8_default",
|
||||||
|
"load_choices": "default",
|
||||||
|
"model_name": "ernie-4_5-21b-a3b-bf16-paddle",
|
||||||
|
"model_subdir": None,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "w4afp8_default_v1",
|
||||||
|
"load_choices": "default_v1",
|
||||||
|
"model_name": "ERNIE-4.5-21B-A3B-PT",
|
||||||
|
"model_subdir": "torch",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_path(config):
|
||||||
|
"""Get model path based on config and MODEL_PATH environment variable."""
|
||||||
|
base_path = os.getenv("MODEL_PATH")
|
||||||
|
model_name = config["model_name"]
|
||||||
|
model_subdir = config.get("model_subdir")
|
||||||
|
|
||||||
|
if base_path:
|
||||||
|
if model_subdir:
|
||||||
|
model_path = os.path.join(base_path, model_subdir, model_name)
|
||||||
|
else:
|
||||||
|
model_path = os.path.join(base_path, model_name)
|
||||||
|
else:
|
||||||
|
if model_subdir:
|
||||||
|
model_path = os.path.join(".", model_subdir, model_name)
|
||||||
|
else:
|
||||||
|
model_path = f"./{model_name}"
|
||||||
|
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", params=W4AFP8_CONFIGS, ids=lambda x: x["id"])
|
||||||
|
def setup_w4afp8_server(request):
|
||||||
|
"""
|
||||||
|
Setup W4AFP8 server for each config.
|
||||||
|
This fixture is parameterized to run with different configurations.
|
||||||
|
"""
|
||||||
|
config = request.param
|
||||||
|
config_id = config["id"]
|
||||||
|
load_choices = config["load_choices"]
|
||||||
|
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f"Starting W4AFP8 server with config: {config_id}")
|
||||||
|
print(f" load_choices: {load_choices}")
|
||||||
|
print(f" api_port: {FD_API_PORT}")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
|
||||||
|
# Clean ports before starting
|
||||||
|
clean_ports()
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
model_path = get_model_path(config)
|
||||||
|
|
||||||
|
# Check model path exists
|
||||||
|
print(f"Model path: {model_path}")
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
pytest.skip(f"Model path does not exist: {model_path}")
|
||||||
|
|
||||||
|
log_path = f"server_{config_id}.log"
|
||||||
|
log_dir = f"log_{config_id}"
|
||||||
|
|
||||||
|
if os.path.exists(log_dir):
|
||||||
|
shutil.rmtree(log_dir)
|
||||||
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
sys.executable,
|
||||||
|
"-m",
|
||||||
|
"fastdeploy.entrypoints.openai.api_server",
|
||||||
|
"--model",
|
||||||
|
model_path,
|
||||||
|
"--port",
|
||||||
|
str(FD_API_PORT),
|
||||||
|
"--tensor-parallel-size",
|
||||||
|
"2",
|
||||||
|
"--engine-worker-queue-port",
|
||||||
|
str(FD_ENGINE_QUEUE_PORT),
|
||||||
|
"--metrics-port",
|
||||||
|
str(FD_METRICS_PORT),
|
||||||
|
"--cache-queue-port",
|
||||||
|
str(FD_CACHE_QUEUE_PORT),
|
||||||
|
"--max-model-len",
|
||||||
|
"32768",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"128",
|
||||||
|
"--quantization",
|
||||||
|
"w4afp8",
|
||||||
|
"--load-choices",
|
||||||
|
load_choices,
|
||||||
|
"--graph-optimization-config",
|
||||||
|
'{"cudagraph_capture_sizes": [1]}',
|
||||||
|
]
|
||||||
|
|
||||||
|
print(f"Starting server with command: {' '.join(cmd)}")
|
||||||
|
|
||||||
|
with open(log_path, "w") as logfile:
|
||||||
|
process = subprocess.Popen(
|
||||||
|
cmd,
|
||||||
|
stdout=logfile,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
start_new_session=True,
|
||||||
|
env={**os.environ, "FD_LOG_DIR": log_dir},
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Server process started with PID: {process.pid}")
|
||||||
|
|
||||||
|
# Wait for server to start
|
||||||
|
server_started = False
|
||||||
|
for i in range(300):
|
||||||
|
# Check if process is still alive
|
||||||
|
if process.poll() is not None:
|
||||||
|
print(f"[ERROR] Server process exited early with code: {process.returncode}")
|
||||||
|
break
|
||||||
|
|
||||||
|
if is_port_open("127.0.0.1", FD_API_PORT):
|
||||||
|
print(f"API server [{config_id}] is up on port {FD_API_PORT}")
|
||||||
|
server_started = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if i % 30 == 0:
|
||||||
|
print(f"Waiting for server [{config_id}] to start... ({i}s)")
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
if not server_started:
|
||||||
|
print(f"[TIMEOUT] API server [{config_id}] failed to start in 5 minutes.")
|
||||||
|
|
||||||
|
# Print log content for debugging
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f"Server log [{config_id}]:")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
try:
|
||||||
|
with open(log_path, "r") as f:
|
||||||
|
log_content = f.read()
|
||||||
|
# Print last 100 lines
|
||||||
|
lines = log_content.split("\n")
|
||||||
|
print("\n".join(lines[-100:]))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to read log: {e}")
|
||||||
|
print(f"{'='*60}\n")
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
try:
|
||||||
|
os.killpg(process.pid, signal.SIGTERM)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to kill process group: {e}")
|
||||||
|
|
||||||
|
clean_ports()
|
||||||
|
raise RuntimeError(f"API server [{config_id}] did not start on port {FD_API_PORT}")
|
||||||
|
|
||||||
|
yield {"process": process, "config": config}
|
||||||
|
|
||||||
|
# Cleanup after test
|
||||||
|
print(f"\n===== Cleanup W4AFP8 server [{config_id}]... =====")
|
||||||
|
|
||||||
|
# Graceful shutdown
|
||||||
|
try:
|
||||||
|
process.terminate()
|
||||||
|
process.wait(timeout=30)
|
||||||
|
print(f"API server [{config_id}] (pid={process.pid}) terminated gracefully")
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
print(f"Timeout waiting for server [{config_id}], force killing...")
|
||||||
|
try:
|
||||||
|
os.killpg(process.pid, signal.SIGKILL)
|
||||||
|
process.wait(timeout=10)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to force kill: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to terminate API server [{config_id}]: {e}")
|
||||||
|
try:
|
||||||
|
os.killpg(process.pid, signal.SIGKILL)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Clean ports after shutdown
|
||||||
|
clean_ports()
|
||||||
|
time.sleep(10)
|
||||||
|
print(f"Cleanup [{config_id}] completed")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def openai_client(setup_w4afp8_server):
|
||||||
|
"""
|
||||||
|
Returns OpenAI client for W4AFP8 quantization service.
|
||||||
|
Depends on setup_w4afp8_server to ensure server is running.
|
||||||
|
"""
|
||||||
|
client = openai.OpenAI(
|
||||||
|
base_url=f"http://127.0.0.1:{FD_API_PORT}/v1",
|
||||||
|
api_key="EMPTY_API_KEY",
|
||||||
|
)
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def current_config(setup_w4afp8_server):
|
||||||
|
"""
|
||||||
|
Returns the current server config for the test module.
|
||||||
|
"""
|
||||||
|
return setup_w4afp8_server["config"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def consistent_payload():
|
||||||
|
"""
|
||||||
|
Returns a fixed payload for consistency testing,
|
||||||
|
including a fixed random seed and temperature.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "北京天安门在哪里?",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"temperature": 0.8,
|
||||||
|
"top_p": 0, # fix top_p to reduce randomness
|
||||||
|
"seed": 13, # fixed random seed
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# Helper function to calculate difference rate between two texts
|
||||||
|
# ==========================
|
||||||
|
def calculate_diff_rate(text1, text2):
|
||||||
|
"""
|
||||||
|
Calculate the difference rate between two strings
|
||||||
|
based on the normalized Levenshtein edit distance.
|
||||||
|
Returns a float in [0,1], where 0 means identical.
|
||||||
|
"""
|
||||||
|
if text1 == text2:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
len1, len2 = len(text1), len(text2)
|
||||||
|
dp = [[0] * (len2 + 1) for _ in range(len1 + 1)]
|
||||||
|
|
||||||
|
for i in range(len1 + 1):
|
||||||
|
for j in range(len2 + 1):
|
||||||
|
if i == 0 or j == 0:
|
||||||
|
dp[i][j] = i + j
|
||||||
|
elif text1[i - 1] == text2[j - 1]:
|
||||||
|
dp[i][j] = dp[i - 1][j - 1]
|
||||||
|
else:
|
||||||
|
dp[i][j] = 1 + min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1])
|
||||||
|
|
||||||
|
edit_distance = dp[len1][len2]
|
||||||
|
max_len = max(len1, len2)
|
||||||
|
return edit_distance / max_len if max_len > 0 else 0.0
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# Test Cases
|
||||||
|
# ==========================
|
||||||
|
def test_w4afp8_consistency_between_runs(openai_client, consistent_payload, current_config):
|
||||||
|
"""
|
||||||
|
Test that two runs with the same fixed input produce similar outputs.
|
||||||
|
This test runs for each W4AFP8 config (default and default_v1).
|
||||||
|
"""
|
||||||
|
config_id = current_config["id"]
|
||||||
|
load_choices = current_config["load_choices"]
|
||||||
|
|
||||||
|
print(f"\n[{config_id}] Testing consistency with load_choices={load_choices}")
|
||||||
|
|
||||||
|
# First request
|
||||||
|
resp1 = openai_client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
stream=False,
|
||||||
|
max_tokens=256,
|
||||||
|
**consistent_payload,
|
||||||
|
)
|
||||||
|
content1 = resp1.choices[0].message.content
|
||||||
|
|
||||||
|
# Second request with same parameters
|
||||||
|
resp2 = openai_client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
stream=False,
|
||||||
|
max_tokens=256,
|
||||||
|
**consistent_payload,
|
||||||
|
)
|
||||||
|
content2 = resp2.choices[0].message.content
|
||||||
|
|
||||||
|
# Check required keywords
|
||||||
|
required_keywords = ["北京", "天安门"]
|
||||||
|
for keyword in required_keywords:
|
||||||
|
assert keyword in content1, (
|
||||||
|
f"[{config_id}] First response missing keyword '{keyword}', " f"response content: {content1}"
|
||||||
|
)
|
||||||
|
assert keyword in content2, (
|
||||||
|
f"[{config_id}] Second response missing keyword '{keyword}', " f"response content: {content2}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check consistency between runs
|
||||||
|
diff_rate = calculate_diff_rate(content1, content2)
|
||||||
|
print(f"[{config_id}] Diff rate between two runs: {diff_rate:.4%}")
|
||||||
|
|
||||||
|
assert diff_rate < 0.05, (
|
||||||
|
f"[{config_id}] Output difference too large ({diff_rate:.4%})\n"
|
||||||
|
f"Response 1: {content1}\n"
|
||||||
|
f"Response 2: {content2}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"[{config_id}] Consistency test passed! Diff rate: {diff_rate:.4%}")
|
||||||
Reference in New Issue
Block a user