diff --git a/custom_ops/gpu_ops/token_penalty_multi_scores.cu b/custom_ops/gpu_ops/token_penalty_multi_scores.cu index 7db52f38af..22f529d3b1 100644 --- a/custom_ops/gpu_ops/token_penalty_multi_scores.cu +++ b/custom_ops/gpu_ops/token_penalty_multi_scores.cu @@ -22,16 +22,16 @@ __global__ inline void min_length_logits_process(T *logits, const int64_t bs, const int64_t vocab_size, const int64_t eos_len) { - int bi = threadIdx.x; - if (bi >= bs) return; - if (cur_len[bi] < 0) { - return; - } - if (cur_len[bi] < min_len[bi]) { - for (int i = 0; i < eos_len; i++) { - logits[bi * vocab_size + eos_token_id[i]] = -1e10; - } + int bi = threadIdx.x; + if (bi >= bs) return; + if (cur_len[bi] < 0) { + return; + } + if (cur_len[bi] < min_len[bi]) { + for (int i = 0; i < eos_len; i++) { + logits[bi * vocab_size + eos_token_id[i]] = -1e10; } + } } template <> @@ -43,16 +43,16 @@ __global__ inline void min_length_logits_process( const int64_t bs, const int64_t vocab_size, const int64_t eos_len) { - int bi = threadIdx.x; - if (bi >= bs) return; - if (cur_len[bi] < 0) { - return; - } - if (cur_len[bi] < min_len[bi]) { - for (int i = 0; i < eos_len; i++) { - logits[bi * vocab_size + eos_token_id[i]] = -1e4; - } + int bi = threadIdx.x; + if (bi >= bs) return; + if (cur_len[bi] < 0) { + return; + } + if (cur_len[bi] < min_len[bi]) { + for (int i = 0; i < eos_len; i++) { + logits[bi * vocab_size + eos_token_id[i]] = -1e4; } + } } __global__ void update_repeat_times(const int64_t *pre_ids, @@ -61,36 +61,46 @@ __global__ void update_repeat_times(const int64_t *pre_ids, const int64_t *cur_len, int *repeat_times, int *is_repeated, + const float *penalty_scores, + const float *frequency_score, + const float *presence_score, const int64_t bs, const int64_t vocab_size, const int64_t max_dec_len, const int64_t max_model_len) { - int64_t bi = blockIdx.x; - if (cur_len[bi] < 0) { - return; + int64_t bi = blockIdx.x; + float alpha = penalty_scores[bi]; + float beta = frequency_score[bi]; + float gamma = presence_score[bi]; + if (alpha == 1.f && beta == 0.f && gamma == 0.f) { + return; + } + if (cur_len[bi] < 0) { + return; + } + const int64_t prompt_len_now = prompt_len[bi]; + int64_t tid = threadIdx.x; + const int64_t *prompt_now = prompt_ids + bi * max_model_len; + const int64_t *pre_ids_now = pre_ids + bi * max_dec_len; + int *repeat_times_now = repeat_times + bi * vocab_size; + int *is_repeated_now = is_repeated + bi * vocab_size; + const int64_t loop_len = + prompt_len_now > max_dec_len ? prompt_len_now : max_dec_len; + for (int64_t i = tid; i < loop_len; i += blockDim.x) { + if (i < max_dec_len) { + int64_t id = pre_ids_now[i]; + if (id >= 0) { + atomicAdd(&repeat_times_now[id], 1); + atomicAdd(&is_repeated_now[id], 1); + } } - const int64_t prompt_len_now = prompt_len[bi]; - int64_t tid = threadIdx.x; - const int64_t *prompt_now = prompt_ids + bi * max_model_len; - const int64_t *pre_ids_now = pre_ids + bi * max_dec_len; - int *repeat_times_now = repeat_times + bi * vocab_size; - int *is_repeated_now = is_repeated + bi * vocab_size; - const int64_t loop_len = prompt_len_now > max_dec_len ? prompt_len_now : max_dec_len; - for (int64_t i = tid; i < loop_len; i += blockDim.x) { - if (i < max_dec_len) { - int64_t id = pre_ids_now[i]; - if (id >= 0) { - atomicAdd(&repeat_times_now[id], 1); - atomicAdd(&is_repeated_now[id], 1); - } - } - if (i < prompt_len_now) { - int64_t id = prompt_now[i]; - if (id >= 0) { - atomicAdd(&is_repeated_now[id], 1); - } - } + if (i < prompt_len_now) { + int64_t id = prompt_now[i]; + if (id >= 0) { + atomicAdd(&is_repeated_now[id], 1); + } } + } } template @@ -103,25 +113,29 @@ __global__ void update_value_by_repeat_times(const int *repeat_times, T *logits, const int64_t bs, const int64_t vocab_size) { - int bi = blockIdx.x; - int tid = threadIdx.x; - T *logits_now = logits + bi * vocab_size; - const int *repeat_times_now = repeat_times + bi * vocab_size; - const int *is_repeated_now = is_repeated + bi * vocab_size; - float alpha = static_cast(penalty_scores[bi]); - float beta = static_cast(frequency_score[bi]); - float gamma = static_cast(presence_score[bi]); - for (int i = tid; i < vocab_size; i += blockDim.x) { - int times = repeat_times_now[i]; - float logit_now = static_cast(logits_now[i]); - if (is_repeated_now[i] != 0) { - logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha; - } - if (times != 0) { - logit_now = logit_now - times * beta - gamma; - } - logits_now[i] = static_cast(logit_now / temperatures[bi]); + int bi = blockIdx.x; + int tid = threadIdx.x; + T *logits_now = logits + bi * vocab_size; + const int *repeat_times_now = repeat_times + bi * vocab_size; + const int *is_repeated_now = is_repeated + bi * vocab_size; + float alpha = static_cast(penalty_scores[bi]); + float beta = static_cast(frequency_score[bi]); + float gamma = static_cast(presence_score[bi]); + float temperature = temperatures[bi]; + if (alpha == 1.f && beta == 0.f && gamma == 0.f && temperature == 1.f) { + return; + } + for (int i = tid; i < vocab_size; i += blockDim.x) { + int times = repeat_times_now[i]; + float logit_now = static_cast(logits_now[i]); + if (is_repeated_now[i] != 0) { + logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha; } + if (times != 0) { + logit_now = logit_now - times * beta - gamma; + } + logits_now[i] = static_cast(logit_now / temperature); + } } template @@ -130,14 +144,14 @@ __global__ void ban_bad_words(T *logits, const int64_t bs, const int64_t vocab_size, const int64_t bad_words_len) { - const int bi = blockIdx.x; - int tid = threadIdx.x; - T *logits_now = logits + bi * vocab_size; - for (int i = tid; i < bad_words_len; i += blockDim.x) { - const int64_t bad_words_token_id = bad_words_list[i]; - if (bad_words_token_id >= vocab_size || bad_words_token_id < 0) continue; - logits_now[bad_words_token_id] = -1e10; - } + const int bi = blockIdx.x; + int tid = threadIdx.x; + T *logits_now = logits + bi * vocab_size; + for (int i = tid; i < bad_words_len; i += blockDim.x) { + const int64_t bad_words_token_id = bad_words_list[i]; + if (bad_words_token_id >= vocab_size || bad_words_token_id < 0) continue; + logits_now[bad_words_token_id] = -1e10; + } } template @@ -153,91 +167,95 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids, const paddle::Tensor &cur_len, const paddle::Tensor &min_len, const paddle::Tensor &eos_token_id) { - typedef PDTraits traits_; - typedef typename traits_::DataType DataType_; - typedef typename traits_::data_t data_t; + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; #ifdef PADDLE_WITH_CUSTOM_DEVICE - auto dev_ctx = static_cast(paddle::experimental::DeviceContextPool::Instance().Get(logits.place())); - auto cu_stream = dev_ctx->stream(); + auto dev_ctx = static_cast( + paddle::experimental::DeviceContextPool::Instance().Get(logits.place())); + auto cu_stream = dev_ctx->stream(); #else - auto cu_stream = logits.stream(); + auto cu_stream = logits.stream(); #endif - std::vector shape = logits.shape(); - auto repeat_times = - paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place()); - auto is_repeated = - paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place()); - int64_t bs = shape[0]; + std::vector shape = logits.shape(); + auto repeat_times = + paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place()); + auto is_repeated = + paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place()); + int64_t bs = shape[0]; - int64_t vocab_size = shape[1]; - int64_t max_dec_len = pre_ids.shape()[1]; - int64_t bad_words_len = bad_tokens.shape()[1]; - int64_t eos_len = eos_token_id.shape()[0]; - int64_t max_model_len = prompt_ids.shape()[1]; + int64_t vocab_size = shape[1]; + int64_t max_dec_len = pre_ids.shape()[1]; + int64_t bad_words_len = bad_tokens.shape()[1]; + int64_t eos_len = eos_token_id.shape()[0]; + int64_t max_model_len = prompt_ids.shape()[1]; - int block_size = (bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; - min_length_logits_process<<<1, block_size, 0, cu_stream>>>( - reinterpret_cast( - const_cast(logits.data())), - cur_len.data(), - min_len.data(), - eos_token_id.data(), - bs, - vocab_size, - eos_len); + int block_size = (bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; + min_length_logits_process<<<1, block_size, 0, cu_stream>>>( + reinterpret_cast( + const_cast(logits.data())), + cur_len.data(), + min_len.data(), + eos_token_id.data(), + bs, + vocab_size, + eos_len); - block_size = (max_dec_len + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; + block_size = (max_dec_len + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; #ifdef PADDLE_WITH_COREX - block_size = std::min(block_size, 512); + block_size = std::min(block_size, 512); #else - block_size = min(block_size, 512); + block_size = min(block_size, 512); #endif - update_repeat_times<<>>( - pre_ids.data(), - prompt_ids.data(), - prompt_len.data(), - cur_len.data(), - repeat_times.data(), - is_repeated.data(), - bs, - vocab_size, - max_dec_len, - max_model_len); + update_repeat_times<<>>( + pre_ids.data(), + prompt_ids.data(), + prompt_len.data(), + cur_len.data(), + repeat_times.data(), + is_repeated.data(), + penalty_scores.data(), + frequency_score.data(), + presence_score.data(), + bs, + vocab_size, + max_dec_len, + max_model_len); - block_size = (vocab_size + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; + block_size = (vocab_size + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; #ifdef PADDLE_WITH_COREX - block_size = std::min(block_size, 512); + block_size = std::min(block_size, 512); #else - block_size = min(block_size, 512); + block_size = min(block_size, 512); #endif - update_value_by_repeat_times<<>>( - repeat_times.data(), - is_repeated.data(), - reinterpret_cast( - const_cast(penalty_scores.data())), - reinterpret_cast( - const_cast(frequency_score.data())), - reinterpret_cast( - const_cast(presence_score.data())), - temperatures.data(), - reinterpret_cast( - const_cast(logits.data())), - bs, - vocab_size); + update_value_by_repeat_times<<>>( + repeat_times.data(), + is_repeated.data(), + reinterpret_cast( + const_cast(penalty_scores.data())), + reinterpret_cast( + const_cast(frequency_score.data())), + reinterpret_cast( + const_cast(presence_score.data())), + temperatures.data(), + reinterpret_cast( + const_cast(logits.data())), + bs, + vocab_size); - block_size = (bad_words_len + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; + block_size = (bad_words_len + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; #ifdef PADDLE_WITH_COREX - block_size = std::min(block_size, 512); + block_size = std::min(block_size, 512); #else - block_size = min(block_size, 512); + block_size = min(block_size, 512); #endif - ban_bad_words<<>>( - reinterpret_cast( - const_cast(logits.data())), - bad_tokens.data(), - bs, - vocab_size, - bad_words_len); + ban_bad_words<<>>( + reinterpret_cast( + const_cast(logits.data())), + bad_tokens.data(), + bs, + vocab_size, + bad_words_len); } void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids, @@ -252,59 +270,59 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids, const paddle::Tensor &cur_len, const paddle::Tensor &min_len, const paddle::Tensor &eos_token_id) { - switch (logits.type()) { - case paddle::DataType::BFLOAT16: { - return token_penalty_multi_scores_kernel< - paddle::DataType::BFLOAT16>(pre_ids, - prompt_ids, - prompt_len, - logits, - penalty_scores, - frequency_scores, - presence_scores, - temperatures, - bad_tokens, - cur_len, - min_len, - eos_token_id); - } - case paddle::DataType::FLOAT16: { - return token_penalty_multi_scores_kernel< - paddle::DataType::FLOAT16>(pre_ids, - prompt_ids, - prompt_len, - logits, - penalty_scores, - frequency_scores, - presence_scores, - temperatures, - bad_tokens, - cur_len, - min_len, - eos_token_id); - } - case paddle::DataType::FLOAT32: { - return token_penalty_multi_scores_kernel< - paddle::DataType::FLOAT32>(pre_ids, - prompt_ids, - prompt_len, - logits, - penalty_scores, - frequency_scores, - presence_scores, - temperatures, - bad_tokens, - cur_len, - min_len, - eos_token_id); - } - default: { - PD_THROW( - "NOT supported data type. " - "Only float16, bfloat16 and float32 are supported. "); - break; - } + switch (logits.type()) { + case paddle::DataType::BFLOAT16: { + return token_penalty_multi_scores_kernel( + pre_ids, + prompt_ids, + prompt_len, + logits, + penalty_scores, + frequency_scores, + presence_scores, + temperatures, + bad_tokens, + cur_len, + min_len, + eos_token_id); } + case paddle::DataType::FLOAT16: { + return token_penalty_multi_scores_kernel( + pre_ids, + prompt_ids, + prompt_len, + logits, + penalty_scores, + frequency_scores, + presence_scores, + temperatures, + bad_tokens, + cur_len, + min_len, + eos_token_id); + } + case paddle::DataType::FLOAT32: { + return token_penalty_multi_scores_kernel( + pre_ids, + prompt_ids, + prompt_len, + logits, + penalty_scores, + frequency_scores, + presence_scores, + temperatures, + bad_tokens, + cur_len, + min_len, + eos_token_id); + } + default: { + PD_THROW( + "NOT supported data type. " + "Only float16, bfloat16 and float32 are supported. "); + break; + } + } } PD_BUILD_STATIC_OP(get_token_penalty_multi_scores)