mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
check paddle version for v1 loader (#4473)
This commit is contained in:
@@ -205,6 +205,22 @@ def is_pre_sliced_weight(model_path):
|
||||
return len(rank_dirs) > 1
|
||||
|
||||
|
||||
def is_paddle_support_v1_loader():
|
||||
src_shape = [32, 32]
|
||||
tgt_shape = [1, 32, 64]
|
||||
src_tensor = paddle.ones(src_shape, dtype="float32")
|
||||
tgt_tensor = paddle.zeros(tgt_shape, dtype="float32")
|
||||
for exp_id in range(tgt_shape[0]):
|
||||
# gate
|
||||
gate_tgt = tgt_tensor[exp_id][..., : tgt_shape[2] // 2]
|
||||
gate_tgt.copy_(src_tensor, False)
|
||||
# up
|
||||
up_tgt = tgt_tensor[exp_id][..., tgt_shape[2] // 2 :]
|
||||
up_tgt.copy_(src_tensor, False)
|
||||
is_same = bool(paddle.all(tgt_tensor == 1))
|
||||
return is_same
|
||||
|
||||
|
||||
def v1_loader_support(fd_config):
|
||||
_v1_no_support_archs = ["Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration"]
|
||||
|
||||
@@ -242,6 +258,10 @@ def v1_loader_support(fd_config):
|
||||
if fd_config.model_config.architectures[0] in _v1_no_support_archs:
|
||||
_err_msg(f"v1 loader currently does not support {fd_config.model_config.architectures[0]}")
|
||||
return False
|
||||
|
||||
if not is_paddle_support_v1_loader():
|
||||
_err_msg("The installed Paddle does not support v1 loader")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user