[Feature] support w4afp8 v1_loader and v0_loader(tp>1) (#5757)

* support

* fix

* support w4afp8 v1_loader and v0_loader

* fix

* fix test

* fix test

* fix test

* fix moe.py

* add test_ernie_4_5_w4afp8

* add test

* delete tensor

* fix test

* fix

* add

* fix test
This commit is contained in:
lizexu123
2025-12-30 14:11:52 +08:00
committed by GitHub
parent e78e22ebd5
commit 44a13e4557
7 changed files with 615 additions and 31 deletions
@@ -14,7 +14,8 @@
import os
import re
file_dir = "./gpu_ops/w4afp8_gemm/"
script_dir = os.path.dirname(os.path.abspath(__file__))
file_dir = os.path.join(script_dir, "..", "gpu_ops", "w4afp8_gemm") + os.sep
gemm_template_head = """
#pragma once
@@ -85,7 +86,15 @@ void w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}(
"""
# [M, K, Number of experts, token Padding Size, weight K group size]
gemm_case = [[256, 256, 2, 0, 128], [512, 256, 2, 0, 128], [256, 5120, 128, 0, 128]]
gemm_case = [
[256, 256, 2, 0, 128],
[512, 256, 2, 0, 128],
[256, 5120, 128, 0, 128],
[3072, 2560, 64, 0, 128],
[2560, 1536, 64, 0, 128],
[1536, 2560, 64, 0, 128],
[2560, 768, 64, 0, 128],
]
dtype = ["BF16"]