[Optimization] Use a separate driver when using Triton with Paddle (#6897)

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Nyakku Shigure
2026-03-24 10:56:00 +08:00
committed by GitHub
parent e87ce4b8cd
commit 8b6bbb3504
13 changed files with 97 additions and 19 deletions
@@ -18,29 +18,42 @@ import inspect
import os
import re
import sys
from importlib.metadata import PackageNotFoundError, distribution
import paddle
import triton
from paddle.base.framework import OpProtoHolder
from fastdeploy import envs
from fastdeploy.utils import _is_package_installed
compile_file = triton.__path__[0] + "/tools/compile.py"
link_file = triton.__path__[0] + "/tools/link.py"
python_path = sys.executable
if _is_package_installed("torch"):
with paddle.use_compat_guard(enable=True, silent=True):
from triton.runtime.driver import _create_driver
def _is_package_installed(dist_name: str) -> bool:
try:
distribution(dist_name)
return True
except PackageNotFoundError:
return False
paddle_driver = _create_driver()
def swap_driver_guard(fn):
from triton.runtime.driver import driver
# A lightweight wrapper to enable compatibility for triton kernel
def wrapped_fn(*args, **kwargs):
driver.set_active(paddle_driver)
try:
return fn(*args, **kwargs)
finally:
driver.reset_active()
return wrapped_fn
def enable_compat_on_triton_kernel(triton_kernel):
# When torch is not installed, this decorator does not do anything, just return the original triton kernel.
# Because the `paddle.enable_compat(scope={"triton"})` already enabled in `__init__.py`, it will take zero runtime overhead.
if not _is_package_installed("torch"):
return triton_kernel
@@ -49,7 +62,7 @@ def enable_compat_on_triton_kernel(triton_kernel):
self.kernel = kernel
def __getitem__(self, index):
return paddle.use_compat_guard(enable=True, silent=True)(self.kernel[index])
return swap_driver_guard(self.kernel[index])
return WrappedTritonKernel(triton_kernel)