[Optimization] Support multimodal runner for image/video feature processing (#7485)

* [NewFeature] support mm runner

* [NewFeature] support mm runner part1

* support mm runner part2

* support mm runner part3

* support mm runner part4
This commit is contained in:
xiaoxiaohehe001
2026-04-22 11:58:33 +08:00
committed by GitHub
parent 76b960cb5b
commit a09792e085
3 changed files with 135 additions and 14 deletions
+40
View File
@@ -233,7 +233,11 @@ class InputBatch:
)
if self.is_mm_model:
self.image_features = None
self.image_grid_thws = None
self.image_features_list = None
self.video_features = None
self.video_grid_thws = None
self.video_infinity_scales = None
# Set block tables
pre_max_block_num = (
@@ -342,7 +346,26 @@ class InputBatch:
dtype="float32",
)
self.image_features = None # Built before the forward
self.image_grid_thws = None
self.image_features_list = None
self.video_features = None
self.video_grid_thws = None
self.video_infinity_scales = None
decode_states_len = self.speculative_config.num_speculative_tokens + 1 if self.speculative_decoding else 1
self.decode_states = paddle.full(
[self.scheduler_config.max_num_seqs, decode_states_len],
-1,
dtype="int32",
)
self.attn_mask_offsets = paddle.full(
shape=[self.scheduler_config.max_num_seqs * self.model_config.max_model_len],
fill_value=-1,
dtype="int32",
)
self.attn_mask_offsets_full = paddle.full(
[self.scheduler_config.max_num_seqs, self.model_config.max_model_len], -1, dtype="int32"
)
# For logits processors
self.logits_processors = build_logits_processors(self.fd_config)
@@ -409,6 +432,7 @@ class InputBatch:
swap_data(self.ori_seq_lens_encoder, i1, i2)
swap_data(self.system_lens, i1, i2)
swap_data(self.system_ids, i1, i2)
swap_data(self.generated_modality, i1, i2)
swap_data(self.enable_thinking, i1, i2)
swap_data(self.max_think_lens, i1, i2)
swap_data(self.limit_think_status, i1, i2)
@@ -451,6 +475,8 @@ class InputBatch:
self.image_features_list[i1],
)
swap_data(self.share_inputs["rope_emb"], i1, i2)
swap_data(self.decode_states, i1, i2)
swap_data(self.attn_mask_offsets_full, i1, i2)
# Swap mask rollback
swap_data(self.mask_rollback, i1, i2)
@@ -578,6 +604,7 @@ class InputBatch:
fill_paddle_tensor(self, "ori_seq_lens_encoder", 0)
fill_paddle_tensor(self, "system_lens", 0)
fill_paddle_tensor(self, "system_ids", -1)
fill_paddle_tensor(self, "generated_modality", -1)
fill_paddle_tensor(self, "ids_remove_padding", 0)
fill_paddle_tensor(self, "batch_id_per_token", 0)
@@ -662,7 +689,14 @@ class InputBatch:
dtype="float32",
)
self.image_features = None
self.image_grid_thws = None
self.image_features_list = None
self.video_features = None
self.video_grid_thws = None
self.video_infinity_scales = None
fill_paddle_tensor(self, "decode_states", -1)
fill_paddle_tensor(self, "attn_mask_offsets", -1)
fill_paddle_tensor(self, "attn_mask_offsets_full", -1)
else:
# Reset non-multimodal rope_emb
self.rope_emb = get_rope(
@@ -674,7 +708,11 @@ class InputBatch:
)
if self.is_mm_model:
self.image_features = None
self.image_grid_thws = None
self.image_features_list = None
self.video_features = None
self.video_grid_thws = None
self.video_infinity_scales = None
# Reset other miscellaneous tensors
fill_paddle_tensor(self, "mask_rollback", 0)
@@ -892,6 +930,8 @@ class ProposerInputBatch(InputBatch):
swap_data(self.mask_rollback, i1, i2)
swap_data(self.recompute_token_num, i1, i2)
if self.enable_mm:
swap_data(self.decode_states, i1, i2)
swap_data(self.attn_mask_offsets, i1, i2)
swap_data(self.attn_mask_offsets_full, i1, i2)
swap_data(self.attn_mask_offsets_decoder, i1, i2)