mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[V1 Loader] Support loading static C8 scale JSON (#5909)
* v1 loader: support loading static C8 scale JSON * update
This commit is contained in:
@@ -76,7 +76,12 @@ def get_weight_iterator(model_path: str):
|
||||
weights_iterator = safetensors_weights_iterator(files_list)
|
||||
else:
|
||||
weights_iterator = pdparams_weight_iterator(files_list)
|
||||
return weights_iterator
|
||||
|
||||
yield from weights_iterator
|
||||
|
||||
kv_cache_scale_json_path = Path(model_path) / "kv_cache_scale.json"
|
||||
if kv_cache_scale_json_path.exists():
|
||||
yield from kv_cache_scale_iterator(str(kv_cache_scale_json_path))
|
||||
|
||||
|
||||
def is_weight_cache_enabled(fd_config, weight_cache_path=".cache"):
|
||||
@@ -319,6 +324,17 @@ def load_ep_checkpoint(cls: PretrainedModel, model_path: str, fd_config: FDConfi
|
||||
return state_dict
|
||||
|
||||
|
||||
def kv_cache_scale_iterator(kv_cache_scale_json_path):
|
||||
"""
|
||||
kv_cache_scale_iterator
|
||||
"""
|
||||
with open(kv_cache_scale_json_path, "r") as f:
|
||||
data = json.load(f)
|
||||
for key, value in data.items():
|
||||
scale_tensor = paddle.to_tensor(value, dtype=paddle.get_default_dtype()) * 448.0
|
||||
yield key, scale_tensor
|
||||
|
||||
|
||||
def safetensors_weights_iterator(safe_tensor_list: list[str]):
|
||||
"""
|
||||
safetensors_weights_iterator
|
||||
@@ -399,7 +415,7 @@ def deal_state_dict(state_dict):
|
||||
src_tensor._share_data_with(dst_tensor)
|
||||
|
||||
|
||||
def load_cache_scale(fd_config, state_dict):
|
||||
def load_kv_cache_scale(fd_config, state_dict):
|
||||
file_path = fd_config.model_config.kv_cache_quant_scale_path
|
||||
prefix_layer_name = fd_config.model_config.prefix_layer_name
|
||||
if os.path.exists(file_path):
|
||||
@@ -465,6 +481,6 @@ def load_composite_checkpoint(
|
||||
if hasattr(fd_config.quant_config, "kv_cache_quant_type"):
|
||||
kv_cache_quant_type = fd_config.quant_config.kv_cache_quant_type
|
||||
if kv_cache_quant_type == "float8_e4m3fn":
|
||||
load_cache_scale(fd_config, state_dict)
|
||||
load_kv_cache_scale(fd_config, state_dict)
|
||||
|
||||
return state_dict
|
||||
|
||||
Reference in New Issue
Block a user