[Metax] adapt to the latest develop (#6282)

This commit is contained in:
xiaozude
2026-01-30 15:21:20 +08:00
committed by GitHub
parent 18ebce9dec
commit 030647521a
14 changed files with 754 additions and 370 deletions
@@ -606,7 +606,7 @@ class SpeculativeSampler(nn.Layer):
def __init__(self, fd_config: FDConfig):
""" """
super().__init__()
if current_platform.is_cuda():
if current_platform.is_cuda() or current_platform.is_maca():
self.forward = self.forward_cuda
elif current_platform.is_xpu():
self.forward = self.forward_xpu
@@ -972,7 +972,7 @@ class MTPSampler(nn.Layer):
def __init__(self, fd_config: FDConfig):
""" """
super().__init__()
if current_platform.is_cuda():
if current_platform.is_cuda() or current_platform.is_maca():
self.forward = self.forward_cuda
elif current_platform.is_xpu():
self.forward = self.forward_xpu