mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-24 01:29:57 +08:00
[Optimization] Use a separate driver when using Triton with Paddle (#6897)
--------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -18,29 +18,42 @@ import inspect
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from importlib.metadata import PackageNotFoundError, distribution
|
||||
|
||||
import paddle
|
||||
import triton
|
||||
from paddle.base.framework import OpProtoHolder
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.utils import _is_package_installed
|
||||
|
||||
compile_file = triton.__path__[0] + "/tools/compile.py"
|
||||
link_file = triton.__path__[0] + "/tools/link.py"
|
||||
python_path = sys.executable
|
||||
|
||||
if _is_package_installed("torch"):
|
||||
with paddle.use_compat_guard(enable=True, silent=True):
|
||||
from triton.runtime.driver import _create_driver
|
||||
|
||||
def _is_package_installed(dist_name: str) -> bool:
|
||||
try:
|
||||
distribution(dist_name)
|
||||
return True
|
||||
except PackageNotFoundError:
|
||||
return False
|
||||
paddle_driver = _create_driver()
|
||||
|
||||
|
||||
def swap_driver_guard(fn):
|
||||
from triton.runtime.driver import driver
|
||||
|
||||
# A lightweight wrapper to enable compatibility for triton kernel
|
||||
def wrapped_fn(*args, **kwargs):
|
||||
driver.set_active(paddle_driver)
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
driver.reset_active()
|
||||
|
||||
return wrapped_fn
|
||||
|
||||
|
||||
def enable_compat_on_triton_kernel(triton_kernel):
|
||||
# When torch is not installed, this decorator does not do anything, just return the original triton kernel.
|
||||
# Because the `paddle.enable_compat(scope={"triton"})` already enabled in `__init__.py`, it will take zero runtime overhead.
|
||||
if not _is_package_installed("torch"):
|
||||
return triton_kernel
|
||||
|
||||
@@ -49,7 +62,7 @@ def enable_compat_on_triton_kernel(triton_kernel):
|
||||
self.kernel = kernel
|
||||
|
||||
def __getitem__(self, index):
|
||||
return paddle.use_compat_guard(enable=True, silent=True)(self.kernel[index])
|
||||
return swap_driver_guard(self.kernel[index])
|
||||
|
||||
return WrappedTritonKernel(triton_kernel)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user