[XPU] refine moe_expert_ffn ut (#5743)

This commit is contained in:
RuohengMa
2025-12-25 10:35:24 +08:00
committed by GitHub
parent 9624bf3c6e
commit e154c03416
@@ -81,16 +81,17 @@ def weight_quant_wint4(w_fp32):
return w_int4, w_max.reshape([-1])
def weight_quant(w_fp32, algo="weight_only_int8"):
if algo == "weight_only_int8":
def weight_quant(w_fp32, algo="w_channelwise_int8_a_float32"):
if algo == "w_channelwise_int8_a_float32":
return weight_quant_wint8(w_fp32)
elif algo == "weight_only_int4":
elif algo == "w_channelwise_int4_a_tokenwise_int15":
return weight_quant_wint4(w_fp32)
else:
return None, None
quant_method = "weight_only_int4"
quant_method = "w_channelwise_int4_a_tokenwise_int15"
# quant_method = "w_channelwise_int8_a_float32"
print(f"quant_method={quant_method}, used_in_ep_low_latency={used_in_ep_low_latency}")
ffn1_quant_w, ffn1_w_scale = weight_quant(ffn1_w, quant_method)
ffn2_quant_w, ffn2_w_scale = weight_quant(ffn2_w, quant_method)
@@ -127,10 +128,10 @@ def weight_dequant_wint4(w_int, w_scale):
return w_fp32
def weight_dequant(w_int, w_scale, algo="weight_only_int8"):
if algo == "weight_only_int8":
def weight_dequant(w_int, w_scale, algo="w_channelwise_int8_a_float32"):
if algo == "w_channelwise_int8_a_float32":
return weight_dequant_wint8(w_int, w_scale)
elif algo == "weight_only_int4":
elif algo == "w_channelwise_int4_a_tokenwise_int15":
return weight_dequant_wint4(w_int, w_scale)
else:
return None, None