mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Intel HPU] enable tensor_wise_fp8 (#5324)
* [Intel HPU] enable tensor_wise_fp8 * update code based on comments * fix code style issue * fix bug about RP 5138 * mv kv_cache modifications to HPU backend * fix FP8 Precision Issues * fix FP8 Precision Issues * Add quantization UT --------- Co-authored-by: yanfeich <yanfei.cheng@intel.com> Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
@@ -18,9 +18,11 @@ intel_hpu backend methods
|
||||
|
||||
from .attention.hpu_attn_backend import HPUAttentionBackend
|
||||
from .moe.fused_moe_hpu_backend import HpuMoEMethod, HpuTensorWiseFP8MoEMethod
|
||||
from .quantization.tensor_wise_fp8 import HpuTensorWiseFP8LinearMethod
|
||||
|
||||
__all__ = [
|
||||
"HPUAttentionBackend",
|
||||
"HpuMoEMethod",
|
||||
"HpuTensorWiseFP8MoEMethod",
|
||||
"HpuTensorWiseFP8LinearMethod",
|
||||
]
|
||||
|
||||
@@ -12,8 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .hpu_attn_backend import HPUAttentionBackend
|
||||
|
||||
__all__ = [
|
||||
"HPUAttentionBackend",
|
||||
]
|
||||
"""
|
||||
intel_hpu attention
|
||||
"""
|
||||
|
||||
+145
-65
@@ -236,6 +236,8 @@ class HPUAttentionBackend(AttentionBackend_HPU):
|
||||
# pd_disaggregation
|
||||
self.use_pd_disaggregation = int(os.getenv("FLAGS_use_pd_disaggregation", 0))
|
||||
self.start_layer_index = llm_config.model_config.start_layer_index
|
||||
if llm_config.quant_config:
|
||||
self.quant_method = llm_config.quant_config.get_quant_method(self)
|
||||
|
||||
def init_attention_metadata(self, forward_meta):
|
||||
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
||||
@@ -287,29 +289,47 @@ class HPUAttentionBackend(AttentionBackend_HPU):
|
||||
|
||||
from fastdeploy.model_executor.ops.intel_hpu import (
|
||||
fused_qkv_rope,
|
||||
fused_sdpa_proj_t,
|
||||
fused_qkv_rope_ref,
|
||||
fused_sdpa_proj,
|
||||
fused_sdpa_proj_ref,
|
||||
index_copy_,
|
||||
)
|
||||
|
||||
query_states, key_value_states = fused_qkv_rope(
|
||||
src,
|
||||
qkv_proj.weight,
|
||||
qkv_proj.bias,
|
||||
forward_meta.rotary_embs,
|
||||
getattr(qkv_proj, "act_scale", None),
|
||||
getattr(qkv_proj, "weight_scale", None),
|
||||
getattr(layer, "q_scale", None),
|
||||
getattr(layer, "cache_k_scale", None),
|
||||
getattr(layer, "cache_v_scale", None),
|
||||
q_norm.weight if q_norm is not None else None,
|
||||
k_norm.weight if k_norm is not None else None,
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
forward_meta.total_batch,
|
||||
transpose=False,
|
||||
use_neox_style=layer.use_neox_rotary_style,
|
||||
epsilon=1e-6,
|
||||
)
|
||||
if forward_meta.measurement_mode:
|
||||
qkv_proj_act_scale_key = qkv_proj.weight_key.replace("weight", "activation_scale")
|
||||
query_states, key_value_states = fused_qkv_rope_ref(
|
||||
src,
|
||||
qkv_proj.weight,
|
||||
qkv_proj.bias,
|
||||
forward_meta.rotary_embs,
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
forward_meta.total_batch,
|
||||
transpose=False,
|
||||
use_neox_style=layer.use_neox_rotary_style,
|
||||
measurement_mode=True,
|
||||
qkv_act_scale_key=qkv_proj_act_scale_key,
|
||||
)
|
||||
else:
|
||||
query_states, key_value_states = fused_qkv_rope(
|
||||
src,
|
||||
qkv_proj.weight,
|
||||
qkv_proj.bias,
|
||||
forward_meta.rotary_embs,
|
||||
getattr(qkv_proj, "act_scale", None),
|
||||
getattr(qkv_proj, "weight_scale", None),
|
||||
getattr(layer, "q_scale", None),
|
||||
getattr(layer, "cache_k_scale", None),
|
||||
getattr(layer, "cache_v_scale", None),
|
||||
q_norm.weight if q_norm is not None else None,
|
||||
k_norm.weight if k_norm is not None else None,
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
forward_meta.total_batch,
|
||||
transpose=False,
|
||||
use_neox_style=layer.use_neox_rotary_style,
|
||||
epsilon=1e-6,
|
||||
)
|
||||
|
||||
kv, B, BP_BS, M, H = key_value_states.shape
|
||||
key_value_states_reshape = key_value_states.reshape([kv, -1, forward_meta.block_size, M, H])
|
||||
@@ -321,16 +341,38 @@ class HPUAttentionBackend(AttentionBackend_HPU):
|
||||
index_copy_(v_cache, forward_meta.block_indices, value_states, 0)
|
||||
|
||||
if forward_meta.block_list.shape == forward_meta.block_indices.shape:
|
||||
out_linear_out = fused_sdpa_proj_t(
|
||||
query_states,
|
||||
key_value_states,
|
||||
forward_meta.attn_mask,
|
||||
None,
|
||||
o_proj.weight,
|
||||
scaling_factor=self.head_dim**-0.5,
|
||||
causal=True,
|
||||
softmax_mode=0,
|
||||
)
|
||||
if forward_meta.measurement_mode:
|
||||
o_proj_act_scale_key = o_proj.weight_key.replace("weight", "activation_scale")
|
||||
out_linear_out = fused_sdpa_proj_ref(
|
||||
query_states,
|
||||
key_value_states,
|
||||
forward_meta.attn_mask,
|
||||
o_proj.weight,
|
||||
scaling_factor=self.head_dim**-0.5,
|
||||
causal=True,
|
||||
softmax_mode=0,
|
||||
measurement_mode=True,
|
||||
o_act_scale_key=o_proj_act_scale_key,
|
||||
)
|
||||
else:
|
||||
out_linear_out = fused_sdpa_proj(
|
||||
query_states,
|
||||
key_value_states,
|
||||
forward_meta.attn_mask,
|
||||
None,
|
||||
o_proj.weight,
|
||||
getattr(layer, "q_out_scale", None),
|
||||
getattr(layer, "cache_k_out_scale", None),
|
||||
getattr(layer, "cache_v_out_scale", None),
|
||||
getattr(layer, "s_scale", None),
|
||||
getattr(o_proj, "act_scale", None),
|
||||
getattr(layer, "s_out_scale", None),
|
||||
getattr(o_proj, "act_scale_inv", None),
|
||||
getattr(o_proj, "weight_scale", None),
|
||||
scaling_factor=self.head_dim**-0.5,
|
||||
causal=True,
|
||||
softmax_mode=0,
|
||||
)
|
||||
else:
|
||||
key_states_with_context = k_cache.index_select(forward_meta.block_list)
|
||||
val_states_with_context = v_cache.index_select(forward_meta.block_list)
|
||||
@@ -344,12 +386,20 @@ class HPUAttentionBackend(AttentionBackend_HPU):
|
||||
query_states.shape[0],
|
||||
query_states.shape[1],
|
||||
)
|
||||
out_linear_out = fused_sdpa_proj_t(
|
||||
out_linear_out = fused_sdpa_proj(
|
||||
query_states,
|
||||
key_value_states_with_context,
|
||||
forward_meta.attn_mask,
|
||||
None,
|
||||
o_proj.weight,
|
||||
getattr(layer, "q_out_scale", None),
|
||||
getattr(layer, "cache_k_out_scale", None),
|
||||
getattr(layer, "cache_v_out_scale", None),
|
||||
getattr(layer, "s_scale", None),
|
||||
getattr(o_proj, "act_scale", None),
|
||||
getattr(layer, "s_out_scale", None),
|
||||
getattr(o_proj, "act_scale_inv", None),
|
||||
getattr(o_proj, "weight_scale", None),
|
||||
scaling_factor=self.head_dim**-0.5,
|
||||
causal=False,
|
||||
softmax_mode=0,
|
||||
@@ -378,45 +428,75 @@ class HPUAttentionBackend(AttentionBackend_HPU):
|
||||
forward_decode
|
||||
"""
|
||||
# metadata = self.attention_metadata
|
||||
from fastdeploy.model_executor.ops.intel_hpu import fused_block_attention
|
||||
|
||||
res = fused_block_attention(
|
||||
src,
|
||||
forward_meta.rotary_embs,
|
||||
forward_meta.caches[2 * layer.layer_id],
|
||||
forward_meta.caches[2 * layer.layer_id + 1],
|
||||
forward_meta.block_groups,
|
||||
forward_meta.block_list,
|
||||
forward_meta.block_mapping,
|
||||
forward_meta.attention_mask,
|
||||
forward_meta.block_indices,
|
||||
forward_meta.block_offsets,
|
||||
qkv_proj.weight,
|
||||
qkv_proj.bias,
|
||||
o_proj.weight,
|
||||
q_norm.weight if q_norm is not None else None,
|
||||
k_norm.weight if k_norm is not None else None,
|
||||
getattr(qkv_proj, "act_scale", None),
|
||||
getattr(qkv_proj, "weight_scale", None),
|
||||
getattr(layer, "q_scaling_scale", None),
|
||||
getattr(layer, "cache_k_scale", None),
|
||||
getattr(layer, "s_scale", None),
|
||||
getattr(layer, "cache_v_scale", None),
|
||||
getattr(o_proj, "act_scale", None),
|
||||
getattr(o_proj, "weight_scale", None),
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
scaling_factor=self.head_dim**-0.5,
|
||||
transpose=False,
|
||||
use_neox_style=layer.use_neox_rotary_style,
|
||||
epsilon=1e-6,
|
||||
from fastdeploy.model_executor.ops.intel_hpu import (
|
||||
fused_block_attention,
|
||||
fused_block_attention_ref,
|
||||
)
|
||||
|
||||
if forward_meta.measurement_mode:
|
||||
qkv_proj_act_scale_key = qkv_proj.weight_key.replace("weight", "activation_scale")
|
||||
o_proj_act_scale_key = o_proj.weight_key.replace("weight", "activation_scale")
|
||||
out_linear_out = fused_block_attention_ref(
|
||||
src,
|
||||
forward_meta.rotary_embs,
|
||||
forward_meta.caches[2 * layer.layer_id],
|
||||
forward_meta.caches[2 * layer.layer_id + 1],
|
||||
forward_meta.block_groups,
|
||||
forward_meta.block_list,
|
||||
forward_meta.block_mapping,
|
||||
forward_meta.attention_mask,
|
||||
forward_meta.block_indices,
|
||||
forward_meta.block_offsets,
|
||||
qkv_proj.weight,
|
||||
qkv_proj.bias,
|
||||
o_proj.weight,
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
scaling_factor=self.head_dim**-0.5,
|
||||
transpose=False,
|
||||
use_neox_style=layer.use_neox_rotary_style,
|
||||
measurement_mode=True,
|
||||
qkv_act_scale_key=qkv_proj_act_scale_key,
|
||||
o_act_scale_key=o_proj_act_scale_key,
|
||||
)
|
||||
else:
|
||||
out_linear_out = fused_block_attention(
|
||||
src,
|
||||
forward_meta.rotary_embs,
|
||||
forward_meta.caches[2 * layer.layer_id],
|
||||
forward_meta.caches[2 * layer.layer_id + 1],
|
||||
forward_meta.block_groups,
|
||||
forward_meta.block_list,
|
||||
forward_meta.block_mapping,
|
||||
forward_meta.attention_mask,
|
||||
forward_meta.block_indices,
|
||||
forward_meta.block_offsets,
|
||||
qkv_proj.weight,
|
||||
qkv_proj.bias,
|
||||
o_proj.weight,
|
||||
q_norm.weight if q_norm is not None else None,
|
||||
k_norm.weight if k_norm is not None else None,
|
||||
getattr(qkv_proj, "act_scale", None),
|
||||
getattr(qkv_proj, "weight_scale", None),
|
||||
getattr(layer, "q_scaling_scale", None),
|
||||
getattr(layer, "cache_k_scale", None),
|
||||
getattr(layer, "s_scale", None),
|
||||
getattr(layer, "cache_v_scale", None),
|
||||
getattr(o_proj, "act_scale", None),
|
||||
getattr(o_proj, "weight_scale", None),
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
scaling_factor=self.head_dim**-0.5,
|
||||
transpose=False,
|
||||
use_neox_style=layer.use_neox_rotary_style,
|
||||
epsilon=1e-6,
|
||||
)
|
||||
|
||||
# all_reduce
|
||||
if self.tp_size > 1:
|
||||
from fastdeploy.distributed.communication import (
|
||||
tensor_model_parallel_all_reduce_custom,
|
||||
)
|
||||
|
||||
tensor_model_parallel_all_reduce_custom(res)
|
||||
return res
|
||||
tensor_model_parallel_all_reduce_custom(out_linear_out)
|
||||
return out_linear_out
|
||||
|
||||
@@ -11,6 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" "
|
||||
"""
|
||||
intel_hpu moe
|
||||
"""
|
||||
|
||||
@@ -23,6 +23,7 @@ from fastdeploy import envs
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import (
|
||||
UnquantizedFusedMoEMethod,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
|
||||
|
||||
class HpuMoEMethod(UnquantizedFusedMoEMethod):
|
||||
@@ -41,10 +42,17 @@ class HpuMoEMethod(UnquantizedFusedMoEMethod):
|
||||
|
||||
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
|
||||
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
|
||||
|
||||
layer.up_gate_proj_weight.set_value(stacked_up_gate_proj_weights)
|
||||
layer.down_proj_weight.set_value(stacked_down_proj_weights)
|
||||
|
||||
# for measurement mode
|
||||
up_gate_proj_expert_weight_key = layer.weight_key_map.get("up_gate_proj_expert_weight_key", None)
|
||||
down_proj_expert_weight_key = layer.weight_key_map.get("down_proj_expert_weight_key", None)
|
||||
self.up_gate_proj_act_scale_key = up_gate_proj_expert_weight_key.replace("{}.", "").replace(
|
||||
"weight", "activation_scale"
|
||||
)
|
||||
self.down_proj_expert_act_scale_key = down_proj_expert_weight_key.replace("weight", "activation_scale")
|
||||
|
||||
def apply_ep_prefill(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
@@ -84,22 +92,44 @@ class HpuMoEMethod(UnquantizedFusedMoEMethod):
|
||||
|
||||
# norm_topk_prob = False if layer.topk_method == "noaux_tc" else True
|
||||
chunk_size = envs.FD_HPU_CHUNK_SIZE
|
||||
from fastdeploy.model_executor.ops.intel_hpu import fused_gate_moe
|
||||
measurement_mode = getattr(layer, "measurement_mode", False)
|
||||
if measurement_mode:
|
||||
from fastdeploy.model_executor.ops.intel_hpu import fused_gate_moe_ref
|
||||
|
||||
fused_moe_out = fused_gate_moe(
|
||||
x,
|
||||
gate.weight,
|
||||
layer.gate_correction_bias,
|
||||
layer.up_gate_proj_weight,
|
||||
layer.down_proj_weight,
|
||||
layer.top_k,
|
||||
norm_topk_prob=True,
|
||||
permuted_weights=False,
|
||||
activation="silu",
|
||||
experts_min=layer.expert_id_offset,
|
||||
experts_max=layer.expert_id_offset + layer.num_local_experts - 1,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
fused_moe_out = fused_gate_moe_ref(
|
||||
x,
|
||||
gate.weight,
|
||||
layer.gate_correction_bias,
|
||||
layer.up_gate_proj_weight,
|
||||
layer.down_proj_weight,
|
||||
layer.top_k,
|
||||
norm_topk_prob=True,
|
||||
permuted_weights=False,
|
||||
activation="silu",
|
||||
experts_min=layer.expert_id_offset,
|
||||
experts_max=layer.expert_id_offset + layer.num_local_experts - 1,
|
||||
chunk_size=chunk_size,
|
||||
measurement_mode=True,
|
||||
up_gate_act_scale_key=self.up_gate_proj_act_scale_key,
|
||||
down_act_scale_key=self.down_proj_expert_act_scale_key,
|
||||
)
|
||||
else:
|
||||
from fastdeploy.model_executor.ops.intel_hpu import fused_gate_moe
|
||||
|
||||
fused_moe_out = fused_gate_moe(
|
||||
x,
|
||||
gate.weight,
|
||||
layer.gate_correction_bias,
|
||||
layer.up_gate_proj_weight,
|
||||
layer.down_proj_weight,
|
||||
layer.top_k,
|
||||
norm_topk_prob=True,
|
||||
permuted_weights=False,
|
||||
activation="silu",
|
||||
experts_min=layer.expert_id_offset,
|
||||
experts_max=layer.expert_id_offset + layer.num_local_experts - 1,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
|
||||
return fused_moe_out
|
||||
|
||||
@@ -110,9 +140,133 @@ class HpuTensorWiseFP8MoEMethod(HpuMoEMethod):
|
||||
This method is the oldest way to compute MoE in Paddle.
|
||||
"""
|
||||
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
|
||||
"""
|
||||
Paddle HPU process prequanted weights.
|
||||
"""
|
||||
|
||||
def _extract_scale_tensor(key_template, logical_expert_ids):
|
||||
result = []
|
||||
for i in logical_expert_ids:
|
||||
result.append(get_tensor(state_dict.pop(key_template.format(i))))
|
||||
return result # bf16 tensor list
|
||||
|
||||
def _extract_descale_tensor(key_template, logical_expert_ids):
|
||||
if key_template.format(0) in state_dict:
|
||||
# Extract scale tensors for all logical_expert_ids
|
||||
scale_tensors = []
|
||||
for i in logical_expert_ids:
|
||||
scale_tensor = get_tensor(state_dict.pop(key_template.format(i)))
|
||||
scale_tensors.append(scale_tensor)
|
||||
# Stack all scale tensors into one tensor
|
||||
stacked = paddle.stack(scale_tensors)
|
||||
reciprocal = 1.0 / stacked
|
||||
# Take min over all logical_expert_ids (axis=0)
|
||||
min_tensor = paddle.min(reciprocal, axis=0)
|
||||
return min_tensor.cast(paddle.get_default_dtype())
|
||||
else:
|
||||
key = key_template.replace("{}.", "")
|
||||
scale_tensor = get_tensor(state_dict.pop(key))
|
||||
reciprocal = 1.0 / scale_tensor
|
||||
return reciprocal.cast(paddle.get_default_dtype())
|
||||
|
||||
up_gate_proj_weight, down_proj_weight, logical_expert_ids, _ = layer.extract_moe_ffn_weights(state_dict)
|
||||
up_gate_proj_weights = [t.view(paddle.float8_e4m3fn) for t in up_gate_proj_weight]
|
||||
down_proj_weights = [t.view(paddle.float8_e4m3fn) for t in down_proj_weight]
|
||||
|
||||
weight_key_map = layer.weight_key_map
|
||||
|
||||
up_gate_proj_expert_weight_scale_key = weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
|
||||
down_proj_expert_weight_scale_key = weight_key_map.get("down_proj_expert_weight_scale_key", None)
|
||||
up_gate_proj_expert_in_scale_key = weight_key_map.get("up_gate_proj_expert_in_scale_key", None)
|
||||
down_proj_expert_in_scale_key = weight_key_map.get("down_proj_expert_in_scale_key", None)
|
||||
|
||||
up_gate_proj_weight_scale = _extract_scale_tensor(up_gate_proj_expert_weight_scale_key, logical_expert_ids)
|
||||
down_proj_weight_scale = _extract_scale_tensor(down_proj_expert_weight_scale_key, logical_expert_ids)
|
||||
up_gate_proj_in_scale = _extract_descale_tensor(up_gate_proj_expert_in_scale_key, logical_expert_ids)
|
||||
down_proj_in_scale = _extract_scale_tensor(down_proj_expert_in_scale_key, logical_expert_ids)
|
||||
|
||||
up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0)
|
||||
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
|
||||
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0)
|
||||
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0)
|
||||
|
||||
name_tensor_map = {
|
||||
"up_gate_proj_weight": up_gate_proj_weight,
|
||||
"down_proj_weight": down_proj_weight,
|
||||
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
|
||||
"down_proj_weight_scale": down_proj_weight_scale,
|
||||
"up_gate_proj_in_scale": up_gate_proj_in_scale,
|
||||
}
|
||||
for name, tensor in name_tensor_map.items():
|
||||
getattr(layer, name).set_value(tensor)
|
||||
setattr(layer, "down_proj_in_scale", down_proj_in_scale)
|
||||
|
||||
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
|
||||
# TODO: split create_parameter from process_loaded_weights
|
||||
return NotImplemented
|
||||
"""
|
||||
Paddle HPU create weight process.
|
||||
"""
|
||||
self.weight_dtype = "float8_e4m3fn"
|
||||
self.up_gate_proj_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.hidden_size,
|
||||
layer.moe_intermediate_size * 2,
|
||||
]
|
||||
self.down_proj_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.moe_intermediate_size,
|
||||
layer.hidden_size,
|
||||
]
|
||||
setattr(
|
||||
layer,
|
||||
self.added_weight_attrs[0],
|
||||
layer.create_parameter(
|
||||
shape=self.up_gate_proj_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
self.added_weight_attrs[1],
|
||||
layer.create_parameter(
|
||||
shape=self.down_proj_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
|
||||
self.default_dtype = layer._helper.get_default_dtype()
|
||||
# in_scales
|
||||
setattr(
|
||||
layer,
|
||||
"up_gate_proj_in_scale",
|
||||
layer.create_parameter(
|
||||
shape=[1],
|
||||
dtype=self.default_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
|
||||
# weight_scales
|
||||
setattr(
|
||||
layer,
|
||||
"up_gate_proj_weight_scale",
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
|
||||
dtype=self.default_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
"down_proj_weight_scale",
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.hidden_size],
|
||||
dtype=self.default_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
|
||||
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
hpu quantization methods
|
||||
"""
|
||||
@@ -0,0 +1,237 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
from fastdeploy.model_executor.layers.quantization.quant_base import (
|
||||
QuantConfigBase,
|
||||
QuantMethodBase,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.model_executor.utils import set_weight_attrs
|
||||
|
||||
|
||||
class KvCacheQuantzationTypes(str, Enum):
|
||||
"""
|
||||
KvCacheQuantzationTypes
|
||||
"""
|
||||
|
||||
FP8 = "float8_e4m3fn"
|
||||
FP8_E4M3 = "float8_e4m3"
|
||||
|
||||
|
||||
class HPUKvCacheQuantConfig(QuantConfigBase):
|
||||
"""
|
||||
quantization config for weight fp8
|
||||
"""
|
||||
|
||||
def __init__(self, kv_cache_quant_type: str, is_channel_wise: bool, has_zero_point: bool) -> None:
|
||||
"""
|
||||
__init__
|
||||
"""
|
||||
super().__init__()
|
||||
self.kv_cache_quant_type = kv_cache_quant_type
|
||||
|
||||
try:
|
||||
self.quant_type = KvCacheQuantzationTypes(kv_cache_quant_type)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid Kvcache type: {kv_cache_quant_type}")
|
||||
|
||||
if self.quant_type == KvCacheQuantzationTypes.FP8_E4M3:
|
||||
self.max_bound = 240.0
|
||||
elif self.quant_type == KvCacheQuantzationTypes.FP8:
|
||||
self.max_bound = 448.0
|
||||
else:
|
||||
raise ValueError(f"Invalid Kvcache type: {kv_cache_quant_type}")
|
||||
|
||||
def name(self) -> str:
|
||||
"""
|
||||
get_name
|
||||
"""
|
||||
return "kvcache"
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, kv_cache_quant_type: str, is_channel_wise: bool, has_zero_point: bool
|
||||
) -> "HPUKvCacheQuantConfig":
|
||||
"""
|
||||
from_config
|
||||
"""
|
||||
return cls(kv_cache_quant_type, is_channel_wise, has_zero_point)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
"""
|
||||
get_quant_method
|
||||
"""
|
||||
return HPUKVCacheMethodBase(self)
|
||||
|
||||
|
||||
class HPUKVCacheMethodBase(QuantMethodBase):
|
||||
"""
|
||||
HPUKVCacheMethodBase: HPU need scale in fp32 format but GPU define all scale in bf16 format
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: HPUKvCacheQuantConfig,
|
||||
) -> None:
|
||||
"""
|
||||
HPUKVCacheMethodBase __init__
|
||||
"""
|
||||
super().__init__()
|
||||
self.cache_quant_config = quant_config
|
||||
|
||||
def load_scale(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
load_scale
|
||||
"""
|
||||
|
||||
cache_k_scale_tensor = get_tensor(state_dict.pop(self.cache_k_scale_name)).cast("float32").reshape_([-1])
|
||||
cache_v_scale_tensor = get_tensor(state_dict.pop(self.cache_v_scale_name)).cast("float32").reshape_([-1])
|
||||
q_scale_tensor = get_tensor(state_dict.pop(self.q_scale_name)).cast("float32").reshape_([-1])
|
||||
s_scale_tensor = get_tensor(state_dict.pop(self.s_scale_name)).cast("float32").reshape_([-1])
|
||||
|
||||
cache_k_scale = self.cache_quant_config.max_bound / cache_k_scale_tensor
|
||||
cache_v_scale = self.cache_quant_config.max_bound / cache_v_scale_tensor
|
||||
cache_k_out_scale = cache_k_scale_tensor / self.cache_quant_config.max_bound
|
||||
cache_v_out_scale = cache_v_scale_tensor / self.cache_quant_config.max_bound
|
||||
q_scale = self.cache_quant_config.max_bound / q_scale_tensor
|
||||
q_out_scale = q_scale_tensor / self.cache_quant_config.max_bound
|
||||
s_scale = self.cache_quant_config.max_bound / s_scale_tensor
|
||||
s_out_scale = s_scale_tensor / self.cache_quant_config.max_bound
|
||||
scaling_factor = layer.head_dim**-0.5
|
||||
q_scaling_scale = self.cache_quant_config.max_bound / (q_scale_tensor * scaling_factor)
|
||||
q_scaling_out_scale = (q_scale_tensor * scaling_factor) / self.cache_quant_config.max_bound
|
||||
|
||||
layer.cache_k_scale.set_value(cache_k_scale)
|
||||
layer.cache_v_scale.set_value(cache_v_scale)
|
||||
layer.cache_k_out_scale.set_value(cache_k_out_scale)
|
||||
layer.cache_v_out_scale.set_value(cache_v_out_scale)
|
||||
layer.q_scale.set_value(q_scale)
|
||||
layer.q_out_scale.set_value(q_out_scale)
|
||||
layer.q_scaling_scale.set_value(q_scaling_scale)
|
||||
layer.q_scaling_out_scale.set_value(q_scaling_out_scale)
|
||||
layer.s_scale.set_value(s_scale)
|
||||
layer.s_out_scale.set_value(s_out_scale)
|
||||
|
||||
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
|
||||
"""
|
||||
create_weights
|
||||
"""
|
||||
if self.cache_quant_config.quant_type == KvCacheQuantzationTypes.FP8_E4M3:
|
||||
layer.cache_quant_type_str = "cache_fp8_sdpa_fp8"
|
||||
layer.quant_max_bound = 240.0
|
||||
layer.quant_min_bound = -240.0
|
||||
else:
|
||||
raise NotImplementedError(f"{self.cache_quant_config.quant_type} is not implemented")
|
||||
|
||||
scale_shape = [1]
|
||||
|
||||
layer.cache_k_scale = layer.create_parameter(
|
||||
shape=scale_shape,
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
layer.cache_v_scale = layer.create_parameter(
|
||||
shape=scale_shape,
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
set_weight_attrs(
|
||||
layer.cache_k_scale,
|
||||
{
|
||||
**extra_weight_attrs,
|
||||
},
|
||||
)
|
||||
set_weight_attrs(
|
||||
layer.cache_v_scale,
|
||||
{
|
||||
**extra_weight_attrs,
|
||||
},
|
||||
)
|
||||
layer.cache_k_out_scale = layer.create_parameter(
|
||||
shape=scale_shape,
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
layer.cache_v_out_scale = layer.create_parameter(
|
||||
shape=scale_shape,
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
layer.q_scale = layer.create_parameter(
|
||||
shape=scale_shape,
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
layer.q_out_scale = layer.create_parameter(
|
||||
shape=scale_shape,
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
layer.q_scaling_scale = layer.create_parameter(
|
||||
shape=scale_shape,
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
layer.q_scaling_out_scale = layer.create_parameter(
|
||||
shape=scale_shape,
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
layer.s_scale = layer.create_parameter(
|
||||
shape=scale_shape,
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
layer.s_out_scale = layer.create_parameter(
|
||||
shape=scale_shape,
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
use for loader v0
|
||||
"""
|
||||
self.prefix = layer.prefix
|
||||
self.cache_k_scale_name = layer.prefix + ".cachek_matmul.activation_scale"
|
||||
self.cache_v_scale_name = layer.prefix + ".cachev_matmul.activation_scale"
|
||||
self.q_scale_name = layer.prefix + ".q_matmul.activation_scale"
|
||||
self.s_scale_name = layer.prefix + ".s_matmul.activation_scale"
|
||||
|
||||
self.load_scale(layer, state_dict)
|
||||
|
||||
def process_weights_after_loading(self, layer: nn.Layer):
|
||||
"""
|
||||
use for loader v1
|
||||
"""
|
||||
# cache_k_out_scale is the reciprocal of cache_k_scale
|
||||
if layer.cache_k_scale._is_initialized():
|
||||
layer.cache_k_out_scale.set_value(1 / layer.cache_k_scale) # cache_k_out_scale
|
||||
if layer.cache_v_scale._is_initialized():
|
||||
layer.cache_v_out_scale.set_value(1 / layer.cache_v_scale)
|
||||
|
||||
def apply(self, layer):
|
||||
"""
|
||||
apply
|
||||
"""
|
||||
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
|
||||
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
from fastdeploy.model_executor.layers.quantization.tensor_wise_fp8 import (
|
||||
TensorWiseFP8Config,
|
||||
TensorWiseFP8LinearMethod,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.model_executor.ops.intel_hpu import fused_quant
|
||||
|
||||
|
||||
class HpuTensorWiseFP8LinearMethod(TensorWiseFP8LinearMethod):
|
||||
"""
|
||||
Tensor wise fp8 quantization method for linear layer on HPU
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: TensorWiseFP8Config,
|
||||
) -> None:
|
||||
super().__init__(quant_config)
|
||||
self.max_bound = 240.0
|
||||
|
||||
def process_prequanted_weights(self, layer, state_dict, is_rearrange: bool = False) -> None:
|
||||
"""
|
||||
Process pre-quantized weights before applying them to the model
|
||||
Args:
|
||||
layer: The layer that owns the weights
|
||||
quant_weight: The quantized weights
|
||||
weight_scale: The scale of the quantized weights
|
||||
"""
|
||||
|
||||
quant_weight = get_tensor(state_dict.pop(layer.weight_key))
|
||||
weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key))
|
||||
act_scale = get_tensor(state_dict.pop(layer.act_scale_key))
|
||||
|
||||
# these activation_scale will fall in, but only quant for self_attn
|
||||
# mlp.shared_experts.up_gate_proj / down_proj
|
||||
# self_attn.qkv_proj / o_proj
|
||||
if "self_attn" in layer.act_scale_key:
|
||||
act_scale_inv = act_scale / self.max_bound
|
||||
act_scale = self.max_bound / act_scale
|
||||
else:
|
||||
act_scale_inv = act_scale
|
||||
act_scale = 1.0 / act_scale
|
||||
|
||||
layer.weight.copy_(quant_weight.view("float8_e4m3fn"), False)
|
||||
layer.weight_scale.set_value(weight_scale.astype(paddle.get_default_dtype()))
|
||||
layer.act_scale.set_value(act_scale.astype(paddle.get_default_dtype()))
|
||||
layer.act_scale_inv.set_value(act_scale_inv.astype(paddle.get_default_dtype()))
|
||||
|
||||
def create_weights(self, layer: nn.Layer, **extra_weight_attrs) -> None:
|
||||
"""
|
||||
Create weights for linear layer on HPU
|
||||
"""
|
||||
layer.weight_dtype = "float8_e4m3fn"
|
||||
layer.weight = layer.create_parameter(
|
||||
shape=layer.weight_shape,
|
||||
dtype=layer.weight_dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
layer.weight_scale = layer.create_parameter(
|
||||
shape=[1],
|
||||
dtype="bfloat16",
|
||||
is_bias=False,
|
||||
)
|
||||
layer.act_scale = layer.create_parameter(
|
||||
shape=[1],
|
||||
dtype="bfloat16",
|
||||
is_bias=False,
|
||||
)
|
||||
layer.act_scale_inv = layer.create_parameter(
|
||||
shape=[1],
|
||||
dtype="bfloat16",
|
||||
is_bias=False,
|
||||
)
|
||||
|
||||
def process_loaded_weights(self, layer: nn.Layer, weight: paddle.Tensor) -> None:
|
||||
"""
|
||||
loaded_weights using HPU specific quantization
|
||||
"""
|
||||
quanted_weight_tensor, weight_scale_tensor = fused_quant(weight)
|
||||
layer.weight.set_value(quanted_weight_tensor)
|
||||
layer.weight_scale.set_value(weight_scale_tensor)
|
||||
@@ -64,7 +64,6 @@ def get_moe_method():
|
||||
from fastdeploy.model_executor.layers.backends import HpuMoEMethod
|
||||
|
||||
return HpuMoEMethod(None)
|
||||
# return HpuTensorWiseFP8MoEMethod(None)
|
||||
|
||||
elif current_platform.is_maca():
|
||||
from fastdeploy.model_executor.layers.backends import (
|
||||
|
||||
@@ -34,6 +34,7 @@ class KvCacheQuantzationTypes(str, Enum):
|
||||
|
||||
INT8 = "int8"
|
||||
FP8 = "float8_e4m3fn"
|
||||
FP8_E4M3 = "float8_e4m3"
|
||||
BLOCK_WISE_FP8 = "block_wise_fp8"
|
||||
INT8_ZP = "int8_zp"
|
||||
INT4_ZP = "int4_zp"
|
||||
@@ -65,6 +66,8 @@ class KvCacheQuantConfig(QuantConfigBase):
|
||||
if self.quant_type == KvCacheQuantzationTypes.INT8 or self.quant_type == KvCacheQuantzationTypes.INT8_ZP:
|
||||
self.max_bound = 127.0
|
||||
self.is_channel_wise = True
|
||||
elif self.quant_type == KvCacheQuantzationTypes.FP8_E4M3:
|
||||
self.max_bound = 240.0
|
||||
elif (
|
||||
self.quant_type == KvCacheQuantzationTypes.FP8
|
||||
or self.quant_type == KvCacheQuantzationTypes.FP8_ZP
|
||||
@@ -101,6 +104,12 @@ class KvCacheQuantConfig(QuantConfigBase):
|
||||
)
|
||||
|
||||
return XPUKVCacheMethodBase(self)
|
||||
elif current_platform.is_intel_hpu():
|
||||
from fastdeploy.model_executor.layers.backends.intel_hpu.quantization.kv_cache import (
|
||||
HPUKVCacheMethodBase,
|
||||
)
|
||||
|
||||
return HPUKVCacheMethodBase(self)
|
||||
else:
|
||||
return KVCacheMethodBase(self)
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from typing import Optional
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.layers.moe import FusedMoE
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
from ..utils import get_tensor
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
@@ -52,14 +53,28 @@ class TensorWiseFP8Config(QuantConfigBase):
|
||||
"""
|
||||
return method according to this config!
|
||||
"""
|
||||
if isinstance(layer, FusedMoE):
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_triton_backend import (
|
||||
TensorWiseFP8MoEMethod,
|
||||
)
|
||||
if current_platform.is_intel_hpu():
|
||||
if isinstance(layer, FusedMoE):
|
||||
from fastdeploy.model_executor.layers.backends import (
|
||||
HpuTensorWiseFP8MoEMethod,
|
||||
)
|
||||
|
||||
return TensorWiseFP8MoEMethod(self)
|
||||
return HpuTensorWiseFP8MoEMethod(self)
|
||||
else:
|
||||
from fastdeploy.model_executor.layers.backends import (
|
||||
HpuTensorWiseFP8LinearMethod,
|
||||
)
|
||||
|
||||
return HpuTensorWiseFP8LinearMethod(self)
|
||||
else:
|
||||
return TensorWiseFP8LinearMethod(self)
|
||||
if isinstance(layer, FusedMoE):
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_triton_backend import (
|
||||
TensorWiseFP8MoEMethod,
|
||||
)
|
||||
|
||||
return TensorWiseFP8MoEMethod(self)
|
||||
else:
|
||||
return TensorWiseFP8LinearMethod(self)
|
||||
|
||||
|
||||
class TensorWiseFP8LinearMethod(QuantMethodBase):
|
||||
|
||||
Reference in New Issue
Block a user