c++ code format (#4527)

This commit is contained in:
zhupengyang
2025-10-22 17:59:50 +08:00
committed by GitHub
parent d7bcedf421
commit 3a6883ac1a
97 changed files with 8760 additions and 7382 deletions
+159 -103
View File
@@ -24,121 +24,176 @@ namespace api {
namespace plugin {
template <typename T>
DLL_EXPORT int set_stop_value_multi_ends(Context *ctx, bool *stop_flags,
T *topk_ids, T *next_tokens,
const T *end_ids, const int *seq_lens,
const int bs, const int end_length,
DLL_EXPORT int set_stop_value_multi_ends(Context* ctx,
bool* stop_flags,
T* topk_ids,
T* next_tokens,
const T* end_ids,
const int* seq_lens,
const int bs,
const int end_length,
const bool beam_search);
DLL_EXPORT int set_value_by_flags_and_idx(Context *ctx, const bool *stop_flags,
int64_t *pre_ids_all,
const int64_t *input_ids,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
const int64_t *step_idx, int bs,
int length, int length_input_ids);
DLL_EXPORT int set_value_by_flags_and_idx(Context* ctx,
const bool* stop_flags,
int64_t* pre_ids_all,
const int64_t* input_ids,
const int* seq_lens_encoder,
const int* seq_lens_decoder,
const int64_t* step_idx,
int bs,
int length,
int length_input_ids);
template <typename T>
DLL_EXPORT int token_penalty_multi_scores(
Context *ctx, const int64_t *pre_ids, T *logits, const T *penalty_scores,
const T *frequency_scores, const T *presence_scores,
const float *temperatures, const int64_t *cur_len, const int64_t *min_len,
const int64_t *eos_token_id, const int64_t *bad_words, const int64_t bs,
const int64_t length, const int64_t length_id, const int64_t end_length,
const int64_t length_bad_words);
DLL_EXPORT int token_penalty_multi_scores(Context* ctx,
const int64_t* pre_ids,
T* logits,
const T* penalty_scores,
const T* frequency_scores,
const T* presence_scores,
const float* temperatures,
const int64_t* cur_len,
const int64_t* min_len,
const int64_t* eos_token_id,
const int64_t* bad_words,
const int64_t bs,
const int64_t length,
const int64_t length_id,
const int64_t end_length,
const int64_t length_bad_words);
DLL_EXPORT int get_padding_offset(Context *ctx, int *padding_offset,
int *cum_offsets_out, int *cu_seqlens_q,
int *cu_seqlens_k, int64_t *x_remove_padding,
const int64_t *input_ids,
const int *cum_offsets, const int *seq_lens,
const int max_seq_len, const int bs);
DLL_EXPORT int get_padding_offset(Context* ctx,
int* padding_offset,
int* cum_offsets_out,
int* cu_seqlens_q,
int* cu_seqlens_k,
int64_t* x_remove_padding,
const int64_t* input_ids,
const int* cum_offsets,
const int* seq_lens,
const int max_seq_len,
const int bs);
DLL_EXPORT int update_inputs(Context *ctx, bool *not_need_stop,
int *seq_lens_this_time, int *seq_lens_encoder,
int *seq_lens_decoder, int64_t *input_ids,
const int64_t *stop_nums, const bool *stop_flags,
const bool *is_block_step,
const int64_t *next_tokens, const int bsz,
const int max_bsz, const int input_ids_stride);
DLL_EXPORT int update_inputs(Context* ctx,
bool* not_need_stop,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* input_ids,
const int64_t* stop_nums,
const bool* stop_flags,
const bool* is_block_step,
const int64_t* next_tokens,
const int bsz,
const int max_bsz,
const int input_ids_stride);
DLL_EXPORT int free_and_dispatch_block(
Context *ctx, bool *stop_flags, int *seq_lens_this_time,
int *seq_lens_decoder, int *block_tables, int *encoder_block_lens,
bool *is_block_step,
int *step_block_list, // [bsz]
int *step_len, int *recover_block_list, int *recover_len,
int *need_block_list, int *need_block_len, int *used_list_len,
int *free_list, int *free_list_len, int64_t *first_token_ids, const int bsz,
const int block_size, const int block_num_per_seq,
const int max_decoder_block_num);
DLL_EXPORT int free_and_dispatch_block(Context* ctx,
bool* stop_flags,
int* seq_lens_this_time,
int* seq_lens_decoder,
int* block_tables,
int* encoder_block_lens,
bool* is_block_step,
int* step_block_list, // [bsz]
int* step_len,
int* recover_block_list,
int* recover_len,
int* need_block_list,
int* need_block_len,
int* used_list_len,
int* free_list,
int* free_list_len,
int64_t* first_token_ids,
const int bsz,
const int block_size,
const int block_num_per_seq,
const int max_decoder_block_num);
DLL_EXPORT int
recover_block(Context *ctx,
int *recover_block_list, // [bsz]
int *recover_len, bool *stop_flags, int *seq_lens_this_time,
const int *ori_seq_lens_encoder, int *seq_lens_encoder,
const int *seq_lens_decoder, int *block_tables, int *free_list,
int *free_list_len, int64_t *input_ids, const int64_t *pre_ids,
const int64_t *step_idx, const int *encoder_block_lens,
const int *used_list_len, const int64_t *next_tokens,
const int64_t *first_token_ids, const int bsz,
const int block_num_per_seq, const int length,
const int pre_id_length);
DLL_EXPORT int recover_block(Context* ctx,
int* recover_block_list, // [bsz]
int* recover_len,
bool* stop_flags,
int* seq_lens_this_time,
const int* ori_seq_lens_encoder,
int* seq_lens_encoder,
const int* seq_lens_decoder,
int* block_tables,
int* free_list,
int* free_list_len,
int64_t* input_ids,
const int64_t* pre_ids,
const int64_t* step_idx,
const int* encoder_block_lens,
const int* used_list_len,
const int64_t* next_tokens,
const int64_t* first_token_ids,
const int bsz,
const int block_num_per_seq,
const int length,
const int pre_id_length);
DLL_EXPORT int
recover_decode_task(Context *ctx, bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int *block_tables,
bool *is_block_step,
DLL_EXPORT int recover_decode_task(Context* ctx,
bool* stop_flags,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int* step_seq_lens_decoder,
int* block_tables,
bool* is_block_step,
const int bsz,
const int block_num_per_seq,
const int block_size);
DLL_EXPORT int
update_inputs_v1(Context *ctx, bool *not_need_stop,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int64_t *prompt_lens,
int64_t *topk_ids,
int64_t *input_ids,
int *block_tables,
const int64_t *stop_nums,
bool *stop_flags,
bool *is_block_step,
const int64_t *next_tokens,
const int bsz,
const int max_bsz,
const int input_ids_stride,
const int block_num_per_seq,
const int block_size);
DLL_EXPORT int update_inputs_v1(Context* ctx,
bool* not_need_stop,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int* step_seq_lens_decoder,
int64_t* prompt_lens,
int64_t* topk_ids,
int64_t* input_ids,
int* block_tables,
const int64_t* stop_nums,
bool* stop_flags,
bool* is_block_step,
const int64_t* next_tokens,
const int bsz,
const int max_bsz,
const int input_ids_stride,
const int block_num_per_seq,
const int block_size);
template <typename TX, typename TY>
DLL_EXPORT int
eb_adjust_batch(Context *ctx, const TX *x, TY *y,
VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
VectorParam<int32_t> &encoder_batch_map, // NOLINT
VectorParam<int32_t> &decoder_batch_map, // NOLINT
int64_t hidden_dim);
DLL_EXPORT int eb_adjust_batch(
Context* ctx,
const TX* x,
TY* y,
VectorParam<int32_t>& encoder_seqs_lods, // NOLINT
VectorParam<int32_t>& encoder_batch_map, // NOLINT
VectorParam<int32_t>& decoder_batch_map, // NOLINT
int64_t hidden_dim);
template <typename TX, typename TY>
DLL_EXPORT int
eb_gather_next_token(Context *ctx, const TX *x, TY *y,
VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
VectorParam<int32_t> &encoder_batch_map, // NOLINT
VectorParam<int32_t> &decoder_batch_map, // NOLINT
int64_t hidden_dim);
DLL_EXPORT int eb_gather_next_token(
Context* ctx,
const TX* x,
TY* y,
VectorParam<int32_t>& encoder_seqs_lods, // NOLINT
VectorParam<int32_t>& encoder_batch_map, // NOLINT
VectorParam<int32_t>& decoder_batch_map, // NOLINT
int64_t hidden_dim);
template <typename TX, typename TSCALE = float, typename TY = int8_t>
DLL_EXPORT int quant2d_per_channel(api::Context *ctx, const TX *x,
const TSCALE *scale_in, TY *y,
TSCALE *scale_out, int64_t m, int64_t n);
DLL_EXPORT int quant2d_per_channel(api::Context* ctx,
const TX* x,
const TSCALE* scale_in,
TY* y,
TSCALE* scale_out,
int64_t m,
int64_t n);
DLL_EXPORT int text_image_index_out(Context* ctx,
const int* token_type_ids, // x
@@ -160,7 +215,8 @@ DLL_EXPORT int text_image_gather_scatter(api::Context* ctx,
int64_t hidden_size,
bool is_scatter);
/*--------------------------------------- MTP being --------------------------------------------*/
/*--------------------------------------- MTP being
* --------------------------------------------*/
template <typename T>
DLL_EXPORT int speculate_token_penalty_multi_scores(
@@ -200,7 +256,6 @@ DLL_EXPORT int mtp_free_and_dispatch_block(Context* ctx,
const int block_num_per_seq,
const int max_draft_tokens);
template <bool ENABLE_TOPP, bool USE_TOPK>
DLL_EXPORT int speculate_verify(Context* ctx,
int64_t* accept_tokens,
@@ -457,9 +512,10 @@ DLL_EXPORT int rebuild_self_hidden_states(api::Context* ctx,
T* output,
int dim_embed,
int elem_cnt);
/*--------------------------------------- MTP end --------------------------------------------*/
/*--------------------------------------- MTP end
* --------------------------------------------*/
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu