Files
FastDeploy/fastdeploy/model_executor/layers/quantization/nvfp4.py
T

912 lines
34 KiB
Python

"""
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import os
import threading
from typing import Callable, Optional
import paddle
from paddle import nn
from paddleformers.utils.log import logger
import fastdeploy
from fastdeploy import envs
from fastdeploy.model_executor.layers.moe import FusedMoE
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase
from fastdeploy.model_executor.utils import (
create_parameter_and_copy,
free_tensor,
get_sm_version,
set_weight_attrs,
)
from fastdeploy.worker.tbo import let_another_thread_run
from .quant_base import QuantConfigBase, QuantMethodBase, is_nvfp4_supported
# Only import flashinfer on supported GPUs (B卡)
if is_nvfp4_supported():
paddle.enable_compat(scope={"flashinfer"})
from flashinfer import fp4_quantize, mm_fp4
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
from fastdeploy.model_executor.layers.moe.ep import deep_ep
from fastdeploy.model_executor.ops.gpu import (
depermute_prefill_combine,
prefill_permute_to_masked_gemm,
)
if envs.FD_MOE_BACKEND == "flashinfer-cutedsl":
logger.info(
"FlashInfer cutedsl is slow to import because it triggers JIT compilation of "
"CUDA kernels via TVM/CODEGEN, and cuBLASLt initializes lookup tables and "
"compiles GEMM kernels during first load. This may take several minutes. "
"The wait is expected and only happens once per process."
)
from fastdeploy.model_executor.layers.moe.flashinfer_cutedsl_moe import (
flashinfer_cutedsl_moe_masked,
)
else:
# Not B卡, skip flashinfer imports
deep_ep = None
depermute_prefill_combine = None
prefill_permute_to_masked_gemm = None
flashinfer_cutedsl_moe_masked = None
fp4_quantize = None
mm_fp4 = None
flashinfer_cutlass_fused_moe = None
logger.warning(
f"NVFP4 requires Blackwell GPU (SM >= 100), "
f"current GPU has SM {get_sm_version()}. Skipping flashinfer imports."
)
def call_prefill_permute_to_masked_gemm(
x: paddle.Tensor,
scale: paddle.Tensor,
topk_ids: paddle.Tensor,
num_local_experts: int,
max_token_num: int,
):
"""
Permute input tokens and scales from token-major to expert-major layout
for MoE masked GEMM operations.
Args:
x: Input hidden states [num_tokens, hidden].
scale: Input scales [num_tokens, hidden_scale].
topk_ids: Expert routing indices [num_tokens, topk] (int64 or int32).
num_local_experts: Number of local experts on this device.
max_token_num: Maximum tokens per expert buffer.
Returns:
tuple: (permute_x, permute_scale, permuted_indice_map, token_nums_per_expert)
"""
if topk_ids.dtype != paddle.int64:
topk_ids = topk_ids.cast(paddle.int64)
# NVFP4 dispatch returns plain BF16 (no fp8 scale); pass empty tensor so the
# C++ op can detect the no-scale path via tensor.numel() == 0.
if scale is None:
scale = paddle.empty([0], dtype=paddle.float32)
results = prefill_permute_to_masked_gemm(x, scale, topk_ids, num_local_experts, max_token_num)
return results[0], results[1], results[2], results[3]
def call_depermute_prefill_combine(
x: paddle.Tensor,
indice_map: paddle.Tensor,
topk_weights: paddle.Tensor,
num_worst_tokens: int,
):
"""
Depermute and combine expert outputs back to token-major layout.
Args:
x: Expert outputs [num_local_experts, max_tokens_per_expert, hidden].
indice_map: Flat index tensor [num_worst_tokens, topk] (int32).
topk_weights: Combination weights [num_worst_tokens, topk] (float32).
num_worst_tokens: Number of output tokens to produce.
Returns:
depermuted_x: Combined output [num_worst_tokens, hidden].
"""
results = depermute_prefill_combine(x, indice_map, topk_weights, num_worst_tokens)
return results
def next_power_of_2(n: int):
return 1 << (n - 1).bit_length() if n > 0 else 1
def _process_scale_interleaved(scales):
scale_dim = len(scales.shape)
if scale_dim == 2:
scales = scales.unsqueeze(0)
assert len(scales.shape) == 3
B, M, K = scales.shape
round_up_multiple = lambda x, m: (x + m - 1) // m * m
M_padded = round_up_multiple(M, 128)
K_padded = round_up_multiple(K, 4)
padded_scales = paddle.empty([B, M_padded, K_padded], dtype=scales.dtype)
padded_scales[:B, :M, :K].copy_(scales)
batches, rows, cols = padded_scales.shape
assert rows % 128 == 0
assert cols % 4 == 0
padded_scales = padded_scales.reshape(batches, rows // 128, 4, 32, cols // 4, 4)
padded_scales = padded_scales.transpose([0, 1, 4, 3, 2, 5])
# [batches, rows // 128, cols // 4, 32, 4, 4]
padded_scales = padded_scales.contiguous().to(paddle.device.get_device())
padded_scales = (
padded_scales.reshape(M_padded, K_padded) if scale_dim == 2 else padded_scales.reshape(B, M_padded, K_padded)
)
return padded_scales
class ModelOptNvFp4Config(QuantConfigBase):
"""
quantization config for ModelOpt Nvfp4 datatype
"""
def __init__(
self,
is_checkpoint_nvfp4_serialized: bool,
kv_cache_quant_algo: str | None,
exclude_modules: list[str],
group_size: int = 16,
is_checkpoint_bf16: bool = False,
) -> None:
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
if is_checkpoint_nvfp4_serialized:
logger.warning(
"Detected ModelOpt NVFP4 checkpoint. Please note that"
" the format is experimental and could change in future."
)
self.group_size = group_size
self.kv_cache_quant_algo = kv_cache_quant_algo
self.exclude_modules = exclude_modules
self.quant_max_bound = 6
self.quant_min_bound = -6
self.quant_round_type = 1
self.is_checkpoint_bf16 = is_checkpoint_bf16
def name(self) -> str:
return "modelopt_fp4"
@classmethod
def from_config(cls, config: dict) -> "ModelOptNvFp4Config":
quant_config = config
quant_method = quant_config.get("quant_algo", "")
if not quant_method:
raise ValueError("Missing 'quant_algo' in quantization config")
# Handle kv_cache_quant_algo with proper type validation
kv_cache_quant_algo_raw = quant_config.get("kv_cache_quant_algo")
if kv_cache_quant_algo_raw is None:
# No KV cache quantization by default
kv_cache_quant_algo = None
elif isinstance(kv_cache_quant_algo_raw, str):
kv_cache_quant_algo = kv_cache_quant_algo_raw
else:
raise ValueError(f"kv_cache_quant_algo must be a string, got " f"{type(kv_cache_quant_algo_raw)}")
# Handle group_size with proper type validation
group_size_raw = quant_config.get("group_size")
if group_size_raw is None:
group_size = 16 # Default value
elif isinstance(group_size_raw, int):
group_size = group_size_raw
else:
try:
group_size = int(group_size_raw)
except (ValueError, TypeError):
raise ValueError(f"group_size must be an integer, got {type(group_size_raw)}") from None
# "exclude_modules" is the key in the legacy hf_quant_config.json
exclude_modules = quant_config.get("exclude_modules", [])
if not isinstance(exclude_modules, list):
raise ValueError(f"exclude_modules must be a list, got {type(exclude_modules)}")
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
# For FP4, these fields are required
if is_checkpoint_nvfp4_serialized and "quantization" in config:
# Check if required fields are present in the quantization config
quant_config = config["quantization"]
required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"]
missing_fields = [field for field in required_fields if field not in quant_config]
if missing_fields:
raise ValueError(
f"NVFP4 quantization requires the following fields in " f"hf_quant_config.json: {missing_fields}"
)
return cls(
is_checkpoint_nvfp4_serialized=is_checkpoint_nvfp4_serialized,
kv_cache_quant_algo=kv_cache_quant_algo,
exclude_modules=exclude_modules,
group_size=group_size,
)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
"""
Get quantization method.
"""
if isinstance(layer, FusedMoE):
return ModelOptNvFp4FusedMoE(self)
else:
return ModelOptNvFp4LinearMethod(self)
class ModelOptNvFp4LinearMethod(QuantMethodBase):
"""Linear method for Model Optimizer NVFP4.
Supports loading NVFP4 checkpoints with the following structure:
input_scale: paddle.float32, scalar ,
weight: NVFP4(represented as byte) Shape: [1, X, y/2]
weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
weight_scale_2: paddle.float32, scalar,
Args: quant_config: The ModelOpt quantization config.
"""
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
self.quant_config = quant_config
self.backend = "none"
if envs.FD_NVFP4_GEMM_BACKEND is None:
self.backend = "flashinfer-cutlass"
elif envs.FD_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
self.backend = envs.FD_NVFP4_GEMM_BACKEND
if self.backend == "none":
raise ValueError(
"No valid NVFP4 GEMM backend found. Please check your platform capability and installtion of Flashinfer."
)
logger.info(f"Using {self.backend} for NVFP4 GEMM")
def create_weights(
self,
layer,
**extra_weight_attrs,
):
# Model storage is column-major, so we need to invert the output_dim flag
extra_weight_attrs["output_dim"] = not extra_weight_attrs["output_dim"]
K = layer.weight_shape[0]
N = layer.weight_shape[1]
# Model stored weights are in [N, K//2] format
# Create weight shape to match model storage format
weight_shape = [N, K // 2]
layer.weight_dtype = "uint8"
input_scale_shape = [1]
weight_scale_shape = [N, K // self.quant_config.group_size]
weight_scale_2_shape = [1]
self._create_main_weight(layer, weight_shape, extra_weight_attrs)
self._create_input_scale(layer, input_scale_shape)
self._create_weight_scales(layer, weight_scale_shape, weight_scale_2_shape, extra_weight_attrs)
def _create_main_weight(self, layer, weight_shape, extra_weight_attrs):
"""Create main weight parameter
Args:
layer: Current layer object
weight_shape: Weight shape
extra_weight_attrs: Extra weight attributes
"""
layer.weight = layer.create_parameter(
shape=weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
set_weight_attrs(
layer.weight,
extra_weight_attrs,
)
def _create_input_scale(self, layer, input_scale_shape):
"""Create input scale parameter
Args:
layer: Current layer object
input_scale_shape: Input scale shape
"""
layer.input_scale = layer.create_parameter(
shape=input_scale_shape,
dtype=paddle.float32,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
def _create_weight_scales(self, layer, weight_scale_shape, weight_scale_2_shape, extra_weight_attrs):
"""Create weight scale parameters
Args:
layer: Current layer object
weight_scale_shape: Weight scale shape
weight_scale_2_shape: Secondary weight scale shape
extra_weight_attrs: Extra weight attributes
"""
layer.weight_scale = layer.create_parameter(
shape=weight_scale_shape,
dtype=paddle.float8_e4m3fn,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
set_weight_attrs(
layer.weight_scale,
extra_weight_attrs,
)
layer.weight_scale_2 = layer.create_parameter(
shape=weight_scale_2_shape,
dtype=paddle.float32,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
def process_weights_after_loading(self, layer) -> None:
input_scale_2 = layer.input_scale.max().to(paddle.float32)
weight_scale_2 = layer.weight_scale_2.max().to(paddle.float32)
alpha = input_scale_2 * weight_scale_2
input_scale_inv = (1 / input_scale_2).to(paddle.float32)
weight_scale_interleaved = _process_scale_interleaved(layer.weight_scale)
free_tensor(layer.input_scale)
free_tensor(layer.weight_scale_2)
layer.weight_scale_2 = layer.create_parameter(
shape=weight_scale_2.shape,
dtype=weight_scale_2.dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.input_scale = layer.create_parameter(
shape=input_scale_2.shape,
dtype=input_scale_2.dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.alpha = layer.create_parameter(
shape=alpha.shape,
dtype=alpha.dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.input_scale_inv = layer.create_parameter(
shape=input_scale_inv.shape,
dtype=input_scale_inv.dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight_scale_interleaved = layer.create_parameter(
shape=weight_scale_interleaved.shape,
dtype=weight_scale_interleaved.dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight_scale_2.copy_(weight_scale_2, False)
layer.input_scale.copy_(input_scale_2, False)
layer.alpha.copy_(alpha, False)
layer.input_scale_inv.copy_(input_scale_inv, False)
layer.weight_scale_interleaved.copy_(weight_scale_interleaved, False)
def apply(
self,
layer,
x,
):
x_m, _ = x.shape
w_n, _ = layer.weight.shape
output_shape = [x_m, w_n]
output_dtype = x.dtype
# Quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_scale_interleaved = fp4_quantize(x, layer.input_scale_inv)
assert x_fp4.dtype == paddle.uint8
assert layer.weight.dtype == paddle.uint8
assert layer.weight_scale_interleaved.dtype == paddle.float8_e4m3fn
assert layer.alpha.dtype == paddle.float32
if self.backend.startswith("flashinfer-"):
backend = self.backend[len("flashinfer-") :]
else:
raise ValueError(f"Unsupported backend: {self.backend}.")
w = layer.weight.T
w_scale_interleaved = layer.weight_scale_interleaved.T
if backend == "cutlass":
x_scale_interleaved = x_scale_interleaved.view(paddle.uint8)
w_scale_interleaved = w_scale_interleaved.view(paddle.uint8)
out = mm_fp4(x_fp4, w, x_scale_interleaved, w_scale_interleaved, layer.alpha, output_dtype, backend=backend)
if layer.with_bias:
out = paddle.add(out, layer.bias)
assert out.shape == output_shape
return out
global_values = {}
class ModelOptNvFp4FusedMoE(MoEMethodBase):
"""Fused MoE method for Model Optimizer NVFP4.
Supports loading NVFP4 checkpoints with the following structure:
input_scale: paddle.float32, scalar ,
weight: NVFP4(represented as byte) Shape: [1, X, y/2]
weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
weight_scale_2: paddle.float32, scalar,
Args:
quant_config: The ModelOpt quantization config.
moe_config: The MoE configuration.
layer: The linear layer.
"""
def __init__(self, quant_config: ModelOptNvFp4Config):
self.quant_config = quant_config
self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"]
self.added_scale_attrs = [
"up_gate_proj_weight_scale",
"down_proj_weight_scale",
]
self.backend = "none"
if envs.FD_MOE_BACKEND is None:
# currently support flashinfer-cutlass, flashinfer-trtllm will support in the future
self.backend = "flashinfer-cutlass"
elif envs.FD_MOE_BACKEND.startswith("flashinfer-"):
self.backend = envs.FD_MOE_BACKEND
if self.backend == "none":
raise ValueError(
"No valid NVFP4 flashinfer MoE backend found. Please check your platform capability and installtion of FlashInfer."
)
logger.info(f"Using {self.backend} for NVFP4 FusedMoE")
def create_weights(self, layer, **extra_weight_attrs):
"""
NVFP4 MoE create weight.
"""
self.up_gate_proj_weight_shape = [
layer.num_local_experts,
layer.moe_intermediate_size * 2,
layer.hidden_size // 2,
]
self.down_proj_weight_shape = [
layer.num_local_experts,
layer.hidden_size,
layer.moe_intermediate_size // 2,
]
self.up_gate_proj_scale_shape = self.up_gate_proj_weight_shape[0:2] + [
layer.hidden_size // self.quant_config.group_size
]
self.down_proj_scale_shape = self.down_proj_weight_shape[0:2] + [
layer.moe_intermediate_size // self.quant_config.group_size
]
self.weight_scale_dtype = paddle.float8_e4m3fn
self.weight_dtype = paddle.uint8
up_gate_proj_weight_name = self.added_weight_attrs[0]
down_proj_weight_name = self.added_weight_attrs[1]
up_gate_proj_scale_name = self.added_scale_attrs[0]
down_proj_scale_name = self.added_scale_attrs[1]
setattr(
layer,
up_gate_proj_weight_name,
layer.create_parameter(
shape=self.up_gate_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
down_proj_weight_name,
layer.create_parameter(
shape=self.down_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# weight_scale
setattr(
layer,
up_gate_proj_scale_name,
layer.create_parameter(
shape=self.up_gate_proj_scale_shape,
dtype=self.weight_scale_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
down_proj_scale_name,
layer.create_parameter(
shape=self.down_proj_scale_shape,
dtype=self.weight_scale_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# weight_scale_2
layer.up_gate_proj_weight_scale_2 = layer.create_parameter(
shape=[layer.num_local_experts, 2],
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.down_proj_weight_scale_2 = layer.create_parameter(
shape=[layer.num_local_experts],
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
)
# input_scale
layer.up_gate_proj_input_scale = layer.create_parameter(
shape=[layer.num_local_experts, 2],
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.down_proj_input_scale = layer.create_parameter(
shape=[layer.num_local_experts],
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
)
set_weight_attrs(
getattr(layer, up_gate_proj_weight_name),
{**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}},
)
set_weight_attrs(
getattr(layer, up_gate_proj_scale_name),
{**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}},
)
set_weight_attrs(
getattr(layer, down_proj_weight_name),
{**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}},
)
set_weight_attrs(
getattr(layer, down_proj_scale_name),
{**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}},
)
set_weight_attrs(layer.up_gate_proj_weight_scale_2, {**extra_weight_attrs, "weight_type": "weight_scale_2"})
set_weight_attrs(layer.down_proj_weight_scale_2, {**extra_weight_attrs, "weight_type": "weight_scale_2"})
set_weight_attrs(layer.up_gate_proj_input_scale, {**extra_weight_attrs, "weight_type": "input_scale"})
set_weight_attrs(layer.down_proj_input_scale, {**extra_weight_attrs, "weight_type": "input_scale"})
@property
def load_up_proj_weight_first(self) -> bool:
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
if self.backend == "flashinfer-cutlass":
return True
def process_weights_after_loading(self, layer):
""" """
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
if self.backend == "flashinfer-cutlass":
[a, b] = layer.up_gate_proj_weight.split(2, axis=1)
layer.up_gate_proj_weight.set_value(paddle.concat([b, a], axis=1))
[a, b] = layer.up_gate_proj_weight_scale.split(2, axis=1)
layer.up_gate_proj_weight_scale.set_value(paddle.concat([b, a], axis=1))
up_gate_proj_weight_scale_2 = layer.up_gate_proj_weight_scale_2[:, 0]
free_tensor(layer.up_gate_proj_weight_scale_2)
create_parameter_and_copy(layer, name="up_gate_proj_weight_scale_2", weight=up_gate_proj_weight_scale_2)
up_gate_proj_input_scale = paddle.max(layer.up_gate_proj_input_scale).cast("float32")
down_proj_input_scale = paddle.max(layer.down_proj_input_scale).cast("float32")
# Create shared parameters
create_parameter_and_copy(
layer, "g1_alphas", (up_gate_proj_input_scale * up_gate_proj_weight_scale_2).cast("float32")
)
create_parameter_and_copy(
layer, "g2_alphas", (down_proj_input_scale * layer.down_proj_weight_scale_2).cast("float32")
)
create_parameter_and_copy(
layer, "up_gate_proj_input_scale_quant", (1 / up_gate_proj_input_scale).cast("float32")
)
create_parameter_and_copy(layer, "down_proj_input_scale_quant", (1 / down_proj_input_scale).cast("float32"))
for name, weight_scale in [
("up_gate", layer.up_gate_proj_weight_scale),
("down", layer.down_proj_weight_scale),
]:
assert weight_scale.shape[2] % 16 == 0, f"Expected {name}_weight_scale.dim(2) to be divisible by 16"
assert (
weight_scale.dtype == paddle.float8_e4m3fn
), f"{name} Weight Blockscale must be represented as FP8-E4M3"
if envs.FD_NVFP4_LOAD_BLOCKSCALE_LEAVE:
up_gate_proj_blockscale_swizzled = layer.up_gate_proj_weight_scale
else:
up_gate_proj_blockscale_swizzled = _process_scale_interleaved(layer.up_gate_proj_weight_scale)
create_parameter_and_copy(
layer, name="up_gate_proj_blockscale_swizzled", weight=up_gate_proj_blockscale_swizzled
)
free_tensor(layer.up_gate_proj_weight_scale)
layer.up_gate_proj_weight_scale = None
if envs.FD_NVFP4_LOAD_BLOCKSCALE_LEAVE:
down_proj_blockscale_swizzled = layer.down_proj_weight_scale
else:
down_proj_blockscale_swizzled = _process_scale_interleaved(layer.down_proj_weight_scale)
create_parameter_and_copy(layer, name="down_proj_blockscale_swizzled", weight=down_proj_blockscale_swizzled)
free_tensor(layer.down_proj_weight_scale)
layer.down_proj_weight_scale = None
def apply_ep_prefill(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
shared_experts: nn.Layer = None,
) -> paddle.Tensor:
# 1. top experts and weights
gate_out = gate(x.cast("float32"))
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)
hidden_size = x.shape[1]
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_idx)
event = deep_ep.Buffer.capture()
if self.ep_prefill_runner.num_worst_tokens <= 0:
let_another_thread_run()
# 2. ep dispatch
(
recv_x,
recv_topk_idx,
recv_topk_weights,
recv_num_tokens_per_expert_list,
handle,
event,
) = self.ep_prefill_runner.dispatch(
x,
topk_idx,
topk_weights,
expert_alignment=128,
previous_event=event,
)
if self.ep_prefill_runner.num_worst_tokens > 0:
let_another_thread_run()
thread_name = threading.current_thread().name
if self.ep_prefill_runner.ep_engine.async_finish:
event.current_stream_wait()
global global_values
if thread_name not in global_values:
global_values[thread_name] = {}
# nvfp4 dispatch returns a plain BF16 tensor (no fp8 scale), unlike deepgemm which returns (value, scale) tuple
if isinstance(recv_x, tuple):
(recv_x_value, recv_x_scale) = recv_x
else:
recv_x_value = recv_x
recv_x_scale = None
global_values[thread_name]["x"] = x
global_values[thread_name]["topk_idx"] = topk_idx
global_values[thread_name]["topk_weights"] = topk_weights
global_values[thread_name]["x_scale_tensor"] = None
global_values[thread_name]["recv_x_value"] = recv_x_value
global_values[thread_name]["recv_x_scale"] = recv_x_scale
global_values[thread_name]["recv_topk_idx"] = recv_topk_idx
global_values[thread_name]["recv_topk_weights"] = recv_topk_weights
global_values[thread_name]["handle"] = handle
global_values[thread_name]["recv_num_tokens_per_expert_list"] = recv_num_tokens_per_expert_list
# 3. compute ffn
token_all_num = sum(recv_num_tokens_per_expert_list)
if self.ep_prefill_runner.num_worst_tokens > 0:
use_tbo = os.getenv("USE_TBO", "0")
token_split_factor = 2 if int(use_tbo) == 1 else 1
max_tokens_per_rank = (
layer.fd_config.scheduler_config.max_num_batched_tokens
// layer.fd_config.parallel_config.tensor_parallel_size
// token_split_factor
)
permute_input, permute_scale, permuted_indice_map, token_nums_per_expert = (
call_prefill_permute_to_masked_gemm(
x=recv_x_value,
scale=recv_x_scale,
topk_ids=recv_topk_idx,
num_local_experts=layer.num_local_experts,
max_token_num=layer.ep_size * max_tokens_per_rank,
)
)
max_token_num = layer.ep_size * max_tokens_per_rank
permute_input = permute_input.reshape([layer.num_local_experts, max_token_num, recv_x_value.shape[-1]])
# ffn_out: [num_local_experts, m, hidden_size]
# NVFP4 dispatch returns BF16 (no pre-quantized scale), so permute_scale is empty.
# Use per-expert 1/input_scale (up_gate_proj_input_scale_quant) as input_global_scale,
# consistent with apply_ep_decode which also uses this value directly.
ffn_out = flashinfer_cutedsl_moe_masked(
hidden_states=(permute_input, None),
input_global_scale=layer.up_gate_proj_input_scale_quant.expand([layer.num_local_experts]),
w1=layer.up_gate_proj_weight,
w1_blockscale=layer.up_gate_proj_blockscale_swizzled,
w1_alpha=layer.g1_alphas,
w2=layer.down_proj_weight,
a2_global_scale=layer.down_proj_input_scale_quant.expand([layer.num_local_experts]),
w2_blockscale=layer.down_proj_blockscale_swizzled,
w2_alpha=layer.g2_alphas,
masked_m=token_nums_per_expert.squeeze(-1),
)
tmp_ffn_out = call_depermute_prefill_combine(
x=ffn_out,
indice_map=permuted_indice_map,
topk_weights=recv_topk_weights,
num_worst_tokens=recv_x_value.shape[0],
)
elif token_all_num > 0:
raise NotImplementedError(
"NVFP4 EP prefill contiguous path (num_worst_tokens <= 0, token_all_num > 0) is not yet implemented."
)
else:
tmp_ffn_out = paddle.empty([0, hidden_size], dtype=paddle.bfloat16)
if shared_experts is not None:
s_x = shared_experts(x)
# 4. EP combine
event = deep_ep.Buffer.capture()
if self.ep_prefill_runner.num_worst_tokens <= 0:
let_another_thread_run()
global_values[thread_name]["combine_in"] = tmp_ffn_out
tmp_ffn_out, event = self.ep_prefill_runner.combine(tmp_ffn_out, handle, recv_topk_weights, event)
if self.ep_prefill_runner.num_worst_tokens > 0:
let_another_thread_run()
if self.ep_prefill_runner.ep_engine.async_finish:
event.current_stream_wait()
global_values[thread_name]["combine_out"] = tmp_ffn_out
if shared_experts is not None:
tmp_ffn_out += s_x
return tmp_ffn_out
def apply_ep_decode(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
shared_experts: nn.Layer = None,
) -> paddle.Tensor:
gate_out = gate(x.cast("float32"))
topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_idx)
recv_x, token_nums_per_expert, handle = self.ep_decoder_runner.dispatch(
x,
topk_idx,
topk_weights,
use_fp8=False,
)
# Compute FFN via CuteDSL masked grouped GEMM
num_experts = layer.num_local_experts
ffn_out = flashinfer_cutedsl_moe_masked(
hidden_states=(recv_x, None),
input_global_scale=layer.up_gate_proj_input_scale_quant.expand([num_experts]),
w1=layer.up_gate_proj_weight,
w1_blockscale=layer.up_gate_proj_blockscale_swizzled,
w1_alpha=layer.g1_alphas,
w2=layer.down_proj_weight,
a2_global_scale=layer.down_proj_input_scale_quant.expand([num_experts]),
w2_blockscale=layer.down_proj_blockscale_swizzled,
w2_alpha=layer.g2_alphas,
masked_m=token_nums_per_expert,
)
if shared_experts is not None:
s_x = shared_experts(x)
out = self.ep_decoder_runner.combine(ffn_out, topk_idx, topk_weights, handle)
if shared_experts is not None:
out += s_x
return out
def apply_tp(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
shared_experts: nn.Layer = None,
) -> paddle.Tensor:
if self.backend == "flashinfer-cutlass":
gate_out = gate(x.cast("float32"))
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out,
layer.gate_correction_bias,
layer.top_k,
True, # apply_norm_weight,
False,
)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids)
output_dtype = x.dtype
x_sf = None
output = paddle.empty_like(x)
# flashinfer cutlass
_ = flashinfer_cutlass_fused_moe(
input=x,
token_selected_experts=topk_ids.to(paddle.int),
token_final_scales=topk_weights,
fc1_expert_weights=getattr(layer, self.added_weight_attrs[0]).view(paddle.long),
fc2_expert_weights=getattr(layer, self.added_weight_attrs[1]).view(paddle.long),
output_dtype=output_dtype,
input_sf=x_sf,
quant_scales=[
layer.up_gate_proj_input_scale_quant,
layer.up_gate_proj_blockscale_swizzled.view(paddle.int32),
layer.g1_alphas,
layer.down_proj_input_scale_quant,
layer.down_proj_blockscale_swizzled.view(paddle.int32),
layer.g2_alphas,
],
ep_size=layer.ep_size,
ep_rank=layer.ep_rank,
tp_size=layer.tp_size,
tp_rank=layer.tp_rank,
tune_max_num_tokens=next_power_of_2(x.shape[0]),
output=output,
)
return output
return paddle.empty_like(x)