mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Speculative Decoding][MTP] Support static CacheKV C8 quantization and optimize memory usage (#5155)
* support static cachekv c8 quantization in mtp mode * optimize memory allocation
This commit is contained in:
@@ -475,15 +475,16 @@ def deal_state_dict(state_dict):
|
||||
src_tensor._share_data_with(dst_tensor)
|
||||
|
||||
|
||||
def load_cache_scale(model_path, fd_config, state_dict):
|
||||
file_path = os.path.join(model_path, "kv_cache_scale.json")
|
||||
def load_cache_scale(fd_config, state_dict):
|
||||
file_path = fd_config.model_config.kv_cache_quant_scale_path
|
||||
prefix_layer_name = fd_config.model_config.prefix_layer_name
|
||||
if os.path.exists(file_path):
|
||||
with open(file_path, "r") as f:
|
||||
data = json.load(f)
|
||||
for i in range(fd_config.model_config.num_hidden_layers):
|
||||
|
||||
k_scale_name = f"ernie.layers.{i}.self_attn.cachek_matmul.activation_scale"
|
||||
v_scale_name = f"ernie.layers.{i}.self_attn.cachev_matmul.activation_scale"
|
||||
k_scale_name = f"ernie.{prefix_layer_name}.{i}.self_attn.cachek_matmul.activation_scale"
|
||||
v_scale_name = f"ernie.{prefix_layer_name}.{i}.self_attn.cachev_matmul.activation_scale"
|
||||
|
||||
k_scale = data[k_scale_name]
|
||||
k_scale_tensor = paddle.to_tensor(k_scale, dtype=paddle.get_default_dtype())
|
||||
@@ -547,6 +548,6 @@ def load_composite_checkpoint(
|
||||
if hasattr(fd_config.quant_config, "kv_cache_quant_type"):
|
||||
kv_cache_quant_type = fd_config.quant_config.kv_cache_quant_type
|
||||
if kv_cache_quant_type == "float8_e4m3fn":
|
||||
load_cache_scale(model_path, fd_config, state_dict)
|
||||
load_cache_scale(fd_config, state_dict)
|
||||
|
||||
return state_dict
|
||||
|
||||
Reference in New Issue
Block a user