mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
@@ -14,14 +14,18 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import paddle
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
|
||||
class BitMaskTracker:
|
||||
@@ -194,6 +198,53 @@ def default_weight_loader(fd_config: FDConfig) -> None:
|
||||
return fn
|
||||
|
||||
|
||||
def is_pre_sliced_weight(model_path):
|
||||
rank_dirs = [
|
||||
f for f in os.listdir(model_path) if f.startswith("rank") and os.path.isdir(os.path.join(model_path, f))
|
||||
]
|
||||
return len(rank_dirs) > 1
|
||||
|
||||
|
||||
def v1_loader_support(fd_config):
|
||||
_v1_no_support_archs = ["Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration"]
|
||||
|
||||
def _err_msg(msg: str) -> str:
|
||||
logger.info(msg + "; fallback to the v0 loader for model loading.")
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
_err_msg("v1loader currently does not support backends other than CUDA")
|
||||
return False
|
||||
|
||||
if is_pre_sliced_weight(fd_config.model_config.model):
|
||||
_err_msg("v1 loader currently does not support pre-sliced weights")
|
||||
return False
|
||||
|
||||
if fd_config.parallel_config.use_ep:
|
||||
_err_msg("v1 loader currently does not support expert parallelism")
|
||||
return False
|
||||
|
||||
if envs.FD_MOE_BACKEND.lower() == "marlin":
|
||||
_err_msg("v1 loader currently does not support marlin backend")
|
||||
return False
|
||||
|
||||
if fd_config.quant_config is not None:
|
||||
if fd_config.quant_config.name() == "mix_quant":
|
||||
moe_quant_type = fd_config.quant_config.moe_quant_type
|
||||
dense_quant_type = fd_config.quant_config.dense_quant_type
|
||||
else:
|
||||
moe_quant_type = fd_config.quant_config.name()
|
||||
dense_quant_type = fd_config.quant_config.name()
|
||||
unsupported_quant = {"w4a8", "w4afp8", "wint2"}
|
||||
|
||||
if unsupported_quant & {moe_quant_type, dense_quant_type}:
|
||||
_err_msg("v1 loader currently does not support w4a8/w4afp8/win2 quantization")
|
||||
return False
|
||||
if fd_config.model_config.architectures[0] in _v1_no_support_archs:
|
||||
_err_msg(f"v1 loader currently does not support {fd_config.model_config.architectures[0]}")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@contextmanager
|
||||
def temporary_dtype(dtype: str):
|
||||
"""Temporarily set Paddle default dtype"""
|
||||
|
||||
Reference in New Issue
Block a user