[XPU] refactor: XPU plugin namespace migration (#6799)

* [XPU] refactor: XPU plugin namespace migration

- Migrate wrapper layer namespace from baidu::xpu::api::plugin to fastdeploy::plugin
- Migrate kernel layer namespace from xpu3::plugin to fd_xpu3
- Add api:: prefix for types (Context, SUCCESS, XPUIndexType, ctx_guard)
- Remove XPU2 support, keep only XPU3
- Update ops/ directory to use new namespace

Total: 137 files changed

* [XPU] fix: add return value check and correct error messages

- Add PADDLE_ENFORCE_XDNN_SUCCESS check for speculate_get_logits and update_attn_mask_offsets
- Fix empty error message in draft_model_postprocess
- Correct function name in speculate_schedule_cache error message
- Update error messages from 'xpu::plugin::' to 'fastdeploy::plugin::'
This commit is contained in:
mayang002
2026-03-13 10:21:51 +08:00
committed by GitHub
parent d73fd876ba
commit 1f9f889e37
138 changed files with 1086 additions and 1467 deletions
@@ -18,13 +18,17 @@
#pragma once
#include "xpu/xdnn.h"
namespace baidu {
namespace xpu {
namespace api {
namespace fd_xpu3 {
typedef xpu3::int64_t int64_t;
}
namespace fastdeploy {
namespace plugin {
namespace api = baidu::xpu::api;
template <typename T>
DLL_EXPORT int set_stop_value_multi_ends(Context* ctx,
DLL_EXPORT int set_stop_value_multi_ends(api::Context* ctx,
bool* stop_flags,
T* topk_ids,
T* next_tokens,
@@ -34,7 +38,7 @@ DLL_EXPORT int set_stop_value_multi_ends(Context* ctx,
const int end_length,
const bool beam_search);
DLL_EXPORT int set_value_by_flags_and_idx(Context* ctx,
DLL_EXPORT int set_value_by_flags_and_idx(api::Context* ctx,
const bool* stop_flags,
int64_t* pre_ids_all,
const int64_t* input_ids,
@@ -46,7 +50,7 @@ DLL_EXPORT int set_value_by_flags_and_idx(Context* ctx,
int length_input_ids);
template <typename T>
DLL_EXPORT int token_penalty_multi_scores(Context* ctx,
DLL_EXPORT int token_penalty_multi_scores(api::Context* ctx,
const int64_t* pre_ids,
T* logits,
const T* penalty_scores,
@@ -63,7 +67,7 @@ DLL_EXPORT int token_penalty_multi_scores(Context* ctx,
const int64_t end_length,
const int64_t length_bad_words);
DLL_EXPORT int get_padding_offset(Context* ctx,
DLL_EXPORT int get_padding_offset(api::Context* ctx,
int* padding_offset,
int* cum_offsets_out,
int* cu_seqlens_q,
@@ -75,7 +79,7 @@ DLL_EXPORT int get_padding_offset(Context* ctx,
const int max_seq_len,
const int bs);
DLL_EXPORT int speculate_get_padding_offset(Context* ctx,
DLL_EXPORT int speculate_get_padding_offset(api::Context* ctx,
int* batch_id_per_token,
int* cum_offsets_out,
int* cu_seqlens_q,
@@ -117,7 +121,7 @@ DLL_EXPORT int draft_model_preprocess(api::Context* ctx,
const bool splitwise_prefill,
const bool kvcache_scheduler_v1);
DLL_EXPORT int update_inputs(Context* ctx,
DLL_EXPORT int update_inputs(api::Context* ctx,
bool* not_need_stop,
int* seq_lens_this_time,
int* seq_lens_encoder,
@@ -130,7 +134,7 @@ DLL_EXPORT int update_inputs(Context* ctx,
const int max_bsz,
const int input_ids_stride);
DLL_EXPORT int free_and_dispatch_block(Context* ctx,
DLL_EXPORT int free_and_dispatch_block(api::Context* ctx,
bool* stop_flags,
int* seq_lens_this_time,
int* seq_lens_decoder,
@@ -153,7 +157,7 @@ DLL_EXPORT int free_and_dispatch_block(Context* ctx,
const int max_decoder_block_num);
DLL_EXPORT int speculate_free_and_dispatch_block(
Context* ctx,
api::Context* ctx,
bool* stop_flags,
int* seq_lens_this_time,
int* seq_lens_decoder,
@@ -177,7 +181,7 @@ DLL_EXPORT int speculate_free_and_dispatch_block(
const int max_decoder_block_num,
const int max_draft_tokens);
DLL_EXPORT int recover_block(Context* ctx,
DLL_EXPORT int recover_block(api::Context* ctx,
int* recover_block_list, // [bsz]
int* recover_len,
bool* stop_flags,
@@ -200,7 +204,7 @@ DLL_EXPORT int recover_block(Context* ctx,
const int length,
const int pre_id_length);
DLL_EXPORT int speculate_recover_block(Context* ctx,
DLL_EXPORT int speculate_recover_block(api::Context* ctx,
int* recover_block_list, // [bsz]
int* recover_len,
bool* stop_flags,
@@ -224,7 +228,7 @@ DLL_EXPORT int speculate_recover_block(Context* ctx,
const int length,
const int pre_id_length);
DLL_EXPORT int recover_decode_task(Context* ctx,
DLL_EXPORT int recover_decode_task(api::Context* ctx,
bool* stop_flags,
int* seq_lens_this_time,
int* seq_lens_encoder,
@@ -236,7 +240,7 @@ DLL_EXPORT int recover_decode_task(Context* ctx,
const int block_num_per_seq,
const int block_size);
DLL_EXPORT int recover_spec_decode_task(Context* ctx,
DLL_EXPORT int recover_spec_decode_task(api::Context* ctx,
bool* stop_flags,
int* seq_lens_this_time,
int* seq_lens_encoder,
@@ -253,7 +257,7 @@ DLL_EXPORT int recover_spec_decode_task(Context* ctx,
const int draft_tokens_len,
const int num_extra_tokens);
DLL_EXPORT int update_inputs_v1(Context* ctx,
DLL_EXPORT int update_inputs_v1(api::Context* ctx,
bool* not_need_stop,
int* seq_lens_this_time,
int* seq_lens_encoder,
@@ -274,45 +278,45 @@ DLL_EXPORT int update_inputs_v1(Context* ctx,
template <typename TX, typename TY>
DLL_EXPORT int eb_adjust_batch(
Context* ctx,
api::Context* ctx,
const TX* x,
TY* y,
VectorParam<int32_t>& encoder_seqs_lods, // NOLINT
VectorParam<int32_t>& decoder_seqs_lods, // NOLINT
VectorParam<int32_t>& encoder_batch_map, // NOLINT
VectorParam<int32_t>& decoder_batch_map, // NOLINT
api::VectorParam<int32_t>& encoder_seqs_lods, // NOLINT
api::VectorParam<int32_t>& decoder_seqs_lods, // NOLINT
api::VectorParam<int32_t>& encoder_batch_map, // NOLINT
api::VectorParam<int32_t>& decoder_batch_map, // NOLINT
int64_t hidden_dim);
template <typename TX, typename TY>
DLL_EXPORT int eb_gather_next_token(
Context* ctx,
api::Context* ctx,
const TX* x,
TY* y,
VectorParam<int32_t>& encoder_seqs_lods, // NOLINT
VectorParam<int32_t>& encoder_batch_map, // NOLINT
VectorParam<int32_t>& decoder_batch_map, // NOLINT
api::VectorParam<int32_t>& encoder_seqs_lods, // NOLINT
api::VectorParam<int32_t>& encoder_batch_map, // NOLINT
api::VectorParam<int32_t>& decoder_batch_map, // NOLINT
int64_t hidden_dim);
template <typename TX, typename TY>
DLL_EXPORT int eb_mtp_gather_next_token(
Context* ctx,
api::Context* ctx,
const TX* x,
TY* y,
VectorParam<int32_t>& encoder_seqs_lods, // NOLINT
VectorParam<int32_t>& decoder_seqs_lods, // NOLINT
VectorParam<int32_t>& encoder_batch_map, // NOLINT
VectorParam<int32_t>& decoder_batch_map, // NOLINT
api::VectorParam<int32_t>& encoder_seqs_lods, // NOLINT
api::VectorParam<int32_t>& decoder_seqs_lods, // NOLINT
api::VectorParam<int32_t>& encoder_batch_map, // NOLINT
api::VectorParam<int32_t>& decoder_batch_map, // NOLINT
int64_t hidden_dim);
template <typename TX, typename TY>
DLL_EXPORT int eb_recover_batch_sequence(
Context* ctx,
api::Context* ctx,
const TX* x,
TY* y,
VectorParam<int32_t>& encoder_seqs_lods, // NOLINT
VectorParam<int32_t>& decoder_seqs_lods, // NOLINT
VectorParam<int32_t>& encoder_batch_map, // NOLINT
VectorParam<int32_t>& decoder_batch_map, // NOLINT
api::VectorParam<int32_t>& encoder_seqs_lods, // NOLINT
api::VectorParam<int32_t>& decoder_seqs_lods, // NOLINT
api::VectorParam<int32_t>& encoder_batch_map, // NOLINT
api::VectorParam<int32_t>& decoder_batch_map, // NOLINT
int64_t hidden_dim);
template <typename TX, typename TSCALE = float, typename TY = int8_t>
@@ -324,7 +328,7 @@ DLL_EXPORT int quant2d_per_channel(api::Context* ctx,
int64_t m,
int64_t n);
DLL_EXPORT int text_image_index_out(Context* ctx,
DLL_EXPORT int text_image_index_out(api::Context* ctx,
const int* token_type_ids, // x
int* text_index, // y1
int* image_index, // y2
@@ -372,7 +376,7 @@ DLL_EXPORT int limit_thinking_content_length_kernel_v2(
template <typename T>
DLL_EXPORT int speculate_token_penalty_multi_scores(
Context* ctx,
api::Context* ctx,
const int64_t* pre_ids,
T* logits,
const T* penalty_scores,
@@ -392,7 +396,7 @@ DLL_EXPORT int speculate_token_penalty_multi_scores(
const int64_t length_bad_words,
const int64_t token_num,
const int64_t max_seq_len);
DLL_EXPORT int mtp_free_and_dispatch_block(Context* ctx,
DLL_EXPORT int mtp_free_and_dispatch_block(api::Context* ctx,
bool* base_model_stop_flags,
bool* stop_flags,
bool* batch_drop,
@@ -409,7 +413,7 @@ DLL_EXPORT int mtp_free_and_dispatch_block(Context* ctx,
const int max_draft_tokens);
template <bool ENABLE_TOPP, bool USE_TOPK>
DLL_EXPORT int speculate_verify(Context* ctx,
DLL_EXPORT int speculate_verify(api::Context* ctx,
const int64_t* sampled_token_ids,
int64_t* accept_tokens,
int* accept_num,
@@ -440,19 +444,19 @@ DLL_EXPORT int speculate_verify(Context* ctx,
const bool accept_all_drafts,
const bool use_target_sampling);
DLL_EXPORT int speculate_clear_accept_nums(Context* ctx,
DLL_EXPORT int speculate_clear_accept_nums(api::Context* ctx,
int* accept_num,
const int* seq_lens_decoder,
const int max_bsz);
DLL_EXPORT int speculate_get_seq_lens_output(Context* ctx,
DLL_EXPORT int speculate_get_seq_lens_output(api::Context* ctx,
int* seq_lens_output,
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int* seq_lens_decoder,
const int real_bsz);
DLL_EXPORT int draft_model_update(Context* ctx,
DLL_EXPORT int draft_model_update(api::Context* ctx,
const int64_t* inter_next_tokens,
int64_t* draft_tokens,
int64_t* pre_ids,
@@ -475,7 +479,7 @@ DLL_EXPORT int draft_model_update(Context* ctx,
const int substep,
const bool prefill_one_step_stop);
DLL_EXPORT int speculate_set_stop_value_multi_seqs(Context* ctx,
DLL_EXPORT int speculate_set_stop_value_multi_seqs(api::Context* ctx,
bool* stop_flags,
int64_t* accept_tokens,
int* accept_nums,
@@ -504,7 +508,7 @@ DLL_EXPORT int speculate_rebuild_append_padding(api::Context* ctx,
T* out);
template <typename T>
DLL_EXPORT int speculate_remove_padding(Context* ctx,
DLL_EXPORT int speculate_remove_padding(api::Context* ctx,
T* x_remove_padding,
const T* input_ids,
const T* draft_tokens,
@@ -536,7 +540,7 @@ DLL_EXPORT int compute_order(api::Context* ctx,
const int actual_draft_token_num,
const int input_token_num);
DLL_EXPORT int draft_model_postprocess(Context* ctx,
DLL_EXPORT int draft_model_postprocess(api::Context* ctx,
const int64_t* base_model_draft_tokens,
int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
@@ -544,7 +548,7 @@ DLL_EXPORT int draft_model_postprocess(Context* ctx,
int bsz,
int base_model_draft_token_len);
DLL_EXPORT int speculate_set_value_by_flag_and_id(Context* ctx,
DLL_EXPORT int speculate_set_value_by_flag_and_id(api::Context* ctx,
int64_t* pre_ids_all,
const int64_t* accept_tokens,
int* accept_num,
@@ -557,7 +561,7 @@ DLL_EXPORT int speculate_set_value_by_flag_and_id(Context* ctx,
int max_draft_tokens);
DLL_EXPORT int speculate_get_output_padding_offset(
Context* ctx,
api::Context* ctx,
int* output_padding_offset,
int* output_cum_offsets,
const int* output_cum_offsets_tmp,
@@ -578,7 +582,7 @@ DLL_EXPORT int top_p_candidates(api::Context* ctx,
int max_cadidate_len,
int max_seq_len);
DLL_EXPORT int speculate_free_and_reschedule(Context* ctx,
DLL_EXPORT int speculate_free_and_reschedule(api::Context* ctx,
bool* stop_flags,
int* seq_lens_this_time,
int* seq_lens_decoder,
@@ -601,7 +605,7 @@ DLL_EXPORT int speculate_free_and_reschedule(Context* ctx,
const int max_decoder_block_num,
const int max_draft_tokens);
DLL_EXPORT int speculate_schedule_cache(Context* ctx,
DLL_EXPORT int speculate_schedule_cache(api::Context* ctx,
const int64_t* draft_tokens,
int* block_tables,
bool* stop_flags,
@@ -625,7 +629,7 @@ DLL_EXPORT int speculate_schedule_cache(Context* ctx,
const int block_num_per_seq,
const bool prefill_one_step_stop);
DLL_EXPORT int speculate_update_v3(Context* ctx,
DLL_EXPORT int speculate_update_v3(api::Context* ctx,
int* seq_lens_encoder,
int* seq_lens_decoder,
bool* not_need_stop,
@@ -641,7 +645,7 @@ DLL_EXPORT int speculate_update_v3(Context* ctx,
const int max_bsz,
const int max_draft_tokens);
DLL_EXPORT int speculate_update(Context* ctx,
DLL_EXPORT int speculate_update(api::Context* ctx,
int* seq_lens_encoder,
int* seq_lens_decoder,
bool* not_need_stop,
@@ -674,7 +678,7 @@ DLL_EXPORT int rebuild_self_hidden_states(api::Context* ctx,
int dim_embed,
int elem_cnt);
DLL_EXPORT int speculate_get_logits(Context* ctx,
DLL_EXPORT int speculate_get_logits(api::Context* ctx,
float* draft_logits,
int* next_token_num,
int* batch_token_num,
@@ -687,7 +691,7 @@ DLL_EXPORT int speculate_get_logits(Context* ctx,
const int real_bsz,
const int vocab_size);
DLL_EXPORT int update_attn_mask_offsets(Context* ctx,
DLL_EXPORT int update_attn_mask_offsets(api::Context* ctx,
int* attn_mask_offsets,
const int* seq_lens_this_time,
const int* seq_lens_encoder,
@@ -723,6 +727,4 @@ DLL_EXPORT int speculate_limit_thinking_content_length_kernel(
* --------------------------------------------*/
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
} // namespace fastdeploy