mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[FDConfig] disable chunked_mm_input in ernie5 (#5774)
* disable chunked_mm_input in ernie5 * update code * update code * update test case * update testcase * upate case
This commit is contained in:
+14
-6
@@ -133,6 +133,11 @@ class ErnieArchitectures:
|
||||
"Ernie4_5_VLMoeForProcessRewardModel",
|
||||
}
|
||||
|
||||
ERNIE5_MODELS = {
|
||||
"Ernie5ForCausalLM",
|
||||
"Ernie5MoeForCausalLM",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_ernie_model_arch(cls, model_class):
|
||||
if model_class.name().startswith("Ernie") and model_class.name() not in cls.ARCHITECTURES:
|
||||
@@ -148,6 +153,11 @@ class ErnieArchitectures:
|
||||
"""Check if the given architecture is an ERNIE architecture."""
|
||||
return architecture in cls.ARCHITECTURES
|
||||
|
||||
@classmethod
|
||||
def is_ernie5_arch(cls, architectures):
|
||||
"""Check if the given architecture is an ERNIE5 architecture."""
|
||||
return any(arch in architectures for arch in cls.ERNIE5_MODELS)
|
||||
|
||||
|
||||
PRETRAINED_INIT_CONFIGURATION = {
|
||||
"top_p": 1.0,
|
||||
@@ -248,12 +258,6 @@ class ModelConfig:
|
||||
|
||||
self._post_init()
|
||||
|
||||
def disable_mm_prefill_batch(self):
|
||||
"""
|
||||
check if the model architecture disable for mm prefill
|
||||
"""
|
||||
return self._architecture in ["Ernie5ForCausalLM", "Ernie5MoeForCausalLM"]
|
||||
|
||||
def _post_init(self):
|
||||
self.is_unified_ckpt = check_unified_ckpt(self.model)
|
||||
self.runner_type = self._get_runner_type(self.architectures, self.runner)
|
||||
@@ -1805,6 +1809,10 @@ class FDConfig:
|
||||
# It will hang when real batch_size < tp_size
|
||||
self.graph_opt_config.filter_capture_size(tp_size=self.parallel_config.tensor_parallel_size)
|
||||
|
||||
if ErnieArchitectures.is_ernie5_arch(self.model_config.architectures):
|
||||
# ernie5 model not support chunked_mm_input
|
||||
self.cache_config.disable_chunked_mm_input = True
|
||||
|
||||
self.postprocess_devices_and_ports()
|
||||
|
||||
def postprocess_devices_and_ports(self):
|
||||
|
||||
Reference in New Issue
Block a user