mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 08:21:53 +08:00
c++ code format (#4527)
This commit is contained in:
@@ -24,121 +24,176 @@ namespace api {
|
||||
namespace plugin {
|
||||
|
||||
template <typename T>
|
||||
DLL_EXPORT int set_stop_value_multi_ends(Context *ctx, bool *stop_flags,
|
||||
T *topk_ids, T *next_tokens,
|
||||
const T *end_ids, const int *seq_lens,
|
||||
const int bs, const int end_length,
|
||||
DLL_EXPORT int set_stop_value_multi_ends(Context* ctx,
|
||||
bool* stop_flags,
|
||||
T* topk_ids,
|
||||
T* next_tokens,
|
||||
const T* end_ids,
|
||||
const int* seq_lens,
|
||||
const int bs,
|
||||
const int end_length,
|
||||
const bool beam_search);
|
||||
|
||||
DLL_EXPORT int set_value_by_flags_and_idx(Context *ctx, const bool *stop_flags,
|
||||
int64_t *pre_ids_all,
|
||||
const int64_t *input_ids,
|
||||
const int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder,
|
||||
const int64_t *step_idx, int bs,
|
||||
int length, int length_input_ids);
|
||||
DLL_EXPORT int set_value_by_flags_and_idx(Context* ctx,
|
||||
const bool* stop_flags,
|
||||
int64_t* pre_ids_all,
|
||||
const int64_t* input_ids,
|
||||
const int* seq_lens_encoder,
|
||||
const int* seq_lens_decoder,
|
||||
const int64_t* step_idx,
|
||||
int bs,
|
||||
int length,
|
||||
int length_input_ids);
|
||||
|
||||
template <typename T>
|
||||
DLL_EXPORT int token_penalty_multi_scores(
|
||||
Context *ctx, const int64_t *pre_ids, T *logits, const T *penalty_scores,
|
||||
const T *frequency_scores, const T *presence_scores,
|
||||
const float *temperatures, const int64_t *cur_len, const int64_t *min_len,
|
||||
const int64_t *eos_token_id, const int64_t *bad_words, const int64_t bs,
|
||||
const int64_t length, const int64_t length_id, const int64_t end_length,
|
||||
const int64_t length_bad_words);
|
||||
DLL_EXPORT int token_penalty_multi_scores(Context* ctx,
|
||||
const int64_t* pre_ids,
|
||||
T* logits,
|
||||
const T* penalty_scores,
|
||||
const T* frequency_scores,
|
||||
const T* presence_scores,
|
||||
const float* temperatures,
|
||||
const int64_t* cur_len,
|
||||
const int64_t* min_len,
|
||||
const int64_t* eos_token_id,
|
||||
const int64_t* bad_words,
|
||||
const int64_t bs,
|
||||
const int64_t length,
|
||||
const int64_t length_id,
|
||||
const int64_t end_length,
|
||||
const int64_t length_bad_words);
|
||||
|
||||
DLL_EXPORT int get_padding_offset(Context *ctx, int *padding_offset,
|
||||
int *cum_offsets_out, int *cu_seqlens_q,
|
||||
int *cu_seqlens_k, int64_t *x_remove_padding,
|
||||
const int64_t *input_ids,
|
||||
const int *cum_offsets, const int *seq_lens,
|
||||
const int max_seq_len, const int bs);
|
||||
DLL_EXPORT int get_padding_offset(Context* ctx,
|
||||
int* padding_offset,
|
||||
int* cum_offsets_out,
|
||||
int* cu_seqlens_q,
|
||||
int* cu_seqlens_k,
|
||||
int64_t* x_remove_padding,
|
||||
const int64_t* input_ids,
|
||||
const int* cum_offsets,
|
||||
const int* seq_lens,
|
||||
const int max_seq_len,
|
||||
const int bs);
|
||||
|
||||
DLL_EXPORT int update_inputs(Context *ctx, bool *not_need_stop,
|
||||
int *seq_lens_this_time, int *seq_lens_encoder,
|
||||
int *seq_lens_decoder, int64_t *input_ids,
|
||||
const int64_t *stop_nums, const bool *stop_flags,
|
||||
const bool *is_block_step,
|
||||
const int64_t *next_tokens, const int bsz,
|
||||
const int max_bsz, const int input_ids_stride);
|
||||
DLL_EXPORT int update_inputs(Context* ctx,
|
||||
bool* not_need_stop,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int64_t* input_ids,
|
||||
const int64_t* stop_nums,
|
||||
const bool* stop_flags,
|
||||
const bool* is_block_step,
|
||||
const int64_t* next_tokens,
|
||||
const int bsz,
|
||||
const int max_bsz,
|
||||
const int input_ids_stride);
|
||||
|
||||
DLL_EXPORT int free_and_dispatch_block(
|
||||
Context *ctx, bool *stop_flags, int *seq_lens_this_time,
|
||||
int *seq_lens_decoder, int *block_tables, int *encoder_block_lens,
|
||||
bool *is_block_step,
|
||||
int *step_block_list, // [bsz]
|
||||
int *step_len, int *recover_block_list, int *recover_len,
|
||||
int *need_block_list, int *need_block_len, int *used_list_len,
|
||||
int *free_list, int *free_list_len, int64_t *first_token_ids, const int bsz,
|
||||
const int block_size, const int block_num_per_seq,
|
||||
const int max_decoder_block_num);
|
||||
DLL_EXPORT int free_and_dispatch_block(Context* ctx,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_decoder,
|
||||
int* block_tables,
|
||||
int* encoder_block_lens,
|
||||
bool* is_block_step,
|
||||
int* step_block_list, // [bsz]
|
||||
int* step_len,
|
||||
int* recover_block_list,
|
||||
int* recover_len,
|
||||
int* need_block_list,
|
||||
int* need_block_len,
|
||||
int* used_list_len,
|
||||
int* free_list,
|
||||
int* free_list_len,
|
||||
int64_t* first_token_ids,
|
||||
const int bsz,
|
||||
const int block_size,
|
||||
const int block_num_per_seq,
|
||||
const int max_decoder_block_num);
|
||||
|
||||
DLL_EXPORT int
|
||||
recover_block(Context *ctx,
|
||||
int *recover_block_list, // [bsz]
|
||||
int *recover_len, bool *stop_flags, int *seq_lens_this_time,
|
||||
const int *ori_seq_lens_encoder, int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder, int *block_tables, int *free_list,
|
||||
int *free_list_len, int64_t *input_ids, const int64_t *pre_ids,
|
||||
const int64_t *step_idx, const int *encoder_block_lens,
|
||||
const int *used_list_len, const int64_t *next_tokens,
|
||||
const int64_t *first_token_ids, const int bsz,
|
||||
const int block_num_per_seq, const int length,
|
||||
const int pre_id_length);
|
||||
DLL_EXPORT int recover_block(Context* ctx,
|
||||
int* recover_block_list, // [bsz]
|
||||
int* recover_len,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
const int* ori_seq_lens_encoder,
|
||||
int* seq_lens_encoder,
|
||||
const int* seq_lens_decoder,
|
||||
int* block_tables,
|
||||
int* free_list,
|
||||
int* free_list_len,
|
||||
int64_t* input_ids,
|
||||
const int64_t* pre_ids,
|
||||
const int64_t* step_idx,
|
||||
const int* encoder_block_lens,
|
||||
const int* used_list_len,
|
||||
const int64_t* next_tokens,
|
||||
const int64_t* first_token_ids,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int length,
|
||||
const int pre_id_length);
|
||||
|
||||
|
||||
DLL_EXPORT int
|
||||
recover_decode_task(Context *ctx, bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int *block_tables,
|
||||
bool *is_block_step,
|
||||
DLL_EXPORT int recover_decode_task(Context* ctx,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int* step_seq_lens_decoder,
|
||||
int* block_tables,
|
||||
bool* is_block_step,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int block_size);
|
||||
|
||||
DLL_EXPORT int
|
||||
update_inputs_v1(Context *ctx, bool *not_need_stop,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int64_t *prompt_lens,
|
||||
int64_t *topk_ids,
|
||||
int64_t *input_ids,
|
||||
int *block_tables,
|
||||
const int64_t *stop_nums,
|
||||
bool *stop_flags,
|
||||
bool *is_block_step,
|
||||
const int64_t *next_tokens,
|
||||
const int bsz,
|
||||
const int max_bsz,
|
||||
const int input_ids_stride,
|
||||
const int block_num_per_seq,
|
||||
const int block_size);
|
||||
DLL_EXPORT int update_inputs_v1(Context* ctx,
|
||||
bool* not_need_stop,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int* step_seq_lens_decoder,
|
||||
int64_t* prompt_lens,
|
||||
int64_t* topk_ids,
|
||||
int64_t* input_ids,
|
||||
int* block_tables,
|
||||
const int64_t* stop_nums,
|
||||
bool* stop_flags,
|
||||
bool* is_block_step,
|
||||
const int64_t* next_tokens,
|
||||
const int bsz,
|
||||
const int max_bsz,
|
||||
const int input_ids_stride,
|
||||
const int block_num_per_seq,
|
||||
const int block_size);
|
||||
|
||||
template <typename TX, typename TY>
|
||||
DLL_EXPORT int
|
||||
eb_adjust_batch(Context *ctx, const TX *x, TY *y,
|
||||
VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
|
||||
VectorParam<int32_t> &encoder_batch_map, // NOLINT
|
||||
VectorParam<int32_t> &decoder_batch_map, // NOLINT
|
||||
int64_t hidden_dim);
|
||||
DLL_EXPORT int eb_adjust_batch(
|
||||
Context* ctx,
|
||||
const TX* x,
|
||||
TY* y,
|
||||
VectorParam<int32_t>& encoder_seqs_lods, // NOLINT
|
||||
VectorParam<int32_t>& encoder_batch_map, // NOLINT
|
||||
VectorParam<int32_t>& decoder_batch_map, // NOLINT
|
||||
int64_t hidden_dim);
|
||||
|
||||
template <typename TX, typename TY>
|
||||
DLL_EXPORT int
|
||||
eb_gather_next_token(Context *ctx, const TX *x, TY *y,
|
||||
VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
|
||||
VectorParam<int32_t> &encoder_batch_map, // NOLINT
|
||||
VectorParam<int32_t> &decoder_batch_map, // NOLINT
|
||||
int64_t hidden_dim);
|
||||
DLL_EXPORT int eb_gather_next_token(
|
||||
Context* ctx,
|
||||
const TX* x,
|
||||
TY* y,
|
||||
VectorParam<int32_t>& encoder_seqs_lods, // NOLINT
|
||||
VectorParam<int32_t>& encoder_batch_map, // NOLINT
|
||||
VectorParam<int32_t>& decoder_batch_map, // NOLINT
|
||||
int64_t hidden_dim);
|
||||
|
||||
template <typename TX, typename TSCALE = float, typename TY = int8_t>
|
||||
DLL_EXPORT int quant2d_per_channel(api::Context *ctx, const TX *x,
|
||||
const TSCALE *scale_in, TY *y,
|
||||
TSCALE *scale_out, int64_t m, int64_t n);
|
||||
DLL_EXPORT int quant2d_per_channel(api::Context* ctx,
|
||||
const TX* x,
|
||||
const TSCALE* scale_in,
|
||||
TY* y,
|
||||
TSCALE* scale_out,
|
||||
int64_t m,
|
||||
int64_t n);
|
||||
|
||||
DLL_EXPORT int text_image_index_out(Context* ctx,
|
||||
const int* token_type_ids, // x
|
||||
@@ -160,7 +215,8 @@ DLL_EXPORT int text_image_gather_scatter(api::Context* ctx,
|
||||
int64_t hidden_size,
|
||||
bool is_scatter);
|
||||
|
||||
/*--------------------------------------- MTP being --------------------------------------------*/
|
||||
/*--------------------------------------- MTP being
|
||||
* --------------------------------------------*/
|
||||
|
||||
template <typename T>
|
||||
DLL_EXPORT int speculate_token_penalty_multi_scores(
|
||||
@@ -200,7 +256,6 @@ DLL_EXPORT int mtp_free_and_dispatch_block(Context* ctx,
|
||||
const int block_num_per_seq,
|
||||
const int max_draft_tokens);
|
||||
|
||||
|
||||
template <bool ENABLE_TOPP, bool USE_TOPK>
|
||||
DLL_EXPORT int speculate_verify(Context* ctx,
|
||||
int64_t* accept_tokens,
|
||||
@@ -457,9 +512,10 @@ DLL_EXPORT int rebuild_self_hidden_states(api::Context* ctx,
|
||||
T* output,
|
||||
int dim_embed,
|
||||
int elem_cnt);
|
||||
/*--------------------------------------- MTP end --------------------------------------------*/
|
||||
/*--------------------------------------- MTP end
|
||||
* --------------------------------------------*/
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
|
||||
Reference in New Issue
Block a user