Support MXFP4 for GPT-OSS (#5435)

* support mxfp4 in gpt-oss

* support mxfp4 in gpt-oss

* add scope for flashinfer

* remove torch code

* update envs.FD_MXFP4_BACKEND

* update process_weights_after_loading

* update env name

* support tp in gpt-oss, add e2e test

* add flashinfer-python-paddle in requirements

* fix import error

* add test

* add test

* add test

* add test
This commit is contained in:
Haonan Luo
2026-01-22 14:21:01 +08:00
committed by GitHub
parent 309c7d9764
commit 82057cb71f
13 changed files with 670 additions and 25 deletions
+9 -2
View File
@@ -254,7 +254,12 @@ class FusedMoE(nn.Layer):
)
def weight_loader(
self, param, loaded_weight, expert_id, shard_id: Optional[str] = None, source: Optional[str] = None
self,
param,
loaded_weight,
expert_id,
shard_id: Optional[str] = None,
source: Optional[str] = None,
):
"""
source:Avoid redundant transpose of fused weights when weight_loader is called iteratively
@@ -376,7 +381,7 @@ class FusedMoE(nn.Layer):
h2d_copy(dst=expert_param, src=loaded_weight)
def _load_fused_experts_weight(self, param, loaded_weight):
if self.tp_size > 1:
if self.tp_size > 1 and self.moe_quant_type != "mxfp4":
dim = -1
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
size = loaded_weight.shape[dim]
@@ -386,9 +391,11 @@ class FusedMoE(nn.Layer):
shard_offset = self.tp_rank * block_size
shard_size = (self.tp_rank + 1) * block_size
loaded_weight = slice_fn(loaded_weight, dim, shard_offset, shard_size)
assert param.shape == loaded_weight.shape, (
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
h2d_copy(dst=param, src=loaded_weight)
if hasattr(param, "tensor_track"):