[Optimization] Avoid unnecessary penalty computation (#6078)

This commit is contained in:
sunxin
2026-01-19 15:24:12 +08:00
committed by GitHub
parent 05fbd89a8e
commit a4144e0b8e
+206 -188
View File
@@ -22,16 +22,16 @@ __global__ inline void min_length_logits_process(T *logits,
const int64_t bs,
const int64_t vocab_size,
const int64_t eos_len) {
int bi = threadIdx.x;
if (bi >= bs) return;
if (cur_len[bi] < 0) {
return;
}
if (cur_len[bi] < min_len[bi]) {
for (int i = 0; i < eos_len; i++) {
logits[bi * vocab_size + eos_token_id[i]] = -1e10;
}
int bi = threadIdx.x;
if (bi >= bs) return;
if (cur_len[bi] < 0) {
return;
}
if (cur_len[bi] < min_len[bi]) {
for (int i = 0; i < eos_len; i++) {
logits[bi * vocab_size + eos_token_id[i]] = -1e10;
}
}
}
template <>
@@ -43,16 +43,16 @@ __global__ inline void min_length_logits_process<half>(
const int64_t bs,
const int64_t vocab_size,
const int64_t eos_len) {
int bi = threadIdx.x;
if (bi >= bs) return;
if (cur_len[bi] < 0) {
return;
}
if (cur_len[bi] < min_len[bi]) {
for (int i = 0; i < eos_len; i++) {
logits[bi * vocab_size + eos_token_id[i]] = -1e4;
}
int bi = threadIdx.x;
if (bi >= bs) return;
if (cur_len[bi] < 0) {
return;
}
if (cur_len[bi] < min_len[bi]) {
for (int i = 0; i < eos_len; i++) {
logits[bi * vocab_size + eos_token_id[i]] = -1e4;
}
}
}
__global__ void update_repeat_times(const int64_t *pre_ids,
@@ -61,36 +61,46 @@ __global__ void update_repeat_times(const int64_t *pre_ids,
const int64_t *cur_len,
int *repeat_times,
int *is_repeated,
const float *penalty_scores,
const float *frequency_score,
const float *presence_score,
const int64_t bs,
const int64_t vocab_size,
const int64_t max_dec_len,
const int64_t max_model_len) {
int64_t bi = blockIdx.x;
if (cur_len[bi] < 0) {
return;
int64_t bi = blockIdx.x;
float alpha = penalty_scores[bi];
float beta = frequency_score[bi];
float gamma = presence_score[bi];
if (alpha == 1.f && beta == 0.f && gamma == 0.f) {
return;
}
if (cur_len[bi] < 0) {
return;
}
const int64_t prompt_len_now = prompt_len[bi];
int64_t tid = threadIdx.x;
const int64_t *prompt_now = prompt_ids + bi * max_model_len;
const int64_t *pre_ids_now = pre_ids + bi * max_dec_len;
int *repeat_times_now = repeat_times + bi * vocab_size;
int *is_repeated_now = is_repeated + bi * vocab_size;
const int64_t loop_len =
prompt_len_now > max_dec_len ? prompt_len_now : max_dec_len;
for (int64_t i = tid; i < loop_len; i += blockDim.x) {
if (i < max_dec_len) {
int64_t id = pre_ids_now[i];
if (id >= 0) {
atomicAdd(&repeat_times_now[id], 1);
atomicAdd(&is_repeated_now[id], 1);
}
}
const int64_t prompt_len_now = prompt_len[bi];
int64_t tid = threadIdx.x;
const int64_t *prompt_now = prompt_ids + bi * max_model_len;
const int64_t *pre_ids_now = pre_ids + bi * max_dec_len;
int *repeat_times_now = repeat_times + bi * vocab_size;
int *is_repeated_now = is_repeated + bi * vocab_size;
const int64_t loop_len = prompt_len_now > max_dec_len ? prompt_len_now : max_dec_len;
for (int64_t i = tid; i < loop_len; i += blockDim.x) {
if (i < max_dec_len) {
int64_t id = pre_ids_now[i];
if (id >= 0) {
atomicAdd(&repeat_times_now[id], 1);
atomicAdd(&is_repeated_now[id], 1);
}
}
if (i < prompt_len_now) {
int64_t id = prompt_now[i];
if (id >= 0) {
atomicAdd(&is_repeated_now[id], 1);
}
}
if (i < prompt_len_now) {
int64_t id = prompt_now[i];
if (id >= 0) {
atomicAdd(&is_repeated_now[id], 1);
}
}
}
}
template <typename T>
@@ -103,25 +113,29 @@ __global__ void update_value_by_repeat_times(const int *repeat_times,
T *logits,
const int64_t bs,
const int64_t vocab_size) {
int bi = blockIdx.x;
int tid = threadIdx.x;
T *logits_now = logits + bi * vocab_size;
const int *repeat_times_now = repeat_times + bi * vocab_size;
const int *is_repeated_now = is_repeated + bi * vocab_size;
float alpha = static_cast<float>(penalty_scores[bi]);
float beta = static_cast<float>(frequency_score[bi]);
float gamma = static_cast<float>(presence_score[bi]);
for (int i = tid; i < vocab_size; i += blockDim.x) {
int times = repeat_times_now[i];
float logit_now = static_cast<float>(logits_now[i]);
if (is_repeated_now[i] != 0) {
logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha;
}
if (times != 0) {
logit_now = logit_now - times * beta - gamma;
}
logits_now[i] = static_cast<T>(logit_now / temperatures[bi]);
int bi = blockIdx.x;
int tid = threadIdx.x;
T *logits_now = logits + bi * vocab_size;
const int *repeat_times_now = repeat_times + bi * vocab_size;
const int *is_repeated_now = is_repeated + bi * vocab_size;
float alpha = static_cast<float>(penalty_scores[bi]);
float beta = static_cast<float>(frequency_score[bi]);
float gamma = static_cast<float>(presence_score[bi]);
float temperature = temperatures[bi];
if (alpha == 1.f && beta == 0.f && gamma == 0.f && temperature == 1.f) {
return;
}
for (int i = tid; i < vocab_size; i += blockDim.x) {
int times = repeat_times_now[i];
float logit_now = static_cast<float>(logits_now[i]);
if (is_repeated_now[i] != 0) {
logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha;
}
if (times != 0) {
logit_now = logit_now - times * beta - gamma;
}
logits_now[i] = static_cast<T>(logit_now / temperature);
}
}
template <typename T>
@@ -130,14 +144,14 @@ __global__ void ban_bad_words(T *logits,
const int64_t bs,
const int64_t vocab_size,
const int64_t bad_words_len) {
const int bi = blockIdx.x;
int tid = threadIdx.x;
T *logits_now = logits + bi * vocab_size;
for (int i = tid; i < bad_words_len; i += blockDim.x) {
const int64_t bad_words_token_id = bad_words_list[i];
if (bad_words_token_id >= vocab_size || bad_words_token_id < 0) continue;
logits_now[bad_words_token_id] = -1e10;
}
const int bi = blockIdx.x;
int tid = threadIdx.x;
T *logits_now = logits + bi * vocab_size;
for (int i = tid; i < bad_words_len; i += blockDim.x) {
const int64_t bad_words_token_id = bad_words_list[i];
if (bad_words_token_id >= vocab_size || bad_words_token_id < 0) continue;
logits_now[bad_words_token_id] = -1e10;
}
}
template <paddle::DataType D>
@@ -153,91 +167,95 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
const paddle::Tensor &cur_len,
const paddle::Tensor &min_len,
const paddle::Tensor &eos_token_id) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(logits.place()));
auto cu_stream = dev_ctx->stream();
auto dev_ctx = static_cast<const phi::CustomContext *>(
paddle::experimental::DeviceContextPool::Instance().Get(logits.place()));
auto cu_stream = dev_ctx->stream();
#else
auto cu_stream = logits.stream();
auto cu_stream = logits.stream();
#endif
std::vector<int64_t> shape = logits.shape();
auto repeat_times =
paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place());
auto is_repeated =
paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place());
int64_t bs = shape[0];
std::vector<int64_t> shape = logits.shape();
auto repeat_times =
paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place());
auto is_repeated =
paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place());
int64_t bs = shape[0];
int64_t vocab_size = shape[1];
int64_t max_dec_len = pre_ids.shape()[1];
int64_t bad_words_len = bad_tokens.shape()[1];
int64_t eos_len = eos_token_id.shape()[0];
int64_t max_model_len = prompt_ids.shape()[1];
int64_t vocab_size = shape[1];
int64_t max_dec_len = pre_ids.shape()[1];
int64_t bad_words_len = bad_tokens.shape()[1];
int64_t eos_len = eos_token_id.shape()[0];
int64_t max_model_len = prompt_ids.shape()[1];
int block_size = (bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
min_length_logits_process<<<1, block_size, 0, cu_stream>>>(
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(logits.data<data_t>())),
cur_len.data<int64_t>(),
min_len.data<int64_t>(),
eos_token_id.data<int64_t>(),
bs,
vocab_size,
eos_len);
int block_size = (bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
min_length_logits_process<<<1, block_size, 0, cu_stream>>>(
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(logits.data<data_t>())),
cur_len.data<int64_t>(),
min_len.data<int64_t>(),
eos_token_id.data<int64_t>(),
bs,
vocab_size,
eos_len);
block_size = (max_dec_len + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
block_size = (max_dec_len + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
#ifdef PADDLE_WITH_COREX
block_size = std::min(block_size, 512);
block_size = std::min(block_size, 512);
#else
block_size = min(block_size, 512);
block_size = min(block_size, 512);
#endif
update_repeat_times<<<bs, block_size, 0, cu_stream>>>(
pre_ids.data<int64_t>(),
prompt_ids.data<int64_t>(),
prompt_len.data<int64_t>(),
cur_len.data<int64_t>(),
repeat_times.data<int>(),
is_repeated.data<int>(),
bs,
vocab_size,
max_dec_len,
max_model_len);
update_repeat_times<<<bs, block_size, 0, cu_stream>>>(
pre_ids.data<int64_t>(),
prompt_ids.data<int64_t>(),
prompt_len.data<int64_t>(),
cur_len.data<int64_t>(),
repeat_times.data<int>(),
is_repeated.data<int>(),
penalty_scores.data<float>(),
frequency_score.data<float>(),
presence_score.data<float>(),
bs,
vocab_size,
max_dec_len,
max_model_len);
block_size = (vocab_size + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
block_size = (vocab_size + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
#ifdef PADDLE_WITH_COREX
block_size = std::min(block_size, 512);
block_size = std::min(block_size, 512);
#else
block_size = min(block_size, 512);
block_size = min(block_size, 512);
#endif
update_value_by_repeat_times<DataType_><<<bs, block_size, 0, cu_stream>>>(
repeat_times.data<int>(),
is_repeated.data<int>(),
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(penalty_scores.data<data_t>())),
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(frequency_score.data<data_t>())),
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(presence_score.data<data_t>())),
temperatures.data<float>(),
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(logits.data<data_t>())),
bs,
vocab_size);
update_value_by_repeat_times<DataType_><<<bs, block_size, 0, cu_stream>>>(
repeat_times.data<int>(),
is_repeated.data<int>(),
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(penalty_scores.data<data_t>())),
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(frequency_score.data<data_t>())),
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(presence_score.data<data_t>())),
temperatures.data<float>(),
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(logits.data<data_t>())),
bs,
vocab_size);
block_size = (bad_words_len + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
block_size = (bad_words_len + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
#ifdef PADDLE_WITH_COREX
block_size = std::min(block_size, 512);
block_size = std::min(block_size, 512);
#else
block_size = min(block_size, 512);
block_size = min(block_size, 512);
#endif
ban_bad_words<DataType_><<<bs, block_size, 0, cu_stream>>>(
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(logits.data<data_t>())),
bad_tokens.data<int64_t>(),
bs,
vocab_size,
bad_words_len);
ban_bad_words<DataType_><<<bs, block_size, 0, cu_stream>>>(
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(logits.data<data_t>())),
bad_tokens.data<int64_t>(),
bs,
vocab_size,
bad_words_len);
}
void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
@@ -252,59 +270,59 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
const paddle::Tensor &cur_len,
const paddle::Tensor &min_len,
const paddle::Tensor &eos_token_id) {
switch (logits.type()) {
case paddle::DataType::BFLOAT16: {
return token_penalty_multi_scores_kernel<
paddle::DataType::BFLOAT16>(pre_ids,
prompt_ids,
prompt_len,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
bad_tokens,
cur_len,
min_len,
eos_token_id);
}
case paddle::DataType::FLOAT16: {
return token_penalty_multi_scores_kernel<
paddle::DataType::FLOAT16>(pre_ids,
prompt_ids,
prompt_len,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
bad_tokens,
cur_len,
min_len,
eos_token_id);
}
case paddle::DataType::FLOAT32: {
return token_penalty_multi_scores_kernel<
paddle::DataType::FLOAT32>(pre_ids,
prompt_ids,
prompt_len,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
bad_tokens,
cur_len,
min_len,
eos_token_id);
}
default: {
PD_THROW(
"NOT supported data type. "
"Only float16, bfloat16 and float32 are supported. ");
break;
}
switch (logits.type()) {
case paddle::DataType::BFLOAT16: {
return token_penalty_multi_scores_kernel<paddle::DataType::BFLOAT16>(
pre_ids,
prompt_ids,
prompt_len,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
bad_tokens,
cur_len,
min_len,
eos_token_id);
}
case paddle::DataType::FLOAT16: {
return token_penalty_multi_scores_kernel<paddle::DataType::FLOAT16>(
pre_ids,
prompt_ids,
prompt_len,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
bad_tokens,
cur_len,
min_len,
eos_token_id);
}
case paddle::DataType::FLOAT32: {
return token_penalty_multi_scores_kernel<paddle::DataType::FLOAT32>(
pre_ids,
prompt_ids,
prompt_len,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
bad_tokens,
cur_len,
min_len,
eos_token_id);
}
default: {
PD_THROW(
"NOT supported data type. "
"Only float16, bfloat16 and float32 are supported. ");
break;
}
}
}
PD_BUILD_STATIC_OP(get_token_penalty_multi_scores)