mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Metax] adapt to the latest develop (#6282)
This commit is contained in:
@@ -108,7 +108,7 @@ class MTPProposer(Proposer):
|
||||
|
||||
if current_platform.is_xpu():
|
||||
self._propose = self._propose_xpu
|
||||
elif current_platform.is_cuda():
|
||||
elif current_platform.is_cuda() or current_platform.is_maca():
|
||||
self._propose = self._propose_cuda
|
||||
else:
|
||||
raise RuntimeError("Unsupported platform.")
|
||||
@@ -350,7 +350,7 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["decoder_tile_ids_per_batch"] = paddle.zeros_like(
|
||||
self.target_model_inputs["decoder_tile_ids_per_batch"]
|
||||
)
|
||||
if current_platform.is_xpu():
|
||||
if current_platform.is_xpu() or current_platform.is_maca():
|
||||
self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like(
|
||||
self.target_model_inputs["decoder_num_blocks_cpu"]
|
||||
).cpu()
|
||||
@@ -1308,7 +1308,7 @@ class MTPProposer(Proposer):
|
||||
elif current_platform.is_xpu():
|
||||
paddle.device.xpu.empty_cache()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
paddle.device.empty_cache()
|
||||
|
||||
def _get_cache_type(self):
|
||||
cache_type = None
|
||||
|
||||
Reference in New Issue
Block a user