[Metax] adapt to gemm interface on different versions of maca (#5905)

Co-authored-by: root <root@lt-wks-10-0-180-15.pub.metax-tech.com>
This commit is contained in:
MingkunZhang
2026-01-07 10:02:24 +08:00
committed by GitHub
parent 1ee285c2d6
commit 7ad5737560
3 changed files with 83 additions and 12 deletions
+39 -10
View File
@@ -653,20 +653,49 @@ elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
sources += find_end_files("gpu_ops/speculate_decoding", ".cu")
sources += find_end_files("gpu_ops/speculate_decoding", ".cc")
metax_extra_compile_args = {
"cxx": ["-O3"],
"nvcc": [
"-O3",
"-Ithird_party/nlohmann_json/include",
"-Igpu_ops",
"-DPADDLE_DEV",
"-DPADDLE_WITH_CUSTOM_DEVICE_METAX_GPU",
],
}
def get_maca_version(version_file: str = "/opt/maca/Version.txt") -> list[int]:
try:
with open(version_file, "r", encoding="utf-8") as f:
version_str = f.readline().strip()
target_version = [int(part) for part in version_str.split(":")[1].split(".")]
except Exception as e:
print(f"Trigger exception: {type(e).__name__} - {e}")
raise
return target_version
maca_version = get_maca_version(f"{maca_path}/Version.txt")
if len(maca_version) == 4:
major_version = maca_version[0]
minor_version = maca_version[1]
patch_version = maca_version[2]
build_version = maca_version[3]
cur_maca_version = (
((major_version & 0xFF) << 24)
| ((minor_version & 0xFF) << 16)
| ((patch_version & 0xFF) << 8)
| ((build_version & 0xFF) << 0)
)
metax_extra_compile_args["nvcc"].append(f"-DMACA_VERSION={cur_maca_version}")
else:
raise ValueError(f"MACA version invalid - {maca_version}")
setup(
name="fastdeploy_ops",
ext_modules=CUDAExtension(
sources=sources,
extra_compile_args={
"cxx": ["-O3"],
"nvcc": [
"-O3",
"-Ithird_party/nlohmann_json/include",
"-Igpu_ops",
"-DPADDLE_DEV",
"-DPADDLE_WITH_CUSTOM_DEVICE_METAX_GPU",
],
},
extra_compile_args=metax_extra_compile_args,
library_dirs=[os.path.join(maca_path, "lib")],
extra_link_args=["-lruntime_cu", "-lmctlassEx"],
include_dirs=[