Co-authored-by: gongweibao <gognweibao@baidu.com>
This commit is contained in:
gongweibao
2026-03-04 21:55:31 +08:00
committed by GitHub
parent 5c8f5184d9
commit ddb06ff83f
306 changed files with 40627 additions and 34418 deletions
+44 -48
View File
@@ -29,39 +29,38 @@ update_split_fuse_inputs_kernel(int* split_fuse_seq_lens,
int64_t* step_idx,
const int split_fuse_size,
const int max_seq_len) {
const int bi = blockIdx.x;
const int tidx = threadIdx.x;
if (split_fuse_seq_lens[bi] <= 0) {
return;
const int bi = blockIdx.x;
const int tidx = threadIdx.x;
if (split_fuse_seq_lens[bi] <= 0) {
return;
}
if (split_fuse_cur_seq_lens[bi] < split_fuse_seq_lens[bi]) {
const int cur_add_tokens = min(
split_fuse_seq_lens[bi] - split_fuse_cur_seq_lens[bi], split_fuse_size);
int64_t* split_fuse_all_input_ids_cur_batch = split_fuse_all_input_ids +
bi * max_seq_len +
split_fuse_cur_seq_lens[bi];
int64_t* input_ids_cur_batch = input_ids + bi * max_seq_len;
for (int i = tidx; i < cur_add_tokens; i += blockDim.x) {
input_ids_cur_batch[i] = split_fuse_all_input_ids_cur_batch[i];
}
if (split_fuse_cur_seq_lens[bi] < split_fuse_seq_lens[bi]) {
const int cur_add_tokens =
min(split_fuse_seq_lens[bi] - split_fuse_cur_seq_lens[bi],
split_fuse_size);
int64_t* split_fuse_all_input_ids_cur_batch =
split_fuse_all_input_ids + bi * max_seq_len +
split_fuse_cur_seq_lens[bi];
int64_t* input_ids_cur_batch = input_ids + bi * max_seq_len;
for (int i = tidx; i < cur_add_tokens; i += blockDim.x) {
input_ids_cur_batch[i] = split_fuse_all_input_ids_cur_batch[i];
}
if (threadIdx.x == 0) {
seq_lens_this_time[bi] = cur_add_tokens;
seq_lens_encoder[bi] = cur_add_tokens;
seq_lens_decoder[bi] = split_fuse_cur_seq_lens[bi];
step_idx[bi] = 0;
split_fuse_cur_seq_lens[bi] += cur_add_tokens;
}
} else if (split_fuse_cur_seq_lens[bi] >= split_fuse_seq_lens[bi]) {
if (threadIdx.x == 0) {
seq_lens_decoder[bi] = split_fuse_cur_seq_lens[bi];
seq_lens_this_time[bi] = 1;
step_idx[bi] = 1;
seq_lens_encoder[bi] = 0;
split_fuse_cur_seq_lens[bi] = 0;
split_fuse_seq_lens[bi] = 0;
}
if (threadIdx.x == 0) {
seq_lens_this_time[bi] = cur_add_tokens;
seq_lens_encoder[bi] = cur_add_tokens;
seq_lens_decoder[bi] = split_fuse_cur_seq_lens[bi];
step_idx[bi] = 0;
split_fuse_cur_seq_lens[bi] += cur_add_tokens;
}
} else if (split_fuse_cur_seq_lens[bi] >= split_fuse_seq_lens[bi]) {
if (threadIdx.x == 0) {
seq_lens_decoder[bi] = split_fuse_cur_seq_lens[bi];
seq_lens_this_time[bi] = 1;
step_idx[bi] = 1;
seq_lens_encoder[bi] = 0;
split_fuse_cur_seq_lens[bi] = 0;
split_fuse_seq_lens[bi] = 0;
}
}
}
void UpdateSplitFuseInputes(const paddle::Tensor& split_fuse_seq_lens,
@@ -75,23 +74,20 @@ void UpdateSplitFuseInputes(const paddle::Tensor& split_fuse_seq_lens,
const int max_seq_len,
const int max_batch_size,
const int split_fuse_size) {
dim3 grids;
grids.x = max_batch_size;
const int block_size = 128;
update_split_fuse_inputs_kernel<<<grids,
block_size,
0,
input_ids.stream()>>>(
const_cast<int*>(split_fuse_seq_lens.data<int>()),
const_cast<int*>(split_fuse_cur_seq_lens.data<int>()),
const_cast<int64_t*>(split_fuse_all_input_ids.data<int64_t>()),
const_cast<int64_t*>(input_ids.data<int64_t>()),
const_cast<int*>(seq_lens_this_time.data<int>()),
const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<int64_t*>(step_idx.data<int64_t>()),
split_fuse_size,
max_seq_len);
dim3 grids;
grids.x = max_batch_size;
const int block_size = 128;
update_split_fuse_inputs_kernel<<<grids, block_size, 0, input_ids.stream()>>>(
const_cast<int*>(split_fuse_seq_lens.data<int>()),
const_cast<int*>(split_fuse_cur_seq_lens.data<int>()),
const_cast<int64_t*>(split_fuse_all_input_ids.data<int64_t>()),
const_cast<int64_t*>(input_ids.data<int64_t>()),
const_cast<int*>(seq_lens_this_time.data<int>()),
const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<int64_t*>(step_idx.data<int64_t>()),
split_fuse_size,
max_seq_len);
}
PD_BUILD_STATIC_OP(update_split_fuse_inputs)