mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[BugFix] fix flashinfer-cutedsl moe nvfp4 (#7120)
* fix nvfp4 * fix * add document * fix nvfp4 * support eb5 * support bka * support eb5 * support xpu * fix * fix * add import cutedsl * fix * fix * fix test * fix H卡 * update document * fix * update document * update document * fix
This commit is contained in:
@@ -18,7 +18,38 @@ Based on [FlashInfer](https://github.com/flashinfer-ai/flashinfer), Fastdeploy s
|
|||||||
Please ensure that FastDeploy is installed with NVIDIA GPU support.
|
Please ensure that FastDeploy is installed with NVIDIA GPU support.
|
||||||
Follow the official guide to set up the base environment: [Fastdeploy NVIDIA GPU Environment Installation Guide](https://paddlepaddle.github.io/FastDeploy/get_started/installation/nvidia_gpu/).
|
Follow the official guide to set up the base environment: [Fastdeploy NVIDIA GPU Environment Installation Guide](https://paddlepaddle.github.io/FastDeploy/get_started/installation/nvidia_gpu/).
|
||||||
|
|
||||||
### Running Inference Service
|
### FlashInfer-cutedsl backend
|
||||||
|
|
||||||
|
#### PaddlePaddle Compatibility Patches for FlashInfer
|
||||||
|
|
||||||
|
Due to compatibility issues between FlashInfer and PaddlePaddle, you need to apply the following patches in `miniconda/envs/<your_env>/lib/python3.10/site-packages/`:
|
||||||
|
|
||||||
|
1. **nvidia_cutlass_dsl/python_packages/cutlass/torch.py**
|
||||||
|
|
||||||
|
Replace `torch.device` with `"torch.device"` (as a string to avoid conflicts).
|
||||||
|
|
||||||
|
2. **flashinfer/utils.py**
|
||||||
|
|
||||||
|
Modify the `get_compute_capability` function:
|
||||||
|
```bash
|
||||||
|
@functools.cache
|
||||||
|
def get_compute_capability(device: torch.device) -> Tuple[int, int]:
|
||||||
|
return torch.cuda.get_device_capability(device)
|
||||||
|
if device.type != "cuda":
|
||||||
|
raise ValueError("device must be a cuda device")
|
||||||
|
return torch.cuda.get_device_capability(device.index)
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **flashinfer/cute_dsl/blockscaled_gemm.py**
|
||||||
|
|
||||||
|
Replace `cutlass_torch.current_stream()` with:
|
||||||
|
```bash
|
||||||
|
cuda.CUstream(torch.cuda.current_stream().stream_base.raw_stream)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Running Inference Service
|
||||||
|
|
||||||
|
flashinfer-cutlass backend:
|
||||||
```bash
|
```bash
|
||||||
python -m fastdeploy.entrypoints.openai.api_server \
|
python -m fastdeploy.entrypoints.openai.api_server \
|
||||||
--model nv-community/Qwen3-30B-A3B-FP4 \
|
--model nv-community/Qwen3-30B-A3B-FP4 \
|
||||||
@@ -31,6 +62,26 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
|||||||
--max-num-seqs 128
|
--max-num-seqs 128
|
||||||
```
|
```
|
||||||
|
|
||||||
|
flashinfer-cutedsl backend:
|
||||||
|
```bash
|
||||||
|
python -m fastdeploy.entrypoints.openai.multi_api_server \
|
||||||
|
--ports "9811,9812,9813,9814" \
|
||||||
|
--num-servers 4 \
|
||||||
|
--model ERNIE-4.5-21B-A3B-FP4 \
|
||||||
|
--disable-custom-all-reduce \
|
||||||
|
--tensor-parallel-size 1 \
|
||||||
|
--data-parallel-size 4 \
|
||||||
|
--no-enable-prefix-caching \
|
||||||
|
--max-model-len 65536 \
|
||||||
|
--enable-expert-parallel \
|
||||||
|
--num-gpu-blocks-override 8192 \
|
||||||
|
--max-num-seqs 4 \
|
||||||
|
--gpu-memory-utilization 0.9 \
|
||||||
|
--max-num-batched-tokens 512 \
|
||||||
|
--ep-prefill-use-worst-num-tokens \
|
||||||
|
--graph-optimization-config '{"use_cudagraph":false}'
|
||||||
|
```
|
||||||
|
|
||||||
### API Access
|
### API Access
|
||||||
Make service requests using the following command
|
Make service requests using the following command
|
||||||
|
|
||||||
@@ -43,6 +94,15 @@ curl -X POST "http://0.0.0.0:8180/v1/chat/completions" \
|
|||||||
]
|
]
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
```shell
|
||||||
|
curl -X POST "http://0.0.0.0:9811/v1/chat/completions" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "把李白的静夜思改写为现代诗"}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
FastDeploy service interface is compatible with OpenAI protocol. You can make service requests using the following Python code.
|
FastDeploy service interface is compatible with OpenAI protocol. You can make service requests using the following Python code.
|
||||||
|
|
||||||
@@ -64,4 +124,4 @@ for chunk in response:
|
|||||||
if chunk.choices[0].delta:
|
if chunk.choices[0].delta:
|
||||||
print(chunk.choices[0].delta.content, end='')
|
print(chunk.choices[0].delta.content, end='')
|
||||||
print('\n')
|
print('\n')
|
||||||
```.
|
```
|
||||||
|
|||||||
@@ -18,6 +18,8 @@ NVFP4 是 NVIDIA 引入的创新 4 位浮点格式,详细介绍请参考[Intro
|
|||||||
FastDeploy 需以 NVIDIA GPU 模式安装,具体安装方式请参考官方文档:[Fastdeploy NVIDIA GPU 环境安装指南](https://paddlepaddle.github.io/FastDeploy/zh/get_started/installation/nvidia_gpu/)。
|
FastDeploy 需以 NVIDIA GPU 模式安装,具体安装方式请参考官方文档:[Fastdeploy NVIDIA GPU 环境安装指南](https://paddlepaddle.github.io/FastDeploy/zh/get_started/installation/nvidia_gpu/)。
|
||||||
|
|
||||||
### 运行推理服务
|
### 运行推理服务
|
||||||
|
|
||||||
|
flashinfer-cutlass后端:
|
||||||
```bash
|
```bash
|
||||||
python -m fastdeploy.entrypoints.openai.api_server \
|
python -m fastdeploy.entrypoints.openai.api_server \
|
||||||
--model nv-community/Qwen3-30B-A3B-FP4 \
|
--model nv-community/Qwen3-30B-A3B-FP4 \
|
||||||
@@ -30,6 +32,62 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
|||||||
--max-num-seqs 128
|
--max-num-seqs 128
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### flashinfer-cutedsl后端:
|
||||||
|
|
||||||
|
#### PaddlePaddle 兼容性补丁
|
||||||
|
|
||||||
|
由于 FlashInfer 与 PaddlePaddle 之间存在兼容性问题,需要在 `miniconda/envs/<your_env>/lib/python3.10/site-packages/` 中应用以下补丁:
|
||||||
|
|
||||||
|
1. **nvidia_cutlass_dsl/python_packages/cutlass/torch.py**
|
||||||
|
|
||||||
|
将 `torch.device` 替换为 `"torch.device"`(作为字符串以避免冲突)。
|
||||||
|
|
||||||
|
2. **flashinfer/utils.py**
|
||||||
|
|
||||||
|
修改 `get_compute_capability` 函数:
|
||||||
|
```bash
|
||||||
|
@functools.cache
|
||||||
|
def get_compute_capability(device: torch.device) -> Tuple[int, int]:
|
||||||
|
return torch.cuda.get_device_capability(device)
|
||||||
|
if device.type != "cuda":
|
||||||
|
raise ValueError("device must be a cuda device")
|
||||||
|
return torch.cuda.get_device_capability(device.index)
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **flashinfer/cute_dsl/blockscaled_gemm.py**
|
||||||
|
|
||||||
|
将 `cutlass_torch.current_stream()` 替换为:
|
||||||
|
```bash
|
||||||
|
cuda.CUstream(torch.cuda.current_stream().stream_base.raw_stream)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 运行推理服务
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export FD_MOE_BACKEND="flashinfer-cutedsl"
|
||||||
|
export FD_USE_PFCC_DEEP_EP=1
|
||||||
|
export CUDA_VISIBLE_DEVICES=4,5,6,7
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
python -m fastdeploy.entrypoints.openai.multi_api_server \
|
||||||
|
--ports "9811,9812,9813,9814" \
|
||||||
|
--num-servers 4 \
|
||||||
|
--model ERNIE-4.5-21B-A3B-FP4 \
|
||||||
|
--disable-custom-all-reduce \
|
||||||
|
--tensor-parallel-size 1 \
|
||||||
|
--data-parallel-size 4 \
|
||||||
|
--no-enable-prefix-caching \
|
||||||
|
--max-model-len 65536 \
|
||||||
|
--enable-expert-parallel \
|
||||||
|
--num-gpu-blocks-override 8192 \
|
||||||
|
--max-num-seqs 4 \
|
||||||
|
--gpu-memory-utilization 0.9 \
|
||||||
|
--max-num-batched-tokens 512 \
|
||||||
|
--ep-prefill-use-worst-num-tokens \
|
||||||
|
--graph-optimization-config '{"use_cudagraph":false}'
|
||||||
|
```
|
||||||
|
|
||||||
### 接口访问
|
### 接口访问
|
||||||
通过如下命令发起服务请求
|
通过如下命令发起服务请求
|
||||||
|
|
||||||
@@ -42,6 +100,15 @@ curl -X POST "http://0.0.0.0:8180/v1/chat/completions" \
|
|||||||
]
|
]
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
```shell
|
||||||
|
curl -X POST "http://0.0.0.0:9811/v1/chat/completions" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "把李白的静夜思改写为现代诗"}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
FastDeploy服务接口兼容OpenAI协议,可以通过如下Python代码发起服务请求。
|
FastDeploy服务接口兼容OpenAI协议,可以通过如下Python代码发起服务请求。
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -75,7 +75,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
# Set moe backend."cutlass","marlin", "triton", "flashinfer-cutlass", "flashinfer-cutedsl" and "flashinfer-trtllm" can be set currently.
|
# Set moe backend."cutlass","marlin", "triton", "flashinfer-cutlass", "flashinfer-cutedsl" and "flashinfer-trtllm" can be set currently.
|
||||||
"FD_MOE_BACKEND": lambda: os.getenv("FD_MOE_BACKEND", "cutlass"),
|
"FD_MOE_BACKEND": lambda: os.getenv("FD_MOE_BACKEND", "cutlass"),
|
||||||
# Set nvfp4 load interleaved weight scale.
|
# Set nvfp4 load interleaved weight scale.
|
||||||
"FD_NVFP4_LOAD_BLOCKSCALE_LEAVE": lambda: os.getenv("FD_NVFP4_LOAD_BLOCKSCALE_LEAVE", "0"),
|
"FD_NVFP4_LOAD_BLOCKSCALE_LEAVE": lambda: bool(int(os.getenv("FD_NVFP4_LOAD_BLOCKSCALE_LEAVE", "0"))),
|
||||||
# Set mxfp4 backend."flashinfer" can be set currently.
|
# Set mxfp4 backend."flashinfer" can be set currently.
|
||||||
"FD_MOE_MXFP4_BACKEND": lambda: os.getenv("FD_MOE_MXFP4_BACKEND", "flashinfer"),
|
"FD_MOE_MXFP4_BACKEND": lambda: os.getenv("FD_MOE_MXFP4_BACKEND", "flashinfer"),
|
||||||
# Whether to use Machete for wint4 dense gemm.
|
# Whether to use Machete for wint4 dense gemm.
|
||||||
|
|||||||
@@ -18,7 +18,20 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
|
|
||||||
paddle.compat.enable_torch_proxy(scope={"flashinfer"})
|
from fastdeploy.model_executor.layers.quantization.quant_base import is_nvfp4_supported
|
||||||
|
|
||||||
|
# Only import flashinfer on supported GPUs (B卡)
|
||||||
|
if is_nvfp4_supported():
|
||||||
|
from flashinfer import (
|
||||||
|
scaled_fp4_grouped_quantize,
|
||||||
|
silu_and_mul_scaled_nvfp4_experts_quantize,
|
||||||
|
)
|
||||||
|
from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked
|
||||||
|
else:
|
||||||
|
# Not B卡, skip flashinfer imports
|
||||||
|
scaled_fp4_grouped_quantize = None
|
||||||
|
silu_and_mul_scaled_nvfp4_experts_quantize = None
|
||||||
|
grouped_gemm_nt_masked = None
|
||||||
|
|
||||||
|
|
||||||
def _dtype_str(dtype) -> str:
|
def _dtype_str(dtype) -> str:
|
||||||
@@ -87,11 +100,6 @@ def flashinfer_cutedsl_moe_masked(
|
|||||||
Returns:
|
Returns:
|
||||||
paddle.Tensor: [num_experts, m, k] bf16
|
paddle.Tensor: [num_experts, m, k] bf16
|
||||||
"""
|
"""
|
||||||
from flashinfer import (
|
|
||||||
scaled_fp4_grouped_quantize,
|
|
||||||
silu_and_mul_scaled_nvfp4_experts_quantize,
|
|
||||||
)
|
|
||||||
from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked
|
|
||||||
|
|
||||||
# === Dtype assertions ===
|
# === Dtype assertions ===
|
||||||
# Use string-based dtype check to be compatible with both paddle and torch proxy tensors
|
# Use string-based dtype check to be compatible with both paddle and torch proxy tensors
|
||||||
|
|||||||
@@ -88,6 +88,7 @@ def parse_quant_config(args, model_config, is_ernie, is_v1_loader):
|
|||||||
quant_config_name = _get_offline_quant_config_name(
|
quant_config_name = _get_offline_quant_config_name(
|
||||||
quantization_config, model_config.model_format == "torch", is_v1_loader
|
quantization_config, model_config.model_format == "torch", is_v1_loader
|
||||||
)
|
)
|
||||||
|
|
||||||
elif args.quantization is not None:
|
elif args.quantization is not None:
|
||||||
quantization_config = {}
|
quantization_config = {}
|
||||||
try:
|
try:
|
||||||
@@ -161,7 +162,10 @@ def get_quantization_config(quantization: str) -> Type[QuantConfigBase]:
|
|||||||
from .block_wise_fp8 import BlockWiseFP8Config
|
from .block_wise_fp8 import BlockWiseFP8Config
|
||||||
from .kv_cache import KvCacheQuantConfig
|
from .kv_cache import KvCacheQuantConfig
|
||||||
from .mix_quant import MixQuantConfig
|
from .mix_quant import MixQuantConfig
|
||||||
from .nvfp4 import ModelOptNvFp4Config
|
|
||||||
|
if quantization == "modelopt_fp4":
|
||||||
|
from .nvfp4 import ModelOptNvFp4Config
|
||||||
|
|
||||||
from .tensor_wise_fp8 import TensorWiseFP8Config
|
from .tensor_wise_fp8 import TensorWiseFP8Config
|
||||||
from .w4a8 import W4A8Config
|
from .w4a8 import W4A8Config
|
||||||
from .w4afp8 import W4AFP8Config
|
from .w4afp8 import W4AFP8Config
|
||||||
@@ -186,9 +190,10 @@ def get_quantization_config(quantization: str) -> Type[QuantConfigBase]:
|
|||||||
"tensor_wise_fp8": TensorWiseFP8Config,
|
"tensor_wise_fp8": TensorWiseFP8Config,
|
||||||
"kvcache": KvCacheQuantConfig,
|
"kvcache": KvCacheQuantConfig,
|
||||||
"mix_quant": MixQuantConfig,
|
"mix_quant": MixQuantConfig,
|
||||||
"modelopt_fp4": ModelOptNvFp4Config,
|
|
||||||
}
|
}
|
||||||
if envs.FD_MOE_MXFP4_BACKEND is not None:
|
if envs.FD_MOE_MXFP4_BACKEND is not None:
|
||||||
method_to_config["mxfp4"] = MXFP4Config
|
method_to_config["mxfp4"] = MXFP4Config
|
||||||
|
if quantization == "modelopt_fp4":
|
||||||
|
method_to_config["modelopt_fp4"] = ModelOptNvFp4Config
|
||||||
|
|
||||||
return method_to_config[quantization]
|
return method_to_config[quantization]
|
||||||
|
|||||||
@@ -28,75 +28,105 @@ from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMetho
|
|||||||
from fastdeploy.model_executor.utils import (
|
from fastdeploy.model_executor.utils import (
|
||||||
create_parameter_and_copy,
|
create_parameter_and_copy,
|
||||||
free_tensor,
|
free_tensor,
|
||||||
|
get_sm_version,
|
||||||
set_weight_attrs,
|
set_weight_attrs,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
from .quant_base import QuantConfigBase, QuantMethodBase, is_nvfp4_supported
|
||||||
|
|
||||||
paddle.compat.enable_torch_proxy(scope={"flashinfer"})
|
# Only import flashinfer on supported GPUs (B卡)
|
||||||
|
if is_nvfp4_supported():
|
||||||
|
paddle.compat.enable_torch_proxy(scope={"flashinfer"})
|
||||||
|
|
||||||
from fastdeploy.platforms import current_platform
|
from flashinfer import fp4_quantize, mm_fp4
|
||||||
|
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
|
||||||
|
|
||||||
if current_platform.is_cuda():
|
from fastdeploy.model_executor.layers.moe.ep import deep_ep
|
||||||
from fastdeploy.model_executor.ops.gpu import (
|
from fastdeploy.model_executor.ops.gpu import (
|
||||||
depermute_prefill_combine,
|
depermute_prefill_combine,
|
||||||
prefill_permute_to_masked_gemm,
|
prefill_permute_to_masked_gemm,
|
||||||
)
|
)
|
||||||
|
|
||||||
def call_prefill_permute_to_masked_gemm(
|
if envs.FD_MOE_BACKEND == "flashinfer-cutedsl":
|
||||||
x: paddle.Tensor,
|
logger.info(
|
||||||
scale: paddle.Tensor,
|
"FlashInfer cutedsl is slow to import because it triggers JIT compilation of "
|
||||||
topk_ids: paddle.Tensor,
|
"CUDA kernels via TVM/CODEGEN, and cuBLASLt initializes lookup tables and "
|
||||||
num_local_experts: int,
|
"compiles GEMM kernels during first load. This may take several minutes. "
|
||||||
max_token_num: int,
|
"The wait is expected and only happens once per process."
|
||||||
):
|
)
|
||||||
"""
|
from fastdeploy.model_executor.layers.moe.flashinfer_cutedsl_moe import (
|
||||||
Permute input tokens and scales from token-major to expert-major layout
|
flashinfer_cutedsl_moe_masked,
|
||||||
for MoE masked GEMM operations.
|
)
|
||||||
|
else:
|
||||||
|
# Not B卡, skip flashinfer imports
|
||||||
|
deep_ep = None
|
||||||
|
depermute_prefill_combine = None
|
||||||
|
prefill_permute_to_masked_gemm = None
|
||||||
|
flashinfer_cutedsl_moe_masked = None
|
||||||
|
fp4_quantize = None
|
||||||
|
mm_fp4 = None
|
||||||
|
flashinfer_cutlass_fused_moe = None
|
||||||
|
logger.warning(
|
||||||
|
f"NVFP4 requires Blackwell GPU (SM >= 100), "
|
||||||
|
f"current GPU has SM {get_sm_version()}. Skipping flashinfer imports."
|
||||||
|
)
|
||||||
|
|
||||||
Args:
|
|
||||||
x: Input hidden states [num_tokens, hidden].
|
|
||||||
scale: Input scales [num_tokens, hidden_scale].
|
|
||||||
topk_ids: Expert routing indices [num_tokens, topk] (int64 or int32).
|
|
||||||
num_local_experts: Number of local experts on this device.
|
|
||||||
max_token_num: Maximum tokens per expert buffer.
|
|
||||||
|
|
||||||
Returns:
|
def call_prefill_permute_to_masked_gemm(
|
||||||
tuple: (permute_x, permute_scale, permuted_indice_map, token_nums_per_expert)
|
x: paddle.Tensor,
|
||||||
"""
|
scale: paddle.Tensor,
|
||||||
if topk_ids.dtype != paddle.int64:
|
topk_ids: paddle.Tensor,
|
||||||
topk_ids = topk_ids.cast(paddle.int64)
|
num_local_experts: int,
|
||||||
|
max_token_num: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Permute input tokens and scales from token-major to expert-major layout
|
||||||
|
for MoE masked GEMM operations.
|
||||||
|
|
||||||
# NVFP4 dispatch returns plain BF16 (no fp8 scale); pass empty tensor so the
|
Args:
|
||||||
# C++ op can detect the no-scale path via tensor.numel() == 0.
|
x: Input hidden states [num_tokens, hidden].
|
||||||
if scale is None:
|
scale: Input scales [num_tokens, hidden_scale].
|
||||||
scale = paddle.empty([0], dtype=paddle.float32)
|
topk_ids: Expert routing indices [num_tokens, topk] (int64 or int32).
|
||||||
|
num_local_experts: Number of local experts on this device.
|
||||||
|
max_token_num: Maximum tokens per expert buffer.
|
||||||
|
|
||||||
results = prefill_permute_to_masked_gemm(x, scale, topk_ids, num_local_experts, max_token_num)
|
Returns:
|
||||||
|
tuple: (permute_x, permute_scale, permuted_indice_map, token_nums_per_expert)
|
||||||
|
"""
|
||||||
|
if topk_ids.dtype != paddle.int64:
|
||||||
|
topk_ids = topk_ids.cast(paddle.int64)
|
||||||
|
|
||||||
return results[0], results[1], results[2], results[3]
|
# NVFP4 dispatch returns plain BF16 (no fp8 scale); pass empty tensor so the
|
||||||
|
# C++ op can detect the no-scale path via tensor.numel() == 0.
|
||||||
|
if scale is None:
|
||||||
|
scale = paddle.empty([0], dtype=paddle.float32)
|
||||||
|
|
||||||
def call_depermute_prefill_combine(
|
results = prefill_permute_to_masked_gemm(x, scale, topk_ids, num_local_experts, max_token_num)
|
||||||
x: paddle.Tensor,
|
|
||||||
indice_map: paddle.Tensor,
|
|
||||||
topk_weights: paddle.Tensor,
|
|
||||||
num_worst_tokens: int,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Depermute and combine expert outputs back to token-major layout.
|
|
||||||
|
|
||||||
Args:
|
return results[0], results[1], results[2], results[3]
|
||||||
x: Expert outputs [num_local_experts, max_tokens_per_expert, hidden].
|
|
||||||
indice_map: Flat index tensor [num_worst_tokens, topk] (int32).
|
|
||||||
topk_weights: Combination weights [num_worst_tokens, topk] (float32).
|
|
||||||
num_worst_tokens: Number of output tokens to produce.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
depermuted_x: Combined output [num_worst_tokens, hidden].
|
|
||||||
"""
|
|
||||||
results = depermute_prefill_combine(x, indice_map, topk_weights, num_worst_tokens)
|
|
||||||
|
|
||||||
return results
|
def call_depermute_prefill_combine(
|
||||||
|
x: paddle.Tensor,
|
||||||
|
indice_map: paddle.Tensor,
|
||||||
|
topk_weights: paddle.Tensor,
|
||||||
|
num_worst_tokens: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Depermute and combine expert outputs back to token-major layout.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Expert outputs [num_local_experts, max_tokens_per_expert, hidden].
|
||||||
|
indice_map: Flat index tensor [num_worst_tokens, topk] (int32).
|
||||||
|
topk_weights: Combination weights [num_worst_tokens, topk] (float32).
|
||||||
|
num_worst_tokens: Number of output tokens to produce.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
depermuted_x: Combined output [num_worst_tokens, hidden].
|
||||||
|
"""
|
||||||
|
results = depermute_prefill_combine(x, indice_map, topk_weights, num_worst_tokens)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
def next_power_of_2(n: int):
|
def next_power_of_2(n: int):
|
||||||
@@ -389,8 +419,6 @@ class ModelOptNvFp4LinearMethod(QuantMethodBase):
|
|||||||
output_dtype = x.dtype
|
output_dtype = x.dtype
|
||||||
|
|
||||||
# Quantize BF16 or FP16 to (FP4 and interleaved block scale)
|
# Quantize BF16 or FP16 to (FP4 and interleaved block scale)
|
||||||
from flashinfer import fp4_quantize
|
|
||||||
|
|
||||||
x_fp4, x_scale_interleaved = fp4_quantize(x, layer.input_scale_inv)
|
x_fp4, x_scale_interleaved = fp4_quantize(x, layer.input_scale_inv)
|
||||||
|
|
||||||
assert x_fp4.dtype == paddle.uint8
|
assert x_fp4.dtype == paddle.uint8
|
||||||
@@ -409,9 +437,8 @@ class ModelOptNvFp4LinearMethod(QuantMethodBase):
|
|||||||
if backend == "cutlass":
|
if backend == "cutlass":
|
||||||
x_scale_interleaved = x_scale_interleaved.view(paddle.uint8)
|
x_scale_interleaved = x_scale_interleaved.view(paddle.uint8)
|
||||||
w_scale_interleaved = w_scale_interleaved.view(paddle.uint8)
|
w_scale_interleaved = w_scale_interleaved.view(paddle.uint8)
|
||||||
from flashinfer import mm_fp4 as fp4_gemm
|
|
||||||
|
|
||||||
out = fp4_gemm(x_fp4, w, x_scale_interleaved, w_scale_interleaved, layer.alpha, output_dtype, backend=backend)
|
out = mm_fp4(x_fp4, w, x_scale_interleaved, w_scale_interleaved, layer.alpha, output_dtype, backend=backend)
|
||||||
if layer.with_bias:
|
if layer.with_bias:
|
||||||
out = paddle.add(out, layer.bias)
|
out = paddle.add(out, layer.bias)
|
||||||
assert out.shape == output_shape
|
assert out.shape == output_shape
|
||||||
@@ -564,9 +591,14 @@ class ModelOptNvFp4FusedMoE(MoEMethodBase):
|
|||||||
set_weight_attrs(layer.up_gate_proj_input_scale, {**extra_weight_attrs, "weight_type": "input_scale"})
|
set_weight_attrs(layer.up_gate_proj_input_scale, {**extra_weight_attrs, "weight_type": "input_scale"})
|
||||||
set_weight_attrs(layer.down_proj_input_scale, {**extra_weight_attrs, "weight_type": "input_scale"})
|
set_weight_attrs(layer.down_proj_input_scale, {**extra_weight_attrs, "weight_type": "input_scale"})
|
||||||
|
|
||||||
|
@property
|
||||||
|
def load_up_proj_weight_first(self) -> bool:
|
||||||
|
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
|
||||||
|
if self.backend == "flashinfer-cutlass":
|
||||||
|
return True
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer):
|
def process_weights_after_loading(self, layer):
|
||||||
""" """
|
""" """
|
||||||
|
|
||||||
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
|
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
|
||||||
|
|
||||||
if self.backend == "flashinfer-cutlass":
|
if self.backend == "flashinfer-cutlass":
|
||||||
@@ -606,18 +638,20 @@ class ModelOptNvFp4FusedMoE(MoEMethodBase):
|
|||||||
up_gate_proj_blockscale_swizzled = layer.up_gate_proj_weight_scale
|
up_gate_proj_blockscale_swizzled = layer.up_gate_proj_weight_scale
|
||||||
else:
|
else:
|
||||||
up_gate_proj_blockscale_swizzled = _process_scale_interleaved(layer.up_gate_proj_weight_scale)
|
up_gate_proj_blockscale_swizzled = _process_scale_interleaved(layer.up_gate_proj_weight_scale)
|
||||||
free_tensor(layer.up_gate_proj_weight_scale)
|
|
||||||
layer.up_gate_proj_weight_scale = None
|
|
||||||
create_parameter_and_copy(
|
create_parameter_and_copy(
|
||||||
layer, name="up_gate_proj_blockscale_swizzled", weight=up_gate_proj_blockscale_swizzled
|
layer, name="up_gate_proj_blockscale_swizzled", weight=up_gate_proj_blockscale_swizzled
|
||||||
)
|
)
|
||||||
|
|
||||||
|
free_tensor(layer.up_gate_proj_weight_scale)
|
||||||
|
layer.up_gate_proj_weight_scale = None
|
||||||
|
|
||||||
if envs.FD_NVFP4_LOAD_BLOCKSCALE_LEAVE:
|
if envs.FD_NVFP4_LOAD_BLOCKSCALE_LEAVE:
|
||||||
down_proj_blockscale_swizzled = layer.down_proj_weight_scale
|
down_proj_blockscale_swizzled = layer.down_proj_weight_scale
|
||||||
else:
|
else:
|
||||||
down_proj_blockscale_swizzled = _process_scale_interleaved(layer.down_proj_weight_scale)
|
down_proj_blockscale_swizzled = _process_scale_interleaved(layer.down_proj_weight_scale)
|
||||||
|
create_parameter_and_copy(layer, name="down_proj_blockscale_swizzled", weight=down_proj_blockscale_swizzled)
|
||||||
free_tensor(layer.down_proj_weight_scale)
|
free_tensor(layer.down_proj_weight_scale)
|
||||||
layer.down_proj_weight_scale = None
|
layer.down_proj_weight_scale = None
|
||||||
create_parameter_and_copy(layer, name="down_proj_blockscale_swizzled", weight=down_proj_blockscale_swizzled)
|
|
||||||
|
|
||||||
def apply_ep_prefill(
|
def apply_ep_prefill(
|
||||||
self,
|
self,
|
||||||
@@ -628,11 +662,6 @@ class ModelOptNvFp4FusedMoE(MoEMethodBase):
|
|||||||
shared_experts: nn.Layer = None,
|
shared_experts: nn.Layer = None,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
|
|
||||||
from fastdeploy.model_executor.layers.moe.ep import deep_ep
|
|
||||||
from fastdeploy.model_executor.layers.moe.flashinfer_cutedsl_moe import (
|
|
||||||
flashinfer_cutedsl_moe_masked,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1. top experts and weights
|
# 1. top experts and weights
|
||||||
gate_out = gate(x.cast("float32"))
|
gate_out = gate(x.cast("float32"))
|
||||||
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)
|
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)
|
||||||
@@ -741,10 +770,6 @@ class ModelOptNvFp4FusedMoE(MoEMethodBase):
|
|||||||
shared_experts: nn.Layer = None,
|
shared_experts: nn.Layer = None,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
|
|
||||||
from fastdeploy.model_executor.layers.moe.flashinfer_cutedsl_moe import (
|
|
||||||
flashinfer_cutedsl_moe_masked,
|
|
||||||
)
|
|
||||||
|
|
||||||
gate_out = gate(x.cast("float32"))
|
gate_out = gate(x.cast("float32"))
|
||||||
topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out)
|
topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out)
|
||||||
|
|
||||||
@@ -803,10 +828,6 @@ class ModelOptNvFp4FusedMoE(MoEMethodBase):
|
|||||||
output = paddle.empty_like(x)
|
output = paddle.empty_like(x)
|
||||||
|
|
||||||
# flashinfer cutlass
|
# flashinfer cutlass
|
||||||
from flashinfer.fused_moe import (
|
|
||||||
cutlass_fused_moe as flashinfer_cutlass_fused_moe,
|
|
||||||
)
|
|
||||||
|
|
||||||
_ = flashinfer_cutlass_fused_moe(
|
_ = flashinfer_cutlass_fused_moe(
|
||||||
input=x,
|
input=x,
|
||||||
token_selected_experts=topk_ids.to(paddle.int),
|
token_selected_experts=topk_ids.to(paddle.int),
|
||||||
|
|||||||
@@ -17,6 +17,22 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
# NVFP4 requires SM >= 100 (Blackwell architecture)
|
||||||
|
NVFP4_MIN_SM_VERSION = 100
|
||||||
|
|
||||||
|
from fastdeploy.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
|
def is_nvfp4_supported() -> bool:
|
||||||
|
if current_platform.is_cuda():
|
||||||
|
"""Check if current GPU supports NVFP4 (requires SM >= 100, Blackwell)."""
|
||||||
|
from fastdeploy.model_executor.utils import get_sm_version
|
||||||
|
|
||||||
|
sm_version = get_sm_version()
|
||||||
|
return sm_version >= NVFP4_MIN_SM_VERSION
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class QuantMethodBase(ABC):
|
class QuantMethodBase(ABC):
|
||||||
"""Base class for different quantized methods."""
|
"""Base class for different quantized methods."""
|
||||||
|
|||||||
@@ -124,17 +124,52 @@ class TestModelOptNvFp4ModuleInit(unittest.TestCase):
|
|||||||
"""Unit tests for nvfp4 module initialization under different environments."""
|
"""Unit tests for nvfp4 module initialization under different environments."""
|
||||||
|
|
||||||
def test_module_import_without_flashinfer(self):
|
def test_module_import_without_flashinfer(self):
|
||||||
"""Test module reloading when flashinfer is not available."""
|
"""Test module reloading when flashinfer is not available (non-Blackwell GPU)."""
|
||||||
with mock.patch.dict(sys.modules, {"flashinfer": None}):
|
# Mock is_nvfp4_supported at the source (quant_base) to return False
|
||||||
|
# This simulates H-card or non-CUDA platform
|
||||||
|
with mock.patch(
|
||||||
|
"fastdeploy.model_executor.layers.quantization.quant_base.is_nvfp4_supported",
|
||||||
|
return_value=False,
|
||||||
|
):
|
||||||
with mock.patch("paddleformers.utils.log.logger.warning"):
|
with mock.patch("paddleformers.utils.log.logger.warning"):
|
||||||
|
# Clear the module's flashinfer-related attributes before reload
|
||||||
|
# to simulate a fresh import on non-supported GPU
|
||||||
|
if hasattr(nvfp4_module, "fp4_quantize"):
|
||||||
|
delattr(nvfp4_module, "fp4_quantize")
|
||||||
|
if hasattr(nvfp4_module, "mm_fp4"):
|
||||||
|
delattr(nvfp4_module, "mm_fp4")
|
||||||
|
if hasattr(nvfp4_module, "flashinfer_cutlass_fused_moe"):
|
||||||
|
delattr(nvfp4_module, "flashinfer_cutlass_fused_moe")
|
||||||
importlib.reload(nvfp4_module)
|
importlib.reload(nvfp4_module)
|
||||||
|
# Verify that flashinfer imports were skipped
|
||||||
|
self.assertIsNone(nvfp4_module.fp4_quantize)
|
||||||
|
self.assertIsNone(nvfp4_module.mm_fp4)
|
||||||
|
|
||||||
def test_module_import_with_flashinfer(self):
|
def test_module_import_with_flashinfer(self):
|
||||||
"""Test module reloading when flashinfer is available."""
|
"""Test module reloading when flashinfer is available (Blackwell GPU)."""
|
||||||
|
# Create mock flashinfer module with required functions
|
||||||
mock_flashinfer = types.ModuleType("flashinfer")
|
mock_flashinfer = types.ModuleType("flashinfer")
|
||||||
with mock.patch.dict(sys.modules, {"flashinfer": mock_flashinfer}):
|
mock_flashinfer.fp4_quantize = mock.Mock()
|
||||||
with mock.patch("paddle.compat.enable_torch_proxy"):
|
mock_flashinfer.mm_fp4 = mock.Mock()
|
||||||
importlib.reload(nvfp4_module)
|
|
||||||
|
mock_fused_moe = types.ModuleType("flashinfer.fused_moe")
|
||||||
|
mock_fused_moe.cutlass_fused_moe = mock.Mock()
|
||||||
|
mock_flashinfer.fused_moe = mock_fused_moe
|
||||||
|
|
||||||
|
# Mock is_nvfp4_supported at the source (quant_base) to return True (simulating B-card)
|
||||||
|
with (
|
||||||
|
mock.patch(
|
||||||
|
"fastdeploy.model_executor.layers.quantization.quant_base.is_nvfp4_supported",
|
||||||
|
return_value=True,
|
||||||
|
),
|
||||||
|
mock.patch.dict(sys.modules, {"flashinfer": mock_flashinfer, "flashinfer.fused_moe": mock_fused_moe}),
|
||||||
|
mock.patch("paddle.compat.enable_torch_proxy"),
|
||||||
|
):
|
||||||
|
importlib.reload(nvfp4_module)
|
||||||
|
|
||||||
|
# Verify that flashinfer imports succeeded
|
||||||
|
self.assertIsNotNone(nvfp4_module.fp4_quantize)
|
||||||
|
self.assertIsNotNone(nvfp4_module.mm_fp4)
|
||||||
|
|
||||||
|
|
||||||
class TestModelOptNvFp4ConfigValidation(unittest.TestCase):
|
class TestModelOptNvFp4ConfigValidation(unittest.TestCase):
|
||||||
@@ -328,11 +363,15 @@ class TestModelOptNvFp4LinearMethod(unittest.TestCase):
|
|||||||
"""Test the apply() method with flashinfer-cutlass backend for Linear layers."""
|
"""Test the apply() method with flashinfer-cutlass backend for Linear layers."""
|
||||||
|
|
||||||
def fake_fp4_quantize(x, input_scale_inv):
|
def fake_fp4_quantize(x, input_scale_inv):
|
||||||
|
# NVFP4 packs two 4-bit values into one uint8, so shape stays the same
|
||||||
|
# but the actual packed dimension is K//2 in terms of elements
|
||||||
x_fp4 = paddle.zeros(x.shape, dtype=paddle.uint8)
|
x_fp4 = paddle.zeros(x.shape, dtype=paddle.uint8)
|
||||||
x_scale_interleaved = paddle.zeros(x.shape, dtype=paddle.uint8)
|
# Scale shape should match the packed K dimension
|
||||||
|
x_scale_interleaved = paddle.zeros([x.shape[0], x.shape[1]], dtype=paddle.uint8)
|
||||||
return x_fp4, x_scale_interleaved
|
return x_fp4, x_scale_interleaved
|
||||||
|
|
||||||
def fake_fp4_gemm(x_fp4, w, x_scale_interleaved, w_scale_interleaved, alpha, output_dtype, backend=None):
|
def fake_fp4_gemm(x_fp4, w, x_scale_interleaved, w_scale_interleaved, alpha, output_dtype, backend=None):
|
||||||
|
# Simply return zeros with correct output shape
|
||||||
return paddle.zeros([x_fp4.shape[0], w.shape[1]], dtype=output_dtype)
|
return paddle.zeros([x_fp4.shape[0], w.shape[1]], dtype=output_dtype)
|
||||||
|
|
||||||
prev_flashinfer, prev_fused = _install_fake_flashinfer(fp4_quantize=fake_fp4_quantize, mm_fp4=fake_fp4_gemm)
|
prev_flashinfer, prev_fused = _install_fake_flashinfer(fp4_quantize=fake_fp4_quantize, mm_fp4=fake_fp4_gemm)
|
||||||
@@ -341,6 +380,9 @@ class TestModelOptNvFp4LinearMethod(unittest.TestCase):
|
|||||||
mock.patch.dict(os.environ, {"FD_MOE_BACKEND": "flashinfer-cutlass"}),
|
mock.patch.dict(os.environ, {"FD_MOE_BACKEND": "flashinfer-cutlass"}),
|
||||||
mock.patch.object(nvfp4_module.paddle, "float8_e4m3fn", paddle.uint8),
|
mock.patch.object(nvfp4_module.paddle, "float8_e4m3fn", paddle.uint8),
|
||||||
mock.patch.object(nvfp4_module, "free_tensor", side_effect=lambda _: None),
|
mock.patch.object(nvfp4_module, "free_tensor", side_effect=lambda _: None),
|
||||||
|
# Patch the module-level imports to use our fake functions
|
||||||
|
mock.patch.object(nvfp4_module, "fp4_quantize", fake_fp4_quantize),
|
||||||
|
mock.patch.object(nvfp4_module, "mm_fp4", fake_fp4_gemm),
|
||||||
):
|
):
|
||||||
method = ModelOptNvFp4LinearMethod(
|
method = ModelOptNvFp4LinearMethod(
|
||||||
ModelOptNvFp4Config(True, kv_cache_quant_algo=None, exclude_modules=[], group_size=16)
|
ModelOptNvFp4Config(True, kv_cache_quant_algo=None, exclude_modules=[], group_size=16)
|
||||||
@@ -352,7 +394,9 @@ class TestModelOptNvFp4LinearMethod(unittest.TestCase):
|
|||||||
layer.weight_scale_2.set_value(paddle.ones([1], dtype=paddle.float32))
|
layer.weight_scale_2.set_value(paddle.ones([1], dtype=paddle.float32))
|
||||||
layer.weight_scale.set_value(paddle.ones(layer.weight_scale.shape, dtype=paddle.uint8))
|
layer.weight_scale.set_value(paddle.ones(layer.weight_scale.shape, dtype=paddle.uint8))
|
||||||
method.process_weights_after_loading(layer)
|
method.process_weights_after_loading(layer)
|
||||||
x = paddle.ones([2, layer.weight.shape[1]], dtype=paddle.float16)
|
# Input dimension should be K (original, not packed)
|
||||||
|
# layer.weight_shape[0] = K = 32
|
||||||
|
x = paddle.ones([2, layer.weight_shape[0]], dtype=paddle.float16)
|
||||||
out = method.apply(layer, x)
|
out = method.apply(layer, x)
|
||||||
self.assertEqual(list(out.shape), [2, layer.weight.shape[0]])
|
self.assertEqual(list(out.shape), [2, layer.weight.shape[0]])
|
||||||
finally:
|
finally:
|
||||||
@@ -380,6 +424,8 @@ class TestModelOptNvFp4LinearMethod(unittest.TestCase):
|
|||||||
mock.patch.dict(os.environ, {"FD_MOE_BACKEND": "flashinfer-cutlass"}),
|
mock.patch.dict(os.environ, {"FD_MOE_BACKEND": "flashinfer-cutlass"}),
|
||||||
mock.patch.object(nvfp4_module.paddle, "float8_e4m3fn", paddle.float16),
|
mock.patch.object(nvfp4_module.paddle, "float8_e4m3fn", paddle.float16),
|
||||||
mock.patch.object(nvfp4_module, "free_tensor", side_effect=lambda _: None),
|
mock.patch.object(nvfp4_module, "free_tensor", side_effect=lambda _: None),
|
||||||
|
# Patch the module-level fp4_quantize for H-card (SM 90) where it's None
|
||||||
|
mock.patch.object(nvfp4_module, "fp4_quantize", fake_fp4_quantize),
|
||||||
):
|
):
|
||||||
method = ModelOptNvFp4LinearMethod(
|
method = ModelOptNvFp4LinearMethod(
|
||||||
ModelOptNvFp4Config(True, kv_cache_quant_algo=None, exclude_modules=[], group_size=16)
|
ModelOptNvFp4Config(True, kv_cache_quant_algo=None, exclude_modules=[], group_size=16)
|
||||||
@@ -392,7 +438,9 @@ class TestModelOptNvFp4LinearMethod(unittest.TestCase):
|
|||||||
method.process_weights_after_loading(layer)
|
method.process_weights_after_loading(layer)
|
||||||
method.backend = "unsupported"
|
method.backend = "unsupported"
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
method.apply(layer, paddle.ones([2, layer.weight.shape[1]], dtype=paddle.float16))
|
# Input dimension should be K (original, not packed)
|
||||||
|
x = paddle.ones([2, layer.weight_shape[0]], dtype=paddle.float16)
|
||||||
|
method.apply(layer, x)
|
||||||
finally:
|
finally:
|
||||||
# Restore original modules to avoid affecting other tests
|
# Restore original modules to avoid affecting other tests
|
||||||
if prev_flashinfer is None:
|
if prev_flashinfer is None:
|
||||||
@@ -479,6 +527,8 @@ class TestModelOptNvFp4FusedMoE(unittest.TestCase):
|
|||||||
mock.patch.dict(os.environ, {"FD_MOE_BACKEND": "flashinfer-cutlass"}),
|
mock.patch.dict(os.environ, {"FD_MOE_BACKEND": "flashinfer-cutlass"}),
|
||||||
mock.patch.object(nvfp4_module.paddle, "float8_e4m3fn", paddle.float16),
|
mock.patch.object(nvfp4_module.paddle, "float8_e4m3fn", paddle.float16),
|
||||||
mock.patch.object(nvfp4_module, "free_tensor", side_effect=lambda _: None),
|
mock.patch.object(nvfp4_module, "free_tensor", side_effect=lambda _: None),
|
||||||
|
# Patch the module-level import to use our fake function
|
||||||
|
mock.patch.object(nvfp4_module, "flashinfer_cutlass_fused_moe", fake_cutlass_fused_moe),
|
||||||
):
|
):
|
||||||
method = ModelOptNvFp4FusedMoE(
|
method = ModelOptNvFp4FusedMoE(
|
||||||
ModelOptNvFp4Config(True, kv_cache_quant_algo=None, exclude_modules=[], group_size=16)
|
ModelOptNvFp4Config(True, kv_cache_quant_algo=None, exclude_modules=[], group_size=16)
|
||||||
|
|||||||
Reference in New Issue
Block a user