mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-24 01:29:57 +08:00
1352 lines
48 KiB
Python
1352 lines
48 KiB
Python
"""
|
|
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import collections
|
|
import hashlib
|
|
import json
|
|
import multiprocessing as mp
|
|
import os
|
|
import random
|
|
import re
|
|
import struct
|
|
from functools import partial
|
|
from typing import Callable, Optional
|
|
|
|
import numpy as np
|
|
from paddlenlp.transformers import PretrainedTokenizer
|
|
from paddlenlp.transformers.model_utils import _add_variant
|
|
from paddlenlp.transformers.utils import paddlenlp_load
|
|
from paddlenlp.transformers.model_utils import load_tp_checkpoint
|
|
from safetensors import safe_open
|
|
|
|
from paddlenlp.utils.env import (
|
|
PADDLE_WEIGHTS_INDEX_NAME,
|
|
SAFE_MASTER_WEIGHTS_INDEX_NAME,
|
|
SAFE_PEFT_WEIGHTS_INDEX_NAME,
|
|
SAFE_WEIGHTS_INDEX_NAME,
|
|
)
|
|
from paddlenlp.utils.log import logger
|
|
from tqdm import tqdm
|
|
|
|
import paddle
|
|
import paddle.distributed as dist
|
|
from paddle.common_ops_import import convert_dtype
|
|
from paddle.distributed import fleet
|
|
from paddlenlp.transformers import PretrainedTokenizer
|
|
from paddlenlp.transformers.model_utils import _add_variant, load_tp_checkpoint
|
|
from paddlenlp.transformers.utils import paddlenlp_load
|
|
from paddlenlp.utils.env import (PADDLE_WEIGHTS_INDEX_NAME,
|
|
SAFE_MASTER_WEIGHTS_INDEX_NAME,
|
|
SAFE_PEFT_WEIGHTS_INDEX_NAME,
|
|
SAFE_WEIGHTS_INDEX_NAME)
|
|
from paddlenlp.utils.log import logger
|
|
from safetensors import safe_open
|
|
from tqdm import tqdm
|
|
|
|
from fastdeploy.platforms import current_platform
|
|
|
|
from .tokenizer import ErnieBotTokenizer
|
|
import glob
|
|
|
|
MODEL_LIB_NAMES = [
|
|
"ernie_bot.modeling",
|
|
"ernie_bot.modeling_pp",
|
|
"ernie_bot.modeling_moe",
|
|
"ernie_bot.modeling_rm",
|
|
"ernie_bot.proxy_distill",
|
|
]
|
|
|
|
MAX_BSZ = 512
|
|
MAX_DRAFT_TOKENS = 6
|
|
|
|
|
|
class UniqueIDGenerator:
|
|
"""
|
|
The generator for the export model id
|
|
"""
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
def generate_unique_id(self, state_dict):
|
|
"""
|
|
Generate the model id from the timestamp
|
|
"""
|
|
keys = state_dict.keys()
|
|
sorted_keys = sorted(keys)
|
|
first_key = sorted_keys[0]
|
|
first_parameter = state_dict[first_key].cast("float32")
|
|
# 假设模型参数是唯一的,通过第一个key来获取md5sum
|
|
model_md5 = hashlib.md5(str(
|
|
first_parameter.sum()).encode("utf-8")).hexdigest()
|
|
unique_id = f"{model_md5}-{random.randint(10000, 99999)}"
|
|
return unique_id
|
|
|
|
|
|
def load_sharded_checkpoint(folder, variant=None, return_numpy=False):
|
|
"""
|
|
|
|
This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
|
|
loaded in the model.
|
|
|
|
Args:
|
|
folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint.
|
|
variant (`str`): The model variant.
|
|
|
|
"""
|
|
# Load the index
|
|
pdparams_file = os.path.join(folder,
|
|
_add_variant("model_state.pdparams", variant))
|
|
lora_pdparams_file = os.path.join(
|
|
folder, _add_variant("lora_model_state.pdparams", variant))
|
|
safetensors_file = os.path.join(folder,
|
|
_add_variant("model.safetensors", variant))
|
|
if os.path.isfile(pdparams_file):
|
|
return paddle.load(pdparams_file, return_numpy=return_numpy)
|
|
if os.path.isfile(lora_pdparams_file):
|
|
return paddle.load(lora_pdparams_file, return_numpy=return_numpy)
|
|
if os.path.isfile(safetensors_file):
|
|
try:
|
|
from paddlenlp.utils.safetensors import \
|
|
fast_load_file as safe_load_file
|
|
except ImportError:
|
|
from safetensors.numpy import load_file as safe_load_file
|
|
|
|
state_dict = safe_load_file(safetensors_file)
|
|
if not return_numpy:
|
|
for key in list(state_dict.keys()):
|
|
if isinstance(state_dict[key], np.ndarray):
|
|
state_dict[key] = paddle.Tensor(state_dict.pop(key),
|
|
zero_copy=True)
|
|
return state_dict
|
|
|
|
index_file = os.path.join(folder,
|
|
_add_variant(PADDLE_WEIGHTS_INDEX_NAME, variant))
|
|
safe_index_file = os.path.join(
|
|
folder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant))
|
|
safe_master_file = os.path.join(
|
|
folder, _add_variant(SAFE_MASTER_WEIGHTS_INDEX_NAME, variant))
|
|
safe_peft_file = os.path.join(
|
|
folder, _add_variant(SAFE_PEFT_WEIGHTS_INDEX_NAME, variant))
|
|
|
|
index_present = os.path.isfile(index_file)
|
|
safe_index_present = os.path.isfile(safe_index_file)
|
|
safe_master_present = os.path.isfile(safe_master_file)
|
|
safe_peft_present = os.path.isfile(safe_peft_file)
|
|
|
|
load_safe = False
|
|
load_index = None
|
|
if safe_index_present:
|
|
load_safe = True # load safe due to preference
|
|
load_index = safe_index_file
|
|
elif safe_master_present:
|
|
load_safe = True
|
|
load_index = safe_master_file
|
|
elif index_present:
|
|
load_index = index_file
|
|
elif safe_peft_present:
|
|
load_safe = True
|
|
load_index = safe_peft_file
|
|
else:
|
|
raise ValueError(
|
|
f"Could not find {index_file} or {safe_index_file} or {safe_peft_file}"
|
|
)
|
|
|
|
if load_safe:
|
|
try:
|
|
from paddlenlp.utils.safetensors import \
|
|
fast_load_file as safe_load_file
|
|
except ImportError:
|
|
from safetensors.numpy import load_file as safe_load_file
|
|
|
|
with open(load_index, "r", encoding="utf-8") as f:
|
|
index = json.load(f)
|
|
|
|
shard_files = list(set(index["weight_map"].values()))
|
|
loader = (safe_load_file if load_safe else partial(
|
|
paddlenlp_load, map_location="np" if return_numpy else "cpu"))
|
|
|
|
ret = {}
|
|
for shard_file in tqdm(shard_files):
|
|
state_dict = loader(os.path.join(folder, shard_file))
|
|
ret.update(state_dict)
|
|
|
|
if not return_numpy:
|
|
for key in list(ret.keys()):
|
|
if isinstance(ret[key], np.ndarray):
|
|
ret[key] = paddle.Tensor(ret.pop(key), zero_copy=True)
|
|
|
|
return ret
|
|
|
|
|
|
def convert_ndarray_dtype(np_array: np.ndarray,
|
|
target_dtype: str) -> np.ndarray:
|
|
"""convert ndarray
|
|
|
|
Args:
|
|
np_array (np.ndarray): numpy ndarray instance
|
|
target_dtype (str): the target dtype
|
|
|
|
Returns:
|
|
np.ndarray: converted numpy ndarray instance
|
|
"""
|
|
source_dtype = convert_dtype(np_array.dtype)
|
|
if source_dtype == "uint16" or target_dtype == "bfloat16":
|
|
if paddle.is_compiled_with_xpu():
|
|
# xpu not support bf16.
|
|
tensor = paddle.to_tensor(np_array, place=paddle.CPUPlace())
|
|
else:
|
|
tensor = paddle.to_tensor(np_array)
|
|
tensor = paddle.cast(tensor, target_dtype)
|
|
return tensor.numpy()
|
|
|
|
# TODO(wj-Mcat): device_guard will slow the converting
|
|
# with device_guard("cpu"):
|
|
# tensor = paddle.to_tensor(np_array)
|
|
# tensor = paddle.cast(tensor, target_dtype)
|
|
# return tensor.numpy()
|
|
|
|
if target_dtype == "bfloat16":
|
|
target_dtype = "uint16"
|
|
|
|
return np_array.astype(target_dtype)
|
|
|
|
|
|
def ernie_bot_postprocess_past_key_value(past_key_values):
|
|
"""
|
|
ernie_bot_postprocess_past_key_values
|
|
"""
|
|
Cache = collections.namedtuple("Cache", ["k", "v"])
|
|
# (layer_num, bs, prefixlen, head_num/tensor_parallel_degree, head_dim)*2
|
|
keys, values = paddle.transpose(past_key_values, perm=[2, 0, 1, 3,
|
|
4]).split(2)
|
|
|
|
past_key_values = []
|
|
for k, v in zip(keys, values):
|
|
past_key_values.append(Cache(k, v))
|
|
return past_key_values
|
|
|
|
|
|
def ernie_bot_pad_attention_mask(input_ids_shape, num_prefix_tokens,
|
|
attention_mask):
|
|
"""
|
|
ernie_bot_pad_attention_mask
|
|
"""
|
|
if attention_mask.dim() == 2:
|
|
attention_mask = attention_mask[:, None, None, :]
|
|
prefix_attention_mask = paddle.ones(
|
|
[input_ids_shape[0], 1, 1, num_prefix_tokens],
|
|
dtype=attention_mask.dtype,
|
|
)
|
|
else:
|
|
prefix_attention_mask = paddle.ones(
|
|
[input_ids_shape[0], 1, input_ids_shape[-1], num_prefix_tokens],
|
|
dtype=attention_mask.dtype,
|
|
)
|
|
return paddle.concat((prefix_attention_mask, attention_mask), axis=3)
|
|
|
|
|
|
def set_seed(seed: int):
|
|
"""
|
|
set random seed for all random modules
|
|
"""
|
|
paddle.seed(seed)
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
|
|
|
|
def get_infer_model_path(input_dir, model_prefix, is_export: bool = False):
|
|
"""when n_ranks = 1, infer_model_path is: `{input_dir}/{model_prefix}.pdiparams`
|
|
when n_ranks > 1, infer_model_path is: `{input_dir}/rank_{idx}/{model_prefix}.pdiparams`
|
|
|
|
Args:
|
|
input_dir (str): the base input_dir
|
|
model_prefix (str): the prefix name of model
|
|
|
|
Returns:
|
|
str: the path of infer model path
|
|
"""
|
|
n_ranks = dist.get_world_size()
|
|
try:
|
|
local_rank = dist.ParallelEnv().dev_id
|
|
except Exception:
|
|
logger.info(
|
|
"`dist.ParallelEnv().dev_id` is not supported on CPU devices,so set local_rank = 0."
|
|
)
|
|
local_rank = 0
|
|
if n_ranks > 1:
|
|
return os.path.join(input_dir, f"rank_{local_rank}", model_prefix)
|
|
|
|
# if n_ranks director exist, return N-rank directory
|
|
sub_rank_dir = os.path.join(input_dir, f"rank_{local_rank}")
|
|
|
|
if is_export:
|
|
return os.path.join(sub_rank_dir, model_prefix)
|
|
else:
|
|
# when inference, return sub_rank_dir when exists
|
|
if os.path.exists(sub_rank_dir):
|
|
return os.path.join(sub_rank_dir, model_prefix)
|
|
else:
|
|
return os.path.join(input_dir, model_prefix)
|
|
|
|
|
|
def pad_batch_data(insts, pad_id=0, return_seq_len=False, pad_style="right"):
|
|
"""Pad the instances to the max sequence length in batch."""
|
|
# pad to max input len i bsz
|
|
max_len = max(map(len, insts))
|
|
# pad to max input len
|
|
# max_len = args.max_len
|
|
if pad_style == "left":
|
|
inst_data = np.array([[pad_id] * (max_len - len(inst)) + list(inst)
|
|
for inst in insts])
|
|
else:
|
|
inst_data = np.array(
|
|
[list(inst) + [pad_id] * (max_len - len(inst)) for inst in insts])
|
|
if return_seq_len:
|
|
seq_len = np.array([len(inst) for inst in insts])
|
|
return inst_data.astype("int64").reshape([-1, max_len]), seq_len
|
|
else:
|
|
return inst_data.astype("int64").reshape([-1, max_len])
|
|
|
|
|
|
def load_prefix_weights(
|
|
prefix_path: str,
|
|
inference: bool = False,
|
|
batch_size: int = 1,
|
|
dtype: str = "bfloat16",
|
|
) -> np.ndarray | list[paddle.Tensor]:
|
|
"""load prefix weight by path
|
|
|
|
Args:
|
|
prefix_path (str): the path of prefix weight
|
|
"""
|
|
past_key_values = paddle.to_tensor(
|
|
np.load(f"{prefix_path}/pre_caches.npy")).unsqueeze(2)
|
|
|
|
if batch_size > 1:
|
|
past_key_values = paddle.concat([past_key_values] * batch_size, axis=2)
|
|
|
|
# .chatglm static model require one tensor, otherwise list of tensor
|
|
past_key_values = past_key_values.astype(dtype)
|
|
if inference:
|
|
return past_key_values.numpy()
|
|
return past_key_values
|
|
|
|
|
|
def build_for_generation(model, tokenizer: PretrainedTokenizer,
|
|
generation_kwargs: dict):
|
|
"""build `ErnieBotForGenerationFuse` to generate tokens
|
|
|
|
Args:
|
|
model (_type_): ErnieBotModel or ErnieBotFusedModel
|
|
tokenizer (PretrainedTokenizer): pretrained tokenizer
|
|
generation_kwargs (dict): generation_kwargs for model
|
|
|
|
Returns:
|
|
PretrainedModel: ErnieBotForGenerationFuse
|
|
"""
|
|
from ernie_bot.single_model_fused import ErnieBotForGenerationFuse
|
|
|
|
configs = {
|
|
"bos_token_id": tokenizer.bos_token_id,
|
|
"eos_token_id": tokenizer.eos_token_id,
|
|
"pad_token_id": tokenizer.pad_token_id,
|
|
"initializer_range": 0.02,
|
|
"fused_linear": False,
|
|
"min_dec_len": 1,
|
|
"max_dec_len": 1024,
|
|
"top_k": 0,
|
|
"top_p": 0.7,
|
|
"temperature": 0.95,
|
|
"use_topp_sampling": True,
|
|
"inference": True,
|
|
}
|
|
configs.update(generation_kwargs)
|
|
model = ErnieBotForGenerationFuse(model, configs=configs)
|
|
model.eval()
|
|
return model
|
|
|
|
|
|
def init_distributed_env() -> tuple[int, int]:
|
|
"""init distributed envs, and only support mp in ErnieBotModel
|
|
|
|
Returns:
|
|
tuple[int, int]: tensor_parallel_degree, tensor_parallel_rank
|
|
"""
|
|
tensor_parallel_degree = dist.get_world_size()
|
|
tensor_parallel_rank = 0
|
|
|
|
if tensor_parallel_degree > 1:
|
|
strategy = fleet.DistributedStrategy()
|
|
strategy.hybrid_configs = {
|
|
"dp_degree": 1,
|
|
"mp_degree": tensor_parallel_degree,
|
|
"pp_degree": 1,
|
|
"sharding_degree": 1,
|
|
}
|
|
|
|
fleet.init(is_collective=True, strategy=strategy)
|
|
hcg = fleet.get_hybrid_communicate_group()
|
|
tensor_parallel_rank = hcg.get_model_parallel_rank()
|
|
|
|
return tensor_parallel_degree, tensor_parallel_rank
|
|
|
|
|
|
def generate_rank_mapping(output_dir: str):
|
|
"""generate current distributed rank mapping file
|
|
|
|
Args:
|
|
output_dir (str): the directory of rank_mapping file
|
|
"""
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
# must in distributed env
|
|
hcg = fleet.get_hybrid_communicate_group()
|
|
model_parallel_group = hcg.get_model_parallel_group()
|
|
ring_id = model_parallel_group.id
|
|
|
|
world_size = dist.get_world_size()
|
|
with open(os.path.join(output_dir, "rank_mapping.csv"), "w") as f:
|
|
f.write("[ring_id -> ranks]\n")
|
|
f.write(",".join(map(str, [0] + list(range(world_size)))) + "\n")
|
|
f.write(",".join(map(str, [ring_id] + list(range(world_size)))) + "\n")
|
|
|
|
f.write("[rank -> ring_ids]\n")
|
|
for i in range(world_size):
|
|
f.write(f"{i},0,{ring_id}\n")
|
|
|
|
|
|
def save_infer_result(trainer, dev_ds, k=100, src_length=256, tgt_length=512):
|
|
"""
|
|
save infer result into jsonl format
|
|
"""
|
|
from predict_generation import Predictor, batchfy_text
|
|
|
|
all_instructions = []
|
|
all_answers = []
|
|
all_output = []
|
|
|
|
# top k instruction from dev_ds
|
|
for i, ds in enumerate(dev_ds.data):
|
|
if i == k:
|
|
break
|
|
if "instruction" in ds:
|
|
all_instructions.append(ds["instruction"])
|
|
all_answers.append(ds["output"])
|
|
elif "src" in ds:
|
|
if isinstance(ds["src"], list):
|
|
all_instructions.append(ds["src"][0])
|
|
all_answers.append(ds["tgt"][0])
|
|
else:
|
|
all_instructions.append(ds["src"])
|
|
all_answers.append(ds["tgt"])
|
|
|
|
batch_texts = batchfy_text(all_instructions,
|
|
trainer.args.per_device_eval_batch_size)
|
|
predictor = Predictor(
|
|
tokenizer=trainer.tokenizer,
|
|
model=trainer.model,
|
|
src_length=src_length,
|
|
tgt_length=tgt_length,
|
|
)
|
|
|
|
# infer results
|
|
for bs, texts in enumerate(batch_texts):
|
|
outputs = predictor.predict(texts)
|
|
for i, (text, result) in enumerate(zip(texts, outputs["result"])):
|
|
out = {
|
|
"instruction":
|
|
text,
|
|
"answer":
|
|
all_answers[bs * trainer.args.per_device_eval_batch_size + i],
|
|
"output":
|
|
result,
|
|
}
|
|
all_output.append(out)
|
|
|
|
# save results
|
|
if trainer.args.tensor_parallel_rank == 0:
|
|
with open(os.path.join(trainer.args.output_dir, "infer_result.json"),
|
|
"w") as f:
|
|
for out in all_output:
|
|
f.write(json.dumps(out, ensure_ascii=False) + "\n")
|
|
|
|
|
|
def w4a8_weight_convert(state_dict):
|
|
"""W4A8 权重转换函数
|
|
Args:
|
|
state_dict (dict): state_dict of model
|
|
"""
|
|
|
|
def w4_weight_squash(value, name, w4a8_weight_bites_name_map):
|
|
weight_dq = value
|
|
# W8表象下的W4权重的absmax值为112,使用正负112进行权重类型判断
|
|
if weight_dq.max() == 112 or weight_dq.min() == -112:
|
|
weight_dq = weight_dq.cast("int8")
|
|
np_weight_dq = np.array(weight_dq, dtype="int8").view("uint8")
|
|
np_weight_dq_left_div_16 = (np_weight_dq / 16).astype("int8")
|
|
# weight_q = (weight_dq/16).cast('int8')
|
|
weight_q = paddle.to_tensor(np_weight_dq_left_div_16, dtype="int8")
|
|
logger.debug(f"int4 weight:{name}")
|
|
w4a8_weight_bites_name_map[name] = 4
|
|
return weight_q.cast("int8")
|
|
elif weight_dq.max() == 127 or weight_dq.min() == -128:
|
|
logger.debug(f"int8 weight:{name}")
|
|
w4a8_weight_bites_name_map[name] = 8
|
|
return weight_dq.cast("int8")
|
|
else:
|
|
logger.debug(f"fp16/bf16/float weight:{name}")
|
|
return weight_dq
|
|
|
|
w4a8_weight_bites_name_map = {}
|
|
for name, value in state_dict.items():
|
|
if value.dtype == "uint16":
|
|
weight_q = w4_weight_squash(
|
|
paddle.to_tensor(value).cast("float32"),
|
|
name,
|
|
w4a8_weight_bites_name_map,
|
|
)
|
|
state_dict[name] = weight_q.numpy(
|
|
) if weight_q is not None else value
|
|
del weight_q
|
|
w4a8_weight_bites_layers_map = {}
|
|
w4a8_weight_bites_layers_map["qkv_gemm_bits_map"] = []
|
|
w4a8_weight_bites_layers_map["out_gemm_bits_map"] = []
|
|
w4a8_weight_bites_layers_map["ffn1_gemm_bits_map"] = []
|
|
w4a8_weight_bites_layers_map["ffn2_gemm_bits_map"] = []
|
|
for name_keys, gemm_bits in w4a8_weight_bites_name_map.items():
|
|
if "qkv_proj" in name_keys:
|
|
w4a8_weight_bites_layers_map["qkv_gemm_bits_map"].append(gemm_bits)
|
|
elif "out_proj" in name_keys:
|
|
w4a8_weight_bites_layers_map["out_gemm_bits_map"].append(gemm_bits)
|
|
elif "linear1" in name_keys:
|
|
w4a8_weight_bites_layers_map["ffn1_gemm_bits_map"].append(
|
|
gemm_bits)
|
|
elif "linear2" in name_keys:
|
|
w4a8_weight_bites_layers_map["ffn2_gemm_bits_map"].append(
|
|
gemm_bits)
|
|
logger.debug(
|
|
f"w4a8_weight_bites_layers_map:{w4a8_weight_bites_layers_map}")
|
|
return state_dict, w4a8_weight_bites_layers_map
|
|
|
|
|
|
def _vocab_size_with_padding(vocab_size, div_unit, mp_degree):
|
|
padded_size = vocab_size
|
|
multiple = div_unit * mp_degree
|
|
while (padded_size % multiple) != 0:
|
|
padded_size += 1
|
|
# logger.warning(
|
|
# " > padded vocab (size: {}) with {} dummy tokens "
|
|
# "(new size: {})".format(vocab_size, padded_size - vocab_size, padded_size)
|
|
# )
|
|
return padded_size
|
|
|
|
|
|
def save_test_case(cases: list[list[dict]], file: str):
|
|
"""save test to result file
|
|
|
|
Args:
|
|
cases (list[list[dict]]): the content of case
|
|
file (str): the path of saved file
|
|
"""
|
|
with open(file, "w+", encoding="utf-8") as f:
|
|
for case in cases:
|
|
raw = json.dumps(case, ensure_ascii=False)
|
|
f.write(raw + "\n")
|
|
|
|
|
|
def infer_save_test_case(cases: list[list[dict]], file: str):
|
|
"""save test to result file
|
|
|
|
Args:
|
|
cases (list[list[dict]]): the content of case
|
|
file (str): the path of saved file
|
|
"""
|
|
with open(file, "a+", encoding="utf-8") as f:
|
|
for case in cases:
|
|
raw = json.dumps(case, ensure_ascii=False)
|
|
f.write(raw + "\n")
|
|
|
|
|
|
def deserialize_from_file(fp):
|
|
"""
|
|
deserialize a binary file into an array
|
|
"""
|
|
|
|
x_type = fp.read(1)
|
|
x_type_out = struct.unpack("c", x_type)[0]
|
|
# data
|
|
data_list = []
|
|
if x_type_out == b"0":
|
|
data = fp.read(4)
|
|
data_out = struct.unpack("f", data)[0]
|
|
while data:
|
|
data_out = struct.unpack("f", data)[0]
|
|
data_list.append(data_out)
|
|
data = fp.read(4)
|
|
elif x_type_out == b"1":
|
|
data = fp.read(8)
|
|
while data:
|
|
data_out = struct.unpack("l", data)[0]
|
|
data_list.append(data_out)
|
|
data = fp.read(8)
|
|
elif x_type_out == b"2":
|
|
data = fp.read(4)
|
|
while data:
|
|
data_out = struct.unpack("i", data)[0]
|
|
data_list.append(data_out)
|
|
data = fp.read(4)
|
|
else:
|
|
print("type error")
|
|
data_arr = np.array(data_list)
|
|
return data_arr
|
|
|
|
|
|
def read_res(
|
|
model_name_or_path,
|
|
output_tensor_max_shape,
|
|
result_queue: mp.Queue,
|
|
msg_queue_id=None,
|
|
use_ep=False,
|
|
ep_just_for_test=False,
|
|
tokenizer=None,
|
|
):
|
|
"""Read result from queue."""
|
|
if msg_queue_id is None:
|
|
if (current_platform.is_cuda() and
|
|
current_platform.available()) or paddle.is_compiled_with_xpu():
|
|
from fastdeploy.model_executor.ops.gpu import get_output
|
|
elif paddle.is_compiled_with_custom_device("npu"):
|
|
from paddle_custom_device.npu import get_output
|
|
else: # CPU
|
|
from fastdeploy.model_executor.ops.cpu import get_output
|
|
else:
|
|
if (current_platform.is_cuda() and
|
|
current_platform.available()) or paddle.is_compiled_with_xpu():
|
|
from fastdeploy.model_executor.ops.gpu import get_output_dynamic
|
|
elif paddle.is_compiled_with_custom_device("npu"):
|
|
from paddle_custom_device.npu import get_output_dynamic
|
|
else: # CPU
|
|
from fastdeploy.model_executor.ops.cpu import get_output_dynamic
|
|
|
|
if tokenizer is None:
|
|
tokenizer = ErnieBotTokenizer.from_pretrained(model_name_or_path)
|
|
|
|
paddle.device.set_device("cpu")
|
|
paddle.disable_static()
|
|
output_tensor = paddle.full(output_tensor_max_shape,
|
|
fill_value=2,
|
|
dtype="int64")
|
|
|
|
while True:
|
|
outputs = []
|
|
while True:
|
|
if msg_queue_id is None:
|
|
get_output(output_tensor, 0, True)
|
|
else:
|
|
get_output_dynamic(output_tensor, 0, True, msg_queue_id)
|
|
if int(output_tensor[0, 0]) == -2: # read none
|
|
continue
|
|
bsz = int(output_tensor[1, 0])
|
|
output_numpy = output_tensor[2:bsz + 2].numpy()
|
|
output_numpy[output_numpy == -1] = 2
|
|
outputs.append(output_numpy)
|
|
|
|
if int(output_tensor[0, 0]) < 0:
|
|
break
|
|
output = np.concatenate(outputs, axis=1)
|
|
seqs = tokenizer.batch_decode(
|
|
output.tolist(),
|
|
skip_special_tokens=True,
|
|
clean_up_tokenization_spaces=False,
|
|
)
|
|
if use_ep and (not ep_just_for_test):
|
|
print("seqs: ", seqs)
|
|
for i, seq in enumerate(seqs):
|
|
result_queue.put([i, len(output.tolist()[i]), seq])
|
|
|
|
|
|
def speculate_read_res(
|
|
model_name_or_path,
|
|
output_tensor_max_shape,
|
|
result_queue: mp.Queue,
|
|
msg_queue_id=None,
|
|
):
|
|
"""Read result from queue."""
|
|
if msg_queue_id is None:
|
|
from fastdeploy.model_executor.ops.gpu import speculate_get_output
|
|
else:
|
|
from fastdeploy.model_executor.ops.gpu import \
|
|
speculate_get_output_dynamic
|
|
|
|
tokenizer = ErnieBotTokenizer.from_pretrained(model_name_or_path)
|
|
paddle.device.set_device("cpu")
|
|
paddle.disable_static()
|
|
output_tensor = paddle.full(output_tensor_max_shape,
|
|
fill_value=2,
|
|
dtype="int64")
|
|
while True:
|
|
outputs = []
|
|
for _ in range(MAX_BSZ):
|
|
outputs.append([])
|
|
|
|
while True:
|
|
if msg_queue_id is None:
|
|
speculate_get_output(output_tensor, 0, True)
|
|
else:
|
|
speculate_get_output_dynamic(output_tensor, 0, True,
|
|
msg_queue_id)
|
|
if int(output_tensor[0]) == -2: # read none
|
|
continue
|
|
bsz = int(output_tensor[1])
|
|
accept_num = output_tensor[2:bsz + 2].numpy()
|
|
for bi in range(bsz):
|
|
outputs[bi].extend(
|
|
output_tensor.numpy()[2 + MAX_BSZ +
|
|
bi * MAX_DRAFT_TOKENS:2 + MAX_BSZ +
|
|
bi * MAX_DRAFT_TOKENS +
|
|
accept_num[bi]].tolist())
|
|
if int(output_tensor[0]) == -1:
|
|
break
|
|
|
|
seqs = tokenizer.batch_decode(
|
|
outputs,
|
|
skip_special_tokens=True,
|
|
clean_up_tokenization_spaces=False,
|
|
)
|
|
for i in range(bsz):
|
|
result_queue.put([i, len(outputs[i]), seqs[i]])
|
|
|
|
|
|
def calculate_effective_tokens(training_args, train_dataset, max_seq_len):
|
|
"""
|
|
Calculate the effective tokens during training.
|
|
"""
|
|
total_effective_tokens = 0
|
|
try:
|
|
data_parallel_degree = training_args.data_parallel_degree
|
|
except Exception:
|
|
data_parallel_degree = 1
|
|
if training_args.sharding_parallel_degree > 1:
|
|
sharding_parallel_degree = training_args.sharding_parallel_degree
|
|
else:
|
|
sharding_parallel_degree = 1
|
|
|
|
total_batch = (training_args.max_steps *
|
|
training_args.per_device_train_batch_size *
|
|
training_args.gradient_accumulation_steps *
|
|
sharding_parallel_degree * data_parallel_degree)
|
|
for i, data in enumerate(train_dataset):
|
|
if i == total_batch:
|
|
break
|
|
for dd in data:
|
|
total_effective_tokens += len(dd.token_ids)
|
|
total_tokens = total_batch * max_seq_len
|
|
|
|
return total_effective_tokens, total_tokens
|
|
|
|
|
|
def estimate_training(train_dataset, data_args, training_args, model_args):
|
|
"""
|
|
根据训练数据估算训练所需的步数。
|
|
|
|
Args:
|
|
- None
|
|
|
|
Returns:
|
|
- dict: 返回一个字典,包含了训练所需的步骤数信息。
|
|
|
|
"""
|
|
train_dataset.estimate = True
|
|
logger.info("Start to estimate max training steps...")
|
|
with open(data_args.train_task_config) as f:
|
|
train_task_group = json.load(f)
|
|
|
|
if len(train_task_group) > 1:
|
|
logger.warning(
|
|
"Suggest to use max_steps instead of num_train_epochs for multi source dataset."
|
|
)
|
|
logger.info(
|
|
"Multi source dataset detected, number of samples will be estimated by following rule. "
|
|
"num_samples = (source1_num_samples * prob1 + source2_num_samples * prob2 + ...) * epochs"
|
|
)
|
|
|
|
max_samples = train_dataset.max_estimate_samples
|
|
|
|
if training_args.max_estimate_samples != -1:
|
|
# Set estimate samples to max_estimate_samples
|
|
logger.warning(
|
|
"The results between sampling and non-sampling methods may differ."
|
|
)
|
|
train_dataset.max_estimate_samples = min(
|
|
training_args.max_estimate_samples,
|
|
train_dataset.max_estimate_samples)
|
|
|
|
if train_dataset.max_estimate_samples > 0:
|
|
train_batches = 0
|
|
train_tokens = 0
|
|
for sequences in train_dataset:
|
|
if not train_dataset.estimate:
|
|
break
|
|
train_batches += 1
|
|
for sequence in sequences:
|
|
train_tokens += len(sequence.token_ids)
|
|
|
|
train_tokens *= training_args.num_train_epochs
|
|
train_batches *= training_args.num_train_epochs
|
|
global_batch_size = (training_args.per_device_train_batch_size *
|
|
training_args.gradient_accumulation_steps *
|
|
max(training_args.data_parallel_degree, 1) *
|
|
max(training_args.sharding_parallel_degree, 1))
|
|
max_steps = int(np.ceil(train_batches / global_batch_size))
|
|
|
|
if max_samples != train_dataset.max_estimate_samples:
|
|
max_steps *= max_samples / train_dataset.max_estimate_samples
|
|
train_tokens *= max_samples / train_dataset.max_estimate_samples
|
|
train_dataset.used_samples *= (max_samples /
|
|
train_dataset.max_estimate_samples)
|
|
train_dataset.unused_samples *= (
|
|
max_samples / train_dataset.max_estimate_samples)
|
|
|
|
res = {
|
|
"num_train_epochs":
|
|
int(training_args.num_train_epochs),
|
|
"max_steps":
|
|
int(np.ceil(max_steps)),
|
|
"train_tokens":
|
|
int(train_tokens),
|
|
"global_batch_size":
|
|
int(global_batch_size),
|
|
"gradient_accumulation_steps":
|
|
training_args.gradient_accumulation_steps,
|
|
"warmup_steps":
|
|
int(np.ceil(0.1 * max_steps)),
|
|
"per_device_train_batch_size":
|
|
int(training_args.per_device_train_batch_size),
|
|
"tensor_parallel_degree":
|
|
int(training_args.tensor_parallel_degree),
|
|
"pipeline_parallel_degree":
|
|
int(training_args.pipeline_parallel_degree),
|
|
"sharding_parallel_degree":
|
|
int(training_args.sharding_parallel_degree),
|
|
"seed":
|
|
training_args.seed,
|
|
"num_samples_each_epoch":
|
|
data_args.num_samples_each_epoch,
|
|
"example_from_same_task_prob":
|
|
data_args.example_from_same_task_prob,
|
|
"pseudo_sampling_prob":
|
|
data_args.pseudo_sampling_prob,
|
|
"trigger_data_prob":
|
|
data_args.trigger_data_prob,
|
|
"max_seq_len":
|
|
int(data_args.max_seq_len),
|
|
"valid":
|
|
True,
|
|
"train_samples":
|
|
int(max_samples * training_args.num_train_epochs),
|
|
"estimate_samples":
|
|
int(train_dataset.max_estimate_samples),
|
|
"actual_train_samples":
|
|
int(train_dataset.used_samples * training_args.num_train_epochs),
|
|
"skip_samples":
|
|
int(train_dataset.unused_samples * training_args.num_train_epochs),
|
|
}
|
|
if hasattr(training_args, "num_of_gpus"):
|
|
res["num_of_gpus"] = training_args.num_of_gpus
|
|
|
|
if train_batches / training_args.num_train_epochs / global_batch_size < 1:
|
|
logger.warning(
|
|
"This dataset is too small, you'd better enlarge your dataset."
|
|
)
|
|
res["valid"] = False
|
|
|
|
if getattr(training_args, "estimation_output_file", None):
|
|
with open(training_args.estimation_output_file,
|
|
"w",
|
|
encoding="utf-8") as f:
|
|
json.dump(res, f)
|
|
|
|
return max_steps
|
|
else:
|
|
res = {
|
|
"num_train_epochs":
|
|
int(training_args.num_train_epochs),
|
|
"max_steps":
|
|
0,
|
|
"gradient_accumulation_steps":
|
|
training_args.gradient_accumulation_steps,
|
|
"train_tokens":
|
|
0,
|
|
"per_device_train_batch_size":
|
|
int(training_args.per_device_train_batch_size),
|
|
"tensor_parallel_degree":
|
|
int(training_args.tensor_parallel_degree),
|
|
"pipeline_parallel_degree":
|
|
int(training_args.pipeline_parallel_degree),
|
|
"sharding_parallel_degree":
|
|
int(training_args.sharding_parallel_degree),
|
|
"num_samples_each_epoch":
|
|
data_args.num_samples_each_epoch,
|
|
"example_from_same_task_prob":
|
|
data_args.example_from_same_task_prob,
|
|
"pseudo_sampling_prob":
|
|
data_args.pseudo_sampling_prob,
|
|
"trigger_data_prob":
|
|
data_args.trigger_data_prob,
|
|
"max_seq_len":
|
|
int(data_args.max_seq_len),
|
|
"seed":
|
|
data_args.seed,
|
|
"valid":
|
|
False,
|
|
"train_samples":
|
|
0,
|
|
}
|
|
if hasattr(training_args, "num_of_gpus"):
|
|
res["num_of_gpus"] = training_args.num_of_gpus
|
|
|
|
if getattr(training_args, "estimation_output_file", None):
|
|
with open(training_args.estimation_output_file,
|
|
"w",
|
|
encoding="utf-8") as f:
|
|
json.dump(res, f)
|
|
|
|
logger.error("No valid data found, please check your dataset format.")
|
|
return 0
|
|
|
|
|
|
def get_w4a8_gemm_config_tuple(file_root_path):
|
|
"""读取预配置的gemm 配置表
|
|
Args:
|
|
file_root_path (str): the directory of w4a8_gemm_config files
|
|
"""
|
|
|
|
def get_gemm_config_tuple_from_file(file):
|
|
gemm_tuple_list = []
|
|
for line in file:
|
|
line_split = line.split(" ")
|
|
gemm_tuple_list.append([
|
|
int(line_split[1]),
|
|
int(line_split[2]),
|
|
int(line_split[3]),
|
|
int(line_split[4]),
|
|
int(line_split[5]),
|
|
int(line_split[6]),
|
|
int(line_split[7]),
|
|
])
|
|
gemm_tuple_list.sort(key=lambda x: x[0])
|
|
gemm_tuple_numpy = np.array(gemm_tuple_list, dtype="int32")
|
|
gemm_tuple_numpy = gemm_tuple_numpy.flatten()
|
|
return gemm_tuple_numpy
|
|
|
|
qkv_gemm_config_tuple = []
|
|
out_linear_gemm_config_tuple = []
|
|
ffn1_gemm_config_tuple = []
|
|
ffn2_gemm_config_tuple = []
|
|
try:
|
|
qkv_tuned_gemm_config_log_path = os.path.join(
|
|
f"{file_root_path}", "qkv_tuned_gemm_config.log")
|
|
with open(qkv_tuned_gemm_config_log_path) as file:
|
|
qkv_gemm_config_tuple = get_gemm_config_tuple_from_file(file)
|
|
out_linear_tuned_gemm_config_log_path = os.path.join(
|
|
f"{file_root_path}", "out_linear_tuned_gemm_config.log")
|
|
with open(out_linear_tuned_gemm_config_log_path) as file:
|
|
out_linear_gemm_config_tuple = get_gemm_config_tuple_from_file(
|
|
file)
|
|
ffn1_tuned_gemm_config_log_path = os.path.join(
|
|
f"{file_root_path}", "ffn1_tuned_gemm_config.log")
|
|
with open(ffn1_tuned_gemm_config_log_path) as file:
|
|
ffn1_gemm_config_tuple = get_gemm_config_tuple_from_file(file)
|
|
ffn2_tuned_gemm_config_log_path = os.path.join(
|
|
f"{file_root_path}", "ffn2_tuned_gemm_config.log")
|
|
with open(ffn2_tuned_gemm_config_log_path) as file:
|
|
ffn2_gemm_config_tuple = get_gemm_config_tuple_from_file(file)
|
|
except Exception:
|
|
logger.warning(
|
|
"Found gemm config for W4A8 failed, using empty gemm tuple list for W4A8"
|
|
)
|
|
w4a8_gemm_config = {}
|
|
w4a8_gemm_config["qkv_gemm_config_tuple"] = qkv_gemm_config_tuple
|
|
w4a8_gemm_config[
|
|
"out_linear_gemm_config_tuple"] = out_linear_gemm_config_tuple
|
|
w4a8_gemm_config["ffn1_gemm_config_tuple"] = ffn1_gemm_config_tuple
|
|
w4a8_gemm_config["ffn2_gemm_config_tuple"] = ffn2_gemm_config_tuple
|
|
return w4a8_gemm_config
|
|
|
|
|
|
def update_refined_recompute(rr, sequence_parallel, lora=False):
|
|
"""update refined recompute dict."""
|
|
# if rr is a dict, return it directly
|
|
if isinstance(rr, dict):
|
|
return rr
|
|
if rr == "":
|
|
return {}
|
|
else:
|
|
|
|
rr_res = {
|
|
"mlp_row_ln": 0,
|
|
"attention_row_ln": 0,
|
|
"attention_column_ln": 0,
|
|
"mlp_column_ln": 0,
|
|
"flash_attn": 0,
|
|
}
|
|
ops = rr.split(",")
|
|
for op in ops:
|
|
if ":" not in op:
|
|
raise ValueError(
|
|
"Illegal refined_recompute input, please check.")
|
|
op_name, skip_num = op.split(":")[0], int(op.split(":")[1])
|
|
if op_name not in rr_res:
|
|
raise ValueError(
|
|
f"Refined recompute do not support {op_name}, please check."
|
|
)
|
|
|
|
if op_name in [
|
|
"mlp_row_ln",
|
|
"attention_row_ln",
|
|
"attention_column_ln",
|
|
"mlp_column_ln",
|
|
]:
|
|
if not sequence_parallel:
|
|
logger.warning(
|
|
f"Currently, the `{op_name}` op is only supported "
|
|
"when `sequence_parallel=True`. This refined recompute op will be ignored."
|
|
)
|
|
continue
|
|
if lora:
|
|
logger.warning(
|
|
"Currently, LoRA does not support refined recompute "
|
|
f"for the `{op_name}` op. This refined recompute op will be ignored."
|
|
)
|
|
continue
|
|
rr_res[op_name] = skip_num
|
|
|
|
return rr_res
|
|
|
|
|
|
def model_convert_fp8(model_path, device=None):
|
|
"""
|
|
Convert a model checkpoint from bf16/fp16 to fp8 format.
|
|
Args:
|
|
model_path (str): The path to the directory containing the model checkpoint files
|
|
(e.g., config.json and model_state.pdparams).
|
|
device (str, optional): The device to set for paddle, such as 'cpu' or 'gpu'.
|
|
If None, the default device is used.
|
|
|
|
Note:
|
|
This function requires non-smooth quantization 'act_scales' to be applied when using the converted model.
|
|
"""
|
|
if device is not None:
|
|
paddle.device.set_device(device)
|
|
|
|
config_path = os.path.join(model_path, "config.json")
|
|
with open(config_path, "r") as model_config_file:
|
|
model_config = json.load(model_config_file)
|
|
nums_layers = model_config["num_layers"]
|
|
|
|
weight_scales_path = os.path.join(model_path, "weight_scales_0.json")
|
|
with open(weight_scales_path, "r") as weight_scales_file:
|
|
weight_scales = json.load(weight_scales_file)
|
|
if "ernie.decoder.layers." + str(
|
|
0) + ".gate.weight_quanter" in weight_scales:
|
|
logger.info("FP8 model checkpoint already converted")
|
|
return
|
|
else:
|
|
logger.info("Converting model checkpoint to fp8...")
|
|
|
|
ffn1_weights_name = ".linear1.weight"
|
|
ffn1_bias_name = ".linear1.bias"
|
|
|
|
gate_weights_name = ".gate.weight"
|
|
up_weights_name = ".up.weight"
|
|
gate_bias_name = ".gate.bias"
|
|
up_bias_name = ".up.bias"
|
|
|
|
params_states = paddle.load(
|
|
os.path.join(model_path, "model_state.pdparams"))
|
|
new_path = os.path.join(model_path, "model_state.pdparams")
|
|
|
|
for i in range(0, nums_layers):
|
|
ffn1_weights = params_states["ernie.decoder.layers." + str(i) +
|
|
ffn1_weights_name]
|
|
ffn1_weights_0 = ffn1_weights[:, ::2]
|
|
ffn1_weights_1 = ffn1_weights[:, 1::2]
|
|
|
|
ffn1_weights_0_range = paddle.abs(ffn1_weights_0).max()
|
|
ffn1_weights_1_range = paddle.abs(ffn1_weights_1).max()
|
|
|
|
weight_scales["ernie.decoder.layers." + str(i) +
|
|
".gate.weight_quanter"] = (paddle.cast(
|
|
ffn1_weights_0_range, "float").numpy().tolist())
|
|
weight_scales["ernie.decoder.layers." + str(i) +
|
|
".up.weight_quanter"] = (paddle.cast(
|
|
ffn1_weights_1_range, "float").numpy().tolist())
|
|
params_states["ernie.decoder.layers." + str(i) +
|
|
gate_weights_name] = (ffn1_weights_0 * 448 /
|
|
ffn1_weights_0_range)
|
|
params_states["ernie.decoder.layers." + str(i) +
|
|
up_weights_name] = (ffn1_weights_1 * 448 /
|
|
ffn1_weights_1_range)
|
|
del params_states["ernie.decoder.layers." + str(i) + ffn1_weights_name]
|
|
|
|
ffn1_bias = params_states["ernie.decoder.layers." + str(i) +
|
|
ffn1_bias_name]
|
|
params_states["ernie.decoder.layers." + str(i) +
|
|
gate_bias_name] = ffn1_bias[::2]
|
|
params_states["ernie.decoder.layers." + str(i) +
|
|
up_bias_name] = ffn1_bias[1::2]
|
|
del params_states["ernie.decoder.layers." + str(i) + ffn1_bias_name]
|
|
|
|
with open(model_path + "/weight_scales_0.json", "w") as weight_scales_file:
|
|
json.dump(weight_scales, weight_scales_file)
|
|
|
|
paddle.save(params_states, new_path)
|
|
|
|
|
|
|
|
def load_ep_checkpoint(model_path, config, return_numpy=False, return_key_name=True):
|
|
"""
|
|
load ep checkpoint
|
|
"""
|
|
if return_key_name:
|
|
merge_path = os.path.join(model_path, "merged_tp1_state_split")
|
|
if os.path.isdir(merge_path):
|
|
# load keyname
|
|
|
|
state_dicts = []
|
|
files = glob.glob(model_path + "/merged_tp1_state_split/*")
|
|
for file_name in files:
|
|
try:
|
|
state_dicts += [
|
|
{file_name.split("/")[-1]: file_name}
|
|
] # save {layer_name: weight_file_name}
|
|
except Exception:
|
|
pass
|
|
new_state_dict = {}
|
|
for state_dict in state_dicts:
|
|
for key, value in state_dict.items():
|
|
new_state_dict[key] = value
|
|
state_dict = new_state_dict
|
|
else:
|
|
with open(
|
|
os.path.join(model_path, "model.safetensors.index.json"), "r"
|
|
) as f:
|
|
weight_map = json.load(f)["weight_map"]
|
|
state_dict = {
|
|
k: "[" + k + "]" + os.path.join(model_path, v)
|
|
for k, v in weight_map.items()
|
|
}
|
|
return state_dict
|
|
else:
|
|
# return_numpy=True cpu
|
|
# return_numpy=False gpu
|
|
with open(os.path.join(model_path, "model.safetensors.index.json"), "r") as f:
|
|
weight_list = json.load(f)["weight_map"]
|
|
filtered_map = {k: v for k, v in weight_list.items() if "experts" not in k}
|
|
num_local_ffn_keys = []
|
|
quant_suffix = (
|
|
"quant_weight"
|
|
if config.use_offline_quant and config.moe_quant_type != "default"
|
|
else ""
|
|
)
|
|
scale_suffix = (
|
|
"quant_scale"
|
|
if config.use_offline_quant and config.moe_quant_type != "default"
|
|
else ""
|
|
)
|
|
|
|
for i in range(config.moe_layer_start_index, config.num_layers):
|
|
for j in range(
|
|
config.num_experts_start_offset,
|
|
config.num_experts_start_offset + config.num_experts_per_rank,
|
|
):
|
|
ffn1_quant_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight.{quant_suffix}"
|
|
ffn2_quant_key = (
|
|
f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight.{quant_suffix}"
|
|
)
|
|
ffn1_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight.{scale_suffix}"
|
|
ffn2_scale_key = (
|
|
f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight.{scale_suffix}"
|
|
)
|
|
num_local_ffn_keys.append(ffn1_quant_key)
|
|
num_local_ffn_keys.append(ffn2_quant_key)
|
|
num_local_ffn_keys.append(ffn1_scale_key)
|
|
num_local_ffn_keys.append(ffn2_scale_key)
|
|
|
|
for k in num_local_ffn_keys:
|
|
if k in weight_list:
|
|
filtered_map[k] = weight_list[k]
|
|
|
|
state_dict = {}
|
|
for k, safetensor_path in filtered_map.items():
|
|
with safe_open(
|
|
os.path.join(model_path, safetensor_path), framework="np", device="cpu"
|
|
) as f:
|
|
if k in f.keys():
|
|
weight = f.get_tensor(k)
|
|
if not return_numpy:
|
|
weight = paddle.Tensor(weight, zero_copy=True)
|
|
weight = weight._copy_to(
|
|
paddle.framework._current_expected_place(), False
|
|
)
|
|
state_dict[k] = weight
|
|
return state_dict
|
|
|
|
|
|
def get_safe_tensor_file(model_path):
|
|
"""
|
|
get_safe_tensor_file
|
|
"""
|
|
with open(os.path.join(model_path, "model.safetensors.index.json"),
|
|
"r") as f:
|
|
weight_map = json.load(f)["weight_map"]
|
|
safe_tensor_list = list(set(weight_map.values()))
|
|
key_name_list = list(set(weight_map.keys()))
|
|
safe_tensor_list = [os.path.join(model_path, v) for v in safe_tensor_list]
|
|
|
|
return key_name_list, safe_tensor_list
|
|
|
|
|
|
def safetensors_weights_iterator(safe_tensor_list: list[str], ):
|
|
"""
|
|
safetensors_weights_iterator
|
|
"""
|
|
for st_file in tqdm(
|
|
safe_tensor_list,
|
|
desc="Loading safetensors checkpoint shards",
|
|
):
|
|
with safe_open(st_file, framework="np") as f:
|
|
for name in f.keys():
|
|
param = f.get_tensor(name)
|
|
yield name, param
|
|
|
|
|
|
def get_state_dict(model_path, config):
|
|
"""
|
|
get_sate_dict
|
|
"""
|
|
state_dict = {}
|
|
_, safe_tensor_list = get_safe_tensor_file(
|
|
os.path.join(model_path, f"rank{config.tensor_parallel_rank}"))
|
|
weights_iterator = safetensors_weights_iterator(safe_tensor_list)
|
|
for name, weight in weights_iterator:
|
|
state_dict[name] = weight
|
|
return state_dict
|
|
|
|
|
|
def load_checkpoint(model_path, cls, config, return_numpy=True):
|
|
"""
|
|
load checkpoint
|
|
"""
|
|
if config.use_ep:
|
|
state_dict = load_ep_checkpoint(
|
|
model_path, config, return_numpy=True, return_key_name=True
|
|
)
|
|
else:
|
|
rank_dirs = [
|
|
f
|
|
for f in os.listdir(model_path)
|
|
if f.startswith("rank") and os.path.isdir(os.path.join(model_path, f))
|
|
]
|
|
if len(rank_dirs) > 1:
|
|
if config.tensor_parallel_degree != len(rank_dirs):
|
|
raise ValueError(
|
|
f"Your model only supports loading with tp{len(rank_dirs)}"
|
|
)
|
|
state_dict = get_state_dict(model_path, config)
|
|
else:
|
|
state_dict = load_tp_checkpoint(
|
|
model_path, cls, config, return_numpy=return_numpy
|
|
)
|
|
return state_dict
|
|
|
|
|
|
def parser_quant_type(quant_type):
|
|
"""
|
|
Parse the quantization type string and return the corresponding quantization types for weights,
|
|
activations, and custom.
|
|
|
|
Args:
|
|
quant_type (str): The quantization type string. It can be one of the following formats:
|
|
- "weight_only_int8" or "wint8": Only weights are quantized to int8.
|
|
- "weight_only_int4" or "wint4": Only weights are quantized to int4.
|
|
- A custom string in the format of "wxaybzcfp8", where 'x', 'y', 'z' are the quantization bitwidths
|
|
for weights, activations, and custom respectively,
|
|
and 'a', 'b', 'c' are the prefixes indicating the quantization types
|
|
(e.g., 'fp8' for floating-point 8-bit).
|
|
If a prefix is missing, the default quantization type will be used.
|
|
|
|
Returns:
|
|
tuple: A tuple of three strings representing the quantization types for weights, activations,
|
|
and custom respectively.
|
|
If the input is "weight_only_int8" or "wint8", returns ("int8", default_type, default_type).
|
|
If the input is "weight_only_int4" or "wint4", returns ("int4", default_type, default_type).
|
|
For custom strings, returns the parsed quantization types based on the input format.
|
|
|
|
Raises:
|
|
AssertionError: If the custom quantization type string format is incorrect.
|
|
"""
|
|
default_type = paddle.get_default_dtype()
|
|
conver_dict = {
|
|
"8": "int8",
|
|
"4": "int4",
|
|
"16": paddle.get_default_dtype,
|
|
"fp8": "float8_e4m3fn",
|
|
"fp16": "float16",
|
|
"bf16": "bfloat16",
|
|
"fp32": "float32"
|
|
}
|
|
cache_type = default_type
|
|
if "c8" in quant_type:
|
|
cache_type = "int8"
|
|
elif "cfp8" in quant_type:
|
|
cache_type = "fp8"
|
|
elif "c4" in quant_type:
|
|
cache_type = "int4"
|
|
if "weight_only_int8" in quant_type or "wint8" in quant_type:
|
|
return "int8", default_type, cache_type
|
|
elif "weight_only_int4" in quant_type or "wint4" in quant_type:
|
|
return "int4", default_type, cache_type
|
|
else:
|
|
# split quant type, eg. w4afp8c8 -> ['w', '4', 'a', 'fp8', 'c', '8']
|
|
pattern = f"({'|'.join(map(re.escape, ['w', 'a', 'c']))})"
|
|
splited_type = re.split(pattern, quant_type)
|
|
splited_type = [tmp_type for tmp_type in splited_type if tmp_type]
|
|
assert (len(splited_type) % 2 == 0 and len(splited_type)
|
|
<= 6), f"Quant type[{quant_type}] format error."
|
|
|
|
quant_type_list = []
|
|
if "w" in splited_type:
|
|
w_idx = splited_type.index("w")
|
|
quant_type_list.append(conver_dict[splited_type[w_idx + 1]])
|
|
else:
|
|
quant_type_list.append(default_type)
|
|
if "a" in splited_type:
|
|
a_idx = splited_type.index("a")
|
|
quant_type_list.append(conver_dict[splited_type[a_idx + 1]])
|
|
else:
|
|
quant_type_list.append(default_type)
|
|
if "c" in splited_type:
|
|
c_idx = splited_type.index("c")
|
|
quant_type_list.append(conver_dict[splited_type[c_idx + 1]])
|
|
else:
|
|
quant_type_list.append(default_type)
|
|
|
|
return quant_type_list[0], quant_type_list[1], quant_type_list[2]
|