[Speculative Decoding][MTP] Support static CacheKV C8 quantization and optimize memory usage (#5155)

* support static cachekv c8 quantization in mtp mode

* optimize memory allocation
This commit is contained in:
freeliuzc
2025-11-21 15:10:13 +08:00
committed by GitHub
parent 3c36283d7d
commit 2d1dade5e2
6 changed files with 350 additions and 295 deletions
@@ -475,15 +475,16 @@ def deal_state_dict(state_dict):
src_tensor._share_data_with(dst_tensor)
def load_cache_scale(model_path, fd_config, state_dict):
file_path = os.path.join(model_path, "kv_cache_scale.json")
def load_cache_scale(fd_config, state_dict):
file_path = fd_config.model_config.kv_cache_quant_scale_path
prefix_layer_name = fd_config.model_config.prefix_layer_name
if os.path.exists(file_path):
with open(file_path, "r") as f:
data = json.load(f)
for i in range(fd_config.model_config.num_hidden_layers):
k_scale_name = f"ernie.layers.{i}.self_attn.cachek_matmul.activation_scale"
v_scale_name = f"ernie.layers.{i}.self_attn.cachev_matmul.activation_scale"
k_scale_name = f"ernie.{prefix_layer_name}.{i}.self_attn.cachek_matmul.activation_scale"
v_scale_name = f"ernie.{prefix_layer_name}.{i}.self_attn.cachev_matmul.activation_scale"
k_scale = data[k_scale_name]
k_scale_tensor = paddle.to_tensor(k_scale, dtype=paddle.get_default_dtype())
@@ -547,6 +548,6 @@ def load_composite_checkpoint(
if hasattr(fd_config.quant_config, "kv_cache_quant_type"):
kv_cache_quant_type = fd_config.quant_config.kv_cache_quant_type
if kv_cache_quant_type == "float8_e4m3fn":
load_cache_scale(model_path, fd_config, state_dict)
load_cache_scale(fd_config, state_dict)
return state_dict