mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 08:21:53 +08:00
[XPU] support kernel for mtp(base) (#4748)
* [XPU] support kernel for mtp(base) * [XPU] support kernel for mtp(base) * format * format * format * fix gather next token * fix step && add test * fix * mv pre/post process * add adjust batch / gather next token for mtp * fix code style * fix mtp kenrel name * fix mtp kernel test * mv xpu pre/post process * mv xpu pre/post process
This commit is contained in:
@@ -75,6 +75,48 @@ DLL_EXPORT int get_padding_offset(Context* ctx,
|
||||
const int max_seq_len,
|
||||
const int bs);
|
||||
|
||||
DLL_EXPORT int speculate_get_padding_offset(Context* ctx,
|
||||
int* batch_id_per_token,
|
||||
int* cum_offsets_out,
|
||||
int* cu_seqlens_q,
|
||||
int* cu_seqlens_k,
|
||||
const int* cum_offsets,
|
||||
const int* seq_lens,
|
||||
const int max_seq_len,
|
||||
int bsz);
|
||||
|
||||
DLL_EXPORT int draft_model_preprocess(api::Context* ctx,
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* is_block_step,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* base_model_seq_lens_decoder,
|
||||
const int64_t* base_model_step_idx,
|
||||
const bool* base_model_stop_flags,
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
const int bsz,
|
||||
const int num_model_step,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len,
|
||||
const int pre_ids_len,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill,
|
||||
const bool kvcache_scheduler_v1);
|
||||
|
||||
DLL_EXPORT int update_inputs(Context* ctx,
|
||||
bool* not_need_stop,
|
||||
int* seq_lens_this_time,
|
||||
@@ -111,6 +153,31 @@ DLL_EXPORT int free_and_dispatch_block(Context* ctx,
|
||||
const int block_num_per_seq,
|
||||
const int max_decoder_block_num);
|
||||
|
||||
DLL_EXPORT int speculate_free_and_dispatch_block(
|
||||
Context* ctx,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_decoder,
|
||||
int* block_tables,
|
||||
int* encoder_block_lens,
|
||||
bool* is_block_step,
|
||||
int* step_block_list, // [bsz]
|
||||
int* step_len,
|
||||
int* recover_block_list,
|
||||
int* recover_len,
|
||||
int* need_block_list,
|
||||
int* need_block_len,
|
||||
int* used_list_len,
|
||||
int* free_list,
|
||||
int* free_list_len,
|
||||
int64_t* first_token_ids,
|
||||
int* accept_num,
|
||||
const int bsz,
|
||||
const int block_size,
|
||||
const int block_num_per_seq,
|
||||
const int max_decoder_block_num,
|
||||
const int max_draft_tokens);
|
||||
|
||||
DLL_EXPORT int recover_block(Context* ctx,
|
||||
int* recover_block_list, // [bsz]
|
||||
int* recover_len,
|
||||
@@ -134,6 +201,29 @@ DLL_EXPORT int recover_block(Context* ctx,
|
||||
const int length,
|
||||
const int pre_id_length);
|
||||
|
||||
DLL_EXPORT int speculate_recover_block(Context* ctx,
|
||||
int* recover_block_list, // [bsz]
|
||||
int* recover_len,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
const int* ori_seq_lens_encoder,
|
||||
int* seq_lens_encoder,
|
||||
const int* seq_lens_decoder,
|
||||
int* block_tables,
|
||||
int* free_list,
|
||||
int* free_list_len,
|
||||
int64_t* input_ids,
|
||||
const int64_t* pre_ids,
|
||||
const int64_t* step_idx,
|
||||
const int* encoder_block_lens,
|
||||
const int* used_list_len,
|
||||
const int64_t* next_tokens,
|
||||
const int64_t* first_token_ids,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int length,
|
||||
const int pre_id_length);
|
||||
|
||||
DLL_EXPORT int recover_decode_task(Context* ctx,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
@@ -172,6 +262,7 @@ DLL_EXPORT int eb_adjust_batch(
|
||||
const TX* x,
|
||||
TY* y,
|
||||
VectorParam<int32_t>& encoder_seqs_lods, // NOLINT
|
||||
VectorParam<int32_t>& decoder_seqs_lods, // NOLINT
|
||||
VectorParam<int32_t>& encoder_batch_map, // NOLINT
|
||||
VectorParam<int32_t>& decoder_batch_map, // NOLINT
|
||||
int64_t hidden_dim);
|
||||
@@ -186,6 +277,17 @@ DLL_EXPORT int eb_gather_next_token(
|
||||
VectorParam<int32_t>& decoder_batch_map, // NOLINT
|
||||
int64_t hidden_dim);
|
||||
|
||||
template <typename TX, typename TY>
|
||||
DLL_EXPORT int eb_mtp_gather_next_token(
|
||||
Context* ctx,
|
||||
const TX* x,
|
||||
TY* y,
|
||||
VectorParam<int32_t>& encoder_seqs_lods, // NOLINT
|
||||
VectorParam<int32_t>& decoder_seqs_lods, // NOLINT
|
||||
VectorParam<int32_t>& encoder_batch_map, // NOLINT
|
||||
VectorParam<int32_t>& decoder_batch_map, // NOLINT
|
||||
int64_t hidden_dim);
|
||||
|
||||
template <typename TX, typename TSCALE = float, typename TY = int8_t>
|
||||
DLL_EXPORT int quant2d_per_channel(api::Context* ctx,
|
||||
const TX* x,
|
||||
@@ -305,7 +407,8 @@ DLL_EXPORT int speculate_verify(Context* ctx,
|
||||
const int max_seq_len,
|
||||
const int max_candidate_len,
|
||||
const int verify_window,
|
||||
const bool prefill_one_step_stop);
|
||||
const bool prefill_one_step_stop,
|
||||
const bool benchmark_mode);
|
||||
|
||||
DLL_EXPORT int speculate_clear_accept_nums(Context* ctx,
|
||||
int* accept_num,
|
||||
@@ -342,35 +445,6 @@ DLL_EXPORT int draft_model_update(Context* ctx,
|
||||
const int substep,
|
||||
const bool prefill_one_step_stop);
|
||||
|
||||
DLL_EXPORT int draft_model_preprocess(api::Context* ctx,
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
int* seq_lens_encoder_record,
|
||||
int* seq_lens_decoder_record,
|
||||
bool* not_need_stop,
|
||||
bool* batch_drop,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* base_model_seq_lens_decoder,
|
||||
const int64_t* base_model_step_idx,
|
||||
const bool* base_model_stop_flags,
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
int real_bsz,
|
||||
int max_draft_token,
|
||||
int accept_tokens_len,
|
||||
int draft_tokens_len,
|
||||
int input_ids_len,
|
||||
int base_model_draft_tokens_len,
|
||||
bool truncate_first_token,
|
||||
bool splitwise_prefill);
|
||||
|
||||
DLL_EXPORT int speculate_set_stop_value_multi_seqs(Context* ctx,
|
||||
bool* stop_flags,
|
||||
int64_t* accept_tokens,
|
||||
@@ -411,16 +485,6 @@ DLL_EXPORT int speculate_remove_padding(Context* ctx,
|
||||
int bsz,
|
||||
int token_num_data);
|
||||
|
||||
DLL_EXPORT int speculate_get_padding_offset(Context* ctx,
|
||||
int* padding_offset,
|
||||
int* cum_offsets_out,
|
||||
int* cu_seqlens_q,
|
||||
int* cu_seqlens_k,
|
||||
const int* cum_offsets,
|
||||
const int* seq_lens,
|
||||
const int max_seq_len,
|
||||
int bsz);
|
||||
|
||||
DLL_EXPORT int compute_self_order(api::Context* ctx,
|
||||
const int* last_seq_lens_this_time,
|
||||
const int* seq_lens_this_time,
|
||||
|
||||
Reference in New Issue
Block a user