Split enable_mm (#7183)

Co-authored-by: liuruian <liuruian@MacBook-Pro.local>
This commit is contained in:
K11OntheBoat
2026-04-08 11:25:41 +08:00
committed by GitHub
parent 8496ec71a6
commit bb48bcbaa2
33 changed files with 109 additions and 69 deletions
+30 -3
View File
@@ -1992,6 +1992,7 @@ class FDConfig:
int(envs.ENABLE_V1_KVCACHE_SCHEDULER) == 0 int(envs.ENABLE_V1_KVCACHE_SCHEDULER) == 0
and self.model_config is not None and self.model_config is not None
and self.model_config.enable_mm and self.model_config.enable_mm
and self.deploy_modality != DeployModality.TEXT
): ):
self.max_prefill_batch = 1 # TODO:当前V0多模prefill阶段只支持并行度为1,待优化 self.max_prefill_batch = 1 # TODO:当前V0多模prefill阶段只支持并行度为1,待优化
else: else:
@@ -2031,6 +2032,20 @@ class FDConfig:
self.check() self.check()
# self.print() # NOTE: it's better to explicitly call .print() when FDConfig is initialized # self.print() # NOTE: it's better to explicitly call .print() when FDConfig is initialized
@property
def enable_mm_runtime(self) -> bool:
return (
self.model_config is not None
and self.model_config.enable_mm
and self.deploy_modality != DeployModality.TEXT
)
@property
def enable_rope_3d_runtime(self) -> bool:
return self.enable_mm_runtime and (
getattr(self.model_config, "rope_3d", False) or getattr(self.model_config, "use_3d_rope", False)
)
def _disable_sequence_parallel_moe_if_needed(self, mode_name): def _disable_sequence_parallel_moe_if_needed(self, mode_name):
if self.parallel_config.use_sequence_parallel_moe and self.graph_opt_config.use_cudagraph: if self.parallel_config.use_sequence_parallel_moe and self.graph_opt_config.use_cudagraph:
self.parallel_config.use_sequence_parallel_moe = False self.parallel_config.use_sequence_parallel_moe = False
@@ -2069,9 +2084,21 @@ class FDConfig:
if self.long_prefill_token_threshold == 0: if self.long_prefill_token_threshold == 0:
self.long_prefill_token_threshold = int(self.model_config.max_model_len * 0.04) self.long_prefill_token_threshold = int(self.model_config.max_model_len * 0.04)
if (
self.model_config is not None
and self.model_config.enable_mm
and self.deploy_modality == DeployModality.TEXT
):
if getattr(self.model_config, "rope_3d", False) or getattr(self.model_config, "use_3d_rope", False):
logger.info(
"Deploy modality is text; forcing the multimodal-capable model onto the 2D RoPE runtime path."
)
setattr(self.model_config, "rope_3d", False)
setattr(self.model_config, "use_3d_rope", False)
self.cache_config.max_block_num_per_seq = int(self.model_config.max_model_len // self.cache_config.block_size) self.cache_config.max_block_num_per_seq = int(self.model_config.max_model_len // self.cache_config.block_size)
self.cache_config.postprocess(self.get_max_chunk_tokens(), self.scheduler_config.max_num_seqs) self.cache_config.postprocess(self.get_max_chunk_tokens(), self.scheduler_config.max_num_seqs)
if self.model_config is not None and self.model_config.enable_mm and not envs.ENABLE_V1_KVCACHE_SCHEDULER: if self.model_config is not None and self.enable_mm_runtime and not envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.cache_config.enable_prefix_caching = False self.cache_config.enable_prefix_caching = False
if ( if (
self.structured_outputs_config is not None self.structured_outputs_config is not None
@@ -2097,7 +2124,7 @@ class FDConfig:
f"Guided decoding backend '{self.structured_outputs_config.guided_decoding_backend}' is not implemented. [auto, xgrammar, guidance, off]" f"Guided decoding backend '{self.structured_outputs_config.guided_decoding_backend}' is not implemented. [auto, xgrammar, guidance, off]"
) )
if self.model_config.enable_mm: if self.enable_mm_runtime:
if self.cache_config.max_encoder_cache is None or self.cache_config.max_encoder_cache < 0: if self.cache_config.max_encoder_cache is None or self.cache_config.max_encoder_cache < 0:
self.cache_config.max_encoder_cache = self.scheduler_config.max_num_batched_tokens self.cache_config.max_encoder_cache = self.scheduler_config.max_num_batched_tokens
elif self.cache_config.max_encoder_cache != 0: elif self.cache_config.max_encoder_cache != 0:
@@ -2404,7 +2431,7 @@ class FDConfig:
num_tokens = self.scheduler_config.max_num_seqs num_tokens = self.scheduler_config.max_num_seqs
else: else:
num_tokens = self.scheduler_config.max_num_batched_tokens num_tokens = self.scheduler_config.max_num_batched_tokens
if mm_max_tokens_per_item is not None and self.deploy_modality != DeployModality.TEXT: if self.enable_mm_runtime and mm_max_tokens_per_item is not None:
max_mm_tokens = max( max_mm_tokens = max(
mm_max_tokens_per_item.get("image", 0), mm_max_tokens_per_item.get("image", 0),
mm_max_tokens_per_item.get("video", 0), mm_max_tokens_per_item.get("video", 0),
+2 -1
View File
@@ -294,6 +294,7 @@ class AsyncLLM(EngineServiceClient):
cfg.limit_mm_per_prompt, cfg.limit_mm_per_prompt,
cfg.mm_processor_kwargs, cfg.mm_processor_kwargs,
cfg.tool_parser, cfg.tool_parser,
enable_mm_runtime=cfg.enable_mm_runtime,
) )
# Create data processor # Create data processor
self.data_processor = self.input_processor.create_processor() self.data_processor = self.input_processor.create_processor()
@@ -446,7 +447,7 @@ class AsyncLLM(EngineServiceClient):
) )
if envs.ZMQ_SEND_BATCH_DATA and self.connection_manager is not None: if envs.ZMQ_SEND_BATCH_DATA and self.connection_manager is not None:
request["zmq_worker_pid"] = self.connection_manager.worker_pid request["zmq_worker_pid"] = self.connection_manager.worker_pid
if self.cfg.model_config.enable_mm: if self.cfg.enable_mm_runtime:
self.request_client.send_pyobj(request) self.request_client.send_pyobj(request)
else: else:
self.request_client.send_json(request) self.request_client.send_json(request)
+5 -3
View File
@@ -330,6 +330,7 @@ class EngineService:
self.cfg.limit_mm_per_prompt, self.cfg.limit_mm_per_prompt,
self.cfg.mm_processor_kwargs, self.cfg.mm_processor_kwargs,
self.cfg.tool_parser, self.cfg.tool_parser,
enable_mm_runtime=self.cfg.enable_mm_runtime,
) )
self.data_processor = self.input_processor.create_processor() self.data_processor = self.input_processor.create_processor()
self.mm_max_tokens_per_item = self.data_processor.get_mm_max_tokens_per_item( self.mm_max_tokens_per_item = self.data_processor.get_mm_max_tokens_per_item(
@@ -601,7 +602,7 @@ class EngineService:
LoggingEventName.RESCHEDULED_INFERENCE_START, task.request_id, getattr(task, "user", "") LoggingEventName.RESCHEDULED_INFERENCE_START, task.request_id, getattr(task, "user", "")
) )
if not is_prefill: if not is_prefill:
if not self.cfg.model_config.enable_mm: if not self.cfg.enable_mm_runtime:
self.update_requests_chunk_size(tasks) self.update_requests_chunk_size(tasks)
else: else:
self.update_mm_requests_chunk_size(tasks) self.update_mm_requests_chunk_size(tasks)
@@ -1217,7 +1218,7 @@ class EngineService:
while self.running: while self.running:
try: try:
block = True if len(added_requests) == 0 else False block = True if len(added_requests) == 0 else False
if not self.cfg.model_config.enable_mm: if not self.cfg.enable_mm_runtime:
err, data = self.recv_request_server.receive_json_once(block) err, data = self.recv_request_server.receive_json_once(block)
else: else:
err, data = self.recv_request_server.receive_pyobj_once(block) err, data = self.recv_request_server.receive_pyobj_once(block)
@@ -1275,6 +1276,7 @@ class EngineService:
err_msg = None err_msg = None
try: try:
request = Request.from_dict(data) request = Request.from_dict(data)
request.metrics.scheduler_recv_req_time = time.time() request.metrics.scheduler_recv_req_time = time.time()
main_process_metrics.requests_number.inc() main_process_metrics.requests_number.inc()
trace_carrier = data.get("trace_carrier") trace_carrier = data.get("trace_carrier")
@@ -2377,7 +2379,7 @@ class EngineService:
if self.cfg.scheduler_config.splitwise_role == "prefill": if self.cfg.scheduler_config.splitwise_role == "prefill":
variables["FLAGS_fmt_write_cache_completed_signal"] = 1 variables["FLAGS_fmt_write_cache_completed_signal"] = 1
if self.cfg.model_config.enable_mm: if self.cfg.enable_mm_runtime:
variables["FLAGS_max_partition_size"] = 1024 variables["FLAGS_max_partition_size"] = 1024
command_prefix = "" command_prefix = ""
@@ -205,11 +205,11 @@ class ResourceManagerV1(ResourceManager):
self.need_block_num_map = dict() self.need_block_num_map = dict()
self.encoder_cache = None self.encoder_cache = None
if config.model_config.enable_mm and config.cache_config.max_encoder_cache > 0: if config.enable_mm_runtime and config.cache_config.max_encoder_cache > 0:
self.encoder_cache = EncoderCacheManager(config.cache_config.max_encoder_cache) self.encoder_cache = EncoderCacheManager(config.cache_config.max_encoder_cache)
self.processor_cache = None self.processor_cache = None
if config.model_config.enable_mm and config.cache_config.max_processor_cache > 0: if config.enable_mm_runtime and config.cache_config.max_processor_cache > 0:
max_processor_cache_in_bytes = int(config.cache_config.max_processor_cache * 1024 * 1024 * 1024) max_processor_cache_in_bytes = int(config.cache_config.max_processor_cache * 1024 * 1024 * 1024)
self.processor_cache = ProcessorCacheManager(max_processor_cache_in_bytes) self.processor_cache = ProcessorCacheManager(max_processor_cache_in_bytes)
@@ -550,7 +550,7 @@ class ResourceManagerV1(ResourceManager):
num_new_tokens = token_budget // self.config.cache_config.block_size * self.config.cache_config.block_size num_new_tokens = token_budget // self.config.cache_config.block_size * self.config.cache_config.block_size
request.with_image = False request.with_image = False
if not self.config.model_config.enable_mm: if not self.config.enable_mm_runtime:
return num_new_tokens return num_new_tokens
inputs = request.multimodal_inputs inputs = request.multimodal_inputs
+3 -1
View File
@@ -84,7 +84,7 @@ class EngineClient:
def __init__(self, pid: int | str, port: int | str, fd_config: FDConfig, workers: int = 1, max_logprobs: int = 20): def __init__(self, pid: int | str, port: int | str, fd_config: FDConfig, workers: int = 1, max_logprobs: int = 20):
self.fd_config = fd_config self.fd_config = fd_config
self.tensor_parallel_size = self.fd_config.parallel_config.tensor_parallel_size self.tensor_parallel_size = self.fd_config.parallel_config.tensor_parallel_size
self.enable_mm = self.fd_config.model_config.enable_mm self.enable_mm = self.fd_config.enable_mm_runtime
self.max_logprobs = max_logprobs self.max_logprobs = max_logprobs
input_processor = InputPreprocessor( input_processor = InputPreprocessor(
self.fd_config.model_config, self.fd_config.model_config,
@@ -93,6 +93,7 @@ class EngineClient:
self.fd_config.mm_processor_kwargs, self.fd_config.mm_processor_kwargs,
self.fd_config.tool_parser, self.fd_config.tool_parser,
self.enable_mm and self.fd_config.cache_config.max_processor_cache > 0, self.enable_mm and self.fd_config.cache_config.max_processor_cache > 0,
enable_mm_runtime=self.enable_mm,
) )
self.enable_logprob = self.fd_config.model_config.enable_logprob self.enable_logprob = self.fd_config.model_config.enable_logprob
self.data_processor = input_processor.create_processor() self.data_processor = input_processor.create_processor()
@@ -358,6 +359,7 @@ class EngineClient:
task["max_tokens"] = min(self.max_model_len - input_ids_len, task.get("max_tokens")) task["max_tokens"] = min(self.max_model_len - input_ids_len, task.get("max_tokens"))
min_tokens = task.get("min_tokens", 1) min_tokens = task.get("min_tokens", 1)
if "messages" in task: if "messages" in task:
task["messages"] = None task["messages"] = None
api_server_logger.info(f"task['max_tokens']:{task['max_tokens']}") api_server_logger.info(f"task['max_tokens']:{task['max_tokens']}")
+4 -1
View File
@@ -48,6 +48,7 @@ class InputPreprocessor:
mm_processor_kwargs: Optional[Dict[str, Any]] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None,
tool_parser: str = None, tool_parser: str = None,
enable_processor_cache: bool = False, enable_processor_cache: bool = False,
enable_mm_runtime: Optional[bool] = None,
) -> None: ) -> None:
self.model_config = model_config self.model_config = model_config
self.model_name_or_path = self.model_config.model self.model_name_or_path = self.model_config.model
@@ -56,6 +57,7 @@ class InputPreprocessor:
self.mm_processor_kwargs = mm_processor_kwargs self.mm_processor_kwargs = mm_processor_kwargs
self.tool_parser = tool_parser self.tool_parser = tool_parser
self.enable_processor_cache = enable_processor_cache self.enable_processor_cache = enable_processor_cache
self.enable_mm_runtime = self.model_config.enable_mm if enable_mm_runtime is None else enable_mm_runtime
def create_processor(self): def create_processor(self):
reasoning_parser_obj = None reasoning_parser_obj = None
@@ -77,10 +79,11 @@ class InputPreprocessor:
reasoning_parser_obj=reasoning_parser_obj, reasoning_parser_obj=reasoning_parser_obj,
tool_parser_obj=tool_parser_obj, tool_parser_obj=tool_parser_obj,
mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_kwargs=self.mm_processor_kwargs,
enable_mm_runtime=self.enable_mm_runtime,
) )
except Exception as e: except Exception as e:
logger.info(f"Plugin input processor not available ({e}), using built-in processor") logger.info(f"Plugin input processor not available ({e}), using built-in processor")
if not self.model_config.enable_mm: if not self.enable_mm_runtime:
from fastdeploy.input.text_processor import TextProcessor from fastdeploy.input.text_processor import TextProcessor
tokenizer_type = "ernie4_5" if ErnieArchitectures.contains_ernie_arch(architecture) else "auto" tokenizer_type = "ernie4_5" if ErnieArchitectures.contains_ernie_arch(architecture) else "auto"
@@ -549,7 +549,6 @@ class EngineWorkerQueue:
self.lock.release() self.lock.release()
time.sleep(0.001) time.sleep(0.001)
self.lock.acquire() self.lock.acquire()
if envs.FD_ENABLE_MAX_PREFILL or envs.FD_ENABLE_E2W_TENSOR_CONVERT: if envs.FD_ENABLE_MAX_PREFILL or envs.FD_ENABLE_E2W_TENSOR_CONVERT:
# multimodal input numpy -> tensor # multimodal input numpy -> tensor
to_tensor(tasks[0]) to_tensor(tasks[0])
@@ -571,7 +570,6 @@ class EngineWorkerQueue:
""" """
tasks: List[Any] = list() tasks: List[Any] = list()
self.lock.acquire() self.lock.acquire()
tasks.extend(self.tasks) tasks.extend(self.tasks)
self.client_read_flag[self.client_id] = 1 self.client_read_flag[self.client_id] = 1
all_client_read: bool = np.sum(self.client_read_flag) == self.num_client all_client_read: bool = np.sum(self.client_read_flag) == self.num_client
@@ -138,9 +138,7 @@ class AppendAttentionBackend(AttentionBackend):
self.rope_theta: float = ( self.rope_theta: float = (
10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta 10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta
) )
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr( self.rope_3d: bool = fd_config.enable_rope_3d_runtime
fd_config.model_config, "use_3d_rope", False
)
if fd_config.speculative_config.model_type != "main": if fd_config.speculative_config.model_type != "main":
self.rope_3d = False self.rope_3d = False
self.causal: bool = getattr(fd_config.model_config, "causal", True) self.causal: bool = getattr(fd_config.model_config, "causal", True)
@@ -136,7 +136,7 @@ class DSAAttentionBackend(AttentionBackend):
self.rope_theta: float = ( self.rope_theta: float = (
10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta 10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta
) )
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) self.rope_3d: bool = fd_config.enable_rope_3d_runtime
self.causal: bool = getattr(fd_config.model_config, "causal", True) self.causal: bool = getattr(fd_config.model_config, "causal", True)
self.speculative_method: str = fd_config.speculative_config.method self.speculative_method: str = fd_config.speculative_config.method
self.use_speculate: bool = self.speculative_method is not None self.use_speculate: bool = self.speculative_method is not None
@@ -269,9 +269,7 @@ class FlashAttentionBackend(AttentionBackend):
self.rank, self.device_id = init_rank_and_device_id(fd_config) self.rank, self.device_id = init_rank_and_device_id(fd_config)
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr( self.rope_3d: bool = fd_config.enable_rope_3d_runtime
fd_config.model_config, "use_3d_rope", False
)
if fd_config.speculative_config.model_type != "main": if fd_config.speculative_config.model_type != "main":
self.rope_3d = False self.rope_3d = False
# Note(ZKK): here must be consistent with append_attn_backend.py # Note(ZKK): here must be consistent with append_attn_backend.py
@@ -123,9 +123,7 @@ class FlashMaskAttentionBackend(AttentionBackend):
self.rank, self.device_id = init_rank_and_device_id(fd_config) self.rank, self.device_id = init_rank_and_device_id(fd_config)
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr( self.rope_3d: bool = fd_config.enable_rope_3d_runtime
fd_config.model_config, "use_3d_rope", False
)
if fd_config.speculative_config.model_type != "main": if fd_config.speculative_config.model_type != "main":
self.rope_3d = False self.rope_3d = False
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", "32768")) self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", "32768"))
@@ -263,7 +263,7 @@ class MLAAttentionBackend(AttentionBackend):
self.rope_theta: float = ( self.rope_theta: float = (
10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta 10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta
) )
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) self.rope_3d: bool = fd_config.enable_rope_3d_runtime
self.causal: bool = getattr(fd_config.model_config, "causal", True) self.causal: bool = getattr(fd_config.model_config, "causal", True)
self.speculative_method = fd_config.speculative_config.method self.speculative_method = fd_config.speculative_config.method
self.use_speculate: bool = self.speculative_method is not None self.use_speculate: bool = self.speculative_method is not None
@@ -89,7 +89,7 @@ class MhaAttnBackend(AttentionBackend):
# note: scale need to change if using MLA # note: scale need to change if using MLA
self.scale = 1.0 / sqrt(head_dim) self.scale = 1.0 / sqrt(head_dim)
self.dtype = paddle.get_default_dtype() self.dtype = paddle.get_default_dtype()
self.enable_mm = fd_config.model_config.enable_mm self.enable_mm = fd_config.enable_mm_runtime
self.rope_batch_stride = self.max_context_len * self.head_dim if self.enable_mm else 0 self.rope_batch_stride = self.max_context_len * self.head_dim if self.enable_mm else 0
if "paddleocr" in fd_config.model_config.model_type: if "paddleocr" in fd_config.model_config.model_type:
self.is_interleaved_rope_mode = False self.is_interleaved_rope_mode = False
@@ -219,7 +219,7 @@ class HPUAttentionBackend(AttentionBackend_HPU):
self.block_size = llm_config.cache_config.block_size self.block_size = llm_config.cache_config.block_size
self.max_seq_len = llm_config.model_config.max_model_len self.max_seq_len = llm_config.model_config.max_model_len
self.rope_theta = 10000.0 if llm_config.model_config.rope_theta is None else llm_config.model_config.rope_theta self.rope_theta = 10000.0 if llm_config.model_config.rope_theta is None else llm_config.model_config.rope_theta
self.rope_3d = getattr(llm_config.model_config, "rope_3d", False) self.rope_3d = llm_config.enable_rope_3d_runtime
self.causal = getattr(llm_config.model_config, "causal", True) self.causal = getattr(llm_config.model_config, "causal", True)
self.speculative_method = llm_config.speculative_config.method self.speculative_method = llm_config.speculative_config.method
self.use_speculate: bool = self.speculative_method is not None self.use_speculate: bool = self.speculative_method is not None
@@ -101,7 +101,7 @@ class FlashAttentionBackend(AttentionBackend):
self.rope_theta: float = ( self.rope_theta: float = (
10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta 10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta
) )
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) self.rope_3d: bool = fd_config.enable_rope_3d_runtime
self.causal: bool = getattr(fd_config.model_config, "causal", True) self.causal: bool = getattr(fd_config.model_config, "causal", True)
self.speculative_method = fd_config.speculative_config.method self.speculative_method = fd_config.speculative_config.method
self.use_speculate: bool = self.speculative_method is not None self.use_speculate: bool = self.speculative_method is not None
@@ -128,7 +128,7 @@ class FlashAttentionBackend(AttentionBackend):
fd_config.parallel_config.expert_parallel_rank = 0 fd_config.parallel_config.expert_parallel_rank = 0
self.rank, self.device_id = init_rank_and_device_id(fd_config) self.rank, self.device_id = init_rank_and_device_id(fd_config)
self.enable_mm = fd_config.model_config.enable_mm self.enable_mm = fd_config.enable_mm_runtime
self.model_type = fd_config.model_config.model_type self.model_type = fd_config.model_config.model_type
self.is_neox_style = False self.is_neox_style = False
if "paddleocr" in fd_config.model_config.model_type: if "paddleocr" in fd_config.model_config.model_type:
@@ -105,7 +105,7 @@ class MetaxMLAAttentionBackend(AttentionBackend):
self.rope_theta: float = ( self.rope_theta: float = (
10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta 10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta
) )
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) self.rope_3d: bool = fd_config.enable_rope_3d_runtime
self.causal: bool = getattr(fd_config.model_config, "causal", True) self.causal: bool = getattr(fd_config.model_config, "causal", True)
self.speculative_method = fd_config.speculative_config.method self.speculative_method = fd_config.speculative_config.method
self.use_speculate: bool = self.speculative_method is not None self.use_speculate: bool = self.speculative_method is not None
@@ -88,9 +88,7 @@ class XPUAttentionBackend(AttentionBackend):
self.rope_theta: float = ( self.rope_theta: float = (
10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta 10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta
) )
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr( self.rope_3d: bool = fd_config.enable_rope_3d_runtime
fd_config.model_config, "use_3d_rope", False
)
self.causal: bool = getattr(fd_config.model_config, "causal", True) self.causal: bool = getattr(fd_config.model_config, "causal", True)
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP) self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP)
+1
View File
@@ -947,6 +947,7 @@ class TokenProcessor:
if not is_prefill: if not is_prefill:
self._record_completion_metrics(task, current_time) self._record_completion_metrics(task, current_time)
llm_logger.info(f"task {task_id} received eos token. Recycling.") llm_logger.info(f"task {task_id} received eos token. Recycling.")
if ( if (
envs.ENABLE_V1_KVCACHE_SCHEDULER envs.ENABLE_V1_KVCACHE_SCHEDULER
and self.cfg.cache_config.enable_prefix_caching and self.cfg.cache_config.enable_prefix_caching
+1 -1
View File
@@ -71,7 +71,7 @@ class Proposer(ABC):
self.max_ngram_size = self.speculative_config.max_ngram_size self.max_ngram_size = self.speculative_config.max_ngram_size
self.min_ngram_size = self.speculative_config.min_ngram_size self.min_ngram_size = self.speculative_config.min_ngram_size
self.enable_mm = self.model_config.enable_mm self.enable_mm = self.fd_config.enable_mm_runtime
spec_logger.info(f"Speculate config: {self.speculative_config}") spec_logger.info(f"Speculate config: {self.speculative_config}")
+1 -1
View File
@@ -103,7 +103,7 @@ class MTPProposer(Proposer):
self.num_main_model_layers = self.model_config.num_hidden_layers self.num_main_model_layers = self.model_config.num_hidden_layers
self.local_rank = local_rank self.local_rank = local_rank
self.device_id = device_id self.device_id = device_id
self.use_attn_mask_offset = self.enable_mm and self.fd_config.deploy_modality != "text" self.use_attn_mask_offset = self.enable_mm
self._update_mtp_config(main_model) self._update_mtp_config(main_model)
self._load_model() self._load_model()
+1 -1
View File
@@ -62,7 +62,7 @@ class GCUModelRunner(ModelRunnerBase):
local_rank: int, local_rank: int,
): ):
super().__init__(fd_config=fd_config, device=device) super().__init__(fd_config=fd_config, device=device)
self.enable_mm = self.model_config.enable_mm self.enable_mm = self.fd_config.enable_mm_runtime
self.rank = rank self.rank = rank
self.local_rank = local_rank self.local_rank = local_rank
self.device_id = device_id self.device_id = device_id
+3 -1
View File
@@ -119,7 +119,7 @@ class GPUModelRunner(ModelRunnerBase):
): ):
super().__init__(fd_config=fd_config, device=device) super().__init__(fd_config=fd_config, device=device)
self.MAX_INFER_SEED = 9223372036854775806 self.MAX_INFER_SEED = 9223372036854775806
self.enable_mm = self.model_config.enable_mm self.enable_mm = self.fd_config.enable_mm_runtime
self.rank = rank self.rank = rank
self.local_rank = local_rank self.local_rank = local_rank
self.device_id = device_id self.device_id = device_id
@@ -1118,10 +1118,12 @@ class GPUModelRunner(ModelRunnerBase):
def _prepare_inputs(self, cached_token_num=-1, cached_real_bsz=-1, is_dummy_or_profile_run=False) -> None: def _prepare_inputs(self, cached_token_num=-1, cached_real_bsz=-1, is_dummy_or_profile_run=False) -> None:
"""Prepare the model inputs""" """Prepare the model inputs"""
if self.enable_mm and self.share_inputs["image_features_list"] is not None: if self.enable_mm and self.share_inputs["image_features_list"] is not None:
tensor_feats = [t for t in self.share_inputs["image_features_list"] if isinstance(t, paddle.Tensor)] tensor_feats = [t for t in self.share_inputs["image_features_list"] if isinstance(t, paddle.Tensor)]
if tensor_feats: if tensor_feats:
self.share_inputs["image_features"] = paddle.concat(tensor_feats, axis=0) self.share_inputs["image_features"] = paddle.concat(tensor_feats, axis=0)
recover_decode_task( recover_decode_task(
self.share_inputs["stop_flags"], self.share_inputs["stop_flags"],
self.share_inputs["seq_lens_this_time"], self.share_inputs["seq_lens_this_time"],
+1 -1
View File
@@ -40,7 +40,7 @@ class IluvatarWorker(GpuWorker):
local_rank: int, local_rank: int,
rank: int, rank: int,
): ):
if fd_config.model_config.enable_mm: if fd_config.enable_mm_runtime:
paddle.set_flags({"FLAGS_enable_ixattnbkd": True, "FLAGS_enable_ixdnn_attn": False}) paddle.set_flags({"FLAGS_enable_ixattnbkd": True, "FLAGS_enable_ixdnn_attn": False})
super(IluvatarWorker, self).__init__( super(IluvatarWorker, self).__init__(
fd_config=fd_config, fd_config=fd_config,
+23 -26
View File
@@ -17,13 +17,7 @@
import paddle import paddle
from paddleformers.utils.log import logger from paddleformers.utils.log import logger
from fastdeploy.config import ( from fastdeploy.config import CacheConfig, FDConfig, ModelConfig, SpeculativeConfig
CacheConfig,
DeployModality,
FDConfig,
ModelConfig,
SpeculativeConfig,
)
from fastdeploy.model_executor.layers.rotary_embedding import get_rope from fastdeploy.model_executor.layers.rotary_embedding import get_rope
from fastdeploy.model_executor.logits_processor import build_logits_processors from fastdeploy.model_executor.logits_processor import build_logits_processors
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
@@ -101,7 +95,8 @@ class InputBatch:
self.scheduler_config = fd_config.scheduler_config self.scheduler_config = fd_config.scheduler_config
self.speculative_config: SpeculativeConfig = fd_config.speculative_config self.speculative_config: SpeculativeConfig = fd_config.speculative_config
self.speculative_decoding = self.speculative_config.method is not None self.speculative_decoding = self.speculative_config.method is not None
self.enable_mm = self.model_config.enable_mm self.is_mm_model = self.model_config.enable_mm
self.enable_mm = fd_config.enable_mm_runtime
self.enable_expert_parallel = fd_config.parallel_config.enable_expert_parallel self.enable_expert_parallel = fd_config.parallel_config.enable_expert_parallel
self.index_to_batch_id = {} self.index_to_batch_id = {}
self.enable_pd_reorder = False self.enable_pd_reorder = False
@@ -231,6 +226,9 @@ class InputBatch:
model_config=self.model_config, model_config=self.model_config,
partial_rotary_factor=self.model_config.partial_rotary_factor, partial_rotary_factor=self.model_config.partial_rotary_factor,
) )
if self.is_mm_model:
self.image_features = None
self.image_features_list = None
# Set block tables # Set block tables
pre_max_block_num = ( pre_max_block_num = (
@@ -677,6 +675,9 @@ class InputBatch:
model_config=self.model_config, model_config=self.model_config,
partial_rotary_factor=self.model_config.partial_rotary_factor, partial_rotary_factor=self.model_config.partial_rotary_factor,
) )
if self.is_mm_model:
self.image_features = None
self.image_features_list = None
# Reset other miscellaneous tensors # Reset other miscellaneous tensors
fill_paddle_tensor(self, "mask_rollback", 0) fill_paddle_tensor(self, "mask_rollback", 0)
@@ -689,7 +690,7 @@ class InputBatch:
class ProposerInputBatch(InputBatch): class ProposerInputBatch(InputBatch):
def __init__(self, fd_config: FDConfig, target_model_input_batch: InputBatch) -> None: def __init__(self, fd_config: FDConfig, target_model_input_batch: InputBatch) -> None:
self.enable_mm = fd_config.model_config.enable_mm self.enable_mm = fd_config.enable_mm_runtime
self.num_model_steps = fd_config.speculative_config.num_model_steps self.num_model_steps = fd_config.speculative_config.num_model_steps
self.index_to_batch_id = {} self.index_to_batch_id = {}
self.target_model_input_batch = target_model_input_batch self.target_model_input_batch = target_model_input_batch
@@ -863,18 +864,15 @@ class ProposerInputBatch(InputBatch):
-1, -1,
dtype="int32", dtype="int32",
) )
if self.fd_config.deploy_modality != DeployModality.TEXT: self.attn_mask_offsets = paddle.full(
self.attn_mask_offsets = paddle.full( shape=[self.scheduler_config.max_num_seqs * self.model_config.max_model_len],
shape=[self.scheduler_config.max_num_seqs * self.model_config.max_model_len], fill_value=-1,
fill_value=-1, dtype="int32",
dtype="int32", )
) self.attn_mask_offsets_full = paddle.full(
self.attn_mask_offsets_full = paddle.full( [self.scheduler_config.max_num_seqs, self.model_config.max_model_len], -1, dtype="int32"
[self.scheduler_config.max_num_seqs, self.model_config.max_model_len], -1, dtype="int32" )
) self.attn_mask_offsets_decoder = paddle.full([self.scheduler_config.max_num_seqs, 1], -1, dtype="int32")
self.attn_mask_offsets_decoder = paddle.full(
[self.scheduler_config.max_num_seqs, 1], -1, dtype="int32"
)
def swap_states(self, i1, i2) -> None: def swap_states(self, i1, i2) -> None:
def swap_data(tensor, idx1, idx2): def swap_data(tensor, idx1, idx2):
@@ -896,7 +894,7 @@ class ProposerInputBatch(InputBatch):
swap_data(self.input_ids_len, i1, i2) swap_data(self.input_ids_len, i1, i2)
swap_data(self.mask_rollback, i1, i2) swap_data(self.mask_rollback, i1, i2)
swap_data(self.recompute_token_num, i1, i2) swap_data(self.recompute_token_num, i1, i2)
if self.enable_mm and self.fd_config.deploy_modality != DeployModality.TEXT: if self.enable_mm:
swap_data(self.attn_mask_offsets_full, i1, i2) swap_data(self.attn_mask_offsets_full, i1, i2)
swap_data(self.attn_mask_offsets_decoder, i1, i2) swap_data(self.attn_mask_offsets_decoder, i1, i2)
@@ -1030,10 +1028,9 @@ class ProposerInputBatch(InputBatch):
# Reset multimodal tensors if enabled # Reset multimodal tensors if enabled
if self.enable_mm: if self.enable_mm:
fill_paddle_tensor(self, "decode_states", -1) fill_paddle_tensor(self, "decode_states", -1)
if self.fd_config.deploy_modality != DeployModality.TEXT: fill_paddle_tensor(self, "attn_mask_offsets", -1)
fill_paddle_tensor(self, "attn_mask_offsets", -1) fill_paddle_tensor(self, "attn_mask_offsets_full", -1)
fill_paddle_tensor(self, "attn_mask_offsets_full", -1) fill_paddle_tensor(self, "attn_mask_offsets_decoder", -1)
fill_paddle_tensor(self, "attn_mask_offsets_decoder", -1)
logger.info("model_inputs reset completed") logger.info("model_inputs reset completed")
except Exception as e: except Exception as e:
+1 -1
View File
@@ -97,7 +97,7 @@ class MetaxModelRunner(ModelRunnerBase):
): ):
super().__init__(fd_config=fd_config, device=device) super().__init__(fd_config=fd_config, device=device)
self.MAX_INFER_SEED = 9223372036854775806 self.MAX_INFER_SEED = 9223372036854775806
self.enable_mm = self.model_config.enable_mm self.enable_mm = self.fd_config.enable_mm_runtime
self.rank = rank self.rank = rank
self.local_rank = local_rank self.local_rank = local_rank
self.device_id = device_id self.device_id = device_id
+2 -2
View File
@@ -138,7 +138,7 @@ def init_distributed_environment(seed: int = 20) -> Tuple[int, int]:
def update_fd_config_for_mm(fd_config: FDConfig) -> None: def update_fd_config_for_mm(fd_config: FDConfig) -> None:
architectures = fd_config.model_config.architectures architectures = fd_config.model_config.architectures
if fd_config.model_config.enable_mm and ErnieArchitectures.contains_ernie_arch(architectures): if fd_config.enable_mm_runtime and ErnieArchitectures.contains_ernie_arch(architectures):
fd_config.model_config.tensor_model_parallel_size = fd_config.parallel_config.tensor_parallel_size fd_config.model_config.tensor_model_parallel_size = fd_config.parallel_config.tensor_parallel_size
fd_config.model_config.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank fd_config.model_config.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
fd_config.model_config.vision_config.dtype = fd_config.model_config.dtype fd_config.model_config.vision_config.dtype = fd_config.model_config.dtype
@@ -506,7 +506,7 @@ class PaddleDisWorkerProc:
if tp_rank == 0: if tp_rank == 0:
if self.task_queue.exist_tasks(): if self.task_queue.exist_tasks():
if envs.ENABLE_V1_KVCACHE_SCHEDULER or not ( if envs.ENABLE_V1_KVCACHE_SCHEDULER or not (
self.fd_config.model_config.enable_mm and self.worker.exist_prefill() self.fd_config.enable_mm_runtime and self.worker.exist_prefill()
): ):
if self.nnode > 1: if self.nnode > 1:
self.task_queue.read_finish_flag.set(1) self.task_queue.read_finish_flag.set(1)
+1 -1
View File
@@ -97,7 +97,7 @@ class XPUModelRunner(ModelRunnerBase):
local_rank: int, local_rank: int,
): ):
super().__init__(fd_config=fd_config, device=device) super().__init__(fd_config=fd_config, device=device)
self.enable_mm = self.model_config.enable_mm self.enable_mm = self.fd_config.enable_mm_runtime
self.rank = rank self.rank = rank
self.local_rank = local_rank self.local_rank = local_rank
self.device_id = device_id self.device_id = device_id
+2 -1
View File
@@ -92,6 +92,7 @@ class MockFDConfig:
model_config = MockModelConfig() model_config = MockModelConfig()
cache_config = MockCacheConfig() cache_config = MockCacheConfig()
speculative_config = MockSpecaulativeConfig() speculative_config = MockSpecaulativeConfig()
enable_mm_runtime = MockModelConfig.enable_mm
def get_max_chunk_tokens(self, mm_max_tokens_per_item=None): def get_max_chunk_tokens(self, mm_max_tokens_per_item=None):
return 8192 return 8192
@@ -139,7 +140,7 @@ class TestChunkedMoE(unittest.TestCase):
model_runner.model_config = mock_model_config model_runner.model_config = mock_model_config
model_runner.cache_config = mock_cache_config model_runner.cache_config = mock_cache_config
model_runner.attn_backends = [MockAttentionBackend()] model_runner.attn_backends = [MockAttentionBackend()]
model_runner.enable_mm = True model_runner.enable_mm = mock_fd_config.enable_mm_runtime
model_runner.cudagraph_only_prefill = False model_runner.cudagraph_only_prefill = False
model_runner.use_cudagraph = False model_runner.use_cudagraph = False
model_runner.speculative_decoding = False model_runner.speculative_decoding = False
+7
View File
@@ -102,6 +102,7 @@ def create_mock_fd_config(
mock_config.structured_outputs_config = Mock() mock_config.structured_outputs_config = Mock()
mock_config.structured_outputs_config.reasoning_parser = None mock_config.structured_outputs_config.reasoning_parser = None
mock_config.tool_parser = None mock_config.tool_parser = None
mock_config.enable_mm_runtime = enable_mm
return mock_config return mock_config
@@ -181,6 +182,7 @@ class TestEngineClient(unittest.IsolatedAsyncioTestCase):
mock_config.structured_outputs_config = Mock() mock_config.structured_outputs_config = Mock()
mock_config.structured_outputs_config.reasoning_parser = None mock_config.structured_outputs_config.reasoning_parser = None
mock_config.node_rank = 0 mock_config.node_rank = 0
mock_config.enable_mm_runtime = mock_model_config.enable_mm
# Create mocks for all the external dependencies # Create mocks for all the external dependencies
mock_input_processor = Mock() mock_input_processor = Mock()
@@ -363,6 +365,7 @@ class TestEngineClientValidParameters(unittest.TestCase):
mock_config.structured_outputs_config = MagicMock() # Add this mock_config.structured_outputs_config = MagicMock() # Add this
mock_config.structured_outputs_config.reasoning_parser = None mock_config.structured_outputs_config.reasoning_parser = None
mock_config.tool_parser = None # Add this attribute mock_config.tool_parser = None # Add this attribute
mock_config.enable_mm_runtime = mock_model_config.enable_mm
# Mock IPCSignal to avoid file system dependencies # Mock IPCSignal to avoid file system dependencies
with patch("fastdeploy.entrypoints.engine_client.IPCSignal") as mock_ipcsignal: with patch("fastdeploy.entrypoints.engine_client.IPCSignal") as mock_ipcsignal:
@@ -655,6 +658,7 @@ class TestEngineClientValidParameters(unittest.TestCase):
mock_config.structured_outputs_config = Mock() mock_config.structured_outputs_config = Mock()
mock_config.structured_outputs_config.reasoning_parser = None mock_config.structured_outputs_config.reasoning_parser = None
mock_config.tool_parser = None mock_config.tool_parser = None
mock_config.enable_mm_runtime = mock_config.model_config.enable_mm
client = EngineClient( client = EngineClient(
pid=5678, pid=5678,
@@ -1078,6 +1082,7 @@ class TestEngineClientValidParameters(unittest.TestCase):
mock_config = Mock() mock_config = Mock()
mock_config.model_config = mock_model_config mock_config.model_config = mock_model_config
mock_config.enable_mm_runtime = mock_model_config.enable_mm
mock_config.eplb_config = Mock() mock_config.eplb_config = Mock()
mock_config.eplb_config.enable_eplb = False mock_config.eplb_config.enable_eplb = False
@@ -1131,6 +1136,7 @@ class TestEngineClientValidParameters(unittest.TestCase):
mock_config = Mock() mock_config = Mock()
mock_config.model_config = mock_model_config mock_config.model_config = mock_model_config
mock_config.enable_mm_runtime = mock_model_config.enable_mm
mock_config.eplb_config = Mock() mock_config.eplb_config = Mock()
mock_config.eplb_config.enable_eplb = False mock_config.eplb_config.enable_eplb = False
@@ -1408,6 +1414,7 @@ class TestEngineClientValidParameters(unittest.TestCase):
mock_config = Mock() mock_config = Mock()
mock_config.model_config = mock_model_config mock_config.model_config = mock_model_config
mock_config.enable_mm_runtime = mock_model_config.enable_mm
mock_config.eplb_config = Mock() mock_config.eplb_config = Mock()
mock_config.eplb_config.enable_eplb = False mock_config.eplb_config.enable_eplb = False
@@ -92,6 +92,7 @@ class DummyFDConfig:
"max_model_len": 2048, "max_model_len": 2048,
"head_dim": 128, "head_dim": 128,
"num_hidden_layers": 2, "num_hidden_layers": 2,
"enable_mm": False,
"causal": True, "causal": True,
"start_layer_index": 0, "start_layer_index": 0,
"rope_3d": False, "rope_3d": False,
@@ -124,6 +125,8 @@ class DummyFDConfig:
"model_type": "main", "model_type": "main",
}, },
)() )()
self.enable_mm_runtime = self.model_config.enable_mm
self.enable_rope_3d_runtime = self.model_config.enable_mm
class DummyLayer: class DummyLayer:
@@ -78,6 +78,7 @@ class StubConfig:
self.cache_config = CacheConfig() self.cache_config = CacheConfig()
self.parallel_config = ParallelConfig() self.parallel_config = ParallelConfig()
self.speculative_config = SpeculativeConfig() self.speculative_config = SpeculativeConfig()
self.enable_mm_runtime = self.model_config.enable_mm
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -168,6 +169,7 @@ class TestChunkedPrefillDeterminism(unittest.TestCase):
def _create_mm_resource_manager(self): def _create_mm_resource_manager(self):
config = StubConfig() config = StubConfig()
config.model_config.enable_mm = True config.model_config.enable_mm = True
config.enable_mm_runtime = config.model_config.enable_mm
return self._create_resource_manager(config) return self._create_resource_manager(config)
# ==================== 1. Deterministic disabled ==================== # ==================== 1. Deterministic disabled ====================
+1
View File
@@ -64,6 +64,7 @@ class MockConfig:
scheduler_config = SchedulerConfig() scheduler_config = SchedulerConfig()
cache_config = CacheConfig() cache_config = CacheConfig()
parallel_config = ParallelConfig() parallel_config = ParallelConfig()
enable_mm_runtime = model_config.enable_mm
def get_max_chunk_tokens(self, mm_max_tokens_per_item=None): def get_max_chunk_tokens(self, mm_max_tokens_per_item=None):
return 8192 return 8192
@@ -83,6 +83,7 @@ def create_mock_config():
fd_config.parallel_config = parallel_config fd_config.parallel_config = parallel_config
fd_config.structured_outputs_config = structured_outputs_config fd_config.structured_outputs_config = structured_outputs_config
fd_config.pad_to = 8 fd_config.pad_to = 8
fd_config.enable_mm_runtime = model_config.enable_mm
def get_max_chunk_tokens(mm_max_tokens_per_item=None): def get_max_chunk_tokens(mm_max_tokens_per_item=None):
return 100 return 100