V1 loader default (#4251)

* v1 laoder

* update

* update
This commit is contained in:
bukejiyu
2025-10-15 16:49:17 +08:00
committed by GitHub
parent e98c1c2f47
commit bcaa98ff9c
4 changed files with 68 additions and 9 deletions
+51
View File
@@ -14,14 +14,18 @@
# limitations under the License.
"""
import os
import re
from contextlib import contextmanager
from typing import Any, Optional, Union
import paddle
from paddleformers.utils.log import logger
from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.platforms import current_platform
class BitMaskTracker:
@@ -194,6 +198,53 @@ def default_weight_loader(fd_config: FDConfig) -> None:
return fn
def is_pre_sliced_weight(model_path):
rank_dirs = [
f for f in os.listdir(model_path) if f.startswith("rank") and os.path.isdir(os.path.join(model_path, f))
]
return len(rank_dirs) > 1
def v1_loader_support(fd_config):
_v1_no_support_archs = ["Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration"]
def _err_msg(msg: str) -> str:
logger.info(msg + "; fallback to the v0 loader for model loading.")
if not current_platform.is_cuda():
_err_msg("v1loader currently does not support backends other than CUDA")
return False
if is_pre_sliced_weight(fd_config.model_config.model):
_err_msg("v1 loader currently does not support pre-sliced weights")
return False
if fd_config.parallel_config.use_ep:
_err_msg("v1 loader currently does not support expert parallelism")
return False
if envs.FD_MOE_BACKEND.lower() == "marlin":
_err_msg("v1 loader currently does not support marlin backend")
return False
if fd_config.quant_config is not None:
if fd_config.quant_config.name() == "mix_quant":
moe_quant_type = fd_config.quant_config.moe_quant_type
dense_quant_type = fd_config.quant_config.dense_quant_type
else:
moe_quant_type = fd_config.quant_config.name()
dense_quant_type = fd_config.quant_config.name()
unsupported_quant = {"w4a8", "w4afp8", "wint2"}
if unsupported_quant & {moe_quant_type, dense_quant_type}:
_err_msg("v1 loader currently does not support w4a8/w4afp8/win2 quantization")
return False
if fd_config.model_config.architectures[0] in _v1_no_support_archs:
_err_msg(f"v1 loader currently does not support {fd_config.model_config.architectures[0]}")
return False
return True
@contextmanager
def temporary_dtype(dtype: str):
"""Temporarily set Paddle default dtype"""