[XPU] Fix speculate schedule (#7049)

* [BugFix] xpu fix speculate schedule cache kernel

* fix code style
This commit is contained in:
cmcamdy
2026-03-27 18:28:17 +08:00
committed by GitHub
parent 11ad95ba91
commit bf8e9bf81d
2 changed files with 16 additions and 11 deletions
@@ -93,6 +93,9 @@ void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
not_need_stop_gpu.copy_to(not_need_stop.place(), true);
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
if (stop_flags.is_cpu()) {
delete ctx;
}
}
PD_BUILD_STATIC_OP(speculate_schedule_cache)
@@ -61,8 +61,7 @@ __global__ void speculate_schedule_cache(const int64_t *draft_tokens,
const int block_num_per_seq,
const bool prefill_one_step_stop) {
const int cid = core_id();
const int tid = core_id() * cluster_num() + cluster_id();
const int nthreads = core_num() * cluster_num();
const int nthreads = core_num();
__shared__ int stop_flag_now_int_sm[64];
int value_zero = 0;
@@ -70,13 +69,15 @@ __global__ void speculate_schedule_cache(const int64_t *draft_tokens,
bool value_false = false;
int bid_start_core, bid_end_core;
partition(tid, nthreads, max_bsz, 1, &bid_start_core, &bid_end_core);
partition(cid, nthreads, max_bsz, 1, &bid_start_core, &bid_end_core);
int64_t draft_tokens_lm[draft_tokens_len];
int64_t step_draft_tokens_lm[draft_tokens_len];
int block_table_lm[block_num_per_seq];
int64_t accept_tokens_lm[accept_tokens_len];
const int max_draft_tokens = 6;
int64_t draft_tokens_lm[max_draft_tokens];
int64_t step_draft_tokens_lm[max_draft_tokens];
int64_t accept_tokens_lm[max_draft_tokens];
int block_table_lm;
int *seq_lens_encoder_lm;
int seq_lens_decoder_lm;
int64_t prompt_len_lm;
@@ -92,9 +93,7 @@ __global__ void speculate_schedule_cache(const int64_t *draft_tokens,
GM2LM_ASYNC(step_draft_tokens + bid * draft_tokens_len,
step_draft_tokens_lm,
draft_tokens_len * sizeof(int64_t));
GM2LM_ASYNC(block_tables + bid * block_num_per_seq,
block_table_lm,
block_num_per_seq * sizeof(int));
GM2LM_ASYNC(accept_tokens + bid * accept_tokens_len,
accept_tokens_lm,
accept_tokens_len * sizeof(int64_t));
@@ -108,6 +107,9 @@ __global__ void speculate_schedule_cache(const int64_t *draft_tokens,
if (seq_lens_decoder_lm >= prompt_len_lm) {
const int max_possible_block_idx =
(seq_lens_decoder_lm + max_next_step_tokens) / block_size;
GM2LM(block_tables + bid * block_num_per_seq + max_possible_block_idx,
&block_table_lm,
sizeof(int));
if (prefill_one_step_stop) {
LM2GM_ASYNC(&value_true, stop_flags + bid, sizeof(bool));
LM2GM_ASYNC(&value_zero, seq_lens_this_time + bid, sizeof(int));
@@ -117,7 +119,7 @@ __global__ void speculate_schedule_cache(const int64_t *draft_tokens,
mfence();
stop_flag_now_int = 1;
} else if (max_possible_block_idx < block_num_per_seq &&
block_table_lm[max_possible_block_idx] == -1) {
block_table_lm == -1) {
LM2GM_ASYNC(&value_true, is_block_step + bid, sizeof(bool));
LM2GM_ASYNC(&seq_lens_this_time_lm,
step_seq_lens_this_time + bid,