mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-25 01:55:45 +08:00
[Metax] adapt to the latest develop (#6282)
This commit is contained in:
@@ -606,7 +606,7 @@ class SpeculativeSampler(nn.Layer):
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
""" """
|
||||
super().__init__()
|
||||
if current_platform.is_cuda():
|
||||
if current_platform.is_cuda() or current_platform.is_maca():
|
||||
self.forward = self.forward_cuda
|
||||
elif current_platform.is_xpu():
|
||||
self.forward = self.forward_xpu
|
||||
@@ -972,7 +972,7 @@ class MTPSampler(nn.Layer):
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
""" """
|
||||
super().__init__()
|
||||
if current_platform.is_cuda():
|
||||
if current_platform.is_cuda() or current_platform.is_maca():
|
||||
self.forward = self.forward_cuda
|
||||
elif current_platform.is_xpu():
|
||||
self.forward = self.forward_xpu
|
||||
|
||||
Reference in New Issue
Block a user