[BugFix]Dev fix custom ar unstable result (#4437)

This commit is contained in:
chen
2025-10-17 11:47:16 +08:00
committed by GitHub
parent 6160145f82
commit b134e6afe6
17 changed files with 25 additions and 24 deletions
@@ -298,7 +298,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
)
if layer.reduce_results and layer.tp_size > 1:
tensor_model_parallel_all_reduce(fused_moe_out, layer.fd_config.parallel_config.tp_group)
fused_moe_out = tensor_model_parallel_all_reduce(fused_moe_out, layer.fd_config.parallel_config.tp_group)
return fused_moe_out
@@ -594,6 +594,6 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
1.0,
)[0]
if layer.tp_size > 1:
tensor_model_parallel_all_reduce(tmp_ffn_out)
tmp_ffn_out = tensor_model_parallel_all_reduce(tmp_ffn_out)
return tmp_ffn_out
@@ -354,6 +354,6 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
ffn_out = ffn_out.sum(axis=1)
if layer.reduce_results and layer.tp_size > 1:
tensor_model_parallel_all_reduce(ffn_out)
ffn_out = tensor_model_parallel_all_reduce(ffn_out)
return ffn_out
@@ -393,7 +393,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
down_proj_out.reshape_([token_num, top_k, hidden_size])
out = down_proj_out.sum(axis=1)
if layer.reduce_results and layer.tp_size > 1:
tensor_model_parallel_all_reduce(out)
out = tensor_model_parallel_all_reduce(out)
return out
@@ -767,7 +767,7 @@ class Wfp8Afp8MoEMethod(QuantMethodBase):
out = down_proj_out.sum(axis=1)
if layer.reduce_results and layer.tp_size > 1:
tensor_model_parallel_all_reduce(out)
out = tensor_model_parallel_all_reduce(out)
return out
@@ -1056,7 +1056,7 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
out = down_proj_out.sum(axis=1)
if layer.tp_size > 1:
tensor_model_parallel_all_reduce(out)
out = tensor_model_parallel_all_reduce(out)
return out
@@ -1460,6 +1460,6 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
out = intermediate_cache3.sum(axis=1)
if layer.tp_size > 1:
tensor_model_parallel_all_reduce(out)
out = tensor_model_parallel_all_reduce(out)
return out
@@ -318,7 +318,7 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
)
if layer.tp_size > 1:
tensor_model_parallel_all_reduce(fused_moe_out)
fused_moe_out = tensor_model_parallel_all_reduce(fused_moe_out)
return fused_moe_out
@@ -488,6 +488,6 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
fused_moe_out = paddle.sum(intermediate_cache3, axis=1)
if layer.tp_size > 1:
tensor_model_parallel_all_reduce(fused_moe_out)
fused_moe_out = tensor_model_parallel_all_reduce(fused_moe_out)
return fused_moe_out