remove load_up_proj_weight_first (#6932)

This commit is contained in:
周周周
2026-03-19 17:21:34 +08:00
committed by GitHub
parent 33e01f22a8
commit b1c800b64b
3 changed files with 11 additions and 11 deletions
+3 -4
View File
@@ -322,8 +322,8 @@ class FusedMoE(nn.Layer):
shard_dim=SHARD_ID_TO_SHARDED_DIM[shard_id],
)
def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None, is_sharded=False):
if self.tp_size > 1 and not is_sharded and not self.fd_config.load_config.is_pre_sharded:
def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None):
if self.tp_size > 1 and not self.fd_config.load_config.is_pre_sharded:
tp_shard_dim = shard_dim
weight_dim = -1 if tp_shard_dim else 0
size = loaded_weight.shape[weight_dim]
@@ -334,8 +334,7 @@ class FusedMoE(nn.Layer):
expert_param = param[expert_id - self.expert_id_offset]
dim = -1 if shard_dim else 0
param_shard_size = expert_param.shape[dim] // 2
switch_w13 = getattr(self.quant_method, "load_up_proj_weight_first", False)
if (shard_id == "gate" and not switch_w13) or (shard_id == "up" and switch_w13):
if shard_id == "gate":
param_shard_offset = 0
else:
param_shard_offset = param_shard_size
@@ -503,14 +503,16 @@ class ModelOptNvFp4FusedMoE(MoEMethodBase):
set_weight_attrs(layer.up_gate_proj_input_scale, {**extra_weight_attrs, "weight_type": "input_scale"})
set_weight_attrs(layer.down_proj_input_scale, {**extra_weight_attrs, "weight_type": "input_scale"})
@property
def load_up_proj_weight_first(self) -> bool:
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
# 目前默认给True
return True
def process_weights_after_loading(self, layer):
""" """
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
[a, b] = layer.up_gate_proj_weight.split(2, axis=1)
layer.up_gate_proj_weight.set_value(paddle.concat([b, a], axis=1))
[a, b] = layer.up_gate_proj_weight_scale.split(2, axis=1)
layer.up_gate_proj_weight_scale.set_value(paddle.concat([b, a], axis=1))
up_gate_proj_weight_scale_2 = layer.up_gate_proj_weight_scale_2[:, 0]
free_tensor(layer.up_gate_proj_weight_scale_2)
create_parameter_and_copy(layer, name="up_gate_proj_weight_scale_2", weight=up_gate_proj_weight_scale_2)
@@ -432,7 +432,6 @@ class TestModelOptNvFp4FusedMoE(unittest.TestCase):
scale = paddle.ones([1, 64, 16], dtype=paddle.float16)
swizzled = _process_scale_interleaved(scale)
self.assertEqual(list(swizzled.shape), [1, 128, 16])
self.assertTrue(method.load_up_proj_weight_first)
def test_process_weights_after_loading(self):
"""Test post-loading weight processing logic for FusedMoE layers."""