polish code with new pre-commit rule (#2923)

This commit is contained in:
Zero Rains
2025-07-19 23:19:27 +08:00
committed by GitHub
parent b8676d71a8
commit 25698d56d1
424 changed files with 14307 additions and 13518 deletions
+63 -53
View File
@@ -27,16 +27,21 @@ def get_moe_method():
return moe method based on device platform
"""
from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
from .fused_moe_cutlass_backend import CutlassMoEMethod
return CutlassMoEMethod(None)
elif current_platform.is_xpu():
from .fused_moe_xpu_backend import XPUMoEMethod
return XPUMoEMethod(None)
elif current_platform.is_gcu():
from fastdeploy.model_executor.layers.backends import GCUFusedMoeMethod
return GCUFusedMoeMethod(None)
raise NotImplementedError()
raise NotImplementedError
class FusedMoE(nn.Layer):
"""
@@ -76,9 +81,9 @@ class FusedMoE(nn.Layer):
self.ep_size = fd_config.parallel_config.expert_parallel_size
self.ep_rank = fd_config.parallel_config.expert_parallel_rank
assert (self.tp_size >= 1 and self.ep_size == 1) or \
(self.tp_size == 1 and self.ep_size > 1), \
'MoE only support parallelism on TP or EP dimension.'
assert (self.tp_size >= 1 and self.ep_size == 1) or (
self.tp_size == 1 and self.ep_size > 1
), "MoE only support parallelism on TP or EP dimension."
self.hidden_size = fd_config.model_config.hidden_size
self.num_experts = num_experts
@@ -123,7 +128,8 @@ class FusedMoE(nn.Layer):
f"{moe_tag}MoE config is {num_experts=}[{expert_id_offset}, {expert_id_offset+self.num_local_experts}), \
{top_k=}, hidden_size={self.hidden_size}, {moe_intermediate_size=}, \
, ep_size={self.ep_size}, \
tp_size={self.tp_size}.")
tp_size={self.tp_size}."
)
def init_moe_weights(self):
"""
@@ -147,15 +153,31 @@ class FusedMoE(nn.Layer):
)
up_gate_proj_output_dim = self.moe_intermediate_size * 2
if self.moe_quant_type in ["fp8", "wint8"]:
up_gate_proj_weight_shape = [self.num_local_experts, up_gate_proj_output_dim, self.hidden_size]
down_proj_weight_shape = [self.num_local_experts, self.hidden_size, self.moe_intermediate_size]
up_gate_proj_weight_shape = [
self.num_local_experts,
up_gate_proj_output_dim,
self.hidden_size,
]
down_proj_weight_shape = [
self.num_local_experts,
self.hidden_size,
self.moe_intermediate_size,
]
else:
up_gate_proj_weight_shape = [self.num_local_experts, self.hidden_size, up_gate_proj_output_dim]
down_proj_weight_shape = [self.num_local_experts, self.moe_intermediate_size, self.hidden_size]
up_gate_proj_weight_shape = [
self.num_local_experts,
self.hidden_size,
up_gate_proj_output_dim,
]
down_proj_weight_shape = [
self.num_local_experts,
self.moe_intermediate_size,
self.hidden_size,
]
# Create parameters
if self.moe_quant_type == "fp8":
#(TODO:gaoziyuan)
# (TODO:gaoziyuan)
pass
elif self.moe_quant_type == "wint8":
self.weight_dtype = "int8"
@@ -187,9 +209,12 @@ class FusedMoE(nn.Layer):
dtype=self._dtype,
)
def load_experts_weight(self, state_dict: dict,
up_gate_proj_expert_weight_key: str,
down_proj_expert_weight_key: str):
def load_experts_weight(
self,
state_dict: dict,
up_gate_proj_expert_weight_key: str,
down_proj_expert_weight_key: str,
):
"""
Load experts weight from state_dict.
Args:
@@ -199,35 +224,23 @@ class FusedMoE(nn.Layer):
"""
up_gate_proj_weights = []
down_proj_weights = []
is_ffn_merged = up_gate_proj_expert_weight_key.format(
self.expert_id_offset) in state_dict
is_ffn_merged = up_gate_proj_expert_weight_key.format(self.expert_id_offset) in state_dict
if is_ffn_merged:
for i in range(self.num_local_experts):
expert_idx = self.expert_id_offset + i
up_gate_proj_weights.append(
get_tensor(
state_dict.pop(
up_gate_proj_expert_weight_key.format(expert_idx))))
down_proj_weights.append(
get_tensor(
state_dict.pop(
down_proj_expert_weight_key.format(expert_idx))))
get_tensor(state_dict.pop(up_gate_proj_expert_weight_key.format(expert_idx)))
)
down_proj_weights.append(get_tensor(state_dict.pop(down_proj_expert_weight_key.format(expert_idx))))
else:
gate_expert_weight_key = up_gate_proj_expert_weight_key.replace(
"up_gate_proj", "gate_proj")
up_expert_weight_key = up_gate_proj_expert_weight_key.replace(
"up_gate_proj", "up_proj")
gate_expert_weight_key = up_gate_proj_expert_weight_key.replace("up_gate_proj", "gate_proj")
up_expert_weight_key = up_gate_proj_expert_weight_key.replace("up_gate_proj", "up_proj")
for j in range(self.num_local_experts):
expert_idx = self.expert_id_offset + j
gate = get_tensor(
state_dict.pop(gate_expert_weight_key.format(expert_idx)))
up = get_tensor(
state_dict.pop(up_expert_weight_key.format(expert_idx)))
gate = get_tensor(state_dict.pop(gate_expert_weight_key.format(expert_idx)))
up = get_tensor(state_dict.pop(up_expert_weight_key.format(expert_idx)))
up_gate_proj_weights.append(paddle.concat([gate, up], axis=-1))
down_proj_weights.append(
get_tensor(
state_dict.pop(
down_proj_expert_weight_key.format(expert_idx))))
down_proj_weights.append(get_tensor(state_dict.pop(down_proj_expert_weight_key.format(expert_idx))))
return up_gate_proj_weights, down_proj_weights
def extract_moe_ffn_weights(self, state_dict: dict):
@@ -246,46 +259,43 @@ class FusedMoE(nn.Layer):
AssertionError: If required weight keys are missing or number of weights
doesn't match number of local experts.
"""
up_gate_proj_expert_weight_key = self.weight_key_map.get(
"up_gate_proj_expert_weight_key", None)
down_proj_expert_weight_key = self.weight_key_map.get(
"down_proj_expert_weight_key", None)
up_gate_proj_expert_weight_key = self.weight_key_map.get("up_gate_proj_expert_weight_key", None)
down_proj_expert_weight_key = self.weight_key_map.get("down_proj_expert_weight_key", None)
assert up_gate_proj_expert_weight_key is not None, "up_gate_proj_expert_weight_key should not be none."
assert down_proj_expert_weight_key is not None, "down_proj_expert_weight_key should not be none."
up_gate_proj_weights, down_proj_weights = self.load_experts_weight(
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key)
assert len(
up_gate_proj_weights
) == self.num_local_experts, "up_gate_proj_weights length should be equal to num_local_experts."
assert len(
down_proj_weights
) == self.num_local_experts, "down_proj_weights length should be equal to num_local_experts."
state_dict,
up_gate_proj_expert_weight_key,
down_proj_expert_weight_key,
)
assert (
len(up_gate_proj_weights) == self.num_local_experts
), "up_gate_proj_weights length should be equal to num_local_experts."
assert (
len(down_proj_weights) == self.num_local_experts
), "down_proj_weights length should be equal to num_local_experts."
return up_gate_proj_weights, down_proj_weights
def extract_gate_correction_bias(self, gate_correction_bias_key,
state_dict):
def extract_gate_correction_bias(self, gate_correction_bias_key, state_dict):
"""
extract_gate_correction_bias function.
"""
gate_correction_bias_tensor = get_tensor(
state_dict.pop(gate_correction_bias_key)).astype("float32")
gate_correction_bias_tensor = get_tensor(state_dict.pop(gate_correction_bias_key)).astype("float32")
return gate_correction_bias_tensor
def load_state_dict(self, state_dict):
"""
load_state_dict function.
"""
self.gate_correction_bias_key = self.weight_key_map.get(
"gate_correction_bias_key", None)
self.gate_correction_bias_key = self.weight_key_map.get("gate_correction_bias_key", None)
if self.gate_correction_bias_key is not None and self.gate_correction_bias_key in state_dict:
self.moe_use_gate_correction_bias = True
else:
self.moe_use_gate_correction_bias = False
if self.moe_use_gate_correction_bias:
gate_correction_bias_tensor = self.extract_gate_correction_bias(
self.gate_correction_bias_key, state_dict)
gate_correction_bias_tensor = self.extract_gate_correction_bias(self.gate_correction_bias_key, state_dict)
self.gate_correction_bias = self.create_parameter(
shape=gate_correction_bias_tensor.shape,
dtype="float32",