[FDConfig] disable chunked_mm_input in ernie5 (#5774)

* disable chunked_mm_input in ernie5

* update code

* update code

* update test case

* update testcase

* upate case
This commit is contained in:
kevin
2025-12-26 15:31:27 +08:00
committed by GitHub
parent 03363cab4c
commit 894f4e312b
17 changed files with 40 additions and 16 deletions
+14 -6
View File
@@ -133,6 +133,11 @@ class ErnieArchitectures:
"Ernie4_5_VLMoeForProcessRewardModel",
}
ERNIE5_MODELS = {
"Ernie5ForCausalLM",
"Ernie5MoeForCausalLM",
}
@classmethod
def register_ernie_model_arch(cls, model_class):
if model_class.name().startswith("Ernie") and model_class.name() not in cls.ARCHITECTURES:
@@ -148,6 +153,11 @@ class ErnieArchitectures:
"""Check if the given architecture is an ERNIE architecture."""
return architecture in cls.ARCHITECTURES
@classmethod
def is_ernie5_arch(cls, architectures):
"""Check if the given architecture is an ERNIE5 architecture."""
return any(arch in architectures for arch in cls.ERNIE5_MODELS)
PRETRAINED_INIT_CONFIGURATION = {
"top_p": 1.0,
@@ -248,12 +258,6 @@ class ModelConfig:
self._post_init()
def disable_mm_prefill_batch(self):
"""
check if the model architecture disable for mm prefill
"""
return self._architecture in ["Ernie5ForCausalLM", "Ernie5MoeForCausalLM"]
def _post_init(self):
self.is_unified_ckpt = check_unified_ckpt(self.model)
self.runner_type = self._get_runner_type(self.architectures, self.runner)
@@ -1805,6 +1809,10 @@ class FDConfig:
# It will hang when real batch_size < tp_size
self.graph_opt_config.filter_capture_size(tp_size=self.parallel_config.tensor_parallel_size)
if ErnieArchitectures.is_ernie5_arch(self.model_config.architectures):
# ernie5 model not support chunked_mm_input
self.cache_config.disable_chunked_mm_input = True
self.postprocess_devices_and_ports()
def postprocess_devices_and_ports(self):
@@ -32,6 +32,7 @@ from fastdeploy.cache_manager.multimodal_cache_manager import (
EncoderCacheManager,
ProcessorCacheManager,
)
from fastdeploy.config import ErnieArchitectures
from fastdeploy.engine.request import (
ImagePosition,
Request,
@@ -680,7 +681,7 @@ class ResourceManagerV1(ResourceManager):
request = self.waiting[0]
if (
self.config.model_config.disable_mm_prefill_batch()
ErnieArchitectures.is_ernie5_arch(self.config.model_config.architectures)
and self._is_mm_request(request)
and self.exist_mm_prefill(scheduled_reqs)
) or (paddle.is_compiled_with_xpu() and self.exist_prefill(scheduled_reqs)):
+5 -8
View File
@@ -1078,7 +1078,6 @@ def check_download_links(bos_client, links, timeout=1):
def init_bos_client():
from baidubce.auth.bce_credentials import BceCredentials
from baidubce.bce_client_configuration import BceClientConfiguration
from baidubce.exception import BceHttpClientError, BceServerError
from baidubce.services.bos.bos_client import BosClient
cfg = BceClientConfiguration(
@@ -1089,14 +1088,12 @@ def init_bos_client():
try:
client = BosClient(cfg)
client.list_buckets()
except BceServerError as e:
if e.status_code == 403:
raise Exception("BOS authentication failed: Invalid AK/SK") from e
raise Exception(f"BOS connection failed: {str(e)}") from e
except BceHttpClientError as e:
raise Exception(f"Invalid BOS endpoint configuration: {str(e)}") from e
except Exception as e:
raise Exception(f"BOS client validation error: {str(e)}") from e
raise Exception(
"Create BOSClient Error, Please check your ENV [ ENCODE_FEATURE_BOS_AK, ENCODE_FEATURE_BOS_SK, ENCODE_FEATURE_ENDPOINT ] \n"
f"Current ENV AK: {envs.ENCODE_FEATURE_BOS_AK}, SK: {envs.ENCODE_FEATURE_BOS_SK}, Endpoint: {envs.ENCODE_FEATURE_ENDPOINT} \n"
f"{str(e)}"
)
return client
+1
View File
@@ -175,6 +175,7 @@ class TestInitEplbSignals(unittest.TestCase):
model_cfg.moe_num_experts = 64
model_cfg.moe_layer_start_index = 1
model_cfg.model = "/test/model"
model_cfg.architectures = ["test_model"]
cache_cfg.bytes_per_layer_per_block = 1
parallel_cfg = ParallelConfig(args)
+1
View File
@@ -55,6 +55,7 @@ class TestRedundantExpertManager(unittest.TestCase):
model_cfg.moe_num_experts = 64
model_cfg.moe_layer_start_index = 1
model_cfg.model = "/test/model"
model_cfg.architectures = ["test_model"]
cache_cfg.bytes_per_layer_per_block = 1
parallel_cfg = ParallelConfig(args)
@@ -159,6 +159,7 @@ class TestCUDAGrpahSubgraph(unittest.TestCase):
parallel_config = ParallelConfig(args={})
model_config = Mock()
model_config.max_model_len = 512
model_config.architectures = ["test_model"]
# Initialize cuda graph capture list
graph_opt_config._set_cudagraph_sizes(max_capture_size=scheduler_config.max_num_seqs)
graph_opt_config.init_with_cudagrpah_size(max_capture_size=scheduler_config.max_num_seqs)
@@ -112,6 +112,7 @@ class TestCUDAGrpahRecapture(unittest.TestCase):
parallel_config = ParallelConfig(args={})
model_config = Mock()
model_config.max_model_len = 5120
model_config.architectures = ["test_model"]
fd_config = FDConfig(
graph_opt_config=graph_opt_config,
scheduler_config=scheduler_config,
@@ -105,6 +105,7 @@ class TestCUDAGrpahSpecDecode(unittest.TestCase):
parallel_config = ParallelConfig(args={})
model_config = Mock()
model_config.max_model_len = 512
model_config.architectures = ["test_model"]
# Initialize cuda graph capture list
graph_opt_config._set_cudagraph_sizes(max_capture_size=scheduler_config.max_num_seqs)
graph_opt_config.init_with_cudagrpah_size(max_capture_size=scheduler_config.max_num_seqs)
@@ -97,6 +97,7 @@ class TestGraphOptBackend(unittest.TestCase):
baseline_parallel_config = ParallelConfig(args={})
model_config = Mock()
model_config.max_model_len = 512
model_config.architectures = ["test_model"]
self.baseline_fd_config = FDConfig(
graph_opt_config=baseline_graph_opt_config,
scheduler_config=baseline_scheduler_config,
@@ -144,6 +145,7 @@ class TestGraphOptBackend(unittest.TestCase):
parallel_config = ParallelConfig(args={})
model_config = Mock()
model_config.max_model_len = 512
model_config.architectures = ["test_model"]
# Create FD config
return FDConfig(
@@ -97,6 +97,7 @@ class TestStaticGraphCUDAGraphSplit(unittest.TestCase):
parallel_config = ParallelConfig(args={})
model_config = Mock()
model_config.max_model_len = 512
model_config.architectures = ["test_model"]
fd_config = FDConfig(
graph_opt_config=graph_opt_config,
scheduler_config=scheduler_config,
+1
View File
@@ -83,6 +83,7 @@ def _create_default_sampling_metadata(
def _create_fd_config(max_model_len):
model_config: Mock = Mock()
model_config.max_model_len = max_model_len
model_config.architectures = ["test_model"]
speculative_config = SpeculativeConfig({})
graph_opt_config = GraphOptimizationConfig({})
scheduler_config = SchedulerConfig({})
+1
View File
@@ -61,6 +61,7 @@ class FakeModelConfig:
self.enable_mm = False
self.max_model_len = 512
self.logprobs_mode = "raw_logprobs"
self.architectures = ["test_model"]
def get_default_test_fd_config():
+5
View File
@@ -39,6 +39,7 @@ class TestConfig(unittest.TestCase):
scheduler_config = SchedulerConfig({})
model_config = Mock()
model_config.max_model_len = 512
model_config.architectures = ["test_model"]
fd_config = FDConfig(
parallel_config=parallel_config,
graph_opt_config=graph_opt_config,
@@ -60,6 +61,7 @@ class TestConfig(unittest.TestCase):
scheduler_config = SchedulerConfig({})
model_config = Mock()
model_config.max_model_len = 512
model_config.architectures = ["test_model"]
fd_config = FDConfig(
parallel_config=parallel_config,
graph_opt_config=graph_opt_config,
@@ -81,6 +83,7 @@ class TestConfig(unittest.TestCase):
scheduler_config = SchedulerConfig({})
model_config: Mock = Mock()
model_config.max_model_len = 512
model_config.architectures = ["test_model"]
fd_config = FDConfig(
parallel_config=parallel_config,
@@ -120,6 +123,7 @@ class TestConfig(unittest.TestCase):
scheduler_config.splitwise_role = "prefill"
model_config: Mock = Mock()
model_config.max_model_len = 512
model_config.architectures = ["test_model"]
fd_config = FDConfig(
parallel_config=parallel_config,
@@ -162,6 +166,7 @@ class TestConfig(unittest.TestCase):
scheduler_config = SchedulerConfig({})
model_config: Mock = Mock()
model_config.max_model_len = 512
model_config.architectures = ["test_model"]
fd_config = FDConfig(
parallel_config=parallel_config,
+1 -1
View File
@@ -127,7 +127,7 @@ class TestInitBosClient(unittest.TestCase):
with self.assertRaises(Exception) as context:
init_bos_client()
self.assertIn("BOS client validation error", str(context.exception))
self.assertIn("Create BOSClient Error, Please check your ENV", str(context.exception))
os.environ.clear()
@@ -33,6 +33,7 @@ def make_prefix_cache_manager(max_num_seqs, enable_mm=False, num_gpu_blocks_over
model_cfg = SimpleNamespace(enable_mm=enable_mm, max_model_len=4196)
speculative_cfg = SimpleNamespace(method=None)
model_cfg.print = print
model_cfg.architectures = ["test_model"]
cache_cfg.bytes_per_layer_per_block = 1
parallel_cfg = ParallelConfig(args)
scheduler_cfg = SchedulerConfig(args)
@@ -35,6 +35,7 @@ def make_prefix_cache_manager(max_num_seqs, enable_mm=False, num_gpu_blocks_over
model_cfg = SimpleNamespace(enable_mm=enable_mm, max_model_len=4196)
speculative_cfg = SimpleNamespace(method=None)
model_cfg.print = print
model_cfg.architectures = ["test_model"]
cache_cfg.bytes_per_layer_per_block = 1
parallel_cfg = ParallelConfig(args)
scheduler_cfg = SchedulerConfig(args)
+1
View File
@@ -44,6 +44,7 @@ class TestResourceManagerV1(unittest.TestCase):
speculative_cfg = SimpleNamespace(method=None)
model_cfg.print = print
model_cfg.max_model_len = 5120
model_cfg.architectures = ["test_model"]
cache_cfg.bytes_per_layer_per_block = 1
parallel_cfg = ParallelConfig(args)
scheduler_cfg = SchedulerConfig(args)