mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Optimization] enable trtllm_all_reduce fusion kernel in glm model (#6660)
* enable trtllm_all_reduce fusion kernel in glm model * fix conflict * format update * fix a bug * modify test * modify test * support empty tensor and modify test * fix test_linear config issues * modify test name * add edge test case * modify format * fix conflict * modify default max token num in trtllm_allreduce_fusion * add max token num branch for trtllm_allreduce_fusion * fix format * fix rmsnorm config issue * modify 2025 to 2026 * using compat grard * Lazily import flashinfer.comm and fix test config issue * fix test issues * add flashinfer cache dir clean machine * fix some issues
This commit is contained in:
@@ -671,6 +671,7 @@ class ParallelConfig:
|
|||||||
self.pod_ip: str = None
|
self.pod_ip: str = None
|
||||||
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).
|
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).
|
||||||
self.disable_custom_all_reduce: bool = False
|
self.disable_custom_all_reduce: bool = False
|
||||||
|
self.enable_flashinfer_allreduce_fusion: bool = False
|
||||||
for key, value in args.items():
|
for key, value in args.items():
|
||||||
if hasattr(self, key):
|
if hasattr(self, key):
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|||||||
@@ -274,6 +274,11 @@ class EngineArgs:
|
|||||||
Flag to disable the custom all-reduce kernel.
|
Flag to disable the custom all-reduce kernel.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
enable_flashinfer_allreduce_fusion: bool = False
|
||||||
|
"""
|
||||||
|
Flag to enable all reduce fusion kernel in flashinfer.
|
||||||
|
"""
|
||||||
|
|
||||||
use_internode_ll_two_stage: bool = False
|
use_internode_ll_two_stage: bool = False
|
||||||
"""
|
"""
|
||||||
Flag to use the internode_ll_two_stage kernel.
|
Flag to use the internode_ll_two_stage kernel.
|
||||||
@@ -995,6 +1000,12 @@ class EngineArgs:
|
|||||||
default=EngineArgs.disable_custom_all_reduce,
|
default=EngineArgs.disable_custom_all_reduce,
|
||||||
help="Flag to disable custom all-reduce.",
|
help="Flag to disable custom all-reduce.",
|
||||||
)
|
)
|
||||||
|
parallel_group.add_argument(
|
||||||
|
"--enable-flashinfer-allreduce-fusion",
|
||||||
|
action="store_true",
|
||||||
|
default=EngineArgs.enable_flashinfer_allreduce_fusion,
|
||||||
|
help="Flag to enable all reduce fusion kernel in flashinfer.",
|
||||||
|
)
|
||||||
parallel_group.add_argument(
|
parallel_group.add_argument(
|
||||||
"--use-internode-ll-two-stage",
|
"--use-internode-ll-two-stage",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
@@ -2503,6 +2503,7 @@ class EngineService:
|
|||||||
"moe_gate_fp32": self.cfg.model_config.moe_gate_fp32,
|
"moe_gate_fp32": self.cfg.model_config.moe_gate_fp32,
|
||||||
"enable_entropy": self.cfg.model_config.enable_entropy,
|
"enable_entropy": self.cfg.model_config.enable_entropy,
|
||||||
"enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule,
|
"enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule,
|
||||||
|
"enable_flashinfer_allreduce_fusion": self.cfg.parallel_config.enable_flashinfer_allreduce_fusion,
|
||||||
}
|
}
|
||||||
for worker_flag, value in worker_store_true_flag.items():
|
for worker_flag, value in worker_store_true_flag.items():
|
||||||
if value:
|
if value:
|
||||||
|
|||||||
@@ -656,6 +656,7 @@ class LLMEngine:
|
|||||||
"enable_entropy": self.cfg.model_config.enable_entropy,
|
"enable_entropy": self.cfg.model_config.enable_entropy,
|
||||||
"ep_prefill_use_worst_num_tokens": self.cfg.parallel_config.ep_prefill_use_worst_num_tokens,
|
"ep_prefill_use_worst_num_tokens": self.cfg.parallel_config.ep_prefill_use_worst_num_tokens,
|
||||||
"enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule,
|
"enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule,
|
||||||
|
"enable_flashinfer_allreduce_fusion": self.cfg.parallel_config.enable_flashinfer_allreduce_fusion,
|
||||||
}
|
}
|
||||||
for worker_flag, value in worker_store_true_flag.items():
|
for worker_flag, value in worker_store_true_flag.items():
|
||||||
if value:
|
if value:
|
||||||
|
|||||||
@@ -0,0 +1,209 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.distributed as dist
|
||||||
|
|
||||||
|
from fastdeploy.config import FDConfig
|
||||||
|
from fastdeploy.model_executor.utils import has_flashinfer
|
||||||
|
from fastdeploy.utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("flashinfer", "flashinfer.log")
|
||||||
|
|
||||||
|
_flashinfer_comm = None
|
||||||
|
_workspace_manager = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_flashinfer_comm():
|
||||||
|
"""Lazily import flashinfer.comm to avoid side effects at module load time."""
|
||||||
|
global _flashinfer_comm
|
||||||
|
if _flashinfer_comm is not None:
|
||||||
|
return _flashinfer_comm
|
||||||
|
if has_flashinfer():
|
||||||
|
try:
|
||||||
|
with paddle.use_compat_guard(enable=True, scope={"flashinfer"}):
|
||||||
|
import flashinfer.comm as comm
|
||||||
|
|
||||||
|
_flashinfer_comm = comm
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("flashinfer.comm is not available, falling back to standard " "implementation")
|
||||||
|
return _flashinfer_comm
|
||||||
|
|
||||||
|
|
||||||
|
class FlashInferWorkspaceManager:
|
||||||
|
def __init__(self):
|
||||||
|
self.workspace_tensor = None
|
||||||
|
self.ipc_handles = None
|
||||||
|
self.world_size = None
|
||||||
|
self.rank = None
|
||||||
|
self.initialized = False
|
||||||
|
|
||||||
|
def initialize(
|
||||||
|
self,
|
||||||
|
world_size: int,
|
||||||
|
rank: int,
|
||||||
|
max_token_num: int,
|
||||||
|
hidden_dim: int,
|
||||||
|
group=None,
|
||||||
|
use_fp32_lamport: bool = False,
|
||||||
|
):
|
||||||
|
"""Initialize workspace"""
|
||||||
|
if self.initialized and self.world_size == world_size:
|
||||||
|
return
|
||||||
|
|
||||||
|
comm = _get_flashinfer_comm()
|
||||||
|
if comm is None:
|
||||||
|
logger.warning("FlashInfer comm not available, skipping workspace " "initialization")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.cleanup()
|
||||||
|
|
||||||
|
self.ipc_handles, self.workspace_tensor = comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
max_token_num,
|
||||||
|
hidden_dim,
|
||||||
|
group=group,
|
||||||
|
use_fp32_lamport=use_fp32_lamport,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.world_size = world_size
|
||||||
|
self.rank = rank
|
||||||
|
self.initialized = True
|
||||||
|
|
||||||
|
logger.info(f"FlashInfer workspace initialized for rank {rank}, " f"world_size {world_size}")
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Clean up workspace"""
|
||||||
|
if self.initialized and self.ipc_handles is not None:
|
||||||
|
try:
|
||||||
|
comm = _get_flashinfer_comm()
|
||||||
|
if comm is not None:
|
||||||
|
comm.trtllm_destroy_ipc_workspace_for_all_reduce(self.ipc_handles, group=dist.get_group())
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to cleanup FlashInfer workspace: {e}")
|
||||||
|
finally:
|
||||||
|
self.workspace_tensor = None
|
||||||
|
self.ipc_handles = None
|
||||||
|
self.initialized = False
|
||||||
|
|
||||||
|
|
||||||
|
_workspace_manager = FlashInferWorkspaceManager()
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_workspace_initialized(
|
||||||
|
fd_config: FDConfig, max_token_num: int = 2048, hidden_dim: int = 4096, use_fp32_lamport: bool = False
|
||||||
|
):
|
||||||
|
"""Ensure workspace is initialized"""
|
||||||
|
comm = _get_flashinfer_comm()
|
||||||
|
if not has_flashinfer() or comm is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
assert fd_config is not None
|
||||||
|
world_size = fd_config.parallel_config.tensor_parallel_size
|
||||||
|
if world_size <= 1:
|
||||||
|
return False
|
||||||
|
|
||||||
|
rank = dist.get_rank()
|
||||||
|
|
||||||
|
if not _workspace_manager.initialized or _workspace_manager.world_size != world_size:
|
||||||
|
_workspace_manager.initialize(
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
|
max_token_num=max_token_num,
|
||||||
|
hidden_dim=hidden_dim,
|
||||||
|
use_fp32_lamport=use_fp32_lamport,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _workspace_manager.initialized
|
||||||
|
|
||||||
|
|
||||||
|
def flashinfer_allreduce_residual_rmsnorm(
|
||||||
|
fd_config: FDConfig,
|
||||||
|
input_tensor: paddle.Tensor,
|
||||||
|
residual: paddle.Tensor,
|
||||||
|
weight: paddle.Tensor,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
max_token_num: int = 2048,
|
||||||
|
use_oneshot: Optional[bool] = None,
|
||||||
|
trigger_completion_at_end: bool = False,
|
||||||
|
fp32_acc: bool = False,
|
||||||
|
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||||
|
"""
|
||||||
|
Use FlashInfer's fused allreduce + residual + RMS norm operation
|
||||||
|
"""
|
||||||
|
comm = _get_flashinfer_comm()
|
||||||
|
if not has_flashinfer() or comm is None:
|
||||||
|
logger.debug("FlashInfer not available, falling back to standard " "implementation")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
assert fd_config is not None
|
||||||
|
world_size = fd_config.parallel_config.tensor_parallel_size
|
||||||
|
if world_size <= 1:
|
||||||
|
logger.debug("Single GPU, no need for allreduce fusion")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
assert input_tensor.shape[0] <= max_token_num
|
||||||
|
|
||||||
|
if not ensure_workspace_initialized(
|
||||||
|
fd_config=fd_config,
|
||||||
|
max_token_num=max_token_num,
|
||||||
|
hidden_dim=input_tensor.shape[-1],
|
||||||
|
use_fp32_lamport=(input_tensor.dtype == paddle.float32),
|
||||||
|
):
|
||||||
|
logger.debug("FlashInfer workspace not available")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
token_num, hidden_dim = input_tensor.shape
|
||||||
|
|
||||||
|
residual_out = paddle.empty_like(residual)
|
||||||
|
norm_out = paddle.empty_like(input_tensor)
|
||||||
|
# support empty tensor
|
||||||
|
if input_tensor.shape[0] == 0:
|
||||||
|
return norm_out, residual_out
|
||||||
|
comm.trtllm_allreduce_fusion(
|
||||||
|
allreduce_in=input_tensor,
|
||||||
|
world_size=world_size,
|
||||||
|
world_rank=dist.get_rank(),
|
||||||
|
token_num=token_num,
|
||||||
|
hidden_dim=hidden_dim,
|
||||||
|
workspace_ptrs=_workspace_manager.workspace_tensor,
|
||||||
|
launch_with_pdl=True,
|
||||||
|
use_oneshot=use_oneshot,
|
||||||
|
trigger_completion_at_end=trigger_completion_at_end,
|
||||||
|
fp32_acc=fp32_acc,
|
||||||
|
pattern_code=(comm.AllReduceFusionPattern.kARResidualRMSNorm),
|
||||||
|
allreduce_out=None,
|
||||||
|
residual_in=residual,
|
||||||
|
residual_out=residual_out,
|
||||||
|
norm_out=norm_out,
|
||||||
|
quant_out=None,
|
||||||
|
scale_out=None,
|
||||||
|
rms_gamma=weight,
|
||||||
|
rms_eps=eps,
|
||||||
|
scale_factor=None,
|
||||||
|
layout_code=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return norm_out, residual_out
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_flashinfer_workspace():
|
||||||
|
global _workspace_manager
|
||||||
|
if _workspace_manager is not None:
|
||||||
|
_workspace_manager.cleanup()
|
||||||
@@ -854,6 +854,7 @@ class RowParallelLinear(LinearBase):
|
|||||||
skip_quant: bool = False,
|
skip_quant: bool = False,
|
||||||
weight_dtype: str = "",
|
weight_dtype: str = "",
|
||||||
layer_id: int = -1,
|
layer_id: int = -1,
|
||||||
|
enable_all_reduce_fusion: bool = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize a linear layer with additional parameters for inference and quantization.
|
Initialize a linear layer with additional parameters for inference and quantization.
|
||||||
@@ -865,9 +866,17 @@ class RowParallelLinear(LinearBase):
|
|||||||
input_size (int): Number of input features. Defaults to None.
|
input_size (int): Number of input features. Defaults to None.
|
||||||
output_size (int): Number of output features. Defaults to None.
|
output_size (int): Number of output features. Defaults to None.
|
||||||
with_bias (bool): Whether to include bias or not. Defaults to False.
|
with_bias (bool): Whether to include bias or not. Defaults to False.
|
||||||
skip_quant (bool): Whether to skip quantization. Defaults to False.
|
skip_quant (bool): Whether to skip quantization or not. Defaults to False.
|
||||||
|
enable_all_reduce_fusion (bool, optional): Whether to enable all-reduce fusion.
|
||||||
|
If None, it is determined by the config flag and prefix. Defaults to None.
|
||||||
"""
|
"""
|
||||||
self.fd_config = fd_config
|
self.fd_config = fd_config
|
||||||
|
if enable_all_reduce_fusion is None:
|
||||||
|
self.enable_all_reduce_fusion = False
|
||||||
|
else:
|
||||||
|
self.enable_all_reduce_fusion = (
|
||||||
|
fd_config.parallel_config.enable_flashinfer_allreduce_fusion and enable_all_reduce_fusion
|
||||||
|
)
|
||||||
self.ep_size = fd_config.parallel_config.expert_parallel_size
|
self.ep_size = fd_config.parallel_config.expert_parallel_size
|
||||||
self.tp_size = fd_config.parallel_config.tensor_parallel_size
|
self.tp_size = fd_config.parallel_config.tensor_parallel_size
|
||||||
self.tp_group = fd_config.parallel_config.tp_group
|
self.tp_group = fd_config.parallel_config.tp_group
|
||||||
@@ -945,7 +954,10 @@ class RowParallelLinear(LinearBase):
|
|||||||
|
|
||||||
out = self.quant_method.apply(self, x)
|
out = self.quant_method.apply(self, x)
|
||||||
|
|
||||||
if self.reduce_results and self.tp_size > 1:
|
need_tp_all_reduce = (
|
||||||
|
self.reduce_results and self.tp_size > 1 and not (self.enable_all_reduce_fusion and out.shape[0] <= 2048)
|
||||||
|
)
|
||||||
|
if need_tp_all_reduce:
|
||||||
out = tensor_model_parallel_all_reduce(out, self.tp_group)
|
out = tensor_model_parallel_all_reduce(out, self.tp_group)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from .batch_invariant_ops import (
|
|||||||
is_batch_invariant_mode_enabled,
|
is_batch_invariant_mode_enabled,
|
||||||
rms_norm_batch_invariant,
|
rms_norm_batch_invariant,
|
||||||
)
|
)
|
||||||
|
from .flashinfer_comm_fusion import flashinfer_allreduce_residual_rmsnorm
|
||||||
from .utils import get_tensor, modules_to_convert
|
from .utils import get_tensor, modules_to_convert
|
||||||
|
|
||||||
|
|
||||||
@@ -122,6 +123,10 @@ class RMSNorm(nn.Layer):
|
|||||||
self.tp_rank = self.fd_config.parallel_config.tensor_parallel_rank
|
self.tp_rank = self.fd_config.parallel_config.tensor_parallel_rank
|
||||||
self.tp_group = self.fd_config.parallel_config.tp_group
|
self.tp_group = self.fd_config.parallel_config.tp_group
|
||||||
is_input_norm = prefix.endswith(".input_layernorm")
|
is_input_norm = prefix.endswith(".input_layernorm")
|
||||||
|
self.enable_all_reduce_fusion = (
|
||||||
|
fd_config.parallel_config.enable_flashinfer_allreduce_fusion and "post_attention_layernorm" in prefix
|
||||||
|
)
|
||||||
|
|
||||||
self.is_last_norm = prefix.endswith(".norm")
|
self.is_last_norm = prefix.endswith(".norm")
|
||||||
self.split_x = (
|
self.split_x = (
|
||||||
self.fd_config.parallel_config.use_sequence_parallel_moe
|
self.fd_config.parallel_config.use_sequence_parallel_moe
|
||||||
@@ -240,6 +245,12 @@ class RMSNorm(nn.Layer):
|
|||||||
norm_out = rms_norm(x, self.weight, self.eps)
|
norm_out = rms_norm(x, self.weight, self.eps)
|
||||||
return norm_out.astype(x_dtype), residual_out
|
return norm_out.astype(x_dtype), residual_out
|
||||||
norm_out = self.norm_func(x, residual_input, self.weight, self.eps)
|
norm_out = self.norm_func(x, residual_input, self.weight, self.eps)
|
||||||
|
# enable trtllm all reduce fusion
|
||||||
|
elif self.enable_all_reduce_fusion and x.shape[0] <= 2048:
|
||||||
|
norm_out = flashinfer_allreduce_residual_rmsnorm(
|
||||||
|
fd_config=self.fd_config, input_tensor=x, residual=residual_input, weight=self.weight, eps=self.eps
|
||||||
|
)
|
||||||
|
assert norm_out[0] is not None, "Trtllm-all-reduce fusion failed!"
|
||||||
else:
|
else:
|
||||||
if is_batch_invariant_mode_enabled():
|
if is_batch_invariant_mode_enabled():
|
||||||
# M-invariant path: per-row Triton kernel, no cross-row reduction
|
# M-invariant path: per-row Triton kernel, no cross-row reduction
|
||||||
|
|||||||
@@ -14,8 +14,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import importlib
|
|
||||||
import importlib.util
|
|
||||||
import math
|
import math
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
@@ -25,11 +23,12 @@ from paddle import nn
|
|||||||
|
|
||||||
from fastdeploy import envs
|
from fastdeploy import envs
|
||||||
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase
|
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase
|
||||||
from fastdeploy.model_executor.utils import set_weight_attrs
|
from fastdeploy.model_executor.utils import has_flashinfer, set_weight_attrs
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
|
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch
|
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch
|
||||||
|
|
||||||
from fastdeploy.utils import get_logger
|
from fastdeploy.utils import get_logger
|
||||||
|
|
||||||
from ..moe import FusedMoE
|
from ..moe import FusedMoE
|
||||||
@@ -59,10 +58,6 @@ def check_device_capability(num):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def has_flashinfer():
|
|
||||||
return importlib.util.find_spec("flashinfer") is not None
|
|
||||||
|
|
||||||
|
|
||||||
def round_up(a, b):
|
def round_up(a, b):
|
||||||
return ((a + b - 1) // b) * b
|
return ((a + b - 1) // b) * b
|
||||||
|
|
||||||
|
|||||||
@@ -130,7 +130,6 @@ class Glm4Moe(nn.Layer):
|
|||||||
self.tensor_parallel_size = fd_config.parallel_config.tensor_parallel_size
|
self.tensor_parallel_size = fd_config.parallel_config.tensor_parallel_size
|
||||||
self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
|
self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
|
||||||
self.tp_group = fd_config.parallel_config.tp_group
|
self.tp_group = fd_config.parallel_config.tp_group
|
||||||
|
|
||||||
self.use_ep = self.expert_parallel_size > 1
|
self.use_ep = self.expert_parallel_size > 1
|
||||||
self.use_tp = self.tensor_parallel_size > 1
|
self.use_tp = self.tensor_parallel_size > 1
|
||||||
|
|
||||||
@@ -229,6 +228,7 @@ class Glm4MoeAttention(nn.Layer):
|
|||||||
input_size=fd_config.model_config.num_attention_heads * fd_config.model_config.head_dim,
|
input_size=fd_config.model_config.num_attention_heads * fd_config.model_config.head_dim,
|
||||||
output_size=fd_config.model_config.hidden_size,
|
output_size=fd_config.model_config.hidden_size,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
enable_all_reduce_fusion=fd_config.parallel_config.enable_flashinfer_allreduce_fusion,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.attn = Attention(
|
self.attn = Attention(
|
||||||
|
|||||||
@@ -14,6 +14,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import importlib.util
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
@@ -553,6 +555,10 @@ def rename_offline_ckpt_suffix_to_fd_suffix(
|
|||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
def has_flashinfer():
|
||||||
|
return importlib.util.find_spec("flashinfer") is not None
|
||||||
|
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def get_sm_version():
|
def get_sm_version():
|
||||||
if paddle.cuda.is_available():
|
if paddle.cuda.is_available():
|
||||||
|
|||||||
@@ -830,6 +830,12 @@ def parse_args():
|
|||||||
default=None,
|
default=None,
|
||||||
help="Configuration of SpeculativeConfig.",
|
help="Configuration of SpeculativeConfig.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable_flashinfer_allreduce_fusion",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Flag to enable all reduce fusion kernel in flashinfer.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_num_batched_tokens",
|
"--max_num_batched_tokens",
|
||||||
type=int,
|
type=int,
|
||||||
|
|||||||
+1
-1
@@ -46,6 +46,6 @@ setproctitle
|
|||||||
aistudio_sdk
|
aistudio_sdk
|
||||||
p2pstore
|
p2pstore
|
||||||
py-cpuinfo
|
py-cpuinfo
|
||||||
flashinfer-python-paddle
|
flashinfer-python-paddle @ https://xly-devops.bj.bcebos.com/flashinfer/flashinfer_python_paddle-0.4.1.2-py3-none-any.whl
|
||||||
flash_mask @ https://xly-devops.bj.bcebos.com/flashmask/flash_mask-4.0.0%2Bg4c84f74-py3-none-any.whl
|
flash_mask @ https://xly-devops.bj.bcebos.com/flashmask/flash_mask-4.0.0%2Bg4c84f74-py3-none-any.whl
|
||||||
transformers>=4.55.1,<5.0.0
|
transformers>=4.55.1,<5.0.0
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ def _make_minimal_rmsnorm(hidden_size, eps=1e-5, dtype="float32"):
|
|||||||
layer.bias = None
|
layer.bias = None
|
||||||
layer.split_x = False
|
layer.split_x = False
|
||||||
layer.allgather_out = False
|
layer.allgather_out = False
|
||||||
|
layer.enable_all_reduce_fusion = False
|
||||||
return layer
|
return layer
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ def _make_cfg(**ov):
|
|||||||
pc.use_internode_ll_two_stage = pc.disable_sequence_parallel_moe = False
|
pc.use_internode_ll_two_stage = pc.disable_sequence_parallel_moe = False
|
||||||
pc.shutdown_comm_group_if_worker_idle = False
|
pc.shutdown_comm_group_if_worker_idle = False
|
||||||
pc.ep_prefill_use_worst_num_tokens = False
|
pc.ep_prefill_use_worst_num_tokens = False
|
||||||
|
pc.enable_flashinfer_allreduce_fusion = False
|
||||||
sc = ns(max_num_seqs=256, max_num_batched_tokens=4096, splitwise_role="mixed", name="local")
|
sc = ns(max_num_seqs=256, max_num_batched_tokens=4096, splitwise_role="mixed", name="local")
|
||||||
sc.enable_overlap_schedule = False
|
sc.enable_overlap_schedule = False
|
||||||
cc = ns(num_gpu_blocks_override=None, gpu_memory_utilization=0.9, block_size=16, enc_dec_block_num=0)
|
cc = ns(num_gpu_blocks_override=None, gpu_memory_utilization=0.9, block_size=16, enc_dec_block_num=0)
|
||||||
|
|||||||
@@ -0,0 +1,56 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_distributed():
|
||||||
|
"""Launch multi-GPU distributed test via paddle.distributed.launch as subprocess"""
|
||||||
|
# clearn flashinfer cache directory
|
||||||
|
flashinfer_cache_dir = os.path.join(os.sep, "root", ".cache", "flashinfer")
|
||||||
|
if os.path.exists(flashinfer_cache_dir):
|
||||||
|
print(f"=== Clearing flashinfer cache directory: {flashinfer_cache_dir} ===")
|
||||||
|
subprocess.run(["rm", "-rf", flashinfer_cache_dir], check=True)
|
||||||
|
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
run_script = os.path.join(current_dir, "trtllm_allreduce_rms_fusion.py")
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
||||||
|
command = [
|
||||||
|
sys.executable,
|
||||||
|
"-m",
|
||||||
|
"paddle.distributed.launch",
|
||||||
|
"--gpus",
|
||||||
|
"0,1",
|
||||||
|
run_script,
|
||||||
|
]
|
||||||
|
|
||||||
|
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
stdout, stderr = process.communicate(timeout=400)
|
||||||
|
return_code = process.returncode
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
process.kill()
|
||||||
|
stdout, stderr = process.communicate()
|
||||||
|
return_code = -1
|
||||||
|
print(f"=== Distributed test stdout ===\n{stdout}")
|
||||||
|
print(f"=== Distributed test stderr ===\n{stderr}")
|
||||||
|
assert return_code in (0, 250), f"Process exited with code {return_code}"
|
||||||
|
|
||||||
|
|
||||||
|
test_run_distributed()
|
||||||
@@ -0,0 +1,548 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
import paddle.distributed as dist
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlashInferAllReduceResidualRMSNorm(unittest.TestCase):
|
||||||
|
"""Test FlashInfer AllReduce + Residual + RMSNorm fused operator"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
"""Set up test environment"""
|
||||||
|
if paddle.is_compiled_with_cuda():
|
||||||
|
paddle.set_device("gpu")
|
||||||
|
else:
|
||||||
|
paddle.set_device("cpu")
|
||||||
|
dist.init_parallel_env()
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Initialize each test case"""
|
||||||
|
# Fix random seed for reproducibility
|
||||||
|
paddle.seed(42)
|
||||||
|
np.random.seed(42)
|
||||||
|
|
||||||
|
self.dtype = paddle.float32
|
||||||
|
self.token_num = 128
|
||||||
|
self.hidden_dim = 768
|
||||||
|
self.eps = 1e-6
|
||||||
|
self.epsilon = 1e-6
|
||||||
|
self.max_token_num = 2048
|
||||||
|
|
||||||
|
# Create mock FDConfig
|
||||||
|
self.fd_config = Mock()
|
||||||
|
self.fd_config.parallel_config = Mock()
|
||||||
|
self.fd_config.parallel_config.tensor_parallel_size = dist.get_world_size()
|
||||||
|
self.begin_norm_axis = 1
|
||||||
|
|
||||||
|
# Performance test params - increase iterations for stability
|
||||||
|
self.warmup_iterations = 20 # Increase warmup
|
||||||
|
self.test_iterations = 200 # Increase test iterations
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
"""Clean up resources"""
|
||||||
|
if paddle.is_compiled_with_cuda():
|
||||||
|
paddle.device.cuda.empty_cache()
|
||||||
|
paddle.device.cuda.synchronize()
|
||||||
|
|
||||||
|
def create_test_tensors(self):
|
||||||
|
"""Create test tensors"""
|
||||||
|
input_tensor = paddle.randn([self.token_num, self.hidden_dim], dtype=self.dtype)
|
||||||
|
residual = paddle.randn([self.token_num, self.hidden_dim], dtype=self.dtype)
|
||||||
|
weight = paddle.randn([self.hidden_dim], dtype=self.dtype)
|
||||||
|
return input_tensor, residual, weight
|
||||||
|
|
||||||
|
def compute_reference_output(self, input_tensor, residual, weight, eps):
|
||||||
|
"""Reference implementation: manually compute AllReduce + Residual + RMSNorm"""
|
||||||
|
# # Step 1: AllReduce (identity on single device)
|
||||||
|
# allreduce_out = input_tensor.clone()
|
||||||
|
# Apply all reduce operator
|
||||||
|
dist.all_reduce(input_tensor, op=dist.ReduceOp.SUM)
|
||||||
|
# Step 2: Add residual
|
||||||
|
residual_out = input_tensor + residual
|
||||||
|
|
||||||
|
# Step 3: RMSNorm
|
||||||
|
variance = residual_out.pow(2).mean(axis=-1, keepdim=True)
|
||||||
|
norm_out = residual_out * paddle.rsqrt(variance + eps)
|
||||||
|
norm_out = norm_out * weight
|
||||||
|
|
||||||
|
# dist.all_reduce(residual_out, op=dist.ReduceOp.SUM)
|
||||||
|
return norm_out, residual_out
|
||||||
|
|
||||||
|
def paddle_rms_fuse(self, input_tensor, residual, weight, eps):
|
||||||
|
from paddle.incubate.nn.functional import fused_rms_norm
|
||||||
|
|
||||||
|
# Apply all reduce operator
|
||||||
|
dist.all_reduce(input_tensor, op=dist.ReduceOp.SUM)
|
||||||
|
out_fused = fused_rms_norm(
|
||||||
|
input_tensor,
|
||||||
|
norm_weight=weight,
|
||||||
|
norm_bias=None,
|
||||||
|
epsilon=eps,
|
||||||
|
begin_norm_axis=self.begin_norm_axis,
|
||||||
|
bias=None,
|
||||||
|
residual=residual,
|
||||||
|
)
|
||||||
|
|
||||||
|
return out_fused[0], out_fused[1]
|
||||||
|
|
||||||
|
def flashinfer_rms_fuse(self, input_tensor, residual, weight, eps):
|
||||||
|
"""FlashInfer fused operator"""
|
||||||
|
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
|
||||||
|
flashinfer_allreduce_residual_rmsnorm,
|
||||||
|
)
|
||||||
|
|
||||||
|
norm_out, residual_out = flashinfer_allreduce_residual_rmsnorm(
|
||||||
|
fd_config=self.fd_config,
|
||||||
|
input_tensor=input_tensor,
|
||||||
|
residual=residual,
|
||||||
|
weight=weight,
|
||||||
|
eps=eps,
|
||||||
|
max_token_num=self.max_token_num,
|
||||||
|
use_oneshot=False,
|
||||||
|
)
|
||||||
|
return norm_out, residual_out
|
||||||
|
|
||||||
|
def benchmark_function(self, func, *args, name="", **kwargs):
|
||||||
|
"""
|
||||||
|
Improved performance benchmark
|
||||||
|
- Wait for GPU frequency stabilization
|
||||||
|
- Use median instead of mean (more stable)
|
||||||
|
- Filter outliers
|
||||||
|
"""
|
||||||
|
# Force GPU frequency stabilization
|
||||||
|
if paddle.is_compiled_with_cuda():
|
||||||
|
for _ in range(5):
|
||||||
|
paddle.device.cuda.synchronize()
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
# Warmup - thorough warm-up
|
||||||
|
for _ in range(self.warmup_iterations):
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
if paddle.is_compiled_with_cuda():
|
||||||
|
paddle.device.cuda.synchronize()
|
||||||
|
|
||||||
|
# Extra wait to ensure GPU stability
|
||||||
|
if paddle.is_compiled_with_cuda():
|
||||||
|
paddle.device.cuda.synchronize()
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# Benchmark run
|
||||||
|
times = []
|
||||||
|
for i in range(self.test_iterations):
|
||||||
|
if paddle.is_compiled_with_cuda():
|
||||||
|
paddle.device.cuda.synchronize()
|
||||||
|
|
||||||
|
start = time.perf_counter()
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
|
||||||
|
if paddle.is_compiled_with_cuda():
|
||||||
|
paddle.device.cuda.synchronize()
|
||||||
|
|
||||||
|
end = time.perf_counter()
|
||||||
|
elapsed = (end - start) * 1000 # Convert to milliseconds
|
||||||
|
times.append(elapsed)
|
||||||
|
|
||||||
|
times = np.array(times)
|
||||||
|
|
||||||
|
# Filter outliers using IQR method
|
||||||
|
q1, q3 = np.percentile(times, [25, 75])
|
||||||
|
iqr = q3 - q1
|
||||||
|
lower_bound = q1 - 1.5 * iqr
|
||||||
|
upper_bound = q3 + 1.5 * iqr
|
||||||
|
filtered_times = times[(times >= lower_bound) & (times <= upper_bound)]
|
||||||
|
|
||||||
|
# Fall back to raw data if too many samples filtered out
|
||||||
|
if len(filtered_times) < self.test_iterations * 0.5:
|
||||||
|
filtered_times = times
|
||||||
|
|
||||||
|
# Statistics
|
||||||
|
avg_time = np.mean(filtered_times)
|
||||||
|
median_time = np.median(filtered_times)
|
||||||
|
std_time = np.std(filtered_times)
|
||||||
|
min_time = np.min(filtered_times)
|
||||||
|
max_time = np.max(filtered_times)
|
||||||
|
cv = (std_time / avg_time) * 100 # Coefficient of variation (%)
|
||||||
|
|
||||||
|
print(f"\n{'='*70}")
|
||||||
|
print(f"Performance Benchmark: {name}")
|
||||||
|
print(f"{'='*70}")
|
||||||
|
print(f"Iterations: {len(filtered_times)}/{self.test_iterations} (after {self.warmup_iterations} warmup)")
|
||||||
|
print(f"Median: {median_time:.4f} ms (most stable metric)")
|
||||||
|
print(f"Average: {avg_time:.4f} ms")
|
||||||
|
print(f"Std Dev: {std_time:.4f} ms (CV: {cv:.2f}%)")
|
||||||
|
print(f"Min: {min_time:.4f} ms")
|
||||||
|
print(f"Max: {max_time:.4f} ms")
|
||||||
|
print(f"{'='*70}\n")
|
||||||
|
|
||||||
|
# Return median (more stable) and result
|
||||||
|
return median_time, result
|
||||||
|
|
||||||
|
def test_accuracy_fused_vs_reference(self):
|
||||||
|
"""Test accuracy of fused operator vs reference implementation"""
|
||||||
|
input_tensor, residual, weight = self.create_test_tensors()
|
||||||
|
reference_output, ref_res = self.compute_reference_output(
|
||||||
|
input_tensor.clone(), residual.clone(), weight.clone(), self.eps
|
||||||
|
)
|
||||||
|
fused_output, paddle_res = self.paddle_rms_fuse(
|
||||||
|
input_tensor.clone(), residual.clone(), weight.clone(), self.eps
|
||||||
|
)
|
||||||
|
flashinfer_output, flashinfer_res = self.flashinfer_rms_fuse(
|
||||||
|
input_tensor.clone(), residual.clone(), weight.clone(), self.eps
|
||||||
|
)
|
||||||
|
# Verify results
|
||||||
|
np.testing.assert_allclose(fused_output.numpy(), reference_output.numpy(), rtol=1e-5, atol=1e-5)
|
||||||
|
np.testing.assert_allclose(ref_res.numpy(), paddle_res.numpy(), rtol=1e-5, atol=1e-5)
|
||||||
|
np.testing.assert_allclose(flashinfer_output.numpy(), reference_output.numpy(), rtol=1e-5, atol=1e-5)
|
||||||
|
np.testing.assert_allclose(ref_res.numpy(), flashinfer_res.numpy(), rtol=1e-5, atol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlashInferWorkspaceManager(unittest.TestCase):
|
||||||
|
"""Test FlashInferWorkspaceManager"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Initialize"""
|
||||||
|
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
|
||||||
|
FlashInferWorkspaceManager,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.manager = FlashInferWorkspaceManager()
|
||||||
|
|
||||||
|
def test_initialization(self):
|
||||||
|
"""Test initialization state"""
|
||||||
|
self.assertIsNone(self.manager.workspace_tensor)
|
||||||
|
self.assertIsNone(self.manager.ipc_handles)
|
||||||
|
self.assertIsNone(self.manager.world_size)
|
||||||
|
self.assertIsNone(self.manager.rank)
|
||||||
|
self.assertFalse(self.manager.initialized)
|
||||||
|
|
||||||
|
def test_cleanup(self):
|
||||||
|
"""Test cleanup functionality"""
|
||||||
|
self.manager.cleanup()
|
||||||
|
self.assertFalse(self.manager.initialized)
|
||||||
|
self.assertIsNone(self.manager.workspace_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlashInferWorkspaceManagerEdgeCases(unittest.TestCase):
|
||||||
|
"""Test FlashInferWorkspaceManager edge cases and fallback paths"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Initialize test fixtures"""
|
||||||
|
# Patch before importing to test fallback paths
|
||||||
|
self.patcher_has_flashinfer = patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion.has_flashinfer")
|
||||||
|
self.mock_has_flashinfer = self.patcher_has_flashinfer.start()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
"""Clean up patches"""
|
||||||
|
self.patcher_has_flashinfer.stop()
|
||||||
|
|
||||||
|
def test_initialization_early_return_when_already_initialized(self):
|
||||||
|
"""Test line 47: early return when already initialized with same world_size"""
|
||||||
|
# Patch _flashinfer_comm to be available
|
||||||
|
with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm") as mock_comm:
|
||||||
|
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
|
||||||
|
FlashInferWorkspaceManager,
|
||||||
|
)
|
||||||
|
|
||||||
|
manager = FlashInferWorkspaceManager()
|
||||||
|
|
||||||
|
# First initialization
|
||||||
|
manager.initialized = True
|
||||||
|
manager.world_size = 2
|
||||||
|
|
||||||
|
# Mock the comm functions
|
||||||
|
mock_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion = Mock(return_value=(Mock(), Mock()))
|
||||||
|
|
||||||
|
# Second initialization with same world_size - should return early
|
||||||
|
manager.initialize(
|
||||||
|
world_size=2,
|
||||||
|
rank=0,
|
||||||
|
max_token_num=2048,
|
||||||
|
hidden_dim=4096,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_initialization_warning_when_comm_none(self):
|
||||||
|
"""Test: warning when _get_flashinfer_comm is None"""
|
||||||
|
# Patch to ensure _get_flashinfer_comm is None
|
||||||
|
with patch(
|
||||||
|
"fastdeploy.model_executor.layers.flashinfer_comm_fusion._get_flashinfer_comm",
|
||||||
|
return_value=None,
|
||||||
|
):
|
||||||
|
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
|
||||||
|
FlashInferWorkspaceManager,
|
||||||
|
)
|
||||||
|
|
||||||
|
manager = FlashInferWorkspaceManager()
|
||||||
|
|
||||||
|
# Should not raise, just log warning and return
|
||||||
|
manager.initialize(
|
||||||
|
world_size=2,
|
||||||
|
rank=0,
|
||||||
|
max_token_num=2048,
|
||||||
|
hidden_dim=4096,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify not initialized
|
||||||
|
self.assertFalse(manager.initialized)
|
||||||
|
|
||||||
|
def test_cleanup_with_exception(self):
|
||||||
|
"""Test lines 73-80: cleanup with exception handling"""
|
||||||
|
with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm") as mock_comm:
|
||||||
|
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
|
||||||
|
FlashInferWorkspaceManager,
|
||||||
|
)
|
||||||
|
|
||||||
|
manager = FlashInferWorkspaceManager()
|
||||||
|
manager.initialized = True
|
||||||
|
manager.ipc_handles = Mock()
|
||||||
|
manager.workspace_tensor = Mock()
|
||||||
|
|
||||||
|
# Mock the destroy function to raise exception
|
||||||
|
mock_comm.trtllm_destroy_ipc_workspace_for_all_reduce = Mock(side_effect=RuntimeError("Cleanup error"))
|
||||||
|
|
||||||
|
# Should not raise, just log warning
|
||||||
|
manager.cleanup()
|
||||||
|
|
||||||
|
# Verify cleanup happened
|
||||||
|
self.assertFalse(manager.initialized)
|
||||||
|
self.assertIsNone(manager.workspace_tensor)
|
||||||
|
self.assertIsNone(manager.ipc_handles)
|
||||||
|
|
||||||
|
def test_cleanup_without_initialization(self):
|
||||||
|
"""Test cleanup when not initialized"""
|
||||||
|
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
|
||||||
|
FlashInferWorkspaceManager,
|
||||||
|
)
|
||||||
|
|
||||||
|
manager = FlashInferWorkspaceManager()
|
||||||
|
manager.initialized = False
|
||||||
|
|
||||||
|
# Should not raise
|
||||||
|
manager.cleanup()
|
||||||
|
|
||||||
|
# Verify state
|
||||||
|
self.assertFalse(manager.initialized)
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnsureWorkspaceInitialized(unittest.TestCase):
|
||||||
|
"""Test ensure_workspace_initialized fallback paths"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Initialize test fixtures"""
|
||||||
|
self.patcher_has_flashinfer = patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion.has_flashinfer")
|
||||||
|
self.mock_has_flashinfer = self.patcher_has_flashinfer.start()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
"""Clean up patches"""
|
||||||
|
self.patcher_has_flashinfer.stop()
|
||||||
|
|
||||||
|
def test_ensure_workspace_when_flashinfer_not_available(self):
|
||||||
|
"""Test line 91: early return when flashinfer not available"""
|
||||||
|
self.mock_has_flashinfer.return_value = False
|
||||||
|
|
||||||
|
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
|
||||||
|
ensure_workspace_initialized,
|
||||||
|
)
|
||||||
|
|
||||||
|
fd_config = Mock()
|
||||||
|
fd_config.parallel_config = Mock()
|
||||||
|
fd_config.parallel_config.tensor_parallel_size = 2
|
||||||
|
|
||||||
|
result = ensure_workspace_initialized(fd_config)
|
||||||
|
|
||||||
|
# Should return False (not initialized)
|
||||||
|
self.assertFalse(result)
|
||||||
|
|
||||||
|
def test_ensure_workspace_when_comm_none(self):
|
||||||
|
"""Test ensure_workspace_initialized when _get_flashinfer_comm is None"""
|
||||||
|
self.mock_has_flashinfer.return_value = True
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"fastdeploy.model_executor.layers.flashinfer_comm_fusion._get_flashinfer_comm",
|
||||||
|
return_value=None,
|
||||||
|
):
|
||||||
|
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
|
||||||
|
ensure_workspace_initialized,
|
||||||
|
)
|
||||||
|
|
||||||
|
fd_config = Mock()
|
||||||
|
fd_config.parallel_config = Mock()
|
||||||
|
fd_config.parallel_config.tensor_parallel_size = 2
|
||||||
|
|
||||||
|
result = ensure_workspace_initialized(fd_config)
|
||||||
|
|
||||||
|
# Should return False
|
||||||
|
self.assertFalse(result)
|
||||||
|
|
||||||
|
def test_ensure_workspace_single_gpu(self):
|
||||||
|
"""Test line 96: early return when world_size <= 1"""
|
||||||
|
self.mock_has_flashinfer.return_value = True
|
||||||
|
|
||||||
|
with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm"):
|
||||||
|
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
|
||||||
|
ensure_workspace_initialized,
|
||||||
|
)
|
||||||
|
|
||||||
|
fd_config = Mock()
|
||||||
|
fd_config.parallel_config = Mock()
|
||||||
|
fd_config.parallel_config.tensor_parallel_size = 1
|
||||||
|
|
||||||
|
with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion.dist.get_rank", return_value=0):
|
||||||
|
result = ensure_workspace_initialized(fd_config)
|
||||||
|
|
||||||
|
# Should return False for single GPU
|
||||||
|
self.assertFalse(result)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlashInferAllReduceResidualRMSNormFallbacks(unittest.TestCase):
|
||||||
|
"""Test flashinfer_allreduce_residual_rmsnorm fallback paths"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Initialize test fixtures"""
|
||||||
|
self.patcher_has_flashinfer = patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion.has_flashinfer")
|
||||||
|
self.mock_has_flashinfer = self.patcher_has_flashinfer.start()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
"""Clean up patches"""
|
||||||
|
self.patcher_has_flashinfer.stop()
|
||||||
|
|
||||||
|
def test_flashinfer_not_available_fallback(self):
|
||||||
|
"""Test lines 140-141: fallback when flashinfer not available"""
|
||||||
|
self.mock_has_flashinfer.return_value = False
|
||||||
|
|
||||||
|
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
|
||||||
|
flashinfer_allreduce_residual_rmsnorm,
|
||||||
|
)
|
||||||
|
|
||||||
|
fd_config = Mock()
|
||||||
|
fd_config.parallel_config = Mock()
|
||||||
|
fd_config.parallel_config.tensor_parallel_size = 2
|
||||||
|
|
||||||
|
input_tensor = paddle.randn([128, 768])
|
||||||
|
residual = paddle.randn([128, 768])
|
||||||
|
weight = paddle.randn([768])
|
||||||
|
|
||||||
|
norm_out, residual_out = flashinfer_allreduce_residual_rmsnorm(
|
||||||
|
fd_config=fd_config,
|
||||||
|
input_tensor=input_tensor,
|
||||||
|
residual=residual,
|
||||||
|
weight=weight,
|
||||||
|
eps=1e-6,
|
||||||
|
max_token_num=2048,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return None, None when flashinfer not available
|
||||||
|
self.assertIsNone(norm_out)
|
||||||
|
self.assertIsNone(residual_out)
|
||||||
|
|
||||||
|
def test_single_gpu_fallback(self):
|
||||||
|
"""Test lines 146-147: fallback for single GPU"""
|
||||||
|
self.mock_has_flashinfer.return_value = True
|
||||||
|
|
||||||
|
with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm"):
|
||||||
|
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
|
||||||
|
flashinfer_allreduce_residual_rmsnorm,
|
||||||
|
)
|
||||||
|
|
||||||
|
fd_config = Mock()
|
||||||
|
fd_config.parallel_config = Mock()
|
||||||
|
fd_config.parallel_config.tensor_parallel_size = 1
|
||||||
|
|
||||||
|
input_tensor = paddle.randn([128, 768])
|
||||||
|
residual = paddle.randn([128, 768])
|
||||||
|
weight = paddle.randn([768])
|
||||||
|
|
||||||
|
norm_out, residual_out = flashinfer_allreduce_residual_rmsnorm(
|
||||||
|
fd_config=fd_config,
|
||||||
|
input_tensor=input_tensor,
|
||||||
|
residual=residual,
|
||||||
|
weight=weight,
|
||||||
|
eps=1e-6,
|
||||||
|
max_token_num=2048,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return None, None for single GPU
|
||||||
|
self.assertIsNone(norm_out)
|
||||||
|
self.assertIsNone(residual_out)
|
||||||
|
|
||||||
|
def test_empty_tensor_handling(self):
|
||||||
|
"""Test line 166: empty tensor handling"""
|
||||||
|
self.mock_has_flashinfer.return_value = True
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm") as mock_comm,
|
||||||
|
patch(
|
||||||
|
"fastdeploy.model_executor.layers.flashinfer_comm_fusion.ensure_workspace_initialized",
|
||||||
|
return_value=True,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
|
||||||
|
flashinfer_allreduce_residual_rmsnorm,
|
||||||
|
)
|
||||||
|
|
||||||
|
fd_config = Mock()
|
||||||
|
fd_config.parallel_config = Mock()
|
||||||
|
fd_config.parallel_config.tensor_parallel_size = 2
|
||||||
|
|
||||||
|
# Empty tensor (0 tokens)
|
||||||
|
input_tensor = paddle.zeros([0, 768])
|
||||||
|
residual = paddle.zeros([0, 768])
|
||||||
|
weight = paddle.randn([768])
|
||||||
|
|
||||||
|
# Mock the trtllm_allreduce_fusion to not be called
|
||||||
|
mock_comm.trtllm_allreduce_fusion = Mock()
|
||||||
|
|
||||||
|
norm_out, residual_out = flashinfer_allreduce_residual_rmsnorm(
|
||||||
|
fd_config=fd_config,
|
||||||
|
input_tensor=input_tensor,
|
||||||
|
residual=residual,
|
||||||
|
weight=weight,
|
||||||
|
eps=1e-6,
|
||||||
|
max_token_num=2048,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return empty tensors, not call flashinfer
|
||||||
|
self.assertEqual(norm_out.shape[0], 0)
|
||||||
|
self.assertEqual(residual_out.shape[0], 0)
|
||||||
|
mock_comm.trtllm_allreduce_fusion.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestCleanupFlashInferWorkspace(unittest.TestCase):
|
||||||
|
"""Test cleanup_flashinfer_workspace function"""
|
||||||
|
|
||||||
|
def test_cleanup_workspace_function(self):
|
||||||
|
"""Test lines 211-212: cleanup function"""
|
||||||
|
with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._workspace_manager") as mock_manager:
|
||||||
|
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
|
||||||
|
cleanup_flashinfer_workspace,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_manager.cleanup = Mock()
|
||||||
|
|
||||||
|
cleanup_flashinfer_workspace()
|
||||||
|
|
||||||
|
mock_manager.cleanup.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
"""Run tests directly (called by subprocess after distributed launch)"""
|
||||||
|
unittest.main(verbosity=2)
|
||||||
@@ -58,6 +58,7 @@ def make_fd_config(
|
|||||||
expert_parallel_size=1,
|
expert_parallel_size=1,
|
||||||
tp_group=None,
|
tp_group=None,
|
||||||
use_sequence_parallel_moe=use_sequence_parallel_moe,
|
use_sequence_parallel_moe=use_sequence_parallel_moe,
|
||||||
|
enable_flashinfer_allreduce_fusion=False,
|
||||||
),
|
),
|
||||||
scheduler_config=SimpleNamespace(splitwise_role=splitwise_role, max_num_seqs=1),
|
scheduler_config=SimpleNamespace(splitwise_role=splitwise_role, max_num_seqs=1),
|
||||||
load_config=SimpleNamespace(
|
load_config=SimpleNamespace(
|
||||||
|
|||||||
Reference in New Issue
Block a user