[Iluvatar] refactor attn and moe code (#6887)

This commit is contained in:
yzwu
2026-03-18 10:31:00 +08:00
committed by GitHub
parent 0359794e08
commit 8b890c0d72
16 changed files with 877 additions and 140 deletions
@@ -20,7 +20,6 @@ from .block_multihead_attn_backend import BlockAttentionBackend
from .dsa_attention_backend import DSAAttentionBackend
from .flash_attn_backend import FlashAttentionBackend
from .flash_mask_attn_backend import FlashMaskAttentionBackend
from .iluvatar_attn_backend import IluvatarAttnBackend
from .mla_attention_backend import MLAAttentionBackend
from .moba_attention_backend import PlasAttentionBackend
from .native_paddle_backend import PaddleNativeAttnBackend
@@ -33,7 +32,6 @@ __all__ = [
"MLAAttentionBackend",
"DSAAttentionBackend",
"FlashAttentionBackend",
"IluvatarAttnBackend",
"BlockAttentionBackend",
"Attention",
"PlasAttentionBackend",
@@ -62,3 +62,10 @@ if current_platform.is_intel_hpu():
if hasattr(intel_hpu, "__all__"):
globals().update({name: getattr(intel_hpu, name) for name in intel_hpu.__all__})
__all__.extend(intel_hpu.__all__)
if current_platform.is_iluvatar():
from . import iluvatar
if hasattr(iluvatar, "__all__"):
globals().update({name: getattr(iluvatar, name) for name in iluvatar.__all__})
__all__.extend(iluvatar.__all__)
@@ -0,0 +1,30 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
iluvatar gpu backend methods
"""
from .attention.mha_attn_backend import MhaAttnBackend
from .moe.fuse_moe_cutlass_iluvatar_backend import (
IluvatarCutlassMoEMethod,
IluvatarCutlassWeightOnlyMoEMethod,
)
from .quantization.weight_only import IluvatarWeightOnlyLinearMethod
__all__ = [
"MhaAttnBackend",
"IluvatarCutlassMoEMethod",
"IluvatarCutlassWeightOnlyMoEMethod",
"IluvatarWeightOnlyLinearMethod",
]
@@ -0,0 +1,17 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
iluvatar gpu backend attention methods
"""
@@ -39,9 +39,9 @@ if TYPE_CHECKING:
@dataclass
class IluvatarAttentionMetadata(AttentionMetadata):
class MhaAttentionMetadata(AttentionMetadata):
"""
IluvatarAttentionMetadata
MhaAttentionMetadata
"""
alibi_slopes: Optional[paddle.Tensor] = None
@@ -60,7 +60,7 @@ class IluvatarAttentionMetadata(AttentionMetadata):
decode_block_tables: paddle.Tensor = None
class IluvatarAttnBackend(AttentionBackend):
class MhaAttnBackend(AttentionBackend):
"""
The backend class that uses paddle native attention implementation.
Which is used only for testing purpose.
@@ -76,7 +76,7 @@ class IluvatarAttnBackend(AttentionBackend):
decoder_block_shape_q: int = -1,
):
super().__init__()
self.attention_metadata = IluvatarAttentionMetadata()
self.attention_metadata = MhaAttentionMetadata()
self.block_size = fd_config.cache_config.block_size
assert self.block_size == 16, "Iluvatar paged attn requires block_size must be 16."
self.max_context_len = fd_config.model_config.max_model_len
@@ -0,0 +1,17 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
iluvatar gpu backend moe methods
"""
@@ -0,0 +1,510 @@
"""
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Callable
import paddle
from paddle import nn
from paddle.nn.quant import weight_quantize
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import (
UnquantizedFusedMoEMethod,
)
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.ops.iluvatar import (
moe_expert_dispatch,
moe_expert_ffn,
moe_expert_reduce,
)
from fastdeploy.model_executor.utils import (
TensorTracker,
free_tensor,
process_weight_transpose,
set_weight_attrs,
weight_fully_copied,
)
class IluvatarCutlassMoEMethod(UnquantizedFusedMoEMethod):
"""
Use Cutlass Group Gemm to compute Fused MoE.
This method is the oldest way to compute MoE in Paddle.
"""
def process_loaded_weights(self, layer: nn.Layer, state_dict):
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
layer.extract_moe_ffn_weights(state_dict)
)
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
layer.up_gate_proj_weight.set_value(stacked_up_gate_proj_weights)
layer.down_proj_weight.set_value(stacked_down_proj_weights)
if layer.with_bias:
up_gate_proj_bias, down_proj_bias = layer.extract_moe_ffn_bias(state_dict)
stacked_up_gate_proj_bias = paddle.stack(up_gate_proj_bias, axis=0)
stacked_down_proj_bias = paddle.stack(down_proj_bias, axis=0)
layer.up_gate_proj_bias.set_value(stacked_up_gate_proj_bias)
layer.down_proj_bias.set_value(stacked_down_proj_bias)
def compute_ffn(
self,
layer: nn.Layer,
permute_input: paddle.Tensor,
token_nums_per_expert: paddle.Tensor,
expert_idx_per_token: paddle.Tensor,
):
"""
Paddle Cutlass compute Fused MoE.
"""
ffn_out_without_down_proj_bias = moe_expert_ffn(
permute_input,
token_nums_per_expert,
getattr(layer, self.added_weight_attrs[0]),
getattr(layer, self.added_weight_attrs[1]),
(layer.up_gate_proj_bias if hasattr(layer, "up_gate_proj_bias") else None),
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
self.moe_quant_type,
layer.fd_config.model_config.moe_phase.phase,
)
if layer.with_bias:
down_proj_bias_expand = paddle.index_select(layer.down_proj_bias, expert_idx_per_token, axis=0)
ffn_out_without_down_proj_bias = paddle.add(ffn_out_without_down_proj_bias, down_proj_bias_expand)
return ffn_out_without_down_proj_bias
def apply_ep_prefill(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the EP prefill method.
"""
raise NotImplementedError
def apply_ep_decode(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the EP decoder method.
"""
raise NotImplementedError
def apply_tp(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Paddle Cutlass compute Fused MoE.
"""
gate_out = gate(x)
gate_out = gate_out.cast("float32")
if layer.topk_method == "noaux_tc":
gate_out, topk_weights, topk_idx = get_moe_scores(
gate_out,
layer.n_group,
layer.topk_group,
layer.top_k,
layer.routed_scaling_factor,
layer.gate_correction_bias,
getattr(layer, "renormalize", True),
)
(
permute_input,
token_nums_per_expert,
permute_indices_per_token,
topk_weights,
topk_idx,
expert_idx_per_token,
) = moe_expert_dispatch(
x,
gate_out,
None, # Use layer.gate_correction_bias in get_moe_scores.
(
layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None
), # if set, permute_input will be int8_t
layer.top_k,
False,
self.moe_quant_type,
topk_only_mode=True,
)
else:
(
permute_input,
token_nums_per_expert,
permute_indices_per_token,
topk_weights,
topk_idx,
expert_idx_per_token,
) = moe_expert_dispatch(
x,
gate_out,
layer.gate_correction_bias,
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None),
layer.top_k,
False,
self.moe_quant_type,
topk_only_mode=False,
)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_idx)
if not layer.with_bias and self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
# only w4a8 need expert_idx_per_token
# Other need not this tensor, so we make it None.
expert_idx_per_token = None
else:
expert_idx_per_token = expert_idx_per_token.cast("int64")
ffn_out = self.compute_ffn(
layer,
permute_input,
token_nums_per_expert,
expert_idx_per_token,
)
# reduce 中会做 topk 个 weight 的 norm 和 routed_scaling_factor
fused_moe_out = moe_expert_reduce(
ffn_out,
topk_weights,
permute_indices_per_token,
topk_idx,
None,
norm_topk_prob=False if layer.topk_method == "noaux_tc" else True,
routed_scaling_factor=1.0,
)
return fused_moe_out
class IluvatarCutlassWeightOnlyMoEMethod(IluvatarCutlassMoEMethod):
"""
weight only for moe
"""
def __init__(self, quant_config):
super().__init__(quant_config)
self.quant_config = quant_config
self.moe_quant_type = self.quant_config.algo
self.pack_num = 1
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
"""
Paddle cutlass process prequanted weights.
"""
up_gate_proj_expert_weight_key = layer.weight_key_map.get("up_gate_proj_expert_weight_key", None)
down_proj_expert_weight_key = layer.weight_key_map.get("down_proj_expert_weight_key", None)
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
up_gate_proj_weights, down_proj_weights, logical_expert_ids, _ = layer.load_experts_weight(
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key, is_rearrange
)
# self.check(layer, up_gate_proj_weights, down_proj_weights)
up_gate_proj_weight_scale = []
down_proj_weight_scale = []
if isinstance(state_dict, list):
state_dict = dict(state_dict)
for expert_idx in logical_expert_ids:
up_gate_proj_weight_scale.append(
get_tensor(state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx)))
)
down_proj_weight_scale.append(
get_tensor(state_dict.pop(down_proj_expert_weight_scale_key.format(expert_idx)))
)
up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0)
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0)
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0)
name_tensor_map = {
"up_gate_proj_weight": up_gate_proj_weight,
"down_proj_weight": down_proj_weight,
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
"down_proj_weight_scale": down_proj_weight_scale,
}
for name, tensor in name_tensor_map.items():
getattr(layer, name).set_value(tensor)
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
"""
Paddle cutlass create weight process.
"""
self.default_dtype = layer._helper.get_default_dtype()
if self.moe_quant_type == "weight_only_int4":
self.up_gate_proj_weight_shape = [
layer.num_local_experts,
layer.moe_intermediate_size,
layer.hidden_size,
]
else:
self.up_gate_proj_weight_shape = [
layer.num_local_experts,
layer.moe_intermediate_size * 2,
layer.hidden_size,
]
if self.moe_quant_type == "weight_only_int4":
self.down_proj_weight_shape = [
layer.num_local_experts,
layer.hidden_size // 2,
layer.moe_intermediate_size,
]
else:
self.down_proj_weight_shape = [
layer.num_local_experts,
layer.hidden_size,
layer.moe_intermediate_size,
]
self.up_gate_proj_scale_shape = [layer.num_local_experts, layer.moe_intermediate_size * 2]
self.down_proj_scale_shape = [layer.num_local_experts, layer.hidden_size]
self.model_format = extra_weight_attrs.get("model_format")
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
if self.model_format != "torch":
up_gate_proj_weight_shape = [
layer.num_local_experts,
layer.hidden_size,
layer.moe_intermediate_size * 2,
]
down_proj_weight_shape = [layer.num_local_experts, layer.moe_intermediate_size, layer.hidden_size]
up_gate_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=up_gate_proj_weight_shape, output_dim=True),
}
down_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=down_proj_weight_shape, output_dim=False),
}
else:
up_gate_proj_weight_shape = [
layer.num_local_experts,
layer.moe_intermediate_size * 2,
layer.hidden_size,
]
down_proj_weight_shape = [layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size]
up_gate_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=up_gate_proj_weight_shape, output_dim=False),
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0},
}
down_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=down_proj_weight_shape, output_dim=True),
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0},
}
layer.up_gate_proj_weight = layer.create_parameter(
shape=up_gate_proj_weight_shape,
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.down_proj_weight = layer.create_parameter(
shape=down_proj_weight_shape,
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
set_weight_attrs(
layer.up_gate_proj_weight,
up_gate_proj_attrs,
)
set_weight_attrs(
layer.down_proj_weight,
down_proj_attrs,
)
else:
self.weight_dtype = "int8"
up_gate_proj_weight_name = self.added_weight_attrs[0]
down_proj_weight_name = self.added_weight_attrs[1]
up_gate_proj_scale_name = self.added_scale_attrs[0]
down_proj_scale_name = self.added_scale_attrs[1]
setattr(
layer,
up_gate_proj_weight_name,
layer.create_parameter(
shape=self.up_gate_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
down_proj_weight_name,
layer.create_parameter(
shape=self.down_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# weight_scale
setattr(
layer,
up_gate_proj_scale_name,
layer.create_parameter(
shape=self.up_gate_proj_scale_shape,
dtype=self.default_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
down_proj_scale_name,
layer.create_parameter(
shape=self.down_proj_scale_shape,
dtype=self.default_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# The v1 loader currently does not support loading offline quantized weight-only weights.
moe_extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}
set_weight_attrs(layer.up_gate_proj_weight, moe_extra_weight_attrs)
set_weight_attrs(layer.down_proj_weight, moe_extra_weight_attrs)
scale_extra_weight_attrs = {
**extra_weight_attrs,
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "up": 0, "down": None},
}
set_weight_attrs(layer.up_gate_proj_weight_scale, scale_extra_weight_attrs)
set_weight_attrs(layer.down_proj_weight_scale, scale_extra_weight_attrs)
if layer.with_bias:
layer.up_gate_proj_bias = layer.create_parameter(
shape=[layer.num_experts, layer.moe_intermediate_size * 2],
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.down_proj_bias = layer.create_parameter(
shape=[layer.num_experts, layer.hidden_size],
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
set_weight_attrs(
layer.up_gate_proj_bias,
extra_weight_attrs,
)
set_weight_attrs(
layer.down_proj_bias,
extra_weight_attrs,
)
def process_weights_after_loading(self, layer):
def _process_quantize(weight_idx):
# 1.init shape and type
# quantized_weight_name
weight_name = self.added_weight_attrs[weight_idx]
unquantized_weight_name = weight_name.replace("quant_weight", "weight")
weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape
weight_dtype = "int8"
# scale
scale_name = self.added_scale_attrs[weight_idx]
scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape
scale_dtype = self.default_dtype
# 2.crate tmp tensor
weight = paddle.empty(weight_shape, dtype=weight_dtype)
scale = paddle.empty(scale_shape, dtype=scale_dtype)
# 3.quantize weight
for expert_id in range(layer.num_local_experts):
weight[expert_id], scale[expert_id] = weight_quantize(
getattr(layer, unquantized_weight_name)[expert_id], algo=self.moe_quant_type
)
free_tensor(getattr(layer, unquantized_weight_name))
# create weight
setattr(
layer,
weight_name,
layer.create_parameter(
shape=weight_shape,
dtype=weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# create scale
setattr(
layer,
scale_name,
layer.create_parameter(
shape=scale_shape,
dtype=scale_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).copy_(weight, False)
getattr(layer, scale_name).copy_(scale, False)
if self.quant_config.is_checkpoint_bf16:
weight_id_map = {"gate_up": 0, "down": 1}
if weight_fully_copied(layer.up_gate_proj_weight):
weight_type = "gate_up"
else:
weight_type = "down"
if self.model_format == "torch":
unquantized_weight_name = self.added_weight_attrs[weight_id_map[weight_type]].replace(
"quant_weight", "weight"
)
process_weight_transpose(layer, unquantized_weight_name)
_process_quantize(weight_id_map[weight_type])
else:
return
def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""
Paddle cutlass load weight process.
"""
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
self.check(layer, up_gate_proj_weights, down_proj_weights)
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
weight_name = self.added_weight_attrs[idx]
scale_name = self.added_scale_attrs[idx]
weight_list = []
weight_scale_list = []
for i in range(layer.num_local_experts):
quant_weight, scale = weight_quantize(weight_tensor[i], algo=self.moe_quant_type)
weight_list.append(quant_weight)
weight_scale_list.append(scale)
quanted_weight = paddle.stack(weight_list, axis=0)
getattr(layer, weight_name).set_value(quanted_weight)
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
getattr(layer, scale_name).set_value(quanted_weight_scale)
@@ -0,0 +1,16 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
iluvatar quantization methods
"""
@@ -0,0 +1,183 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import paddle
from paddle.nn.quant import weight_only_linear, weight_quantize
from fastdeploy.model_executor.layers.linear import (
MergedColumnParallelLinear,
MergedReplicatedLinear,
QKVGateParallelLinear,
QKVParallelLinear,
)
from fastdeploy.model_executor.layers.quantization.weight_only import (
WeightOnlyConfig,
WeightOnlyLinearMethod,
)
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.utils import (
TensorTracker,
free_tensor,
process_weight_transpose,
set_weight_attrs,
)
class IluvatarWeightOnlyLinearMethod(WeightOnlyLinearMethod):
"""
Weight only quantization method for linear layer
"""
def __init__(
self,
quant_config: WeightOnlyConfig,
) -> None:
super().__init__(quant_config)
self.quant_config.weight_only_linear_arch = -1
self.group_size = -1
def create_weights(self, layer, **extra_weight_attrs):
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
self.model_format = extra_weight_attrs.get("model_format")
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
weight_shape = layer.weight_shape[::-1] if self.model_format == "torch" else layer.weight_shape
layer.weight = layer.create_parameter(
shape=weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
quant_attrs = extra_weight_attrs
if (
isinstance(layer, MergedColumnParallelLinear)
or isinstance(layer, QKVParallelLinear)
or isinstance(layer, MergedReplicatedLinear)
or isinstance(layer, QKVGateParallelLinear)
):
# Only MergedReplicatedLinear uses the default outdim.
tensor_output_dim = (self.model_format == "torch") ^ quant_attrs.get("output_dim", True)
quant_attrs = {
**quant_attrs,
"tensor_track": TensorTracker(shape=weight_shape, output_dim=tensor_output_dim),
}
if self.model_format == "torch" and "output_dim" in quant_attrs:
quant_attrs["output_dim"] = not quant_attrs["output_dim"]
set_weight_attrs(
layer.weight,
quant_attrs,
)
else:
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
weight_scale_shape = [layer.weight_shape[1]]
layer.weight_shape.reverse()
if self.quant_config.name() == "wint4":
layer.weight_shape[0] //= 2
layer.weight_dtype = "int8"
layer.weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
if "output_dim" in extra_weight_attrs:
extra_weight_attrs["output_dim"] = not extra_weight_attrs["output_dim"]
set_weight_attrs(
layer.weight,
extra_weight_attrs,
)
layer.weight_scale = layer.create_parameter(
shape=weight_scale_shape,
dtype=layer._dtype,
is_bias=False,
)
set_weight_attrs(
layer.weight_scale,
extra_weight_attrs,
)
def process_weights_after_loading(self, layer) -> None:
def _process_quantize():
quanted_weight_tensor, weight_scale_tensor = weight_quantize(
layer.weight,
algo=self.quant_config.algo,
arch=self.quant_config.weight_only_linear_arch,
)
free_tensor(layer.weight)
layer.weight = layer.create_parameter(
shape=quanted_weight_tensor.shape,
dtype="int8",
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight_scale = layer.create_parameter(
shape=weight_scale_tensor.shape,
dtype=layer._dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight.copy_(quanted_weight_tensor, False)
layer.weight_scale.copy_(weight_scale_tensor, False)
if self.quant_config.is_checkpoint_bf16:
if self.model_format == "torch":
process_weight_transpose(layer, "weight")
_process_quantize()
else:
return
def process_loaded_weights(self, layer, weight) -> None:
quanted_weight_tensor, weight_scale_tensor = weight_quantize(
weight,
algo=self.quant_config.algo,
arch=self.quant_config.weight_only_linear_arch,
)
layer.weight.set_value(quanted_weight_tensor)
layer.weight_scale.set_value(weight_scale_tensor.astype(paddle.get_default_dtype()))
def process_prequanted_weights(self, layer, state_dict, is_rearrange: bool = False) -> None:
"""
Process pre-quantized weights before applying them to the model
Args:
layer: The layer that owns the weights
quant_weight: The quantized weights
weight_scale: The scale of the quantized weights
"""
quant_weight = get_tensor(state_dict.pop(layer.weight_key))
weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key))
layer.weight.set_value(quant_weight)
layer.weight_scale.set_value(weight_scale.astype(paddle.get_default_dtype()))
def apply(self, layer, x):
linear_out = weight_only_linear(
x,
weight=layer.weight,
bias=layer.bias if layer.with_bias else None,
weight_scale=layer.weight_scale,
weight_dtype=("int8" if self.quant_config.name() == "wint8" else "int4"),
arch=self.quant_config.weight_only_linear_arch,
)
return linear_out
@@ -37,11 +37,6 @@ if current_platform.is_cuda():
)
except:
logger.warning("import w4afp8_gemm_scale_permute Failed!")
elif current_platform.is_iluvatar():
from fastdeploy.model_executor.ops.iluvatar import (
moe_expert_dispatch,
moe_expert_reduce,
)
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
from fastdeploy.model_executor.utils import (
@@ -91,40 +86,24 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
"""
Paddle Cutlass compute Fused MoE.
"""
if current_platform.is_iluvatar():
ffn_out_without_down_proj_bias = fastdeploy.model_executor.ops.iluvatar.moe_expert_ffn(
permute_input,
token_nums_per_expert,
getattr(layer, self.added_weight_attrs[0]),
getattr(layer, self.added_weight_attrs[1]),
(layer.up_gate_proj_bias if hasattr(layer, "up_gate_proj_bias") else None),
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
(layer.down_proj_in_scale if hasattr(layer, "down_proj_in_scale") else None),
expert_idx_per_token,
self.moe_quant_type,
used_in_ep_low_latency,
layer.fd_config.model_config.moe_phase.phase,
)
else:
ffn_out_without_down_proj_bias = fastdeploy.model_executor.ops.gpu.moe_expert_ffn(
permute_input,
token_nums_per_expert,
getattr(layer, self.added_weight_attrs[0]),
getattr(layer, self.added_weight_attrs[1]),
dequant_scale,
(layer.up_gate_proj_bias if hasattr(layer, "up_gate_proj_bias") else None),
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
(layer.down_proj_in_scale if hasattr(layer, "down_proj_in_scale") else None),
expert_idx_per_token,
max_tokens_per_expert,
self.moe_quant_type,
used_in_ep_low_latency,
estimate_total_token_nums,
getattr(layer.moe_quant_config, "hadamard_block_size", 128),
layer.activation,
)
ffn_out_without_down_proj_bias = fastdeploy.model_executor.ops.gpu.moe_expert_ffn(
permute_input,
token_nums_per_expert,
getattr(layer, self.added_weight_attrs[0]),
getattr(layer, self.added_weight_attrs[1]),
dequant_scale,
(layer.up_gate_proj_bias if hasattr(layer, "up_gate_proj_bias") else None),
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
(layer.down_proj_in_scale if hasattr(layer, "down_proj_in_scale") else None),
expert_idx_per_token,
max_tokens_per_expert,
self.moe_quant_type,
used_in_ep_low_latency,
estimate_total_token_nums,
getattr(layer.moe_quant_config, "hadamard_block_size", 128),
layer.activation,
)
if layer.with_bias:
down_proj_bias_expand = paddle.index_select(layer.down_proj_bias, expert_idx_per_token, axis=0)
@@ -307,91 +286,47 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
layer.gate_correction_bias,
getattr(layer, "renormalize", True),
)
if current_platform.is_iluvatar():
(
permute_input,
token_nums_per_expert,
permute_indices_per_token,
topk_weights,
topk_idx,
expert_idx_per_token,
dequant_scale,
max_tokens_per_expert,
) = moe_expert_dispatch(
x,
gate_out,
None, # Use layer.gate_correction_bias in get_moe_scores.
(
permute_input,
token_nums_per_expert,
permute_indices_per_token,
topk_weights,
topk_idx,
expert_idx_per_token,
) = moe_expert_dispatch(
x,
gate_out,
None, # Use layer.gate_correction_bias in get_moe_scores.
(
layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None
), # if set, permute_input will be int8_t
layer.top_k,
False,
self.moe_quant_type,
topk_only_mode=True,
)
dequant_scale = None
max_tokens_per_expert = None
else:
(
permute_input,
token_nums_per_expert,
permute_indices_per_token,
topk_weights,
topk_idx,
expert_idx_per_token,
dequant_scale,
max_tokens_per_expert,
) = moe_expert_dispatch(
x,
gate_out,
None, # Use layer.gate_correction_bias in get_moe_scores.
(
layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None
), # if set, permute_input will be int8_t
layer.top_k,
False,
self.moe_quant_type,
topk_only_mode=True,
)
layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None
), # if set, permute_input will be int8_t
layer.top_k,
False,
self.moe_quant_type,
topk_only_mode=True,
)
else:
if current_platform.is_iluvatar():
(
permute_input,
token_nums_per_expert,
permute_indices_per_token,
topk_weights,
topk_idx,
expert_idx_per_token,
) = moe_expert_dispatch(
x,
gate_out,
layer.gate_correction_bias,
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None),
layer.top_k,
False,
self.moe_quant_type,
topk_only_mode=False,
)
dequant_scale = None
max_tokens_per_expert = None
else:
(
permute_input,
token_nums_per_expert,
permute_indices_per_token,
topk_weights,
topk_idx,
expert_idx_per_token,
dequant_scale,
max_tokens_per_expert,
) = moe_expert_dispatch(
x,
gate_out,
layer.gate_correction_bias,
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None),
layer.top_k,
False,
self.moe_quant_type,
topk_only_mode=False,
)
(
permute_input,
token_nums_per_expert,
permute_indices_per_token,
topk_weights,
topk_idx,
expert_idx_per_token,
dequant_scale,
max_tokens_per_expert,
) = moe_expert_dispatch(
x,
gate_out,
layer.gate_correction_bias,
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None),
layer.top_k,
False,
self.moe_quant_type,
topk_only_mode=False,
)
if hasattr(layer, "up_gate_proj_in_scale"):
dequant_scale = None
+5 -1
View File
@@ -47,10 +47,14 @@ def get_moe_method(layer=None):
return moe method based on device platform
"""
if current_platform.is_cuda() or current_platform.is_iluvatar():
if current_platform.is_cuda():
from .fused_moe_cutlass_backend import CutlassMoEMethod
return CutlassMoEMethod(None)
elif current_platform.is_iluvatar():
from fastdeploy.model_executor.layers.backends import IluvatarCutlassMoEMethod
return IluvatarCutlassMoEMethod(None)
elif current_platform.is_xpu():
from fastdeploy.model_executor.layers.backends import XPUMoEMethod
@@ -149,6 +149,22 @@ class WeightOnlyConfig(QuantConfigBase):
else:
return GPUWeightOnlyLinearMethod(self)
elif current_platform.is_iluvatar():
if isinstance(layer, FusedMoE):
if layer.use_method == "cutlass":
from fastdeploy.model_executor.layers.backends import (
IluvatarCutlassWeightOnlyMoEMethod,
)
return IluvatarCutlassWeightOnlyMoEMethod(self)
else:
raise ValueError(f"Unsupported MOE backend {layer.use_method}")
else:
from fastdeploy.model_executor.layers.backends import (
IluvatarWeightOnlyLinearMethod,
)
return IluvatarWeightOnlyLinearMethod(self)
else:
if isinstance(layer, FusedMoE):
if layer.use_method == "cutlass":
@@ -100,19 +100,13 @@ def iluvatar_moe_expert_ffn(
up_gate_proj_bias: Optional[paddle.Tensor],
up_gate_proj_scale: Optional[paddle.Tensor],
down_proj_scale: Optional[paddle.Tensor],
down_proj_in_scale: Optional[paddle.Tensor],
expert_idx_per_token: Optional[paddle.Tensor],
quant_method: str,
used_in_ep_low_latency: bool,
moe_phase: str,
):
assert up_gate_proj_bias is None
assert up_gate_proj_scale is not None
assert down_proj_scale is not None
assert down_proj_in_scale is None
assert expert_idx_per_token is None
assert quant_method in ("weight_only_int8")
assert not used_in_ep_low_latency
group_gemm_func, tokens_per_expert = _pre_process_expert_ffn(moe_phase, tokens_expert_prefix_sum)
ffn1_output = group_gemm_func(permute_input, up_gate_proj_weight, up_gate_proj_scale, tokens_per_expert, -1)
act_out = swiglu(ffn1_output)
+12 -2
View File
@@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base import Platform
from fastdeploy.utils import console_logger as logger
from .base import Platform, _Backend
class IluvatarPlatform(Platform):
@@ -22,4 +25,11 @@ class IluvatarPlatform(Platform):
"""
get_attention_backend_cls
"""
return "fastdeploy.model_executor.layers.attention.IluvatarAttnBackend"
if selected_backend == _Backend.APPEND_ATTN:
logger.info("Using ixinfer MHA backend instead of append attention")
return "fastdeploy.model_executor.layers.backends.iluvatar.MhaAttnBackend"
else:
raise ValueError(
"Invalid attention backend you specified.\n"
"Now only support [NATIVE_ATTN, MLA_ATTN, APPEND_ATTN] in cuda place."
)
+3 -2
View File
@@ -20,7 +20,7 @@ import paddle
from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention import IluvatarAttnBackend
from fastdeploy.model_executor.layers.attention import get_attention_backend
from fastdeploy.worker.gpu_model_runner import GPUModelRunner
@@ -90,7 +90,8 @@ class IluvatarModelRunner(GPUModelRunner):
1,
int(self.model_config.num_key_value_heads) // self.parallel_config.tensor_parallel_size,
)
attn_backend = IluvatarAttnBackend(
attn_cls = get_attention_backend()
attn_backend = attn_cls(
self.fd_config,
kv_num_heads=self.model_config.kv_num_heads,
num_heads=num_heads,
-1
View File
@@ -33,7 +33,6 @@ omit =
*/fastdeploy/model_executor/ops/gpu/fastdeploy_ops.py
*/fastdeploy/model_executor/ops/gpu/fastdeploy_ops/__init__.py
*/fastdeploy/model_executor/ops/gpu/deep_gemm/utils.py
*/fastdeploy/model_executor/layers/attention/iluvatar_attn_backend.py
*/fastdeploy/model_executor/xpu_pre_and_post_process.py
*/fastdeploy/**/dcu/*
*/fastdeploy/worker/dcu*.py