[Feature] support nvfp4 tbo (#7259)

This commit is contained in:
lizexu123
2026-04-09 17:29:39 +08:00
committed by GitHub
parent fcaf614133
commit 613f92ee8f
@@ -15,6 +15,7 @@
"""
import os
import threading
from typing import Callable, Optional
import paddle
@@ -31,6 +32,7 @@ from fastdeploy.model_executor.utils import (
get_sm_version,
set_weight_attrs,
)
from fastdeploy.worker.tbo import let_another_thread_run
from .quant_base import QuantConfigBase, QuantMethodBase, is_nvfp4_supported
@@ -54,6 +56,7 @@ if is_nvfp4_supported():
"compiles GEMM kernels during first load. This may take several minutes. "
"The wait is expected and only happens once per process."
)
from fastdeploy.model_executor.layers.moe.flashinfer_cutedsl_moe import (
flashinfer_cutedsl_moe_masked,
)
@@ -445,6 +448,9 @@ class ModelOptNvFp4LinearMethod(QuantMethodBase):
return out
global_values = {}
class ModelOptNvFp4FusedMoE(MoEMethodBase):
"""Fused MoE method for Model Optimizer NVFP4.
Supports loading NVFP4 checkpoints with the following structure:
@@ -672,6 +678,9 @@ class ModelOptNvFp4FusedMoE(MoEMethodBase):
event = deep_ep.Buffer.capture()
if self.ep_prefill_runner.num_worst_tokens <= 0:
let_another_thread_run()
# 2. ep dispatch
(
recv_x,
@@ -688,26 +697,51 @@ class ModelOptNvFp4FusedMoE(MoEMethodBase):
previous_event=event,
)
if self.ep_prefill_runner.num_worst_tokens > 0:
let_another_thread_run()
thread_name = threading.current_thread().name
if self.ep_prefill_runner.ep_engine.async_finish:
event.current_stream_wait()
global global_values
if thread_name not in global_values:
global_values[thread_name] = {}
# nvfp4 dispatch returns a plain BF16 tensor (no fp8 scale), unlike deepgemm which returns (value, scale) tuple
recv_x_value = recv_x
recv_x_scale = None
if isinstance(recv_x, tuple):
(recv_x_value, recv_x_scale) = recv_x
else:
recv_x_value = recv_x
recv_x_scale = None
global_values[thread_name]["x"] = x
global_values[thread_name]["topk_idx"] = topk_idx
global_values[thread_name]["topk_weights"] = topk_weights
global_values[thread_name]["x_scale_tensor"] = None
global_values[thread_name]["recv_x_value"] = recv_x_value
global_values[thread_name]["recv_x_scale"] = recv_x_scale
global_values[thread_name]["recv_topk_idx"] = recv_topk_idx
global_values[thread_name]["recv_topk_weights"] = recv_topk_weights
global_values[thread_name]["handle"] = handle
global_values[thread_name]["recv_num_tokens_per_expert_list"] = recv_num_tokens_per_expert_list
# 3. compute ffn
token_all_num = sum(recv_num_tokens_per_expert_list)
if self.ep_prefill_runner.num_worst_tokens > 0:
token_split_factor = 2 if int(os.getenv("USE_TBO", "0")) == 1 else 1
use_tbo = os.getenv("USE_TBO", "0")
token_split_factor = 2 if int(use_tbo) == 1 else 1
max_tokens_per_rank = (
layer.fd_config.scheduler_config.max_num_batched_tokens
// layer.fd_config.parallel_config.tensor_parallel_size
// token_split_factor
)
# logger.debug(f"max_tokens_per_rank {max_tokens_per_rank}")
permute_input, permute_scale, permuted_indice_map, token_nums_per_expert = (
call_prefill_permute_to_masked_gemm(
x=recv_x_value,
@@ -717,6 +751,7 @@ class ModelOptNvFp4FusedMoE(MoEMethodBase):
max_token_num=layer.ep_size * max_tokens_per_rank,
)
)
max_token_num = layer.ep_size * max_tokens_per_rank
permute_input = permute_input.reshape([layer.num_local_experts, max_token_num, recv_x_value.shape[-1]])
@@ -734,7 +769,7 @@ class ModelOptNvFp4FusedMoE(MoEMethodBase):
a2_global_scale=layer.down_proj_input_scale_quant.expand([layer.num_local_experts]),
w2_blockscale=layer.down_proj_blockscale_swizzled,
w2_alpha=layer.g2_alphas,
masked_m=token_nums_per_expert.squeeze(-1).cast(paddle.int32),
masked_m=token_nums_per_expert.squeeze(-1),
)
tmp_ffn_out = call_depermute_prefill_combine(
@@ -751,14 +786,28 @@ class ModelOptNvFp4FusedMoE(MoEMethodBase):
else:
tmp_ffn_out = paddle.empty([0, hidden_size], dtype=paddle.bfloat16)
if shared_experts is not None:
s_x = shared_experts(x)
# 4. EP combine
event = deep_ep.Buffer.capture()
if self.ep_prefill_runner.num_worst_tokens <= 0:
let_another_thread_run()
global_values[thread_name]["combine_in"] = tmp_ffn_out
tmp_ffn_out, event = self.ep_prefill_runner.combine(tmp_ffn_out, handle, recv_topk_weights, event)
if self.ep_prefill_runner.num_worst_tokens > 0:
let_another_thread_run()
if self.ep_prefill_runner.ep_engine.async_finish:
event.current_stream_wait()
global_values[thread_name]["combine_out"] = tmp_ffn_out
if shared_experts is not None:
tmp_ffn_out += s_x
return tmp_ffn_out
def apply_ep_decode(
@@ -798,8 +847,14 @@ class ModelOptNvFp4FusedMoE(MoEMethodBase):
masked_m=token_nums_per_expert,
)
if shared_experts is not None:
s_x = shared_experts(x)
out = self.ep_decoder_runner.combine(ffn_out, topk_idx, topk_weights, handle)
if shared_experts is not None:
out += s_x
return out
def apply_tp(