[V1 Loader] Support loading static C8 scale JSON (#5909)

* v1 loader: support loading static C8 scale JSON

* update
This commit is contained in:
sunxin
2026-01-07 11:49:30 +08:00
committed by GitHub
parent 7ad5737560
commit 6ee8241521
+19 -3
View File
@@ -76,7 +76,12 @@ def get_weight_iterator(model_path: str):
weights_iterator = safetensors_weights_iterator(files_list)
else:
weights_iterator = pdparams_weight_iterator(files_list)
return weights_iterator
yield from weights_iterator
kv_cache_scale_json_path = Path(model_path) / "kv_cache_scale.json"
if kv_cache_scale_json_path.exists():
yield from kv_cache_scale_iterator(str(kv_cache_scale_json_path))
def is_weight_cache_enabled(fd_config, weight_cache_path=".cache"):
@@ -319,6 +324,17 @@ def load_ep_checkpoint(cls: PretrainedModel, model_path: str, fd_config: FDConfi
return state_dict
def kv_cache_scale_iterator(kv_cache_scale_json_path):
"""
kv_cache_scale_iterator
"""
with open(kv_cache_scale_json_path, "r") as f:
data = json.load(f)
for key, value in data.items():
scale_tensor = paddle.to_tensor(value, dtype=paddle.get_default_dtype()) * 448.0
yield key, scale_tensor
def safetensors_weights_iterator(safe_tensor_list: list[str]):
"""
safetensors_weights_iterator
@@ -399,7 +415,7 @@ def deal_state_dict(state_dict):
src_tensor._share_data_with(dst_tensor)
def load_cache_scale(fd_config, state_dict):
def load_kv_cache_scale(fd_config, state_dict):
file_path = fd_config.model_config.kv_cache_quant_scale_path
prefix_layer_name = fd_config.model_config.prefix_layer_name
if os.path.exists(file_path):
@@ -465,6 +481,6 @@ def load_composite_checkpoint(
if hasattr(fd_config.quant_config, "kv_cache_quant_type"):
kv_cache_quant_type = fd_config.quant_config.kv_cache_quant_type
if kv_cache_quant_type == "float8_e4m3fn":
load_cache_scale(fd_config, state_dict)
load_kv_cache_scale(fd_config, state_dict)
return state_dict