[BugFix] fix shm opened but not closed in set_data_ipc (#5826)

This commit is contained in:
Yonghua Li
2025-12-29 23:35:07 +08:00
committed by GitHub
parent deb9698ac5
commit a8d3e3ba12
2 changed files with 48 additions and 60 deletions
+47 -60
View File
@@ -12,14 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "cuda_multiprocess.h"
#include "helper.h"
int sharedMemoryCreate(const char *name, size_t sz, sharedMemoryInfo *info) {
int sharedMemoryCreate(const char* name, size_t sz, sharedMemoryInfo* info) {
#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64)
info->size = sz;
info->shmHandle = CreateFileMapping(INVALID_HANDLE_VALUE, NULL,
PAGE_READWRITE, 0, (DWORD)sz, name);
info->shmHandle = CreateFileMapping(
INVALID_HANDLE_VALUE, NULL, PAGE_READWRITE, 0, (DWORD)sz, name);
if (info->shmHandle == 0) {
return GetLastError();
}
@@ -42,20 +42,22 @@ int sharedMemoryCreate(const char *name, size_t sz, sharedMemoryInfo *info) {
status = ftruncate(info->shmFd, sz);
if (status != 0) {
return status;
return errno;
}
info->addr = mmap(0, sz, PROT_READ | PROT_WRITE, MAP_SHARED, info->shmFd, 0);
if (info->addr == NULL) {
if (info->addr == MAP_FAILED) {
return errno;
}
close(info->shmFd);
info->shmFd = -1;
return 0;
#endif
}
template <typename T>
__global__ void set_data(T *input, int n) {
__global__ void set_data(T* input, int n) {
if (threadIdx.x == 0) {
for (int i = 0; i < n; ++i) {
*(input + i) = static_cast<T>(i);
@@ -65,7 +67,7 @@ __global__ void set_data(T *input, int n) {
}
template <typename T>
__global__ void print_data(const T *input, int n) {
__global__ void print_data(const T* input, int n) {
if (threadIdx.x == 0) {
for (int i = 0; i < n; ++i) {
printf("input[%d]: %f\n", i, input[i]);
@@ -81,72 +83,57 @@ void set_data_ipc(const paddle::Tensor& tmp_input,
typedef typename traits_::data_t data_t;
sharedMemoryInfo info;
volatile shmStruct *shm = NULL;
volatile shmStruct* shm = NULL;
if (sharedMemoryCreate(shm_name.c_str(), sizeof(*shm), &info) != 0) {
printf("Failed to create shared memory slab\n");
printf("Func sharedMemoryCreate. Shm_name: %s\n", shm_name.c_str());
exit(EXIT_FAILURE);
printf("Failed to create shared memory slab\n");
printf("Func sharedMemoryCreate. Shm_name: %s\n", shm_name.c_str());
exit(EXIT_FAILURE);
}
shm = (volatile shmStruct *)info.addr;
memset((void *)shm, 0, sizeof(*shm));
shm = (volatile shmStruct*)info.addr;
memset((void*)shm, 0, sizeof(*shm));
void *data_ptr_now = reinterpret_cast<void*>(const_cast<data_t*>(tmp_input.data<data_t>()));
void* data_ptr_now =
reinterpret_cast<void*>(const_cast<data_t*>(tmp_input.data<data_t>()));
#ifdef PADDLE_WITH_HIP
checkCudaErrors(hipIpcGetMemHandle((hipIpcMemHandle_t *)&shm->memHandle, data_ptr_now));
checkCudaErrors(
hipIpcGetMemHandle((hipIpcMemHandle_t*)&shm->memHandle, data_ptr_now));
#else
checkCudaErrors(cudaIpcGetMemHandle((cudaIpcMemHandle_t *)&shm->memHandle, data_ptr_now));
checkCudaErrors(
cudaIpcGetMemHandle((cudaIpcMemHandle_t*)&shm->memHandle, data_ptr_now));
#endif
}
void SetDataIpc(const paddle::Tensor& tmp_input,
const std::string& shm_name) {
std::vector<int64_t> shape = tmp_input.shape();
void SetDataIpc(const paddle::Tensor& tmp_input, const std::string& shm_name) {
std::vector<int64_t> shape = tmp_input.shape();
switch (tmp_input.type()) {
case paddle::DataType::BFLOAT16: {
return set_data_ipc<paddle::DataType::BFLOAT16>(
tmp_input,
shm_name
);
}
case paddle::DataType::FLOAT16: {
return set_data_ipc<paddle::DataType::FLOAT16>(
tmp_input,
shm_name
);
}
case paddle::DataType::FLOAT32: {
return set_data_ipc<paddle::DataType::FLOAT32>(
tmp_input,
shm_name
);
}
case paddle::DataType::INT8: {
return set_data_ipc<paddle::DataType::INT8>(
tmp_input,
shm_name
);
}
case paddle::DataType::UINT8: {
return set_data_ipc<paddle::DataType::UINT8>(
tmp_input,
shm_name
);
}
default: {
PD_THROW(
"NOT supported data type. "
"Only float16, bfloat16 and float32 are supported. ");
break;
}
switch (tmp_input.type()) {
case paddle::DataType::BFLOAT16: {
return set_data_ipc<paddle::DataType::BFLOAT16>(tmp_input, shm_name);
}
case paddle::DataType::FLOAT16: {
return set_data_ipc<paddle::DataType::FLOAT16>(tmp_input, shm_name);
}
case paddle::DataType::FLOAT32: {
return set_data_ipc<paddle::DataType::FLOAT32>(tmp_input, shm_name);
}
case paddle::DataType::INT8: {
return set_data_ipc<paddle::DataType::INT8>(tmp_input, shm_name);
}
case paddle::DataType::UINT8: {
return set_data_ipc<paddle::DataType::UINT8>(tmp_input, shm_name);
}
default: {
PD_THROW(
"NOT supported data type. "
"Only float16, bfloat16 and float32 are supported. ");
break;
}
}
}
PD_BUILD_STATIC_OP(set_data_ipc)
.Inputs({"tmp_input"})
.Attrs({ "shm_name: std::string"})
.Attrs({"shm_name: std::string"})
.Outputs({"tmp_input_out"})
.SetInplaceMap({{"tmp_input", "tmp_input_out"}})
.SetKernelFn(PD_KERNEL(SetDataIpc));