mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
c++ code format (#4527)
This commit is contained in:
@@ -23,13 +23,13 @@
|
||||
#endif
|
||||
|
||||
bool is_in_end(const int64_t id, const int64_t *end_ids, int length) {
|
||||
bool flag = false;
|
||||
for (int i = 0; i < length; i++) {
|
||||
if (id == end_ids[i]) {
|
||||
return true;
|
||||
}
|
||||
bool flag = false;
|
||||
for (int i = 0; i < length; i++) {
|
||||
if (id == end_ids[i]) {
|
||||
return true;
|
||||
}
|
||||
return flag;
|
||||
}
|
||||
return flag;
|
||||
}
|
||||
|
||||
void set_value_by_flags(bool *stop_flags,
|
||||
@@ -40,23 +40,23 @@ void set_value_by_flags(bool *stop_flags,
|
||||
const int bs,
|
||||
const int end_length,
|
||||
bool beam_search) {
|
||||
for (int bi = 0; bi < bs; bi++) {
|
||||
if (stop_flags[bi]) {
|
||||
if ((seq_lens[bi] == 0)) {
|
||||
topk_ids[bi] = -1;
|
||||
} else {
|
||||
topk_ids[bi] = end_ids[0];
|
||||
next_tokens[bi] = end_ids[0];
|
||||
}
|
||||
} else {
|
||||
next_tokens[bi] = topk_ids[bi];
|
||||
}
|
||||
if (!beam_search && is_in_end(topk_ids[bi], end_ids, end_length)) {
|
||||
stop_flags[bi] = true;
|
||||
topk_ids[bi] = end_ids[0];
|
||||
next_tokens[bi] = end_ids[0];
|
||||
}
|
||||
for (int bi = 0; bi < bs; bi++) {
|
||||
if (stop_flags[bi]) {
|
||||
if ((seq_lens[bi] == 0)) {
|
||||
topk_ids[bi] = -1;
|
||||
} else {
|
||||
topk_ids[bi] = end_ids[0];
|
||||
next_tokens[bi] = end_ids[0];
|
||||
}
|
||||
} else {
|
||||
next_tokens[bi] = topk_ids[bi];
|
||||
}
|
||||
if (!beam_search && is_in_end(topk_ids[bi], end_ids, end_length)) {
|
||||
stop_flags[bi] = true;
|
||||
topk_ids[bi] = end_ids[0];
|
||||
next_tokens[bi] = end_ids[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
|
||||
@@ -65,17 +65,17 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
|
||||
const paddle::Tensor &end_ids,
|
||||
const paddle::Tensor &next_tokens,
|
||||
const bool beam_search) {
|
||||
std::vector<int64_t> shape = topk_ids.shape();
|
||||
int64_t bs_now = shape[0];
|
||||
int64_t end_length = end_ids.shape()[0];
|
||||
set_value_by_flags(const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int64_t *>(topk_ids.data<int64_t>()),
|
||||
const_cast<int64_t *>(next_tokens.data<int64_t>()),
|
||||
end_ids.data<int64_t>(),
|
||||
seq_lens.data<int>(),
|
||||
bs_now,
|
||||
end_length,
|
||||
false);
|
||||
std::vector<int64_t> shape = topk_ids.shape();
|
||||
int64_t bs_now = shape[0];
|
||||
int64_t end_length = end_ids.shape()[0];
|
||||
set_value_by_flags(const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int64_t *>(topk_ids.data<int64_t>()),
|
||||
const_cast<int64_t *>(next_tokens.data<int64_t>()),
|
||||
end_ids.data<int64_t>(),
|
||||
seq_lens.data<int>(),
|
||||
bs_now,
|
||||
end_length,
|
||||
false);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(set_stop_value_multi_ends_cpu)
|
||||
|
||||
Reference in New Issue
Block a user