[Feat] ernie4_5_vl_moe support CudaGraph (#3226)

* delete dynamic control flow for decode

* coda-style

* fix scatter/gather typos and use input stream instead default stream

* support 0-Size Tensor

* update runner and model

* using static mem address as input

* fix mem leak

* refine code

* update mm_buffer

* fix typo

* fix buffersize

* fix unk token

* refine code

* refine

* support other arch

* open cudagraph in vlci

* fix

* update

* update

* update

* fix cmd

* update

---------

Co-authored-by: aquagull <hongyuh@qq.com>
Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
Ayakouji
2025-09-10 13:11:57 +08:00
committed by GitHub
parent 9d0074a91a
commit 453487d5b0
9 changed files with 207 additions and 98 deletions
@@ -32,6 +32,7 @@ from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.graph_optimization.decorator import (
cuda_graph_buffers,
support_graph_optimization,
)
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
@@ -66,12 +67,23 @@ class Ernie4_5_VLAttention(Ernie4_5_Attention):
@dataclass
class VLMoEMeta:
image_input: Optional[paddle.Tensor] = None
text_input: Optional[paddle.Tensor] = None
text_index: Optional[paddle.Tensor] = None
image_index: Optional[paddle.Tensor] = None
token_type_ids: Optional[paddle.Tensor] = None
fake_hidden_states: Optional[paddle.Tensor] = None
image_input: paddle.Tensor
text_input: paddle.Tensor
text_index: paddle.Tensor
image_index: paddle.Tensor
token_type_ids: paddle.Tensor
image_token_num: paddle.Tensor
def __str__(self):
return (
f"VLMoEMeta(\n"
f" image_input: {self.image_input}, pointer: {self.image_input.data_ptr()}\n"
f" text_input: {self.text_input}, pointer: {self.text_input.data_ptr()}\n"
f" text_index: {self.text_index}, pointer: {self.text_index.data_ptr()}\n"
f" image_index: {self.image_index}, pointer: {self.image_index.data_ptr()}\n"
f" token_type_ids: {self.token_type_ids}, pointer: {self.token_type_ids.data_ptr()}\n\n"
f")"
)
class Ernie4_5_VLMoeBlock(nn.Layer):
@@ -266,31 +278,26 @@ class Ernie4_5_VLMoE(nn.Layer):
def forward(self, hidden_states: paddle.Tensor, vl_moe_meta: VLMoEMeta):
if self.num_shared_experts > 0:
shared_experts_out = self.shared_experts(hidden_states)
if vl_moe_meta.image_input is not None:
text_image_gather_scatter(
hidden_states,
vl_moe_meta.text_input,
vl_moe_meta.image_input,
vl_moe_meta.token_type_ids,
vl_moe_meta.text_index,
vl_moe_meta.image_index,
True,
)
text_out = self.text_fused_moe(vl_moe_meta.text_input)
image_out = self.image_fused_moe(vl_moe_meta.image_input)
text_image_gather_scatter(
hidden_states,
text_out,
image_out,
vl_moe_meta.token_type_ids,
vl_moe_meta.text_index,
vl_moe_meta.image_index,
False,
)
else:
hidden_states = self.text_fused_moe(hidden_states)
if vl_moe_meta.fake_hidden_states is not None:
self.image_fused_moe(vl_moe_meta.fake_hidden_states)
text_image_gather_scatter(
hidden_states,
vl_moe_meta.text_input,
vl_moe_meta.image_input,
vl_moe_meta.token_type_ids,
vl_moe_meta.text_index,
vl_moe_meta.image_index,
True,
)
text_out = self.text_fused_moe(vl_moe_meta.text_input)
image_out = self.image_fused_moe(vl_moe_meta.image_input)
text_image_gather_scatter(
hidden_states,
text_out,
image_out,
vl_moe_meta.token_type_ids,
vl_moe_meta.text_index,
vl_moe_meta.image_index,
False,
)
if self.num_shared_experts > 0:
hidden_states += shared_experts_out
if self.tp_size > 1:
@@ -394,6 +401,40 @@ class Ernie4_5_VLDecoderLayer(nn.Layer):
return hidden_states, residual
@cuda_graph_buffers(
{
"text_input": {
"shape": ["parallel_config.max_model_len", "model_config.hidden_size"],
"dtype": "model_config.dtype",
"value": 1,
},
"image_input": {
"shape": ["parallel_config.max_model_len", "model_config.hidden_size"],
"dtype": "model_config.dtype",
"value": 1,
},
"text_index": {
"shape": ["parallel_config.max_model_len"],
"dtype": "int32",
"value": 0,
},
"image_index": {
"shape": ["parallel_config.max_model_len"],
"dtype": "int32",
"value": 0,
},
"token_type_ids": {
"shape": ["parallel_config.max_model_len"],
"dtype": "int32",
"value": -1,
},
"image_token_num": {
"shape": [1],
"dtype": "int64",
"value": 0,
},
}
)
@support_graph_optimization
class Ernie4_5_VLModel(nn.Layer):
def __init__(
@@ -454,59 +495,46 @@ class Ernie4_5_VLModel(nn.Layer):
logger.info(f"Start load layer {i}")
self.layers[i].load_state_dict(state_dict)
def forward(
def prepare_vl_moe_meta(
self,
ids_remove_padding: paddle.Tensor,
image_features: Optional[paddle.Tensor],
forward_meta: ForwardMeta,
):
text_input = None
image_input = None
text_index = None
image_index = None
fake_hidden_states = None
) -> VLMoEMeta:
hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding)
token_num, hidden_dim = hidden_states.shape
# -----------------------
image_mask = ids_remove_padding == self.im_patch_id
token_type_ids = image_mask.cast("int32")
image_token_num = image_mask.sum()
token_num = ids_remove_padding.shape[0]
text_token_num = paddle.maximum((token_num - image_token_num), paddle.ones([], dtype="int64"))
token_type_ids = image_mask.cast("int32")
if self.fd_config.parallel_config.use_ep is True:
fake_hidden_states = paddle.empty(
shape=[0, self.fd_config.model_config.hidden_size],
dtype=paddle.get_default_dtype(),
)
text_input = fake_hidden_states
# The scenario requiring padding is CUDA graph, thus we only need to pad the maximum capture size.
self._mm_buffers["token_type_ids"][: self.fd_config.graph_opt_config.max_capture_size].fill_(-1)
self._mm_buffers["token_type_ids"].copy_(token_type_ids, False)
self._mm_buffers["image_token_num"].copy_(image_token_num, False)
if image_token_num > 0:
hidden_states[image_mask] = image_features.cast(self._dtype)
text_input = paddle.ones(
shape=[text_token_num, hidden_dim],
dtype=self._dtype,
)
image_input = paddle.ones(
shape=[image_token_num, hidden_dim],
dtype=self._dtype,
)
text_index = paddle.zeros_like(image_mask, dtype="int32")
image_index = paddle.zeros_like(image_mask, dtype="int32")
text_image_index_out(token_type_ids, text_index, image_index)
vl_moe_meta = VLMoEMeta(
text_input=text_input,
image_input=image_input,
text_index=text_index,
image_index=image_index,
token_type_ids=token_type_ids,
fake_hidden_states=fake_hidden_states,
return VLMoEMeta(
text_input=self._mm_buffers["text_input"][:text_token_num],
image_input=self._mm_buffers["image_input"][:image_token_num],
text_index=self._mm_buffers["text_index"][:token_num],
image_index=self._mm_buffers["image_index"][:token_num],
token_type_ids=self._mm_buffers["token_type_ids"][:token_num],
image_token_num=self._mm_buffers["image_token_num"],
)
# -----------------------
def get_input_embeddings(self, ids_remove_padding: paddle.Tensor) -> paddle.Tensor:
return self.embed_tokens(ids_remove_padding=ids_remove_padding)
def forward(
self,
input_embeddings: paddle.Tensor,
ids_remove_padding: paddle.Tensor,
forward_meta: ForwardMeta,
vl_moe_meta: VLMoEMeta,
):
text_image_index_out(vl_moe_meta.token_type_ids, vl_moe_meta.text_index, vl_moe_meta.image_index)
hidden_states = input_embeddings
residual = None
for i in range(self.num_layers):
hidden_states, residual = self.layers[i](
forward_meta,
@@ -517,17 +545,15 @@ class Ernie4_5_VLModel(nn.Layer):
hidden_states = hidden_states + residual
# -----------------------
max_seq_len, max_seq_len_index = paddle.topk(forward_meta.seq_lens_this_time, k=1)
hidden_states = extract_text_token_output(
max_seq_len,
max_seq_len_index.cast("int32"),
image_token_num.cast("int32"),
vl_moe_meta.image_token_num.cast("int32"),
forward_meta.seq_lens_this_time,
forward_meta.cu_seqlens_q,
hidden_states.cast("float32"),
).cast(self._dtype)
# -----------------------
out = self.norm(hidden_states)
@@ -552,6 +578,12 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
# ernie
self.ernie = Ernie4_5_VLModel(fd_config=fd_config)
# Persistent buffers for CUDA graphs.
self._input_embeddings = paddle.zeros(
[fd_config.parallel_config.max_model_len, fd_config.model_config.hidden_size],
dtype=fd_config.model_config.dtype,
)
self.ori_vocab_size = fd_config.model_config.ori_vocab_size
self.lm_head = ParallelLMHead(
@@ -733,16 +765,33 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
self.ernie.layers[i].mlp.text_fused_moe(fake_hidden_states)
self.ernie.layers[i].mlp.image_fused_moe(fake_hidden_states)
def get_input_embeddings(
self,
ids_remove_padding: paddle.Tensor,
image_features: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
input_embeddings = self.ernie.get_input_embeddings(ids_remove_padding=ids_remove_padding)
if image_features is not None and len(image_features) > 0:
input_embeddings[ids_remove_padding == self.ernie.im_patch_id] = image_features.cast(self.ernie._dtype)
return input_embeddings
def forward(
self,
ids_remove_padding: paddle.Tensor,
image_features: Optional[paddle.Tensor],
forward_meta: ForwardMeta,
):
input_embeddings = self.get_input_embeddings(
ids_remove_padding=ids_remove_padding, image_features=image_features
)
self._input_embeddings.copy_(input_embeddings, False)
vl_moe_meta = self.ernie.prepare_vl_moe_meta(ids_remove_padding=ids_remove_padding)
hidden_states = self.ernie(
input_embeddings=self._input_embeddings,
ids_remove_padding=ids_remove_padding,
image_features=image_features,
forward_meta=forward_meta,
vl_moe_meta=vl_moe_meta,
)
return hidden_states