Co-authored-by: gongweibao <gognweibao@baidu.com>
This commit is contained in:
gongweibao
2026-03-04 21:55:31 +08:00
committed by GitHub
parent 5c8f5184d9
commit ddb06ff83f
306 changed files with 40627 additions and 34418 deletions
+44 -46
View File
@@ -27,73 +27,71 @@ __global__ void System2GroupKernel(int* group_ids,
const int* system_ids,
const int bsz,
const int max_bsz) {
const int ti = threadIdx.x;
if (ti < bsz) {
if (seq_lens_this_time[ti] <=
0) { // 终止位置不参与分组,encoder需要是一个特定的system
// id,在seqs2seqs里处理
return;
}
int group_id = system_ids[ti];
int group_len_now = atomicAdd(&group_lens[group_id], 1);
if (seq_lens_encoder[ti] <= 0) { // is decoder
atomicAdd(dec_group_num, 1);
atomicAdd(&group_lens_without_encoder[group_id], 1);
}
group_ids[group_id * max_bsz + group_len_now] = ti;
const int ti = threadIdx.x;
if (ti < bsz) {
if (seq_lens_this_time[ti] <=
0) { // 终止位置不参与分组,encoder需要是一个特定的system
// id,在seqs2seqs里处理
return;
}
int group_id = system_ids[ti];
int group_len_now = atomicAdd(&group_lens[group_id], 1);
if (seq_lens_encoder[ti] <= 0) { // is decoder
atomicAdd(dec_group_num, 1);
atomicAdd(&group_lens_without_encoder[group_id], 1);
}
group_ids[group_id * max_bsz + group_len_now] = ti;
}
}
std::vector<paddle::Tensor> System2Group(
const paddle::Tensor& system_ids,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder) {
auto cu_stream = seq_lens_this_time.stream();
const int bsz = seq_lens_this_time.shape()[0];
const int max_bsz = seq_lens_encoder.shape()[0];
auto cu_stream = seq_lens_this_time.stream();
const int bsz = seq_lens_this_time.shape()[0];
const int max_bsz = seq_lens_encoder.shape()[0];
auto group_ids = paddle::full({bsz, max_bsz},
-1,
paddle::DataType::INT32,
seq_lens_this_time.place());
auto group_lens = paddle::full(
{bsz, 1}, 0, paddle::DataType::INT32, seq_lens_this_time.place());
auto group_lens_without_encoder = paddle::full(
{bsz, 1}, 0, paddle::DataType::INT32, seq_lens_this_time.place());
auto dec_group_num = paddle::full(
{1}, 0, paddle::DataType::INT32, seq_lens_this_time.place());
auto group_ids = paddle::full(
{bsz, max_bsz}, -1, paddle::DataType::INT32, seq_lens_this_time.place());
auto group_lens = paddle::full(
{bsz, 1}, 0, paddle::DataType::INT32, seq_lens_this_time.place());
auto group_lens_without_encoder = paddle::full(
{bsz, 1}, 0, paddle::DataType::INT32, seq_lens_this_time.place());
auto dec_group_num =
paddle::full({1}, 0, paddle::DataType::INT32, seq_lens_this_time.place());
const int blockSize = (bsz + 32 - 1) / 32 * 32;
System2GroupKernel<<<1, blockSize, 0, cu_stream>>>(
group_ids.data<int>(),
group_lens.data<int>(),
group_lens_without_encoder.data<int>(),
dec_group_num.data<int>(),
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
system_ids.data<int>(),
bsz,
max_bsz);
return {group_ids, group_lens, group_lens_without_encoder, dec_group_num};
const int blockSize = (bsz + 32 - 1) / 32 * 32;
System2GroupKernel<<<1, blockSize, 0, cu_stream>>>(
group_ids.data<int>(),
group_lens.data<int>(),
group_lens_without_encoder.data<int>(),
dec_group_num.data<int>(),
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
system_ids.data<int>(),
bsz,
max_bsz);
return {group_ids, group_lens, group_lens_without_encoder, dec_group_num};
}
std::vector<std::vector<int64_t>> System2GroupInferShape(
const std::vector<int64_t>& system_ids_shape,
const std::vector<int64_t>& seq_lens_this_time_shape,
const std::vector<int64_t>& seq_lens_encoder_shape) {
int64_t bsz = seq_lens_this_time_shape[0];
int64_t max_bsz = seq_lens_encoder_shape[0];
return {{bsz, max_bsz}, {bsz, 1}, {bsz, 1}, {1}};
int64_t bsz = seq_lens_this_time_shape[0];
int64_t max_bsz = seq_lens_encoder_shape[0];
return {{bsz, max_bsz}, {bsz, 1}, {bsz, 1}, {1}};
}
std::vector<paddle::DataType> System2GroupInferDtype(
const paddle::DataType& system_ids_dtype,
const paddle::DataType& seq_lens_this_time_dtype,
const paddle::DataType& seq_lens_encoder_dtype) {
return {seq_lens_this_time_dtype,
seq_lens_this_time_dtype,
seq_lens_this_time_dtype,
seq_lens_this_time_dtype};
return {seq_lens_this_time_dtype,
seq_lens_this_time_dtype,
seq_lens_this_time_dtype,
seq_lens_this_time_dtype};
}
PD_BUILD_STATIC_OP(system2group)