mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 08:21:53 +08:00
[XPU] refactor: XPU plugin namespace migration (#6799)
* [XPU] refactor: XPU plugin namespace migration - Migrate wrapper layer namespace from baidu::xpu::api::plugin to fastdeploy::plugin - Migrate kernel layer namespace from xpu3::plugin to fd_xpu3 - Add api:: prefix for types (Context, SUCCESS, XPUIndexType, ctx_guard) - Remove XPU2 support, keep only XPU3 - Update ops/ directory to use new namespace Total: 137 files changed * [XPU] fix: add return value check and correct error messages - Add PADDLE_ENFORCE_XDNN_SUCCESS check for speculate_get_logits and update_attn_mask_offsets - Fix empty error message in draft_model_postprocess - Correct function name in speculate_schedule_cache error message - Update error messages from 'xpu::plugin::' to 'fastdeploy::plugin::'
This commit is contained in:
@@ -18,13 +18,17 @@
|
||||
#pragma once
|
||||
#include "xpu/xdnn.h"
|
||||
|
||||
namespace baidu {
|
||||
namespace xpu {
|
||||
namespace api {
|
||||
namespace fd_xpu3 {
|
||||
typedef xpu3::int64_t int64_t;
|
||||
}
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace plugin {
|
||||
|
||||
namespace api = baidu::xpu::api;
|
||||
|
||||
template <typename T>
|
||||
DLL_EXPORT int set_stop_value_multi_ends(Context* ctx,
|
||||
DLL_EXPORT int set_stop_value_multi_ends(api::Context* ctx,
|
||||
bool* stop_flags,
|
||||
T* topk_ids,
|
||||
T* next_tokens,
|
||||
@@ -34,7 +38,7 @@ DLL_EXPORT int set_stop_value_multi_ends(Context* ctx,
|
||||
const int end_length,
|
||||
const bool beam_search);
|
||||
|
||||
DLL_EXPORT int set_value_by_flags_and_idx(Context* ctx,
|
||||
DLL_EXPORT int set_value_by_flags_and_idx(api::Context* ctx,
|
||||
const bool* stop_flags,
|
||||
int64_t* pre_ids_all,
|
||||
const int64_t* input_ids,
|
||||
@@ -46,7 +50,7 @@ DLL_EXPORT int set_value_by_flags_and_idx(Context* ctx,
|
||||
int length_input_ids);
|
||||
|
||||
template <typename T>
|
||||
DLL_EXPORT int token_penalty_multi_scores(Context* ctx,
|
||||
DLL_EXPORT int token_penalty_multi_scores(api::Context* ctx,
|
||||
const int64_t* pre_ids,
|
||||
T* logits,
|
||||
const T* penalty_scores,
|
||||
@@ -63,7 +67,7 @@ DLL_EXPORT int token_penalty_multi_scores(Context* ctx,
|
||||
const int64_t end_length,
|
||||
const int64_t length_bad_words);
|
||||
|
||||
DLL_EXPORT int get_padding_offset(Context* ctx,
|
||||
DLL_EXPORT int get_padding_offset(api::Context* ctx,
|
||||
int* padding_offset,
|
||||
int* cum_offsets_out,
|
||||
int* cu_seqlens_q,
|
||||
@@ -75,7 +79,7 @@ DLL_EXPORT int get_padding_offset(Context* ctx,
|
||||
const int max_seq_len,
|
||||
const int bs);
|
||||
|
||||
DLL_EXPORT int speculate_get_padding_offset(Context* ctx,
|
||||
DLL_EXPORT int speculate_get_padding_offset(api::Context* ctx,
|
||||
int* batch_id_per_token,
|
||||
int* cum_offsets_out,
|
||||
int* cu_seqlens_q,
|
||||
@@ -117,7 +121,7 @@ DLL_EXPORT int draft_model_preprocess(api::Context* ctx,
|
||||
const bool splitwise_prefill,
|
||||
const bool kvcache_scheduler_v1);
|
||||
|
||||
DLL_EXPORT int update_inputs(Context* ctx,
|
||||
DLL_EXPORT int update_inputs(api::Context* ctx,
|
||||
bool* not_need_stop,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
@@ -130,7 +134,7 @@ DLL_EXPORT int update_inputs(Context* ctx,
|
||||
const int max_bsz,
|
||||
const int input_ids_stride);
|
||||
|
||||
DLL_EXPORT int free_and_dispatch_block(Context* ctx,
|
||||
DLL_EXPORT int free_and_dispatch_block(api::Context* ctx,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_decoder,
|
||||
@@ -153,7 +157,7 @@ DLL_EXPORT int free_and_dispatch_block(Context* ctx,
|
||||
const int max_decoder_block_num);
|
||||
|
||||
DLL_EXPORT int speculate_free_and_dispatch_block(
|
||||
Context* ctx,
|
||||
api::Context* ctx,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_decoder,
|
||||
@@ -177,7 +181,7 @@ DLL_EXPORT int speculate_free_and_dispatch_block(
|
||||
const int max_decoder_block_num,
|
||||
const int max_draft_tokens);
|
||||
|
||||
DLL_EXPORT int recover_block(Context* ctx,
|
||||
DLL_EXPORT int recover_block(api::Context* ctx,
|
||||
int* recover_block_list, // [bsz]
|
||||
int* recover_len,
|
||||
bool* stop_flags,
|
||||
@@ -200,7 +204,7 @@ DLL_EXPORT int recover_block(Context* ctx,
|
||||
const int length,
|
||||
const int pre_id_length);
|
||||
|
||||
DLL_EXPORT int speculate_recover_block(Context* ctx,
|
||||
DLL_EXPORT int speculate_recover_block(api::Context* ctx,
|
||||
int* recover_block_list, // [bsz]
|
||||
int* recover_len,
|
||||
bool* stop_flags,
|
||||
@@ -224,7 +228,7 @@ DLL_EXPORT int speculate_recover_block(Context* ctx,
|
||||
const int length,
|
||||
const int pre_id_length);
|
||||
|
||||
DLL_EXPORT int recover_decode_task(Context* ctx,
|
||||
DLL_EXPORT int recover_decode_task(api::Context* ctx,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
@@ -236,7 +240,7 @@ DLL_EXPORT int recover_decode_task(Context* ctx,
|
||||
const int block_num_per_seq,
|
||||
const int block_size);
|
||||
|
||||
DLL_EXPORT int recover_spec_decode_task(Context* ctx,
|
||||
DLL_EXPORT int recover_spec_decode_task(api::Context* ctx,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
@@ -253,7 +257,7 @@ DLL_EXPORT int recover_spec_decode_task(Context* ctx,
|
||||
const int draft_tokens_len,
|
||||
const int num_extra_tokens);
|
||||
|
||||
DLL_EXPORT int update_inputs_v1(Context* ctx,
|
||||
DLL_EXPORT int update_inputs_v1(api::Context* ctx,
|
||||
bool* not_need_stop,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
@@ -274,45 +278,45 @@ DLL_EXPORT int update_inputs_v1(Context* ctx,
|
||||
|
||||
template <typename TX, typename TY>
|
||||
DLL_EXPORT int eb_adjust_batch(
|
||||
Context* ctx,
|
||||
api::Context* ctx,
|
||||
const TX* x,
|
||||
TY* y,
|
||||
VectorParam<int32_t>& encoder_seqs_lods, // NOLINT
|
||||
VectorParam<int32_t>& decoder_seqs_lods, // NOLINT
|
||||
VectorParam<int32_t>& encoder_batch_map, // NOLINT
|
||||
VectorParam<int32_t>& decoder_batch_map, // NOLINT
|
||||
api::VectorParam<int32_t>& encoder_seqs_lods, // NOLINT
|
||||
api::VectorParam<int32_t>& decoder_seqs_lods, // NOLINT
|
||||
api::VectorParam<int32_t>& encoder_batch_map, // NOLINT
|
||||
api::VectorParam<int32_t>& decoder_batch_map, // NOLINT
|
||||
int64_t hidden_dim);
|
||||
|
||||
template <typename TX, typename TY>
|
||||
DLL_EXPORT int eb_gather_next_token(
|
||||
Context* ctx,
|
||||
api::Context* ctx,
|
||||
const TX* x,
|
||||
TY* y,
|
||||
VectorParam<int32_t>& encoder_seqs_lods, // NOLINT
|
||||
VectorParam<int32_t>& encoder_batch_map, // NOLINT
|
||||
VectorParam<int32_t>& decoder_batch_map, // NOLINT
|
||||
api::VectorParam<int32_t>& encoder_seqs_lods, // NOLINT
|
||||
api::VectorParam<int32_t>& encoder_batch_map, // NOLINT
|
||||
api::VectorParam<int32_t>& decoder_batch_map, // NOLINT
|
||||
int64_t hidden_dim);
|
||||
|
||||
template <typename TX, typename TY>
|
||||
DLL_EXPORT int eb_mtp_gather_next_token(
|
||||
Context* ctx,
|
||||
api::Context* ctx,
|
||||
const TX* x,
|
||||
TY* y,
|
||||
VectorParam<int32_t>& encoder_seqs_lods, // NOLINT
|
||||
VectorParam<int32_t>& decoder_seqs_lods, // NOLINT
|
||||
VectorParam<int32_t>& encoder_batch_map, // NOLINT
|
||||
VectorParam<int32_t>& decoder_batch_map, // NOLINT
|
||||
api::VectorParam<int32_t>& encoder_seqs_lods, // NOLINT
|
||||
api::VectorParam<int32_t>& decoder_seqs_lods, // NOLINT
|
||||
api::VectorParam<int32_t>& encoder_batch_map, // NOLINT
|
||||
api::VectorParam<int32_t>& decoder_batch_map, // NOLINT
|
||||
int64_t hidden_dim);
|
||||
|
||||
template <typename TX, typename TY>
|
||||
DLL_EXPORT int eb_recover_batch_sequence(
|
||||
Context* ctx,
|
||||
api::Context* ctx,
|
||||
const TX* x,
|
||||
TY* y,
|
||||
VectorParam<int32_t>& encoder_seqs_lods, // NOLINT
|
||||
VectorParam<int32_t>& decoder_seqs_lods, // NOLINT
|
||||
VectorParam<int32_t>& encoder_batch_map, // NOLINT
|
||||
VectorParam<int32_t>& decoder_batch_map, // NOLINT
|
||||
api::VectorParam<int32_t>& encoder_seqs_lods, // NOLINT
|
||||
api::VectorParam<int32_t>& decoder_seqs_lods, // NOLINT
|
||||
api::VectorParam<int32_t>& encoder_batch_map, // NOLINT
|
||||
api::VectorParam<int32_t>& decoder_batch_map, // NOLINT
|
||||
int64_t hidden_dim);
|
||||
|
||||
template <typename TX, typename TSCALE = float, typename TY = int8_t>
|
||||
@@ -324,7 +328,7 @@ DLL_EXPORT int quant2d_per_channel(api::Context* ctx,
|
||||
int64_t m,
|
||||
int64_t n);
|
||||
|
||||
DLL_EXPORT int text_image_index_out(Context* ctx,
|
||||
DLL_EXPORT int text_image_index_out(api::Context* ctx,
|
||||
const int* token_type_ids, // x
|
||||
int* text_index, // y1
|
||||
int* image_index, // y2
|
||||
@@ -372,7 +376,7 @@ DLL_EXPORT int limit_thinking_content_length_kernel_v2(
|
||||
|
||||
template <typename T>
|
||||
DLL_EXPORT int speculate_token_penalty_multi_scores(
|
||||
Context* ctx,
|
||||
api::Context* ctx,
|
||||
const int64_t* pre_ids,
|
||||
T* logits,
|
||||
const T* penalty_scores,
|
||||
@@ -392,7 +396,7 @@ DLL_EXPORT int speculate_token_penalty_multi_scores(
|
||||
const int64_t length_bad_words,
|
||||
const int64_t token_num,
|
||||
const int64_t max_seq_len);
|
||||
DLL_EXPORT int mtp_free_and_dispatch_block(Context* ctx,
|
||||
DLL_EXPORT int mtp_free_and_dispatch_block(api::Context* ctx,
|
||||
bool* base_model_stop_flags,
|
||||
bool* stop_flags,
|
||||
bool* batch_drop,
|
||||
@@ -409,7 +413,7 @@ DLL_EXPORT int mtp_free_and_dispatch_block(Context* ctx,
|
||||
const int max_draft_tokens);
|
||||
|
||||
template <bool ENABLE_TOPP, bool USE_TOPK>
|
||||
DLL_EXPORT int speculate_verify(Context* ctx,
|
||||
DLL_EXPORT int speculate_verify(api::Context* ctx,
|
||||
const int64_t* sampled_token_ids,
|
||||
int64_t* accept_tokens,
|
||||
int* accept_num,
|
||||
@@ -440,19 +444,19 @@ DLL_EXPORT int speculate_verify(Context* ctx,
|
||||
const bool accept_all_drafts,
|
||||
const bool use_target_sampling);
|
||||
|
||||
DLL_EXPORT int speculate_clear_accept_nums(Context* ctx,
|
||||
DLL_EXPORT int speculate_clear_accept_nums(api::Context* ctx,
|
||||
int* accept_num,
|
||||
const int* seq_lens_decoder,
|
||||
const int max_bsz);
|
||||
|
||||
DLL_EXPORT int speculate_get_seq_lens_output(Context* ctx,
|
||||
DLL_EXPORT int speculate_get_seq_lens_output(api::Context* ctx,
|
||||
int* seq_lens_output,
|
||||
const int* seq_lens_this_time,
|
||||
const int* seq_lens_encoder,
|
||||
const int* seq_lens_decoder,
|
||||
const int real_bsz);
|
||||
|
||||
DLL_EXPORT int draft_model_update(Context* ctx,
|
||||
DLL_EXPORT int draft_model_update(api::Context* ctx,
|
||||
const int64_t* inter_next_tokens,
|
||||
int64_t* draft_tokens,
|
||||
int64_t* pre_ids,
|
||||
@@ -475,7 +479,7 @@ DLL_EXPORT int draft_model_update(Context* ctx,
|
||||
const int substep,
|
||||
const bool prefill_one_step_stop);
|
||||
|
||||
DLL_EXPORT int speculate_set_stop_value_multi_seqs(Context* ctx,
|
||||
DLL_EXPORT int speculate_set_stop_value_multi_seqs(api::Context* ctx,
|
||||
bool* stop_flags,
|
||||
int64_t* accept_tokens,
|
||||
int* accept_nums,
|
||||
@@ -504,7 +508,7 @@ DLL_EXPORT int speculate_rebuild_append_padding(api::Context* ctx,
|
||||
T* out);
|
||||
|
||||
template <typename T>
|
||||
DLL_EXPORT int speculate_remove_padding(Context* ctx,
|
||||
DLL_EXPORT int speculate_remove_padding(api::Context* ctx,
|
||||
T* x_remove_padding,
|
||||
const T* input_ids,
|
||||
const T* draft_tokens,
|
||||
@@ -536,7 +540,7 @@ DLL_EXPORT int compute_order(api::Context* ctx,
|
||||
const int actual_draft_token_num,
|
||||
const int input_token_num);
|
||||
|
||||
DLL_EXPORT int draft_model_postprocess(Context* ctx,
|
||||
DLL_EXPORT int draft_model_postprocess(api::Context* ctx,
|
||||
const int64_t* base_model_draft_tokens,
|
||||
int* base_model_seq_lens_this_time,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
@@ -544,7 +548,7 @@ DLL_EXPORT int draft_model_postprocess(Context* ctx,
|
||||
int bsz,
|
||||
int base_model_draft_token_len);
|
||||
|
||||
DLL_EXPORT int speculate_set_value_by_flag_and_id(Context* ctx,
|
||||
DLL_EXPORT int speculate_set_value_by_flag_and_id(api::Context* ctx,
|
||||
int64_t* pre_ids_all,
|
||||
const int64_t* accept_tokens,
|
||||
int* accept_num,
|
||||
@@ -557,7 +561,7 @@ DLL_EXPORT int speculate_set_value_by_flag_and_id(Context* ctx,
|
||||
int max_draft_tokens);
|
||||
|
||||
DLL_EXPORT int speculate_get_output_padding_offset(
|
||||
Context* ctx,
|
||||
api::Context* ctx,
|
||||
int* output_padding_offset,
|
||||
int* output_cum_offsets,
|
||||
const int* output_cum_offsets_tmp,
|
||||
@@ -578,7 +582,7 @@ DLL_EXPORT int top_p_candidates(api::Context* ctx,
|
||||
int max_cadidate_len,
|
||||
int max_seq_len);
|
||||
|
||||
DLL_EXPORT int speculate_free_and_reschedule(Context* ctx,
|
||||
DLL_EXPORT int speculate_free_and_reschedule(api::Context* ctx,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_decoder,
|
||||
@@ -601,7 +605,7 @@ DLL_EXPORT int speculate_free_and_reschedule(Context* ctx,
|
||||
const int max_decoder_block_num,
|
||||
const int max_draft_tokens);
|
||||
|
||||
DLL_EXPORT int speculate_schedule_cache(Context* ctx,
|
||||
DLL_EXPORT int speculate_schedule_cache(api::Context* ctx,
|
||||
const int64_t* draft_tokens,
|
||||
int* block_tables,
|
||||
bool* stop_flags,
|
||||
@@ -625,7 +629,7 @@ DLL_EXPORT int speculate_schedule_cache(Context* ctx,
|
||||
const int block_num_per_seq,
|
||||
const bool prefill_one_step_stop);
|
||||
|
||||
DLL_EXPORT int speculate_update_v3(Context* ctx,
|
||||
DLL_EXPORT int speculate_update_v3(api::Context* ctx,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
bool* not_need_stop,
|
||||
@@ -641,7 +645,7 @@ DLL_EXPORT int speculate_update_v3(Context* ctx,
|
||||
const int max_bsz,
|
||||
const int max_draft_tokens);
|
||||
|
||||
DLL_EXPORT int speculate_update(Context* ctx,
|
||||
DLL_EXPORT int speculate_update(api::Context* ctx,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
bool* not_need_stop,
|
||||
@@ -674,7 +678,7 @@ DLL_EXPORT int rebuild_self_hidden_states(api::Context* ctx,
|
||||
int dim_embed,
|
||||
int elem_cnt);
|
||||
|
||||
DLL_EXPORT int speculate_get_logits(Context* ctx,
|
||||
DLL_EXPORT int speculate_get_logits(api::Context* ctx,
|
||||
float* draft_logits,
|
||||
int* next_token_num,
|
||||
int* batch_token_num,
|
||||
@@ -687,7 +691,7 @@ DLL_EXPORT int speculate_get_logits(Context* ctx,
|
||||
const int real_bsz,
|
||||
const int vocab_size);
|
||||
|
||||
DLL_EXPORT int update_attn_mask_offsets(Context* ctx,
|
||||
DLL_EXPORT int update_attn_mask_offsets(api::Context* ctx,
|
||||
int* attn_mask_offsets,
|
||||
const int* seq_lens_this_time,
|
||||
const int* seq_lens_encoder,
|
||||
@@ -723,6 +727,4 @@ DLL_EXPORT int speculate_limit_thinking_content_length_kernel(
|
||||
* --------------------------------------------*/
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace api
|
||||
} // namespace xpu
|
||||
} // namespace baidu
|
||||
} // namespace fastdeploy
|
||||
|
||||
Reference in New Issue
Block a user