mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-22 16:07:51 +08:00
This reverts commit 8c3513a410.
This commit is contained in:
@@ -105,14 +105,14 @@ class RMSNorm(nn.Layer):
|
||||
self.tp_rank = self.fd_config.parallel_config.tensor_parallel_rank
|
||||
self.tp_group = self.fd_config.parallel_config.tp_group
|
||||
is_input_norm = prefix.endswith(".input_layernorm")
|
||||
is_last_norm = prefix.endswith(".norm")
|
||||
self.is_last_norm = prefix.endswith(".norm")
|
||||
self.split_x = (
|
||||
self.fd_config.parallel_config.use_sequence_parallel_moe
|
||||
and self.layer_id == self.fd_config.model_config.moe_layer_start_index
|
||||
and is_input_norm
|
||||
)
|
||||
self.allgather_out = self.fd_config.parallel_config.use_sequence_parallel_moe and (
|
||||
(self.layer_id > self.fd_config.model_config.moe_layer_start_index and is_input_norm) or is_last_norm
|
||||
(self.layer_id > self.fd_config.model_config.moe_layer_start_index and is_input_norm)
|
||||
)
|
||||
|
||||
self.init_weight()
|
||||
|
||||
@@ -592,6 +592,9 @@ class DeepSeekV3Model(nn.Layer):
|
||||
)
|
||||
out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]
|
||||
|
||||
if self.norm.is_last_norm and self.norm.fd_config.parallel_config.use_sequence_parallel_moe:
|
||||
out = self.norm.allgather(out, forward_meta.ids_remove_padding.shape[0])
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
||||
@@ -477,6 +477,9 @@ class Ernie4_5_Model(nn.Layer):
|
||||
|
||||
out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]
|
||||
|
||||
if self.norm.is_last_norm and self.norm.fd_config.parallel_config.use_sequence_parallel_moe:
|
||||
out = self.norm.allgather(out, forward_meta.ids_remove_padding.shape[0])
|
||||
|
||||
if current_platform.is_iluvatar() and forward_meta.attn_backend.mixed:
|
||||
out = forward_meta.attn_backend.reverse_transpose(out)
|
||||
|
||||
|
||||
@@ -325,7 +325,10 @@ class Ernie4_5_MTPModel(nn.Layer):
|
||||
for i in range(self.num_layers):
|
||||
hidden_states, residual = self.mtp_block[i](forward_meta, hidden_states, residual)
|
||||
|
||||
hidden_states = self.norm(hidden_states, residual)[0]
|
||||
hidden_states = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]
|
||||
|
||||
if self.norm.is_last_norm and self.norm.fd_config.parallel_config.use_sequence_parallel_moe:
|
||||
hidden_states = self.norm.allgather(hidden_states, forward_meta.ids_remove_padding.shape[0])
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -396,7 +399,7 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM):
|
||||
),
|
||||
)
|
||||
|
||||
def compute_logits(self, hidden_states: paddle.Tensor):
|
||||
def compute_logits(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta):
|
||||
"""
|
||||
compute logits
|
||||
"""
|
||||
|
||||
@@ -556,6 +556,9 @@ class Ernie4_5_VLModel(nn.Layer):
|
||||
|
||||
out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]
|
||||
|
||||
if self.norm.is_last_norm and self.norm.fd_config.parallel_config.use_sequence_parallel_moe:
|
||||
out = self.norm.allgather(out, forward_meta.ids_remove_padding.shape[0])
|
||||
|
||||
if current_platform.is_iluvatar() and forward_meta.attn_backend.mixed:
|
||||
out = forward_meta.attn_backend.reverse_transpose(out)
|
||||
|
||||
|
||||
@@ -370,6 +370,9 @@ class Glm4MoeModel(nn.Layer):
|
||||
|
||||
out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]
|
||||
|
||||
if self.norm.is_last_norm and self.norm.fd_config.parallel_config.use_sequence_parallel_moe:
|
||||
out = self.norm.allgather(out, forward_meta.ids_remove_padding.shape[0])
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
||||
@@ -213,8 +213,12 @@ class GptOssModel(nn.Layer):
|
||||
for i in range(self.num_layers):
|
||||
hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual)
|
||||
|
||||
hidden_states = self.norm(hidden_states, residual)[0]
|
||||
return hidden_states
|
||||
out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]
|
||||
|
||||
if self.norm.is_last_norm and self.norm.fd_config.parallel_config.use_sequence_parallel_moe:
|
||||
out = self.norm.allgather(out, forward_meta.ids_remove_padding.shape[0])
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@ModelRegistry.register_model_class(
|
||||
|
||||
@@ -282,6 +282,9 @@ class Qwen3MoeModel(nn.Layer):
|
||||
|
||||
out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]
|
||||
|
||||
if self.norm.is_last_norm and self.norm.fd_config.parallel_config.use_sequence_parallel_moe:
|
||||
out = self.norm.allgather(out, forward_meta.ids_remove_padding.shape[0])
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
||||
@@ -1004,7 +1004,7 @@ class MTPProposer(Proposer):
|
||||
)
|
||||
|
||||
# 4. Compute logits, Sample
|
||||
logits = self.model.compute_logits(hidden_states)
|
||||
logits = self.model.compute_logits(hidden_states, forward_meta=self.forward_meta)
|
||||
if self.enable_logprob and self.enable_draft_logprob and substep == 0:
|
||||
first_token_logits = self.model.compute_logits(self.model_inputs["first_token_hidden_states"])
|
||||
|
||||
@@ -1118,7 +1118,7 @@ class MTPProposer(Proposer):
|
||||
model_output, self.model_inputs["cum_offsets"], self.forward_meta, self.model_inputs
|
||||
)
|
||||
# 4. Compute logits, Sample
|
||||
logits = self.model.compute_logits(hidden_states)
|
||||
logits = self.model.compute_logits(hidden_states, forward_meta=self.forward_meta)
|
||||
sampled_token_ids, sampler_output = self.sampler(
|
||||
logits,
|
||||
self.sampling_metadata,
|
||||
|
||||
Reference in New Issue
Block a user