[TSP] last_norm allgather move to model.py (#5924)

* support_lastnorm_gather_split_dev

* support_lastnorm_gather_split_dev1

* support_lastnorm_gather_split_dev3

* support_lastnorm_gather_split_dev4

* support_lastnorm_gather_split_dev5
This commit is contained in:
xiaoluomi
2026-01-08 15:36:33 +08:00
committed by GitHub
parent 8e11d719f3
commit 2bb838fed9
9 changed files with 30 additions and 8 deletions
@@ -105,14 +105,14 @@ class RMSNorm(nn.Layer):
self.tp_rank = self.fd_config.parallel_config.tensor_parallel_rank
self.tp_group = self.fd_config.parallel_config.tp_group
is_input_norm = prefix.endswith(".input_layernorm")
is_last_norm = prefix.endswith(".norm")
self.is_last_norm = prefix.endswith(".norm")
self.split_x = (
self.fd_config.parallel_config.use_sequence_parallel_moe
and self.layer_id == self.fd_config.model_config.moe_layer_start_index
and is_input_norm
)
self.allgather_out = self.fd_config.parallel_config.use_sequence_parallel_moe and (
(self.layer_id > self.fd_config.model_config.moe_layer_start_index and is_input_norm) or is_last_norm
(self.layer_id > self.fd_config.model_config.moe_layer_start_index and is_input_norm)
)
self.init_weight()