mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
@@ -16,93 +16,104 @@
|
||||
#include "paddle/extension.h"
|
||||
|
||||
template <paddle::DataType D>
|
||||
void SwapCacheImpl(const paddle::Tensor& cache_gpu, // gpu
|
||||
const int64_t& cache_cpu_pointer, // cpu
|
||||
void SwapCacheImpl(const paddle::Tensor& cache_gpu, // gpu
|
||||
const int64_t& cache_cpu_pointer, // cpu
|
||||
const int64_t& max_block_num_cpu,
|
||||
const std::vector<int64_t>& swap_block_ids_gpu,
|
||||
const std::vector<int64_t>& swap_block_ids_cpu,
|
||||
// const paddle::Tensor& swap_block_ids_dst, // cpu
|
||||
// const paddle::Tensor& swap_block_ids_src, // cpu
|
||||
// const paddle::Tensor& swap_block_ids_dst, // cpu
|
||||
// const paddle::Tensor& swap_block_ids_src, // cpu
|
||||
int mode) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
data_t* cache_gpu_ptr = const_cast<data_t*>(cache_gpu.data<data_t>());
|
||||
auto* cache_cpu_ptr = reinterpret_cast<data_t*>(cache_cpu_pointer);
|
||||
auto cache_shape = cache_gpu.shape();
|
||||
// auto* swap_block_ids_dst_ptr = swap_block_ids_dst.data<int32_t>();
|
||||
// auto* swap_block_ids_src_ptr = swap_block_ids_src.data<int32_t>();
|
||||
// const int swap_block_length = swap_block_ids_dst.shape()[0];
|
||||
const int64_t max_block_num_gpu = cache_shape[0];
|
||||
const int num_heads = cache_shape[1];
|
||||
const int block_size = cache_shape[2];
|
||||
const int head_dim = cache_shape[3];
|
||||
const int64_t cache_stride = num_heads * block_size * head_dim;
|
||||
auto stream = cache_gpu.stream();
|
||||
for (int i = 0; i < swap_block_ids_gpu.size(); ++i) {
|
||||
int64_t gpu_block_id = swap_block_ids_gpu[i];
|
||||
int64_t cpu_block_id = swap_block_ids_cpu[i];
|
||||
assert(gpu_block_id >= 0 && gpu_block_id < max_block_num_gpu);
|
||||
assert(cpu_block_id >= 0 && cpu_block_id < max_block_num_cpu);
|
||||
auto *cache_gpu_ptr_now = cache_gpu_ptr + gpu_block_id * cache_stride;
|
||||
auto *cache_cpu_ptr_now = cache_cpu_ptr + cpu_block_id * cache_stride;
|
||||
if (mode == 0) { // copy from device to host
|
||||
cudaMemcpyAsync(cache_cpu_ptr_now, cache_gpu_ptr_now, cache_stride * sizeof(DataType_), cudaMemcpyDeviceToHost, stream);
|
||||
// cudaMemcpy(cache_dst_ptr_now, cache_src_ptr_now, cache_stride * sizeof(DataType_), cudaMemcpyDeviceToHost);
|
||||
} else { // copy from host to device
|
||||
cudaMemcpyAsync(cache_gpu_ptr_now, cache_cpu_ptr_now, cache_stride * sizeof(DataType_), cudaMemcpyHostToDevice, stream);
|
||||
// cudaMemcpy(cache_dst_ptr_now, cache_src_ptr_now, cache_stride * sizeof(DataType_), cudaMemcpyHostToDevice);
|
||||
}
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
data_t* cache_gpu_ptr = const_cast<data_t*>(cache_gpu.data<data_t>());
|
||||
auto* cache_cpu_ptr = reinterpret_cast<data_t*>(cache_cpu_pointer);
|
||||
auto cache_shape = cache_gpu.shape();
|
||||
// auto* swap_block_ids_dst_ptr = swap_block_ids_dst.data<int32_t>();
|
||||
// auto* swap_block_ids_src_ptr = swap_block_ids_src.data<int32_t>();
|
||||
// const int swap_block_length = swap_block_ids_dst.shape()[0];
|
||||
const int64_t max_block_num_gpu = cache_shape[0];
|
||||
const int num_heads = cache_shape[1];
|
||||
const int block_size = cache_shape[2];
|
||||
const int head_dim = cache_shape[3];
|
||||
const int64_t cache_stride = num_heads * block_size * head_dim;
|
||||
auto stream = cache_gpu.stream();
|
||||
for (int i = 0; i < swap_block_ids_gpu.size(); ++i) {
|
||||
int64_t gpu_block_id = swap_block_ids_gpu[i];
|
||||
int64_t cpu_block_id = swap_block_ids_cpu[i];
|
||||
assert(gpu_block_id >= 0 && gpu_block_id < max_block_num_gpu);
|
||||
assert(cpu_block_id >= 0 && cpu_block_id < max_block_num_cpu);
|
||||
auto* cache_gpu_ptr_now = cache_gpu_ptr + gpu_block_id * cache_stride;
|
||||
auto* cache_cpu_ptr_now = cache_cpu_ptr + cpu_block_id * cache_stride;
|
||||
if (mode == 0) { // copy from device to host
|
||||
cudaMemcpyAsync(cache_cpu_ptr_now,
|
||||
cache_gpu_ptr_now,
|
||||
cache_stride * sizeof(DataType_),
|
||||
cudaMemcpyDeviceToHost,
|
||||
stream);
|
||||
// cudaMemcpy(cache_dst_ptr_now, cache_src_ptr_now, cache_stride *
|
||||
// sizeof(DataType_), cudaMemcpyDeviceToHost);
|
||||
} else { // copy from host to device
|
||||
cudaMemcpyAsync(cache_gpu_ptr_now,
|
||||
cache_cpu_ptr_now,
|
||||
cache_stride * sizeof(DataType_),
|
||||
cudaMemcpyHostToDevice,
|
||||
stream);
|
||||
// cudaMemcpy(cache_dst_ptr_now, cache_src_ptr_now, cache_stride *
|
||||
// sizeof(DataType_), cudaMemcpyHostToDevice);
|
||||
}
|
||||
cudaStreamSynchronize(stream);
|
||||
}
|
||||
cudaStreamSynchronize(stream);
|
||||
}
|
||||
|
||||
void SwapCache(const paddle::Tensor& cache_gpu, // gpu
|
||||
int64_t cache_cpu_ptr, // cpu memory pointer
|
||||
int64_t max_block_num_cpu, // cpu max block num
|
||||
void SwapCache(const paddle::Tensor& cache_gpu, // gpu
|
||||
int64_t cache_cpu_ptr, // cpu memory pointer
|
||||
int64_t max_block_num_cpu, // cpu max block num
|
||||
const std::vector<int64_t>& swap_block_ids_gpu,
|
||||
const std::vector<int64_t>& swap_block_ids_cpu,
|
||||
int rank,
|
||||
int mode) {
|
||||
cudaSetDevice(rank); // used for distributed launch
|
||||
switch (cache_gpu.dtype()) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
return SwapCacheImpl<paddle::DataType::BFLOAT16>(
|
||||
cache_gpu,
|
||||
cache_cpu_ptr,
|
||||
max_block_num_cpu,
|
||||
swap_block_ids_gpu,
|
||||
swap_block_ids_cpu,
|
||||
mode);
|
||||
case paddle::DataType::FLOAT16:
|
||||
return SwapCacheImpl<paddle::DataType::FLOAT16>(
|
||||
cache_gpu,
|
||||
cache_cpu_ptr,
|
||||
max_block_num_cpu,
|
||||
swap_block_ids_gpu,
|
||||
swap_block_ids_cpu,
|
||||
mode);
|
||||
case paddle::DataType::UINT8:
|
||||
return SwapCacheImpl<paddle::DataType::UINT8>(
|
||||
cache_gpu,
|
||||
cache_cpu_ptr,
|
||||
max_block_num_cpu,
|
||||
swap_block_ids_gpu,
|
||||
swap_block_ids_cpu,
|
||||
mode);
|
||||
default:
|
||||
PD_THROW("Unsupported data type.");
|
||||
}
|
||||
cudaSetDevice(rank); // used for distributed launch
|
||||
switch (cache_gpu.dtype()) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
return SwapCacheImpl<paddle::DataType::BFLOAT16>(cache_gpu,
|
||||
cache_cpu_ptr,
|
||||
max_block_num_cpu,
|
||||
swap_block_ids_gpu,
|
||||
swap_block_ids_cpu,
|
||||
mode);
|
||||
case paddle::DataType::FLOAT16:
|
||||
return SwapCacheImpl<paddle::DataType::FLOAT16>(cache_gpu,
|
||||
cache_cpu_ptr,
|
||||
max_block_num_cpu,
|
||||
swap_block_ids_gpu,
|
||||
swap_block_ids_cpu,
|
||||
mode);
|
||||
case paddle::DataType::UINT8:
|
||||
return SwapCacheImpl<paddle::DataType::UINT8>(cache_gpu,
|
||||
cache_cpu_ptr,
|
||||
max_block_num_cpu,
|
||||
swap_block_ids_gpu,
|
||||
swap_block_ids_cpu,
|
||||
mode);
|
||||
default:
|
||||
PD_THROW("Unsupported data type.");
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(swap_cache)
|
||||
.Inputs({"cache_gpu",})
|
||||
.Attrs({"cache_cpu_ptr: int64_t",
|
||||
"max_block_num_cpu: int64_t",
|
||||
"swap_block_ids_gpu: std::vector<int64_t>",
|
||||
"swap_block_ids_cpu: std::vector<int64_t>",
|
||||
"rank: int",
|
||||
"mode: int",})
|
||||
.Inputs({
|
||||
"cache_gpu",
|
||||
})
|
||||
.Attrs({
|
||||
"cache_cpu_ptr: int64_t",
|
||||
"max_block_num_cpu: int64_t",
|
||||
"swap_block_ids_gpu: std::vector<int64_t>",
|
||||
"swap_block_ids_cpu: std::vector<int64_t>",
|
||||
"rank: int",
|
||||
"mode: int",
|
||||
})
|
||||
.Outputs({"cache_dst_out"})
|
||||
.SetInplaceMap({{"cache_gpu", "cache_dst_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SwapCache));
|
||||
|
||||
Reference in New Issue
Block a user