mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
115 lines
5.3 KiB
Plaintext
115 lines
5.3 KiB
Plaintext
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "paddle/extension.h"
|
|
|
|
#ifndef PD_BUILD_STATIC_OP
|
|
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
|
#endif
|
|
|
|
void __global__
|
|
update_split_fuse_inputs_kernel(int* split_fuse_seq_lens,
|
|
int* split_fuse_cur_seq_lens,
|
|
int64_t* split_fuse_all_input_ids,
|
|
int64_t* input_ids,
|
|
int* seq_lens_this_time,
|
|
int* seq_lens_encoder,
|
|
int* seq_lens_decoder,
|
|
int64_t* step_idx,
|
|
const int split_fuse_size,
|
|
const int max_seq_len) {
|
|
const int bi = blockIdx.x;
|
|
const int tidx = threadIdx.x;
|
|
if (split_fuse_seq_lens[bi] <= 0) {
|
|
return;
|
|
}
|
|
if (split_fuse_cur_seq_lens[bi] < split_fuse_seq_lens[bi]) {
|
|
const int cur_add_tokens =
|
|
min(split_fuse_seq_lens[bi] - split_fuse_cur_seq_lens[bi],
|
|
split_fuse_size);
|
|
int64_t* split_fuse_all_input_ids_cur_batch =
|
|
split_fuse_all_input_ids + bi * max_seq_len +
|
|
split_fuse_cur_seq_lens[bi];
|
|
int64_t* input_ids_cur_batch = input_ids + bi * max_seq_len;
|
|
for (int i = tidx; i < cur_add_tokens; i += blockDim.x) {
|
|
input_ids_cur_batch[i] = split_fuse_all_input_ids_cur_batch[i];
|
|
}
|
|
if (threadIdx.x == 0) {
|
|
seq_lens_this_time[bi] = cur_add_tokens;
|
|
seq_lens_encoder[bi] = cur_add_tokens;
|
|
seq_lens_decoder[bi] = split_fuse_cur_seq_lens[bi];
|
|
step_idx[bi] = 0;
|
|
split_fuse_cur_seq_lens[bi] += cur_add_tokens;
|
|
}
|
|
} else if (split_fuse_cur_seq_lens[bi] >= split_fuse_seq_lens[bi]) {
|
|
if (threadIdx.x == 0) {
|
|
seq_lens_decoder[bi] = split_fuse_cur_seq_lens[bi];
|
|
seq_lens_this_time[bi] = 1;
|
|
step_idx[bi] = 1;
|
|
seq_lens_encoder[bi] = 0;
|
|
split_fuse_cur_seq_lens[bi] = 0;
|
|
split_fuse_seq_lens[bi] = 0;
|
|
}
|
|
}
|
|
}
|
|
|
|
void UpdateSplitFuseInputes(const paddle::Tensor& split_fuse_seq_lens,
|
|
const paddle::Tensor& split_fuse_cur_seq_lens,
|
|
const paddle::Tensor& split_fuse_all_input_ids,
|
|
const paddle::Tensor& input_ids,
|
|
const paddle::Tensor& seq_lens_this_time,
|
|
const paddle::Tensor& seq_lens_encoder,
|
|
const paddle::Tensor& seq_lens_decoder,
|
|
const paddle::Tensor& step_idx,
|
|
const int max_seq_len,
|
|
const int max_batch_size,
|
|
const int split_fuse_size) {
|
|
dim3 girds;
|
|
girds.x = max_batch_size;
|
|
const int block_size = 128;
|
|
update_split_fuse_inputs_kernel<<<girds,
|
|
block_size,
|
|
0,
|
|
input_ids.stream()>>>(
|
|
const_cast<int*>(split_fuse_seq_lens.data<int>()),
|
|
const_cast<int*>(split_fuse_cur_seq_lens.data<int>()),
|
|
const_cast<int64_t*>(split_fuse_all_input_ids.data<int64_t>()),
|
|
const_cast<int64_t*>(input_ids.data<int64_t>()),
|
|
const_cast<int*>(seq_lens_this_time.data<int>()),
|
|
const_cast<int*>(seq_lens_encoder.data<int>()),
|
|
const_cast<int*>(seq_lens_decoder.data<int>()),
|
|
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
|
split_fuse_size,
|
|
max_seq_len);
|
|
}
|
|
|
|
PD_BUILD_STATIC_OP(update_split_fuse_inputs)
|
|
.Inputs(
|
|
{"split_fuse_seq_lens", // 当前query的长度
|
|
"split_fuse_cur_seq_lens", // 当前query已经计算完成的长度,是split
|
|
// size的整数倍
|
|
"split_fuse_all_input_ids", // 当前query经过split的input
|
|
// ids,长度是split size的整数倍
|
|
"input_ids", // 当前query所有的input ids
|
|
"seq_lens_this_time", // 当前query需要计算的长度
|
|
"seq_lens_encoder", // 当前query encoder需要计算的长度,decoder时为0
|
|
"seq_lens_decoder", // 当前query decoder需要计算的长度,encoder时为0
|
|
"step_idx"}) // 当前query的token index,首token为0,第二个token为1
|
|
.Outputs({"input_ids_out"})
|
|
.Attrs({"max_seq_len: int", // 最大的seq len
|
|
"max_batch_size: int", // 最大的batch size
|
|
"split_fuse_size: int"}) // 切分的长度
|
|
.SetInplaceMap({{"input_ids", "input_ids_out"}})
|
|
.SetKernelFn(PD_KERNEL(UpdateSplitFuseInputes));
|