Files
FastDeploy/fastdeploy/model_executor/ops/triton_ops/triton_utils.py
T

833 lines
28 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import inspect
import os
import re
import sys
from importlib.metadata import PackageNotFoundError, distribution
import paddle
import triton
from paddle.base.framework import OpProtoHolder
from fastdeploy import envs
compile_file = triton.__path__[0] + "/tools/compile.py"
link_file = triton.__path__[0] + "/tools/link.py"
python_path = sys.executable
def _is_package_installed(dist_name: str) -> bool:
try:
distribution(dist_name)
return True
except PackageNotFoundError:
return False
def enable_compat_on_triton_kernel(triton_kernel):
# When torch is not installed, this decorator does not do anything, just return the original triton kernel.
if not _is_package_installed("torch"):
return triton_kernel
class WrappedTritonKernel:
def __init__(self, kernel):
self.kernel = kernel
def __getitem__(self, index):
return paddle.use_compat_guard(enable=True, silent=True)(self.kernel[index])
return WrappedTritonKernel(triton_kernel)
def SubstituteTemplate(template, values):
"""
Substitute all variables in the given template string using the provided values dictionary.
"""
text = template
changed = True
while changed:
changed = False
for key, value in values.items():
regex = "\\$\\{%s\\}" % key
newtext = re.sub(regex, value, text)
if newtext != text:
changed = True
text = newtext
return text
def find_so_path(generated_dir, python_package_name):
"""
find the specified so in generated_dir, if not found it will return None.
"""
so_path = []
for root, dirs, files in os.walk(generated_dir):
for file in files:
if file.endswith(python_package_name + ".so"):
so_path.append(os.path.join(root, file))
if len(so_path) == 0:
return None
else:
assert len(so_path) == 1
return so_path[0]
def multi_process_do(commands):
"""
Multi-threaded execution of commands.
"""
THREADS = 40
import multiprocessing
process = []
def one_process_work(commands, thread_id):
i = thread_id
while i < len(commands):
re = os.system(commands[i])
assert re == 0
i += THREADS
for i in range(THREADS):
p = multiprocessing.Process(target=one_process_work, args=(commands, i))
process.append(p)
for p in process:
p.start()
for p in process:
p.join()
def extract_triton_kernel(kernel, file_name):
"""
Extract the triton kernel and write it to the specified file_name.
Args:
kernel: the triton kernel name.
file_name: the file name you want to write.
"""
import inspect
import re
import textwrap
fn = kernel
if isinstance(kernel, triton.runtime.jit.JITFunction):
fn = kernel.fn
elif isinstance(kernel, triton.runtime.autotuner.Autotuner):
fn = kernel.fn.fn
else:
AssertionError("error occurs")
py_script = textwrap.dedent(inspect.getsource(fn))
# @triton.jit must only appear once
# assert len(re.findall("@triton.jit", py_script)) == 1
assert len(re.findall("def ", py_script)) == 1
# assert len(re.findall("@haha()", py_script)) == 1
# py_script = py_script.replace("@haha()", "@triton.jit")
py_script = py_script[py_script.find("def ") :]
py_script = "import triton\nimport triton.language as tl\n\n\n@triton.jit\n" + py_script
py_script = py_script.replace("if bias_ptr is not None", "if bias_ptr")
with open(file_name, "w") as f:
f.write(py_script)
f.close()
template_install = """
import os
generated_cu = []
for root, dirs, files in os.walk("./"):
for file in files:
if file.endswith(".c") or file.endswith(".cu"):
generated_cu.append(os.path.join(root, file))
import paddle
from paddle.utils.cpp_extension import CUDAExtension, setup
def get_gencode_flags():
prop = paddle.device.cuda.get_device_properties()
cc = prop.major * 10 + prop.minor
return ["-gencode", "arch=compute_{{0}},code=sm_{{0}}".format(cc)]
gencode_flags = get_gencode_flags()
setup(
name="{python_package_name}",
ext_modules=CUDAExtension(
sources = generated_cu,
extra_compile_args={{
"cc": ["-lcuda"],
"nvcc": [
"-O3",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
]
+ gencode_flags,
}},
extra_link_args = ["-lcuda"]
),
)
"""
def get_op_name_with_suffix(op_name, x_list):
"""
Get the operator name with suffix.
"""
suffix = []
for x in x_list:
if x % 16 == 0:
suffix.append(16)
elif x == 1:
suffix.append(1)
else:
suffix.append(0)
return op_name + "_".join([str(i) for i in suffix])
def get_value_hint(x):
"""
Get the value hint from input list.
"""
hint = ""
for ele in x:
if isinstance(ele, int):
if ele % 16 == 0 and ele > 0:
hint += "i64:16,"
elif ele == 1:
hint += "i64:1,"
else:
hint += "i64,"
if isinstance(ele, float):
hint += "fp32,"
return hint
def get_dtype_str(dtype):
"""
Get the dtype str.
"""
if dtype == paddle.float16:
return "_fp16"
if dtype == paddle.float8_e4m3fn:
return "_float8_e4m3fn"
elif dtype == paddle.uint8:
return "_u8"
elif dtype == paddle.int8:
return "_i8"
elif dtype == paddle.int16:
return "_i16"
elif dtype == paddle.int32:
return "_i32"
elif dtype == paddle.int64:
return "_i64"
elif dtype == paddle.float32:
return "_fp32"
elif dtype == paddle.bfloat16:
return "_bf16"
else:
raise ValueError("Not support this dtype.")
def build_package(generated_dir, python_package_name):
"""
Build the package, not install it.
Args:
generated_dir: the source cu file dir.
python_package_name: the python package name.
"""
setup_file_path = generated_dir + "/setup_cuda.py"
python_path = sys.executable
with open(setup_file_path, "w") as f:
f.write(template_install.format(python_package_name=python_package_name))
f.close()
install_command = f"cd {generated_dir} && {python_path} setup_cuda.py build"
re = os.system(install_command)
assert re == 0
def rename_c_to_cu(generated_dir):
"""
Rename the .c files int generated_dir to .cu file, because the triton aot tool generate the .c files.
"""
# rename the .c file to .cu
for filename in os.listdir(generated_dir):
if filename.endswith(".c"):
old_path = os.path.join(generated_dir, filename)
new_path = os.path.join(generated_dir, filename + "u")
os.rename(old_path, new_path)
def get_pointer_hint(dtypes):
"""
Get the pointer hint from input list.
"""
hint = ""
for ele in dtypes:
if ele == paddle.float16:
hint += "*fp16:16,"
elif ele == paddle.uint8:
hint += "*u8:16,"
elif ele == paddle.int8:
hint += "*i8:16,"
elif ele == paddle.int16:
hint += "*i16:16,"
elif ele == paddle.float32:
hint += "*fp32:16,"
elif ele == paddle.bfloat16:
hint += "*bf16:16,"
elif ele == paddle.int32:
hint += "*i32:16,"
elif ele == paddle.int64:
hint += "*i64,"
elif ele == paddle.float8_e4m3fn:
hint += "*fp8e4nv:16,"
return hint
paddle_custom_op_head_part = """#include <vector>
#include <map>
#include "${op_name}_kernel.h"
#include "paddle/extension.h"
std::map<std::vector<int>, int> map_problem_${op_name};
CUdeviceptr get_tensor_ptr(const paddle::Tensor& input){
if (input.type() == paddle::DataType::FLOAT16) {
return (CUdeviceptr)(input.data<phi::dtype::float16>());
} else if (input.type() == paddle::DataType::BFLOAT16) {
return (CUdeviceptr)(input.data<phi::dtype::bfloat16>());
} else if (input.type() == paddle::DataType::INT32) {
return (CUdeviceptr)(input.data<int>());
} else if (input.type() == paddle::DataType::FLOAT32) {
return (CUdeviceptr)(input.data<float>());
} else if (input.type() == paddle::DataType::UINT8) {
return (CUdeviceptr)(input.data<uint8_t>());
} else if (input.type() == paddle::DataType::INT8) {
return (CUdeviceptr)(input.data<int8_t>());
} else if (input.type() == paddle::DataType::INT64) {
return (CUdeviceptr)(input.data<int64_t>());
} else if (input.type() == paddle::DataType::INT32) {
return (CUdeviceptr)(input.data<int32_t>());
} else if (input.type() == paddle::DataType::INT16) {
return (CUdeviceptr)(input.data<int16_t>());
} else if (input.type() == paddle::DataType::FLOAT8_E4M3FN) {
return (CUdeviceptr)(input.data<phi::dtype::float8_e4m3fn>());
} else {
assert(false);
return (CUdeviceptr)(nullptr);
}
}
int triton_cdiv(int x, int y) {
int result = (x + y - 1) / y;
return (int)(result);
}
"""
tune_and_invoke_part = """
std::vector<int> problem_size = {${key}};
auto run_triton_kernel = [&](int algo_id) -> CUresult{
return ${op_name}_kernel(run_stream,
${triton_kernel_args},
algo_id);
};
map_problem_${op_name}[problem_size] = 0;
if (!map_problem_${op_name}.count(problem_size)) {
std::cout << "we are tuning for ${op_name} which key is: {";
for (int i = 0; i < problem_size.size(); i++) {
std::cout << problem_size[i] << ", ";
}
std::cout << "}" << std::endl;
float min_time = 10000.f;
int select_id = -1;
constexpr int WARMUP = 5;
constexpr int REPEAT = 10;
for (int algo_id = 0; algo_id < ${op_name}_kernel_get_num_algos(); ++algo_id) {
cudaEvent_t beg[REPEAT];
cudaEvent_t end[REPEAT];
float elapsed_times[REPEAT];
auto status = CUDA_SUCCESS;
for (int ii = 0; ii < WARMUP + REPEAT; ii++) {
int repeat_id = ii - WARMUP;
if (repeat_id >= 0) {
(cudaEventCreate(beg + repeat_id));
(cudaEventCreate(end + repeat_id));
(cudaEventRecord(beg[repeat_id]));
}
auto flush_l2_cache = paddle::full(
{10 * 1024 * 1024}, 0, paddle::DataType::INT32, ${arbitary_output_name}.place());
// std::cout << &flush_l2_cache << std::endl;
// this is used when out is need to be reset to zero, such as split-k gemm.
${reset_zero_when_tune};
status = run_triton_kernel(algo_id);
// assert(status == CUDA_SUCCESS);
if (repeat_id >= 0) {
(cudaEventRecord(end[repeat_id]));
(cudaEventSynchronize(end[repeat_id]));
(cudaEventElapsedTime(
elapsed_times + repeat_id, beg[repeat_id], end[repeat_id]));
}
}
float avg_elapsed_time = 0.f;
for (int ii = 0; ii < REPEAT; ++ii) {
avg_elapsed_time += elapsed_times[ii];
}
std::cout << "algo id " << algo_id << " costs " << avg_elapsed_time << " ms" << std::endl;
if (avg_elapsed_time < min_time && status == CUDA_SUCCESS) {
min_time = avg_elapsed_time;
select_id = algo_id;
}
}
map_problem_${op_name}[problem_size] = select_id;
std::cout << "select algo id: " << select_id << std::endl;
${reset_zero_when_tune};
}
if (map_problem_${op_name}.count(problem_size)) {
int algo_id = map_problem_${op_name}[problem_size];
auto status = run_triton_kernel(algo_id);
assert(status == CUDA_SUCCESS);
}
"""
common_template = (
"""
std::vector<paddle::Tensor> ${op_name}_func(${input_and_attr}) {
${prepare_attr_for_triton_kernel}
${prepare_ptr_for_triton_kernel}
auto run_stream = ${arbitary_output_name}.stream();
"""
+ tune_and_invoke_part
+ """
return {${return_tensor_names}};
}
${d2s_infer_code}
PD_BUILD_OP(${op_name})
.Inputs({${paddle_input_sig}})
.Outputs({${paddle_output_sig}})
.Attrs({${paddle_attr_sig}})
.SetKernelFn(PD_KERNEL(${op_name}_func))
.SetInferDtypeFn(PD_INFER_DTYPE(${op_name}_InferDtype))
.SetInferShapeFn(PD_INFER_SHAPE(${op_name}_InferShape));
"""
)
def rendering_common_template(
func,
prepare_attr_for_triton_kernel,
prepare_ptr_for_triton_kernel,
return_tensor_names=None,
d2s_infer_code="",
):
"""
Render a template with given function and its arguments.
Args:
func: The function to render.
prepare_attr_for_triton_kernel: The code snippet that prepares attributes for Triton kernel.
prepare_ptr_for_triton_kernel: The code snippet that prepares pointers for Triton kernel.
return_tensor_names: The names of the returned tensors. Default is None.
"""
signature = inspect.signature(func)
arg_names = [v.name for v in signature.parameters.values()]
arg_defaults = [v.default for v in signature.parameters.values()]
input_and_attr = ""
paddle_input_sig = ""
paddle_attr_sig = ""
if return_tensor_names is None:
return_tensor_names = "useless"
prepare_ptr_for_triton_kernel += (
"auto useless = paddle::empty({1}, paddle::DataType::INT32, paddle::CPUPlace());"
)
for i in range(len(arg_names)):
if arg_defaults[i] is None:
input_and_attr += f"paddle::optional<paddle::Tensor> & {arg_names[i]},"
paddle_input_sig += f"""paddle::Optional("{arg_names[i]}"),"""
elif isinstance(arg_defaults[i], float):
input_and_attr += f"float {arg_names[i]},"
paddle_attr_sig += f""""{arg_names[i]}: float","""
elif isinstance(arg_defaults[i], bool):
input_and_attr += f"bool {arg_names[i]},"
paddle_attr_sig += f""""{arg_names[i]}: bool","""
elif isinstance(arg_defaults[i], int):
input_and_attr += f"int64_t {arg_names[i]},"
paddle_attr_sig += f""""{arg_names[i]}: int64_t","""
elif isinstance(arg_defaults[i], str):
input_and_attr += f"std::string {arg_names[i]},"
paddle_attr_sig += f""""{arg_names[i]}: std::string","""
elif arg_names[i] == "config":
continue
else:
input_and_attr += f"const paddle::Tensor & {arg_names[i]},"
paddle_input_sig += f""""{arg_names[i]}","""
input_and_attr = input_and_attr[:-1]
paddle_input_sig = paddle_input_sig[:-1]
if len(paddle_attr_sig) > 1:
paddle_attr_sig = paddle_attr_sig[:-1]
paddle_output_sig = ""
arbitary_output_name = ""
for name in return_tensor_names.split(","):
name = name.strip()
arbitary_output_name = name
paddle_output_sig += f""""{name}","""
paddle_output_sig = paddle_output_sig[:-1]
if "${op_name}_InferShape" not in d2s_infer_code:
d2s_infer_shape_part = (
"std::vector<std::vector<int64_t>> ${op_name}_InferShape("
"const std::vector<int64_t>& A_shape) {"
"return {${tmp}};"
"}\n "
)
tmp = ",".join(["A_shape"] * len(return_tensor_names.split(",")))
tmp_dict = {"tmp": tmp}
d2s_infer_shape_part = SubstituteTemplate(d2s_infer_shape_part, tmp_dict)
d2s_infer_code += d2s_infer_shape_part
if "${op_name}_InferDtype" not in d2s_infer_code:
d2s_infer_dtype_part = (
"std::vector<paddle::DataType> ${op_name}_InferDtype("
"const paddle::DataType& A_dtype) {"
"return {${tmp}};"
"}\n "
)
tmp = ",".join(["A_dtype"] * len(return_tensor_names.split(",")))
tmp_dict = {"tmp": tmp}
d2s_infer_dtype_part = SubstituteTemplate(d2s_infer_dtype_part, tmp_dict)
d2s_infer_code += d2s_infer_dtype_part
result_str = SubstituteTemplate(
common_template,
{
"input_and_attr": input_and_attr,
"prepare_attr_for_triton_kernel": prepare_attr_for_triton_kernel,
"prepare_ptr_for_triton_kernel": prepare_ptr_for_triton_kernel,
"return_tensor_names": return_tensor_names,
"arbitary_output_name": arbitary_output_name,
"d2s_infer_code": d2s_infer_code,
"paddle_input_sig": paddle_input_sig,
"paddle_output_sig": paddle_output_sig,
"paddle_attr_sig": paddle_attr_sig,
},
)
return paddle_custom_op_head_part + result_str
class KernelInterface:
"""
triton kernel interface.
"""
def __init__(
self,
func,
other_config,
key_args=["1"],
):
"""
triton kernel interface.
"""
self.func = func
self.key_args = key_args
signature = inspect.signature(func)
self.arg_names = [v.name for v in signature.parameters.values()]
for ele in self.arg_names:
assert self.arg_names.count(ele) == 1
# arg_defaults = [v.default for v in signature.parameters.values()]
# self.annotations = {
# name: ty for name, ty in func.__annotations__.items()
# }
self.annotations = dict(func.__annotations__)
self.constexprs = [
self.arg_names.index(name)
for name in self.arg_names
if self.annotations.get(name) == triton.language.core.constexpr
]
self.arg_exclude_constexpr = [
self.arg_names[i] for i in range(len(self.arg_names)) if i not in self.constexprs
]
import textwrap
py_script = textwrap.dedent(inspect.getsource(func))
import re
pat = r"def\s" + func.__name__
func_begin = re.findall(pat, py_script)
assert len(func_begin) == 1
func_begin = func_begin[0]
py_script = py_script[py_script.find(func_begin) :]
def decorator(*args, **kwargs):
"""
decorator for triton kernels.
Args:
*args: positional arguments
**kwargs: keyword arguments
"""
all_input = []
for i in range(len(args)):
all_input.append(args[i])
position_arguments_num = len(all_input)
for i in range(position_arguments_num, len(self.arg_names)):
if self.arg_names[i] in kwargs.keys():
all_input.append(kwargs[self.arg_names[i]])
else:
# means this input is not specified, it muse be a tl.constexpr.
assert i in self.constexprs
all_input.append(None)
dtypes = []
x_list = []
const_args = [self.arg_names[i] for i in self.constexprs]
# we dont allow there are two strings in const_args, and one is a substring of the other.
for i in const_args:
for j in const_args:
if i != j and i.find(j) != -1:
raise ValueError(
f"We find {i}, {j} in tl.constexpr args, and {j} is a substring of {i}, "
"please modify your triton kernel arguments names to avoid this."
)
modified_arg_exclude_constexpr = self.arg_exclude_constexpr
const_hint_dict = {}
for i in range(len(all_input)):
ele = all_input[i]
if (
isinstance(ele, paddle.Tensor)
or isinstance(ele, paddle.base.framework.EagerParamBase)
or isinstance(ele, paddle.base.framework.Parameter)
or isinstance(ele, paddle.base.framework.Variable)
or isinstance(ele, paddle.base.libpaddle.pir.Value)
):
dtypes.append(ele.dtype)
modified_arg_exclude_constexpr[i] = f"input_ptrs[{i}]"
elif i in self.constexprs:
const_hint_dict[self.arg_names[i]] = ele
else:
x_list.append(ele)
op_name = self.op_name
python_package_name = f"{op_name}_package"
tp_rank = paddle.distributed.get_rank()
generated_dir = envs.FD_TRITON_KERNEL_CACHE_DIR
if generated_dir is None:
generated_dir = f"/tmp/triton_cache/rank{tp_rank}"
print("the kernel cache dir is:", generated_dir)
assert generated_dir is not None, (
"TRITON_KERNEL_CACHE_DIR is None, please set it such as "
"export TRITON_KERNEL_CACHE_DIR=/tmp/triton_cache "
)
generated_dir = f"{generated_dir}/{op_name}"
os.makedirs(generated_dir, exist_ok=True)
py_script_file = f"{generated_dir}/triton_kernels.py"
extract_triton_kernel(func, py_script_file)
address_hint = get_pointer_hint(dtypes)
value_hint = get_value_hint(x_list)
const_args = [f"{{{ele}}}" for ele in const_args]
const_args = ",".join(const_args)
lanuch_grid = list(self.grid)
for i in range(len(lanuch_grid)):
ele = lanuch_grid[i]
if isinstance(ele, str):
for key in const_hint_dict.keys():
if key in ele:
ele = ele.replace(key, f"{{{key}}}")
else:
ele = str(ele)
lanuch_grid[i] = ele
if len(lanuch_grid) < 3:
lanuch_grid += ["1"] * (3 - len(lanuch_grid))
lanuch_grid = ",".join(lanuch_grid)
op_dict = {"op_name": op_name, "reset_zero_when_tune": ""}
op_dict["triton_kernel_args"] = ",".join(modified_arg_exclude_constexpr)
op_dict["key"] = ",".join(self.key_args)
# when tuning, we need to reset the out to zero.
if "reset_zero_when_tune" in other_config.keys():
op_dict["reset_zero_when_tune"] = other_config["reset_zero_when_tune"]
paddle_custom_op_file_path = f"{generated_dir}/{op_name}.cu"
so_path = find_so_path(generated_dir, python_package_name)
if so_path is None:
print("== we do not find so_path, we need to compile it")
with open(paddle_custom_op_file_path, "w") as f:
f.write(
SubstituteTemplate(
self.custom_op_template,
op_dict,
)
)
f.close()
# ahead of time compile command.
aot_template = (
f"""{python_path} {compile_file} {py_script_file} """
+ f""" -n {func.__name__} -o {generated_dir}/{op_name}_kernel """
+ f"""--out-name {op_name}_kernel """
+ """ -w {num_warps} -ns {num_stages} """
+ f""" -s"{address_hint} {value_hint} {const_args}" """
+ f""" -g "{lanuch_grid}" """
)
all_tune_config = [{key: value} for key, value in self.tune_config.items()]
if len(all_tune_config) == 0:
# when user do not specify config, we use const_hint_dict as config.
all_tune_config = [const_hint_dict]
# reset const_hint_dict as empty.
const_hint_dict = {}
codegen_commands = []
for config in all_tune_config:
for key in const_hint_dict.keys():
if const_hint_dict[key] is not None:
if key not in config.keys():
config[key] = const_hint_dict[key]
else:
if config[key] == const_hint_dict[key]:
pass
else:
message = (
f"you specify {key} both in arguments and config, "
"and they are not same, this is wrong."
)
raise ValueError(message)
else:
assert key in config.keys(), f"you must specify {key} in your config."
if "num_warps" not in config.keys():
config["num_warps"] = 4
if "num_stages" not in config.keys():
config["num_stages"] = 4
for key in config:
assert config[key] is not None, f"{key} must be specified."
codegen_command = aot_template.format(
**config,
)
print(codegen_command)
codegen_commands.append(codegen_command)
multi_process_do(codegen_commands)
link_command = (
f"{python_path} {link_file} " f"{generated_dir}/*.h -o {generated_dir}/{op_name}_kernel"
)
re = os.system(link_command)
assert re == 0
# rename the .c file to .cu
rename_c_to_cu(generated_dir)
# build the package to so, not install
build_package(generated_dir, python_package_name)
if op_name not in OpProtoHolder.instance().op_proto_map.keys():
so_path = find_so_path(generated_dir, python_package_name)
print("== we find so_path: ", so_path)
assert so_path is not None
paddle.utils.cpp_extension.load_op_meta_info_and_register_op(so_path)
self.decorator = decorator
def __getitem__(self, op_name_and_grid):
"""
override the operator [], which will call the decorator function.
Args:
op_name_and_grid: the name of the operator and the grid size.
Returns:
the decorator function.
"""
assert len(op_name_and_grid) >= 3, "len(op_name_and_grid) must >= 3."
self.op_name = op_name_and_grid[0]
self.custom_op_template = op_name_and_grid[1]
self.grid = op_name_and_grid[2]
if len(op_name_and_grid) == 3:
self.tune_config = {}
else:
self.tune_config = op_name_and_grid[3]
return self.decorator
def paddle_use_triton(other_config={}, key=[]):
"""
The decorator function that wraps the original function.
Args:
func: the original function.
Returns:
the wrapped function.
"""
def decorator(func):
"""
The decorator function that wraps the original function.
Args:
func: the original function.
Returns:
the wrapped function.
"""
return KernelInterface(func, other_config, key)
return decorator