[Intel HPU] enable MoE EP for hpu (#5855)

* enable HPU MoE EP

* MoE intermediate_scale stack

* enable loader_v1 esp for tensor_wise_fp8 TP or EP

* modify activation_scale name
This commit is contained in:
Cheng Yanfei
2026-01-15 13:08:00 +08:00
committed by GitHub
parent 7c56041272
commit fbcccaa750
9 changed files with 177 additions and 11 deletions
@@ -24,6 +24,7 @@ from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import (
UnquantizedFusedMoEMethod,
)
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.utils import set_weight_attrs
class HpuMoEMethod(UnquantizedFusedMoEMethod):
@@ -53,6 +54,24 @@ class HpuMoEMethod(UnquantizedFusedMoEMethod):
)
self.down_proj_expert_act_scale_key = down_proj_expert_weight_key.replace("weight", "activation_scale")
def init_ep(self, layer: nn.Layer) -> None:
"""
Initialize EP (Expert Parallel) related modules.
"""
return
def apply_tp(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the TP prefill method.
"""
raise NotImplementedError
def apply_ep_prefill(
self,
layer: nn.Layer,
@@ -77,7 +96,7 @@ class HpuMoEMethod(UnquantizedFusedMoEMethod):
"""
raise NotImplementedError
def apply_tp(
def apply(
self,
layer: nn.Layer,
x: paddle.Tensor,
@@ -190,6 +209,7 @@ class HpuTensorWiseFP8MoEMethod(HpuMoEMethod):
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0)
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0)
down_proj_in_scale = paddle.stack(down_proj_in_scale, axis=0)
name_tensor_map = {
"up_gate_proj_weight": up_gate_proj_weight,
@@ -197,10 +217,10 @@ class HpuTensorWiseFP8MoEMethod(HpuMoEMethod):
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
"down_proj_weight_scale": down_proj_weight_scale,
"up_gate_proj_in_scale": up_gate_proj_in_scale,
"down_proj_in_scale": down_proj_in_scale,
}
for name, tensor in name_tensor_map.items():
getattr(layer, name).set_value(tensor)
setattr(layer, "down_proj_in_scale", down_proj_in_scale)
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
"""
@@ -247,6 +267,15 @@ class HpuTensorWiseFP8MoEMethod(HpuMoEMethod):
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
"down_proj_in_scale",
layer.create_parameter(
shape=[layer.num_local_experts, 1],
dtype=self.default_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# weight_scales
setattr(
@@ -267,6 +296,19 @@ class HpuTensorWiseFP8MoEMethod(HpuMoEMethod):
default_initializer=paddle.nn.initializer.Constant(0),
),
)
extra_weight_attrs = {
**(extra_weight_attrs or {}),
"SHARD_ID_TO_SHARDED_DIM": {"gate": 1, "down": 0, "up": 1},
}
set_weight_attrs(layer.up_gate_proj_weight, extra_weight_attrs)
set_weight_attrs(layer.down_proj_weight, extra_weight_attrs)
extra_scale_attrs = {
**(extra_weight_attrs or {}),
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "up": 0, "down": None},
}
set_weight_attrs(layer.down_proj_in_scale, extra_scale_attrs)
set_weight_attrs(layer.up_gate_proj_weight_scale, extra_scale_attrs)
set_weight_attrs(layer.down_proj_weight_scale, extra_scale_attrs)
def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""
@@ -296,6 +338,27 @@ class HpuTensorWiseFP8MoEMethod(HpuMoEMethod):
setattr(layer, weights_name, weights_list)
setattr(layer, scales_name, scales_list)
def process_weights_after_loading(self, layer):
return
def init_ep(self, layer: nn.Layer) -> None:
"""
Initialize EP (Expert Parallel) related modules.
"""
return
def apply_tp(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the TP decoder method.
"""
raise NotImplementedError
def apply_ep_prefill(
self,
layer: nn.Layer,
@@ -320,7 +383,7 @@ class HpuTensorWiseFP8MoEMethod(HpuMoEMethod):
"""
raise NotImplementedError
def apply_tp(
def apply(
self,
layer: nn.Layer,
x: paddle.Tensor,
@@ -131,6 +131,11 @@ class HPUKVCacheMethodBase(QuantMethodBase):
layer.s_scale.set_value(s_scale)
layer.s_out_scale.set_value(s_out_scale)
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
loaded_weight = get_tensor(loaded_weight).cast("float32")
loaded_weight = self.cache_quant_config.max_bound / loaded_weight
param.copy_(loaded_weight, False)
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
"""
create_weights
@@ -158,12 +163,14 @@ class HPUKVCacheMethodBase(QuantMethodBase):
layer.cache_k_scale,
{
**extra_weight_attrs,
"weight_loader": self.weight_loader,
},
)
set_weight_attrs(
layer.cache_v_scale,
{
**extra_weight_attrs,
"weight_loader": self.weight_loader,
},
)
layer.cache_k_out_scale = layer.create_parameter(
@@ -182,6 +189,13 @@ class HPUKVCacheMethodBase(QuantMethodBase):
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
)
set_weight_attrs(
layer.q_scale,
{
**extra_weight_attrs,
"weight_loader": self.weight_loader,
},
)
layer.q_out_scale = layer.create_parameter(
shape=scale_shape,
dtype="float32",
@@ -202,6 +216,13 @@ class HPUKVCacheMethodBase(QuantMethodBase):
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
)
set_weight_attrs(
layer.s_scale,
{
**extra_weight_attrs,
"weight_loader": self.weight_loader,
},
)
layer.s_out_scale = layer.create_parameter(
shape=scale_shape,
dtype="float32",
@@ -226,9 +247,16 @@ class HPUKVCacheMethodBase(QuantMethodBase):
"""
# cache_k_out_scale is the reciprocal of cache_k_scale
if layer.cache_k_scale._is_initialized():
layer.cache_k_out_scale.set_value(1 / layer.cache_k_scale) # cache_k_out_scale
layer.cache_k_out_scale.set_value(1.0 / layer.cache_k_scale)
if layer.cache_v_scale._is_initialized():
layer.cache_v_out_scale.set_value(1 / layer.cache_v_scale)
layer.cache_v_out_scale.set_value(1.0 / layer.cache_v_scale)
if layer.q_scale._is_initialized():
scaling_factor = layer.head_dim**-0.5
layer.q_scaling_scale.set_value(layer.q_scale / scaling_factor)
layer.q_scaling_out_scale.set_value(scaling_factor / layer.q_scale)
layer.q_out_scale.set_value(1.0 / layer.q_scale)
if layer.s_scale._is_initialized():
layer.s_out_scale.set_value(1.0 / layer.s_scale)
def apply(self, layer):
"""
@@ -23,6 +23,7 @@ from fastdeploy.model_executor.layers.quantization.tensor_wise_fp8 import (
)
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.ops.intel_hpu import fused_quant
from fastdeploy.model_executor.utils import set_weight_attrs
class HpuTensorWiseFP8LinearMethod(TensorWiseFP8LinearMethod):
@@ -92,6 +93,14 @@ class HpuTensorWiseFP8LinearMethod(TensorWiseFP8LinearMethod):
is_bias=False,
)
self.model_format = extra_weight_attrs.get("model_format")
if self.model_format == "torch" and "output_dim" in extra_weight_attrs:
extra_weight_attrs["output_dim"] = not extra_weight_attrs["output_dim"]
set_weight_attrs(
layer.weight,
extra_weight_attrs,
)
def process_loaded_weights(self, layer: nn.Layer, weight: paddle.Tensor) -> None:
"""
loaded_weights using HPU specific quantization
@@ -99,3 +108,20 @@ class HpuTensorWiseFP8LinearMethod(TensorWiseFP8LinearMethod):
quanted_weight_tensor, weight_scale_tensor = fused_quant(weight)
layer.weight.set_value(quanted_weight_tensor)
layer.weight_scale.set_value(weight_scale_tensor)
def process_weights_after_loading(self, layer: nn.Layer):
"""
use for loader v1
"""
# these activation_scale will fall in, but only quant for self_attn
# mlp.shared_experts.up_gate_proj / down_proj
# self_attn.qkv_proj / o_proj
if layer.act_scale._is_initialized():
if "self_attn" in layer.act_scale_key:
act_scale_inv = layer.act_scale / self.max_bound
act_scale = self.max_bound / layer.act_scale
else:
act_scale_inv = layer.act_scale
act_scale = 1.0 / layer.act_scale
layer.act_scale.set_value(act_scale.astype(paddle.get_default_dtype()))
layer.act_scale_inv.set_value(act_scale_inv.astype(paddle.get_default_dtype()))
+7 -4
View File
@@ -659,6 +659,12 @@ class FusedMoE(nn.Layer):
tp_group=self.fd_config.parallel_config.tp_group,
)
if current_platform.is_intel_hpu():
out = self.forward_normal(x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc)
if self.reduce_results and (self.ep_size > 1 or self.tp_size > 1):
tensor_model_parallel_all_reduce_custom(out)
return out
token_num = x.shape[0]
if (
self.ep_size > 1
@@ -678,10 +684,7 @@ class FusedMoE(nn.Layer):
out = self.forward_normal(x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc)
if self.reduce_results and self.tp_size > 1:
if current_platform.is_intel_hpu():
tensor_model_parallel_all_reduce_custom(out)
else:
out = tensor_model_parallel_all_reduce(out, self.tp_group)
out = tensor_model_parallel_all_reduce(out, self.tp_group)
return out
def forward_chunked_moe(