[BugFix]Dev fix custom ar unstable result (#4437)

This commit is contained in:
chen
2025-10-17 11:47:16 +08:00
committed by GitHub
parent 6160145f82
commit b134e6afe6
17 changed files with 25 additions and 24 deletions
+2 -1
View File
@@ -59,7 +59,7 @@ try:
global _TP_AR
if _TP_AR is not None and _TP_AR.should_custom_ar(input_):
# TODO: supports different_group custom allreduce
_TP_AR.custom_all_reduce(input_)
input_ = _TP_AR.custom_all_reduce(input_)
elif paddle.in_dynamic_mode():
if group_ is not None:
dist.all_reduce(input_, group=group_)
@@ -69,6 +69,7 @@ try:
dist.all_reduce(input_, group=mp_group)
else:
dist.all_reduce(input_)
return input_
except:
tensor_model_parallel_all_reduce = None