mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
@@ -19,13 +19,13 @@
|
||||
#endif
|
||||
|
||||
__device__ bool is_in_list(const int64_t id, const int64_t *ids, int bs_id) {
|
||||
bool is_in_list = false;
|
||||
for (int i = 0; i < bs_id; i++) {
|
||||
if (ids[i] == id) {
|
||||
return true;
|
||||
}
|
||||
bool is_in_list = false;
|
||||
for (int i = 0; i < bs_id; i++) {
|
||||
if (ids[i] == id) {
|
||||
return true;
|
||||
}
|
||||
return is_in_list;
|
||||
}
|
||||
return is_in_list;
|
||||
}
|
||||
|
||||
__global__ void set_value_by_id(const bool *stop_flags,
|
||||
@@ -33,46 +33,45 @@ __global__ void set_value_by_id(const bool *stop_flags,
|
||||
bool *stop_flags_out,
|
||||
int bs,
|
||||
int bs_id) {
|
||||
int tid = threadIdx.x;
|
||||
if (tid < bs && !is_in_list(tid, ids, bs_id)) {
|
||||
stop_flags_out[tid] = true;
|
||||
}
|
||||
int tid = threadIdx.x;
|
||||
if (tid < bs && !is_in_list(tid, ids, bs_id)) {
|
||||
stop_flags_out[tid] = true;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> SetFlags(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &gather_id) {
|
||||
PD_CHECK(gather_id.dtype() == paddle::DataType::INT64);
|
||||
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
|
||||
auto cu_stream = stop_flags.stream();
|
||||
std::vector<int64_t> flag_shape = stop_flags.shape();
|
||||
std::vector<int64_t> id_shape = gather_id.shape();
|
||||
auto stop_flags_out =
|
||||
stop_flags.copy_to(stop_flags.place(), false); // gpu -> gpu
|
||||
if (flag_shape[0] == id_shape[0]) {
|
||||
return {stop_flags_out};
|
||||
}
|
||||
int flag_bs = flag_shape[0];
|
||||
int id_bs = id_shape[0];
|
||||
int block_size = (flag_bs + 32 - 1) / 32 * 32;
|
||||
set_value_by_id<<<1, block_size, 0, cu_stream>>>(
|
||||
stop_flags.data<bool>(),
|
||||
gather_id.data<int64_t>(),
|
||||
stop_flags_out.data<bool>(),
|
||||
flag_bs,
|
||||
id_bs);
|
||||
PD_CHECK(gather_id.dtype() == paddle::DataType::INT64);
|
||||
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
|
||||
auto cu_stream = stop_flags.stream();
|
||||
std::vector<int64_t> flag_shape = stop_flags.shape();
|
||||
std::vector<int64_t> id_shape = gather_id.shape();
|
||||
auto stop_flags_out =
|
||||
stop_flags.copy_to(stop_flags.place(), false); // gpu -> gpu
|
||||
if (flag_shape[0] == id_shape[0]) {
|
||||
return {stop_flags_out};
|
||||
}
|
||||
int flag_bs = flag_shape[0];
|
||||
int id_bs = id_shape[0];
|
||||
int block_size = (flag_bs + 32 - 1) / 32 * 32;
|
||||
set_value_by_id<<<1, block_size, 0, cu_stream>>>(stop_flags.data<bool>(),
|
||||
gather_id.data<int64_t>(),
|
||||
stop_flags_out.data<bool>(),
|
||||
flag_bs,
|
||||
id_bs);
|
||||
return {stop_flags_out};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> SetFlagsInferShape(
|
||||
const std::vector<int64_t> &stop_flags_shape,
|
||||
const std::vector<int64_t> &gather_id_shape) {
|
||||
return {stop_flags_shape};
|
||||
return {stop_flags_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> SetFlagsInferDtype(
|
||||
const paddle::DataType &stop_flags_dtype,
|
||||
const paddle::DataType &gather_id_dtype) {
|
||||
return {stop_flags_dtype};
|
||||
return {stop_flags_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(set_flags)
|
||||
|
||||
Reference in New Issue
Block a user