[FDConfig] disable chunked_mm_input in ernie5 (#5774)

* disable chunked_mm_input in ernie5

* update code

* update code

* update test case

* update testcase

* upate case
This commit is contained in:
kevin
2025-12-26 15:31:27 +08:00
committed by GitHub
parent 03363cab4c
commit 894f4e312b
17 changed files with 40 additions and 16 deletions
+14 -6
View File
@@ -133,6 +133,11 @@ class ErnieArchitectures:
"Ernie4_5_VLMoeForProcessRewardModel",
}
ERNIE5_MODELS = {
"Ernie5ForCausalLM",
"Ernie5MoeForCausalLM",
}
@classmethod
def register_ernie_model_arch(cls, model_class):
if model_class.name().startswith("Ernie") and model_class.name() not in cls.ARCHITECTURES:
@@ -148,6 +153,11 @@ class ErnieArchitectures:
"""Check if the given architecture is an ERNIE architecture."""
return architecture in cls.ARCHITECTURES
@classmethod
def is_ernie5_arch(cls, architectures):
"""Check if the given architecture is an ERNIE5 architecture."""
return any(arch in architectures for arch in cls.ERNIE5_MODELS)
PRETRAINED_INIT_CONFIGURATION = {
"top_p": 1.0,
@@ -248,12 +258,6 @@ class ModelConfig:
self._post_init()
def disable_mm_prefill_batch(self):
"""
check if the model architecture disable for mm prefill
"""
return self._architecture in ["Ernie5ForCausalLM", "Ernie5MoeForCausalLM"]
def _post_init(self):
self.is_unified_ckpt = check_unified_ckpt(self.model)
self.runner_type = self._get_runner_type(self.architectures, self.runner)
@@ -1805,6 +1809,10 @@ class FDConfig:
# It will hang when real batch_size < tp_size
self.graph_opt_config.filter_capture_size(tp_size=self.parallel_config.tensor_parallel_size)
if ErnieArchitectures.is_ernie5_arch(self.model_config.architectures):
# ernie5 model not support chunked_mm_input
self.cache_config.disable_chunked_mm_input = True
self.postprocess_devices_and_ports()
def postprocess_devices_and_ports(self):