[Metax] adapt to the latest develop (#6282)

This commit is contained in:
xiaozude
2026-01-30 15:21:20 +08:00
committed by GitHub
parent 18ebce9dec
commit 030647521a
14 changed files with 754 additions and 370 deletions
+3 -3
View File
@@ -108,7 +108,7 @@ class MTPProposer(Proposer):
if current_platform.is_xpu():
self._propose = self._propose_xpu
elif current_platform.is_cuda():
elif current_platform.is_cuda() or current_platform.is_maca():
self._propose = self._propose_cuda
else:
raise RuntimeError("Unsupported platform.")
@@ -350,7 +350,7 @@ class MTPProposer(Proposer):
self.model_inputs["decoder_tile_ids_per_batch"] = paddle.zeros_like(
self.target_model_inputs["decoder_tile_ids_per_batch"]
)
if current_platform.is_xpu():
if current_platform.is_xpu() or current_platform.is_maca():
self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like(
self.target_model_inputs["decoder_num_blocks_cpu"]
).cpu()
@@ -1308,7 +1308,7 @@ class MTPProposer(Proposer):
elif current_platform.is_xpu():
paddle.device.xpu.empty_cache()
else:
raise NotImplementedError
paddle.device.empty_cache()
def _get_cache_type(self):
cache_type = None