[Optimization] enable trtllm_all_reduce fusion kernel in glm model (#6660)

* enable trtllm_all_reduce fusion kernel in glm model

* fix conflict

* format update

* fix a bug

* modify test

* modify test

* support empty tensor and modify test

* fix test_linear config issues

* modify test name

* add edge test case

* modify format

* fix conflict

* modify default max token num in trtllm_allreduce_fusion

* add max token num branch for trtllm_allreduce_fusion

* fix format

* fix rmsnorm config issue

* modify 2025 to 2026

* using compat grard

* Lazily import flashinfer.comm and fix test config issue

* fix test issues

* add flashinfer cache dir clean machine

* fix some issues
This commit is contained in:
Bingoo
2026-04-16 14:10:19 +08:00
committed by GitHub
parent e53f5184ac
commit 6b891da02b
17 changed files with 871 additions and 11 deletions
@@ -31,6 +31,7 @@ def _make_minimal_rmsnorm(hidden_size, eps=1e-5, dtype="float32"):
layer.bias = None
layer.split_x = False
layer.allgather_out = False
layer.enable_all_reduce_fusion = False
return layer
+1
View File
@@ -39,6 +39,7 @@ def _make_cfg(**ov):
pc.use_internode_ll_two_stage = pc.disable_sequence_parallel_moe = False
pc.shutdown_comm_group_if_worker_idle = False
pc.ep_prefill_use_worst_num_tokens = False
pc.enable_flashinfer_allreduce_fusion = False
sc = ns(max_num_seqs=256, max_num_batched_tokens=4096, splitwise_role="mixed", name="local")
sc.enable_overlap_schedule = False
cc = ns(num_gpu_blocks_override=None, gpu_memory_utilization=0.9, block_size=16, enc_dec_block_num=0)
@@ -0,0 +1,56 @@
"""
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import os
import subprocess
import sys
def test_run_distributed():
"""Launch multi-GPU distributed test via paddle.distributed.launch as subprocess"""
# clearn flashinfer cache directory
flashinfer_cache_dir = os.path.join(os.sep, "root", ".cache", "flashinfer")
if os.path.exists(flashinfer_cache_dir):
print(f"=== Clearing flashinfer cache directory: {flashinfer_cache_dir} ===")
subprocess.run(["rm", "-rf", flashinfer_cache_dir], check=True)
current_dir = os.path.dirname(os.path.abspath(__file__))
run_script = os.path.join(current_dir, "trtllm_allreduce_rms_fusion.py")
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
command = [
sys.executable,
"-m",
"paddle.distributed.launch",
"--gpus",
"0,1",
run_script,
]
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
try:
stdout, stderr = process.communicate(timeout=400)
return_code = process.returncode
except subprocess.TimeoutExpired:
process.kill()
stdout, stderr = process.communicate()
return_code = -1
print(f"=== Distributed test stdout ===\n{stdout}")
print(f"=== Distributed test stderr ===\n{stderr}")
assert return_code in (0, 250), f"Process exited with code {return_code}"
test_run_distributed()
+548
View File
@@ -0,0 +1,548 @@
"""
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import time
import unittest
from unittest.mock import Mock, patch
import numpy as np
import paddle
import paddle.distributed as dist
class TestFlashInferAllReduceResidualRMSNorm(unittest.TestCase):
"""Test FlashInfer AllReduce + Residual + RMSNorm fused operator"""
@classmethod
def setUpClass(cls):
"""Set up test environment"""
if paddle.is_compiled_with_cuda():
paddle.set_device("gpu")
else:
paddle.set_device("cpu")
dist.init_parallel_env()
def setUp(self):
"""Initialize each test case"""
# Fix random seed for reproducibility
paddle.seed(42)
np.random.seed(42)
self.dtype = paddle.float32
self.token_num = 128
self.hidden_dim = 768
self.eps = 1e-6
self.epsilon = 1e-6
self.max_token_num = 2048
# Create mock FDConfig
self.fd_config = Mock()
self.fd_config.parallel_config = Mock()
self.fd_config.parallel_config.tensor_parallel_size = dist.get_world_size()
self.begin_norm_axis = 1
# Performance test params - increase iterations for stability
self.warmup_iterations = 20 # Increase warmup
self.test_iterations = 200 # Increase test iterations
def tearDown(self):
"""Clean up resources"""
if paddle.is_compiled_with_cuda():
paddle.device.cuda.empty_cache()
paddle.device.cuda.synchronize()
def create_test_tensors(self):
"""Create test tensors"""
input_tensor = paddle.randn([self.token_num, self.hidden_dim], dtype=self.dtype)
residual = paddle.randn([self.token_num, self.hidden_dim], dtype=self.dtype)
weight = paddle.randn([self.hidden_dim], dtype=self.dtype)
return input_tensor, residual, weight
def compute_reference_output(self, input_tensor, residual, weight, eps):
"""Reference implementation: manually compute AllReduce + Residual + RMSNorm"""
# # Step 1: AllReduce (identity on single device)
# allreduce_out = input_tensor.clone()
# Apply all reduce operator
dist.all_reduce(input_tensor, op=dist.ReduceOp.SUM)
# Step 2: Add residual
residual_out = input_tensor + residual
# Step 3: RMSNorm
variance = residual_out.pow(2).mean(axis=-1, keepdim=True)
norm_out = residual_out * paddle.rsqrt(variance + eps)
norm_out = norm_out * weight
# dist.all_reduce(residual_out, op=dist.ReduceOp.SUM)
return norm_out, residual_out
def paddle_rms_fuse(self, input_tensor, residual, weight, eps):
from paddle.incubate.nn.functional import fused_rms_norm
# Apply all reduce operator
dist.all_reduce(input_tensor, op=dist.ReduceOp.SUM)
out_fused = fused_rms_norm(
input_tensor,
norm_weight=weight,
norm_bias=None,
epsilon=eps,
begin_norm_axis=self.begin_norm_axis,
bias=None,
residual=residual,
)
return out_fused[0], out_fused[1]
def flashinfer_rms_fuse(self, input_tensor, residual, weight, eps):
"""FlashInfer fused operator"""
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
flashinfer_allreduce_residual_rmsnorm,
)
norm_out, residual_out = flashinfer_allreduce_residual_rmsnorm(
fd_config=self.fd_config,
input_tensor=input_tensor,
residual=residual,
weight=weight,
eps=eps,
max_token_num=self.max_token_num,
use_oneshot=False,
)
return norm_out, residual_out
def benchmark_function(self, func, *args, name="", **kwargs):
"""
Improved performance benchmark
- Wait for GPU frequency stabilization
- Use median instead of mean (more stable)
- Filter outliers
"""
# Force GPU frequency stabilization
if paddle.is_compiled_with_cuda():
for _ in range(5):
paddle.device.cuda.synchronize()
time.sleep(0.01)
# Warmup - thorough warm-up
for _ in range(self.warmup_iterations):
result = func(*args, **kwargs)
if paddle.is_compiled_with_cuda():
paddle.device.cuda.synchronize()
# Extra wait to ensure GPU stability
if paddle.is_compiled_with_cuda():
paddle.device.cuda.synchronize()
time.sleep(0.1)
# Benchmark run
times = []
for i in range(self.test_iterations):
if paddle.is_compiled_with_cuda():
paddle.device.cuda.synchronize()
start = time.perf_counter()
result = func(*args, **kwargs)
if paddle.is_compiled_with_cuda():
paddle.device.cuda.synchronize()
end = time.perf_counter()
elapsed = (end - start) * 1000 # Convert to milliseconds
times.append(elapsed)
times = np.array(times)
# Filter outliers using IQR method
q1, q3 = np.percentile(times, [25, 75])
iqr = q3 - q1
lower_bound = q1 - 1.5 * iqr
upper_bound = q3 + 1.5 * iqr
filtered_times = times[(times >= lower_bound) & (times <= upper_bound)]
# Fall back to raw data if too many samples filtered out
if len(filtered_times) < self.test_iterations * 0.5:
filtered_times = times
# Statistics
avg_time = np.mean(filtered_times)
median_time = np.median(filtered_times)
std_time = np.std(filtered_times)
min_time = np.min(filtered_times)
max_time = np.max(filtered_times)
cv = (std_time / avg_time) * 100 # Coefficient of variation (%)
print(f"\n{'='*70}")
print(f"Performance Benchmark: {name}")
print(f"{'='*70}")
print(f"Iterations: {len(filtered_times)}/{self.test_iterations} (after {self.warmup_iterations} warmup)")
print(f"Median: {median_time:.4f} ms (most stable metric)")
print(f"Average: {avg_time:.4f} ms")
print(f"Std Dev: {std_time:.4f} ms (CV: {cv:.2f}%)")
print(f"Min: {min_time:.4f} ms")
print(f"Max: {max_time:.4f} ms")
print(f"{'='*70}\n")
# Return median (more stable) and result
return median_time, result
def test_accuracy_fused_vs_reference(self):
"""Test accuracy of fused operator vs reference implementation"""
input_tensor, residual, weight = self.create_test_tensors()
reference_output, ref_res = self.compute_reference_output(
input_tensor.clone(), residual.clone(), weight.clone(), self.eps
)
fused_output, paddle_res = self.paddle_rms_fuse(
input_tensor.clone(), residual.clone(), weight.clone(), self.eps
)
flashinfer_output, flashinfer_res = self.flashinfer_rms_fuse(
input_tensor.clone(), residual.clone(), weight.clone(), self.eps
)
# Verify results
np.testing.assert_allclose(fused_output.numpy(), reference_output.numpy(), rtol=1e-5, atol=1e-5)
np.testing.assert_allclose(ref_res.numpy(), paddle_res.numpy(), rtol=1e-5, atol=1e-5)
np.testing.assert_allclose(flashinfer_output.numpy(), reference_output.numpy(), rtol=1e-5, atol=1e-5)
np.testing.assert_allclose(ref_res.numpy(), flashinfer_res.numpy(), rtol=1e-5, atol=1e-5)
class TestFlashInferWorkspaceManager(unittest.TestCase):
"""Test FlashInferWorkspaceManager"""
def setUp(self):
"""Initialize"""
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
FlashInferWorkspaceManager,
)
self.manager = FlashInferWorkspaceManager()
def test_initialization(self):
"""Test initialization state"""
self.assertIsNone(self.manager.workspace_tensor)
self.assertIsNone(self.manager.ipc_handles)
self.assertIsNone(self.manager.world_size)
self.assertIsNone(self.manager.rank)
self.assertFalse(self.manager.initialized)
def test_cleanup(self):
"""Test cleanup functionality"""
self.manager.cleanup()
self.assertFalse(self.manager.initialized)
self.assertIsNone(self.manager.workspace_tensor)
class TestFlashInferWorkspaceManagerEdgeCases(unittest.TestCase):
"""Test FlashInferWorkspaceManager edge cases and fallback paths"""
def setUp(self):
"""Initialize test fixtures"""
# Patch before importing to test fallback paths
self.patcher_has_flashinfer = patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion.has_flashinfer")
self.mock_has_flashinfer = self.patcher_has_flashinfer.start()
def tearDown(self):
"""Clean up patches"""
self.patcher_has_flashinfer.stop()
def test_initialization_early_return_when_already_initialized(self):
"""Test line 47: early return when already initialized with same world_size"""
# Patch _flashinfer_comm to be available
with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm") as mock_comm:
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
FlashInferWorkspaceManager,
)
manager = FlashInferWorkspaceManager()
# First initialization
manager.initialized = True
manager.world_size = 2
# Mock the comm functions
mock_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion = Mock(return_value=(Mock(), Mock()))
# Second initialization with same world_size - should return early
manager.initialize(
world_size=2,
rank=0,
max_token_num=2048,
hidden_dim=4096,
)
def test_initialization_warning_when_comm_none(self):
"""Test: warning when _get_flashinfer_comm is None"""
# Patch to ensure _get_flashinfer_comm is None
with patch(
"fastdeploy.model_executor.layers.flashinfer_comm_fusion._get_flashinfer_comm",
return_value=None,
):
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
FlashInferWorkspaceManager,
)
manager = FlashInferWorkspaceManager()
# Should not raise, just log warning and return
manager.initialize(
world_size=2,
rank=0,
max_token_num=2048,
hidden_dim=4096,
)
# Verify not initialized
self.assertFalse(manager.initialized)
def test_cleanup_with_exception(self):
"""Test lines 73-80: cleanup with exception handling"""
with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm") as mock_comm:
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
FlashInferWorkspaceManager,
)
manager = FlashInferWorkspaceManager()
manager.initialized = True
manager.ipc_handles = Mock()
manager.workspace_tensor = Mock()
# Mock the destroy function to raise exception
mock_comm.trtllm_destroy_ipc_workspace_for_all_reduce = Mock(side_effect=RuntimeError("Cleanup error"))
# Should not raise, just log warning
manager.cleanup()
# Verify cleanup happened
self.assertFalse(manager.initialized)
self.assertIsNone(manager.workspace_tensor)
self.assertIsNone(manager.ipc_handles)
def test_cleanup_without_initialization(self):
"""Test cleanup when not initialized"""
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
FlashInferWorkspaceManager,
)
manager = FlashInferWorkspaceManager()
manager.initialized = False
# Should not raise
manager.cleanup()
# Verify state
self.assertFalse(manager.initialized)
class TestEnsureWorkspaceInitialized(unittest.TestCase):
"""Test ensure_workspace_initialized fallback paths"""
def setUp(self):
"""Initialize test fixtures"""
self.patcher_has_flashinfer = patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion.has_flashinfer")
self.mock_has_flashinfer = self.patcher_has_flashinfer.start()
def tearDown(self):
"""Clean up patches"""
self.patcher_has_flashinfer.stop()
def test_ensure_workspace_when_flashinfer_not_available(self):
"""Test line 91: early return when flashinfer not available"""
self.mock_has_flashinfer.return_value = False
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
ensure_workspace_initialized,
)
fd_config = Mock()
fd_config.parallel_config = Mock()
fd_config.parallel_config.tensor_parallel_size = 2
result = ensure_workspace_initialized(fd_config)
# Should return False (not initialized)
self.assertFalse(result)
def test_ensure_workspace_when_comm_none(self):
"""Test ensure_workspace_initialized when _get_flashinfer_comm is None"""
self.mock_has_flashinfer.return_value = True
with patch(
"fastdeploy.model_executor.layers.flashinfer_comm_fusion._get_flashinfer_comm",
return_value=None,
):
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
ensure_workspace_initialized,
)
fd_config = Mock()
fd_config.parallel_config = Mock()
fd_config.parallel_config.tensor_parallel_size = 2
result = ensure_workspace_initialized(fd_config)
# Should return False
self.assertFalse(result)
def test_ensure_workspace_single_gpu(self):
"""Test line 96: early return when world_size <= 1"""
self.mock_has_flashinfer.return_value = True
with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm"):
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
ensure_workspace_initialized,
)
fd_config = Mock()
fd_config.parallel_config = Mock()
fd_config.parallel_config.tensor_parallel_size = 1
with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion.dist.get_rank", return_value=0):
result = ensure_workspace_initialized(fd_config)
# Should return False for single GPU
self.assertFalse(result)
class TestFlashInferAllReduceResidualRMSNormFallbacks(unittest.TestCase):
"""Test flashinfer_allreduce_residual_rmsnorm fallback paths"""
def setUp(self):
"""Initialize test fixtures"""
self.patcher_has_flashinfer = patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion.has_flashinfer")
self.mock_has_flashinfer = self.patcher_has_flashinfer.start()
def tearDown(self):
"""Clean up patches"""
self.patcher_has_flashinfer.stop()
def test_flashinfer_not_available_fallback(self):
"""Test lines 140-141: fallback when flashinfer not available"""
self.mock_has_flashinfer.return_value = False
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
flashinfer_allreduce_residual_rmsnorm,
)
fd_config = Mock()
fd_config.parallel_config = Mock()
fd_config.parallel_config.tensor_parallel_size = 2
input_tensor = paddle.randn([128, 768])
residual = paddle.randn([128, 768])
weight = paddle.randn([768])
norm_out, residual_out = flashinfer_allreduce_residual_rmsnorm(
fd_config=fd_config,
input_tensor=input_tensor,
residual=residual,
weight=weight,
eps=1e-6,
max_token_num=2048,
)
# Should return None, None when flashinfer not available
self.assertIsNone(norm_out)
self.assertIsNone(residual_out)
def test_single_gpu_fallback(self):
"""Test lines 146-147: fallback for single GPU"""
self.mock_has_flashinfer.return_value = True
with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm"):
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
flashinfer_allreduce_residual_rmsnorm,
)
fd_config = Mock()
fd_config.parallel_config = Mock()
fd_config.parallel_config.tensor_parallel_size = 1
input_tensor = paddle.randn([128, 768])
residual = paddle.randn([128, 768])
weight = paddle.randn([768])
norm_out, residual_out = flashinfer_allreduce_residual_rmsnorm(
fd_config=fd_config,
input_tensor=input_tensor,
residual=residual,
weight=weight,
eps=1e-6,
max_token_num=2048,
)
# Should return None, None for single GPU
self.assertIsNone(norm_out)
self.assertIsNone(residual_out)
def test_empty_tensor_handling(self):
"""Test line 166: empty tensor handling"""
self.mock_has_flashinfer.return_value = True
with (
patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm") as mock_comm,
patch(
"fastdeploy.model_executor.layers.flashinfer_comm_fusion.ensure_workspace_initialized",
return_value=True,
),
):
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
flashinfer_allreduce_residual_rmsnorm,
)
fd_config = Mock()
fd_config.parallel_config = Mock()
fd_config.parallel_config.tensor_parallel_size = 2
# Empty tensor (0 tokens)
input_tensor = paddle.zeros([0, 768])
residual = paddle.zeros([0, 768])
weight = paddle.randn([768])
# Mock the trtllm_allreduce_fusion to not be called
mock_comm.trtllm_allreduce_fusion = Mock()
norm_out, residual_out = flashinfer_allreduce_residual_rmsnorm(
fd_config=fd_config,
input_tensor=input_tensor,
residual=residual,
weight=weight,
eps=1e-6,
max_token_num=2048,
)
# Should return empty tensors, not call flashinfer
self.assertEqual(norm_out.shape[0], 0)
self.assertEqual(residual_out.shape[0], 0)
mock_comm.trtllm_allreduce_fusion.assert_not_called()
class TestCleanupFlashInferWorkspace(unittest.TestCase):
"""Test cleanup_flashinfer_workspace function"""
def test_cleanup_workspace_function(self):
"""Test lines 211-212: cleanup function"""
with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._workspace_manager") as mock_manager:
from fastdeploy.model_executor.layers.flashinfer_comm_fusion import (
cleanup_flashinfer_workspace,
)
mock_manager.cleanup = Mock()
cleanup_flashinfer_workspace()
mock_manager.cleanup.assert_called_once()
if __name__ == "__main__":
"""Run tests directly (called by subprocess after distributed launch)"""
unittest.main(verbosity=2)
+1
View File
@@ -58,6 +58,7 @@ def make_fd_config(
expert_parallel_size=1,
tp_group=None,
use_sequence_parallel_moe=use_sequence_parallel_moe,
enable_flashinfer_allreduce_fusion=False,
),
scheduler_config=SimpleNamespace(splitwise_role=splitwise_role, max_num_seqs=1),
load_config=SimpleNamespace(