[Feature] Support NVFP4 MoE on SM100 (#6003)

* fp4 dense

* [WIP] support nvfp4, dense part

* [wip] developing loading qwen model

* loading

* update

* dense fp4 OK, cudagraph error

* [WIP] moe forward part

* with flashinfer-backend

* qwen3_moe_fp4

* update

* support flashinfer-cutlass moe, qwen3-moe-fp4 OK

* support ernie4.5-fp4

* fix load error

* add some ut

* add docs

* fix CLA, test

* fix the apply() in ModelOptNvFp4FusedMoE

* fix CodeStyle

* del the PADDLE_COMPATIBLE_API

* fix broken url: nvidia_gpu.md

* fix docs

* fix token_ids

* fix CI in Hopper

* move flashinfer imports inside the function

* fix model_runner

Removed the logic for generating random padding IDs.

* Remove skip condition for CUDA version in nvfp4 test

* add test for nvfp4

* fix according to review

* Add Chinese translation link to NVFP4 documentation

* del flashinfer.py

* fix unittest

---------

Co-authored-by: zoooo0820 <zoooo0820@qq.com>
Co-authored-by: bukejiyu <395822456@qq.com>
This commit is contained in:
yuxuan
2026-01-29 14:16:07 +08:00
committed by GitHub
parent eb80724b71
commit 44b52701f6
8 changed files with 1369 additions and 5 deletions
+47 -4
View File
@@ -325,10 +325,10 @@ class FusedMoE(nn.Layer):
expert_param = param[expert_id - self.expert_id_offset]
dim = -1 if shard_dim else 0
param_shard_size = expert_param.shape[dim] // 2
if shard_id == "gate":
switch_w13 = getattr(self.quant_method, "load_up_proj_weight_first", False)
if (shard_id == "gate" and not switch_w13) or (shard_id == "up" and switch_w13):
param_shard_offset = 0
else:
# shard_id == "up":
param_shard_offset = param_shard_size
expert_param = slice_fn(
expert_param, shard_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size
@@ -342,8 +342,22 @@ class FusedMoE(nn.Layer):
)
# To ensure compatibility across backends, apply an extra transpose for GCU and XPU
if expert_param.shape != loaded_weight.shape:
loaded_weight = loaded_weight.transpose([1, 0])
if len(expert_param.shape) != len(loaded_weight.shape):
logger.warning(
"[MoE] Expert weight rank mismatch detected "
f"(loaded: {loaded_weight.shape}, expected: {expert_param.shape}). "
"Reshaping loaded weight for compatibility."
)
loaded_weight = loaded_weight.reshape(expert_param.shape)
else:
logger.warning(
"[MoE] Expert weight layout mismatch detected "
f"(loaded: {loaded_weight.shape}, expected: {expert_param.shape}). "
"Applying transpose to match parameter layout."
)
loaded_weight = loaded_weight.transpose([1, 0])
assert expert_param.shape == loaded_weight.shape, (
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})"
)
@@ -402,6 +416,32 @@ class FusedMoE(nn.Layer):
for i in range(self.num_local_experts):
param.tensor_track.mark(start=0, batch_id=i)
def _load_per_tensor_weight_scale(
self,
param,
expert_id,
loaded_weight,
shard_id,
):
loaded_weight = get_tensor(loaded_weight)
expert_param = param[expert_id - self.expert_id_offset]
if shard_id in ["gate", "up"]:
idx = 0 if shard_id == "gate" else 1
if expert_param[idx].shape != loaded_weight.shape:
if len(expert_param[idx].shape) != len(loaded_weight.shape):
loaded_weight = loaded_weight.reshape(expert_param[idx].shape)
else:
loaded_weight = loaded_weight.transpose([1, 0])
expert_param[idx].set_value(loaded_weight)
elif shard_id == "down":
if expert_param.shape != loaded_weight.shape:
if len(expert_param.shape) != len(loaded_weight.shape):
loaded_weight = loaded_weight.reshape(expert_param.shape)
else:
loaded_weight = loaded_weight.transpose([1, 0])
expert_param.set_value(loaded_weight)
def _load_expert_weight(
self,
param,
@@ -410,7 +450,10 @@ class FusedMoE(nn.Layer):
shard_id,
shard_dim=None,
):
if shard_id == "down":
weight_type = getattr(param, "weight_type", None)
if weight_type in ["weight_scale_2", "input_scale"]:
self._load_per_tensor_weight_scale(param, expert_id, loaded_weight, shard_id)
elif shard_id == "down":
self._load_down_weight(param, expert_id, loaded_weight, shard_id, shard_dim)
elif shard_id in ["gate", "up"]:
self._load_gate_up_weight(param, expert_id, loaded_weight, shard_id, shard_dim)