Files
FastDeploy/custom_ops/gpu_ops/share_external_data.cu
T
gongweibao a6351dea0b [BugFix][Optimization] Replace silent failures with catchable exceptions and informative error messages (#6533)
* init

* init

* fix format

* add

* add files

* add ut

* fix some

* add ut

* add more

* add

* fix pre-commit

* fix pre-commit

* fix cover

* skip long seq

* add

* add

* fix

* remove not need

* fix set attr

* fix comments

* fix comments

* fix failed tests

---------

Co-authored-by: gongweibao <gognweibao@baidu.com>
2026-03-16 21:32:43 +08:00

61 lines
2.2 KiB
Plaintext

// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdlib.h>
#include <string.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <unistd.h>
#include <fcntl.h>
#include <sys/mman.h>
#include <stdio.h>
#include "cuda_multiprocess.h"
#include "helper.h"
#include "paddle/phi/core/tensor_meta.h"
std::vector<paddle::Tensor> ShareExternalData(paddle::Tensor &input,
const std::string shm_name,
const std::vector<int> &shape) {
volatile shmStruct *shm = NULL;
sharedMemoryInfo info;
if (sharedMemoryOpen(shm_name.c_str(), sizeof(shmStruct), &info) != 0) {
throw std::runtime_error(
"Failed to open shared memory slab in ShareExternalData, shm_name: " +
shm_name + ", errno: " + std::string(strerror(errno)));
}
shm = (volatile shmStruct *)info.addr;
void *ptr = nullptr;
#ifdef PADDLE_WITH_HIP
checkCudaErrors(
hipIpcOpenMemHandle(&ptr,
*(hipIpcMemHandle_t *)&shm->memHandle, // NOLINT
hipIpcMemLazyEnablePeerAccess));
#else
checkCudaErrors(
cudaIpcOpenMemHandle(&ptr,
*(cudaIpcMemHandle_t *)&shm->memHandle, // NOLINT
cudaIpcMemLazyEnablePeerAccess));
#endif
paddle::Tensor tmp_tensor = paddle::from_blob(ptr, shape, input.type());
sharedMemoryClose(&info);
return {tmp_tensor};
}
PD_BUILD_STATIC_OP(share_external_data)
.Inputs({"input"})
.Outputs({"output"})
.Attrs({"shm_name: std::string", "shape: std::vector<int>"})
.SetKernelFn(PD_KERNEL(ShareExternalData));