[Feature] pd support dy-c8 ipc (#5750)

* pd support dy-c8 ipc

* update code

* support v0

* update code
This commit is contained in:
kevin
2025-12-25 21:22:34 +08:00
committed by GitHub
parent 4fa76296d9
commit 5538dda3c8
4 changed files with 403 additions and 271 deletions
@@ -12,66 +12,74 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "fstream"
#include "iostream"
#include "iomanip"
#include <nvml.h>
#include <iostream>
#include "fstream"
#include "helper.h"
#include "iomanip"
#include "iostream"
// #define PRINT_GPU_MEMORY
// 函数用于获取 NVIDIA GPU 显存信息
bool getNvidiaGPUMemoryUsage(int callLine) {
#ifndef PRINT_GPU_MEMORY
return true;
#endif
// 初始化 NVML
nvmlReturn_t result;
result = nvmlInit();
if (NVML_SUCCESS != result) {
std::cerr << callLine << ": Failed to initialize NVML: " << nvmlErrorString(result) << std::endl;
return false;
}
// 获取 GPU 设备数量
unsigned int deviceCount;
result = nvmlDeviceGetCount(&deviceCount);
if (NVML_SUCCESS != result) {
std::cerr << callLine << ": Failed to get device count: " << nvmlErrorString(result) << std::endl;
nvmlShutdown();
return false;
}
// 遍历每个 GPU 设备
for (unsigned int i = 0; i < deviceCount; ++i) {
nvmlDevice_t device;
result = nvmlDeviceGetHandleByIndex(i, &device);
if (NVML_SUCCESS != result) {
std::cerr << callLine << ": Failed to get device handle for device " << i << ": " << nvmlErrorString(result) << std::endl;
continue;
}
// 获取显存信息
nvmlMemory_t memory;
result = nvmlDeviceGetMemoryInfo(device, &memory);
if (NVML_SUCCESS != result) {
std::cerr << callLine << ": Failed to get memory info for device " << i << ": " << nvmlErrorString(result) << std::endl;
continue;
}
// 只打印一行信息并显示调用函数时的行号
std::cout << callLine << ": GPU " << i << " - Total: " << memory.total / (1024 * 1024)
<< " MiB, Used: " << memory.used / (1024 * 1024)
<< " MiB, Free: " << memory.free / (1024 * 1024) << " MiB" << std::endl;
}
// 清理 NVML 资源
#ifndef PRINT_GPU_MEMORY
return true;
#endif
// 初始化 NVML
nvmlReturn_t result;
result = nvmlInit();
if (NVML_SUCCESS != result) {
std::cerr << callLine
<< ": Failed to initialize NVML: " << nvmlErrorString(result)
<< std::endl;
return false;
}
// 获取 GPU 设备数量
unsigned int deviceCount;
result = nvmlDeviceGetCount(&deviceCount);
if (NVML_SUCCESS != result) {
std::cerr << callLine
<< ": Failed to get device count: " << nvmlErrorString(result)
<< std::endl;
nvmlShutdown();
return true;
return false;
}
// 遍历每个 GPU 设备
for (unsigned int i = 0; i < deviceCount; ++i) {
nvmlDevice_t device;
result = nvmlDeviceGetHandleByIndex(i, &device);
if (NVML_SUCCESS != result) {
std::cerr << callLine << ": Failed to get device handle for device " << i
<< ": " << nvmlErrorString(result) << std::endl;
continue;
}
// 获取显存信息
nvmlMemory_t memory;
result = nvmlDeviceGetMemoryInfo(device, &memory);
if (NVML_SUCCESS != result) {
std::cerr << callLine << ": Failed to get memory info for device " << i
<< ": " << nvmlErrorString(result) << std::endl;
continue;
}
// 只打印一行信息并显示调用函数时的行号
std::cout << callLine << ": GPU " << i
<< " - Total: " << memory.total / (1024 * 1024)
<< " MiB, Used: " << memory.used / (1024 * 1024)
<< " MiB, Free: " << memory.free / (1024 * 1024) << " MiB"
<< std::endl;
}
// 清理 NVML 资源
nvmlShutdown();
return true;
}
// #define DEBUG_IPC_SENT
// #define DEBUG_IPC_SENT_SYNC_AND_PRINT
template<typename T>
template <typename T>
void sent_key_value_by_remote_ptr(
const T* local_key_tensor_base_ptr, // gpu ptr
const T* local_value_tensor_base_ptr, // gpu ptr
const int32_t* local_block_ids_ptr, //cpu ptr,
const T* local_key_tensor_base_ptr, // gpu ptr
const T* local_value_tensor_base_ptr, // gpu ptr
const int32_t* local_block_ids_ptr, // cpu ptr,
const int32_t* remote_block_ids_ptr,
const int32_t block_num,
const int64_t block_idx_stride,
@@ -80,255 +88,275 @@ void sent_key_value_by_remote_ptr(
const int32_t remote_device_id,
T* remote_key_tensor_base_ptr, // gpu ptr
T* remote_value_tensor_base_ptr, // gpu ptr
cudaStream_t stream){
for(int block_idx=0;block_idx < block_num; ++block_idx){
const T* local_key_tensor_sent_ptr = local_key_tensor_base_ptr + local_block_ids_ptr[block_idx] * block_idx_stride;
T* remote_key_tensor_sent_ptr = remote_key_tensor_base_ptr + remote_block_ids_ptr[block_idx] * block_idx_stride;
#ifdef DEBUG_IPC_SENT
std::cout<<"remote_key_tensor_sent_ptr:"<<(int64_t)remote_key_tensor_sent_ptr
<<" local_key_tensor_sent_ptr:"<<(int64_t)local_key_tensor_sent_ptr
<<" local_device_id:" << local_device_id
<<" remote_device_id:" << remote_device_id
<<" block_idx_stride:" << block_idx_stride
<<" block_size_byte:" << block_size_byte
<<" stream: " << stream
<<" local_block_ids: " << local_block_ids_ptr[block_idx]
<<" remote_block_ids: " << remote_block_ids_ptr[block_idx]
<<std::endl;
#endif
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
cudaDeviceSynchronize();
PrintMatrix<T>(reinterpret_cast<const T*>(local_key_tensor_sent_ptr),
128 * 1,
"ipc_send_src_key.datatxt." + std::to_string(local_device_id),
128 * 1);
cudaDeviceSynchronize();
#endif
#ifndef DEBUG_IPC_SENT_SYNC_AND_PRINT
cudaMemcpyPeerAsync(
reinterpret_cast<void*>(remote_key_tensor_sent_ptr),
remote_device_id,
reinterpret_cast<const void*>(local_key_tensor_sent_ptr),
local_device_id,
block_size_byte,
stream);
#endif
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
cudaMemcpyPeer(
reinterpret_cast<void*>(remote_key_tensor_sent_ptr),
remote_device_id,
reinterpret_cast<const void*>(local_key_tensor_sent_ptr),
local_device_id,
block_size_byte);
#endif
cudaError_t err = cudaGetLastError();
if ( err != cudaSuccess )
{
printf("CUDA Error: %s\n", cudaGetErrorString(err));
}
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
cudaDeviceSynchronize();
PrintMatrix<T>(reinterpret_cast<T*>(remote_key_tensor_sent_ptr),
128 * 1,
"ipc_send_tgt_key.datatxt." + std::to_string(local_device_id),
128 * 1);
cudaDeviceSynchronize();
#endif
const T* local_value_tensor_sent_ptr = local_value_tensor_base_ptr + local_block_ids_ptr[block_idx] * block_idx_stride;
T* remote_value_tensor_sent_ptr = remote_value_tensor_base_ptr + remote_block_ids_ptr[block_idx] * block_idx_stride;
cudaStream_t stream) {
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const T* local_key_tensor_sent_ptr =
local_key_tensor_base_ptr +
local_block_ids_ptr[block_idx] * block_idx_stride;
T* remote_key_tensor_sent_ptr =
remote_key_tensor_base_ptr +
remote_block_ids_ptr[block_idx] * block_idx_stride;
#ifdef DEBUG_IPC_SENT
std::cout<<"remote_value_tensor_sent_ptr:"<<(int64_t)remote_value_tensor_sent_ptr
<<" local_value_tensor_sent_ptr:"<<(int64_t)local_value_tensor_sent_ptr
<<" local_device_id:" << local_device_id
<<" remote_device_id:" << remote_device_id
<<" block_idx_stride:" << block_idx_stride
<<" block_size_byte:" << block_size_byte
<<" stream: " << stream
<<" local_block_ids: " << local_block_ids_ptr[block_idx]
<<" remote_block_ids: " << remote_block_ids_ptr[block_idx]
<<std::endl;
std::cout << "remote_key_tensor_sent_ptr:"
<< (int64_t)remote_key_tensor_sent_ptr
<< " local_key_tensor_sent_ptr:"
<< (int64_t)local_key_tensor_sent_ptr
<< " local_device_id:" << local_device_id
<< " remote_device_id:" << remote_device_id
<< " block_idx_stride:" << block_idx_stride
<< " block_size_byte:" << block_size_byte << " stream: " << stream
<< " local_block_ids: " << local_block_ids_ptr[block_idx]
<< " remote_block_ids: " << remote_block_ids_ptr[block_idx]
<< std::endl;
#endif
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
cudaDeviceSynchronize();
PrintMatrix<T>(reinterpret_cast<const T*>(local_value_tensor_sent_ptr),
128 * 1,
"ipc_send_src_value.datatxt." + std::to_string(local_device_id),
128 * 1);
cudaDeviceSynchronize();
cudaDeviceSynchronize();
PrintMatrix<T>(
reinterpret_cast<const T*>(local_key_tensor_sent_ptr),
128 * 1,
"ipc_send_src_key.datatxt." + std::to_string(local_device_id),
128 * 1);
cudaDeviceSynchronize();
#endif
#ifndef DEBUG_IPC_SENT_SYNC_AND_PRINT
cudaMemcpyPeerAsync(
reinterpret_cast<void*>(remote_value_tensor_sent_ptr),
remote_device_id,
reinterpret_cast<const void*>(local_value_tensor_sent_ptr),
local_device_id,
block_size_byte,
stream);
cudaMemcpyPeerAsync(
reinterpret_cast<void*>(remote_key_tensor_sent_ptr),
remote_device_id,
reinterpret_cast<const void*>(local_key_tensor_sent_ptr),
local_device_id,
block_size_byte,
stream);
#endif
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
cudaMemcpyPeer(
reinterpret_cast<void*>(remote_value_tensor_sent_ptr),
remote_device_id,
reinterpret_cast<const void*>(local_value_tensor_sent_ptr),
local_device_id,
block_size_byte);
cudaDeviceSynchronize();
#endif
err = cudaGetLastError();
if ( err != cudaSuccess )
{
printf("CUDA Error: %s\n", cudaGetErrorString(err));
}
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
PrintMatrix<T>(reinterpret_cast<T*>(remote_value_tensor_sent_ptr),
128 * 1,
"ipc_send_tgt_value.datatxt." + std::to_string(local_device_id),
128 * 1);
cudaDeviceSynchronize();
cudaMemcpyPeer(reinterpret_cast<void*>(remote_key_tensor_sent_ptr),
remote_device_id,
reinterpret_cast<const void*>(local_key_tensor_sent_ptr),
local_device_id,
block_size_byte);
#endif
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("CUDA Error: %s\n", cudaGetErrorString(err));
}
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
cudaDeviceSynchronize();
PrintMatrix<T>(
reinterpret_cast<T*>(remote_key_tensor_sent_ptr),
128 * 1,
"ipc_send_tgt_key.datatxt." + std::to_string(local_device_id),
128 * 1);
cudaDeviceSynchronize();
#endif
const T* local_value_tensor_sent_ptr =
local_value_tensor_base_ptr +
local_block_ids_ptr[block_idx] * block_idx_stride;
T* remote_value_tensor_sent_ptr =
remote_value_tensor_base_ptr +
remote_block_ids_ptr[block_idx] * block_idx_stride;
#ifdef DEBUG_IPC_SENT
std::cout << "remote_value_tensor_sent_ptr:"
<< (int64_t)remote_value_tensor_sent_ptr
<< " local_value_tensor_sent_ptr:"
<< (int64_t)local_value_tensor_sent_ptr
<< " local_device_id:" << local_device_id
<< " remote_device_id:" << remote_device_id
<< " block_idx_stride:" << block_idx_stride
<< " block_size_byte:" << block_size_byte << " stream: " << stream
<< " local_block_ids: " << local_block_ids_ptr[block_idx]
<< " remote_block_ids: " << remote_block_ids_ptr[block_idx]
<< std::endl;
#endif
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
cudaDeviceSynchronize();
PrintMatrix<T>(
reinterpret_cast<const T*>(local_value_tensor_sent_ptr),
128 * 1,
"ipc_send_src_value.datatxt." + std::to_string(local_device_id),
128 * 1);
cudaDeviceSynchronize();
#endif
#ifndef DEBUG_IPC_SENT_SYNC_AND_PRINT
cudaMemcpyPeerAsync(
reinterpret_cast<void*>(remote_value_tensor_sent_ptr),
remote_device_id,
reinterpret_cast<const void*>(local_value_tensor_sent_ptr),
local_device_id,
block_size_byte,
stream);
#endif
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
cudaMemcpyPeer(reinterpret_cast<void*>(remote_value_tensor_sent_ptr),
remote_device_id,
reinterpret_cast<const void*>(local_value_tensor_sent_ptr),
local_device_id,
block_size_byte);
cudaDeviceSynchronize();
#endif
err = cudaGetLastError();
if (err != cudaSuccess) {
printf("CUDA Error: %s\n", cudaGetErrorString(err));
}
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
PrintMatrix<T>(
reinterpret_cast<T*>(remote_value_tensor_sent_ptr),
128 * 1,
"ipc_send_tgt_value.datatxt." + std::to_string(local_device_id),
128 * 1);
cudaDeviceSynchronize();
#endif
}
}
void SentKeyValueByRemotePtr(const paddle::Tensor& local_key_tensor,
const paddle::Tensor& local_value_tensor,
const paddle::Tensor& local_block_ids, // cpu
const paddle::Tensor& remote_block_ids, // cpu
const paddle::Tensor& local_block_ids, // cpu
const paddle::Tensor& remote_block_ids, // cpu
const paddle::Tensor& remote_key_tensor,
const paddle::Tensor& remote_value_tensor,
const int& block_num,
const int& local_device_id,
const int& remote_device_id,
const int64_t& cuda_stream_raw) {
std::vector<int64_t> cache_key_tensor_shape = local_key_tensor.shape();
getNvidiaGPUMemoryUsage(__LINE__);
// auto cuda_stream = local_key_tensor.stream();
cudaStream_t cuda_stream = (cudaStream_t)cuda_stream_raw;
getNvidiaGPUMemoryUsage(__LINE__);
// const cudaStream_t cuda_stream = *(reinterpret_cast<const cudaStream_t*>(&stream));
#ifdef DEBUG_IPC_SENT
std::cout<<"#### 000"<<std::endl;
#endif
const int64_t& cuda_stream_raw,
const bool& is_scale) {
std::vector<int64_t> cache_key_tensor_shape = local_key_tensor.shape();
getNvidiaGPUMemoryUsage(__LINE__);
// auto cuda_stream = local_key_tensor.stream();
cudaStream_t cuda_stream = (cudaStream_t)cuda_stream_raw;
getNvidiaGPUMemoryUsage(__LINE__);
// const cudaStream_t cuda_stream = *(reinterpret_cast<const
// cudaStream_t*>(&stream));
#ifdef DEBUG_IPC_SENT
std::cout << "#### 000" << std::endl;
#endif
int32_t total_block_num_local = cache_key_tensor_shape[0];
int32_t kv_num_head_local = cache_key_tensor_shape[1];
int32_t block_size_local = cache_key_tensor_shape[2];
int32_t hidden_size_local = cache_key_tensor_shape[3];
getNvidiaGPUMemoryUsage(__LINE__);
int32_t total_block_num_local = cache_key_tensor_shape[0];
int32_t kv_num_head_local = cache_key_tensor_shape[1];
int32_t block_size_local = cache_key_tensor_shape[2];
int32_t hidden_size_local = cache_key_tensor_shape[3];
getNvidiaGPUMemoryUsage(__LINE__);
auto local_block_ids_ptr = local_block_ids.data<int32_t>(); // cpu
auto remote_block_ids_ptr = remote_block_ids.data<int32_t>(); // cpu
auto remote_key_ptr = remote_key_tensor.data<int64_t>()[0];
auto remote_value_ptr = remote_value_tensor.data<int64_t>()[0];
getNvidiaGPUMemoryUsage(__LINE__);
auto local_block_ids_ptr = local_block_ids.data<int32_t>(); // cpu
auto remote_block_ids_ptr = remote_block_ids.data<int32_t>(); // cpu
auto remote_key_ptr = remote_key_tensor.data<int64_t>()[0];
auto remote_value_ptr = remote_value_tensor.data<int64_t>()[0];
getNvidiaGPUMemoryUsage(__LINE__);
#ifdef DEBUG_IPC_SENT
std::cout<<"#### 1111"
<< " remote_key_ptr: "<<remote_key_ptr
<< " remote_value_ptr: "<<remote_value_ptr<<std::endl;
#endif
getNvidiaGPUMemoryUsage(__LINE__);
int64_t block_idx_stride = kv_num_head_local*block_size_local*hidden_size_local;
auto local_key_tensor_ptr = local_key_tensor.data();
auto local_value_tensor_ptr = local_value_tensor.data();
getNvidiaGPUMemoryUsage(__LINE__);
#ifdef DEBUG_IPC_SENT
std::cout<<"#### 2222"<<std::endl;
#endif
#ifdef DEBUG_IPC_SENT
std::cout << "#### 1111"
<< " remote_key_ptr: " << remote_key_ptr
<< " remote_value_ptr: " << remote_value_ptr << std::endl;
#endif
getNvidiaGPUMemoryUsage(__LINE__);
int64_t block_idx_stride =
kv_num_head_local * block_size_local * hidden_size_local;
if (is_scale == true) {
block_idx_stride = kv_num_head_local * block_size_local;
}
auto local_key_tensor_ptr = local_key_tensor.data();
auto local_value_tensor_ptr = local_value_tensor.data();
getNvidiaGPUMemoryUsage(__LINE__);
#ifdef DEBUG_IPC_SENT
std::cout << "#### 2222" << std::endl;
#endif
switch (local_key_tensor.type()) {
case paddle::DataType::BFLOAT16: {
using dataT=__nv_bfloat16;
// std::cout<<"#### cache type __nv_bfloat16" << std::endl;
return sent_key_value_by_remote_ptr<dataT>(
reinterpret_cast<const dataT*>(local_key_tensor_ptr),
reinterpret_cast<const dataT*>(local_value_tensor_ptr),
local_block_ids_ptr,
remote_block_ids_ptr,
block_num,
block_idx_stride,
block_idx_stride * 2,
local_device_id,
remote_device_id,
reinterpret_cast<dataT*>((void*)remote_key_ptr),
reinterpret_cast<dataT*>((void*)remote_value_ptr),
cuda_stream
);
}
case paddle::DataType::FLOAT16: {
using dataT=half;
return sent_key_value_by_remote_ptr<dataT>(
reinterpret_cast<const dataT*>(local_key_tensor_ptr),
reinterpret_cast<const dataT*>(local_value_tensor_ptr),
local_block_ids_ptr,
remote_block_ids_ptr,
block_num,
block_idx_stride,
block_idx_stride * 2,
local_device_id,
remote_device_id,
reinterpret_cast<dataT*>((void*)remote_key_ptr),
reinterpret_cast<dataT*>((void*)remote_value_ptr),
cuda_stream
);
}
case paddle::DataType::INT8: {
using dataT=int8_t;
return sent_key_value_by_remote_ptr<dataT>(
reinterpret_cast<const dataT*>(local_key_tensor_ptr),
reinterpret_cast<const dataT*>(local_value_tensor_ptr),
local_block_ids_ptr,
remote_block_ids_ptr,
block_num,
block_idx_stride,
block_idx_stride * 1,
local_device_id,
remote_device_id,
reinterpret_cast<dataT*>((void*)remote_key_ptr),
reinterpret_cast<dataT*>((void*)remote_value_ptr),
cuda_stream
);
}
case paddle::DataType::UINT8: {
using dataT=uint8_t;
// std::cout<<"#### cache type uint8" << std::endl;
return sent_key_value_by_remote_ptr<dataT>(
reinterpret_cast<const dataT*>(local_key_tensor_ptr),
reinterpret_cast<const dataT*>(local_value_tensor_ptr),
local_block_ids_ptr,
remote_block_ids_ptr,
block_num,
block_idx_stride,
block_idx_stride * 1,
local_device_id,
remote_device_id,
reinterpret_cast<dataT*>((void*)remote_key_ptr),
reinterpret_cast<dataT*>((void*)remote_value_ptr),
cuda_stream
);
}
switch (local_key_tensor.type()) {
case paddle::DataType::BFLOAT16: {
using dataT = __nv_bfloat16;
// std::cout<<"#### cache type __nv_bfloat16" << std::endl;
return sent_key_value_by_remote_ptr<dataT>(
reinterpret_cast<const dataT*>(local_key_tensor_ptr),
reinterpret_cast<const dataT*>(local_value_tensor_ptr),
local_block_ids_ptr,
remote_block_ids_ptr,
block_num,
block_idx_stride,
block_idx_stride * 2,
local_device_id,
remote_device_id,
reinterpret_cast<dataT*>((void*)remote_key_ptr),
reinterpret_cast<dataT*>((void*)remote_value_ptr),
cuda_stream);
}
// using dataT=std::remove_pointer<decltype(local_block_ids_ptr)>;
case paddle::DataType::FLOAT16: {
using dataT = half;
return sent_key_value_by_remote_ptr<dataT>(
reinterpret_cast<const dataT*>(local_key_tensor_ptr),
reinterpret_cast<const dataT*>(local_value_tensor_ptr),
local_block_ids_ptr,
remote_block_ids_ptr,
block_num,
block_idx_stride,
block_idx_stride * 2,
local_device_id,
remote_device_id,
reinterpret_cast<dataT*>((void*)remote_key_ptr),
reinterpret_cast<dataT*>((void*)remote_value_ptr),
cuda_stream);
}
case paddle::DataType::INT8: {
using dataT = int8_t;
return sent_key_value_by_remote_ptr<dataT>(
reinterpret_cast<const dataT*>(local_key_tensor_ptr),
reinterpret_cast<const dataT*>(local_value_tensor_ptr),
local_block_ids_ptr,
remote_block_ids_ptr,
block_num,
block_idx_stride,
block_idx_stride * 1,
local_device_id,
remote_device_id,
reinterpret_cast<dataT*>((void*)remote_key_ptr),
reinterpret_cast<dataT*>((void*)remote_value_ptr),
cuda_stream);
}
case paddle::DataType::UINT8: {
using dataT = uint8_t;
// std::cout<<"#### cache type uint8" << std::endl;
return sent_key_value_by_remote_ptr<dataT>(
reinterpret_cast<const dataT*>(local_key_tensor_ptr),
reinterpret_cast<const dataT*>(local_value_tensor_ptr),
local_block_ids_ptr,
remote_block_ids_ptr,
block_num,
block_idx_stride,
block_idx_stride * 1,
local_device_id,
remote_device_id,
reinterpret_cast<dataT*>((void*)remote_key_ptr),
reinterpret_cast<dataT*>((void*)remote_value_ptr),
cuda_stream);
}
}
// using dataT=std::remove_pointer<decltype(local_block_ids_ptr)>;
}
void SentKeyValueByRemotePtrBlockSync(const paddle::Tensor& local_key_tensor,
const paddle::Tensor& local_value_tensor,
const int64_t& cuda_stream_raw) {
cudaStream_t cuda_stream = (cudaStream_t)cuda_stream_raw;
cudaStreamSynchronize(cuda_stream);
}
const paddle::Tensor& local_value_tensor,
const int64_t& cuda_stream_raw) {
cudaStream_t cuda_stream = (cudaStream_t)cuda_stream_raw;
cudaStreamSynchronize(cuda_stream);
}
PD_BUILD_STATIC_OP(ipc_sent_key_value_cache_by_remote_ptr)
.Inputs({"local_key_tensor", "local_value_tensor", "local_block_ids", "remote_block_ids", "remote_key_tensor", "remote_value_tensor"})
.Attrs({ "block_num: int",
"local_device_id: int",
"remote_device_id: int",
"cuda_stream_raw: int64_t"})
.Inputs({"local_key_tensor",
"local_value_tensor",
"local_block_ids",
"remote_block_ids",
"remote_key_tensor",
"remote_value_tensor"})
.Attrs({"block_num: int",
"local_device_id: int",
"remote_device_id: int",
"cuda_stream_raw: int64_t",
"is_scale: bool"})
.Outputs({"local_key_tensor_out", "local_value_tensor_out"})
.SetInplaceMap({{"local_key_tensor", "local_key_tensor_out"},{"local_value_tensor","local_value_tensor_out"}})
.SetInplaceMap({{"local_key_tensor", "local_key_tensor_out"},
{"local_value_tensor", "local_value_tensor_out"}})
.SetKernelFn(PD_KERNEL(SentKeyValueByRemotePtr));
PD_BUILD_STATIC_OP(ipc_sent_key_value_cache_by_remote_ptr_block_sync)
.Inputs({"local_key_tensor", "local_value_tensor"})
.Attrs({"cuda_stream_raw: int64_t"})
.Outputs({"local_key_tensor_out", "local_value_tensor_out"})
.SetInplaceMap({{"local_key_tensor", "local_key_tensor_out"},{"local_value_tensor","local_value_tensor_out"}})
.SetInplaceMap({{"local_key_tensor", "local_key_tensor_out"},
{"local_value_tensor", "local_value_tensor_out"}})
.SetKernelFn(PD_KERNEL(SentKeyValueByRemotePtrBlockSync));
+72 -3
View File
@@ -81,9 +81,16 @@ def parse_args():
"--cache_dtype",
type=str,
default="bfloat16",
choices=["uint8", "bfloat16"],
choices=["uint8", "bfloat16", "block_wise_fp8"],
help="cache dtype",
)
parser.add_argument(
"--default_dtype",
type=str,
default="bfloat16",
choices=["float16", "bfloat16", "uint8", "int8"],
help="paddle default dtype, cache manager only support float16、bfloat16、int8 and uint8 now",
)
parser.add_argument(
"--speculative_config",
type=json.loads,
@@ -127,6 +134,7 @@ class CacheMessager:
num_layers,
gpu_id=0,
rdma_port=None,
cache_dtype="bfloat16",
):
"""
Initialize the CacheMessager object.
@@ -170,6 +178,8 @@ class CacheMessager:
cache_v_ptr_list = []
cache_k = []
cache_v = []
k_scale_ptr_list, v_scale_ptr_list = [], []
k_cache_scale, v_cache_scale = [], []
self.messager = {}
for layer_idx in range(self.num_layers):
# value cache
@@ -181,6 +191,11 @@ class CacheMessager:
cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr()))
else:
cache_v_ptr_list.append(val_cache.data_ptr())
if cache_dtype == "block_wise_fp8":
val_scale_key = f"value_cache_scales_{layer_idx}_rank{self.rank}_device{gpu_id}"
val_cache_scale = self.gpu_cache_kvs[val_scale_key]
v_cache_scale.append(val_cache_scale)
v_scale_ptr_list.append(val_cache_scale.data_ptr())
# key cache
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
cache_k.append(key_cache)
@@ -188,6 +203,11 @@ class CacheMessager:
cache_k_ptr_list.append(get_peer_mem_addr(key_cache.data_ptr()))
else:
cache_k_ptr_list.append(key_cache.data_ptr())
if cache_dtype == "block_wise_fp8":
key_scale_key = f"key_cache_scales_{layer_idx}_rank{self.rank}_device{gpu_id}"
key_cache_scale = self.gpu_cache_kvs[key_scale_key]
k_cache_scale.append(key_cache_scale)
k_scale_ptr_list.append(key_cache_scale.data_ptr())
cache_k_ptr_list = np.array(cache_k_ptr_list)
cache_v_ptr_list = np.array(cache_v_ptr_list)
@@ -211,6 +231,9 @@ class CacheMessager:
gpu_id,
cache_k,
cache_v,
k_cache_scale,
v_cache_scale,
cache_dtype,
)
local_device_id = int(str(cache_k[0].place)[-2])
logger.info(f"done create ipc_comm with local_device_id:{local_device_id}, ")
@@ -437,6 +460,7 @@ class CacheMessagerV1:
gpu_id=0,
block_size=64,
rdma_port=None,
cache_dtype="bfloat16",
):
"""
Initialize the CacheMessager object.
@@ -473,7 +497,9 @@ class CacheMessagerV1:
self.block_size = block_size
transfer_protocol = transfer_protocol.split(",")
logger.info(f"splitwise role: {splitwise_role}, {transfer_protocol}" f"rank: {rank}")
logger.info(
f"splitwise role: {splitwise_role}, transfer_protocol: {transfer_protocol}, rank: {rank}, cache_dtype: {cache_dtype}"
)
# 1. initialize the cache_k_ptr_list and cache_v_ptr_list
self.num_layers = num_layers
@@ -481,6 +507,8 @@ class CacheMessagerV1:
cache_v_ptr_list = []
cache_k = []
cache_v = []
k_scale_ptr_list, v_scale_ptr_list = [], []
k_cache_scale, v_cache_scale = [], []
self.messager = {}
for layer_idx in range(self.num_layers):
# value cache
@@ -492,6 +520,11 @@ class CacheMessagerV1:
cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr()))
else:
cache_v_ptr_list.append(val_cache.data_ptr())
if cache_dtype == "block_wise_fp8":
val_scale_key = f"value_cache_scales_{layer_idx}_rank{self.rank}_device{gpu_id}"
val_cache_scale = self.gpu_cache_kvs[val_scale_key]
v_cache_scale.append(val_cache_scale)
v_scale_ptr_list.append(val_cache_scale.data_ptr())
# key cache
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
cache_k.append(key_cache)
@@ -499,6 +532,11 @@ class CacheMessagerV1:
cache_k_ptr_list.append(get_peer_mem_addr(key_cache.data_ptr()))
else:
cache_k_ptr_list.append(key_cache.data_ptr())
if cache_dtype == "block_wise_fp8":
key_scale_key = f"key_cache_scales_{layer_idx}_rank{self.rank}_device{gpu_id}"
key_cache_scale = self.gpu_cache_kvs[key_scale_key]
k_cache_scale.append(key_cache_scale)
k_scale_ptr_list.append(key_cache_scale.data_ptr())
cache_k_ptr_list = np.array(cache_k_ptr_list)
cache_v_ptr_list = np.array(cache_v_ptr_list)
@@ -522,6 +560,9 @@ class CacheMessagerV1:
gpu_id,
cache_k,
cache_v,
k_cache_scale,
v_cache_scale,
cache_dtype,
)
local_device_id = int(str(cache_k[0].place)[-2])
logger.info(f"done create ipc_comm with local_device_id:{local_device_id}, ")
@@ -850,7 +891,13 @@ def main():
device = args.device_id
rank = args.rank
set_device(device)
cache_type = args.cache_dtype
paddle.set_default_dtype(args.default_dtype)
if args.cache_dtype == "block_wise_fp8":
cache_type = "uint8"
else:
cache_type = args.cache_dtype
speculative_config = SpeculativeConfig(args.speculative_config)
num_extra_layers = speculative_config.num_extra_cache_layer
key_cache_shape_list = [int(i) for i in args.key_cache_shape.split(",")]
@@ -894,6 +941,16 @@ def main():
gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
f"key_caches_{i}_rank{rank}.device{device}",
)
if args.cache_dtype == "block_wise_fp8":
gpu_cache_kvs[f"key_cache_scales_{i}_rank{rank}_device{device}"] = paddle.full(
shape=[num_gpu_blocks, key_cache_shape[1], key_cache_shape[2]],
fill_value=0,
dtype=paddle.get_default_dtype(),
)
set_data_ipc(
gpu_cache_kvs[f"key_cache_scales_{i}_rank{rank}_device{device}"],
f"key_cache_scales_{i}_rank{rank}.device{device}",
)
if value_cache_shape_list:
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
shape=value_cache_shape,
@@ -906,6 +963,16 @@ def main():
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
f"value_caches_{i}_rank{rank}.device{device}",
)
if args.cache_dtype == "block_wise_fp8":
gpu_cache_kvs[f"value_cache_scales_{i}_rank{rank}_device{device}"] = paddle.full(
shape=[num_gpu_blocks, value_cache_shape[1], value_cache_shape[2]],
fill_value=0,
dtype=paddle.get_default_dtype(),
)
set_data_ipc(
gpu_cache_kvs[f"value_cache_scales_{i}_rank{rank}_device{device}"],
f"value_cache_scales_{i}_rank{rank}.device{device}",
)
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in gpu_cache_kvs.items()])
logger.info(f"device :{device}")
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
@@ -924,6 +991,7 @@ def main():
num_layers=args.num_layers + num_extra_layers,
gpu_id=device,
rdma_port=args.rdma_port,
cache_dtype=args.cache_dtype,
)
else:
cache_messager = CacheMessager(
@@ -938,6 +1006,7 @@ def main():
num_layers=args.num_layers + num_extra_layers,
gpu_id=device,
rdma_port=args.rdma_port,
cache_dtype=args.cache_dtype,
)
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
@@ -380,6 +380,7 @@ class PrefixCacheManager:
+ f" --key_cache_shape {key_cache_shape}"
+ val_cache_arg_str
+ f" --pod_ip {pod_ip}"
+ f" --default_dtype '{self.config.model_config.dtype}'"
+ f" --cache_queue_port {cache_config.local_cache_queue_port}"
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
+ f" --protocol {cache_config.cache_transfer_protocol}"
@@ -31,7 +31,7 @@ class IPCConnector:
IPC communication class.
"""
def __init__(self, rank_id_, remote_gpu_id_, layer_num, local_gpu_id_):
def __init__(self, rank_id_, remote_gpu_id_, layer_num, local_gpu_id_, cache_dtype):
"""
Args:
rank_id_: rank id
@@ -40,16 +40,26 @@ class IPCConnector:
"""
self.remote_key_tensor_ptr_list = []
self.remote_value_tensor_ptr_list = []
self.remote_key_scale_tensor_ptr_list = []
self.remote_value_scale_tensor_ptr_list = []
self.remote_gpu_id = int(remote_gpu_id_)
self.rank_id = rank_id_
self.local_gpu_id = int(local_gpu_id_)
self.cache_dtype = cache_dtype
tmp = paddle.ones([1, 1])
logger.info(f"init ipc rank{self.rank_id} with remote {self.remote_gpu_id} {self.local_gpu_id}")
logger.info(
f"init ipc rank{self.rank_id} with remote {self.remote_gpu_id} {self.local_gpu_id}, cache dtype {self.cache_dtype}"
)
for layer_id in range(layer_num):
key_unique_name = f"key_caches_{layer_id}_rank{self.rank_id}.device{self.remote_gpu_id}"
value_unique_name = f"value_caches_{layer_id}_rank{self.rank_id}.device{self.remote_gpu_id}"
self.remote_key_tensor_ptr_list.append(get_data_ptr_ipc(tmp, key_unique_name))
self.remote_value_tensor_ptr_list.append(get_data_ptr_ipc(tmp, value_unique_name))
if self.cache_dtype == "block_wise_fp8":
key_scale_name = f"key_cache_scales_{layer_id}_rank{self.rank_id}.device{self.remote_gpu_id}"
val_scale_name = f"value_cache_scales_{layer_id}_rank{self.rank_id}.device{self.remote_gpu_id}"
self.remote_key_scale_tensor_ptr_list.append(get_data_ptr_ipc(tmp, key_scale_name))
self.remote_value_scale_tensor_ptr_list.append(get_data_ptr_ipc(tmp, val_scale_name))
self.write_stream = paddle.device.Stream(f"gpu:{self.local_gpu_id}")
@@ -64,13 +74,19 @@ class IPCCommManager:
gpu_idx_,
local_key_cache_tensor_list, # tensor list
local_value_cache_tensor_list, # tensor
local_key_cache_scale_list,
local_value_cache_scale_list,
cache_dtype,
):
self.rank_id = rank_id_
self.gpu_idx = gpu_idx_
self.cache_dtype = cache_dtype
# local cache to tensor
self.local_key_cache_tensor_list = local_key_cache_tensor_list
self.local_value_cache_tensor_list = local_value_cache_tensor_list
self.layer_num = len(self.local_key_cache_tensor_list)
self.local_key_cache_scale_list = local_key_cache_scale_list
self.local_value_cache_scale_list = local_value_cache_scale_list
# record connected ipc info
self.comm_map = {}
@@ -82,7 +98,9 @@ class IPCCommManager:
if self.is_connected(remote_gpu_id_):
return True
else:
self.comm_map[remote_gpu_id_] = IPCConnector(self.rank_id, remote_gpu_id_, self.layer_num, self.gpu_idx)
self.comm_map[remote_gpu_id_] = IPCConnector(
self.rank_id, remote_gpu_id_, self.layer_num, self.gpu_idx, self.cache_dtype
)
return True
def is_connected(self, remote_gpu_id_=0):
@@ -114,7 +132,23 @@ class IPCCommManager:
self.gpu_idx,
comm.remote_gpu_id,
comm.write_stream.stream_base.cuda_stream,
False,
)
if self.cache_dtype == "block_wise_fp8":
logger.info(f"IPC write cache scales for layer: {layer_idx}")
ipc_sent_key_value_cache_by_remote_ptr(
self.local_key_cache_scale_list[layer_idx],
self.local_value_cache_scale_list[layer_idx],
local_block_ids,
remote_block_ids,
comm.remote_key_scale_tensor_ptr_list[layer_idx],
comm.remote_value_scale_tensor_ptr_list[layer_idx],
block_num,
self.gpu_idx,
comm.remote_gpu_id,
comm.write_stream.stream_base.cuda_stream,
True,
)
return 0
def write_block_by_sync(self, remote_gpu_id):