mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Intel HPU] change MoE weights and scales from list to tensor and add… (#5289)
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
* [Intel HPU] change MoE weights and scales from list to tensor and add q/k rms norm * update doc * move HPU_CHUNK_SIZE into envs
This commit is contained in:
@@ -38,6 +38,7 @@ if TYPE_CHECKING:
|
||||
from fastdeploy.model_executor.forward_meta import HPUForwardMeta
|
||||
|
||||
from fastdeploy.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
|
||||
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
||||
|
||||
|
||||
def get_attention_mask(seq_lens_encoder, seq_lens_decoder, batch_size, query_len):
|
||||
@@ -80,6 +81,8 @@ class AttentionBackend_HPU(AttentionBackend):
|
||||
o_proj: RowParallelLinear,
|
||||
layer: paddle.nn.Layer,
|
||||
forward_meta: HPUForwardMeta,
|
||||
q_norm: RMSNorm = None,
|
||||
k_norm: RMSNorm = None,
|
||||
):
|
||||
"""
|
||||
Run a forward.
|
||||
@@ -96,6 +99,8 @@ class AttentionBackend_HPU(AttentionBackend):
|
||||
o_proj,
|
||||
layer,
|
||||
forward_meta,
|
||||
q_norm,
|
||||
k_norm,
|
||||
)
|
||||
elif forward_meta.forward_mode.is_decode():
|
||||
return self.forward_decode(
|
||||
@@ -104,6 +109,8 @@ class AttentionBackend_HPU(AttentionBackend):
|
||||
o_proj,
|
||||
layer,
|
||||
forward_meta,
|
||||
q_norm,
|
||||
k_norm,
|
||||
)
|
||||
else:
|
||||
return self.forward_extend(
|
||||
@@ -112,6 +119,8 @@ class AttentionBackend_HPU(AttentionBackend):
|
||||
o_proj,
|
||||
layer,
|
||||
forward_meta,
|
||||
q_norm,
|
||||
k_norm,
|
||||
)
|
||||
|
||||
def forward_mixed(
|
||||
@@ -121,6 +130,8 @@ class AttentionBackend_HPU(AttentionBackend):
|
||||
o_proj: RowParallelLinear,
|
||||
layer: paddle.nn.Layer,
|
||||
forward_meta: HPUForwardMeta,
|
||||
q_norm: RMSNorm = None,
|
||||
k_norm: RMSNorm = None,
|
||||
):
|
||||
"""Run a forward for mix."""
|
||||
raise NotImplementedError()
|
||||
@@ -132,6 +143,8 @@ class AttentionBackend_HPU(AttentionBackend):
|
||||
o_proj: RowParallelLinear,
|
||||
layer: paddle.nn.Layer,
|
||||
forward_meta: HPUForwardMeta,
|
||||
q_norm: RMSNorm = None,
|
||||
k_norm: RMSNorm = None,
|
||||
):
|
||||
"""Run a forward for decode."""
|
||||
raise NotImplementedError()
|
||||
@@ -143,6 +156,8 @@ class AttentionBackend_HPU(AttentionBackend):
|
||||
o_proj: RowParallelLinear,
|
||||
layer: paddle.nn.Layer,
|
||||
forward_meta: HPUForwardMeta,
|
||||
q_norm: RMSNorm = None,
|
||||
k_norm: RMSNorm = None,
|
||||
):
|
||||
"""Run a forward for extend."""
|
||||
raise NotImplementedError()
|
||||
@@ -256,7 +271,14 @@ class HPUAttentionBackend(AttentionBackend_HPU):
|
||||
return key_cache_shape, value_cache_shape
|
||||
|
||||
def forward_extend(
|
||||
self, src, qkv_proj: QKVParallelLinear, o_proj: RowParallelLinear, layer: Attention, forward_meta
|
||||
self,
|
||||
src,
|
||||
qkv_proj: QKVParallelLinear,
|
||||
o_proj: RowParallelLinear,
|
||||
layer: Attention,
|
||||
forward_meta,
|
||||
q_norm: RMSNorm = None,
|
||||
k_norm: RMSNorm = None,
|
||||
):
|
||||
"""
|
||||
forward_extend
|
||||
@@ -274,11 +296,19 @@ class HPUAttentionBackend(AttentionBackend_HPU):
|
||||
qkv_proj.weight,
|
||||
qkv_proj.bias,
|
||||
forward_meta.rotary_embs,
|
||||
getattr(qkv_proj, "act_scale", None),
|
||||
getattr(qkv_proj, "weight_scale", None),
|
||||
getattr(layer, "q_scale", None),
|
||||
getattr(layer, "cache_k_scale", None),
|
||||
getattr(layer, "cache_v_scale", None),
|
||||
q_norm.weight if q_norm is not None else None,
|
||||
k_norm.weight if k_norm is not None else None,
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
forward_meta.total_batch,
|
||||
transpose=False,
|
||||
use_neox_style=layer.use_neox_rotary_style,
|
||||
epsilon=1e-6,
|
||||
)
|
||||
|
||||
kv, B, BP_BS, M, H = key_value_states.shape
|
||||
@@ -335,7 +365,14 @@ class HPUAttentionBackend(AttentionBackend_HPU):
|
||||
return out_linear_out
|
||||
|
||||
def forward_decode(
|
||||
self, src, qkv_proj: QKVParallelLinear, o_proj: RowParallelLinear, layer: Attention, forward_meta
|
||||
self,
|
||||
src,
|
||||
qkv_proj: QKVParallelLinear,
|
||||
o_proj: RowParallelLinear,
|
||||
layer: Attention,
|
||||
forward_meta,
|
||||
q_norm: RMSNorm = None,
|
||||
k_norm: RMSNorm = None,
|
||||
):
|
||||
"""
|
||||
forward_decode
|
||||
@@ -357,8 +394,16 @@ class HPUAttentionBackend(AttentionBackend_HPU):
|
||||
qkv_proj.weight,
|
||||
qkv_proj.bias,
|
||||
o_proj.weight,
|
||||
None, # past_key: not used in decode mode
|
||||
None, # past_value: not used in decode mode
|
||||
q_norm.weight if q_norm is not None else None,
|
||||
k_norm.weight if k_norm is not None else None,
|
||||
getattr(qkv_proj, "act_scale", None),
|
||||
getattr(qkv_proj, "weight_scale", None),
|
||||
getattr(layer, "q_scaling_scale", None),
|
||||
getattr(layer, "cache_k_scale", None),
|
||||
getattr(layer, "s_scale", None),
|
||||
getattr(layer, "cache_v_scale", None),
|
||||
getattr(o_proj, "act_scale", None),
|
||||
getattr(o_proj, "weight_scale", None),
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
scaling_factor=self.head_dim**-0.5,
|
||||
|
||||
@@ -17,20 +17,20 @@
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce_custom
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import (
|
||||
UnquantizedFusedMoEMethod,
|
||||
)
|
||||
|
||||
|
||||
class HpuMoEMethod(MoEMethodBase):
|
||||
class HpuMoEMethod(UnquantizedFusedMoEMethod):
|
||||
"""
|
||||
Use Cutlass Group Gemm to compute Fused MoE.
|
||||
This method is the oldest way to compute MoE in Paddle.
|
||||
Implements Fused Mixture-of-Experts (MoE) computation using HPU-optimized operations.
|
||||
This method leverages the HPU backend's fused_gate_moe function for efficient expert routing and computation.
|
||||
Designed specifically for PaddlePaddle execution on Habana Processing Units (HPU).
|
||||
"""
|
||||
|
||||
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
|
||||
# TODO: split create_parameter from process_loaded_weights
|
||||
return NotImplemented
|
||||
|
||||
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Paddle HPU load weight process.
|
||||
@@ -38,19 +38,11 @@ class HpuMoEMethod(MoEMethodBase):
|
||||
# bf16
|
||||
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
|
||||
|
||||
for idx, weights_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
|
||||
weights_list = []
|
||||
for i in range(layer.num_local_experts):
|
||||
weight_tensor = weights_tensor[i]
|
||||
weight = layer.create_parameter(
|
||||
shape=weight_tensor.shape,
|
||||
dtype=weight_tensor.dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
weight.set_value(weight_tensor)
|
||||
weights_list.append(weight)
|
||||
weights_name = self.added_weight_attrs[idx]
|
||||
setattr(layer, weights_name, weights_list)
|
||||
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
|
||||
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
|
||||
|
||||
layer.up_gate_proj_weight.set_value(stacked_up_gate_proj_weights)
|
||||
layer.down_proj_weight.set_value(stacked_down_proj_weights)
|
||||
|
||||
def apply_ep_prefill(
|
||||
self,
|
||||
@@ -87,51 +79,16 @@ class HpuMoEMethod(MoEMethodBase):
|
||||
raise NotImplementedError
|
||||
|
||||
# norm_topk_prob = False if layer.topk_method == "noaux_tc" else True
|
||||
"""
|
||||
weights = paddle.nn.functional.softmax(gate_out, axis=-1)
|
||||
if layer.moe_use_gate_correction_bias:
|
||||
scores = weights + layer.gate_correction_bias
|
||||
_, selected_experts = paddle.topk(scores, layer.top_k, axis=-1)
|
||||
routing_weights = paddle.index_sample(weights, selected_experts)
|
||||
else:
|
||||
routing_weights, selected_experts = paddle.topk(weights, layer.top_k, axis=-1)
|
||||
routing_weights /= paddle.sum(routing_weights, axis=-1, keepdim=True)
|
||||
|
||||
common_inputs = (x, selected_experts, routing_weights.cast("bfloat16"))
|
||||
|
||||
common_params = (
|
||||
False, #permuted_weights
|
||||
"silu", #activation,
|
||||
0,
|
||||
layer.num_experts - 1,
|
||||
)
|
||||
|
||||
weights = (
|
||||
layer.moe_ffn1_weight,
|
||||
layer.moe_ffn2_weight,
|
||||
)
|
||||
|
||||
fused_moe_out, _ = mixture_of_experts(
|
||||
*common_inputs, *weights, *common_params, False
|
||||
)
|
||||
|
||||
# if norm_topk_prob:
|
||||
# routing_weights_norm = paddle.sum(routing_weights, axis=-1, keepdim=True).cast("bfloat16")
|
||||
# fused_moe_out = fused_moe_out / routing_weights_norm
|
||||
"""
|
||||
chunk_size = 64
|
||||
chunk_size = envs.FD_HPU_CHUNK_SIZE
|
||||
from fastdeploy.model_executor.ops.intel_hpu import fused_gate_moe
|
||||
|
||||
# TODO: fuse matmul to gate_moe
|
||||
gate_out = paddle.matmul(x.cast("float32"), gate.weight)
|
||||
fused_moe_out = fused_gate_moe(
|
||||
x,
|
||||
gate_out,
|
||||
gate.weight,
|
||||
layer.gate_correction_bias,
|
||||
layer.up_gate_proj_weight,
|
||||
layer.down_proj_weight,
|
||||
layer.top_k,
|
||||
layer.moe_use_gate_correction_bias,
|
||||
norm_topk_prob=True,
|
||||
permuted_weights=False,
|
||||
activation="silu",
|
||||
@@ -219,22 +176,20 @@ class HpuTensorWiseFP8MoEMethod(HpuMoEMethod):
|
||||
|
||||
# norm_topk_prob = False if layer.topk_method == "noaux_tc" else True
|
||||
|
||||
chunk_size = 64
|
||||
chunk_size = envs.FD_HPU_CHUNK_SIZE
|
||||
from fastdeploy.model_executor.ops.intel_hpu import fused_gate_moe_fp8
|
||||
|
||||
# TODO: fuse matmul to gate_moe
|
||||
gate_out = paddle.matmul(x.cast("float32"), gate.weight)
|
||||
fused_moe_out = fused_gate_moe_fp8(
|
||||
x,
|
||||
gate_out,
|
||||
gate.weight,
|
||||
layer.gate_correction_bias,
|
||||
layer.up_gate_proj_weight,
|
||||
layer.down_proj_weight,
|
||||
None, # intermediate_hidden_states_scales
|
||||
layer.up_gate_proj_in_scale,
|
||||
layer.down_proj_in_scale,
|
||||
layer.up_gate_proj_weight_scale,
|
||||
layer.down_proj_weight_scale,
|
||||
layer.top_k,
|
||||
layer.moe_use_gate_correction_bias,
|
||||
norm_topk_prob=True,
|
||||
permuted_weights=False,
|
||||
activation="silu",
|
||||
|
||||
@@ -222,7 +222,7 @@ class UnquantizedFusedMoEMethod(MoEMethodBase):
|
||||
hidden_size = extra_weight_attrs.pop("hidden_size")
|
||||
moe_intermediate_size = extra_weight_attrs.pop("moe_intermediate_size")
|
||||
self.model_format = extra_weight_attrs.get("model_format")
|
||||
if current_platform.is_cuda() and self.model_format != "torch":
|
||||
if (current_platform.is_cuda() or current_platform.is_intel_hpu()) and self.model_format != "torch":
|
||||
self.up_gate_proj_weight_shape = [num_experts, hidden_size, moe_intermediate_size * 2]
|
||||
self.down_proj_weight_shape = [num_experts, moe_intermediate_size, hidden_size]
|
||||
extra_weight_attrs = {
|
||||
|
||||
Reference in New Issue
Block a user