mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 08:21:53 +08:00
@@ -29,39 +29,38 @@ update_split_fuse_inputs_kernel(int* split_fuse_seq_lens,
|
||||
int64_t* step_idx,
|
||||
const int split_fuse_size,
|
||||
const int max_seq_len) {
|
||||
const int bi = blockIdx.x;
|
||||
const int tidx = threadIdx.x;
|
||||
if (split_fuse_seq_lens[bi] <= 0) {
|
||||
return;
|
||||
const int bi = blockIdx.x;
|
||||
const int tidx = threadIdx.x;
|
||||
if (split_fuse_seq_lens[bi] <= 0) {
|
||||
return;
|
||||
}
|
||||
if (split_fuse_cur_seq_lens[bi] < split_fuse_seq_lens[bi]) {
|
||||
const int cur_add_tokens = min(
|
||||
split_fuse_seq_lens[bi] - split_fuse_cur_seq_lens[bi], split_fuse_size);
|
||||
int64_t* split_fuse_all_input_ids_cur_batch = split_fuse_all_input_ids +
|
||||
bi * max_seq_len +
|
||||
split_fuse_cur_seq_lens[bi];
|
||||
int64_t* input_ids_cur_batch = input_ids + bi * max_seq_len;
|
||||
for (int i = tidx; i < cur_add_tokens; i += blockDim.x) {
|
||||
input_ids_cur_batch[i] = split_fuse_all_input_ids_cur_batch[i];
|
||||
}
|
||||
if (split_fuse_cur_seq_lens[bi] < split_fuse_seq_lens[bi]) {
|
||||
const int cur_add_tokens =
|
||||
min(split_fuse_seq_lens[bi] - split_fuse_cur_seq_lens[bi],
|
||||
split_fuse_size);
|
||||
int64_t* split_fuse_all_input_ids_cur_batch =
|
||||
split_fuse_all_input_ids + bi * max_seq_len +
|
||||
split_fuse_cur_seq_lens[bi];
|
||||
int64_t* input_ids_cur_batch = input_ids + bi * max_seq_len;
|
||||
for (int i = tidx; i < cur_add_tokens; i += blockDim.x) {
|
||||
input_ids_cur_batch[i] = split_fuse_all_input_ids_cur_batch[i];
|
||||
}
|
||||
if (threadIdx.x == 0) {
|
||||
seq_lens_this_time[bi] = cur_add_tokens;
|
||||
seq_lens_encoder[bi] = cur_add_tokens;
|
||||
seq_lens_decoder[bi] = split_fuse_cur_seq_lens[bi];
|
||||
step_idx[bi] = 0;
|
||||
split_fuse_cur_seq_lens[bi] += cur_add_tokens;
|
||||
}
|
||||
} else if (split_fuse_cur_seq_lens[bi] >= split_fuse_seq_lens[bi]) {
|
||||
if (threadIdx.x == 0) {
|
||||
seq_lens_decoder[bi] = split_fuse_cur_seq_lens[bi];
|
||||
seq_lens_this_time[bi] = 1;
|
||||
step_idx[bi] = 1;
|
||||
seq_lens_encoder[bi] = 0;
|
||||
split_fuse_cur_seq_lens[bi] = 0;
|
||||
split_fuse_seq_lens[bi] = 0;
|
||||
}
|
||||
if (threadIdx.x == 0) {
|
||||
seq_lens_this_time[bi] = cur_add_tokens;
|
||||
seq_lens_encoder[bi] = cur_add_tokens;
|
||||
seq_lens_decoder[bi] = split_fuse_cur_seq_lens[bi];
|
||||
step_idx[bi] = 0;
|
||||
split_fuse_cur_seq_lens[bi] += cur_add_tokens;
|
||||
}
|
||||
} else if (split_fuse_cur_seq_lens[bi] >= split_fuse_seq_lens[bi]) {
|
||||
if (threadIdx.x == 0) {
|
||||
seq_lens_decoder[bi] = split_fuse_cur_seq_lens[bi];
|
||||
seq_lens_this_time[bi] = 1;
|
||||
step_idx[bi] = 1;
|
||||
seq_lens_encoder[bi] = 0;
|
||||
split_fuse_cur_seq_lens[bi] = 0;
|
||||
split_fuse_seq_lens[bi] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateSplitFuseInputes(const paddle::Tensor& split_fuse_seq_lens,
|
||||
@@ -75,23 +74,20 @@ void UpdateSplitFuseInputes(const paddle::Tensor& split_fuse_seq_lens,
|
||||
const int max_seq_len,
|
||||
const int max_batch_size,
|
||||
const int split_fuse_size) {
|
||||
dim3 grids;
|
||||
grids.x = max_batch_size;
|
||||
const int block_size = 128;
|
||||
update_split_fuse_inputs_kernel<<<grids,
|
||||
block_size,
|
||||
0,
|
||||
input_ids.stream()>>>(
|
||||
const_cast<int*>(split_fuse_seq_lens.data<int>()),
|
||||
const_cast<int*>(split_fuse_cur_seq_lens.data<int>()),
|
||||
const_cast<int64_t*>(split_fuse_all_input_ids.data<int64_t>()),
|
||||
const_cast<int64_t*>(input_ids.data<int64_t>()),
|
||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int*>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int*>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
split_fuse_size,
|
||||
max_seq_len);
|
||||
dim3 grids;
|
||||
grids.x = max_batch_size;
|
||||
const int block_size = 128;
|
||||
update_split_fuse_inputs_kernel<<<grids, block_size, 0, input_ids.stream()>>>(
|
||||
const_cast<int*>(split_fuse_seq_lens.data<int>()),
|
||||
const_cast<int*>(split_fuse_cur_seq_lens.data<int>()),
|
||||
const_cast<int64_t*>(split_fuse_all_input_ids.data<int64_t>()),
|
||||
const_cast<int64_t*>(input_ids.data<int64_t>()),
|
||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int*>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int*>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
split_fuse_size,
|
||||
max_seq_len);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(update_split_fuse_inputs)
|
||||
|
||||
Reference in New Issue
Block a user