[Iluvatar] Fix cuda graph error for tp > 1 in ernie models (#7126)

This commit is contained in:
yzwu
2026-04-01 19:13:34 +08:00
committed by GitHub
parent fdfc908e2f
commit ceaf5df350
5 changed files with 75 additions and 11 deletions
@@ -43,7 +43,7 @@ class DefaultModelLoader(BaseModelLoader):
def clean_memory_fragments(self, state_dict: dict) -> None:
"""clean_memory_fragments"""
if current_platform.is_cuda() or current_platform.is_maca():
if current_platform.is_cuda() or current_platform.is_maca() or current_platform.is_iluvatar():
if state_dict:
for k, v in state_dict.items():
if isinstance(v, paddle.Tensor):