refactor pt loading (#4532)
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled

This commit is contained in:
bukejiyu
2025-11-11 21:30:39 +08:00
committed by GitHub
parent 4c911ecb74
commit b09ebb2813
35 changed files with 1094 additions and 797 deletions
@@ -19,7 +19,13 @@ from abc import abstractmethod
import paddle
from paddle import nn
from fastdeploy.model_executor.utils import default_weight_loader, set_weight_attrs
from fastdeploy.model_executor.utils import (
TensorTracker,
default_weight_loader,
free_tensor,
set_weight_attrs,
weight_fully_copied,
)
from fastdeploy.platforms import current_platform
from ..quantization.quant_base import QuantMethodBase
@@ -215,14 +221,21 @@ class UnquantizedFusedMoEMethod(MoEMethodBase):
num_experts = extra_weight_attrs.pop("num_experts")
hidden_size = extra_weight_attrs.pop("hidden_size")
moe_intermediate_size = extra_weight_attrs.pop("moe_intermediate_size")
if current_platform.is_cuda():
self.model_format = extra_weight_attrs.get("model_format")
if current_platform.is_cuda() and self.model_format != "torch":
self.up_gate_proj_weight_shape = [num_experts, hidden_size, moe_intermediate_size * 2]
self.down_proj_weight_shape = [num_experts, moe_intermediate_size, hidden_size]
extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 1, "down": 0, "up": 1}}
extra_weight_attrs = {
**(extra_weight_attrs or {}),
"SHARD_ID_TO_SHARDED_DIM": {"gate": 1, "down": 0, "up": 1},
}
else:
self.up_gate_proj_weight_shape = [num_experts, moe_intermediate_size * 2, hidden_size]
self.down_proj_weight_shape = [num_experts, hidden_size, moe_intermediate_size]
extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}
extra_weight_attrs = {
**(extra_weight_attrs or {}),
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0},
}
layer.up_gate_proj_weight = layer.create_parameter(
shape=self.up_gate_proj_weight_shape,
@@ -235,31 +248,46 @@ class UnquantizedFusedMoEMethod(MoEMethodBase):
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
extra_weight_attrs["weight_loader"] = extra_weight_attrs.get(
"weight_loader", default_weight_loader(layer.fd_config)
)
if self.model_format != "torch":
up_gate_proj_attrs = extra_weight_attrs
down_proj_attrs = extra_weight_attrs
else:
up_gate_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(
shape=layer.up_gate_proj_weight.shape,
output_dim=extra_weight_attrs["SHARD_ID_TO_SHARDED_DIM"]["gate"],
),
}
down_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(
shape=layer.down_proj_weight.shape,
output_dim=extra_weight_attrs["SHARD_ID_TO_SHARDED_DIM"]["down"],
),
}
set_weight_attrs(
layer.up_gate_proj_weight,
{
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
"weight_need_transpose": extra_weight_attrs.get("model_format") == "torch",
},
up_gate_proj_attrs,
)
set_weight_attrs(
layer.down_proj_weight,
{
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
"weight_need_transpose": extra_weight_attrs.get("model_format") == "torch",
},
down_proj_attrs,
)
if layer.with_bias:
# only pt model now
layer.up_gate_proj_bias = layer.create_parameter(
shape=[layer.num_experts, layer.moe_intermediate_size * 2],
shape=[num_experts, moe_intermediate_size * 2],
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.down_proj_bias = layer.create_parameter(
shape=[layer.num_experts, layer.hidden_size],
shape=[num_experts, hidden_size],
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
@@ -267,13 +295,37 @@ class UnquantizedFusedMoEMethod(MoEMethodBase):
layer.up_gate_proj_bias,
{
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
"model_format": extra_weight_attrs.get("model_format", ""),
},
)
set_weight_attrs(
layer.down_proj_bias,
{
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
"model_format": extra_weight_attrs.get("model_format", ""),
},
)
def process_weights_after_loading(self, layer):
if self.model_format != "torch":
return
if not weight_fully_copied(layer.up_gate_proj_weight) or not weight_fully_copied(layer.down_proj_weight):
return
up_gate_proj_weight_transpose = layer.up_gate_proj_weight.transpose([0, 2, 1])
down_proj_weight_transpose = layer.down_proj_weight.transpose([0, 2, 1])
up_gate_proj = layer.create_parameter(
shape=up_gate_proj_weight_transpose.shape,
dtype=up_gate_proj_weight_transpose.dtype,
default_initializer=paddle.nn.initializer.Normal(mean=0.0, std=0.02),
is_bias=False,
)
up_gate_proj.copy_(up_gate_proj_weight_transpose, False)
free_tensor(layer.up_gate_proj_weight)
layer.up_gate_proj_weight = up_gate_proj
down_proj = layer.create_parameter(
shape=down_proj_weight_transpose.shape,
dtype=down_proj_weight_transpose.dtype,
default_initializer=paddle.nn.initializer.Normal(mean=0.0, std=0.02),
is_bias=False,
)
down_proj.copy_(down_proj_weight_transpose, False)
free_tensor(layer.down_proj_weight)
layer.down_proj_weight = down_proj
@@ -40,7 +40,13 @@ elif current_platform.is_iluvatar():
)
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs
from fastdeploy.model_executor.utils import (
TensorTracker,
free_tensor,
process_weight_transpose,
set_weight_attrs,
weight_fully_copied,
)
class CutlassMoEMethod(UnquantizedFusedMoEMethod):
@@ -1084,33 +1090,60 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
]
self.up_gate_proj_scale_shape = [layer.num_local_experts, layer.moe_intermediate_size * 2]
self.down_proj_scale_shape = [layer.num_local_experts, layer.hidden_size]
self.model_format = extra_weight_attrs.get("model_format")
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
if self.model_format != "torch":
up_gate_proj_weight_shape = [
layer.num_local_experts,
layer.hidden_size,
layer.moe_intermediate_size * 2,
]
down_proj_weight_shape = [layer.num_local_experts, layer.moe_intermediate_size, layer.hidden_size]
up_gate_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=up_gate_proj_weight_shape, output_dim=True),
}
down_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=down_proj_weight_shape, output_dim=False),
}
else:
up_gate_proj_weight_shape = [
layer.num_local_experts,
layer.moe_intermediate_size * 2,
layer.hidden_size,
]
down_proj_weight_shape = [layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size]
up_gate_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=up_gate_proj_weight_shape, output_dim=False),
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0},
}
down_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=down_proj_weight_shape, output_dim=True),
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0},
}
layer.up_gate_proj_weight = layer.create_parameter(
shape=[layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
shape=up_gate_proj_weight_shape,
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.down_proj_weight = layer.create_parameter(
shape=[layer.num_local_experts, layer.moe_intermediate_size, layer.hidden_size],
shape=down_proj_weight_shape,
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
set_weight_attrs(
layer.up_gate_proj_weight,
{
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True),
},
up_gate_proj_attrs,
)
set_weight_attrs(
layer.down_proj_weight,
{
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False),
},
down_proj_attrs,
)
else:
self.weight_dtype = "int8"
@@ -1157,7 +1190,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
default_initializer=paddle.nn.initializer.Constant(0),
),
)
extra_weight_attrs["weight_need_transpose"] = not extra_weight_attrs.get("model_format") == "torch"
# The v1 loader currently does not support loading offline quantized weight-only weights.
moe_extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}
set_weight_attrs(layer.up_gate_proj_weight, moe_extra_weight_attrs)
set_weight_attrs(layer.down_proj_weight, moe_extra_weight_attrs)
@@ -1191,66 +1224,70 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
)
def process_weights_after_loading(self, layer):
""" """
if not self.quant_config.is_checkpoint_bf16:
return
weight_id_map = {"gate_up": 0, "down": 1}
if (
hasattr(layer.up_gate_proj_weight, "tensor_track")
and layer.up_gate_proj_weight.tensor_track is not None
and layer.up_gate_proj_weight.tensor_track.is_fully_copied()
):
weight_type = "gate_up"
else:
weight_type = "down"
def _process_quantize(weight_idx):
# 1.init shape and type
# quantized_weight_name
weight_name = self.added_weight_attrs[weight_idx]
unquantized_weight_name = weight_name.replace("quant_weight", "weight")
weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape
weight_dtype = "int8"
# scale
scale_name = self.added_scale_attrs[weight_idx]
scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape
scale_dtype = self.default_dtype
# 1.init shape and type
# weight
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
unquantized_weight_name = weight_name.replace("quant_weight", "weight")
weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape
weight_dtype = "int8"
# scale
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]
scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape
scale_dtype = self.default_dtype
# 2.crate tmp tensor
# 2.crate tmp tensor
weight = paddle.empty(weight_shape, dtype=weight_dtype)
scale = paddle.empty(scale_shape, dtype=scale_dtype)
weight = paddle.empty(weight_shape, dtype=weight_dtype)
scale = paddle.empty(scale_shape, dtype=scale_dtype)
# 3.quantize weight
# 3.quantize weight
for expert_id in range(layer.num_local_experts):
weight[expert_id], scale[expert_id] = weight_quantize(
getattr(layer, unquantized_weight_name)[expert_id], algo=self.moe_quant_type
)
for expert_id in range(layer.num_local_experts):
weight[expert_id], scale[expert_id] = weight_quantize(
getattr(layer, unquantized_weight_name)[expert_id], algo=self.moe_quant_type
free_tensor(getattr(layer, unquantized_weight_name))
# create weight
setattr(
layer,
weight_name,
layer.create_parameter(
shape=weight_shape,
dtype=weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# create scale
setattr(
layer,
scale_name,
layer.create_parameter(
shape=scale_shape,
dtype=scale_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).copy_(weight, False)
getattr(layer, scale_name).copy_(scale, False)
free_tensor(getattr(layer, unquantized_weight_name))
if self.quant_config.is_checkpoint_bf16:
weight_id_map = {"gate_up": 0, "down": 1}
if weight_fully_copied(layer.up_gate_proj_weight):
weight_type = "gate_up"
else:
weight_type = "down"
# create weight
setattr(
layer,
weight_name,
layer.create_parameter(
shape=weight_shape,
dtype=weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# create scale
setattr(
layer,
scale_name,
layer.create_parameter(
shape=scale_shape,
dtype=scale_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).copy_(weight, False)
getattr(layer, scale_name).copy_(scale, False)
if self.model_format == "torch":
unquantized_weight_name = self.added_weight_attrs[weight_id_map[weight_type]].replace(
"quant_weight", "weight"
)
process_weight_transpose(layer, unquantized_weight_name)
_process_quantize(weight_id_map[weight_type])
else:
return
def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""
@@ -22,10 +22,9 @@ import fastdeploy
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.ops.gpu import count_tokens_per_expert_func, deep_gemm
from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs
from fastdeploy.utils import ceil_div
from .fused_moe_backend_base import MoEMethodBase
from .fused_moe_triton_backend import BlockWiseFP8MoEMethod
class DeepGemmFusedMoeMethod(MoEMethodBase):
@@ -37,184 +36,11 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
"""
deepgemm create weight process.
"""
self.up_gate_proj_weight_shape = [
layer.num_local_experts,
layer.moe_intermediate_size * 2,
layer.hidden_size,
]
self.down_proj_weight_shape = [
layer.num_local_experts,
layer.hidden_size,
layer.moe_intermediate_size,
]
self.up_gate_proj_scale_shape = [
layer.num_local_experts,
ceil_div(layer.moe_intermediate_size * 2, self.quant_config.weight_block_size[0]),
ceil_div(layer.hidden_size, self.quant_config.weight_block_size[1]),
]
self.down_proj_scale_shape = [
layer.num_local_experts,
ceil_div(layer.hidden_size, self.quant_config.weight_block_size[0]),
ceil_div(layer.moe_intermediate_size, self.quant_config.weight_block_size[1]),
]
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
layer.up_gate_proj_weight = layer.create_parameter(
shape=[layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.down_proj_weight = layer.create_parameter(
shape=[layer.num_local_experts, layer.moe_intermediate_size, layer.hidden_size],
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
set_weight_attrs(
layer.up_gate_proj_weight,
{
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True),
},
)
set_weight_attrs(
layer.down_proj_weight,
{
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False),
},
)
else:
self.weight_dtype = paddle.float8_e4m3fn
self.added_scale_attrs = ["up_gate_proj_weight_scale_inv", "down_proj_weight_scale_inv"]
up_gate_proj_weight_name = self.added_weight_attrs[0]
down_proj_weight_name = self.added_weight_attrs[1]
up_gate_proj_scale_name = self.added_scale_attrs[0]
down_proj_scale_name = self.added_scale_attrs[1]
setattr(
layer,
up_gate_proj_weight_name,
layer.create_parameter(
shape=self.up_gate_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
down_proj_weight_name,
layer.create_parameter(
shape=self.down_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# weight_scale
setattr(
layer,
up_gate_proj_scale_name,
layer.create_parameter(
shape=self.up_gate_proj_scale_shape,
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
down_proj_scale_name,
layer.create_parameter(
shape=self.down_proj_scale_shape,
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),
)
extra_weight_attrs["weight_need_transpose"] = not extra_weight_attrs.get("model_format") == "torch"
extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}
set_weight_attrs(
getattr(layer, up_gate_proj_weight_name),
extra_weight_attrs,
)
set_weight_attrs(
getattr(layer, up_gate_proj_scale_name),
extra_weight_attrs,
)
set_weight_attrs(
getattr(layer, down_proj_weight_name),
extra_weight_attrs,
)
set_weight_attrs(
getattr(layer, down_proj_scale_name),
extra_weight_attrs,
)
BlockWiseFP8MoEMethod.create_weights(self, layer, **extra_weight_attrs)
def process_weights_after_loading(self, layer):
""" """
if not self.quant_config.is_checkpoint_bf16:
return
weight_id_map = {"gate_up": 0, "down": 1}
if (
hasattr(layer.up_gate_proj_weight, "tensor_track")
and layer.up_gate_proj_weight.tensor_track is not None
and layer.up_gate_proj_weight.tensor_track.is_fully_copied()
):
weight_type = "gate_up"
layer.up_gate_proj_weight.tensor_track = None
else:
weight_type = "down"
layer.down_proj_weight.tensor_track = None
# 1.init shape and type
self.added_scale_attrs = ["up_gate_proj_weight_scale_inv", "down_proj_weight_scale_inv"]
# weight
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
unquantized_weight_name = weight_name.replace("quant_weight", "weight")
weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape
weight_dtype = paddle.float8_e4m3fn
# scale
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]
scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape
scale_dtype = "float32"
# 2.crate tmp tensor
weight = paddle.empty(shape=[weight_shape[0], weight_shape[2], weight_shape[1]], dtype=weight_dtype)
scale = paddle.empty(shape=[scale_shape[0], scale_shape[2], scale_shape[1]], dtype=scale_dtype)
# 3.quantize weight
from fastdeploy.model_executor.layers.utils import per_block_cast_to_fp8
for expert_id in range(layer.num_local_experts):
weight_quant, scale[expert_id] = per_block_cast_to_fp8(
getattr(layer, unquantized_weight_name)[expert_id], self.quant_config.weight_block_size
)
weight[expert_id].copy_(weight_quant, False)
getattr(layer, unquantized_weight_name).value().get_tensor()._clear()
# create weight
setattr(
layer,
weight_name,
layer.create_parameter(
shape=weight.shape,
dtype=weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# create scale
setattr(
layer,
scale_name,
layer.create_parameter(
shape=scale.shape,
dtype=scale_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).copy_(weight.transpose([0, 2, 1]).contiguous(), False)
getattr(layer, scale_name).copy_(scale.transpose([0, 2, 1]).contiguous(), False)
BlockWiseFP8MoEMethod.process_weights_after_loading(self, layer)
def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""
@@ -20,7 +20,13 @@ from paddle import nn
import fastdeploy
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs
from fastdeploy.model_executor.utils import (
TensorTracker,
free_tensor,
process_weight_transpose,
set_weight_attrs,
weight_fully_copied,
)
from fastdeploy.utils import ceil_div
from ..quantization.quant_base import QuantMethodBase
@@ -59,10 +65,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
"""
Triton MoE create weight process.
"""
self.weight_dtype = "int8"
self.default_dtype = layer._helper.get_default_dtype()
up_gate_proj_weight_name = self.added_weight_attrs[0]
down_proj_weight_name = self.added_weight_attrs[1]
self.up_gate_proj_weight_shape = [
layer.num_local_experts,
layer.hidden_size,
@@ -73,36 +76,69 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
layer.moe_intermediate_size,
layer.hidden_size,
]
self.model_format = extra_weight_attrs.get("model_format")
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
if self.model_format != "torch":
up_gate_proj_weight_shape = [
layer.num_local_experts,
layer.hidden_size,
layer.moe_intermediate_size * 2,
]
down_proj_weight_shape = [layer.num_local_experts, layer.moe_intermediate_size, layer.hidden_size]
up_gate_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=up_gate_proj_weight_shape, output_dim=True),
}
down_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=down_proj_weight_shape, output_dim=False),
}
else:
up_gate_proj_weight_shape = [
layer.num_local_experts,
layer.moe_intermediate_size * 2,
layer.hidden_size,
]
down_proj_weight_shape = [layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size]
up_gate_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=up_gate_proj_weight_shape, output_dim=False),
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0},
}
down_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=down_proj_weight_shape, output_dim=True),
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0},
}
layer.up_gate_proj_weight = layer.create_parameter(
shape=self.up_gate_proj_weight_shape,
shape=up_gate_proj_weight_shape,
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.down_proj_weight = layer.create_parameter(
shape=self.down_proj_weight_shape,
shape=down_proj_weight_shape,
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
set_weight_attrs(
layer.up_gate_proj_weight,
{
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True),
},
up_gate_proj_attrs,
)
set_weight_attrs(
layer.down_proj_weight,
{
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False),
},
down_proj_attrs,
)
else:
self.weight_dtype = "int8"
up_gate_proj_weight_name = self.added_weight_attrs[0]
down_proj_weight_name = self.added_weight_attrs[1]
up_gate_proj_scale_name = self.added_scale_attrs[0]
down_proj_scale_name = self.added_scale_attrs[1]
setattr(
layer,
up_gate_proj_weight_name,
@@ -124,7 +160,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
# weight_scale
setattr(
layer,
self.added_scale_attrs[0],
up_gate_proj_scale_name,
layer.create_parameter(
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
dtype=self.default_dtype,
@@ -133,7 +169,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
)
setattr(
layer,
self.added_scale_attrs[1],
down_proj_scale_name,
layer.create_parameter(
shape=[layer.num_local_experts, layer.hidden_size],
dtype=self.default_dtype,
@@ -185,59 +221,62 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
def process_weights_after_loading(self, layer):
""" """
if not self.quant_config.is_checkpoint_bf16:
return
algo = layer.quant_method.quant_config.name()
assert algo == "wint8"
max_bound = 127
weight_id_map = {"gate_up": 0, "down": 1}
if (
hasattr(layer.up_gate_proj_weight, "tensor_track")
and layer.up_gate_proj_weight.tensor_track is not None
and layer.up_gate_proj_weight.tensor_track.is_fully_copied()
):
weight_type = "gate_up"
layer.up_gate_proj_weight.tensor_track = None
def _process_quantize(weight_idx):
algo = layer.quant_method.quant_config.name()
assert algo == "wint8"
max_bound = 127
# weight
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
# scale
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]
weight_tensor = getattr(layer, weight_name)
quanted_weight_scale = weight_tensor.abs().max(axis=1)
quanted_weight = weight_tensor / quanted_weight_scale[:, None, :] * max_bound
quanted_weight = paddle.round(quanted_weight).astype("int8")
quanted_weight_scale = quanted_weight_scale / max_bound
free_tensor(getattr(layer, weight_name))
# create weight
setattr(
layer,
weight_name,
layer.create_parameter(
shape=weight_tensor.shape,
dtype=quanted_weight.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# create scale
setattr(
layer,
scale_name,
layer.create_parameter(
shape=quanted_weight_scale.shape,
dtype=quanted_weight_scale.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).copy_(quanted_weight, False)
getattr(layer, scale_name).copy_(quanted_weight_scale, False)
if self.quant_config.is_checkpoint_bf16:
weight_id_map = {"gate_up": 0, "down": 1}
if weight_fully_copied(layer.up_gate_proj_weight):
weight_type = "gate_up"
else:
weight_type = "down"
if self.model_format == "torch":
unquantized_weight_name = self.added_weight_attrs[weight_id_map[weight_type]].replace(
"quant_weight", "weight"
)
process_weight_transpose(layer, unquantized_weight_name)
_process_quantize(weight_id_map[weight_type])
else:
weight_type = "down"
layer.down_proj_weight.tensor_track = None
# weight
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
# scale
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]
weight_tensor = getattr(layer, weight_name)
quanted_weight_scale = weight_tensor.abs().max(axis=1)
quanted_weight = weight_tensor / quanted_weight_scale[:, None, :] * max_bound
quanted_weight = paddle.round(quanted_weight).astype("int8")
quanted_weight_scale = quanted_weight_scale / max_bound
getattr(layer, weight_name).value().get_tensor()._clear()
# create weight
setattr(
layer,
weight_name,
layer.create_parameter(
shape=weight_tensor.shape,
dtype=quanted_weight.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# create scale
setattr(
layer,
scale_name,
layer.create_parameter(
shape=quanted_weight_scale.shape,
dtype=quanted_weight_scale.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).copy_(quanted_weight, False)
getattr(layer, scale_name).copy_(quanted_weight_scale, False)
return
def apply(
self,
@@ -443,34 +482,59 @@ class Wfp8Afp8MoEMethod(QuantMethodBase):
layer.hidden_size,
1,
]
self.model_format = extra_weight_attrs.get("model_format")
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
if self.model_format != "torch":
up_gate_proj_weight_shape = [
layer.num_local_experts,
layer.hidden_size,
layer.moe_intermediate_size * 2,
]
down_proj_weight_shape = [layer.num_local_experts, layer.moe_intermediate_size, layer.hidden_size]
up_gate_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=up_gate_proj_weight_shape, output_dim=True),
}
down_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=down_proj_weight_shape, output_dim=False),
}
else:
up_gate_proj_weight_shape = [
layer.num_local_experts,
layer.moe_intermediate_size * 2,
layer.hidden_size,
]
down_proj_weight_shape = [layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size]
up_gate_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=up_gate_proj_weight_shape, output_dim=False),
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0},
}
down_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=down_proj_weight_shape, output_dim=True),
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0},
}
layer.up_gate_proj_weight = layer.create_parameter(
shape=[layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
shape=up_gate_proj_weight_shape,
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.down_proj_weight = layer.create_parameter(
shape=[layer.num_local_experts, layer.moe_intermediate_size, layer.hidden_size],
shape=down_proj_weight_shape,
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
set_weight_attrs(
layer.up_gate_proj_weight,
{
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True),
},
up_gate_proj_attrs,
)
set_weight_attrs(
layer.down_proj_weight,
{
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False),
},
down_proj_attrs,
)
else:
self.weight_dtype = paddle.float8_e4m3fn
@@ -518,66 +582,70 @@ class Wfp8Afp8MoEMethod(QuantMethodBase):
def process_weights_after_loading(self, layer):
""" """
if not self.quant_config.is_checkpoint_bf16:
return
weight_id_map = {"gate_up": 0, "down": 1}
if (
hasattr(layer.up_gate_proj_weight, "tensor_track")
and layer.up_gate_proj_weight.tensor_track is not None
and layer.up_gate_proj_weight.tensor_track.is_fully_copied()
):
weight_type = "gate_up"
layer.up_gate_proj_weight.tensor_track = None
else:
weight_type = "down"
layer.down_proj_weight.tensor_track = None
# weight
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape
weight_dtype = paddle.float8_e4m3fn
# scale
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]
scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape
scale_dtype = "float32"
def _process_quantize(weight_idx):
# weight
weight_name = self.added_weight_attrs[weight_idx]
weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape
weight_dtype = paddle.float8_e4m3fn
# scale
scale_name = self.added_scale_attrs[weight_idx]
scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape
scale_dtype = "float32"
# 2.crate tmp tensor
# 2.crate tmp tensor
weight = paddle.empty(shape=weight_shape, dtype=weight_dtype)
scale = paddle.empty(shape=scale_shape, dtype=scale_dtype)
weight = paddle.empty(shape=weight_shape, dtype=weight_dtype)
scale = paddle.empty(shape=scale_shape, dtype=scale_dtype)
# 3.quantize weight
from fastdeploy.model_executor.layers.utils import per_token_cast_to_fp8
# 3.quantize weight
from fastdeploy.model_executor.layers.utils import per_token_cast_to_fp8
for expert_id in range(layer.num_experts):
weight_quant, scale[expert_id] = per_token_cast_to_fp8(
getattr(layer, weight_name)[expert_id].transpose([1, 0]).contiguous(),
for expert_id in range(layer.num_experts):
weight_quant, scale[expert_id] = per_token_cast_to_fp8(
getattr(layer, weight_name)[expert_id].transpose([1, 0]).contiguous(),
)
weight[expert_id].copy_(weight_quant, False)
free_tensor(getattr(layer, weight_name))
# create weight
setattr(
layer,
weight_name,
layer.create_parameter(
shape=weight_shape,
dtype=weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
weight[expert_id].copy_(weight_quant, False)
getattr(layer, weight_name).value().get_tensor()._clear()
# create scale
setattr(
layer,
scale_name,
layer.create_parameter(
shape=scale_shape,
dtype=scale_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).copy_(weight, False)
getattr(layer, scale_name).copy_(scale, False)
# create weight
setattr(
layer,
weight_name,
layer.create_parameter(
shape=weight_shape,
dtype=weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# create scale
setattr(
layer,
scale_name,
layer.create_parameter(
shape=scale_shape,
dtype=scale_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).copy_(weight, False)
getattr(layer, scale_name).copy_(scale, False)
if self.quant_config.is_checkpoint_bf16:
# dynamic quantize
weight_id_map = {"gate_up": 0, "down": 1}
if weight_fully_copied(layer.up_gate_proj_weight):
weight_type = "gate_up"
else:
weight_type = "down"
if self.model_format == "torch":
# pt model
process_weight_transpose(layer, self.added_weight_attrs[weight_id_map[weight_type]])
_process_quantize(weight_id_map[weight_type])
else:
return
def check(self, layer: nn.Layer, up_gate_proj_weights, down_proj_weights):
"""
@@ -1107,45 +1175,123 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
ceil_div(layer.moe_intermediate_size, self.quant_config.weight_block_size[1]),
]
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
self.model_format = extra_weight_attrs.get("model_format")
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
if self.model_format != "torch":
up_gate_proj_weight_shape = [
layer.num_local_experts,
layer.hidden_size,
layer.moe_intermediate_size * 2,
]
down_proj_weight_shape = [layer.num_local_experts, layer.moe_intermediate_size, layer.hidden_size]
up_gate_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=up_gate_proj_weight_shape, output_dim=True),
}
down_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=down_proj_weight_shape, output_dim=False),
}
else:
up_gate_proj_weight_shape = [
layer.num_local_experts,
layer.moe_intermediate_size * 2,
layer.hidden_size,
]
down_proj_weight_shape = [layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size]
up_gate_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=up_gate_proj_weight_shape, output_dim=False),
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0},
}
down_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=down_proj_weight_shape, output_dim=True),
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0},
}
layer.up_gate_proj_weight = layer.create_parameter(
shape=[layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
shape=up_gate_proj_weight_shape,
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.down_proj_weight = layer.create_parameter(
shape=[layer.num_local_experts, layer.moe_intermediate_size, layer.hidden_size],
shape=down_proj_weight_shape,
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
set_weight_attrs(
layer.up_gate_proj_weight,
{
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True),
},
up_gate_proj_attrs,
)
set_weight_attrs(
layer.down_proj_weight,
{
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False),
},
down_proj_attrs,
)
else:
# 1.init shape
extra_weight_attrs = {**extra_weight_attrs}
if layer.fd_config.load_config.load_choices == "default_v1":
if self.model_format != "torch":
# transpose [0,2,1]
up_gate_proj_weight_shape = (
self.up_gate_proj_weight_shape[:1] + self.up_gate_proj_weight_shape[1:][::-1]
)
up_gate_proj_scale_shape = (
self.up_gate_proj_scale_shape[:1] + self.up_gate_proj_scale_shape[1:][::-1]
)
down_proj_weight_shape = self.down_proj_weight_shape[:1] + self.down_proj_weight_shape[1:][::-1]
down_proj_scale_shape = self.down_proj_scale_shape[:1] + self.down_proj_scale_shape[1:][::-1]
up_gate_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(
shape=up_gate_proj_weight_shape,
output_dim=False,
),
}
down_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(
shape=down_proj_weight_shape,
output_dim=False,
),
}
else:
up_gate_proj_weight_shape = self.up_gate_proj_weight_shape
up_gate_proj_scale_shape = self.up_gate_proj_scale_shape
down_proj_weight_shape = self.down_proj_weight_shape
down_proj_scale_shape = self.down_proj_scale_shape
up_gate_proj_attrs = {
**extra_weight_attrs,
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0},
}
down_proj_attrs = {
**extra_weight_attrs,
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0},
}
else:
# v0 loader
up_gate_proj_weight_shape = self.up_gate_proj_weight_shape
up_gate_proj_scale_shape = self.up_gate_proj_scale_shape
down_proj_weight_shape = self.down_proj_weight_shape
down_proj_scale_shape = self.down_proj_scale_shape
up_gate_proj_attrs = {}
down_proj_attrs = {}
self.weight_dtype = paddle.float8_e4m3fn
self.added_scale_attrs = ["up_gate_proj_weight_scale_inv", "down_proj_weight_scale_inv"]
up_gate_proj_weight_name = self.added_weight_attrs[0]
down_proj_weight_name = self.added_weight_attrs[1]
up_gate_proj_scale_name = self.added_scale_attrs[0]
down_proj_scale_name = self.added_scale_attrs[1]
setattr(
layer,
up_gate_proj_weight_name,
layer.create_parameter(
shape=self.up_gate_proj_weight_shape,
shape=up_gate_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
@@ -1154,7 +1300,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
layer,
down_proj_weight_name,
layer.create_parameter(
shape=self.down_proj_weight_shape,
shape=down_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
@@ -1164,7 +1310,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
layer,
up_gate_proj_scale_name,
layer.create_parameter(
shape=self.up_gate_proj_scale_shape,
shape=up_gate_proj_scale_shape,
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),
@@ -1173,97 +1319,116 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
layer,
down_proj_scale_name,
layer.create_parameter(
shape=self.down_proj_scale_shape,
shape=down_proj_scale_shape,
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),
)
extra_weight_attrs["weight_need_transpose"] = not extra_weight_attrs.get("model_format") == "torch"
extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}
set_weight_attrs(
getattr(layer, up_gate_proj_weight_name),
extra_weight_attrs,
up_gate_proj_attrs,
)
set_weight_attrs(
getattr(layer, up_gate_proj_scale_name),
extra_weight_attrs,
up_gate_proj_attrs,
)
set_weight_attrs(
getattr(layer, down_proj_weight_name),
extra_weight_attrs,
down_proj_attrs,
)
set_weight_attrs(
getattr(layer, down_proj_scale_name),
extra_weight_attrs,
down_proj_attrs,
)
def process_weights_after_loading(self, layer):
""" """
if not self.quant_config.is_checkpoint_bf16:
return
weight_id_map = {"gate_up": 0, "down": 1}
if (
hasattr(layer.up_gate_proj_weight, "tensor_track")
and layer.up_gate_proj_weight.tensor_track is not None
and layer.up_gate_proj_weight.tensor_track.is_fully_copied()
):
weight_type = "gate_up"
layer.up_gate_proj_weight.tensor_track = None
else:
weight_type = "down"
layer.down_proj_weight.tensor_track = None
# 1.init shape and type
self.added_scale_attrs = ["up_gate_proj_weight_scale_inv", "down_proj_weight_scale_inv"]
# weight
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
unquantized_weight_name = weight_name.replace("quant_weight", "weight")
weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape
weight_dtype = paddle.float8_e4m3fn
# scale
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]
scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape
scale_dtype = "float32"
def _process_quantize(weight_idx):
# 1.init shape and type
self.added_scale_attrs = ["up_gate_proj_weight_scale_inv", "down_proj_weight_scale_inv"]
# weight
weight_name = self.added_weight_attrs[weight_idx]
unquantized_weight_name = weight_name.replace("quant_weight", "weight")
weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape
weight_dtype = paddle.float8_e4m3fn
# scale
scale_name = self.added_scale_attrs[weight_idx]
scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape
scale_dtype = "float32"
# 2.crate tmp tensor
# 2.crate tmp tensor
weight = paddle.empty(shape=[weight_shape[0], weight_shape[2], weight_shape[1]], dtype=weight_dtype)
scale = paddle.empty(shape=[scale_shape[0], scale_shape[2], scale_shape[1]], dtype=scale_dtype)
weight = paddle.empty(shape=[weight_shape[0], weight_shape[2], weight_shape[1]], dtype=weight_dtype)
scale = paddle.empty(shape=[scale_shape[0], scale_shape[2], scale_shape[1]], dtype=scale_dtype)
# 3.quantize weight
from fastdeploy.model_executor.layers.utils import per_block_cast_to_fp8
# 3.quantize weight
from fastdeploy.model_executor.layers.utils import per_block_cast_to_fp8
for expert_id in range(layer.num_local_experts):
weight_quant, scale[expert_id] = per_block_cast_to_fp8(
getattr(layer, unquantized_weight_name)[expert_id], self.quant_config.weight_block_size
for expert_id in range(layer.num_local_experts):
weight_quant, scale[expert_id] = per_block_cast_to_fp8(
getattr(layer, unquantized_weight_name)[expert_id], self.quant_config.weight_block_size
)
weight[expert_id].copy_(weight_quant, False)
free_tensor(getattr(layer, unquantized_weight_name))
# create weight
setattr(
layer,
weight_name,
layer.create_parameter(
shape=weight.shape,
dtype=weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
weight[expert_id].copy_(weight_quant, False)
getattr(layer, unquantized_weight_name).value().get_tensor()._clear()
# create scale
setattr(
layer,
scale_name,
layer.create_parameter(
shape=scale.shape,
dtype=scale_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).copy_(weight.transpose([0, 2, 1]).contiguous(), False)
getattr(layer, scale_name).copy_(scale.transpose([0, 2, 1]).contiguous(), False)
# create weight
setattr(
layer,
weight_name,
layer.create_parameter(
shape=weight.shape,
dtype=weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# create scale
setattr(
layer,
scale_name,
layer.create_parameter(
shape=scale.shape,
dtype=scale_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).copy_(weight.transpose([0, 2, 1]).contiguous(), False)
getattr(layer, scale_name).copy_(scale.transpose([0, 2, 1]).contiguous(), False)
if self.quant_config.is_checkpoint_bf16:
# dynamic quantize
weight_id_map = {"gate_up": 0, "down": 1}
if weight_fully_copied(layer.up_gate_proj_weight):
weight_type = "gate_up"
else:
weight_type = "down"
if self.model_format == "torch":
# pt model
unquantized_weight_name = self.added_weight_attrs[weight_id_map[weight_type]].replace(
"quant_weight", "weight"
)
process_weight_transpose(layer, unquantized_weight_name)
_process_quantize(weight_id_map[weight_type])
else:
if self.model_format != "torch":
up_gate_proj_weight_name = self.added_weight_attrs[0]
down_proj_weight_name = self.added_weight_attrs[1]
up_gate_proj_scale_name = self.added_scale_attrs[0]
down_proj_scale_name = self.added_scale_attrs[1]
if (
not weight_fully_copied(getattr(layer, up_gate_proj_weight_name))
or not weight_fully_copied(getattr(layer, down_proj_weight_name))
or not weight_fully_copied(getattr(layer, up_gate_proj_scale_name))
or not weight_fully_copied(getattr(layer, down_proj_scale_name))
):
return
process_weight_transpose(layer, up_gate_proj_weight_name)
process_weight_transpose(layer, down_proj_weight_name)
process_weight_transpose(layer, up_gate_proj_scale_name)
process_weight_transpose(layer, down_proj_scale_name)
else:
return
def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""
+56 -65
View File
@@ -16,14 +16,13 @@
from typing import Optional
import numpy as np
import paddle
from paddle import nn
from paddleformers.utils.log import logger
from fastdeploy import envs
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.utils import slice_fn
from fastdeploy.model_executor.utils import h2d_copy, slice_fn
from fastdeploy.platforms import current_platform
from fastdeploy.worker.experts_manager import RedundantExpertManger
@@ -31,6 +30,7 @@ try:
from fastdeploy.model_executor.ops.gpu import noaux_tc
except:
logger.warning("import noaux_tc Failed!")
import numpy as np
def get_moe_method():
@@ -118,6 +118,7 @@ class FusedMoE(nn.Layer):
weight_key_map: dict = {},
with_bias: bool = False,
activation="swiglu",
model_format: Optional[str] = None,
):
"""
Initialize the Moe layer with given parameters.
@@ -201,7 +202,7 @@ class FusedMoE(nn.Layer):
self.quant_method.create_weights(
self,
weight_loader=self.weight_loader,
model_format=fd_config.model_config.model_format,
model_format=fd_config.model_config.model_format if model_format is None else model_format,
num_experts=self.num_local_experts if self.ep_size > 1 else self.num_experts,
hidden_size=self.hidden_size,
moe_intermediate_size=self.moe_intermediate_size,
@@ -214,72 +215,68 @@ class FusedMoE(nn.Layer):
tp_size={self.tp_size}."
)
def weight_loader(self, param, loaded_weight, expert_id, shard_id: Optional[str] = None):
def weight_loader(
self, param, loaded_weight, expert_id, shard_id: Optional[str] = None, source: Optional[str] = None
):
"""
source:Avoid redundant transpose of fused weights when weight_loader is called iteratively
"""
if expert_id is None and shard_id is None:
# MoE experts has been fused in disk
self._load_fused_experts_weight(param, loaded_weight)
return
if hasattr(param, "SHARD_ID_TO_SHARDED_DIM"):
SHARD_ID_TO_SHARDED_DIM = param.SHARD_ID_TO_SHARDED_DIM
elif current_platform.is_cuda() or current_platform.is_iluvatar():
SHARD_ID_TO_SHARDED_DIM = {"gate": 1, "down": 0, "up": 1}
else:
SHARD_ID_TO_SHARDED_DIM = {"gate": 0, "down": 1, "up": 0}
if expert_id - self.expert_id_offset >= 0 and expert_id - self.expert_id_offset < self.num_local_experts:
if hasattr(param, "SHARD_ID_TO_SHARDED_DIM"):
SHARD_ID_TO_SHARDED_DIM = param.SHARD_ID_TO_SHARDED_DIM
elif current_platform.is_cuda() or current_platform.is_iluvatar():
SHARD_ID_TO_SHARDED_DIM = {"gate": 1, "down": 0, "up": 1}
else:
SHARD_ID_TO_SHARDED_DIM = {"gate": 0, "down": 1, "up": 0}
if not param._is_initialized():
param.initialize()
if not (expert_id - self.expert_id_offset >= 0 and expert_id - self.expert_id_offset < self.num_local_experts):
return
weight_need_transpose = getattr(param, "weight_need_transpose", False)
if shard_id is None:
# 1.gate up fused in disk
if weight_need_transpose:
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
output_size = param[expert_id - self.expert_id_offset].shape[SHARD_ID_TO_SHARDED_DIM["gate"]]
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("gate", 0, output_size // 2 * self.tp_size),
("up", output_size // 2 * self.tp_size, output_size // 2 * self.tp_size),
]
if not param._is_initialized():
param.initialize()
if shard_id is None:
# 1.gate up fused in disk
weight_need_transpose = getattr(param, "weight_need_transpose", False)
output_size = param[expert_id - self.expert_id_offset].shape[SHARD_ID_TO_SHARDED_DIM["gate"]]
per_rank = output_size // 2
start = self.tp_rank * per_rank
loaded_weight_shard_gate = slice_fn(
loaded_weight, weight_need_transpose ^ SHARD_ID_TO_SHARDED_DIM["gate"], start, start + per_rank
)
self._load_gate_up_weight(
param,
expert_id,
loaded_weight_shard_gate,
"gate",
SHARD_ID_TO_SHARDED_DIM["gate"],
is_sharded=True,
)
start_up = output_size // 2 * self.tp_size + self.tp_rank * per_rank
loaded_weight_shard_up = slice_fn(
loaded_weight, weight_need_transpose ^ SHARD_ID_TO_SHARDED_DIM["up"], start_up, start_up + per_rank
)
self._load_gate_up_weight(
param, expert_id, loaded_weight_shard_up, "up", SHARD_ID_TO_SHARDED_DIM["up"], is_sharded=True
)
else:
# 2.gate up splited in disk
assert shard_id in ["gate", "down", "up"]
self._load_expert_weight(
param=param,
expert_id=expert_id,
loaded_weight=loaded_weight,
shard_id=shard_id,
shard_dim=SHARD_ID_TO_SHARDED_DIM[shard_id],
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = slice_fn(
loaded_weight, SHARD_ID_TO_SHARDED_DIM[shard_id], shard_offset, shard_offset + shard_size
)
self.weight_loader(param, loaded_weight_shard, expert_id, shard_id, "fused")
else:
if weight_need_transpose and source != "fused":
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
# 2.gate up splited in disk
assert shard_id in ["gate", "down", "up"]
self._load_expert_weight(
param=param,
expert_id=expert_id,
loaded_weight=loaded_weight,
shard_id=shard_id,
shard_dim=SHARD_ID_TO_SHARDED_DIM[shard_id],
)
def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None, is_sharded=False):
weight_need_transpose = getattr(param, "weight_need_transpose", False)
if self.tp_size > 1 and not is_sharded:
tp_shard_dim = weight_need_transpose ^ shard_dim
tp_shard_dim = shard_dim
weight_dim = -1 if tp_shard_dim else 0
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
size = loaded_weight.shape[weight_dim]
else:
size = loaded_weight.get_shape()[weight_dim]
size = loaded_weight.shape[weight_dim]
block_size = size // self.tp_size
shard_offset = self.tp_rank * block_size
shard_size = (self.tp_rank + 1) * block_size
loaded_weight = slice_fn(loaded_weight, tp_shard_dim, shard_offset, shard_size)
loaded_weight = get_tensor(loaded_weight)
expert_param = param[expert_id - self.expert_id_offset]
dim = -1 if shard_dim else 0
param_shard_size = expert_param.shape[dim] // 2
@@ -310,22 +307,17 @@ class FusedMoE(nn.Layer):
loaded_weight = loaded_weight.view(expert_param.dtype)
else:
loaded_weight = loaded_weight.cast(expert_param.dtype)
expert_param.copy_(loaded_weight, False)
h2d_copy(dst=expert_param, src=loaded_weight)
def _load_down_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None):
weight_need_transpose = getattr(param, "weight_need_transpose", False)
if self.tp_size > 1 and shard_dim is not None:
tp_shard_dim = weight_need_transpose ^ shard_dim
tp_shard_dim = shard_dim
dim = -1 if tp_shard_dim else 0
if isinstance(loaded_weight, paddle.Tensor):
size = loaded_weight.shape[dim]
else:
size = loaded_weight.get_shape()[dim]
size = loaded_weight.shape[dim]
block_size = size // self.tp_size
shard_offset = self.tp_rank * block_size
shard_size = (self.tp_rank + 1) * block_size
loaded_weight = slice_fn(loaded_weight, tp_shard_dim, shard_offset, shard_size)
loaded_weight = get_tensor(loaded_weight)
expert_param = param[expert_id - self.expert_id_offset]
if hasattr(param, "tensor_track"):
# for dyn quant
@@ -341,7 +333,7 @@ class FusedMoE(nn.Layer):
loaded_weight = loaded_weight.view(expert_param.dtype)
else:
loaded_weight = loaded_weight.cast(expert_param.dtype)
expert_param.copy_(loaded_weight, False)
h2d_copy(dst=expert_param, src=loaded_weight)
def _load_fused_experts_weight(self, param, loaded_weight):
if self.tp_size > 1:
@@ -357,8 +349,7 @@ class FusedMoE(nn.Layer):
assert param.shape == loaded_weight.shape, (
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
loaded_weight = get_tensor(loaded_weight)
param.copy_(loaded_weight, False)
h2d_copy(dst=param, src=loaded_weight)
if hasattr(param, "tensor_track"):
for i in range(self.num_local_experts):