mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] pd support dy-c8 ipc (#5750)
* pd support dy-c8 ipc * update code * support v0 * update code
This commit is contained in:
@@ -12,66 +12,74 @@
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#include "helper.h"
|
||||
#include "fstream"
|
||||
#include "iostream"
|
||||
#include "iomanip"
|
||||
#include <nvml.h>
|
||||
#include <iostream>
|
||||
#include "fstream"
|
||||
#include "helper.h"
|
||||
#include "iomanip"
|
||||
#include "iostream"
|
||||
// #define PRINT_GPU_MEMORY
|
||||
// 函数用于获取 NVIDIA GPU 显存信息
|
||||
bool getNvidiaGPUMemoryUsage(int callLine) {
|
||||
#ifndef PRINT_GPU_MEMORY
|
||||
return true;
|
||||
#endif
|
||||
// 初始化 NVML
|
||||
nvmlReturn_t result;
|
||||
result = nvmlInit();
|
||||
if (NVML_SUCCESS != result) {
|
||||
std::cerr << callLine << ": Failed to initialize NVML: " << nvmlErrorString(result) << std::endl;
|
||||
return false;
|
||||
}
|
||||
// 获取 GPU 设备数量
|
||||
unsigned int deviceCount;
|
||||
result = nvmlDeviceGetCount(&deviceCount);
|
||||
if (NVML_SUCCESS != result) {
|
||||
std::cerr << callLine << ": Failed to get device count: " << nvmlErrorString(result) << std::endl;
|
||||
nvmlShutdown();
|
||||
return false;
|
||||
}
|
||||
// 遍历每个 GPU 设备
|
||||
for (unsigned int i = 0; i < deviceCount; ++i) {
|
||||
nvmlDevice_t device;
|
||||
result = nvmlDeviceGetHandleByIndex(i, &device);
|
||||
if (NVML_SUCCESS != result) {
|
||||
std::cerr << callLine << ": Failed to get device handle for device " << i << ": " << nvmlErrorString(result) << std::endl;
|
||||
continue;
|
||||
}
|
||||
// 获取显存信息
|
||||
nvmlMemory_t memory;
|
||||
result = nvmlDeviceGetMemoryInfo(device, &memory);
|
||||
if (NVML_SUCCESS != result) {
|
||||
std::cerr << callLine << ": Failed to get memory info for device " << i << ": " << nvmlErrorString(result) << std::endl;
|
||||
continue;
|
||||
}
|
||||
// 只打印一行信息并显示调用函数时的行号
|
||||
std::cout << callLine << ": GPU " << i << " - Total: " << memory.total / (1024 * 1024)
|
||||
<< " MiB, Used: " << memory.used / (1024 * 1024)
|
||||
<< " MiB, Free: " << memory.free / (1024 * 1024) << " MiB" << std::endl;
|
||||
}
|
||||
// 清理 NVML 资源
|
||||
#ifndef PRINT_GPU_MEMORY
|
||||
return true;
|
||||
#endif
|
||||
// 初始化 NVML
|
||||
nvmlReturn_t result;
|
||||
result = nvmlInit();
|
||||
if (NVML_SUCCESS != result) {
|
||||
std::cerr << callLine
|
||||
<< ": Failed to initialize NVML: " << nvmlErrorString(result)
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
// 获取 GPU 设备数量
|
||||
unsigned int deviceCount;
|
||||
result = nvmlDeviceGetCount(&deviceCount);
|
||||
if (NVML_SUCCESS != result) {
|
||||
std::cerr << callLine
|
||||
<< ": Failed to get device count: " << nvmlErrorString(result)
|
||||
<< std::endl;
|
||||
nvmlShutdown();
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
// 遍历每个 GPU 设备
|
||||
for (unsigned int i = 0; i < deviceCount; ++i) {
|
||||
nvmlDevice_t device;
|
||||
result = nvmlDeviceGetHandleByIndex(i, &device);
|
||||
if (NVML_SUCCESS != result) {
|
||||
std::cerr << callLine << ": Failed to get device handle for device " << i
|
||||
<< ": " << nvmlErrorString(result) << std::endl;
|
||||
continue;
|
||||
}
|
||||
// 获取显存信息
|
||||
nvmlMemory_t memory;
|
||||
result = nvmlDeviceGetMemoryInfo(device, &memory);
|
||||
if (NVML_SUCCESS != result) {
|
||||
std::cerr << callLine << ": Failed to get memory info for device " << i
|
||||
<< ": " << nvmlErrorString(result) << std::endl;
|
||||
continue;
|
||||
}
|
||||
// 只打印一行信息并显示调用函数时的行号
|
||||
std::cout << callLine << ": GPU " << i
|
||||
<< " - Total: " << memory.total / (1024 * 1024)
|
||||
<< " MiB, Used: " << memory.used / (1024 * 1024)
|
||||
<< " MiB, Free: " << memory.free / (1024 * 1024) << " MiB"
|
||||
<< std::endl;
|
||||
}
|
||||
// 清理 NVML 资源
|
||||
nvmlShutdown();
|
||||
return true;
|
||||
}
|
||||
|
||||
// #define DEBUG_IPC_SENT
|
||||
// #define DEBUG_IPC_SENT_SYNC_AND_PRINT
|
||||
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
void sent_key_value_by_remote_ptr(
|
||||
const T* local_key_tensor_base_ptr, // gpu ptr
|
||||
const T* local_value_tensor_base_ptr, // gpu ptr
|
||||
const int32_t* local_block_ids_ptr, //cpu ptr,
|
||||
const T* local_key_tensor_base_ptr, // gpu ptr
|
||||
const T* local_value_tensor_base_ptr, // gpu ptr
|
||||
const int32_t* local_block_ids_ptr, // cpu ptr,
|
||||
const int32_t* remote_block_ids_ptr,
|
||||
const int32_t block_num,
|
||||
const int64_t block_idx_stride,
|
||||
@@ -80,255 +88,275 @@ void sent_key_value_by_remote_ptr(
|
||||
const int32_t remote_device_id,
|
||||
T* remote_key_tensor_base_ptr, // gpu ptr
|
||||
T* remote_value_tensor_base_ptr, // gpu ptr
|
||||
cudaStream_t stream){
|
||||
for(int block_idx=0;block_idx < block_num; ++block_idx){
|
||||
const T* local_key_tensor_sent_ptr = local_key_tensor_base_ptr + local_block_ids_ptr[block_idx] * block_idx_stride;
|
||||
T* remote_key_tensor_sent_ptr = remote_key_tensor_base_ptr + remote_block_ids_ptr[block_idx] * block_idx_stride;
|
||||
#ifdef DEBUG_IPC_SENT
|
||||
std::cout<<"remote_key_tensor_sent_ptr:"<<(int64_t)remote_key_tensor_sent_ptr
|
||||
<<" local_key_tensor_sent_ptr:"<<(int64_t)local_key_tensor_sent_ptr
|
||||
<<" local_device_id:" << local_device_id
|
||||
<<" remote_device_id:" << remote_device_id
|
||||
<<" block_idx_stride:" << block_idx_stride
|
||||
<<" block_size_byte:" << block_size_byte
|
||||
<<" stream: " << stream
|
||||
<<" local_block_ids: " << local_block_ids_ptr[block_idx]
|
||||
<<" remote_block_ids: " << remote_block_ids_ptr[block_idx]
|
||||
<<std::endl;
|
||||
#endif
|
||||
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
||||
cudaDeviceSynchronize();
|
||||
PrintMatrix<T>(reinterpret_cast<const T*>(local_key_tensor_sent_ptr),
|
||||
128 * 1,
|
||||
"ipc_send_src_key.datatxt." + std::to_string(local_device_id),
|
||||
128 * 1);
|
||||
cudaDeviceSynchronize();
|
||||
#endif
|
||||
#ifndef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
||||
cudaMemcpyPeerAsync(
|
||||
reinterpret_cast<void*>(remote_key_tensor_sent_ptr),
|
||||
remote_device_id,
|
||||
reinterpret_cast<const void*>(local_key_tensor_sent_ptr),
|
||||
local_device_id,
|
||||
block_size_byte,
|
||||
stream);
|
||||
#endif
|
||||
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
||||
cudaMemcpyPeer(
|
||||
reinterpret_cast<void*>(remote_key_tensor_sent_ptr),
|
||||
remote_device_id,
|
||||
reinterpret_cast<const void*>(local_key_tensor_sent_ptr),
|
||||
local_device_id,
|
||||
block_size_byte);
|
||||
#endif
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if ( err != cudaSuccess )
|
||||
{
|
||||
printf("CUDA Error: %s\n", cudaGetErrorString(err));
|
||||
}
|
||||
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
||||
cudaDeviceSynchronize();
|
||||
PrintMatrix<T>(reinterpret_cast<T*>(remote_key_tensor_sent_ptr),
|
||||
128 * 1,
|
||||
"ipc_send_tgt_key.datatxt." + std::to_string(local_device_id),
|
||||
128 * 1);
|
||||
cudaDeviceSynchronize();
|
||||
#endif
|
||||
const T* local_value_tensor_sent_ptr = local_value_tensor_base_ptr + local_block_ids_ptr[block_idx] * block_idx_stride;
|
||||
T* remote_value_tensor_sent_ptr = remote_value_tensor_base_ptr + remote_block_ids_ptr[block_idx] * block_idx_stride;
|
||||
cudaStream_t stream) {
|
||||
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
||||
const T* local_key_tensor_sent_ptr =
|
||||
local_key_tensor_base_ptr +
|
||||
local_block_ids_ptr[block_idx] * block_idx_stride;
|
||||
T* remote_key_tensor_sent_ptr =
|
||||
remote_key_tensor_base_ptr +
|
||||
remote_block_ids_ptr[block_idx] * block_idx_stride;
|
||||
#ifdef DEBUG_IPC_SENT
|
||||
std::cout<<"remote_value_tensor_sent_ptr:"<<(int64_t)remote_value_tensor_sent_ptr
|
||||
<<" local_value_tensor_sent_ptr:"<<(int64_t)local_value_tensor_sent_ptr
|
||||
<<" local_device_id:" << local_device_id
|
||||
<<" remote_device_id:" << remote_device_id
|
||||
<<" block_idx_stride:" << block_idx_stride
|
||||
<<" block_size_byte:" << block_size_byte
|
||||
<<" stream: " << stream
|
||||
<<" local_block_ids: " << local_block_ids_ptr[block_idx]
|
||||
<<" remote_block_ids: " << remote_block_ids_ptr[block_idx]
|
||||
<<std::endl;
|
||||
std::cout << "remote_key_tensor_sent_ptr:"
|
||||
<< (int64_t)remote_key_tensor_sent_ptr
|
||||
<< " local_key_tensor_sent_ptr:"
|
||||
<< (int64_t)local_key_tensor_sent_ptr
|
||||
<< " local_device_id:" << local_device_id
|
||||
<< " remote_device_id:" << remote_device_id
|
||||
<< " block_idx_stride:" << block_idx_stride
|
||||
<< " block_size_byte:" << block_size_byte << " stream: " << stream
|
||||
<< " local_block_ids: " << local_block_ids_ptr[block_idx]
|
||||
<< " remote_block_ids: " << remote_block_ids_ptr[block_idx]
|
||||
<< std::endl;
|
||||
#endif
|
||||
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
||||
cudaDeviceSynchronize();
|
||||
PrintMatrix<T>(reinterpret_cast<const T*>(local_value_tensor_sent_ptr),
|
||||
128 * 1,
|
||||
"ipc_send_src_value.datatxt." + std::to_string(local_device_id),
|
||||
128 * 1);
|
||||
cudaDeviceSynchronize();
|
||||
cudaDeviceSynchronize();
|
||||
PrintMatrix<T>(
|
||||
reinterpret_cast<const T*>(local_key_tensor_sent_ptr),
|
||||
128 * 1,
|
||||
"ipc_send_src_key.datatxt." + std::to_string(local_device_id),
|
||||
128 * 1);
|
||||
cudaDeviceSynchronize();
|
||||
#endif
|
||||
#ifndef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
||||
cudaMemcpyPeerAsync(
|
||||
reinterpret_cast<void*>(remote_value_tensor_sent_ptr),
|
||||
remote_device_id,
|
||||
reinterpret_cast<const void*>(local_value_tensor_sent_ptr),
|
||||
local_device_id,
|
||||
block_size_byte,
|
||||
stream);
|
||||
cudaMemcpyPeerAsync(
|
||||
reinterpret_cast<void*>(remote_key_tensor_sent_ptr),
|
||||
remote_device_id,
|
||||
reinterpret_cast<const void*>(local_key_tensor_sent_ptr),
|
||||
local_device_id,
|
||||
block_size_byte,
|
||||
stream);
|
||||
#endif
|
||||
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
||||
cudaMemcpyPeer(
|
||||
reinterpret_cast<void*>(remote_value_tensor_sent_ptr),
|
||||
remote_device_id,
|
||||
reinterpret_cast<const void*>(local_value_tensor_sent_ptr),
|
||||
local_device_id,
|
||||
block_size_byte);
|
||||
cudaDeviceSynchronize();
|
||||
#endif
|
||||
err = cudaGetLastError();
|
||||
if ( err != cudaSuccess )
|
||||
{
|
||||
printf("CUDA Error: %s\n", cudaGetErrorString(err));
|
||||
}
|
||||
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
||||
PrintMatrix<T>(reinterpret_cast<T*>(remote_value_tensor_sent_ptr),
|
||||
128 * 1,
|
||||
"ipc_send_tgt_value.datatxt." + std::to_string(local_device_id),
|
||||
128 * 1);
|
||||
cudaDeviceSynchronize();
|
||||
cudaMemcpyPeer(reinterpret_cast<void*>(remote_key_tensor_sent_ptr),
|
||||
remote_device_id,
|
||||
reinterpret_cast<const void*>(local_key_tensor_sent_ptr),
|
||||
local_device_id,
|
||||
block_size_byte);
|
||||
#endif
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess) {
|
||||
printf("CUDA Error: %s\n", cudaGetErrorString(err));
|
||||
}
|
||||
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
||||
cudaDeviceSynchronize();
|
||||
PrintMatrix<T>(
|
||||
reinterpret_cast<T*>(remote_key_tensor_sent_ptr),
|
||||
128 * 1,
|
||||
"ipc_send_tgt_key.datatxt." + std::to_string(local_device_id),
|
||||
128 * 1);
|
||||
cudaDeviceSynchronize();
|
||||
#endif
|
||||
const T* local_value_tensor_sent_ptr =
|
||||
local_value_tensor_base_ptr +
|
||||
local_block_ids_ptr[block_idx] * block_idx_stride;
|
||||
T* remote_value_tensor_sent_ptr =
|
||||
remote_value_tensor_base_ptr +
|
||||
remote_block_ids_ptr[block_idx] * block_idx_stride;
|
||||
#ifdef DEBUG_IPC_SENT
|
||||
std::cout << "remote_value_tensor_sent_ptr:"
|
||||
<< (int64_t)remote_value_tensor_sent_ptr
|
||||
<< " local_value_tensor_sent_ptr:"
|
||||
<< (int64_t)local_value_tensor_sent_ptr
|
||||
<< " local_device_id:" << local_device_id
|
||||
<< " remote_device_id:" << remote_device_id
|
||||
<< " block_idx_stride:" << block_idx_stride
|
||||
<< " block_size_byte:" << block_size_byte << " stream: " << stream
|
||||
<< " local_block_ids: " << local_block_ids_ptr[block_idx]
|
||||
<< " remote_block_ids: " << remote_block_ids_ptr[block_idx]
|
||||
<< std::endl;
|
||||
#endif
|
||||
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
||||
cudaDeviceSynchronize();
|
||||
PrintMatrix<T>(
|
||||
reinterpret_cast<const T*>(local_value_tensor_sent_ptr),
|
||||
128 * 1,
|
||||
"ipc_send_src_value.datatxt." + std::to_string(local_device_id),
|
||||
128 * 1);
|
||||
cudaDeviceSynchronize();
|
||||
#endif
|
||||
#ifndef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
||||
cudaMemcpyPeerAsync(
|
||||
reinterpret_cast<void*>(remote_value_tensor_sent_ptr),
|
||||
remote_device_id,
|
||||
reinterpret_cast<const void*>(local_value_tensor_sent_ptr),
|
||||
local_device_id,
|
||||
block_size_byte,
|
||||
stream);
|
||||
#endif
|
||||
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
||||
cudaMemcpyPeer(reinterpret_cast<void*>(remote_value_tensor_sent_ptr),
|
||||
remote_device_id,
|
||||
reinterpret_cast<const void*>(local_value_tensor_sent_ptr),
|
||||
local_device_id,
|
||||
block_size_byte);
|
||||
cudaDeviceSynchronize();
|
||||
#endif
|
||||
err = cudaGetLastError();
|
||||
if (err != cudaSuccess) {
|
||||
printf("CUDA Error: %s\n", cudaGetErrorString(err));
|
||||
}
|
||||
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
||||
PrintMatrix<T>(
|
||||
reinterpret_cast<T*>(remote_value_tensor_sent_ptr),
|
||||
128 * 1,
|
||||
"ipc_send_tgt_value.datatxt." + std::to_string(local_device_id),
|
||||
128 * 1);
|
||||
cudaDeviceSynchronize();
|
||||
#endif
|
||||
}
|
||||
}
|
||||
void SentKeyValueByRemotePtr(const paddle::Tensor& local_key_tensor,
|
||||
const paddle::Tensor& local_value_tensor,
|
||||
const paddle::Tensor& local_block_ids, // cpu
|
||||
const paddle::Tensor& remote_block_ids, // cpu
|
||||
const paddle::Tensor& local_block_ids, // cpu
|
||||
const paddle::Tensor& remote_block_ids, // cpu
|
||||
const paddle::Tensor& remote_key_tensor,
|
||||
const paddle::Tensor& remote_value_tensor,
|
||||
const int& block_num,
|
||||
const int& local_device_id,
|
||||
const int& remote_device_id,
|
||||
const int64_t& cuda_stream_raw) {
|
||||
std::vector<int64_t> cache_key_tensor_shape = local_key_tensor.shape();
|
||||
getNvidiaGPUMemoryUsage(__LINE__);
|
||||
// auto cuda_stream = local_key_tensor.stream();
|
||||
cudaStream_t cuda_stream = (cudaStream_t)cuda_stream_raw;
|
||||
getNvidiaGPUMemoryUsage(__LINE__);
|
||||
// const cudaStream_t cuda_stream = *(reinterpret_cast<const cudaStream_t*>(&stream));
|
||||
#ifdef DEBUG_IPC_SENT
|
||||
std::cout<<"#### 000"<<std::endl;
|
||||
#endif
|
||||
const int64_t& cuda_stream_raw,
|
||||
const bool& is_scale) {
|
||||
std::vector<int64_t> cache_key_tensor_shape = local_key_tensor.shape();
|
||||
getNvidiaGPUMemoryUsage(__LINE__);
|
||||
// auto cuda_stream = local_key_tensor.stream();
|
||||
cudaStream_t cuda_stream = (cudaStream_t)cuda_stream_raw;
|
||||
getNvidiaGPUMemoryUsage(__LINE__);
|
||||
// const cudaStream_t cuda_stream = *(reinterpret_cast<const
|
||||
// cudaStream_t*>(&stream));
|
||||
#ifdef DEBUG_IPC_SENT
|
||||
std::cout << "#### 000" << std::endl;
|
||||
#endif
|
||||
|
||||
int32_t total_block_num_local = cache_key_tensor_shape[0];
|
||||
int32_t kv_num_head_local = cache_key_tensor_shape[1];
|
||||
int32_t block_size_local = cache_key_tensor_shape[2];
|
||||
int32_t hidden_size_local = cache_key_tensor_shape[3];
|
||||
getNvidiaGPUMemoryUsage(__LINE__);
|
||||
int32_t total_block_num_local = cache_key_tensor_shape[0];
|
||||
int32_t kv_num_head_local = cache_key_tensor_shape[1];
|
||||
int32_t block_size_local = cache_key_tensor_shape[2];
|
||||
int32_t hidden_size_local = cache_key_tensor_shape[3];
|
||||
getNvidiaGPUMemoryUsage(__LINE__);
|
||||
|
||||
auto local_block_ids_ptr = local_block_ids.data<int32_t>(); // cpu
|
||||
auto remote_block_ids_ptr = remote_block_ids.data<int32_t>(); // cpu
|
||||
auto remote_key_ptr = remote_key_tensor.data<int64_t>()[0];
|
||||
auto remote_value_ptr = remote_value_tensor.data<int64_t>()[0];
|
||||
getNvidiaGPUMemoryUsage(__LINE__);
|
||||
auto local_block_ids_ptr = local_block_ids.data<int32_t>(); // cpu
|
||||
auto remote_block_ids_ptr = remote_block_ids.data<int32_t>(); // cpu
|
||||
auto remote_key_ptr = remote_key_tensor.data<int64_t>()[0];
|
||||
auto remote_value_ptr = remote_value_tensor.data<int64_t>()[0];
|
||||
getNvidiaGPUMemoryUsage(__LINE__);
|
||||
|
||||
#ifdef DEBUG_IPC_SENT
|
||||
std::cout<<"#### 1111"
|
||||
<< " remote_key_ptr: "<<remote_key_ptr
|
||||
<< " remote_value_ptr: "<<remote_value_ptr<<std::endl;
|
||||
#endif
|
||||
getNvidiaGPUMemoryUsage(__LINE__);
|
||||
int64_t block_idx_stride = kv_num_head_local*block_size_local*hidden_size_local;
|
||||
auto local_key_tensor_ptr = local_key_tensor.data();
|
||||
auto local_value_tensor_ptr = local_value_tensor.data();
|
||||
getNvidiaGPUMemoryUsage(__LINE__);
|
||||
#ifdef DEBUG_IPC_SENT
|
||||
std::cout<<"#### 2222"<<std::endl;
|
||||
#endif
|
||||
#ifdef DEBUG_IPC_SENT
|
||||
std::cout << "#### 1111"
|
||||
<< " remote_key_ptr: " << remote_key_ptr
|
||||
<< " remote_value_ptr: " << remote_value_ptr << std::endl;
|
||||
#endif
|
||||
getNvidiaGPUMemoryUsage(__LINE__);
|
||||
int64_t block_idx_stride =
|
||||
kv_num_head_local * block_size_local * hidden_size_local;
|
||||
if (is_scale == true) {
|
||||
block_idx_stride = kv_num_head_local * block_size_local;
|
||||
}
|
||||
auto local_key_tensor_ptr = local_key_tensor.data();
|
||||
auto local_value_tensor_ptr = local_value_tensor.data();
|
||||
getNvidiaGPUMemoryUsage(__LINE__);
|
||||
#ifdef DEBUG_IPC_SENT
|
||||
std::cout << "#### 2222" << std::endl;
|
||||
#endif
|
||||
|
||||
switch (local_key_tensor.type()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
using dataT=__nv_bfloat16;
|
||||
// std::cout<<"#### cache type __nv_bfloat16" << std::endl;
|
||||
return sent_key_value_by_remote_ptr<dataT>(
|
||||
reinterpret_cast<const dataT*>(local_key_tensor_ptr),
|
||||
reinterpret_cast<const dataT*>(local_value_tensor_ptr),
|
||||
local_block_ids_ptr,
|
||||
remote_block_ids_ptr,
|
||||
block_num,
|
||||
block_idx_stride,
|
||||
block_idx_stride * 2,
|
||||
local_device_id,
|
||||
remote_device_id,
|
||||
reinterpret_cast<dataT*>((void*)remote_key_ptr),
|
||||
reinterpret_cast<dataT*>((void*)remote_value_ptr),
|
||||
cuda_stream
|
||||
);
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
using dataT=half;
|
||||
return sent_key_value_by_remote_ptr<dataT>(
|
||||
reinterpret_cast<const dataT*>(local_key_tensor_ptr),
|
||||
reinterpret_cast<const dataT*>(local_value_tensor_ptr),
|
||||
local_block_ids_ptr,
|
||||
remote_block_ids_ptr,
|
||||
block_num,
|
||||
block_idx_stride,
|
||||
block_idx_stride * 2,
|
||||
local_device_id,
|
||||
remote_device_id,
|
||||
reinterpret_cast<dataT*>((void*)remote_key_ptr),
|
||||
reinterpret_cast<dataT*>((void*)remote_value_ptr),
|
||||
cuda_stream
|
||||
);
|
||||
}
|
||||
case paddle::DataType::INT8: {
|
||||
using dataT=int8_t;
|
||||
return sent_key_value_by_remote_ptr<dataT>(
|
||||
reinterpret_cast<const dataT*>(local_key_tensor_ptr),
|
||||
reinterpret_cast<const dataT*>(local_value_tensor_ptr),
|
||||
local_block_ids_ptr,
|
||||
remote_block_ids_ptr,
|
||||
block_num,
|
||||
block_idx_stride,
|
||||
block_idx_stride * 1,
|
||||
local_device_id,
|
||||
remote_device_id,
|
||||
reinterpret_cast<dataT*>((void*)remote_key_ptr),
|
||||
reinterpret_cast<dataT*>((void*)remote_value_ptr),
|
||||
cuda_stream
|
||||
);
|
||||
}
|
||||
case paddle::DataType::UINT8: {
|
||||
using dataT=uint8_t;
|
||||
// std::cout<<"#### cache type uint8" << std::endl;
|
||||
return sent_key_value_by_remote_ptr<dataT>(
|
||||
reinterpret_cast<const dataT*>(local_key_tensor_ptr),
|
||||
reinterpret_cast<const dataT*>(local_value_tensor_ptr),
|
||||
local_block_ids_ptr,
|
||||
remote_block_ids_ptr,
|
||||
block_num,
|
||||
block_idx_stride,
|
||||
block_idx_stride * 1,
|
||||
local_device_id,
|
||||
remote_device_id,
|
||||
reinterpret_cast<dataT*>((void*)remote_key_ptr),
|
||||
reinterpret_cast<dataT*>((void*)remote_value_ptr),
|
||||
cuda_stream
|
||||
);
|
||||
}
|
||||
switch (local_key_tensor.type()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
using dataT = __nv_bfloat16;
|
||||
// std::cout<<"#### cache type __nv_bfloat16" << std::endl;
|
||||
return sent_key_value_by_remote_ptr<dataT>(
|
||||
reinterpret_cast<const dataT*>(local_key_tensor_ptr),
|
||||
reinterpret_cast<const dataT*>(local_value_tensor_ptr),
|
||||
local_block_ids_ptr,
|
||||
remote_block_ids_ptr,
|
||||
block_num,
|
||||
block_idx_stride,
|
||||
block_idx_stride * 2,
|
||||
local_device_id,
|
||||
remote_device_id,
|
||||
reinterpret_cast<dataT*>((void*)remote_key_ptr),
|
||||
reinterpret_cast<dataT*>((void*)remote_value_ptr),
|
||||
cuda_stream);
|
||||
}
|
||||
// using dataT=std::remove_pointer<decltype(local_block_ids_ptr)>;
|
||||
case paddle::DataType::FLOAT16: {
|
||||
using dataT = half;
|
||||
return sent_key_value_by_remote_ptr<dataT>(
|
||||
reinterpret_cast<const dataT*>(local_key_tensor_ptr),
|
||||
reinterpret_cast<const dataT*>(local_value_tensor_ptr),
|
||||
local_block_ids_ptr,
|
||||
remote_block_ids_ptr,
|
||||
block_num,
|
||||
block_idx_stride,
|
||||
block_idx_stride * 2,
|
||||
local_device_id,
|
||||
remote_device_id,
|
||||
reinterpret_cast<dataT*>((void*)remote_key_ptr),
|
||||
reinterpret_cast<dataT*>((void*)remote_value_ptr),
|
||||
cuda_stream);
|
||||
}
|
||||
case paddle::DataType::INT8: {
|
||||
using dataT = int8_t;
|
||||
return sent_key_value_by_remote_ptr<dataT>(
|
||||
reinterpret_cast<const dataT*>(local_key_tensor_ptr),
|
||||
reinterpret_cast<const dataT*>(local_value_tensor_ptr),
|
||||
local_block_ids_ptr,
|
||||
remote_block_ids_ptr,
|
||||
block_num,
|
||||
block_idx_stride,
|
||||
block_idx_stride * 1,
|
||||
local_device_id,
|
||||
remote_device_id,
|
||||
reinterpret_cast<dataT*>((void*)remote_key_ptr),
|
||||
reinterpret_cast<dataT*>((void*)remote_value_ptr),
|
||||
cuda_stream);
|
||||
}
|
||||
case paddle::DataType::UINT8: {
|
||||
using dataT = uint8_t;
|
||||
// std::cout<<"#### cache type uint8" << std::endl;
|
||||
return sent_key_value_by_remote_ptr<dataT>(
|
||||
reinterpret_cast<const dataT*>(local_key_tensor_ptr),
|
||||
reinterpret_cast<const dataT*>(local_value_tensor_ptr),
|
||||
local_block_ids_ptr,
|
||||
remote_block_ids_ptr,
|
||||
block_num,
|
||||
block_idx_stride,
|
||||
block_idx_stride * 1,
|
||||
local_device_id,
|
||||
remote_device_id,
|
||||
reinterpret_cast<dataT*>((void*)remote_key_ptr),
|
||||
reinterpret_cast<dataT*>((void*)remote_value_ptr),
|
||||
cuda_stream);
|
||||
}
|
||||
}
|
||||
// using dataT=std::remove_pointer<decltype(local_block_ids_ptr)>;
|
||||
}
|
||||
|
||||
void SentKeyValueByRemotePtrBlockSync(const paddle::Tensor& local_key_tensor,
|
||||
const paddle::Tensor& local_value_tensor,
|
||||
const int64_t& cuda_stream_raw) {
|
||||
cudaStream_t cuda_stream = (cudaStream_t)cuda_stream_raw;
|
||||
cudaStreamSynchronize(cuda_stream);
|
||||
}
|
||||
const paddle::Tensor& local_value_tensor,
|
||||
const int64_t& cuda_stream_raw) {
|
||||
cudaStream_t cuda_stream = (cudaStream_t)cuda_stream_raw;
|
||||
cudaStreamSynchronize(cuda_stream);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(ipc_sent_key_value_cache_by_remote_ptr)
|
||||
.Inputs({"local_key_tensor", "local_value_tensor", "local_block_ids", "remote_block_ids", "remote_key_tensor", "remote_value_tensor"})
|
||||
.Attrs({ "block_num: int",
|
||||
"local_device_id: int",
|
||||
"remote_device_id: int",
|
||||
"cuda_stream_raw: int64_t"})
|
||||
.Inputs({"local_key_tensor",
|
||||
"local_value_tensor",
|
||||
"local_block_ids",
|
||||
"remote_block_ids",
|
||||
"remote_key_tensor",
|
||||
"remote_value_tensor"})
|
||||
.Attrs({"block_num: int",
|
||||
"local_device_id: int",
|
||||
"remote_device_id: int",
|
||||
"cuda_stream_raw: int64_t",
|
||||
"is_scale: bool"})
|
||||
.Outputs({"local_key_tensor_out", "local_value_tensor_out"})
|
||||
.SetInplaceMap({{"local_key_tensor", "local_key_tensor_out"},{"local_value_tensor","local_value_tensor_out"}})
|
||||
.SetInplaceMap({{"local_key_tensor", "local_key_tensor_out"},
|
||||
{"local_value_tensor", "local_value_tensor_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SentKeyValueByRemotePtr));
|
||||
|
||||
PD_BUILD_STATIC_OP(ipc_sent_key_value_cache_by_remote_ptr_block_sync)
|
||||
.Inputs({"local_key_tensor", "local_value_tensor"})
|
||||
.Attrs({"cuda_stream_raw: int64_t"})
|
||||
.Outputs({"local_key_tensor_out", "local_value_tensor_out"})
|
||||
.SetInplaceMap({{"local_key_tensor", "local_key_tensor_out"},{"local_value_tensor","local_value_tensor_out"}})
|
||||
.SetInplaceMap({{"local_key_tensor", "local_key_tensor_out"},
|
||||
{"local_value_tensor", "local_value_tensor_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SentKeyValueByRemotePtrBlockSync));
|
||||
|
||||
@@ -81,9 +81,16 @@ def parse_args():
|
||||
"--cache_dtype",
|
||||
type=str,
|
||||
default="bfloat16",
|
||||
choices=["uint8", "bfloat16"],
|
||||
choices=["uint8", "bfloat16", "block_wise_fp8"],
|
||||
help="cache dtype",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--default_dtype",
|
||||
type=str,
|
||||
default="bfloat16",
|
||||
choices=["float16", "bfloat16", "uint8", "int8"],
|
||||
help="paddle default dtype, cache manager only support float16、bfloat16、int8 and uint8 now",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speculative_config",
|
||||
type=json.loads,
|
||||
@@ -127,6 +134,7 @@ class CacheMessager:
|
||||
num_layers,
|
||||
gpu_id=0,
|
||||
rdma_port=None,
|
||||
cache_dtype="bfloat16",
|
||||
):
|
||||
"""
|
||||
Initialize the CacheMessager object.
|
||||
@@ -170,6 +178,8 @@ class CacheMessager:
|
||||
cache_v_ptr_list = []
|
||||
cache_k = []
|
||||
cache_v = []
|
||||
k_scale_ptr_list, v_scale_ptr_list = [], []
|
||||
k_cache_scale, v_cache_scale = [], []
|
||||
self.messager = {}
|
||||
for layer_idx in range(self.num_layers):
|
||||
# value cache
|
||||
@@ -181,6 +191,11 @@ class CacheMessager:
|
||||
cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr()))
|
||||
else:
|
||||
cache_v_ptr_list.append(val_cache.data_ptr())
|
||||
if cache_dtype == "block_wise_fp8":
|
||||
val_scale_key = f"value_cache_scales_{layer_idx}_rank{self.rank}_device{gpu_id}"
|
||||
val_cache_scale = self.gpu_cache_kvs[val_scale_key]
|
||||
v_cache_scale.append(val_cache_scale)
|
||||
v_scale_ptr_list.append(val_cache_scale.data_ptr())
|
||||
# key cache
|
||||
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
||||
cache_k.append(key_cache)
|
||||
@@ -188,6 +203,11 @@ class CacheMessager:
|
||||
cache_k_ptr_list.append(get_peer_mem_addr(key_cache.data_ptr()))
|
||||
else:
|
||||
cache_k_ptr_list.append(key_cache.data_ptr())
|
||||
if cache_dtype == "block_wise_fp8":
|
||||
key_scale_key = f"key_cache_scales_{layer_idx}_rank{self.rank}_device{gpu_id}"
|
||||
key_cache_scale = self.gpu_cache_kvs[key_scale_key]
|
||||
k_cache_scale.append(key_cache_scale)
|
||||
k_scale_ptr_list.append(key_cache_scale.data_ptr())
|
||||
cache_k_ptr_list = np.array(cache_k_ptr_list)
|
||||
cache_v_ptr_list = np.array(cache_v_ptr_list)
|
||||
|
||||
@@ -211,6 +231,9 @@ class CacheMessager:
|
||||
gpu_id,
|
||||
cache_k,
|
||||
cache_v,
|
||||
k_cache_scale,
|
||||
v_cache_scale,
|
||||
cache_dtype,
|
||||
)
|
||||
local_device_id = int(str(cache_k[0].place)[-2])
|
||||
logger.info(f"done create ipc_comm with local_device_id:{local_device_id}, ")
|
||||
@@ -437,6 +460,7 @@ class CacheMessagerV1:
|
||||
gpu_id=0,
|
||||
block_size=64,
|
||||
rdma_port=None,
|
||||
cache_dtype="bfloat16",
|
||||
):
|
||||
"""
|
||||
Initialize the CacheMessager object.
|
||||
@@ -473,7 +497,9 @@ class CacheMessagerV1:
|
||||
self.block_size = block_size
|
||||
transfer_protocol = transfer_protocol.split(",")
|
||||
|
||||
logger.info(f"splitwise role: {splitwise_role}, {transfer_protocol}" f"rank: {rank}")
|
||||
logger.info(
|
||||
f"splitwise role: {splitwise_role}, transfer_protocol: {transfer_protocol}, rank: {rank}, cache_dtype: {cache_dtype}"
|
||||
)
|
||||
|
||||
# 1. initialize the cache_k_ptr_list and cache_v_ptr_list
|
||||
self.num_layers = num_layers
|
||||
@@ -481,6 +507,8 @@ class CacheMessagerV1:
|
||||
cache_v_ptr_list = []
|
||||
cache_k = []
|
||||
cache_v = []
|
||||
k_scale_ptr_list, v_scale_ptr_list = [], []
|
||||
k_cache_scale, v_cache_scale = [], []
|
||||
self.messager = {}
|
||||
for layer_idx in range(self.num_layers):
|
||||
# value cache
|
||||
@@ -492,6 +520,11 @@ class CacheMessagerV1:
|
||||
cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr()))
|
||||
else:
|
||||
cache_v_ptr_list.append(val_cache.data_ptr())
|
||||
if cache_dtype == "block_wise_fp8":
|
||||
val_scale_key = f"value_cache_scales_{layer_idx}_rank{self.rank}_device{gpu_id}"
|
||||
val_cache_scale = self.gpu_cache_kvs[val_scale_key]
|
||||
v_cache_scale.append(val_cache_scale)
|
||||
v_scale_ptr_list.append(val_cache_scale.data_ptr())
|
||||
# key cache
|
||||
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
||||
cache_k.append(key_cache)
|
||||
@@ -499,6 +532,11 @@ class CacheMessagerV1:
|
||||
cache_k_ptr_list.append(get_peer_mem_addr(key_cache.data_ptr()))
|
||||
else:
|
||||
cache_k_ptr_list.append(key_cache.data_ptr())
|
||||
if cache_dtype == "block_wise_fp8":
|
||||
key_scale_key = f"key_cache_scales_{layer_idx}_rank{self.rank}_device{gpu_id}"
|
||||
key_cache_scale = self.gpu_cache_kvs[key_scale_key]
|
||||
k_cache_scale.append(key_cache_scale)
|
||||
k_scale_ptr_list.append(key_cache_scale.data_ptr())
|
||||
cache_k_ptr_list = np.array(cache_k_ptr_list)
|
||||
cache_v_ptr_list = np.array(cache_v_ptr_list)
|
||||
|
||||
@@ -522,6 +560,9 @@ class CacheMessagerV1:
|
||||
gpu_id,
|
||||
cache_k,
|
||||
cache_v,
|
||||
k_cache_scale,
|
||||
v_cache_scale,
|
||||
cache_dtype,
|
||||
)
|
||||
local_device_id = int(str(cache_k[0].place)[-2])
|
||||
logger.info(f"done create ipc_comm with local_device_id:{local_device_id}, ")
|
||||
@@ -850,7 +891,13 @@ def main():
|
||||
device = args.device_id
|
||||
rank = args.rank
|
||||
set_device(device)
|
||||
cache_type = args.cache_dtype
|
||||
|
||||
paddle.set_default_dtype(args.default_dtype)
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
cache_type = "uint8"
|
||||
else:
|
||||
cache_type = args.cache_dtype
|
||||
|
||||
speculative_config = SpeculativeConfig(args.speculative_config)
|
||||
num_extra_layers = speculative_config.num_extra_cache_layer
|
||||
key_cache_shape_list = [int(i) for i in args.key_cache_shape.split(",")]
|
||||
@@ -894,6 +941,16 @@ def main():
|
||||
gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
|
||||
f"key_caches_{i}_rank{rank}.device{device}",
|
||||
)
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
gpu_cache_kvs[f"key_cache_scales_{i}_rank{rank}_device{device}"] = paddle.full(
|
||||
shape=[num_gpu_blocks, key_cache_shape[1], key_cache_shape[2]],
|
||||
fill_value=0,
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
set_data_ipc(
|
||||
gpu_cache_kvs[f"key_cache_scales_{i}_rank{rank}_device{device}"],
|
||||
f"key_cache_scales_{i}_rank{rank}.device{device}",
|
||||
)
|
||||
if value_cache_shape_list:
|
||||
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
|
||||
shape=value_cache_shape,
|
||||
@@ -906,6 +963,16 @@ def main():
|
||||
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
|
||||
f"value_caches_{i}_rank{rank}.device{device}",
|
||||
)
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
gpu_cache_kvs[f"value_cache_scales_{i}_rank{rank}_device{device}"] = paddle.full(
|
||||
shape=[num_gpu_blocks, value_cache_shape[1], value_cache_shape[2]],
|
||||
fill_value=0,
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
set_data_ipc(
|
||||
gpu_cache_kvs[f"value_cache_scales_{i}_rank{rank}_device{device}"],
|
||||
f"value_cache_scales_{i}_rank{rank}.device{device}",
|
||||
)
|
||||
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in gpu_cache_kvs.items()])
|
||||
logger.info(f"device :{device}")
|
||||
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
|
||||
@@ -924,6 +991,7 @@ def main():
|
||||
num_layers=args.num_layers + num_extra_layers,
|
||||
gpu_id=device,
|
||||
rdma_port=args.rdma_port,
|
||||
cache_dtype=args.cache_dtype,
|
||||
)
|
||||
else:
|
||||
cache_messager = CacheMessager(
|
||||
@@ -938,6 +1006,7 @@ def main():
|
||||
num_layers=args.num_layers + num_extra_layers,
|
||||
gpu_id=device,
|
||||
rdma_port=args.rdma_port,
|
||||
cache_dtype=args.cache_dtype,
|
||||
)
|
||||
|
||||
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
|
||||
|
||||
@@ -380,6 +380,7 @@ class PrefixCacheManager:
|
||||
+ f" --key_cache_shape {key_cache_shape}"
|
||||
+ val_cache_arg_str
|
||||
+ f" --pod_ip {pod_ip}"
|
||||
+ f" --default_dtype '{self.config.model_config.dtype}'"
|
||||
+ f" --cache_queue_port {cache_config.local_cache_queue_port}"
|
||||
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
|
||||
+ f" --protocol {cache_config.cache_transfer_protocol}"
|
||||
|
||||
@@ -31,7 +31,7 @@ class IPCConnector:
|
||||
IPC communication class.
|
||||
"""
|
||||
|
||||
def __init__(self, rank_id_, remote_gpu_id_, layer_num, local_gpu_id_):
|
||||
def __init__(self, rank_id_, remote_gpu_id_, layer_num, local_gpu_id_, cache_dtype):
|
||||
"""
|
||||
Args:
|
||||
rank_id_: rank id
|
||||
@@ -40,16 +40,26 @@ class IPCConnector:
|
||||
"""
|
||||
self.remote_key_tensor_ptr_list = []
|
||||
self.remote_value_tensor_ptr_list = []
|
||||
self.remote_key_scale_tensor_ptr_list = []
|
||||
self.remote_value_scale_tensor_ptr_list = []
|
||||
self.remote_gpu_id = int(remote_gpu_id_)
|
||||
self.rank_id = rank_id_
|
||||
self.local_gpu_id = int(local_gpu_id_)
|
||||
self.cache_dtype = cache_dtype
|
||||
tmp = paddle.ones([1, 1])
|
||||
logger.info(f"init ipc rank{self.rank_id} with remote {self.remote_gpu_id} {self.local_gpu_id}")
|
||||
logger.info(
|
||||
f"init ipc rank{self.rank_id} with remote {self.remote_gpu_id} {self.local_gpu_id}, cache dtype {self.cache_dtype}"
|
||||
)
|
||||
for layer_id in range(layer_num):
|
||||
key_unique_name = f"key_caches_{layer_id}_rank{self.rank_id}.device{self.remote_gpu_id}"
|
||||
value_unique_name = f"value_caches_{layer_id}_rank{self.rank_id}.device{self.remote_gpu_id}"
|
||||
self.remote_key_tensor_ptr_list.append(get_data_ptr_ipc(tmp, key_unique_name))
|
||||
self.remote_value_tensor_ptr_list.append(get_data_ptr_ipc(tmp, value_unique_name))
|
||||
if self.cache_dtype == "block_wise_fp8":
|
||||
key_scale_name = f"key_cache_scales_{layer_id}_rank{self.rank_id}.device{self.remote_gpu_id}"
|
||||
val_scale_name = f"value_cache_scales_{layer_id}_rank{self.rank_id}.device{self.remote_gpu_id}"
|
||||
self.remote_key_scale_tensor_ptr_list.append(get_data_ptr_ipc(tmp, key_scale_name))
|
||||
self.remote_value_scale_tensor_ptr_list.append(get_data_ptr_ipc(tmp, val_scale_name))
|
||||
self.write_stream = paddle.device.Stream(f"gpu:{self.local_gpu_id}")
|
||||
|
||||
|
||||
@@ -64,13 +74,19 @@ class IPCCommManager:
|
||||
gpu_idx_,
|
||||
local_key_cache_tensor_list, # tensor list
|
||||
local_value_cache_tensor_list, # tensor
|
||||
local_key_cache_scale_list,
|
||||
local_value_cache_scale_list,
|
||||
cache_dtype,
|
||||
):
|
||||
self.rank_id = rank_id_
|
||||
self.gpu_idx = gpu_idx_
|
||||
self.cache_dtype = cache_dtype
|
||||
# local cache to tensor
|
||||
self.local_key_cache_tensor_list = local_key_cache_tensor_list
|
||||
self.local_value_cache_tensor_list = local_value_cache_tensor_list
|
||||
self.layer_num = len(self.local_key_cache_tensor_list)
|
||||
self.local_key_cache_scale_list = local_key_cache_scale_list
|
||||
self.local_value_cache_scale_list = local_value_cache_scale_list
|
||||
# record connected ipc info
|
||||
self.comm_map = {}
|
||||
|
||||
@@ -82,7 +98,9 @@ class IPCCommManager:
|
||||
if self.is_connected(remote_gpu_id_):
|
||||
return True
|
||||
else:
|
||||
self.comm_map[remote_gpu_id_] = IPCConnector(self.rank_id, remote_gpu_id_, self.layer_num, self.gpu_idx)
|
||||
self.comm_map[remote_gpu_id_] = IPCConnector(
|
||||
self.rank_id, remote_gpu_id_, self.layer_num, self.gpu_idx, self.cache_dtype
|
||||
)
|
||||
return True
|
||||
|
||||
def is_connected(self, remote_gpu_id_=0):
|
||||
@@ -114,7 +132,23 @@ class IPCCommManager:
|
||||
self.gpu_idx,
|
||||
comm.remote_gpu_id,
|
||||
comm.write_stream.stream_base.cuda_stream,
|
||||
False,
|
||||
)
|
||||
if self.cache_dtype == "block_wise_fp8":
|
||||
logger.info(f"IPC write cache scales for layer: {layer_idx}")
|
||||
ipc_sent_key_value_cache_by_remote_ptr(
|
||||
self.local_key_cache_scale_list[layer_idx],
|
||||
self.local_value_cache_scale_list[layer_idx],
|
||||
local_block_ids,
|
||||
remote_block_ids,
|
||||
comm.remote_key_scale_tensor_ptr_list[layer_idx],
|
||||
comm.remote_value_scale_tensor_ptr_list[layer_idx],
|
||||
block_num,
|
||||
self.gpu_idx,
|
||||
comm.remote_gpu_id,
|
||||
comm.write_stream.stream_base.cuda_stream,
|
||||
True,
|
||||
)
|
||||
return 0
|
||||
|
||||
def write_block_by_sync(self, remote_gpu_id):
|
||||
|
||||
Reference in New Issue
Block a user