[XPU] Support W4A8C8-TP4-300B Model (#4068)

* support w4a8

* delete ep block attn

* delete moe_topk_select

* update note

* update

* delte useless info

* update

* add some note

* fix some format

* update scale info

* add ans baseline

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
yinwei
2025-10-10 15:41:32 +08:00
committed by GitHub
parent c46d5e48f8
commit 20c7b741f4
21 changed files with 2029 additions and 714 deletions
@@ -22,6 +22,7 @@ from paddle import nn
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.utils import set_weight_attrs
from fastdeploy.platforms import current_platform
from .quant_base import QuantConfigBase, QuantMethodBase
@@ -94,7 +95,14 @@ class KvCacheQuantConfig(QuantConfigBase):
"""
get_quant_method
"""
return KVCacheMethodBase(self)
if current_platform.is_xpu():
from fastdeploy.model_executor.layers.backends.xpu.quantization.kv_cache import (
XPUKVCacheMethodBase,
)
return XPUKVCacheMethodBase(self)
else:
return KVCacheMethodBase(self)
class KVCacheMethodBase(QuantMethodBase):
@@ -118,6 +126,7 @@ class KVCacheMethodBase(QuantMethodBase):
"""
cache_k_zeropoint = get_tensor(state_dict.pop(self.cache_k_zp_name)).cast(paddle.get_default_dtype())
cache_v_zeropoint = get_tensor(state_dict.pop(self.cache_v_zp_name)).cast(paddle.get_default_dtype())
layer.cache_k_zp.set_value(cache_k_zeropoint)
layer.cache_v_zp.set_value(cache_v_zeropoint)
@@ -125,7 +134,6 @@ class KVCacheMethodBase(QuantMethodBase):
"""
load_scale
"""
cache_k_scale_tensor = (
get_tensor(state_dict.pop(self.cache_k_scale_name)).cast(paddle.get_default_dtype()).reshape_([-1])
)
@@ -186,6 +194,7 @@ class KVCacheMethodBase(QuantMethodBase):
dtype=paddle.get_default_dtype(),
default_initializer=paddle.nn.initializer.Constant(0),
)
set_weight_attrs(
layer.cache_k_scale,
{
@@ -198,6 +207,7 @@ class KVCacheMethodBase(QuantMethodBase):
**extra_weight_attrs,
},
)
layer.cache_k_out_scale = layer.create_parameter(
shape=scale_shape,
dtype=paddle.get_default_dtype(),