mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 17:11:21 +08:00
[Optimization] Use a separate driver when using Triton with Paddle (#6897)
--------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -6,6 +6,9 @@ from collections import namedtuple
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastdeploy.model_executor.ops.triton_ops.triton_utils import (
|
||||
enable_compat_on_triton_kernel,
|
||||
)
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("worker_process", "worker_process.log")
|
||||
@@ -37,6 +40,7 @@ def _matmul_launch_metadata(grid: Callable[..., Any], kernel: Any, args: Dict[st
|
||||
return ret
|
||||
|
||||
|
||||
@enable_compat_on_triton_kernel
|
||||
@triton.jit
|
||||
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS):
|
||||
group_id = tile_id // num_pid_in_group
|
||||
@@ -47,6 +51,7 @@ def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS):
|
||||
return pid_m, pid_n
|
||||
|
||||
|
||||
@enable_compat_on_triton_kernel
|
||||
@triton.jit(launch_metadata=_matmul_launch_metadata)
|
||||
def matmul_kernel_persistent(
|
||||
a_ptr,
|
||||
@@ -226,6 +231,7 @@ def matmul_persistent(a: paddle.Tensor, b: paddle.Tensor, bias: paddle.Tensor |
|
||||
return c
|
||||
|
||||
|
||||
@enable_compat_on_triton_kernel
|
||||
@triton.jit
|
||||
def _log_softmax_kernel(
|
||||
input_ptr,
|
||||
@@ -330,6 +336,7 @@ def log_softmax(input: paddle.Tensor, axis: int = -1) -> paddle.Tensor:
|
||||
return output.reshape(original_shape)
|
||||
|
||||
|
||||
@enable_compat_on_triton_kernel
|
||||
@triton.jit
|
||||
def mean_kernel(
|
||||
input_ptr,
|
||||
@@ -475,6 +482,7 @@ def mean_dim(
|
||||
# We thank the SGLang authors and the Thinking Machines Lab for their contributions.
|
||||
|
||||
|
||||
@enable_compat_on_triton_kernel
|
||||
@triton.jit # pragma: no cover
|
||||
def bmm_kernel_persistent(
|
||||
a_ptr,
|
||||
@@ -724,6 +732,7 @@ def mean_batch_invariant(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@enable_compat_on_triton_kernel
|
||||
@triton.jit
|
||||
def _rms_norm_kernel( # pragma: no cover
|
||||
input_ptr,
|
||||
|
||||
Reference in New Issue
Block a user