mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-24 01:29:57 +08:00
[XPU] Support W4A8C8-TP4-300B Model (#4068)
* support w4a8 * delete ep block attn * delete moe_topk_select * update note * update * delte useless info * update * add some note * fix some format * update scale info * add ans baseline --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
@@ -22,6 +22,7 @@ from paddle import nn
|
||||
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.model_executor.utils import set_weight_attrs
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
|
||||
@@ -94,7 +95,14 @@ class KvCacheQuantConfig(QuantConfigBase):
|
||||
"""
|
||||
get_quant_method
|
||||
"""
|
||||
return KVCacheMethodBase(self)
|
||||
if current_platform.is_xpu():
|
||||
from fastdeploy.model_executor.layers.backends.xpu.quantization.kv_cache import (
|
||||
XPUKVCacheMethodBase,
|
||||
)
|
||||
|
||||
return XPUKVCacheMethodBase(self)
|
||||
else:
|
||||
return KVCacheMethodBase(self)
|
||||
|
||||
|
||||
class KVCacheMethodBase(QuantMethodBase):
|
||||
@@ -118,6 +126,7 @@ class KVCacheMethodBase(QuantMethodBase):
|
||||
"""
|
||||
cache_k_zeropoint = get_tensor(state_dict.pop(self.cache_k_zp_name)).cast(paddle.get_default_dtype())
|
||||
cache_v_zeropoint = get_tensor(state_dict.pop(self.cache_v_zp_name)).cast(paddle.get_default_dtype())
|
||||
|
||||
layer.cache_k_zp.set_value(cache_k_zeropoint)
|
||||
layer.cache_v_zp.set_value(cache_v_zeropoint)
|
||||
|
||||
@@ -125,7 +134,6 @@ class KVCacheMethodBase(QuantMethodBase):
|
||||
"""
|
||||
load_scale
|
||||
"""
|
||||
|
||||
cache_k_scale_tensor = (
|
||||
get_tensor(state_dict.pop(self.cache_k_scale_name)).cast(paddle.get_default_dtype()).reshape_([-1])
|
||||
)
|
||||
@@ -186,6 +194,7 @@ class KVCacheMethodBase(QuantMethodBase):
|
||||
dtype=paddle.get_default_dtype(),
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
set_weight_attrs(
|
||||
layer.cache_k_scale,
|
||||
{
|
||||
@@ -198,6 +207,7 @@ class KVCacheMethodBase(QuantMethodBase):
|
||||
**extra_weight_attrs,
|
||||
},
|
||||
)
|
||||
|
||||
layer.cache_k_out_scale = layer.create_parameter(
|
||||
shape=scale_shape,
|
||||
dtype=paddle.get_default_dtype(),
|
||||
|
||||
Reference in New Issue
Block a user