mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[TSP] last_norm allgather move to model.py (#5924)
* support_lastnorm_gather_split_dev * support_lastnorm_gather_split_dev1 * support_lastnorm_gather_split_dev3 * support_lastnorm_gather_split_dev4 * support_lastnorm_gather_split_dev5
This commit is contained in:
@@ -105,14 +105,14 @@ class RMSNorm(nn.Layer):
|
||||
self.tp_rank = self.fd_config.parallel_config.tensor_parallel_rank
|
||||
self.tp_group = self.fd_config.parallel_config.tp_group
|
||||
is_input_norm = prefix.endswith(".input_layernorm")
|
||||
is_last_norm = prefix.endswith(".norm")
|
||||
self.is_last_norm = prefix.endswith(".norm")
|
||||
self.split_x = (
|
||||
self.fd_config.parallel_config.use_sequence_parallel_moe
|
||||
and self.layer_id == self.fd_config.model_config.moe_layer_start_index
|
||||
and is_input_norm
|
||||
)
|
||||
self.allgather_out = self.fd_config.parallel_config.use_sequence_parallel_moe and (
|
||||
(self.layer_id > self.fd_config.model_config.moe_layer_start_index and is_input_norm) or is_last_norm
|
||||
(self.layer_id > self.fd_config.model_config.moe_layer_start_index and is_input_norm)
|
||||
)
|
||||
|
||||
self.init_weight()
|
||||
|
||||
Reference in New Issue
Block a user