[Graph Optimization][CINN] Use CINN in PaddleOCR-VL ViT part (#5223)

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Nyakku Shigure
2025-12-09 14:37:00 +08:00
committed by GitHub
parent 8d99bac532
commit e1c4a12e34
8 changed files with 120 additions and 10 deletions
@@ -281,7 +281,6 @@ class SiglipMLP(nn.Layer):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = get_activation_fn(config.hidden_act)
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc1.weight.weight_loader = self.weight_loader
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
@@ -304,7 +303,7 @@ class SiglipMLP(nn.Layer):
def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states[0])
hidden_states = get_activation_fn(self.config.hidden_act)(hidden_states[0])
hidden_states = self.fc2(hidden_states)
return hidden_states
@@ -318,7 +317,6 @@ class SiglipEncoderLayer(paddle.nn.Layer):
self.layer_norm2 = paddle.nn.LayerNorm(self.embed_dim, epsilon=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
# @paddle.jit.to_static
def forward(
self,
hidden_states,
@@ -527,7 +525,37 @@ class SiglipEncoder(nn.Layer):
else:
attn_cu_seqlens = cu_seqlens
max_seqlen = (attn_cu_seqlens[1:] - attn_cu_seqlens[:-1]).max().item()
return self._run_encoder_layer(
encoder_states=encoder_states,
all_attentions=all_attentions,
attn_cu_seqlens=attn_cu_seqlens,
output_hidden_states=output_hidden_states,
reversed_window_indices=reversed_window_indices if output_hidden_states else None,
use_window_attn=use_window_attn,
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
cos_emb=cos_emb,
sin_emb=sin_emb,
)
# This function will be compiled with CINN when graph_opt_level >= 2
# TODO(SigureMo): Use a new decorator to mark the function for CINN compilation
def _run_encoder_layer(
self,
encoder_states: Optional[Tuple[()]],
all_attentions: Optional[Tuple[()]],
attn_cu_seqlens: Optional[paddle.Tensor],
output_hidden_states: Optional[bool],
reversed_window_indices: paddle.Tensor,
use_window_attn: bool,
hidden_states: paddle.Tensor,
attention_mask: Optional[paddle.Tensor],
output_attentions: bool,
cos_emb: Optional[paddle.Tensor],
sin_emb: Optional[paddle.Tensor],
) -> paddle.Tensor:
max_seqlen = (attn_cu_seqlens[1:] - attn_cu_seqlens[:-1]).max().cpu()
for encoder_layer in self.layers:
if output_hidden_states: