mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[stop sequence] support stop sequence (#3025)
* stop seqs in multi-ends * unittest for gpu stop op * kernel tid==0
This commit is contained in:
@@ -30,30 +30,62 @@ __global__ void set_value_by_flags(bool *stop_flags,
|
||||
const int *seq_lens,
|
||||
const int bs,
|
||||
const int end_length,
|
||||
const int64_t *pre_ids,
|
||||
const int pre_ids_len,
|
||||
const int64_t *step_idx,
|
||||
const int64_t *stop_seqs,
|
||||
const int *stop_seqs_len,
|
||||
const int stop_seqs_bs,
|
||||
const int stop_seqs_max_len,
|
||||
bool beam_search,
|
||||
bool prefill_one_step_stop) {
|
||||
int tid = threadIdx.x;
|
||||
if (tid < bs) {
|
||||
if (prefill_one_step_stop) {
|
||||
stop_flags[tid] = true;
|
||||
if (seq_lens[tid] == 0) {
|
||||
topk_ids[tid] = -1;
|
||||
}
|
||||
next_tokens[tid] = topk_ids[tid];
|
||||
} else {
|
||||
if (stop_flags[tid]) {
|
||||
if (seq_lens[tid] == 0) {
|
||||
topk_ids[tid] = -1;
|
||||
} else {
|
||||
topk_ids[tid] = end_ids[0];
|
||||
next_tokens[tid] = end_ids[0];
|
||||
int bid = blockIdx.x;
|
||||
if (tid >= stop_seqs_bs) return;
|
||||
if (bid < bs) {
|
||||
if(tid == 0){
|
||||
if (prefill_one_step_stop) {
|
||||
stop_flags[bid] = true;
|
||||
if (seq_lens[bid] == 0) {
|
||||
topk_ids[bid] = -1;
|
||||
}
|
||||
next_tokens[bid] = topk_ids[bid];
|
||||
} else {
|
||||
next_tokens[tid] = topk_ids[tid];
|
||||
if (stop_flags[bid]) {
|
||||
if (seq_lens[bid] == 0) {
|
||||
topk_ids[bid] = -1;
|
||||
} else {
|
||||
topk_ids[bid] = end_ids[0];
|
||||
next_tokens[bid] = end_ids[0];
|
||||
}
|
||||
} else {
|
||||
next_tokens[bid] = topk_ids[bid];
|
||||
}
|
||||
}
|
||||
if (!beam_search && is_in_end(topk_ids[bid], end_ids, end_length)) {
|
||||
stop_flags[bid] = true;
|
||||
}
|
||||
}
|
||||
if (!beam_search && is_in_end(topk_ids[tid], end_ids, end_length)) {
|
||||
stop_flags[tid] = true;
|
||||
// dealing stop_seqs
|
||||
const int stop_seq_len = (stop_seqs_len + bid * stop_seqs_bs)[tid];
|
||||
if (stop_seq_len <= 0) return;
|
||||
const int64_t *stop_seq_now = stop_seqs + bid * stop_seqs_bs + tid * stop_seqs_max_len;
|
||||
const int64_t *pre_ids_now = pre_ids + bid * pre_ids_len;
|
||||
const int64_t step_idx_now = step_idx[bid];
|
||||
|
||||
bool is_end = true;
|
||||
int count = 1;
|
||||
for (int i = stop_seq_len - 1; i >= 0; --i) {
|
||||
if ((step_idx_now - count) < 0 ||
|
||||
pre_ids_now[step_idx_now - count++] != stop_seq_now[i]) {
|
||||
is_end = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (is_end) {
|
||||
next_tokens[bid] = end_ids[0];
|
||||
stop_flags[bid] = true;
|
||||
topk_ids[bid] = end_ids[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -63,6 +95,10 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
|
||||
const paddle::Tensor &seq_lens,
|
||||
const paddle::Tensor &end_ids,
|
||||
const paddle::Tensor &next_tokens,
|
||||
const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &stop_seqs,
|
||||
const paddle::Tensor &stop_seqs_len,
|
||||
const bool beam_search) {
|
||||
PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64);
|
||||
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
|
||||
@@ -83,8 +119,10 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
|
||||
std::vector<int64_t> shape = topk_ids.shape();
|
||||
int64_t bs_now = shape[0];
|
||||
int64_t end_length = end_ids.shape()[0];
|
||||
int block_size = (bs_now + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
|
||||
set_value_by_flags<<<1, block_size, 0, cu_stream>>>(
|
||||
int stop_seqs_bs = stop_seqs.shape()[1];
|
||||
int stop_seqs_max_len = stop_seqs.shape()[2];
|
||||
int block_size = (stop_seqs_bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
|
||||
set_value_by_flags<<<bs_now, block_size, 0, cu_stream>>>(
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int64_t *>(topk_ids.data<int64_t>()),
|
||||
const_cast<int64_t *>(next_tokens.data<int64_t>()),
|
||||
@@ -92,12 +130,19 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
|
||||
seq_lens.data<int>(),
|
||||
bs_now,
|
||||
end_length,
|
||||
pre_ids.data<int64_t>(),
|
||||
pre_ids.shape()[1],
|
||||
step_idx.data<int64_t>(),
|
||||
stop_seqs.data<int64_t>(),
|
||||
stop_seqs_len.data<int>(),
|
||||
stop_seqs_bs,
|
||||
stop_seqs_max_len,
|
||||
beam_search,
|
||||
prefill_one_step_stop);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(set_stop_value_multi_ends)
|
||||
.Inputs({"topk_ids", "stop_flags", "seq_lens", "end_ids", "next_tokens"})
|
||||
.Inputs({"topk_ids", "stop_flags", "seq_lens", "end_ids", "next_tokens", "pre_ids", "step_idx", "stop_seqs", "stop_seqs_len"})
|
||||
.Attrs({"beam_search: bool"})
|
||||
.Outputs({"topk_ids_out", "stop_flags_out", "next_tokens_out"})
|
||||
.SetInplaceMap({{"topk_ids", "topk_ids_out"},
|
||||
|
||||
Reference in New Issue
Block a user