[Metax] adapt prefix caching & cpu swap (#5844)

Co-authored-by: root <root@lt-wks-10-0-180-15.pub.metax-tech.com>
This commit is contained in:
MingkunZhang
2025-12-31 17:02:48 +08:00
committed by GitHub
parent 193886e745
commit f732d7d2ad
4 changed files with 94 additions and 1 deletions
+52
View File
@@ -0,0 +1,52 @@
#include "paddle/extension.h"
#include "pybind11/pybind11.h"
namespace py = pybind11;
// 自定义异常类,用于处理CUDA错误
class CudaError : public std::exception {
public:
explicit CudaError(cudaError_t error) : error_(error) {}
const char* what() const noexcept override {
return cudaGetErrorString(error_);
}
private:
cudaError_t error_;
};
// 检查CUDA错误并抛出异常
void check_cuda_error(cudaError_t error) {
if (error != cudaSuccess) {
throw CudaError(error);
}
}
// 封装cudaHostAlloc的Python函数
uintptr_t cuda_host_alloc(size_t size,
unsigned int flags = cudaHostAllocDefault) {
void* ptr = nullptr;
check_cuda_error(cudaHostAlloc(&ptr, size, flags));
return reinterpret_cast<uintptr_t>(ptr);
}
// 封装cudaFreeHost的Python函数
void cuda_host_free(uintptr_t ptr) {
check_cuda_error(cudaFreeHost(reinterpret_cast<void*>(ptr)));
}
PYBIND11_MODULE(fastdeploy_ops, m) {
/**
* alloc_cache_pinned.cc
* cuda_host_alloc
* cuda_host_free
*/
m.def("cuda_host_alloc",
&cuda_host_alloc,
"Allocate pinned memory",
py::arg("size"),
py::arg("flags") = cudaHostAllocDefault);
m.def(
"cuda_host_free", &cuda_host_free, "Free pinned memory", py::arg("ptr"));
py::register_exception<CudaError>(m, "CudaError");
}
+5
View File
@@ -637,12 +637,17 @@ elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
"gpu_ops/sample_kernels/rejection_top_p_sampling.cu",
"gpu_ops/sample_kernels/top_k_renorm_probs.cu",
"gpu_ops/sample_kernels/min_p_sampling_from_probs.cu",
"gpu_ops/get_data_ptr_ipc.cu",
"gpu_ops/ipc_sent_key_value_cache_by_remote_ptr.cu",
"gpu_ops/unset_data_ipc.cu",
"gpu_ops/swap_cache_batch.cu",
"metax_ops/moe_dispatch.cu",
"metax_ops/moe_ffn.cu",
"metax_ops/moe_reduce.cu",
"metax_ops/fused_moe.cu",
"metax_ops/apply_rope_qkv.cu",
"metax_ops/cache_kv_with_rope.cu",
"metax_ops/cpp_extensions.cc",
]
sources += find_end_files("gpu_ops/speculate_decoding", ".cu")
+31
View File
@@ -14,6 +14,8 @@
# limitations under the License.
"""
import os
import paddle
from fastdeploy.platforms import current_platform
@@ -39,6 +41,29 @@ try:
def get_peer_mem_addr(*args, **kwargs):
raise RuntimeError("CUDA no need of get_peer_mem_addr!")
elif current_platform.is_maca():
from fastdeploy.model_executor.ops.gpu import ( # get_output_kv_signal,; ipc_sent_key_value_cache_by_remote_ptr_block_sync,
cuda_host_alloc,
cuda_host_free,
get_data_ptr_ipc,
ipc_sent_key_value_cache_by_remote_ptr,
set_data_ipc,
share_external_data,
swap_cache_all_layers,
unset_data_ipc,
)
memory_allocated = paddle.device.memory_allocated
def get_peer_mem_addr(*args, **kwargs):
raise RuntimeError("CUDA no need of get_peer_mem_addr!")
def get_output_kv_signal(*args, **kwargs):
raise RuntimeError("Metax get_output_kv_signal UNIMPLENENTED!")
def ipc_sent_key_value_cache_by_remote_ptr_block_sync(*args, **kwargs):
raise RuntimeError("Metax ipc_sent_key_value_cache_by_remote_ptr_block_sync UNIMPLENENTED!")
elif current_platform.is_xpu():
from fastdeploy.model_executor.ops.xpu import (
cuda_host_alloc,
@@ -69,6 +94,8 @@ try:
def set_device(device):
if current_platform.is_cuda():
paddle.set_device(f"gpu:{device}")
elif current_platform.is_maca():
paddle.set_device(f"metax_gpu:{device}")
elif current_platform.is_xpu():
paddle.set_device(f"xpu:{device}")
else:
@@ -77,6 +104,8 @@ try:
def share_external_data_(cache, cache_name, cache_shape, use_ipc):
if current_platform.is_cuda():
cache = share_external_data(cache, cache_name, cache_shape)
elif current_platform.is_maca():
cache = share_external_data(cache, cache_name, cache_shape)
elif current_platform.is_xpu():
cache = share_external_data(cache, cache_name, cache_shape, use_ipc)
else:
@@ -86,6 +115,8 @@ try:
def get_all_visible_devices():
if current_platform.is_xpu():
return "XPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
elif current_platform.is_maca():
return f'MACA_VISIBLE_DEVICES={os.environ.get("MACA_VISIBLE_DEVICES", "0,1,2,3,4,5,6,7")}'
else:
return "CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
+6 -1
View File
@@ -531,7 +531,12 @@ class EngineArgs:
self.tokenizer = self.model
if self.splitwise_role == "decode":
self.enable_prefix_caching = False
if not current_platform.is_cuda() and not current_platform.is_xpu() and not current_platform.is_intel_hpu():
if (
not current_platform.is_cuda()
and not current_platform.is_xpu()
and not current_platform.is_intel_hpu()
and not current_platform.is_maca()
):
self.enable_prefix_caching = False
if self.enable_logprob:
if not current_platform.is_cuda() and not current_platform.is_xpu():