[Features] add audio request & fix embedding bug (#5201)

* [Features] add audio request & fix embedding bug

* fix bug
This commit is contained in:
ming1753
2025-12-01 11:12:17 +08:00
committed by GitHub
parent 9f4977eb74
commit 70ec1e17c1
2 changed files with 46 additions and 14 deletions
+17 -12
View File
@@ -106,6 +106,8 @@ class VocabParallelEmbedding(nn.Layer):
params_dtype: str = "bfloat16",
prefix="",
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
org_num_embeddings: int | None = None,
general=False,
) -> None:
"""
Initialize the VocabParallelEmbedding layer for the model.
@@ -132,17 +134,23 @@ class VocabParallelEmbedding(nn.Layer):
self.max_position_embeddings: int = fd_config.model_config.max_position_embeddings
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
self.params_dtype: str = params_dtype
self.padding_size = padding_size
self.org_vocab_size = num_embeddings
self.general = general # used for general Embedding
self.num_embeddings = num_embeddings
num_added_embeddings = num_embeddings - self.org_vocab_size
self.padding_size = padding_size
if self.general:
self.org_vocab_size = num_embeddings
self.num_embeddings_padded = num_embeddings
self.org_vocab_size_padded = num_embeddings
else:
self.org_vocab_size = org_num_embeddings or num_embeddings
num_added_embeddings = num_embeddings - self.org_vocab_size
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size, self.padding_size)
self.num_embeddings_padded = pad_vocab_size(
self.org_vocab_size_padded + num_added_embeddings, self.padding_size
)
assert self.org_vocab_size_padded <= self.num_embeddings_padded
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size, self.padding_size)
self.num_embeddings_padded = pad_vocab_size(
self.org_vocab_size_padded + num_added_embeddings, self.padding_size
)
assert self.org_vocab_size_padded <= self.num_embeddings_padded
self.shard_indices = self._get_indices(
self.num_embeddings_padded,
self.org_vocab_size_padded,
@@ -152,9 +160,6 @@ class VocabParallelEmbedding(nn.Layer):
self.world_size,
)
if num_embeddings % self.world_size != 0:
self.num_embeddings_padded = pad_vocab_size(num_embeddings, self.padding_size)
if not self.column_cut:
self.embeddings = fleet.meta_parallel.VocabParallelEmbedding(
self.num_embeddings_padded,
@@ -188,7 +193,7 @@ class VocabParallelEmbedding(nn.Layer):
Args:
state_dict (dict): A dictionary containing the checkpoint weights and biases.
"""
if self.tie_word_embeddings:
if self.tie_word_embeddings and not self.general:
weight_tensor = get_tensor(state_dict[self.prefix + ".weight"]).astype(paddle.get_default_dtype())
else:
weight_tensor = get_tensor(state_dict.pop(self.prefix + ".weight")).astype(paddle.get_default_dtype())