[Optimization] enable trtllm_all_reduce fusion kernel in glm model (#6660)

* enable trtllm_all_reduce fusion kernel in glm model

* fix conflict

* format update

* fix a bug

* modify test

* modify test

* support empty tensor and modify test

* fix test_linear config issues

* modify test name

* add edge test case

* modify format

* fix conflict

* modify default max token num in trtllm_allreduce_fusion

* add max token num branch for trtllm_allreduce_fusion

* fix format

* fix rmsnorm config issue

* modify 2025 to 2026

* using compat grard

* Lazily import flashinfer.comm and fix test config issue

* fix test issues

* add flashinfer cache dir clean machine

* fix some issues
This commit is contained in:
Bingoo
2026-04-16 14:10:19 +08:00
committed by GitHub
parent e53f5184ac
commit 6b891da02b
17 changed files with 871 additions and 11 deletions
+1
View File
@@ -671,6 +671,7 @@ class ParallelConfig:
self.pod_ip: str = None self.pod_ip: str = None
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce). # enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).
self.disable_custom_all_reduce: bool = False self.disable_custom_all_reduce: bool = False
self.enable_flashinfer_allreduce_fusion: bool = False
for key, value in args.items(): for key, value in args.items():
if hasattr(self, key): if hasattr(self, key):
setattr(self, key, value) setattr(self, key, value)
+11
View File
@@ -274,6 +274,11 @@ class EngineArgs:
Flag to disable the custom all-reduce kernel. Flag to disable the custom all-reduce kernel.
""" """
enable_flashinfer_allreduce_fusion: bool = False
"""
Flag to enable all reduce fusion kernel in flashinfer.
"""
use_internode_ll_two_stage: bool = False use_internode_ll_two_stage: bool = False
""" """
Flag to use the internode_ll_two_stage kernel. Flag to use the internode_ll_two_stage kernel.
@@ -995,6 +1000,12 @@ class EngineArgs:
default=EngineArgs.disable_custom_all_reduce, default=EngineArgs.disable_custom_all_reduce,
help="Flag to disable custom all-reduce.", help="Flag to disable custom all-reduce.",
) )
parallel_group.add_argument(
"--enable-flashinfer-allreduce-fusion",
action="store_true",
default=EngineArgs.enable_flashinfer_allreduce_fusion,
help="Flag to enable all reduce fusion kernel in flashinfer.",
)
parallel_group.add_argument( parallel_group.add_argument(
"--use-internode-ll-two-stage", "--use-internode-ll-two-stage",
action="store_true", action="store_true",
+1
View File
@@ -2503,6 +2503,7 @@ class EngineService:
"moe_gate_fp32": self.cfg.model_config.moe_gate_fp32, "moe_gate_fp32": self.cfg.model_config.moe_gate_fp32,
"enable_entropy": self.cfg.model_config.enable_entropy, "enable_entropy": self.cfg.model_config.enable_entropy,
"enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule, "enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule,
"enable_flashinfer_allreduce_fusion": self.cfg.parallel_config.enable_flashinfer_allreduce_fusion,
} }
for worker_flag, value in worker_store_true_flag.items(): for worker_flag, value in worker_store_true_flag.items():
if value: if value:
+1
View File
@@ -656,6 +656,7 @@ class LLMEngine:
"enable_entropy": self.cfg.model_config.enable_entropy, "enable_entropy": self.cfg.model_config.enable_entropy,
"ep_prefill_use_worst_num_tokens": self.cfg.parallel_config.ep_prefill_use_worst_num_tokens, "ep_prefill_use_worst_num_tokens": self.cfg.parallel_config.ep_prefill_use_worst_num_tokens,
"enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule, "enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule,
"enable_flashinfer_allreduce_fusion": self.cfg.parallel_config.enable_flashinfer_allreduce_fusion,
} }
for worker_flag, value in worker_store_true_flag.items(): for worker_flag, value in worker_store_true_flag.items():
if value: if value:
@@ -0,0 +1,209 @@
"""
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Optional, Tuple
import paddle
import paddle.distributed as dist
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.utils import has_flashinfer
from fastdeploy.utils import get_logger
logger = get_logger("flashinfer", "flashinfer.log")
_flashinfer_comm = None
_workspace_manager = None
def _get_flashinfer_comm():
"""Lazily import flashinfer.comm to avoid side effects at module load time."""
global _flashinfer_comm
if _flashinfer_comm is not None:
return _flashinfer_comm
if has_flashinfer():
try:
with paddle.use_compat_guard(enable=True, scope={"flashinfer"}):
import flashinfer.comm as comm
_flashinfer_comm = comm
except ImportError:
logger.warning("flashinfer.comm is not available, falling back to standard " "implementation")
return _flashinfer_comm
class FlashInferWorkspaceManager:
def __init__(self):
self.workspace_tensor = None
self.ipc_handles = None
self.world_size = None
self.rank = None
self.initialized = False
def initialize(
self,
world_size: int,
rank: int,
max_token_num: int,
hidden_dim: int,
group=None,
use_fp32_lamport: bool = False,
):
"""Initialize workspace"""
if self.initialized and self.world_size == world_size:
return
comm = _get_flashinfer_comm()
if comm is None:
logger.warning("FlashInfer comm not available, skipping workspace " "initialization")
return
self.cleanup()
self.ipc_handles, self.workspace_tensor = comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
rank,
world_size,
max_token_num,
hidden_dim,
group=group,
use_fp32_lamport=use_fp32_lamport,
)
self.world_size = world_size
self.rank = rank
self.initialized = True
logger.info(f"FlashInfer workspace initialized for rank {rank}, " f"world_size {world_size}")
def cleanup(self):
"""Clean up workspace"""
if self.initialized and self.ipc_handles is not None:
try:
comm = _get_flashinfer_comm()
if comm is not None:
comm.trtllm_destroy_ipc_workspace_for_all_reduce(self.ipc_handles, group=dist.get_group())
except Exception as e:
logger.warning(f"Failed to cleanup FlashInfer workspace: {e}")
finally:
self.workspace_tensor = None
self.ipc_handles = None
self.initialized = False
_workspace_manager = FlashInferWorkspaceManager()
def ensure_workspace_initialized(
fd_config: FDConfig, max_token_num: int = 2048, hidden_dim: int = 4096, use_fp32_lamport: bool = False
):
"""Ensure workspace is initialized"""
comm = _get_flashinfer_comm()
if not has_flashinfer() or comm is None:
return False
assert fd_config is not None
world_size = fd_config.parallel_config.tensor_parallel_size
if world_size <= 1:
return False
rank = dist.get_rank()
if not _workspace_manager.initialized or _workspace_manager.world_size != world_size:
_workspace_manager.initialize(
world_size=world_size,
rank=rank,
max_token_num=max_token_num,
hidden_dim=hidden_dim,
use_fp32_lamport=use_fp32_lamport,
)
return _workspace_manager.initialized
def flashinfer_allreduce_residual_rmsnorm(
fd_config: FDConfig,
input_tensor: paddle.Tensor,
residual: paddle.Tensor,
weight: paddle.Tensor,
eps: float = 1e-6,
max_token_num: int = 2048,
use_oneshot: Optional[bool] = None,
trigger_completion_at_end: bool = False,
fp32_acc: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""
Use FlashInfer's fused allreduce + residual + RMS norm operation
"""
comm = _get_flashinfer_comm()
if not has_flashinfer() or comm is None:
logger.debug("FlashInfer not available, falling back to standard " "implementation")
return None, None
assert fd_config is not None
world_size = fd_config.parallel_config.tensor_parallel_size
if world_size <= 1:
logger.debug("Single GPU, no need for allreduce fusion")
return None, None
assert input_tensor.shape[0] <= max_token_num
if not ensure_workspace_initialized(
fd_config=fd_config,
max_token_num=max_token_num,
hidden_dim=input_tensor.shape[-1],
use_fp32_lamport=(input_tensor.dtype == paddle.float32),
):
logger.debug("FlashInfer workspace not available")
return None, None
token_num, hidden_dim = input_tensor.shape
residual_out = paddle.empty_like(residual)
norm_out = paddle.empty_like(input_tensor)
# support empty tensor
if input_tensor.shape[0] == 0:
return norm_out, residual_out
comm.trtllm_allreduce_fusion(
allreduce_in=input_tensor,
world_size=world_size,
world_rank=dist.get_rank(),
token_num=token_num,
hidden_dim=hidden_dim,
workspace_ptrs=_workspace_manager.workspace_tensor,
launch_with_pdl=True,
use_oneshot=use_oneshot,
trigger_completion_at_end=trigger_completion_at_end,
fp32_acc=fp32_acc,
pattern_code=(comm.AllReduceFusionPattern.kARResidualRMSNorm),
allreduce_out=None,
residual_in=residual,
residual_out=residual_out,
norm_out=norm_out,
quant_out=None,
scale_out=None,
rms_gamma=weight,
rms_eps=eps,
scale_factor=None,
layout_code=None,
)
return norm_out, residual_out
def cleanup_flashinfer_workspace():
global _workspace_manager
if _workspace_manager is not None:
_workspace_manager.cleanup()
+14 -2
View File
@@ -854,6 +854,7 @@ class RowParallelLinear(LinearBase):
skip_quant: bool = False, skip_quant: bool = False,
weight_dtype: str = "", weight_dtype: str = "",
layer_id: int = -1, layer_id: int = -1,
enable_all_reduce_fusion: bool = None,
): ):
""" """
Initialize a linear layer with additional parameters for inference and quantization. Initialize a linear layer with additional parameters for inference and quantization.
@@ -865,9 +866,17 @@ class RowParallelLinear(LinearBase):
input_size (int): Number of input features. Defaults to None. input_size (int): Number of input features. Defaults to None.
output_size (int): Number of output features. Defaults to None. output_size (int): Number of output features. Defaults to None.
with_bias (bool): Whether to include bias or not. Defaults to False. with_bias (bool): Whether to include bias or not. Defaults to False.
skip_quant (bool): Whether to skip quantization. Defaults to False. skip_quant (bool): Whether to skip quantization or not. Defaults to False.
enable_all_reduce_fusion (bool, optional): Whether to enable all-reduce fusion.
If None, it is determined by the config flag and prefix. Defaults to None.
""" """
self.fd_config = fd_config self.fd_config = fd_config
if enable_all_reduce_fusion is None:
self.enable_all_reduce_fusion = False
else:
self.enable_all_reduce_fusion = (
fd_config.parallel_config.enable_flashinfer_allreduce_fusion and enable_all_reduce_fusion
)
self.ep_size = fd_config.parallel_config.expert_parallel_size self.ep_size = fd_config.parallel_config.expert_parallel_size
self.tp_size = fd_config.parallel_config.tensor_parallel_size self.tp_size = fd_config.parallel_config.tensor_parallel_size
self.tp_group = fd_config.parallel_config.tp_group self.tp_group = fd_config.parallel_config.tp_group
@@ -945,7 +954,10 @@ class RowParallelLinear(LinearBase):
out = self.quant_method.apply(self, x) out = self.quant_method.apply(self, x)
if self.reduce_results and self.tp_size > 1: need_tp_all_reduce = (
self.reduce_results and self.tp_size > 1 and not (self.enable_all_reduce_fusion and out.shape[0] <= 2048)
)
if need_tp_all_reduce:
out = tensor_model_parallel_all_reduce(out, self.tp_group) out = tensor_model_parallel_all_reduce(out, self.tp_group)
return out return out
@@ -35,6 +35,7 @@ from .batch_invariant_ops import (
is_batch_invariant_mode_enabled, is_batch_invariant_mode_enabled,
rms_norm_batch_invariant, rms_norm_batch_invariant,
) )
from .flashinfer_comm_fusion import flashinfer_allreduce_residual_rmsnorm
from .utils import get_tensor, modules_to_convert from .utils import get_tensor, modules_to_convert
@@ -122,6 +123,10 @@ class RMSNorm(nn.Layer):
self.tp_rank = self.fd_config.parallel_config.tensor_parallel_rank self.tp_rank = self.fd_config.parallel_config.tensor_parallel_rank
self.tp_group = self.fd_config.parallel_config.tp_group self.tp_group = self.fd_config.parallel_config.tp_group
is_input_norm = prefix.endswith(".input_layernorm") is_input_norm = prefix.endswith(".input_layernorm")
self.enable_all_reduce_fusion = (
fd_config.parallel_config.enable_flashinfer_allreduce_fusion and "post_attention_layernorm" in prefix
)
self.is_last_norm = prefix.endswith(".norm") self.is_last_norm = prefix.endswith(".norm")
self.split_x = ( self.split_x = (
self.fd_config.parallel_config.use_sequence_parallel_moe self.fd_config.parallel_config.use_sequence_parallel_moe
@@ -240,6 +245,12 @@ class RMSNorm(nn.Layer):
norm_out = rms_norm(x, self.weight, self.eps) norm_out = rms_norm(x, self.weight, self.eps)
return norm_out.astype(x_dtype), residual_out return norm_out.astype(x_dtype), residual_out
norm_out = self.norm_func(x, residual_input, self.weight, self.eps) norm_out = self.norm_func(x, residual_input, self.weight, self.eps)
# enable trtllm all reduce fusion
elif self.enable_all_reduce_fusion and x.shape[0] <= 2048:
norm_out = flashinfer_allreduce_residual_rmsnorm(
fd_config=self.fd_config, input_tensor=x, residual=residual_input, weight=self.weight, eps=self.eps
)
assert norm_out[0] is not None, "Trtllm-all-reduce fusion failed!"
else: else:
if is_batch_invariant_mode_enabled(): if is_batch_invariant_mode_enabled():
# M-invariant path: per-row Triton kernel, no cross-row reduction # M-invariant path: per-row Triton kernel, no cross-row reduction
@@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
""" """
import importlib
import importlib.util
import math import math
from enum import Enum from enum import Enum
from typing import Callable, Optional from typing import Callable, Optional
@@ -25,11 +23,12 @@ from paddle import nn
from fastdeploy import envs from fastdeploy import envs
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase
from fastdeploy.model_executor.utils import set_weight_attrs from fastdeploy.model_executor.utils import has_flashinfer, set_weight_attrs
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
if current_platform.is_cuda(): if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch
from fastdeploy.utils import get_logger from fastdeploy.utils import get_logger
from ..moe import FusedMoE from ..moe import FusedMoE
@@ -59,10 +58,6 @@ def check_device_capability(num):
return False return False
def has_flashinfer():
return importlib.util.find_spec("flashinfer") is not None
def round_up(a, b): def round_up(a, b):
return ((a + b - 1) // b) * b return ((a + b - 1) // b) * b
+1 -1
View File
@@ -130,7 +130,6 @@ class Glm4Moe(nn.Layer):
self.tensor_parallel_size = fd_config.parallel_config.tensor_parallel_size self.tensor_parallel_size = fd_config.parallel_config.tensor_parallel_size
self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
self.tp_group = fd_config.parallel_config.tp_group self.tp_group = fd_config.parallel_config.tp_group
self.use_ep = self.expert_parallel_size > 1 self.use_ep = self.expert_parallel_size > 1
self.use_tp = self.tensor_parallel_size > 1 self.use_tp = self.tensor_parallel_size > 1
@@ -229,6 +228,7 @@ class Glm4MoeAttention(nn.Layer):
input_size=fd_config.model_config.num_attention_heads * fd_config.model_config.head_dim, input_size=fd_config.model_config.num_attention_heads * fd_config.model_config.head_dim,
output_size=fd_config.model_config.hidden_size, output_size=fd_config.model_config.hidden_size,
layer_id=layer_id, layer_id=layer_id,
enable_all_reduce_fusion=fd_config.parallel_config.enable_flashinfer_allreduce_fusion,
) )
self.attn = Attention( self.attn = Attention(
+6
View File
@@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
""" """
import importlib
import importlib.util
import os import os
import re import re
from collections.abc import Mapping from collections.abc import Mapping
@@ -553,6 +555,10 @@ def rename_offline_ckpt_suffix_to_fd_suffix(
return fn return fn
def has_flashinfer():
return importlib.util.find_spec("flashinfer") is not None
@cache @cache
def get_sm_version(): def get_sm_version():
if paddle.cuda.is_available(): if paddle.cuda.is_available():
+6
View File
@@ -830,6 +830,12 @@ def parse_args():
default=None, default=None,
help="Configuration of SpeculativeConfig.", help="Configuration of SpeculativeConfig.",
) )
parser.add_argument(
"--enable_flashinfer_allreduce_fusion",
action="store_true",
default=False,
help="Flag to enable all reduce fusion kernel in flashinfer.",
)
parser.add_argument( parser.add_argument(
"--max_num_batched_tokens", "--max_num_batched_tokens",
type=int, type=int,
+1 -1
View File
@@ -46,6 +46,6 @@ setproctitle
aistudio_sdk aistudio_sdk
p2pstore p2pstore
py-cpuinfo py-cpuinfo
flashinfer-python-paddle flashinfer-python-paddle @ https://xly-devops.bj.bcebos.com/flashinfer/flashinfer_python_paddle-0.4.1.2-py3-none-any.whl
flash_mask @ https://xly-devops.bj.bcebos.com/flashmask/flash_mask-4.0.0%2Bg4c84f74-py3-none-any.whl flash_mask @ https://xly-devops.bj.bcebos.com/flashmask/flash_mask-4.0.0%2Bg4c84f74-py3-none-any.whl
transformers>=4.55.1,<5.0.0 transformers>=4.55.1,<5.0.0
@@ -31,6 +31,7 @@ def _make_minimal_rmsnorm(hidden_size, eps=1e-5, dtype="float32"):
layer.bias = None layer.bias = None
layer.split_x = False layer.split_x = False
layer.allgather_out = False layer.allgather_out = False
layer.enable_all_reduce_fusion = False
return layer return layer
+1
View File
@@ -39,6 +39,7 @@ def _make_cfg(**ov):
pc.use_internode_ll_two_stage = pc.disable_sequence_parallel_moe = False pc.use_internode_ll_two_stage = pc.disable_sequence_parallel_moe = False
pc.shutdown_comm_group_if_worker_idle = False pc.shutdown_comm_group_if_worker_idle = False
pc.ep_prefill_use_worst_num_tokens = False pc.ep_prefill_use_worst_num_tokens = False
pc.enable_flashinfer_allreduce_fusion = False
sc = ns(max_num_seqs=256, max_num_batched_tokens=4096, splitwise_role="mixed", name="local") sc = ns(max_num_seqs=256, max_num_batched_tokens=4096, splitwise_role="mixed", name="local")
sc.enable_overlap_schedule = False sc.enable_overlap_schedule = False
cc = ns(num_gpu_blocks_override=None, gpu_memory_utilization=0.9, block_size=16, enc_dec_block_num=0) cc = ns(num_gpu_blocks_override=None, gpu_memory_utilization=0.9, block_size=16, enc_dec_block_num=0)
@@ -0,0 +1,56 @@
"""
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import os
import subprocess
import sys
def test_run_distributed():
"""Launch multi-GPU distributed test via paddle.distributed.launch as subprocess"""
# clearn flashinfer cache directory
flashinfer_cache_dir = os.path.join(os.sep, "root", ".cache", "flashinfer")
if os.path.exists(flashinfer_cache_dir):
print(f"=== Clearing flashinfer cache directory: {flashinfer_cache_dir} ===")
subprocess.run(["rm", "-rf", flashinfer_cache_dir], check=True)
current_dir = os.path.dirname(os.path.abspath(__file__))
run_script = os.path.join(current_dir, "trtllm_allreduce_rms_fusion.py")
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
command = [
sys.executable,
"-m",
"paddle.distributed.launch",
"--gpus",
"0,1",
run_script,
]
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
try:
stdout, stderr = process.communicate(timeout=400)
return_code = process.returncode
except subprocess.TimeoutExpired:
process.kill()
stdout, stderr = process.communicate()
return_code = -1
print(f"=== Distributed test stdout ===\n{stdout}")
print(f"=== Distributed test stderr ===\n{stderr}")
assert return_code in (0, 250), f"Process exited with code {return_code}"
test_run_distributed()
+548
View File
@@ -0,0 +1,548 @@
"""
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import time
import unittest
from unittest.mock import Mock, patch
import numpy as np
import paddle
import paddle.distributed as dist
class TestFlashInferAllReduceResidualRMSNorm(unittest.TestCase):
"""Test FlashInfer AllReduce + Residual + RMSNorm fused operator"""
@classmethod
def setUpClass(cls):
"""Set up test environment"""
if paddle.is_compiled_with_cuda():
paddle.set_device("gpu")
else:
paddle.set_device("cpu")
dist.init_parallel_env()
def setUp(self):
"""Initialize each test case"""
# Fix random seed for reproducibility
paddle.seed(42)
np.random.seed(42)
self.dtype = paddle.float32
self.token_num = 128
self.hidden_dim = 768
self.eps = 1e-6
self.epsilon = 1e-6
self.max_token_num = 2048
# Create mock FDConfig
self.fd_config = Mock()
self.fd_config.parallel_config = Mock()
self.fd_config.parallel_config.tensor_parallel_size = dist.get_world_size()
self.begin_norm_axis = 1
# Performance test params - increase iterations for stability
self.warmup_iterations = 20 # Increase warmup
self.test_iterations = 200 # Increase test iterations
def tearDown(self):
"""Clean up resources"""
if paddle.is_compiled_with_cuda():
paddle.device.cuda.empty_cache()
paddle.device.cuda.synchronize()
def create_test_tensors(self):
"""Create test tensors"""
input_tensor = paddle.randn([self.token_num, self.hidden_dim], dtype=self.dtype)
residual = paddle.randn([self.token_num, self.hidden_dim], dtype=self.dtype)
weight = paddle.randn([self.hidden_dim], dtype=self.dtype)
return input_tensor, residual, weight
def compute_reference_output(self, input_tensor, residual, weight, eps):
"""Reference implementation: manually compute AllReduce + Residual + RMSNorm"""
# # Step 1: AllReduce (identity on single device)
# allreduce_out = input_tensor.clone()
# Apply all reduce operator
dist.all_reduce(input_tensor, op=dist.ReduceOp.SUM)
# Step 2: Add residual
residual_out = input_tensor + residual
# Step 3: RMSNorm
variance = residual_out.pow(2).mean(axis=-1, keepdim=True)
norm_out = residual_out * paddle.rsqrt(variance + eps)
norm_out = norm_out * weight
# dist.all_reduce(residual_out, op=dist.ReduceOp.SUM)
return norm_out, residual_out
def paddle_rms_fuse(self, input_tensor, residual, weight, eps):
from paddle.incubate.nn.functional import fused_rms_norm
# Apply all reduce operator
dist.all_reduce(input_tensor, op=dist.ReduceOp.SUM)
out_fused = fused_rms_norm(
input_tensor,
norm_weight=weight,
norm_bias=None,
epsilon=eps,
begin_norm_axis=self.begin_norm_axis,
bias=None,
residual=residual,
)
return out_fused[0], out_fused[1]
def flashinfer_rms_fuse(self, input_tensor, residual, weight, eps):
"""FlashInfer fused operator"""
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
flashinfer_allreduce_residual_rmsnorm,
)
norm_out, residual_out = flashinfer_allreduce_residual_rmsnorm(
fd_config=self.fd_config,
input_tensor=input_tensor,
residual=residual,
weight=weight,
eps=eps,
max_token_num=self.max_token_num,
use_oneshot=False,
)
return norm_out, residual_out
def benchmark_function(self, func, *args, name="", **kwargs):
"""
Improved performance benchmark
- Wait for GPU frequency stabilization
- Use median instead of mean (more stable)
- Filter outliers
"""
# Force GPU frequency stabilization
if paddle.is_compiled_with_cuda():
for _ in range(5):
paddle.device.cuda.synchronize()
time.sleep(0.01)
# Warmup - thorough warm-up
for _ in range(self.warmup_iterations):
result = func(*args, **kwargs)
if paddle.is_compiled_with_cuda():
paddle.device.cuda.synchronize()
# Extra wait to ensure GPU stability
if paddle.is_compiled_with_cuda():
paddle.device.cuda.synchronize()
time.sleep(0.1)
# Benchmark run
times = []
for i in range(self.test_iterations):
if paddle.is_compiled_with_cuda():
paddle.device.cuda.synchronize()
start = time.perf_counter()
result = func(*args, **kwargs)
if paddle.is_compiled_with_cuda():
paddle.device.cuda.synchronize()
end = time.perf_counter()
elapsed = (end - start) * 1000 # Convert to milliseconds
times.append(elapsed)
times = np.array(times)
# Filter outliers using IQR method
q1, q3 = np.percentile(times, [25, 75])
iqr = q3 - q1
lower_bound = q1 - 1.5 * iqr
upper_bound = q3 + 1.5 * iqr
filtered_times = times[(times >= lower_bound) & (times <= upper_bound)]
# Fall back to raw data if too many samples filtered out
if len(filtered_times) < self.test_iterations * 0.5:
filtered_times = times
# Statistics
avg_time = np.mean(filtered_times)
median_time = np.median(filtered_times)
std_time = np.std(filtered_times)
min_time = np.min(filtered_times)
max_time = np.max(filtered_times)
cv = (std_time / avg_time) * 100 # Coefficient of variation (%)
print(f"\n{'='*70}")
print(f"Performance Benchmark: {name}")
print(f"{'='*70}")
print(f"Iterations: {len(filtered_times)}/{self.test_iterations} (after {self.warmup_iterations} warmup)")
print(f"Median: {median_time:.4f} ms (most stable metric)")
print(f"Average: {avg_time:.4f} ms")
print(f"Std Dev: {std_time:.4f} ms (CV: {cv:.2f}%)")
print(f"Min: {min_time:.4f} ms")
print(f"Max: {max_time:.4f} ms")
print(f"{'='*70}\n")
# Return median (more stable) and result
return median_time, result
def test_accuracy_fused_vs_reference(self):
"""Test accuracy of fused operator vs reference implementation"""
input_tensor, residual, weight = self.create_test_tensors()
reference_output, ref_res = self.compute_reference_output(
input_tensor.clone(), residual.clone(), weight.clone(), self.eps
)
fused_output, paddle_res = self.paddle_rms_fuse(
input_tensor.clone(), residual.clone(), weight.clone(), self.eps
)
flashinfer_output, flashinfer_res = self.flashinfer_rms_fuse(
input_tensor.clone(), residual.clone(), weight.clone(), self.eps
)
# Verify results
np.testing.assert_allclose(fused_output.numpy(), reference_output.numpy(), rtol=1e-5, atol=1e-5)
np.testing.assert_allclose(ref_res.numpy(), paddle_res.numpy(), rtol=1e-5, atol=1e-5)
np.testing.assert_allclose(flashinfer_output.numpy(), reference_output.numpy(), rtol=1e-5, atol=1e-5)
np.testing.assert_allclose(ref_res.numpy(), flashinfer_res.numpy(), rtol=1e-5, atol=1e-5)
class TestFlashInferWorkspaceManager(unittest.TestCase):
"""Test FlashInferWorkspaceManager"""
def setUp(self):
"""Initialize"""
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
FlashInferWorkspaceManager,
)
self.manager = FlashInferWorkspaceManager()
def test_initialization(self):
"""Test initialization state"""
self.assertIsNone(self.manager.workspace_tensor)
self.assertIsNone(self.manager.ipc_handles)
self.assertIsNone(self.manager.world_size)
self.assertIsNone(self.manager.rank)
self.assertFalse(self.manager.initialized)
def test_cleanup(self):
"""Test cleanup functionality"""
self.manager.cleanup()
self.assertFalse(self.manager.initialized)
self.assertIsNone(self.manager.workspace_tensor)
class TestFlashInferWorkspaceManagerEdgeCases(unittest.TestCase):
"""Test FlashInferWorkspaceManager edge cases and fallback paths"""
def setUp(self):
"""Initialize test fixtures"""
# Patch before importing to test fallback paths
self.patcher_has_flashinfer = patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion.has_flashinfer")
self.mock_has_flashinfer = self.patcher_has_flashinfer.start()
def tearDown(self):
"""Clean up patches"""
self.patcher_has_flashinfer.stop()
def test_initialization_early_return_when_already_initialized(self):
"""Test line 47: early return when already initialized with same world_size"""
# Patch _flashinfer_comm to be available
with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm") as mock_comm:
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
FlashInferWorkspaceManager,
)
manager = FlashInferWorkspaceManager()
# First initialization
manager.initialized = True
manager.world_size = 2
# Mock the comm functions
mock_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion = Mock(return_value=(Mock(), Mock()))
# Second initialization with same world_size - should return early
manager.initialize(
world_size=2,
rank=0,
max_token_num=2048,
hidden_dim=4096,
)
def test_initialization_warning_when_comm_none(self):
"""Test: warning when _get_flashinfer_comm is None"""
# Patch to ensure _get_flashinfer_comm is None
with patch(
"fastdeploy.model_executor.layers.flashinfer_comm_fusion._get_flashinfer_comm",
return_value=None,
):
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
FlashInferWorkspaceManager,
)
manager = FlashInferWorkspaceManager()
# Should not raise, just log warning and return
manager.initialize(
world_size=2,
rank=0,
max_token_num=2048,
hidden_dim=4096,
)
# Verify not initialized
self.assertFalse(manager.initialized)
def test_cleanup_with_exception(self):
"""Test lines 73-80: cleanup with exception handling"""
with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm") as mock_comm:
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
FlashInferWorkspaceManager,
)
manager = FlashInferWorkspaceManager()
manager.initialized = True
manager.ipc_handles = Mock()
manager.workspace_tensor = Mock()
# Mock the destroy function to raise exception
mock_comm.trtllm_destroy_ipc_workspace_for_all_reduce = Mock(side_effect=RuntimeError("Cleanup error"))
# Should not raise, just log warning
manager.cleanup()
# Verify cleanup happened
self.assertFalse(manager.initialized)
self.assertIsNone(manager.workspace_tensor)
self.assertIsNone(manager.ipc_handles)
def test_cleanup_without_initialization(self):
"""Test cleanup when not initialized"""
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
FlashInferWorkspaceManager,
)
manager = FlashInferWorkspaceManager()
manager.initialized = False
# Should not raise
manager.cleanup()
# Verify state
self.assertFalse(manager.initialized)
class TestEnsureWorkspaceInitialized(unittest.TestCase):
"""Test ensure_workspace_initialized fallback paths"""
def setUp(self):
"""Initialize test fixtures"""
self.patcher_has_flashinfer = patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion.has_flashinfer")
self.mock_has_flashinfer = self.patcher_has_flashinfer.start()
def tearDown(self):
"""Clean up patches"""
self.patcher_has_flashinfer.stop()
def test_ensure_workspace_when_flashinfer_not_available(self):
"""Test line 91: early return when flashinfer not available"""
self.mock_has_flashinfer.return_value = False
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
ensure_workspace_initialized,
)
fd_config = Mock()
fd_config.parallel_config = Mock()
fd_config.parallel_config.tensor_parallel_size = 2
result = ensure_workspace_initialized(fd_config)
# Should return False (not initialized)
self.assertFalse(result)
def test_ensure_workspace_when_comm_none(self):
"""Test ensure_workspace_initialized when _get_flashinfer_comm is None"""
self.mock_has_flashinfer.return_value = True
with patch(
"fastdeploy.model_executor.layers.flashinfer_comm_fusion._get_flashinfer_comm",
return_value=None,
):
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
ensure_workspace_initialized,
)
fd_config = Mock()
fd_config.parallel_config = Mock()
fd_config.parallel_config.tensor_parallel_size = 2
result = ensure_workspace_initialized(fd_config)
# Should return False
self.assertFalse(result)
def test_ensure_workspace_single_gpu(self):
"""Test line 96: early return when world_size <= 1"""
self.mock_has_flashinfer.return_value = True
with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm"):
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
ensure_workspace_initialized,
)
fd_config = Mock()
fd_config.parallel_config = Mock()
fd_config.parallel_config.tensor_parallel_size = 1
with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion.dist.get_rank", return_value=0):
result = ensure_workspace_initialized(fd_config)
# Should return False for single GPU
self.assertFalse(result)
class TestFlashInferAllReduceResidualRMSNormFallbacks(unittest.TestCase):
"""Test flashinfer_allreduce_residual_rmsnorm fallback paths"""
def setUp(self):
"""Initialize test fixtures"""
self.patcher_has_flashinfer = patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion.has_flashinfer")
self.mock_has_flashinfer = self.patcher_has_flashinfer.start()
def tearDown(self):
"""Clean up patches"""
self.patcher_has_flashinfer.stop()
def test_flashinfer_not_available_fallback(self):
"""Test lines 140-141: fallback when flashinfer not available"""
self.mock_has_flashinfer.return_value = False
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
flashinfer_allreduce_residual_rmsnorm,
)
fd_config = Mock()
fd_config.parallel_config = Mock()
fd_config.parallel_config.tensor_parallel_size = 2
input_tensor = paddle.randn([128, 768])
residual = paddle.randn([128, 768])
weight = paddle.randn([768])
norm_out, residual_out = flashinfer_allreduce_residual_rmsnorm(
fd_config=fd_config,
input_tensor=input_tensor,
residual=residual,
weight=weight,
eps=1e-6,
max_token_num=2048,
)
# Should return None, None when flashinfer not available
self.assertIsNone(norm_out)
self.assertIsNone(residual_out)
def test_single_gpu_fallback(self):
"""Test lines 146-147: fallback for single GPU"""
self.mock_has_flashinfer.return_value = True
with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm"):
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
flashinfer_allreduce_residual_rmsnorm,
)
fd_config = Mock()
fd_config.parallel_config = Mock()
fd_config.parallel_config.tensor_parallel_size = 1
input_tensor = paddle.randn([128, 768])
residual = paddle.randn([128, 768])
weight = paddle.randn([768])
norm_out, residual_out = flashinfer_allreduce_residual_rmsnorm(
fd_config=fd_config,
input_tensor=input_tensor,
residual=residual,
weight=weight,
eps=1e-6,
max_token_num=2048,
)
# Should return None, None for single GPU
self.assertIsNone(norm_out)
self.assertIsNone(residual_out)
def test_empty_tensor_handling(self):
"""Test line 166: empty tensor handling"""
self.mock_has_flashinfer.return_value = True
with (
patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm") as mock_comm,
patch(
"fastdeploy.model_executor.layers.flashinfer_comm_fusion.ensure_workspace_initialized",
return_value=True,
),
):
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
flashinfer_allreduce_residual_rmsnorm,
)
fd_config = Mock()
fd_config.parallel_config = Mock()
fd_config.parallel_config.tensor_parallel_size = 2
# Empty tensor (0 tokens)
input_tensor = paddle.zeros([0, 768])
residual = paddle.zeros([0, 768])
weight = paddle.randn([768])
# Mock the trtllm_allreduce_fusion to not be called
mock_comm.trtllm_allreduce_fusion = Mock()
norm_out, residual_out = flashinfer_allreduce_residual_rmsnorm(
fd_config=fd_config,
input_tensor=input_tensor,
residual=residual,
weight=weight,
eps=1e-6,
max_token_num=2048,
)
# Should return empty tensors, not call flashinfer
self.assertEqual(norm_out.shape[0], 0)
self.assertEqual(residual_out.shape[0], 0)
mock_comm.trtllm_allreduce_fusion.assert_not_called()
class TestCleanupFlashInferWorkspace(unittest.TestCase):
"""Test cleanup_flashinfer_workspace function"""
def test_cleanup_workspace_function(self):
"""Test lines 211-212: cleanup function"""
with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._workspace_manager") as mock_manager:
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
cleanup_flashinfer_workspace,
)
mock_manager.cleanup = Mock()
cleanup_flashinfer_workspace()
mock_manager.cleanup.assert_called_once()
if __name__ == "__main__":
"""Run tests directly (called by subprocess after distributed launch)"""
unittest.main(verbosity=2)
+1
View File
@@ -58,6 +58,7 @@ def make_fd_config(
expert_parallel_size=1, expert_parallel_size=1,
tp_group=None, tp_group=None,
use_sequence_parallel_moe=use_sequence_parallel_moe, use_sequence_parallel_moe=use_sequence_parallel_moe,
enable_flashinfer_allreduce_fusion=False,
), ),
scheduler_config=SimpleNamespace(splitwise_role=splitwise_role, max_num_seqs=1), scheduler_config=SimpleNamespace(splitwise_role=splitwise_role, max_num_seqs=1),
load_config=SimpleNamespace( load_config=SimpleNamespace(