diff --git a/.github/workflows/_base_test.yml b/.github/workflows/_base_test.yml index 4087a50ff4..378e0105ad 100644 --- a/.github/workflows/_base_test.yml +++ b/.github/workflows/_base_test.yml @@ -233,14 +233,28 @@ jobs: curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \ -H "Content-Type: application/json" \ - -d "{\"--model\": \"/MODELDATA/ernie-4_5-21b-a3b-bf16-paddle\", \"--config\": \"ernie45t_21b_sot.yaml\", \"--enable-logprob\": \"False\"}" + -d "{\"--model\": \"/MODELDATA/ernie-4_5-21b-a3b-bf16-paddle\", \"--config\": \"ernie45t_21b_sot_wint4.yaml\", \"--enable-logprob\": \"False\"}" check_service 360 export TEMPLATE=TOKEN_NORMAL python -m pytest -sv test_seed_usage.py -k "not test_seed_stream" || TEST_EXIT_CODE=1 curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \ -H "Content-Type: application/json" \ - -d "{\"--model\": \"/MODELDATA/ernie-4_5-21b-a3b-bf16-paddle\", \"--config\": \"ernie45t_21b_cinn.yaml\", \"--enable-logprob\": \"False\"}" + -d "{\"--model\": \"/MODELDATA/ernie-4_5-21b-a3b-bf16-paddle\", \"--config\": \"ernie45t_21b_cinn_wint4.yaml\", \"--enable-logprob\": \"False\"}" + check_service 360 + export TEMPLATE=TOKEN_NORMAL + python -m pytest -sv test_seed_usage.py -k "not test_seed_stream" || TEST_EXIT_CODE=1 + + curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \ + -H "Content-Type: application/json" \ + -d "{\"--model\": \"/MODELDATA/ernie-4_5-21b-a3b-bf16-paddle\", \"--config\": \"ernie45t_21b_sot_fp8.yaml\", \"--enable-logprob\": \"False\"}" + check_service 360 + export TEMPLATE=TOKEN_NORMAL + python -m pytest -sv test_seed_usage.py -k "not test_seed_stream" || TEST_EXIT_CODE=1 + + curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \ + -H "Content-Type: application/json" \ + -d "{\"--model\": \"/MODELDATA/ernie-4_5-21b-a3b-bf16-paddle\", \"--config\": \"ernie45t_21b_cinn_fp8.yaml\", \"--enable-logprob\": \"False\"}" check_service 360 export TEMPLATE=TOKEN_NORMAL python -m pytest -sv test_seed_usage.py -k "not test_seed_stream" || TEST_EXIT_CODE=1 diff --git a/fastdeploy/distributed/communication.py b/fastdeploy/distributed/communication.py index 922fbb3df8..039f545c32 100644 --- a/fastdeploy/distributed/communication.py +++ b/fastdeploy/distributed/communication.py @@ -20,6 +20,8 @@ import paddle import paddle.distributed as dist from paddle.distributed import fleet +from fastdeploy.utils import register_custom_python_op + _TP_AR = None @@ -50,7 +52,18 @@ def custom_ar_clear_ipc_handles(): try: - @paddle.jit.marker.unified + def tensor_model_parallel_all_reduce_infer_meta(x: "paddle.static.MetaTensor", group_) -> paddle.static.MetaTensor: + return paddle.static.MetaTensor(shape=x.shape, dtype=x.dtype) + + @register_custom_python_op( + name="tensor_model_parallel_all_reduce", + infer_meta=tensor_model_parallel_all_reduce_infer_meta, + input_names=[ + "input_", + ], + output_names=["out"], + inplace_map={}, + ) def tensor_model_parallel_all_reduce( input_: paddle.Tensor, group_: paddle.distributed.communication.group.Group = None, diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index da705357c1..c332adf8cc 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -28,7 +28,7 @@ from fastdeploy.model_executor.utils import ( set_weight_attrs, weight_fully_copied, ) -from fastdeploy.utils import ceil_div +from fastdeploy.utils import ceil_div, register_custom_python_op from ..quantization.quant_base import QuantMethodBase @@ -1141,6 +1141,196 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): return out +def python_op_fused_moe_kernel_paddle_infer_meta( + x, + layer_added_weight_attrs_0, + layer_added_scale_attrs_0, + layer_added_weight_attrs1, + layer_added_scale_attrs1, + gate_out, + gate_correction_bias, + top_k: int, + N1: int, + N2: int, + num_local_experts: int, + moe_intermediate_size: int, + hidden_size: int, + config: dict, + quant_config, + topk_ids_hookfunc, +): + token_num = x.shape[0] + return paddle.static.MetaTensor(shape=[token_num, hidden_size], dtype=x.dtype) + + +@register_custom_python_op( + name="python_op_fused_moe_kernel_paddle", + infer_meta=python_op_fused_moe_kernel_paddle_infer_meta, + input_names=[ + "x", + "layer_added_weight_attrs_0", + "layer_added_scale_attrs_0", + "layer_added_weight_attrs1", + "layer_added_scale_attrs1", + "gate_out", + "gate_correction_bias", + ], + output_names=["out"], + inplace_map={}, +) +def python_op_fused_moe_kernel_paddle( + x: paddle.Tensor, + layer_added_weight_attrs_0: paddle.Tensor, + layer_added_scale_attrs_0: paddle.Tensor, + layer_added_weight_attrs1: paddle.Tensor, + layer_added_scale_attrs1: paddle.Tensor, + gate_out: paddle.Tensor, + gate_correction_bias: paddle.Tensor, + top_k: int, + N1: int, + N2: int, + num_local_experts: int, + moe_intermediate_size: int, + hidden_size: int, + config: dict, + quant_config, + topk_ids_hookfunc, +): + + token_num = x.shape[0] + if x.shape[0] == 0: + return paddle.zeros([token_num, hidden_size], dtype=x.dtype) + + topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( + gate_out, + gate_correction_bias, + top_k, + True, # apply_norm_weight + False, + ) + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_ids) + + from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func + + sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func( + topk_ids, num_local_experts, config["BLOCK_SIZE_M"] + ) + # cache13 = create_empty_tensor(tuple([token_num * top_k * max(N1, N2)]), x.dtype) + cache13 = paddle.empty([token_num * top_k * max(N1, N2)], dtype=x.dtype) + intermediate_cache1 = cache13[: token_num * top_k * N1].view([token_num * top_k, N1]) + max_num_tokens_padded = sorted_token_ids.shape[0] + + grid = ( + ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) + * ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), + ) + + from .triton_moe_kernels import fused_moe_kernel_paddle + + x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, quant_config.weight_block_size[0]) + + fused_moe_kernel_paddle[grid]( + x_q, + layer_added_weight_attrs_0, + intermediate_cache1, + x_scale, + layer_added_scale_attrs_0, + None, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + max_num_tokens_padded, + token_num * top_k, + N=moe_intermediate_size * 2, + K=hidden_size, + stride_am=x_q.strides[0], + stride_ak=x_q.strides[1], + stride_be=layer_added_weight_attrs_0.strides[0], + stride_bk=layer_added_weight_attrs_0.strides[2], + stride_bn=layer_added_weight_attrs_0.strides[1], + stride_cm=intermediate_cache1.strides[0], + stride_cn=intermediate_cache1.strides[1], + # + stride_asm=x_scale.strides[0], # only used in blockwise fp8 + stride_ask=x_scale.strides[1], # only used in blockwise fp8 + stride_bse=layer_added_scale_attrs_0.strides[0], + stride_bsk=layer_added_scale_attrs_0.strides[2], + stride_bsn=layer_added_scale_attrs_0.strides[1], + group_n=quant_config.weight_block_size[1], + group_k=quant_config.weight_block_size[0], + # Meta-parameters + BLOCK_SIZE_M=config["BLOCK_SIZE_M"], + BLOCK_SIZE_N=config["BLOCK_SIZE_N"], + BLOCK_SIZE_K=config["BLOCK_SIZE_K"], + GROUP_SIZE_M=config["GROUP_SIZE_M"], + MUL_ROUTED_WEIGHT=False, + top_k=top_k, + compute_type_enum=1, + use_fp8_w8a8=True, + use_int8_w8a16=False, + per_channel_quant=False, + even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0, + ) + + intermediate_cache2 = paddle.incubate.nn.functional.swiglu(intermediate_cache1) + + intermediate_cache3 = cache13[: token_num * top_k * N2].view([token_num * top_k, N2]) + + grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) * ceil_div(hidden_size, config["BLOCK_SIZE_N"]),) + + x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant( + intermediate_cache2, quant_config.weight_block_size[0] + ) + + fused_moe_kernel_paddle[grid]( + x_q, + layer_added_weight_attrs1, + intermediate_cache3, + x_scale, + layer_added_scale_attrs1, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + max_num_tokens_padded, + token_num * top_k, + N=hidden_size, + K=moe_intermediate_size, + stride_am=x_q.strides[0], + stride_ak=x_q.strides[1], + stride_be=layer_added_weight_attrs1.strides[0], + stride_bk=layer_added_weight_attrs1.strides[2], + stride_bn=layer_added_weight_attrs1.strides[1], + stride_cm=intermediate_cache3.strides[0], + stride_cn=intermediate_cache3.strides[1], + stride_asm=x_scale.strides[0], # only used in blockwise fp8 + stride_ask=x_scale.strides[1], # only used in blockwise fp8 + stride_bse=layer_added_scale_attrs1.strides[0], + stride_bsk=layer_added_scale_attrs1.strides[2], + stride_bsn=layer_added_scale_attrs1.strides[1], + group_n=quant_config.weight_block_size[1], + group_k=quant_config.weight_block_size[0], + # Meta-parameters + BLOCK_SIZE_M=config["BLOCK_SIZE_M"], + BLOCK_SIZE_N=config["BLOCK_SIZE_N"], + BLOCK_SIZE_K=config["BLOCK_SIZE_K"], + GROUP_SIZE_M=config["GROUP_SIZE_M"], + MUL_ROUTED_WEIGHT=True, + top_k=1, + compute_type_enum=1, + use_fp8_w8a8=True, + use_int8_w8a16=False, + per_channel_quant=False, + even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0, + ) + + intermediate_cache3.reshape_([token_num, top_k, hidden_size]) + out = intermediate_cache3.sum(axis=1) + + return out + + class BlockWiseFP8MoEMethod(QuantMethodBase): """ Use Triton Group Gemm to compute Fused BlockWise FP8 Quant MoE. @@ -1479,9 +1669,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): """ Triton compute Fused MoE. """ - token_num = x.shape[0] - if token_num == 0: - return paddle.zeros([token_num, layer.hidden_size], dtype=x.dtype) + gate_out = gate(x.cast("float32")) top_k = layer.top_k num_local_experts = layer.num_local_experts @@ -1490,15 +1678,12 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): E, N1, _ = getattr(layer, self.added_weight_attrs[0]).shape N2 = getattr(layer, self.added_weight_attrs[1]).shape[1] - topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( - gate_out, - layer.gate_correction_bias, - layer.top_k, - True, # apply_norm_weight - False, - ) - if topk_ids_hookfunc is not None: - topk_ids_hookfunc(topk_ids=topk_ids) + gate_correction_bias = layer.gate_correction_bias + # for triton op input + layer_added_weight_attrs_0 = getattr(layer, self.added_weight_attrs[0]) + layer_added_scale_attrs_0 = getattr(layer, self.added_scale_attrs[0]) + layer_added_weight_attrs1 = getattr(layer, self.added_weight_attrs[1]) + layer_added_scale_attrs1 = getattr(layer, self.added_scale_attrs[1]) config = { "BLOCK_SIZE_M": 64, @@ -1508,123 +1693,22 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): "num_warps": 4, "num_stages": 3, } - from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func - sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func( - topk_ids, num_local_experts, config["BLOCK_SIZE_M"] + return python_op_fused_moe_kernel_paddle( + x, + layer_added_weight_attrs_0, + layer_added_scale_attrs_0, + layer_added_weight_attrs1, + layer_added_scale_attrs1, + gate_out, + gate_correction_bias, + top_k, + N1, + N2, + num_local_experts, + moe_intermediate_size, + hidden_size, + config, + self.quant_config, + topk_ids_hookfunc, ) - # cache13 = create_empty_tensor(tuple([token_num * top_k * max(N1, N2)]), x.dtype) - cache13 = paddle.empty([token_num * top_k * max(N1, N2)], dtype=x.dtype) - intermediate_cache1 = cache13[: token_num * top_k * N1].view([token_num * top_k, N1]) - max_num_tokens_padded = sorted_token_ids.shape[0] - - grid = ( - ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) - * ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), - ) - - from .triton_moe_kernels import fused_moe_kernel_paddle - - x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, self.quant_config.weight_block_size[0]) - - fused_moe_kernel_paddle[grid]( - x_q, - getattr(layer, self.added_weight_attrs[0]), - intermediate_cache1, - x_scale, - getattr(layer, self.added_scale_attrs[0]), - None, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - max_num_tokens_padded, - token_num * top_k, - N=moe_intermediate_size * 2, - K=hidden_size, - stride_am=x_q.strides[0], - stride_ak=x_q.strides[1], - stride_be=getattr(layer, self.added_weight_attrs[0]).strides[0], - stride_bk=getattr(layer, self.added_weight_attrs[0]).strides[2], - stride_bn=getattr(layer, self.added_weight_attrs[0]).strides[1], - stride_cm=intermediate_cache1.strides[0], - stride_cn=intermediate_cache1.strides[1], - # - stride_asm=x_scale.strides[0], # only used in blockwise fp8 - stride_ask=x_scale.strides[1], # only used in blockwise fp8 - stride_bse=getattr(layer, self.added_scale_attrs[0]).strides[0], - stride_bsk=getattr(layer, self.added_scale_attrs[0]).strides[2], - stride_bsn=getattr(layer, self.added_scale_attrs[0]).strides[1], - group_n=self.quant_config.weight_block_size[1], - group_k=self.quant_config.weight_block_size[0], - # Meta-parameters - BLOCK_SIZE_M=config["BLOCK_SIZE_M"], - BLOCK_SIZE_N=config["BLOCK_SIZE_N"], - BLOCK_SIZE_K=config["BLOCK_SIZE_K"], - GROUP_SIZE_M=config["GROUP_SIZE_M"], - MUL_ROUTED_WEIGHT=False, - top_k=top_k, - compute_type_enum=1, - use_fp8_w8a8=True, - use_int8_w8a16=False, - per_channel_quant=False, - even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0, - ) - - intermediate_cache2 = paddle.incubate.nn.functional.swiglu(intermediate_cache1) - - intermediate_cache3 = cache13[: token_num * top_k * N2].view([token_num * top_k, N2]) - - grid = ( - ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) * ceil_div(hidden_size, config["BLOCK_SIZE_N"]), - ) - - x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant( - intermediate_cache2, self.quant_config.weight_block_size[0] - ) - - fused_moe_kernel_paddle[grid]( - x_q, - getattr(layer, self.added_weight_attrs[1]), - intermediate_cache3, - x_scale, - getattr(layer, self.added_scale_attrs[1]), - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - max_num_tokens_padded, - token_num * top_k, - N=hidden_size, - K=moe_intermediate_size, - stride_am=x_q.strides[0], - stride_ak=x_q.strides[1], - stride_be=getattr(layer, self.added_weight_attrs[1]).strides[0], - stride_bk=getattr(layer, self.added_weight_attrs[1]).strides[2], - stride_bn=getattr(layer, self.added_weight_attrs[1]).strides[1], - stride_cm=intermediate_cache3.strides[0], - stride_cn=intermediate_cache3.strides[1], - stride_asm=x_scale.strides[0], # only used in blockwise fp8 - stride_ask=x_scale.strides[1], # only used in blockwise fp8 - stride_bse=getattr(layer, self.added_scale_attrs[1]).strides[0], - stride_bsk=getattr(layer, self.added_scale_attrs[1]).strides[2], - stride_bsn=getattr(layer, self.added_scale_attrs[1]).strides[1], - group_n=self.quant_config.weight_block_size[1], - group_k=self.quant_config.weight_block_size[0], - # Meta-parameters - BLOCK_SIZE_M=config["BLOCK_SIZE_M"], - BLOCK_SIZE_N=config["BLOCK_SIZE_N"], - BLOCK_SIZE_K=config["BLOCK_SIZE_K"], - GROUP_SIZE_M=config["GROUP_SIZE_M"], - MUL_ROUTED_WEIGHT=True, - top_k=1, - compute_type_enum=1, - use_fp8_w8a8=True, - use_int8_w8a16=False, - per_channel_quant=False, - even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0, - ) - - intermediate_cache3.reshape_([token_num, top_k, hidden_size]) - out = intermediate_cache3.sum(axis=1) - - return out diff --git a/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py b/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py index 59daa23848..d5af8106ff 100644 --- a/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py +++ b/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py @@ -31,6 +31,7 @@ from fastdeploy.model_executor.utils import ( process_weight_transpose, set_weight_attrs, ) +from fastdeploy.utils import register_custom_python_op from ..utils import get_tensor, per_block_cast_to_fp8 from .quant_base import QuantConfigBase, QuantMethodBase @@ -81,6 +82,43 @@ class BlockWiseFP8Config(QuantConfigBase): return BlockWiseFP8LinearMethod(self) +def deep_gemm_fp8_fp8_bf16_nt_infer_meta( + x_meta: "paddle.static.MetaTensor", + x_scale_tensor_meta: "paddle.static.MetaTensor", + layer_weight_meta: "paddle.static.MetaTensor", + layer_weight_scale_inv_meta: "paddle.static.MetaTensor", + linear_out_meta: "paddle.static.MetaTensor", + layer_output_size: int, +): + return paddle.static.MetaTensor(shape=[x_meta.shape[0], layer_output_size], dtype=paddle.bfloat16) + + +@register_custom_python_op( + name="deep_gemm_fp8_fp8_bf16_nt", + infer_meta=deep_gemm_fp8_fp8_bf16_nt_infer_meta, + input_names=["x", "x_scale_tensor", "layer_weight", "layer_weight_scale_inv", "linear_out_empty"], + output_names=["linear_out"], + inplace_map={}, +) +def deep_gemm_fp8_fp8_bf16_nt( + x: paddle.Tensor, + x_scale_tensor: paddle.Tensor, + layer_weight: paddle.Tensor, + layer_weight_scale_inv: paddle.Tensor, + linear_out: paddle.Tensor, + layer_output_size: int, +): + from fastdeploy.model_executor.ops.gpu import deep_gemm + + deep_gemm.gemm_fp8_fp8_bf16_nt( + (x, x_scale_tensor), + (layer_weight, layer_weight_scale_inv), + linear_out, + ) + + return linear_out + + class BlockWiseFP8LinearMethod(QuantMethodBase): """ block wise quantization method for linear @@ -230,12 +268,8 @@ class BlockWiseFP8LinearMethod(QuantMethodBase): x, self.quant_config.weight_block_size[0] ) linear_out = paddle.empty((x.shape[0], layer.output_size), dtype=paddle.bfloat16) - from fastdeploy.model_executor.ops.gpu import deep_gemm - - deep_gemm.gemm_fp8_fp8_bf16_nt( - (x, x_scale_tensor), - (layer.weight, layer.weight_scale_inv), - linear_out, + linear_out = deep_gemm_fp8_fp8_bf16_nt( + x, x_scale_tensor, layer.weight, layer.weight_scale_inv, linear_out, layer.output_size ) if layer.with_bias: linear_out = paddle.add(linear_out, layer.bias) diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index f8330c3586..240feb1129 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -1229,3 +1229,18 @@ def to_tensor(tasks: List[Any]): multimodal_inputs[key] = [paddle.to_tensor(v) for v in value] except Exception as e: llm_logger.warning(f"Tensor conversion failed: {type(e).__name__}: {e}") + + +def do_nothing(*args, **kwargs): + def decorator(func): + return func + + return decorator + + +if hasattr(paddle.static, "register_op"): + from paddle.static import register_op +else: + register_op = do_nothing + +register_custom_python_op = register_op diff --git a/tests/ce/deploy/ernie45t_21b_cinn_fp8.yaml b/tests/ce/deploy/ernie45t_21b_cinn_fp8.yaml new file mode 100644 index 0000000000..d02af4a6c6 --- /dev/null +++ b/tests/ce/deploy/ernie45t_21b_cinn_fp8.yaml @@ -0,0 +1,8 @@ +max_model_len: 32768 +max_num_seqs: 128 +tensor_parallel_size: 1 +quantization: block_wise_fp8 +graph_optimization_config: + graph_opt_level: 2 + sot_warmup_sizes: [2,16,32,64] + use_cudagraph: True diff --git a/tests/ce/deploy/ernie45t_21b_cinn.yaml b/tests/ce/deploy/ernie45t_21b_cinn_wint4.yaml similarity index 100% rename from tests/ce/deploy/ernie45t_21b_cinn.yaml rename to tests/ce/deploy/ernie45t_21b_cinn_wint4.yaml diff --git a/tests/ce/deploy/ernie45t_21b_sot_fp8.yaml b/tests/ce/deploy/ernie45t_21b_sot_fp8.yaml new file mode 100644 index 0000000000..269afb1004 --- /dev/null +++ b/tests/ce/deploy/ernie45t_21b_sot_fp8.yaml @@ -0,0 +1,8 @@ +max_model_len: 32768 +max_num_seqs: 128 +tensor_parallel_size: 1 +quantization: block_wise_fp8 +graph_optimization_config: + graph_opt_level: 1 + sot_warmup_sizes: [2,16,32,64] + use_cudagraph: True diff --git a/tests/ce/deploy/ernie45t_21b_sot.yaml b/tests/ce/deploy/ernie45t_21b_sot_wint4.yaml similarity index 100% rename from tests/ce/deploy/ernie45t_21b_sot.yaml rename to tests/ce/deploy/ernie45t_21b_sot_wint4.yaml