Split enable_mm (#7183) (#7233)

Co-authored-by: K11OntheBoat <ruianmaidanglao@163.com>
Co-authored-by: liuruian <liuruian@MacBook-Pro.local>
This commit is contained in:
YuBaoku
2026-04-08 16:32:04 +08:00
committed by GitHub
parent 403ce139c7
commit 6b78981dde
33 changed files with 109 additions and 69 deletions
+23 -26
View File
@@ -17,13 +17,7 @@
import paddle
from paddleformers.utils.log import logger
from fastdeploy.config import (
CacheConfig,
DeployModality,
FDConfig,
ModelConfig,
SpeculativeConfig,
)
from fastdeploy.config import CacheConfig, FDConfig, ModelConfig, SpeculativeConfig
from fastdeploy.model_executor.layers.rotary_embedding import get_rope
from fastdeploy.model_executor.logits_processor import build_logits_processors
from fastdeploy.platforms import current_platform
@@ -101,7 +95,8 @@ class InputBatch:
self.scheduler_config = fd_config.scheduler_config
self.speculative_config: SpeculativeConfig = fd_config.speculative_config
self.speculative_decoding = self.speculative_config.method is not None
self.enable_mm = self.model_config.enable_mm
self.is_mm_model = self.model_config.enable_mm
self.enable_mm = fd_config.enable_mm_runtime
self.enable_expert_parallel = fd_config.parallel_config.enable_expert_parallel
self.index_to_batch_id = {}
self.enable_pd_reorder = False
@@ -231,6 +226,9 @@ class InputBatch:
model_config=self.model_config,
partial_rotary_factor=self.model_config.partial_rotary_factor,
)
if self.is_mm_model:
self.image_features = None
self.image_features_list = None
# Set block tables
pre_max_block_num = (
@@ -677,6 +675,9 @@ class InputBatch:
model_config=self.model_config,
partial_rotary_factor=self.model_config.partial_rotary_factor,
)
if self.is_mm_model:
self.image_features = None
self.image_features_list = None
# Reset other miscellaneous tensors
fill_paddle_tensor(self, "mask_rollback", 0)
@@ -689,7 +690,7 @@ class InputBatch:
class ProposerInputBatch(InputBatch):
def __init__(self, fd_config: FDConfig, target_model_input_batch: InputBatch) -> None:
self.enable_mm = fd_config.model_config.enable_mm
self.enable_mm = fd_config.enable_mm_runtime
self.num_model_steps = fd_config.speculative_config.num_model_steps
self.index_to_batch_id = {}
self.target_model_input_batch = target_model_input_batch
@@ -863,18 +864,15 @@ class ProposerInputBatch(InputBatch):
-1,
dtype="int32",
)
if self.fd_config.deploy_modality != DeployModality.TEXT:
self.attn_mask_offsets = paddle.full(
shape=[self.scheduler_config.max_num_seqs * self.model_config.max_model_len],
fill_value=-1,
dtype="int32",
)
self.attn_mask_offsets_full = paddle.full(
[self.scheduler_config.max_num_seqs, self.model_config.max_model_len], -1, dtype="int32"
)
self.attn_mask_offsets_decoder = paddle.full(
[self.scheduler_config.max_num_seqs, 1], -1, dtype="int32"
)
self.attn_mask_offsets = paddle.full(
shape=[self.scheduler_config.max_num_seqs * self.model_config.max_model_len],
fill_value=-1,
dtype="int32",
)
self.attn_mask_offsets_full = paddle.full(
[self.scheduler_config.max_num_seqs, self.model_config.max_model_len], -1, dtype="int32"
)
self.attn_mask_offsets_decoder = paddle.full([self.scheduler_config.max_num_seqs, 1], -1, dtype="int32")
def swap_states(self, i1, i2) -> None:
def swap_data(tensor, idx1, idx2):
@@ -896,7 +894,7 @@ class ProposerInputBatch(InputBatch):
swap_data(self.input_ids_len, i1, i2)
swap_data(self.mask_rollback, i1, i2)
swap_data(self.recompute_token_num, i1, i2)
if self.enable_mm and self.fd_config.deploy_modality != DeployModality.TEXT:
if self.enable_mm:
swap_data(self.attn_mask_offsets_full, i1, i2)
swap_data(self.attn_mask_offsets_decoder, i1, i2)
@@ -1030,10 +1028,9 @@ class ProposerInputBatch(InputBatch):
# Reset multimodal tensors if enabled
if self.enable_mm:
fill_paddle_tensor(self, "decode_states", -1)
if self.fd_config.deploy_modality != DeployModality.TEXT:
fill_paddle_tensor(self, "attn_mask_offsets", -1)
fill_paddle_tensor(self, "attn_mask_offsets_full", -1)
fill_paddle_tensor(self, "attn_mask_offsets_decoder", -1)
fill_paddle_tensor(self, "attn_mask_offsets", -1)
fill_paddle_tensor(self, "attn_mask_offsets_full", -1)
fill_paddle_tensor(self, "attn_mask_offsets_decoder", -1)
logger.info("model_inputs reset completed")
except Exception as e: