mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] support nvfp4 tbo (#7259)
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
"""
|
||||
|
||||
import os
|
||||
import threading
|
||||
from typing import Callable, Optional
|
||||
|
||||
import paddle
|
||||
@@ -31,6 +32,7 @@ from fastdeploy.model_executor.utils import (
|
||||
get_sm_version,
|
||||
set_weight_attrs,
|
||||
)
|
||||
from fastdeploy.worker.tbo import let_another_thread_run
|
||||
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase, is_nvfp4_supported
|
||||
|
||||
@@ -54,6 +56,7 @@ if is_nvfp4_supported():
|
||||
"compiles GEMM kernels during first load. This may take several minutes. "
|
||||
"The wait is expected and only happens once per process."
|
||||
)
|
||||
|
||||
from fastdeploy.model_executor.layers.moe.flashinfer_cutedsl_moe import (
|
||||
flashinfer_cutedsl_moe_masked,
|
||||
)
|
||||
@@ -445,6 +448,9 @@ class ModelOptNvFp4LinearMethod(QuantMethodBase):
|
||||
return out
|
||||
|
||||
|
||||
global_values = {}
|
||||
|
||||
|
||||
class ModelOptNvFp4FusedMoE(MoEMethodBase):
|
||||
"""Fused MoE method for Model Optimizer NVFP4.
|
||||
Supports loading NVFP4 checkpoints with the following structure:
|
||||
@@ -672,6 +678,9 @@ class ModelOptNvFp4FusedMoE(MoEMethodBase):
|
||||
|
||||
event = deep_ep.Buffer.capture()
|
||||
|
||||
if self.ep_prefill_runner.num_worst_tokens <= 0:
|
||||
let_another_thread_run()
|
||||
|
||||
# 2. ep dispatch
|
||||
(
|
||||
recv_x,
|
||||
@@ -688,26 +697,51 @@ class ModelOptNvFp4FusedMoE(MoEMethodBase):
|
||||
previous_event=event,
|
||||
)
|
||||
|
||||
if self.ep_prefill_runner.num_worst_tokens > 0:
|
||||
let_another_thread_run()
|
||||
|
||||
thread_name = threading.current_thread().name
|
||||
|
||||
if self.ep_prefill_runner.ep_engine.async_finish:
|
||||
event.current_stream_wait()
|
||||
|
||||
global global_values
|
||||
|
||||
if thread_name not in global_values:
|
||||
global_values[thread_name] = {}
|
||||
|
||||
# nvfp4 dispatch returns a plain BF16 tensor (no fp8 scale), unlike deepgemm which returns (value, scale) tuple
|
||||
recv_x_value = recv_x
|
||||
recv_x_scale = None
|
||||
if isinstance(recv_x, tuple):
|
||||
(recv_x_value, recv_x_scale) = recv_x
|
||||
else:
|
||||
recv_x_value = recv_x
|
||||
recv_x_scale = None
|
||||
|
||||
global_values[thread_name]["x"] = x
|
||||
global_values[thread_name]["topk_idx"] = topk_idx
|
||||
global_values[thread_name]["topk_weights"] = topk_weights
|
||||
|
||||
global_values[thread_name]["x_scale_tensor"] = None
|
||||
|
||||
global_values[thread_name]["recv_x_value"] = recv_x_value
|
||||
global_values[thread_name]["recv_x_scale"] = recv_x_scale
|
||||
global_values[thread_name]["recv_topk_idx"] = recv_topk_idx
|
||||
global_values[thread_name]["recv_topk_weights"] = recv_topk_weights
|
||||
global_values[thread_name]["handle"] = handle
|
||||
global_values[thread_name]["recv_num_tokens_per_expert_list"] = recv_num_tokens_per_expert_list
|
||||
|
||||
# 3. compute ffn
|
||||
token_all_num = sum(recv_num_tokens_per_expert_list)
|
||||
|
||||
if self.ep_prefill_runner.num_worst_tokens > 0:
|
||||
token_split_factor = 2 if int(os.getenv("USE_TBO", "0")) == 1 else 1
|
||||
use_tbo = os.getenv("USE_TBO", "0")
|
||||
token_split_factor = 2 if int(use_tbo) == 1 else 1
|
||||
max_tokens_per_rank = (
|
||||
layer.fd_config.scheduler_config.max_num_batched_tokens
|
||||
// layer.fd_config.parallel_config.tensor_parallel_size
|
||||
// token_split_factor
|
||||
)
|
||||
|
||||
# logger.debug(f"max_tokens_per_rank {max_tokens_per_rank}")
|
||||
|
||||
permute_input, permute_scale, permuted_indice_map, token_nums_per_expert = (
|
||||
call_prefill_permute_to_masked_gemm(
|
||||
x=recv_x_value,
|
||||
@@ -717,6 +751,7 @@ class ModelOptNvFp4FusedMoE(MoEMethodBase):
|
||||
max_token_num=layer.ep_size * max_tokens_per_rank,
|
||||
)
|
||||
)
|
||||
|
||||
max_token_num = layer.ep_size * max_tokens_per_rank
|
||||
permute_input = permute_input.reshape([layer.num_local_experts, max_token_num, recv_x_value.shape[-1]])
|
||||
|
||||
@@ -734,7 +769,7 @@ class ModelOptNvFp4FusedMoE(MoEMethodBase):
|
||||
a2_global_scale=layer.down_proj_input_scale_quant.expand([layer.num_local_experts]),
|
||||
w2_blockscale=layer.down_proj_blockscale_swizzled,
|
||||
w2_alpha=layer.g2_alphas,
|
||||
masked_m=token_nums_per_expert.squeeze(-1).cast(paddle.int32),
|
||||
masked_m=token_nums_per_expert.squeeze(-1),
|
||||
)
|
||||
|
||||
tmp_ffn_out = call_depermute_prefill_combine(
|
||||
@@ -751,14 +786,28 @@ class ModelOptNvFp4FusedMoE(MoEMethodBase):
|
||||
else:
|
||||
tmp_ffn_out = paddle.empty([0, hidden_size], dtype=paddle.bfloat16)
|
||||
|
||||
if shared_experts is not None:
|
||||
s_x = shared_experts(x)
|
||||
|
||||
# 4. EP combine
|
||||
event = deep_ep.Buffer.capture()
|
||||
if self.ep_prefill_runner.num_worst_tokens <= 0:
|
||||
let_another_thread_run()
|
||||
|
||||
global_values[thread_name]["combine_in"] = tmp_ffn_out
|
||||
|
||||
tmp_ffn_out, event = self.ep_prefill_runner.combine(tmp_ffn_out, handle, recv_topk_weights, event)
|
||||
|
||||
if self.ep_prefill_runner.num_worst_tokens > 0:
|
||||
let_another_thread_run()
|
||||
|
||||
if self.ep_prefill_runner.ep_engine.async_finish:
|
||||
event.current_stream_wait()
|
||||
|
||||
global_values[thread_name]["combine_out"] = tmp_ffn_out
|
||||
if shared_experts is not None:
|
||||
tmp_ffn_out += s_x
|
||||
|
||||
return tmp_ffn_out
|
||||
|
||||
def apply_ep_decode(
|
||||
@@ -798,8 +847,14 @@ class ModelOptNvFp4FusedMoE(MoEMethodBase):
|
||||
masked_m=token_nums_per_expert,
|
||||
)
|
||||
|
||||
if shared_experts is not None:
|
||||
s_x = shared_experts(x)
|
||||
|
||||
out = self.ep_decoder_runner.combine(ffn_out, topk_idx, topk_weights, handle)
|
||||
|
||||
if shared_experts is not None:
|
||||
out += s_x
|
||||
|
||||
return out
|
||||
|
||||
def apply_tp(
|
||||
|
||||
Reference in New Issue
Block a user