[BugFix] fix flashinfer-cutedsl moe nvfp4 (#7120)

* fix nvfp4

* fix

* add document

* fix nvfp4

* support eb5

* support bka

* support eb5

* support xpu

* fix

* fix

* add import cutedsl

* fix

* fix

* fix test

* fix H卡

* update document

* fix

* update document

* update document

* fix
This commit is contained in:
lizexu123
2026-04-03 15:43:19 +08:00
committed by GitHub
parent 095a11d932
commit 5f612a348d
8 changed files with 317 additions and 90 deletions
+59 -9
View File
@@ -124,17 +124,52 @@ class TestModelOptNvFp4ModuleInit(unittest.TestCase):
"""Unit tests for nvfp4 module initialization under different environments."""
def test_module_import_without_flashinfer(self):
"""Test module reloading when flashinfer is not available."""
with mock.patch.dict(sys.modules, {"flashinfer": None}):
"""Test module reloading when flashinfer is not available (non-Blackwell GPU)."""
# Mock is_nvfp4_supported at the source (quant_base) to return False
# This simulates H-card or non-CUDA platform
with mock.patch(
"fastdeploy.model_executor.layers.quantization.quant_base.is_nvfp4_supported",
return_value=False,
):
with mock.patch("paddleformers.utils.log.logger.warning"):
# Clear the module's flashinfer-related attributes before reload
# to simulate a fresh import on non-supported GPU
if hasattr(nvfp4_module, "fp4_quantize"):
delattr(nvfp4_module, "fp4_quantize")
if hasattr(nvfp4_module, "mm_fp4"):
delattr(nvfp4_module, "mm_fp4")
if hasattr(nvfp4_module, "flashinfer_cutlass_fused_moe"):
delattr(nvfp4_module, "flashinfer_cutlass_fused_moe")
importlib.reload(nvfp4_module)
# Verify that flashinfer imports were skipped
self.assertIsNone(nvfp4_module.fp4_quantize)
self.assertIsNone(nvfp4_module.mm_fp4)
def test_module_import_with_flashinfer(self):
"""Test module reloading when flashinfer is available."""
"""Test module reloading when flashinfer is available (Blackwell GPU)."""
# Create mock flashinfer module with required functions
mock_flashinfer = types.ModuleType("flashinfer")
with mock.patch.dict(sys.modules, {"flashinfer": mock_flashinfer}):
with mock.patch("paddle.compat.enable_torch_proxy"):
importlib.reload(nvfp4_module)
mock_flashinfer.fp4_quantize = mock.Mock()
mock_flashinfer.mm_fp4 = mock.Mock()
mock_fused_moe = types.ModuleType("flashinfer.fused_moe")
mock_fused_moe.cutlass_fused_moe = mock.Mock()
mock_flashinfer.fused_moe = mock_fused_moe
# Mock is_nvfp4_supported at the source (quant_base) to return True (simulating B-card)
with (
mock.patch(
"fastdeploy.model_executor.layers.quantization.quant_base.is_nvfp4_supported",
return_value=True,
),
mock.patch.dict(sys.modules, {"flashinfer": mock_flashinfer, "flashinfer.fused_moe": mock_fused_moe}),
mock.patch("paddle.compat.enable_torch_proxy"),
):
importlib.reload(nvfp4_module)
# Verify that flashinfer imports succeeded
self.assertIsNotNone(nvfp4_module.fp4_quantize)
self.assertIsNotNone(nvfp4_module.mm_fp4)
class TestModelOptNvFp4ConfigValidation(unittest.TestCase):
@@ -328,11 +363,15 @@ class TestModelOptNvFp4LinearMethod(unittest.TestCase):
"""Test the apply() method with flashinfer-cutlass backend for Linear layers."""
def fake_fp4_quantize(x, input_scale_inv):
# NVFP4 packs two 4-bit values into one uint8, so shape stays the same
# but the actual packed dimension is K//2 in terms of elements
x_fp4 = paddle.zeros(x.shape, dtype=paddle.uint8)
x_scale_interleaved = paddle.zeros(x.shape, dtype=paddle.uint8)
# Scale shape should match the packed K dimension
x_scale_interleaved = paddle.zeros([x.shape[0], x.shape[1]], dtype=paddle.uint8)
return x_fp4, x_scale_interleaved
def fake_fp4_gemm(x_fp4, w, x_scale_interleaved, w_scale_interleaved, alpha, output_dtype, backend=None):
# Simply return zeros with correct output shape
return paddle.zeros([x_fp4.shape[0], w.shape[1]], dtype=output_dtype)
prev_flashinfer, prev_fused = _install_fake_flashinfer(fp4_quantize=fake_fp4_quantize, mm_fp4=fake_fp4_gemm)
@@ -341,6 +380,9 @@ class TestModelOptNvFp4LinearMethod(unittest.TestCase):
mock.patch.dict(os.environ, {"FD_MOE_BACKEND": "flashinfer-cutlass"}),
mock.patch.object(nvfp4_module.paddle, "float8_e4m3fn", paddle.uint8),
mock.patch.object(nvfp4_module, "free_tensor", side_effect=lambda _: None),
# Patch the module-level imports to use our fake functions
mock.patch.object(nvfp4_module, "fp4_quantize", fake_fp4_quantize),
mock.patch.object(nvfp4_module, "mm_fp4", fake_fp4_gemm),
):
method = ModelOptNvFp4LinearMethod(
ModelOptNvFp4Config(True, kv_cache_quant_algo=None, exclude_modules=[], group_size=16)
@@ -352,7 +394,9 @@ class TestModelOptNvFp4LinearMethod(unittest.TestCase):
layer.weight_scale_2.set_value(paddle.ones([1], dtype=paddle.float32))
layer.weight_scale.set_value(paddle.ones(layer.weight_scale.shape, dtype=paddle.uint8))
method.process_weights_after_loading(layer)
x = paddle.ones([2, layer.weight.shape[1]], dtype=paddle.float16)
# Input dimension should be K (original, not packed)
# layer.weight_shape[0] = K = 32
x = paddle.ones([2, layer.weight_shape[0]], dtype=paddle.float16)
out = method.apply(layer, x)
self.assertEqual(list(out.shape), [2, layer.weight.shape[0]])
finally:
@@ -380,6 +424,8 @@ class TestModelOptNvFp4LinearMethod(unittest.TestCase):
mock.patch.dict(os.environ, {"FD_MOE_BACKEND": "flashinfer-cutlass"}),
mock.patch.object(nvfp4_module.paddle, "float8_e4m3fn", paddle.float16),
mock.patch.object(nvfp4_module, "free_tensor", side_effect=lambda _: None),
# Patch the module-level fp4_quantize for H-card (SM 90) where it's None
mock.patch.object(nvfp4_module, "fp4_quantize", fake_fp4_quantize),
):
method = ModelOptNvFp4LinearMethod(
ModelOptNvFp4Config(True, kv_cache_quant_algo=None, exclude_modules=[], group_size=16)
@@ -392,7 +438,9 @@ class TestModelOptNvFp4LinearMethod(unittest.TestCase):
method.process_weights_after_loading(layer)
method.backend = "unsupported"
with self.assertRaises(ValueError):
method.apply(layer, paddle.ones([2, layer.weight.shape[1]], dtype=paddle.float16))
# Input dimension should be K (original, not packed)
x = paddle.ones([2, layer.weight_shape[0]], dtype=paddle.float16)
method.apply(layer, x)
finally:
# Restore original modules to avoid affecting other tests
if prev_flashinfer is None:
@@ -479,6 +527,8 @@ class TestModelOptNvFp4FusedMoE(unittest.TestCase):
mock.patch.dict(os.environ, {"FD_MOE_BACKEND": "flashinfer-cutlass"}),
mock.patch.object(nvfp4_module.paddle, "float8_e4m3fn", paddle.float16),
mock.patch.object(nvfp4_module, "free_tensor", side_effect=lambda _: None),
# Patch the module-level import to use our fake function
mock.patch.object(nvfp4_module, "flashinfer_cutlass_fused_moe", fake_cutlass_fused_moe),
):
method = ModelOptNvFp4FusedMoE(
ModelOptNvFp4Config(True, kv_cache_quant_algo=None, exclude_modules=[], group_size=16)