mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feat] ernie4_5_vl_moe support CudaGraph (#3226)
* delete dynamic control flow for decode * coda-style * fix scatter/gather typos and use input stream instead default stream * support 0-Size Tensor * update runner and model * using static mem address as input * fix mem leak * refine code * update mm_buffer * fix typo * fix buffersize * fix unk token * refine code * refine * support other arch * open cudagraph in vlci * fix * update * update * update * fix cmd * update --------- Co-authored-by: aquagull <hongyuh@qq.com> Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
@@ -32,6 +32,7 @@ from paddleformers.utils.log import logger
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.graph_optimization.decorator import (
|
||||
cuda_graph_buffers,
|
||||
support_graph_optimization,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
||||
@@ -66,12 +67,23 @@ class Ernie4_5_VLAttention(Ernie4_5_Attention):
|
||||
|
||||
@dataclass
|
||||
class VLMoEMeta:
|
||||
image_input: Optional[paddle.Tensor] = None
|
||||
text_input: Optional[paddle.Tensor] = None
|
||||
text_index: Optional[paddle.Tensor] = None
|
||||
image_index: Optional[paddle.Tensor] = None
|
||||
token_type_ids: Optional[paddle.Tensor] = None
|
||||
fake_hidden_states: Optional[paddle.Tensor] = None
|
||||
image_input: paddle.Tensor
|
||||
text_input: paddle.Tensor
|
||||
text_index: paddle.Tensor
|
||||
image_index: paddle.Tensor
|
||||
token_type_ids: paddle.Tensor
|
||||
image_token_num: paddle.Tensor
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"VLMoEMeta(\n"
|
||||
f" image_input: {self.image_input}, pointer: {self.image_input.data_ptr()}\n"
|
||||
f" text_input: {self.text_input}, pointer: {self.text_input.data_ptr()}\n"
|
||||
f" text_index: {self.text_index}, pointer: {self.text_index.data_ptr()}\n"
|
||||
f" image_index: {self.image_index}, pointer: {self.image_index.data_ptr()}\n"
|
||||
f" token_type_ids: {self.token_type_ids}, pointer: {self.token_type_ids.data_ptr()}\n\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
|
||||
class Ernie4_5_VLMoeBlock(nn.Layer):
|
||||
@@ -266,31 +278,26 @@ class Ernie4_5_VLMoE(nn.Layer):
|
||||
def forward(self, hidden_states: paddle.Tensor, vl_moe_meta: VLMoEMeta):
|
||||
if self.num_shared_experts > 0:
|
||||
shared_experts_out = self.shared_experts(hidden_states)
|
||||
if vl_moe_meta.image_input is not None:
|
||||
text_image_gather_scatter(
|
||||
hidden_states,
|
||||
vl_moe_meta.text_input,
|
||||
vl_moe_meta.image_input,
|
||||
vl_moe_meta.token_type_ids,
|
||||
vl_moe_meta.text_index,
|
||||
vl_moe_meta.image_index,
|
||||
True,
|
||||
)
|
||||
text_out = self.text_fused_moe(vl_moe_meta.text_input)
|
||||
image_out = self.image_fused_moe(vl_moe_meta.image_input)
|
||||
text_image_gather_scatter(
|
||||
hidden_states,
|
||||
text_out,
|
||||
image_out,
|
||||
vl_moe_meta.token_type_ids,
|
||||
vl_moe_meta.text_index,
|
||||
vl_moe_meta.image_index,
|
||||
False,
|
||||
)
|
||||
else:
|
||||
hidden_states = self.text_fused_moe(hidden_states)
|
||||
if vl_moe_meta.fake_hidden_states is not None:
|
||||
self.image_fused_moe(vl_moe_meta.fake_hidden_states)
|
||||
text_image_gather_scatter(
|
||||
hidden_states,
|
||||
vl_moe_meta.text_input,
|
||||
vl_moe_meta.image_input,
|
||||
vl_moe_meta.token_type_ids,
|
||||
vl_moe_meta.text_index,
|
||||
vl_moe_meta.image_index,
|
||||
True,
|
||||
)
|
||||
text_out = self.text_fused_moe(vl_moe_meta.text_input)
|
||||
image_out = self.image_fused_moe(vl_moe_meta.image_input)
|
||||
text_image_gather_scatter(
|
||||
hidden_states,
|
||||
text_out,
|
||||
image_out,
|
||||
vl_moe_meta.token_type_ids,
|
||||
vl_moe_meta.text_index,
|
||||
vl_moe_meta.image_index,
|
||||
False,
|
||||
)
|
||||
if self.num_shared_experts > 0:
|
||||
hidden_states += shared_experts_out
|
||||
if self.tp_size > 1:
|
||||
@@ -394,6 +401,40 @@ class Ernie4_5_VLDecoderLayer(nn.Layer):
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@cuda_graph_buffers(
|
||||
{
|
||||
"text_input": {
|
||||
"shape": ["parallel_config.max_model_len", "model_config.hidden_size"],
|
||||
"dtype": "model_config.dtype",
|
||||
"value": 1,
|
||||
},
|
||||
"image_input": {
|
||||
"shape": ["parallel_config.max_model_len", "model_config.hidden_size"],
|
||||
"dtype": "model_config.dtype",
|
||||
"value": 1,
|
||||
},
|
||||
"text_index": {
|
||||
"shape": ["parallel_config.max_model_len"],
|
||||
"dtype": "int32",
|
||||
"value": 0,
|
||||
},
|
||||
"image_index": {
|
||||
"shape": ["parallel_config.max_model_len"],
|
||||
"dtype": "int32",
|
||||
"value": 0,
|
||||
},
|
||||
"token_type_ids": {
|
||||
"shape": ["parallel_config.max_model_len"],
|
||||
"dtype": "int32",
|
||||
"value": -1,
|
||||
},
|
||||
"image_token_num": {
|
||||
"shape": [1],
|
||||
"dtype": "int64",
|
||||
"value": 0,
|
||||
},
|
||||
}
|
||||
)
|
||||
@support_graph_optimization
|
||||
class Ernie4_5_VLModel(nn.Layer):
|
||||
def __init__(
|
||||
@@ -454,59 +495,46 @@ class Ernie4_5_VLModel(nn.Layer):
|
||||
logger.info(f"Start load layer {i}")
|
||||
self.layers[i].load_state_dict(state_dict)
|
||||
|
||||
def forward(
|
||||
def prepare_vl_moe_meta(
|
||||
self,
|
||||
ids_remove_padding: paddle.Tensor,
|
||||
image_features: Optional[paddle.Tensor],
|
||||
forward_meta: ForwardMeta,
|
||||
):
|
||||
text_input = None
|
||||
image_input = None
|
||||
text_index = None
|
||||
image_index = None
|
||||
fake_hidden_states = None
|
||||
) -> VLMoEMeta:
|
||||
|
||||
hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding)
|
||||
token_num, hidden_dim = hidden_states.shape
|
||||
|
||||
# -----------------------
|
||||
image_mask = ids_remove_padding == self.im_patch_id
|
||||
token_type_ids = image_mask.cast("int32")
|
||||
image_token_num = image_mask.sum()
|
||||
token_num = ids_remove_padding.shape[0]
|
||||
text_token_num = paddle.maximum((token_num - image_token_num), paddle.ones([], dtype="int64"))
|
||||
|
||||
token_type_ids = image_mask.cast("int32")
|
||||
if self.fd_config.parallel_config.use_ep is True:
|
||||
fake_hidden_states = paddle.empty(
|
||||
shape=[0, self.fd_config.model_config.hidden_size],
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
text_input = fake_hidden_states
|
||||
# The scenario requiring padding is CUDA graph, thus we only need to pad the maximum capture size.
|
||||
self._mm_buffers["token_type_ids"][: self.fd_config.graph_opt_config.max_capture_size].fill_(-1)
|
||||
self._mm_buffers["token_type_ids"].copy_(token_type_ids, False)
|
||||
self._mm_buffers["image_token_num"].copy_(image_token_num, False)
|
||||
|
||||
if image_token_num > 0:
|
||||
hidden_states[image_mask] = image_features.cast(self._dtype)
|
||||
text_input = paddle.ones(
|
||||
shape=[text_token_num, hidden_dim],
|
||||
dtype=self._dtype,
|
||||
)
|
||||
image_input = paddle.ones(
|
||||
shape=[image_token_num, hidden_dim],
|
||||
dtype=self._dtype,
|
||||
)
|
||||
text_index = paddle.zeros_like(image_mask, dtype="int32")
|
||||
image_index = paddle.zeros_like(image_mask, dtype="int32")
|
||||
text_image_index_out(token_type_ids, text_index, image_index)
|
||||
|
||||
vl_moe_meta = VLMoEMeta(
|
||||
text_input=text_input,
|
||||
image_input=image_input,
|
||||
text_index=text_index,
|
||||
image_index=image_index,
|
||||
token_type_ids=token_type_ids,
|
||||
fake_hidden_states=fake_hidden_states,
|
||||
return VLMoEMeta(
|
||||
text_input=self._mm_buffers["text_input"][:text_token_num],
|
||||
image_input=self._mm_buffers["image_input"][:image_token_num],
|
||||
text_index=self._mm_buffers["text_index"][:token_num],
|
||||
image_index=self._mm_buffers["image_index"][:token_num],
|
||||
token_type_ids=self._mm_buffers["token_type_ids"][:token_num],
|
||||
image_token_num=self._mm_buffers["image_token_num"],
|
||||
)
|
||||
# -----------------------
|
||||
|
||||
def get_input_embeddings(self, ids_remove_padding: paddle.Tensor) -> paddle.Tensor:
|
||||
return self.embed_tokens(ids_remove_padding=ids_remove_padding)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_embeddings: paddle.Tensor,
|
||||
ids_remove_padding: paddle.Tensor,
|
||||
forward_meta: ForwardMeta,
|
||||
vl_moe_meta: VLMoEMeta,
|
||||
):
|
||||
text_image_index_out(vl_moe_meta.token_type_ids, vl_moe_meta.text_index, vl_moe_meta.image_index)
|
||||
|
||||
hidden_states = input_embeddings
|
||||
residual = None
|
||||
|
||||
for i in range(self.num_layers):
|
||||
hidden_states, residual = self.layers[i](
|
||||
forward_meta,
|
||||
@@ -517,17 +545,15 @@ class Ernie4_5_VLModel(nn.Layer):
|
||||
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
# -----------------------
|
||||
max_seq_len, max_seq_len_index = paddle.topk(forward_meta.seq_lens_this_time, k=1)
|
||||
hidden_states = extract_text_token_output(
|
||||
max_seq_len,
|
||||
max_seq_len_index.cast("int32"),
|
||||
image_token_num.cast("int32"),
|
||||
vl_moe_meta.image_token_num.cast("int32"),
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.cu_seqlens_q,
|
||||
hidden_states.cast("float32"),
|
||||
).cast(self._dtype)
|
||||
# -----------------------
|
||||
|
||||
out = self.norm(hidden_states)
|
||||
|
||||
@@ -552,6 +578,12 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
|
||||
# ernie
|
||||
self.ernie = Ernie4_5_VLModel(fd_config=fd_config)
|
||||
|
||||
# Persistent buffers for CUDA graphs.
|
||||
self._input_embeddings = paddle.zeros(
|
||||
[fd_config.parallel_config.max_model_len, fd_config.model_config.hidden_size],
|
||||
dtype=fd_config.model_config.dtype,
|
||||
)
|
||||
|
||||
self.ori_vocab_size = fd_config.model_config.ori_vocab_size
|
||||
|
||||
self.lm_head = ParallelLMHead(
|
||||
@@ -733,16 +765,33 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
|
||||
self.ernie.layers[i].mlp.text_fused_moe(fake_hidden_states)
|
||||
self.ernie.layers[i].mlp.image_fused_moe(fake_hidden_states)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
ids_remove_padding: paddle.Tensor,
|
||||
image_features: Optional[paddle.Tensor] = None,
|
||||
) -> paddle.Tensor:
|
||||
input_embeddings = self.ernie.get_input_embeddings(ids_remove_padding=ids_remove_padding)
|
||||
if image_features is not None and len(image_features) > 0:
|
||||
input_embeddings[ids_remove_padding == self.ernie.im_patch_id] = image_features.cast(self.ernie._dtype)
|
||||
return input_embeddings
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ids_remove_padding: paddle.Tensor,
|
||||
image_features: Optional[paddle.Tensor],
|
||||
forward_meta: ForwardMeta,
|
||||
):
|
||||
input_embeddings = self.get_input_embeddings(
|
||||
ids_remove_padding=ids_remove_padding, image_features=image_features
|
||||
)
|
||||
self._input_embeddings.copy_(input_embeddings, False)
|
||||
vl_moe_meta = self.ernie.prepare_vl_moe_meta(ids_remove_padding=ids_remove_padding)
|
||||
|
||||
hidden_states = self.ernie(
|
||||
input_embeddings=self._input_embeddings,
|
||||
ids_remove_padding=ids_remove_padding,
|
||||
image_features=image_features,
|
||||
forward_meta=forward_meta,
|
||||
vl_moe_meta=vl_moe_meta,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
Reference in New Issue
Block a user