mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[XPU] Fix speculate schedule (#7049)
* [BugFix] xpu fix speculate schedule cache kernel * fix code style
This commit is contained in:
@@ -93,6 +93,9 @@ void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), true);
|
||||
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
if (stop_flags.is_cpu()) {
|
||||
delete ctx;
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_schedule_cache)
|
||||
|
||||
+13
-11
@@ -61,8 +61,7 @@ __global__ void speculate_schedule_cache(const int64_t *draft_tokens,
|
||||
const int block_num_per_seq,
|
||||
const bool prefill_one_step_stop) {
|
||||
const int cid = core_id();
|
||||
const int tid = core_id() * cluster_num() + cluster_id();
|
||||
const int nthreads = core_num() * cluster_num();
|
||||
const int nthreads = core_num();
|
||||
|
||||
__shared__ int stop_flag_now_int_sm[64];
|
||||
int value_zero = 0;
|
||||
@@ -70,13 +69,15 @@ __global__ void speculate_schedule_cache(const int64_t *draft_tokens,
|
||||
bool value_false = false;
|
||||
|
||||
int bid_start_core, bid_end_core;
|
||||
partition(tid, nthreads, max_bsz, 1, &bid_start_core, &bid_end_core);
|
||||
partition(cid, nthreads, max_bsz, 1, &bid_start_core, &bid_end_core);
|
||||
|
||||
int64_t draft_tokens_lm[draft_tokens_len];
|
||||
int64_t step_draft_tokens_lm[draft_tokens_len];
|
||||
int block_table_lm[block_num_per_seq];
|
||||
int64_t accept_tokens_lm[accept_tokens_len];
|
||||
const int max_draft_tokens = 6;
|
||||
|
||||
int64_t draft_tokens_lm[max_draft_tokens];
|
||||
int64_t step_draft_tokens_lm[max_draft_tokens];
|
||||
int64_t accept_tokens_lm[max_draft_tokens];
|
||||
|
||||
int block_table_lm;
|
||||
int *seq_lens_encoder_lm;
|
||||
int seq_lens_decoder_lm;
|
||||
int64_t prompt_len_lm;
|
||||
@@ -92,9 +93,7 @@ __global__ void speculate_schedule_cache(const int64_t *draft_tokens,
|
||||
GM2LM_ASYNC(step_draft_tokens + bid * draft_tokens_len,
|
||||
step_draft_tokens_lm,
|
||||
draft_tokens_len * sizeof(int64_t));
|
||||
GM2LM_ASYNC(block_tables + bid * block_num_per_seq,
|
||||
block_table_lm,
|
||||
block_num_per_seq * sizeof(int));
|
||||
|
||||
GM2LM_ASYNC(accept_tokens + bid * accept_tokens_len,
|
||||
accept_tokens_lm,
|
||||
accept_tokens_len * sizeof(int64_t));
|
||||
@@ -108,6 +107,9 @@ __global__ void speculate_schedule_cache(const int64_t *draft_tokens,
|
||||
if (seq_lens_decoder_lm >= prompt_len_lm) {
|
||||
const int max_possible_block_idx =
|
||||
(seq_lens_decoder_lm + max_next_step_tokens) / block_size;
|
||||
GM2LM(block_tables + bid * block_num_per_seq + max_possible_block_idx,
|
||||
&block_table_lm,
|
||||
sizeof(int));
|
||||
if (prefill_one_step_stop) {
|
||||
LM2GM_ASYNC(&value_true, stop_flags + bid, sizeof(bool));
|
||||
LM2GM_ASYNC(&value_zero, seq_lens_this_time + bid, sizeof(int));
|
||||
@@ -117,7 +119,7 @@ __global__ void speculate_schedule_cache(const int64_t *draft_tokens,
|
||||
mfence();
|
||||
stop_flag_now_int = 1;
|
||||
} else if (max_possible_block_idx < block_num_per_seq &&
|
||||
block_table_lm[max_possible_block_idx] == -1) {
|
||||
block_table_lm == -1) {
|
||||
LM2GM_ASYNC(&value_true, is_block_step + bid, sizeof(bool));
|
||||
LM2GM_ASYNC(&seq_lens_this_time_lm,
|
||||
step_seq_lens_this_time + bid,
|
||||
|
||||
Reference in New Issue
Block a user