[XPU] Add speculate_limit_thinking_content_length Op. (#6627)

* [XPU] Add speculate_limit_thinking_content_length OP for xpu.

* add unittest.

* format codes.

* format codes.

* format codes.

* Fix unused kernel launch return value.

---------

Co-authored-by: cmcamdy <1027740945@qq.com>
This commit is contained in:
Jiajun Ji
2026-03-11 17:30:17 +08:00
committed by GitHub
parent 9f0778f991
commit 88c4fbf8e1
5 changed files with 1038 additions and 11 deletions
@@ -645,17 +645,17 @@ DLL_EXPORT int speculate_update_v3(Context* ctx,
const int max_draft_tokens);
DLL_EXPORT int speculate_update(Context* ctx,
int* seq_lens_encoder, // 输入 [B_max, ]
int* seq_lens_decoder, // 输出 [B_max, ]
bool* not_need_stop, // [1,]
int64_t* draft_tokens, // [B_max, T_max]
int* actual_draft_token_nums, // [B_max, ]
const int64_t* accept_tokens, // [B_max, T_max]
const int* accept_num, // [B_max, ]
const bool* stop_flags, // [B_max, ]
const int* seq_lens_this_time, // [B_real,]
const bool* is_block_step, // [B_max, ]
int* mask_rollback, // [1,]
int* seq_lens_encoder,
int* seq_lens_decoder,
bool* not_need_stop,
int64_t* draft_tokens,
int* actual_draft_token_nums,
const int64_t* accept_tokens,
const int* accept_num,
const bool* stop_flags,
const int* seq_lens_this_time,
const bool* is_block_step,
int* mask_rollback,
const int real_bsz,
const int max_bsz,
const int max_draft_tokens);
@@ -704,6 +704,24 @@ DLL_EXPORT int update_attn_mask_offsets(Context* ctx,
int real_bsz,
int max_model_len,
int decode_states_len);
DLL_EXPORT int speculate_limit_thinking_content_length_kernel(
api::Context* ctx,
int64_t* next_tokens,
const int* max_think_lens,
int* max_reply_lens,
int64_t* step_idx,
const int64_t* eos_token_ids,
int* limit_status,
int* accept_num,
const bool* stop_flags,
const int64_t think_end_id,
const int64_t* inject_token_ids,
const int tokens_per_step,
const int bs,
const int eos_token_id_len,
const int inject_len,
const bool splitwise_role_is_decode);
/*--------------------------------------- MTP end
* --------------------------------------------*/