[Optimization] default compile rdma, reduce cudagraph buffer size in mm, fix some config bug (#5121)

* default compile rdma, reduce cudagraph buffer size in mm, fix some config logic

* update

* update

* fix bug

* enhance rdma compile

* fix
This commit is contained in:
Yuanle Liu
2025-11-20 17:19:47 +08:00
committed by GitHub
parent 6fa34102e8
commit 7ac25935c7
8 changed files with 126 additions and 37 deletions
@@ -152,10 +152,11 @@ class Qwen2_5_VLForConditionalGeneration(ModelForCasualLM):
self.model = Qwen2_5_VLModel(fd_config=fd_config)
# Persistent buffers for CUDA graphs.
self._input_embeddings = paddle.zeros(
[fd_config.model_config.max_model_len, fd_config.model_config.hidden_size],
dtype=fd_config.model_config.dtype,
)
if fd_config.graph_opt_config.use_cudagraph:
self._decoder_input_embeddings = paddle.zeros(
[fd_config.graph_opt_config.max_capture_size, fd_config.model_config.hidden_size],
dtype=fd_config.model_config.dtype,
)
self.ori_vocab_size = fd_config.model_config.ori_vocab_size
@@ -290,10 +291,13 @@ class Qwen2_5_VLForConditionalGeneration(ModelForCasualLM):
input_embeddings = self.get_input_embeddings(
ids_remove_padding=ids_remove_padding, image_features=image_features
)
self._input_embeddings.copy_(input_embeddings, False)
if forward_meta.step_use_cudagraph:
self._decoder_input_embeddings.copy_(input_embeddings, False)
input_embeddings = self._decoder_input_embeddings
hidden_states = self.model(
input_embeddings=self._input_embeddings,
input_embeddings=input_embeddings,
ids_remove_padding=ids_remove_padding,
image_features=image_features,
forward_meta=forward_meta,