mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[BugFix] Fix token_penalty kernel (#6069)
* fix token_penalty kernel * try to fix xpu * fix xpu * fix unit test
This commit is contained in:
+222
-212
@@ -27,20 +27,20 @@ __global__ inline void min_length_logits_process(
|
||||
const int64_t length,
|
||||
const int64_t end_length,
|
||||
const int max_seq_len) {
|
||||
const int token_idx = threadIdx.x;
|
||||
if (token_idx >= token_num) return;
|
||||
const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len;
|
||||
if (bi >= bs) return;
|
||||
const int query_start_token_idx = bi * max_seq_len - output_cum_offsets[bi];
|
||||
const int token_idx = threadIdx.x;
|
||||
if (token_idx >= token_num) return;
|
||||
const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len;
|
||||
if (bi >= bs) return;
|
||||
const int query_start_token_idx = bi * max_seq_len - output_cum_offsets[bi];
|
||||
|
||||
if (cur_len[bi] < 0) {
|
||||
return;
|
||||
}
|
||||
if (cur_len[bi] + (token_idx - query_start_token_idx) < min_len[bi]) {
|
||||
for (int i = 0; i < end_length; i++) {
|
||||
logits[token_idx * length + eos_token_id[i]] = -1e10;
|
||||
}
|
||||
if (cur_len[bi] < 0) {
|
||||
return;
|
||||
}
|
||||
if (cur_len[bi] + (token_idx - query_start_token_idx) < min_len[bi]) {
|
||||
for (int i = 0; i < end_length; i++) {
|
||||
logits[token_idx * length + eos_token_id[i]] = -1e10;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
@@ -56,20 +56,20 @@ __global__ inline void min_length_logits_process<half>(
|
||||
const int64_t length,
|
||||
const int64_t end_length,
|
||||
const int max_seq_len) {
|
||||
const int token_idx = threadIdx.x;
|
||||
if (token_idx >= token_num) return;
|
||||
const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len;
|
||||
if (bi >= bs) return;
|
||||
const int query_start_token_idx = bi * max_seq_len - output_cum_offsets[bi];
|
||||
const int token_idx = threadIdx.x;
|
||||
if (token_idx >= token_num) return;
|
||||
const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len;
|
||||
if (bi >= bs) return;
|
||||
const int query_start_token_idx = bi * max_seq_len - output_cum_offsets[bi];
|
||||
|
||||
if (cur_len[bi] < 0) {
|
||||
return;
|
||||
}
|
||||
if (cur_len[bi] + (token_idx - query_start_token_idx) < min_len[bi]) {
|
||||
for (int i = 0; i < end_length; i++) {
|
||||
logits[token_idx * length + eos_token_id[i]] = -1e4;
|
||||
}
|
||||
if (cur_len[bi] < 0) {
|
||||
return;
|
||||
}
|
||||
if (cur_len[bi] + (token_idx - query_start_token_idx) < min_len[bi]) {
|
||||
for (int i = 0; i < end_length; i++) {
|
||||
logits[token_idx * length + eos_token_id[i]] = -1e4;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void update_repeat_times(const int64_t *pre_ids,
|
||||
@@ -81,21 +81,21 @@ __global__ void update_repeat_times(const int64_t *pre_ids,
|
||||
const int64_t length,
|
||||
const int64_t length_id,
|
||||
const int max_seq_len) {
|
||||
const int token_idx = blockIdx.x;
|
||||
if (token_idx >= token_num) return;
|
||||
const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len;
|
||||
if (bi >= bs) return;
|
||||
if (cur_len[bi] < 0) {
|
||||
return;
|
||||
}
|
||||
int tid = threadIdx.x;
|
||||
const int64_t *pre_ids_now = pre_ids + bi * length_id;
|
||||
int *repeat_times_now = repeat_times + token_idx * length;
|
||||
for (int i = tid; i < length_id; i += blockDim.x) {
|
||||
int64_t id = pre_ids_now[i];
|
||||
if (id < 0) break;
|
||||
atomicAdd(&repeat_times_now[id], 1);
|
||||
}
|
||||
const int token_idx = blockIdx.x;
|
||||
if (token_idx >= token_num) return;
|
||||
const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len;
|
||||
if (bi >= bs) return;
|
||||
if (cur_len[bi] < 0) {
|
||||
return;
|
||||
}
|
||||
int tid = threadIdx.x;
|
||||
const int64_t *pre_ids_now = pre_ids + bi * length_id;
|
||||
int *repeat_times_now = repeat_times + token_idx * length;
|
||||
for (int i = tid; i < length_id; i += blockDim.x) {
|
||||
int64_t id = pre_ids_now[i];
|
||||
if (id < 0) break;
|
||||
atomicAdd(&repeat_times_now[id], 1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -110,47 +110,52 @@ __global__ void update_value_by_repeat_times(const int *repeat_times,
|
||||
const int64_t bs,
|
||||
const int64_t length,
|
||||
const int max_seq_len) {
|
||||
const int token_idx = blockIdx.x;
|
||||
if (token_idx >= token_num) return;
|
||||
const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len;
|
||||
if (bi >= bs) return;
|
||||
int tid = threadIdx.x;
|
||||
T *logits_now = logits + token_idx * length;
|
||||
const int *repeat_times_now = repeat_times + token_idx * length;
|
||||
float alpha = static_cast<float>(penalty_scores[bi]);
|
||||
float beta = static_cast<float>(frequency_score[bi]);
|
||||
float gamma = static_cast<float>(presence_score[bi]);
|
||||
for (int i = tid; i < length; i += blockDim.x) {
|
||||
int times = repeat_times_now[i];
|
||||
float logit_now = static_cast<float>(logits_now[i]);
|
||||
if (times != 0) {
|
||||
logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha;
|
||||
logit_now = logit_now - times * beta - gamma;
|
||||
}
|
||||
logits_now[i] = static_cast<T>(logit_now / temperatures[bi]);
|
||||
const int token_idx = blockIdx.x;
|
||||
if (token_idx >= token_num) return;
|
||||
const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len;
|
||||
if (bi >= bs) return;
|
||||
int tid = threadIdx.x;
|
||||
T *logits_now = logits + token_idx * length;
|
||||
const int *repeat_times_now = repeat_times + token_idx * length;
|
||||
float alpha = static_cast<float>(penalty_scores[bi]);
|
||||
float beta = static_cast<float>(frequency_score[bi]);
|
||||
float gamma = static_cast<float>(presence_score[bi]);
|
||||
for (int i = tid; i < length; i += blockDim.x) {
|
||||
int times = repeat_times_now[i];
|
||||
float logit_now = static_cast<float>(logits_now[i]);
|
||||
if (times != 0) {
|
||||
logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha;
|
||||
logit_now = logit_now - times * beta - gamma;
|
||||
}
|
||||
logits_now[i] = static_cast<T>(logit_now / temperatures[bi]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void ban_bad_words(T *logits,
|
||||
const int64_t *bad_words_list,
|
||||
const int64_t *bad_tokens,
|
||||
const int64_t *bad_tokens_len,
|
||||
const int *output_padding_offset,
|
||||
const int64_t token_num,
|
||||
const int64_t bs,
|
||||
const int64_t length,
|
||||
const int64_t bad_words_length,
|
||||
const int max_seq_len) {
|
||||
const int token_idx = blockIdx.x;
|
||||
if (token_idx >= token_num) return;
|
||||
const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len;
|
||||
if (bi >= bs) return;
|
||||
int tid = threadIdx.x;
|
||||
T *logits_now = logits + token_idx * length;
|
||||
for (int i = tid; i < bad_words_length; i += blockDim.x) {
|
||||
const int64_t bad_words_token_id = bad_words_list[i];
|
||||
if (bad_words_token_id >= length || bad_words_token_id < 0) continue;
|
||||
logits_now[bad_words_token_id] = -1e10;
|
||||
}
|
||||
const int token_idx = blockIdx.x;
|
||||
if (token_idx >= token_num) return;
|
||||
const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len;
|
||||
|
||||
if (bi >= bs) return;
|
||||
int tid = threadIdx.x;
|
||||
T *logits_now = logits + token_idx * length;
|
||||
const int64_t *bad_tokens_now = bad_tokens + bi * bad_words_length;
|
||||
const int32_t bad_token_len =
|
||||
static_cast<int32_t>(min(bad_tokens_len[bi], bad_words_length));
|
||||
for (int i = tid; i < bad_token_len; i += blockDim.x) {
|
||||
const int64_t bad_words_token_id = bad_tokens_now[i];
|
||||
if (bad_words_token_id >= length || bad_words_token_id < 0) continue;
|
||||
logits_now[bad_words_token_id] = -1e10;
|
||||
}
|
||||
}
|
||||
|
||||
template <paddle::DataType D>
|
||||
@@ -162,6 +167,7 @@ void token_penalty_multi_scores_kernel(
|
||||
const paddle::Tensor &presence_score,
|
||||
const paddle::Tensor &temperatures,
|
||||
const paddle::Tensor &bad_tokens,
|
||||
const paddle::Tensor &bad_tokens_len,
|
||||
const paddle::Tensor &cur_len,
|
||||
const paddle::Tensor &min_len,
|
||||
const paddle::Tensor &eos_token_id,
|
||||
@@ -169,156 +175,159 @@ void token_penalty_multi_scores_kernel(
|
||||
const paddle::Tensor &output_padding_offset,
|
||||
const paddle::Tensor &output_cum_offsets,
|
||||
const int max_seq_len) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
auto cu_stream = logits.stream();
|
||||
std::vector<int64_t> shape = logits.shape();
|
||||
auto repeat_times =
|
||||
paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place());
|
||||
int64_t bs = seq_lens_this_time.shape()[0];
|
||||
int64_t token_num = shape[0];
|
||||
int64_t length = shape[1];
|
||||
int64_t length_id = pre_ids.shape()[1];
|
||||
int64_t length_bad_words = bad_tokens.shape()[1];
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
auto cu_stream = logits.stream();
|
||||
std::vector<int64_t> shape = logits.shape();
|
||||
auto repeat_times =
|
||||
paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place());
|
||||
int64_t bs = seq_lens_this_time.shape()[0];
|
||||
int64_t token_num = shape[0];
|
||||
int64_t length = shape[1];
|
||||
int64_t length_id = pre_ids.shape()[1];
|
||||
int64_t length_bad_words = bad_tokens.shape()[1];
|
||||
|
||||
int64_t end_length = eos_token_id.shape()[0];
|
||||
int64_t end_length = eos_token_id.shape()[0];
|
||||
int block_size = (token_num + 32 - 1) / 32 * 32;
|
||||
min_length_logits_process<<<1, block_size, 0, cu_stream>>>(
|
||||
reinterpret_cast<DataType_ *>(
|
||||
const_cast<data_t *>(logits.data<data_t>())),
|
||||
cur_len.data<int64_t>(),
|
||||
min_len.data<int64_t>(),
|
||||
eos_token_id.data<int64_t>(),
|
||||
output_padding_offset.data<int>(),
|
||||
output_cum_offsets.data<int>(),
|
||||
token_num,
|
||||
bs,
|
||||
length,
|
||||
end_length,
|
||||
max_seq_len);
|
||||
|
||||
int block_size = (token_num + 32 - 1) / 32 * 32;
|
||||
min_length_logits_process<<<1, block_size, 0, cu_stream>>>(
|
||||
reinterpret_cast<DataType_ *>(
|
||||
const_cast<data_t *>(logits.data<data_t>())),
|
||||
cur_len.data<int64_t>(),
|
||||
min_len.data<int64_t>(),
|
||||
eos_token_id.data<int64_t>(),
|
||||
output_padding_offset.data<int>(),
|
||||
output_cum_offsets.data<int>(),
|
||||
token_num,
|
||||
bs,
|
||||
length,
|
||||
end_length,
|
||||
max_seq_len);
|
||||
block_size = (length_id + 32 - 1) / 32 * 32;
|
||||
block_size = min(block_size, 512);
|
||||
update_repeat_times<<<token_num, block_size, 0, cu_stream>>>(
|
||||
pre_ids.data<int64_t>(),
|
||||
cur_len.data<int64_t>(),
|
||||
repeat_times.data<int>(),
|
||||
output_padding_offset.data<int>(),
|
||||
token_num,
|
||||
bs,
|
||||
length,
|
||||
length_id,
|
||||
max_seq_len);
|
||||
|
||||
block_size = (length_id + 32 - 1) / 32 * 32;
|
||||
block_size = min(block_size, 512);
|
||||
update_repeat_times<<<token_num, block_size, 0, cu_stream>>>(
|
||||
pre_ids.data<int64_t>(),
|
||||
cur_len.data<int64_t>(),
|
||||
repeat_times.data<int>(),
|
||||
output_padding_offset.data<int>(),
|
||||
token_num,
|
||||
bs,
|
||||
length,
|
||||
length_id,
|
||||
max_seq_len);
|
||||
|
||||
block_size = (length + 32 - 1) / 32 * 32;
|
||||
block_size = min(block_size, 512);
|
||||
update_value_by_repeat_times<DataType_>
|
||||
<<<token_num, block_size, 0, cu_stream>>>(
|
||||
repeat_times.data<int>(),
|
||||
reinterpret_cast<DataType_ *>(
|
||||
const_cast<data_t *>(penalty_scores.data<data_t>())),
|
||||
reinterpret_cast<DataType_ *>(
|
||||
const_cast<data_t *>(frequency_score.data<data_t>())),
|
||||
reinterpret_cast<DataType_ *>(
|
||||
const_cast<data_t *>(presence_score.data<data_t>())),
|
||||
temperatures.data<float>(),
|
||||
reinterpret_cast<DataType_ *>(
|
||||
const_cast<data_t *>(logits.data<data_t>())),
|
||||
output_padding_offset.data<int>(),
|
||||
token_num,
|
||||
bs,
|
||||
length,
|
||||
max_seq_len);
|
||||
|
||||
block_size = (length_bad_words + 32 - 1) / 32 * 32;
|
||||
block_size = min(block_size, 512);
|
||||
ban_bad_words<DataType_><<<token_num, block_size, 0, cu_stream>>>(
|
||||
reinterpret_cast<DataType_ *>(
|
||||
const_cast<data_t *>(logits.data<data_t>())),
|
||||
bad_tokens.data<int64_t>(),
|
||||
output_padding_offset.data<int>(),
|
||||
token_num,
|
||||
bs,
|
||||
length,
|
||||
length_bad_words,
|
||||
max_seq_len);
|
||||
block_size = (length + 32 - 1) / 32 * 32;
|
||||
block_size = min(block_size, 512);
|
||||
update_value_by_repeat_times<DataType_>
|
||||
<<<token_num, block_size, 0, cu_stream>>>(
|
||||
repeat_times.data<int>(),
|
||||
reinterpret_cast<DataType_ *>(
|
||||
const_cast<data_t *>(penalty_scores.data<data_t>())),
|
||||
reinterpret_cast<DataType_ *>(
|
||||
const_cast<data_t *>(frequency_score.data<data_t>())),
|
||||
reinterpret_cast<DataType_ *>(
|
||||
const_cast<data_t *>(presence_score.data<data_t>())),
|
||||
temperatures.data<float>(),
|
||||
reinterpret_cast<DataType_ *>(
|
||||
const_cast<data_t *>(logits.data<data_t>())),
|
||||
output_padding_offset.data<int>(),
|
||||
token_num,
|
||||
bs,
|
||||
length,
|
||||
max_seq_len);
|
||||
block_size = (length_bad_words + 32 - 1) / 32 * 32;
|
||||
block_size = min(block_size, 512);
|
||||
ban_bad_words<DataType_><<<token_num, block_size, 0, cu_stream>>>(
|
||||
reinterpret_cast<DataType_ *>(
|
||||
const_cast<data_t *>(logits.data<data_t>())),
|
||||
bad_tokens.data<int64_t>(),
|
||||
bad_tokens_len.data<int64_t>(),
|
||||
output_padding_offset.data<int>(),
|
||||
token_num,
|
||||
bs,
|
||||
length,
|
||||
length_bad_words,
|
||||
max_seq_len);
|
||||
}
|
||||
|
||||
void SpecTokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &logits,
|
||||
const paddle::Tensor &penalty_scores,
|
||||
const paddle::Tensor &frequency_scores,
|
||||
const paddle::Tensor &presence_scores,
|
||||
const paddle::Tensor &temperatures,
|
||||
const paddle::Tensor &bad_tokens,
|
||||
const paddle::Tensor &cur_len,
|
||||
const paddle::Tensor &min_len,
|
||||
const paddle::Tensor &eos_token_id,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &output_padding_offset,
|
||||
const paddle::Tensor &output_cum_offsets,
|
||||
const int max_seq_len) {
|
||||
switch (logits.type()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return token_penalty_multi_scores_kernel<
|
||||
paddle::DataType::BFLOAT16>(pre_ids,
|
||||
logits,
|
||||
penalty_scores,
|
||||
frequency_scores,
|
||||
presence_scores,
|
||||
temperatures,
|
||||
bad_tokens,
|
||||
cur_len,
|
||||
min_len,
|
||||
eos_token_id,
|
||||
seq_lens_this_time,
|
||||
output_padding_offset,
|
||||
output_cum_offsets,
|
||||
max_seq_len);
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
return token_penalty_multi_scores_kernel<paddle::DataType::FLOAT16>(
|
||||
pre_ids,
|
||||
logits,
|
||||
penalty_scores,
|
||||
frequency_scores,
|
||||
presence_scores,
|
||||
temperatures,
|
||||
bad_tokens,
|
||||
cur_len,
|
||||
min_len,
|
||||
eos_token_id,
|
||||
seq_lens_this_time,
|
||||
output_padding_offset,
|
||||
output_cum_offsets,
|
||||
max_seq_len);
|
||||
}
|
||||
case paddle::DataType::FLOAT32: {
|
||||
return token_penalty_multi_scores_kernel<paddle::DataType::FLOAT32>(
|
||||
pre_ids,
|
||||
logits,
|
||||
penalty_scores,
|
||||
frequency_scores,
|
||||
presence_scores,
|
||||
temperatures,
|
||||
bad_tokens,
|
||||
cur_len,
|
||||
min_len,
|
||||
eos_token_id,
|
||||
seq_lens_this_time,
|
||||
output_padding_offset,
|
||||
output_cum_offsets,
|
||||
max_seq_len);
|
||||
}
|
||||
default: {
|
||||
PD_THROW(
|
||||
"NOT supported data type. "
|
||||
"Only float16, bfloat16 and float32 are supported. ");
|
||||
break;
|
||||
}
|
||||
const paddle::Tensor &logits,
|
||||
const paddle::Tensor &penalty_scores,
|
||||
const paddle::Tensor &frequency_scores,
|
||||
const paddle::Tensor &presence_scores,
|
||||
const paddle::Tensor &temperatures,
|
||||
const paddle::Tensor &bad_tokens,
|
||||
const paddle::Tensor &bad_tokens_len,
|
||||
const paddle::Tensor &cur_len,
|
||||
const paddle::Tensor &min_len,
|
||||
const paddle::Tensor &eos_token_id,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &output_padding_offset,
|
||||
const paddle::Tensor &output_cum_offsets,
|
||||
const int max_seq_len) {
|
||||
switch (logits.type()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return token_penalty_multi_scores_kernel<paddle::DataType::BFLOAT16>(
|
||||
pre_ids,
|
||||
logits,
|
||||
penalty_scores,
|
||||
frequency_scores,
|
||||
presence_scores,
|
||||
temperatures,
|
||||
bad_tokens,
|
||||
bad_tokens_len,
|
||||
cur_len,
|
||||
min_len,
|
||||
eos_token_id,
|
||||
seq_lens_this_time,
|
||||
output_padding_offset,
|
||||
output_cum_offsets,
|
||||
max_seq_len);
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
return token_penalty_multi_scores_kernel<paddle::DataType::FLOAT16>(
|
||||
pre_ids,
|
||||
logits,
|
||||
penalty_scores,
|
||||
frequency_scores,
|
||||
presence_scores,
|
||||
temperatures,
|
||||
bad_tokens,
|
||||
bad_tokens_len,
|
||||
cur_len,
|
||||
min_len,
|
||||
eos_token_id,
|
||||
seq_lens_this_time,
|
||||
output_padding_offset,
|
||||
output_cum_offsets,
|
||||
max_seq_len);
|
||||
}
|
||||
case paddle::DataType::FLOAT32: {
|
||||
return token_penalty_multi_scores_kernel<paddle::DataType::FLOAT32>(
|
||||
pre_ids,
|
||||
logits,
|
||||
penalty_scores,
|
||||
frequency_scores,
|
||||
presence_scores,
|
||||
temperatures,
|
||||
bad_tokens,
|
||||
bad_tokens_len,
|
||||
cur_len,
|
||||
min_len,
|
||||
eos_token_id,
|
||||
seq_lens_this_time,
|
||||
output_padding_offset,
|
||||
output_cum_offsets,
|
||||
max_seq_len);
|
||||
}
|
||||
default: {
|
||||
PD_THROW(
|
||||
"NOT supported data type. "
|
||||
"Only float16, bfloat16 and float32 are supported. ");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_get_token_penalty_multi_scores)
|
||||
@@ -329,6 +338,7 @@ PD_BUILD_STATIC_OP(speculate_get_token_penalty_multi_scores)
|
||||
"presence_scores",
|
||||
"temperatures",
|
||||
"bad_tokens",
|
||||
"bad_tokens_len",
|
||||
"cur_len",
|
||||
"min_len",
|
||||
"eos_token_id",
|
||||
|
||||
Reference in New Issue
Block a user