[Optimization] Use a separate driver when using Triton with Paddle (#6897)

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Nyakku Shigure
2026-03-24 10:56:00 +08:00
committed by GitHub
parent e87ce4b8cd
commit 8b6bbb3504
13 changed files with 97 additions and 19 deletions
@@ -6,6 +6,9 @@ from collections import namedtuple
from collections.abc import Callable
from typing import Any, Dict
from fastdeploy.model_executor.ops.triton_ops.triton_utils import (
enable_compat_on_triton_kernel,
)
from fastdeploy.utils import get_logger
logger = get_logger("worker_process", "worker_process.log")
@@ -37,6 +40,7 @@ def _matmul_launch_metadata(grid: Callable[..., Any], kernel: Any, args: Dict[st
return ret
@enable_compat_on_triton_kernel
@triton.jit
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS):
group_id = tile_id // num_pid_in_group
@@ -47,6 +51,7 @@ def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS):
return pid_m, pid_n
@enable_compat_on_triton_kernel
@triton.jit(launch_metadata=_matmul_launch_metadata)
def matmul_kernel_persistent(
a_ptr,
@@ -226,6 +231,7 @@ def matmul_persistent(a: paddle.Tensor, b: paddle.Tensor, bias: paddle.Tensor |
return c
@enable_compat_on_triton_kernel
@triton.jit
def _log_softmax_kernel(
input_ptr,
@@ -330,6 +336,7 @@ def log_softmax(input: paddle.Tensor, axis: int = -1) -> paddle.Tensor:
return output.reshape(original_shape)
@enable_compat_on_triton_kernel
@triton.jit
def mean_kernel(
input_ptr,
@@ -475,6 +482,7 @@ def mean_dim(
# We thank the SGLang authors and the Thinking Machines Lab for their contributions.
@enable_compat_on_triton_kernel
@triton.jit # pragma: no cover
def bmm_kernel_persistent(
a_ptr,
@@ -724,6 +732,7 @@ def mean_batch_invariant(
# ---------------------------------------------------------------------------
@enable_compat_on_triton_kernel
@triton.jit
def _rms_norm_kernel( # pragma: no cover
input_ptr,