mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[XPU] refine moe_expert_ffn ut (#5743)
This commit is contained in:
@@ -81,16 +81,17 @@ def weight_quant_wint4(w_fp32):
|
||||
return w_int4, w_max.reshape([-1])
|
||||
|
||||
|
||||
def weight_quant(w_fp32, algo="weight_only_int8"):
|
||||
if algo == "weight_only_int8":
|
||||
def weight_quant(w_fp32, algo="w_channelwise_int8_a_float32"):
|
||||
if algo == "w_channelwise_int8_a_float32":
|
||||
return weight_quant_wint8(w_fp32)
|
||||
elif algo == "weight_only_int4":
|
||||
elif algo == "w_channelwise_int4_a_tokenwise_int15":
|
||||
return weight_quant_wint4(w_fp32)
|
||||
else:
|
||||
return None, None
|
||||
|
||||
|
||||
quant_method = "weight_only_int4"
|
||||
quant_method = "w_channelwise_int4_a_tokenwise_int15"
|
||||
# quant_method = "w_channelwise_int8_a_float32"
|
||||
print(f"quant_method={quant_method}, used_in_ep_low_latency={used_in_ep_low_latency}")
|
||||
ffn1_quant_w, ffn1_w_scale = weight_quant(ffn1_w, quant_method)
|
||||
ffn2_quant_w, ffn2_w_scale = weight_quant(ffn2_w, quant_method)
|
||||
@@ -127,10 +128,10 @@ def weight_dequant_wint4(w_int, w_scale):
|
||||
return w_fp32
|
||||
|
||||
|
||||
def weight_dequant(w_int, w_scale, algo="weight_only_int8"):
|
||||
if algo == "weight_only_int8":
|
||||
def weight_dequant(w_int, w_scale, algo="w_channelwise_int8_a_float32"):
|
||||
if algo == "w_channelwise_int8_a_float32":
|
||||
return weight_dequant_wint8(w_int, w_scale)
|
||||
elif algo == "weight_only_int4":
|
||||
elif algo == "w_channelwise_int4_a_tokenwise_int15":
|
||||
return weight_dequant_wint4(w_int, w_scale)
|
||||
else:
|
||||
return None, None
|
||||
|
||||
Reference in New Issue
Block a user