[Speculative Decoding] refactor MTP and optimize spec-decoding postprocess (#6973)

* support new mtp

* refactor(speculate_decoding and mtp): optimize mtp sturcture logic. Update spec-branch status-process

* fix cuda-graph for spec-decoding

* fix xpu mtp and fix some note

* fix unittest and optmize note

* fix model status update in eos-branch
This commit is contained in:
freeliuzc
2026-03-24 10:19:01 +08:00
committed by GitHub
parent c62f6b4ea5
commit e87ce4b8cd
13 changed files with 1401 additions and 1150 deletions
+19 -19
View File
@@ -875,7 +875,6 @@ void UnifiedUpdateModelStatus(const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& has_running_seqs,
const paddle::Tensor& step_input_ids,
const paddle::Tensor& adaptive_step_input_len,
const paddle::Tensor& step_output_ids,
const paddle::Tensor& step_output_len,
const paddle::Tensor& stop_flags,
@@ -886,9 +885,13 @@ void UnifiedUpdateModelStatus(const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& prompt_lens,
const paddle::Tensor& step_idx,
const paddle::Tensor& end_tokens,
const paddle::Tensor& max_dec_len,
const bool is_naive_mode,
const bool prefill_one_step_stop);
const paddle::Tensor& max_dec_len);
void NaiveUpdateModelStatus(const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& next_tokens,
const paddle::Tensor& cu_seqlens_q_output);
void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor& token_ids_all,
const paddle::Tensor& prompt_lens,
@@ -971,24 +974,17 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& is_block_step,
const paddle::Tensor& batch_drop,
const paddle::Tensor& pre_ids,
const paddle::Tensor& mask_rollback,
const paddle::Tensor& recompute_token_num,
const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& base_model_seq_lens_this_time,
const paddle::Tensor& base_model_seq_lens_encoder,
const paddle::Tensor& base_model_seq_lens_decoder,
const paddle::Tensor& base_model_step_idx,
const paddle::Tensor& base_model_stop_flags,
const paddle::Tensor& base_model_is_block_step,
const paddle::Tensor& base_model_draft_tokens,
const int max_draft_token,
const bool truncate_first_token,
const bool splitwise_prefill,
const bool kvcache_scheduler_v1);
const paddle::Tensor& target_model_seq_lens_encoder,
const paddle::Tensor& target_model_seq_lens_decoder,
const paddle::Tensor& target_model_step_idx,
const paddle::Tensor& target_model_stop_flags,
const paddle::Tensor& max_dec_len,
const paddle::Tensor& target_model_draft_tokens,
const int num_model_step,
const bool is_splitwise_prefill);
void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
const paddle::Tensor& draft_tokens,
@@ -1785,6 +1781,10 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
&UnifiedUpdateModelStatus,
"unified_update_model_status function");
m.def("naive_update_model_status",
&NaiveUpdateModelStatus,
"naive_update_model_status function");
m.def("speculate_set_value_by_flags_and_idx",
&SpeculateSetValueByFlagsAndIdx,
"speculate_set_value_by_flags_and_idx function");
@@ -15,119 +15,16 @@
#include "helper.h"
#include "paddle/extension.h"
#define DISPATCH_BLOCKSIZE(BLOCK_SIZE, ...) \
do { \
constexpr int BlockSize = BLOCK_SIZE; \
__VA_ARGS__; \
} while (0)
#define DISPATCH_TRUNCATE_FIRST_TOKEN( \
truncate_first_token, TRUNCATE_FIRST_TOKEN, ...) \
do { \
if (truncate_first_token) { \
constexpr bool TRUNCATE_FIRST_TOKEN = true; \
__VA_ARGS__; \
} else { \
constexpr bool TRUNCATE_FIRST_TOKEN = false; \
__VA_ARGS__; \
} \
} while (0)
#define DISPATCH_KVCACHE_SCHEDULER( \
kvcache_scheduler_v1, KVCACHE_SCHEDULER_V1, ...) \
do { \
if (kvcache_scheduler_v1) { \
constexpr bool KVCACHE_SCHEDULER_V1 = true; \
__VA_ARGS__; \
} else { \
constexpr bool KVCACHE_SCHEDULER_V1 = false; \
__VA_ARGS__; \
} \
} while (0)
#define DISPATCH_SPLITWISE_PREFILL(splitwise_prefill, SPLITWISE_PREFILL, ...) \
do { \
if (splitwise_prefill) { \
constexpr bool SPLITWISE_PREFILL = true; \
__VA_ARGS__; \
} else { \
constexpr bool SPLITWISE_PREFILL = false; \
__VA_ARGS__; \
} \
} while (0)
template <int THREADBLOCK_SIZE,
bool TRUNCATE_FIRST_TOKEN,
bool KVCACHE_SCHEDULER_V1>
__global__ void process_splitwise_prefill(
int64_t* draft_tokens,
int64_t* input_ids,
bool* stop_flags,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
bool* not_need_stop,
bool* is_block_step,
bool* batch_drop,
int64_t* pre_ids,
const int64_t* accept_tokens,
const int* accept_num,
const int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx,
const bool* base_model_stop_flags,
const bool* base_model_is_block_step,
int64_t* base_model_draft_tokens,
const int bsz,
const int num_model_step,
const int accept_tokens_len,
const int draft_tokens_len,
const int input_ids_len,
const int base_model_draft_tokens_len,
const int pre_ids_len) {
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int64_t not_stop_flag = 0;
int tid = threadIdx.x;
if (tid < bsz) {
int base_model_step_idx_now = base_model_step_idx[tid];
auto* input_ids_now = input_ids + tid * input_ids_len;
auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len;
if (seq_lens_encoder[tid] > 0) {
not_stop_flag = 1;
int seq_len_encoder = seq_lens_encoder[tid];
stop_flags[tid] = false;
int64_t base_model_first_token = accept_tokens_now[0];
int position = seq_len_encoder;
if (TRUNCATE_FIRST_TOKEN) {
input_ids_now[position - 1] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder;
} else {
input_ids_now[position] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder + 1;
}
} else {
stop_flags[tid] = true;
seq_lens_this_time[tid] = 0;
seq_lens_decoder[tid] = 0;
seq_lens_encoder[tid] = 0;
not_stop_flag = 0;
}
}
__syncthreads();
int64_t not_stop_flag_sum = BlockReduce(temp_storage).Sum(not_stop_flag);
if (tid == 0) {
not_need_stop[0] = not_stop_flag_sum > 0;
}
}
template <int THREADBLOCK_SIZE,
bool TRUNCATE_FIRST_TOKEN,
bool KVCACHE_SCHEDULER_V1>
// Main draft preprocess kernel.
// MTP state (seq_lens_decoder, step_idx) is "shadow state":
// - Initialized from target model state each round
// - Used for MTP forward, but not committed until verify
// - No rollback needed since it's always re-initialized
//
// is_splitwise_prefill: set true on P-D disaggregated prefill node.
// In that mode, only prefill requests (seq_lens_encoder > 0) run MTP;
// decode requests are marked stopped and skipped.
template <int THREADBLOCK_SIZE>
__global__ void draft_model_preprocess_kernel(
int64_t* draft_tokens,
int64_t* input_ids,
@@ -137,27 +34,23 @@ __global__ void draft_model_preprocess_kernel(
int* seq_lens_decoder,
int64_t* step_idx,
bool* not_need_stop,
bool* is_block_step,
bool* batch_drop,
int64_t* pre_ids,
int* mask_rollback,
int* recompute_token_num,
const int64_t* accept_tokens,
const int* accept_num,
const int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx,
const bool* base_model_stop_flags,
const bool* base_model_is_block_step,
int64_t* base_model_draft_tokens,
const int* target_model_seq_lens_encoder,
const int* target_model_seq_lens_decoder,
const int64_t* target_model_step_idx,
const bool* target_model_stop_flags,
const int64_t* max_dec_len,
int64_t* target_model_draft_tokens,
const int bsz,
const int num_model_step,
const int accept_tokens_len,
const int draft_tokens_len,
const int input_ids_len,
const int base_model_draft_tokens_len,
const int pre_ids_len) {
const int target_model_draft_tokens_len,
const int pre_ids_len,
const bool is_splitwise_prefill) {
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int64_t not_stop_flag = 0;
@@ -165,96 +58,83 @@ __global__ void draft_model_preprocess_kernel(
int tid = threadIdx.x;
if (tid < bsz) {
const int32_t base_model_step_idx_now = base_model_step_idx[tid];
auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len;
auto* draft_tokens_now = draft_tokens + tid * draft_tokens_len;
const int32_t accept_num_now = accept_num[tid];
auto* input_ids_now = input_ids + tid * input_ids_len;
auto* base_model_draft_tokens_now =
base_model_draft_tokens + tid * base_model_draft_tokens_len;
auto base_model_seq_len_decoder = base_model_seq_lens_decoder[tid];
const int32_t base_model_seq_len_this_time =
base_model_seq_lens_this_time[tid];
auto* target_model_draft_tokens_now =
target_model_draft_tokens + tid * target_model_draft_tokens_len;
auto* pre_ids_now = pre_ids + tid * pre_ids_len;
const auto target_step = target_model_step_idx[tid];
auto seq_len_encoder = seq_lens_encoder[tid];
// Clear target_model_draft_tokens (keep first token)
#pragma unroll
for (int i = 1; i < base_model_draft_tokens_len; i++) {
base_model_draft_tokens_now[i] = -1;
for (int i = 1; i < target_model_draft_tokens_len; i++) {
target_model_draft_tokens_now[i] = -1;
}
// 1. process block_step situation
// -- In v0 mode, block_step will drop mtp query.
// -- In v1 mode, block_step will continue to infer.
if constexpr (KVCACHE_SCHEDULER_V1) {
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
stop_flags[tid] = true;
is_block_step[tid] = true;
// Need to continue infer
}
} else {
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
batch_drop[tid] = true;
stop_flags[tid] = true;
}
// ============================================================
// Decision: Should MTP/Draft model run?
// ============================================================
bool should_skip = false;
// Target model stopped
if (target_model_stop_flags[tid]) {
should_skip = true;
}
// 2. process normal query, not in any special case.
if (!(base_model_stop_flags[tid] || batch_drop[tid])) {
not_stop_flag = 1;
// prefill generation
if (seq_lens_encoder[tid] > 0) {
// Can be extended to first few tokens
int seq_len_encoder = seq_lens_encoder[tid];
stop_flags[tid] = false;
int64_t base_model_first_token = accept_tokens_now[0];
pre_ids_now[0] = base_model_first_token;
int position = seq_len_encoder;
if (TRUNCATE_FIRST_TOKEN) {
input_ids_now[position - 1] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder;
} else {
input_ids_now[position] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder + 1;
}
} else { // decode generation
if constexpr (KVCACHE_SCHEDULER_V1) {
// 3. try to recover mtp infer in V1 mode
if (!base_model_is_block_step[tid] && is_block_step[tid]) {
is_block_step[tid] = false;
}
}
if (stop_flags[tid]) {
stop_flags[tid] = false;
// TODO: check
seq_lens_decoder[tid] =
base_model_seq_len_decoder - base_model_seq_len_this_time;
step_idx[tid] =
base_model_step_idx[tid] - base_model_seq_len_this_time;
} else {
// 2: Last base model generated token and first MTP token
const int recompute_token_num_now = recompute_token_num[tid];
seq_lens_decoder[tid] -= recompute_token_num_now;
step_idx[tid] -= recompute_token_num_now;
mask_rollback[tid] += recompute_token_num_now;
// NOTE(liuzichang): Used for PD-split mode and future dynamic
// strategies.
recompute_token_num[tid] = num_model_step - 1;
}
for (int i = 0; i < accept_num_now; i++) {
draft_tokens_now[i] = accept_tokens_now[i];
const int pre_id_pos =
base_model_step_idx[tid] - (accept_num_now - i);
const int64_t accept_token = accept_tokens_now[i];
pre_ids_now[pre_id_pos] = accept_token;
}
seq_lens_this_time[tid] = accept_num_now;
}
} else {
// Near end of max_dec_len in no prefill node
if ((not is_splitwise_prefill &&
target_step + num_model_step >= max_dec_len[tid])) {
should_skip = true;
}
// ============================================================
// Execute based on decision
// ============================================================
if (should_skip) {
stop_flags[tid] = true;
seq_lens_this_time[tid] = 0;
seq_lens_decoder[tid] = 0;
seq_lens_encoder[tid] = 0;
step_idx[tid] = 0;
not_stop_flag = 0;
} else {
not_stop_flag = 1;
stop_flags[tid] = false;
if (seq_len_encoder > 0) {
// prefill | chunk_prefill | prompt_cache | recover after preempted
int64_t target_model_first_token = accept_tokens_now[0];
pre_ids_now[0] = target_model_first_token;
input_ids_now[seq_len_encoder - 1] = target_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder;
// Shadow state: initialize from target model (prefill just finished)
step_idx[tid] = target_step - 1;
} else {
// Decode: MTP shadow state from target model
// Shadow state is initialized from target model state each round
// This is the key simplification: no rollback needed
int32_t need_compute_token = accept_num_now;
seq_lens_decoder[tid] =
target_model_seq_lens_decoder[tid] - need_compute_token;
step_idx[tid] = target_model_step_idx[tid] - need_compute_token;
// Prepare draft input tokens from accepted tokens
for (int i = 0; i < accept_num_now; i++) {
draft_tokens_now[i] = accept_tokens_now[i];
const int pre_id_pos =
target_model_step_idx[tid] - (accept_num_now - i);
pre_ids_now[pre_id_pos] = accept_tokens_now[i];
}
seq_lens_this_time[tid] = accept_num_now;
}
}
}
__syncthreads();
int64_t not_stop_flag_sum = BlockReduce(temp_storage).Sum(not_stop_flag);
if (tid == 0) {
@@ -262,114 +142,6 @@ __global__ void draft_model_preprocess_kernel(
}
}
void DispatchRunner(const cudaStream_t& stream,
int64_t* draft_tokens,
int64_t* input_ids,
bool* stop_flags,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
bool* not_need_stop,
bool* is_block_step,
bool* batch_drop,
int64_t* pre_ids,
int* mask_rollback,
int* recompute_token_num,
const int64_t* accept_tokens,
const int* accept_num,
const int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx,
const bool* base_model_stop_flags,
const bool* base_model_is_block_step,
int64_t* base_model_draft_tokens,
const int bsz,
const int num_model_step,
const int accept_tokens_len,
const int draft_tokens_len,
const int input_ids_len,
const int base_model_draft_tokens_len,
const int pre_ids_len,
const bool truncate_first_token,
const bool splitwise_prefill,
const bool kvcache_scheduler_v1) {
DISPATCH_BLOCKSIZE(512, {
DISPATCH_TRUNCATE_FIRST_TOKEN(truncate_first_token, TRUNCATE_FIRST_TOKEN, {
DISPATCH_KVCACHE_SCHEDULER(kvcache_scheduler_v1, KVCACHE_SCHEDULER_V1, {
DISPATCH_SPLITWISE_PREFILL(splitwise_prefill, SPLITWISE_PREFILL, {
if constexpr (SPLITWISE_PREFILL) {
process_splitwise_prefill<BlockSize,
TRUNCATE_FIRST_TOKEN,
KVCACHE_SCHEDULER_V1>
<<<1, BlockSize, 0, stream>>>(draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
not_need_stop,
is_block_step,
batch_drop,
pre_ids,
accept_tokens,
accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
bsz,
num_model_step,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
pre_ids_len);
} else {
draft_model_preprocess_kernel<BlockSize,
TRUNCATE_FIRST_TOKEN,
KVCACHE_SCHEDULER_V1>
<<<1, BlockSize, 0, stream>>>(draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
not_need_stop,
is_block_step,
batch_drop,
pre_ids,
mask_rollback,
recompute_token_num,
accept_tokens,
accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
bsz,
num_model_step,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
pre_ids_len);
}
});
});
});
});
}
void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& input_ids,
const paddle::Tensor& stop_flags,
@@ -378,69 +150,61 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& is_block_step,
const paddle::Tensor& batch_drop,
const paddle::Tensor& pre_ids,
const paddle::Tensor& mask_rollback,
const paddle::Tensor& recompute_token_num,
const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& base_model_seq_lens_this_time,
const paddle::Tensor& base_model_seq_lens_encoder,
const paddle::Tensor& base_model_seq_lens_decoder,
const paddle::Tensor& base_model_step_idx,
const paddle::Tensor& base_model_stop_flags,
const paddle::Tensor& base_model_is_block_step,
const paddle::Tensor& base_model_draft_tokens,
const paddle::Tensor& target_model_seq_lens_encoder,
const paddle::Tensor& target_model_seq_lens_decoder,
const paddle::Tensor& target_model_step_idx,
const paddle::Tensor& target_model_stop_flags,
const paddle::Tensor& max_dec_len,
const paddle::Tensor& target_model_draft_tokens,
const int num_model_step,
const bool truncate_first_token,
const bool splitwise_prefill,
const bool kvcache_scheduler_v1) {
const bool is_splitwise_prefill) {
constexpr int kBlockSize = 1024;
int real_bsz = seq_lens_this_time.shape()[0];
PADDLE_ENFORCE_LE(
real_bsz,
kBlockSize,
phi::errors::InvalidArgument(
"draft_model_preprocess: real_bsz (%d) exceeds kBlockSize (%d).",
real_bsz,
kBlockSize));
int accept_tokens_len = accept_tokens.shape()[1];
int input_ids_len = input_ids.shape()[1];
int draft_tokens_len = draft_tokens.shape()[1];
int pre_ids_len = pre_ids.shape()[1];
auto cu_stream = seq_lens_this_time.stream();
constexpr int BlockSize = 512;
int base_model_draft_tokens_len = base_model_draft_tokens.shape()[1];
int target_model_draft_tokens_len = target_model_draft_tokens.shape()[1];
auto not_need_stop_gpu =
not_need_stop.copy_to(seq_lens_this_time.place(), false);
DispatchRunner(cu_stream,
const_cast<int64_t*>(draft_tokens.data<int64_t>()),
const_cast<int64_t*>(input_ids.data<int64_t>()),
const_cast<bool*>(stop_flags.data<bool>()),
const_cast<int*>(seq_lens_this_time.data<int>()),
const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<int64_t*>(step_idx.data<int64_t>()),
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
const_cast<bool*>(is_block_step.data<bool>()),
const_cast<bool*>(batch_drop.data<bool>()),
const_cast<int64_t*>(pre_ids.data<int64_t>()),
const_cast<int*>(mask_rollback.data<int>()),
const_cast<int*>(recompute_token_num.data<int>()),
accept_tokens.data<int64_t>(),
accept_num.data<int>(),
base_model_seq_lens_this_time.data<int>(),
base_model_seq_lens_encoder.data<int>(),
base_model_seq_lens_decoder.data<int>(),
base_model_step_idx.data<int64_t>(),
base_model_stop_flags.data<bool>(),
base_model_is_block_step.data<bool>(),
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
real_bsz,
num_model_step,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
pre_ids_len,
truncate_first_token,
splitwise_prefill,
kvcache_scheduler_v1);
draft_model_preprocess_kernel<kBlockSize><<<1, kBlockSize, 0, cu_stream>>>(
const_cast<int64_t*>(draft_tokens.data<int64_t>()),
const_cast<int64_t*>(input_ids.data<int64_t>()),
const_cast<bool*>(stop_flags.data<bool>()),
const_cast<int*>(seq_lens_this_time.data<int>()),
const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<int64_t*>(step_idx.data<int64_t>()),
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
const_cast<int64_t*>(pre_ids.data<int64_t>()),
accept_tokens.data<int64_t>(),
accept_num.data<int>(),
target_model_seq_lens_encoder.data<int>(),
target_model_seq_lens_decoder.data<int>(),
target_model_step_idx.data<int64_t>(),
target_model_stop_flags.data<bool>(),
max_dec_len.data<int64_t>(),
const_cast<int64_t*>(target_model_draft_tokens.data<int64_t>()),
real_bsz,
num_model_step,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
target_model_draft_tokens_len,
pre_ids_len,
is_splitwise_prefill);
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
@@ -456,20 +220,15 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
"seq_lens_decoder",
"step_idx",
"not_need_stop",
"is_block_step",
"batch_drop",
"pre_ids",
"mask_rollback",
"recompute_token_num",
"accept_tokens",
"accept_num",
"base_model_seq_lens_this_time",
"base_model_seq_lens_encoder",
"base_model_seq_lens_decoder",
"base_model_step_idx",
"base_model_stop_flags",
"base_model_is_block_step",
"base_model_draft_tokens"})
"target_model_seq_lens_encoder",
"target_model_seq_lens_decoder",
"target_model_step_idx",
"target_model_stop_flags",
"max_dec_len",
"target_model_draft_tokens"})
.Outputs({"draft_tokens_out",
"input_ids_out",
"stop_flags_out",
@@ -478,12 +237,8 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
"seq_lens_decoder_out",
"step_idx_out",
"not_need_stop_out",
"batch_drop_out",
"pre_ids_out"})
.Attrs({"num_model_step: int",
"truncate_first_token: bool",
"splitwise_prefill: bool",
"kvcache_scheduler_v1: bool"})
.Attrs({"num_model_step: int", "is_splitwise_prefill: bool"})
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
{"input_ids", "input_ids_out"},
{"stop_flags", "stop_flags_out"},
@@ -492,6 +247,5 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"step_idx", "step_idx_out"},
{"not_need_stop", "not_need_stop_out"},
{"batch_drop", "batch_drop_out"},
{"pre_ids", "pre_ids_out"}})
.SetKernelFn(PD_KERNEL(DraftModelPreprocess));
@@ -54,43 +54,31 @@ __global__ void draft_model_update_kernel(const int64_t* inter_next_tokens,
auto seq_len_decoder = seq_lens_decoder[tid];
// 1. update step_idx && seq_lens_dec
if (!stop_flags[tid] /* seq_lens_decoder > 0 or seq_lens_encoder > 0 */) {
if (!stop_flags[tid]) {
int64_t token_this_time = -1;
// decoder step
if (seq_len_decoder > 0 && seq_len_encoder <= 0) {
seq_lens_decoder[tid] += seq_len_this_time;
token_this_time = next_tokens_start[seq_len_this_time - 1];
draft_token_now[0] = next_tokens_start[seq_len_this_time - 1];
base_model_draft_tokens_now[substep + 1] = token_this_time;
step_idx[tid] += seq_len_this_time;
pre_ids_now[step_idx[tid]] = token_this_time;
} else {
if (seq_len_encoder > 0) {
token_this_time = next_tokens_start[0];
// seq_lens_decoder[tid] = seq_lens_encoder[tid];
seq_lens_decoder[tid] = seq_len_encoder + seq_len_decoder;
seq_lens_encoder[tid] = 0;
pre_ids_now[1] = token_this_time;
step_idx[tid] += 1;
draft_token_now[0] = token_this_time;
base_model_draft_tokens_now[substep + 1] = token_this_time;
if (step_idx[tid] < max_dec_len[tid]) {
base_model_draft_tokens_now[substep + 1] = token_this_time;
}
} else if (seq_len_decoder > 0) {
if (step_idx[tid] >= max_dec_len[tid] - 1) {
// If up to max_dec_len -1. Recompute but not update.
base_model_draft_tokens_now[substep + 1] = -1;
} else {
seq_lens_decoder[tid] += seq_len_this_time;
token_this_time = next_tokens_start[seq_len_this_time - 1];
draft_token_now[0] = next_tokens_start[seq_len_this_time - 1];
base_model_draft_tokens_now[substep + 1] = token_this_time;
step_idx[tid] += seq_len_this_time;
pre_ids_now[step_idx[tid]] = token_this_time;
}
}
// multi_end
// TODO(liuzichang): Don't check eos in future
if (is_in_end(token_this_time, end_ids, end_ids_len) ||
prefill_one_step_stop) {
stop_flags[tid] = true;
stop_flag_now_int = 1;
// max_dec_len
} else if (step_idx[tid] >= max_dec_len[tid]) {
stop_flags[tid] = true;
draft_token_now[seq_len_this_time - 1] = end_ids[0];
base_model_draft_tokens_now[substep + 1] = end_ids[0];
stop_flag_now_int = 1;
}
} else {
draft_token_now[0] = -1;
base_model_draft_tokens_now[substep + 1] = -1;
@@ -0,0 +1,97 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
/**
* @file naive_update_model_status.cu
* @brief Post-sampling update for NAIVE speculative decoding mode.
*
* Responsibilities (one thread per batch slot, <<<1, 1024>>>):
* 1. Scatter sampled token into accept_tokens[i, 0] using cu_seqlens_q_output
* to map from the packed next_tokens array to the per-batch output.
* 2. Set accept_num[i] = 1 for running slots (seq_lens_this_time > 0), else
* 0.
* 3. Set seq_lens_this_time[i] = 1 for running, 0 for stopped/paused.
*
* Running slots are identified by seq_lens_this_time[i] > 0, which is already
* zeroed for stopped/paused slots by pre_process before this kernel runs.
*
* The packed next_tokens layout mirrors cu_seqlens_q_output:
* next_tokens[cu_seqlens_q_output[i] .. cu_seqlens_q_output[i+1]-1]
* are the output tokens for request i (exactly 1 per running slot in naive
* decode; 0 for stopped/encoder-only slots).
*/
template <int THREADBLOCK_SIZE>
__global__ void naive_update_model_status_kernel(int64_t *accept_tokens,
int *accept_num,
int *seq_lens_this_time,
const int64_t *next_tokens,
const int *cu_seqlens_q_output,
int real_bsz,
int max_step_tokens) {
int bid = threadIdx.x;
if (bid >= real_bsz) return;
if (seq_lens_this_time[bid] > 0) {
// Write the last (and only) sampled token to accept_tokens[bid, 0]
accept_tokens[bid * max_step_tokens] =
next_tokens[cu_seqlens_q_output[bid + 1] - 1];
accept_num[bid] = 1;
seq_lens_this_time[bid] = 1;
} else {
accept_num[bid] = 0;
seq_lens_this_time[bid] = 0;
}
}
void NaiveUpdateModelStatus(const paddle::Tensor &accept_tokens,
const paddle::Tensor &accept_num,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &next_tokens,
const paddle::Tensor &cu_seqlens_q_output) {
constexpr int kBlockSize = 1024;
const int real_bsz = seq_lens_this_time.shape()[0];
PADDLE_ENFORCE_LE(
real_bsz,
kBlockSize,
phi::errors::InvalidArgument(
"naive_update_model_status: real_bsz (%d) must be <= %d.",
real_bsz,
kBlockSize));
const int max_step_tokens = accept_tokens.shape()[1];
auto cu_stream = seq_lens_this_time.stream();
naive_update_model_status_kernel<kBlockSize><<<1, kBlockSize, 0, cu_stream>>>(
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
next_tokens.data<int64_t>(),
cu_seqlens_q_output.data<int>(),
real_bsz,
max_step_tokens);
}
PD_BUILD_STATIC_OP(naive_update_model_status)
.Inputs({"accept_tokens",
"accept_num",
"seq_lens_this_time",
"next_tokens",
"cu_seqlens_q_output"})
.Attrs({})
.Outputs({"accept_tokens_out", "accept_num_out", "seq_lens_this_time_out"})
.SetInplaceMap({{"accept_tokens", "accept_tokens_out"},
{"accept_num", "accept_num_out"},
{"seq_lens_this_time", "seq_lens_this_time_out"}})
.SetKernelFn(PD_KERNEL(NaiveUpdateModelStatus));
@@ -43,7 +43,6 @@ __global__ void unified_update_model_status_kernel(int *seq_lens_encoder,
bool *has_running_seqs,
int *mask_rollback,
int64_t *step_input_ids,
int *adaptive_step_input_len,
int64_t *step_output_ids,
int *step_output_len,
bool *stop_flags,
@@ -58,36 +57,25 @@ __global__ void unified_update_model_status_kernel(int *seq_lens_encoder,
int max_bsz,
int max_step_tokens,
int max_model_len,
int num_end_tokens,
bool is_naive_mode,
bool prefill_one_step_stop) {
int num_end_tokens) {
const int batch_id = blockIdx.x * BLOCK_SIZE + threadIdx.x;
const bool is_valid_slot = batch_id < max_bsz;
int stop_flag_int = 0;
if (is_valid_slot) {
// Read state
// 1. Read state
int cur_seq_len_encoder = seq_lens_encoder[batch_id];
int cur_seq_len_decoder = seq_lens_decoder[batch_id];
bool cur_stop_flag = stop_flags[batch_id];
int output_len = 0;
int output_len = step_output_len[batch_id];
int64_t cur_step_idx = step_idx[batch_id];
bool cur_is_paused = is_paused[batch_id];
int64_t prompt_len = prompt_lens[batch_id];
bool is_running = !cur_stop_flag && !cur_is_paused;
// Compute output length
if (is_running) {
if (is_naive_mode) {
output_len = 1;
} else {
output_len = step_output_len[batch_id];
}
}
// EOS detection
// 2. EOS detection
if (is_running && output_len > 0) {
bool hit_stop = false;
int64_t *output_ids = &step_output_ids[batch_id * max_step_tokens];
for (int i = 0; i < output_len; i++) {
@@ -100,63 +88,56 @@ __global__ void unified_update_model_status_kernel(int *seq_lens_encoder,
if (!is_eos) output_ids[i] = end_tokens[0];
output_len = i + 1;
cur_stop_flag = true;
hit_stop = true;
break;
}
}
if (!hit_stop && prefill_one_step_stop && cur_seq_len_encoder > 0) {
cur_stop_flag = true;
}
}
// Update state and write back
if (is_running) {
if (cur_stop_flag) {
stop_flag_int = 1;
if (output_len == 0) cur_seq_len_decoder = 0;
stop_flags[batch_id] = true;
mask_rollback[batch_id] = 0;
} else if (cur_seq_len_encoder == 0) {
// 3. Update state and write back
if (cur_seq_len_encoder > 0) {
cur_seq_len_decoder += cur_seq_len_encoder;
cur_seq_len_encoder = 0;
} else if (cur_seq_len_decoder > 0) {
cur_seq_len_decoder += output_len;
mask_rollback[batch_id] = seq_lens_this_time[batch_id] - output_len;
} else {
mask_rollback[batch_id] = 0;
}
if (cur_seq_len_encoder > 0) {
cur_seq_len_decoder += cur_seq_len_encoder;
cur_seq_len_encoder = 0;
if (cur_stop_flag) {
// It should clear seq_lens_decoder in next step for save_output
stop_flag_int = 1;
stop_flags[batch_id] = true;
mask_rollback[batch_id] = 0;
}
// 4. Update model status
seq_lens_encoder[batch_id] = cur_seq_len_encoder;
seq_lens_decoder[batch_id] = cur_seq_len_decoder;
step_output_len[batch_id] = output_len;
step_idx[batch_id] = cur_step_idx;
// Write history to token_ids_all
if (cur_step_idx > 0 && output_len > 0) {
// Bounds check: highest write index is prompt_lens + cur_step_idx
if (prompt_lens[batch_id] + cur_step_idx < max_model_len) {
// 5. Write history to token_ids_all (forward loop: position base+k =
// output_ids[k])
if (output_len > 0) {
// Bounds check: highest write index is prompt_len + cur_step_idx
if (prompt_len + cur_step_idx < max_model_len) {
int64_t *token_ids_all_now =
&token_ids_all[batch_id * max_model_len + prompt_lens[batch_id]];
&token_ids_all[batch_id * max_model_len + prompt_len];
int64_t *output_ids = &step_output_ids[batch_id * max_step_tokens];
int64_t base = cur_step_idx - output_len + 1;
for (int i = 0; i < output_len; i++) {
token_ids_all_now[cur_step_idx - i] =
output_ids[output_len - 1 - i];
token_ids_all_now[base + i] = output_ids[i];
}
}
}
// Setup next input
// 6. Prepare next step input[0]
if (output_len > 0) {
step_input_ids[batch_id * max_step_tokens] =
step_output_ids[batch_id * max_step_tokens + output_len - 1];
}
if (is_naive_mode) {
seq_lens_this_time[batch_id] = cur_stop_flag ? 0 : 1;
}
} else if (batch_id >= real_bsz) {
// Padding slot: just count as stopped, don't modify state
stop_flag_int = 1;
@@ -164,6 +145,7 @@ __global__ void unified_update_model_status_kernel(int *seq_lens_encoder,
// Stopped or paused slot (batch_id < real_bsz)
stop_flag_int = 1;
stop_flags[batch_id] = true;
seq_lens_encoder[batch_id] = 0;
seq_lens_decoder[batch_id] = 0;
seq_lens_this_time[batch_id] = 0;
step_output_len[batch_id] = 0;
@@ -175,11 +157,9 @@ __global__ void unified_update_model_status_kernel(int *seq_lens_encoder,
typedef cub::BlockReduce<int64_t, BLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
// printf("stop_flag_now_int %d \n", stop_flag_int);
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_int);
if (threadIdx.x == 0) {
// printf("stop_sum %d \n", stop_sum);
has_running_seqs[0] = stop_sum < max_bsz;
}
}
@@ -189,7 +169,6 @@ void UnifiedUpdateModelStatus(const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &has_running_seqs,
const paddle::Tensor &step_input_ids,
const paddle::Tensor &adaptive_step_input_len,
const paddle::Tensor &step_output_ids,
const paddle::Tensor &step_output_len,
const paddle::Tensor &stop_flags,
@@ -200,9 +179,7 @@ void UnifiedUpdateModelStatus(const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &prompt_lens,
const paddle::Tensor &step_idx,
const paddle::Tensor &end_tokens,
const paddle::Tensor &max_dec_len,
const bool is_naive_mode,
const bool prefill_one_step_stop) {
const paddle::Tensor &max_dec_len) {
const int real_bsz = seq_lens_this_time.shape()[0];
const int max_bsz = stop_flags.shape()[0];
PADDLE_ENFORCE_LE(
@@ -228,7 +205,6 @@ void UnifiedUpdateModelStatus(const paddle::Tensor &seq_lens_encoder,
const_cast<bool *>(has_running_seqs_gpu.data<bool>()),
const_cast<int *>(mask_rollback.data<int>()),
const_cast<int64_t *>(step_input_ids.data<int64_t>()),
const_cast<int *>(adaptive_step_input_len.data<int>()),
const_cast<int64_t *>(step_output_ids.data<int64_t>()),
const_cast<int *>(step_output_len.data<int>()),
const_cast<bool *>(stop_flags.data<bool>()),
@@ -243,9 +219,7 @@ void UnifiedUpdateModelStatus(const paddle::Tensor &seq_lens_encoder,
max_bsz,
max_step_tokens,
max_model_len,
num_end_tokens,
is_naive_mode,
prefill_one_step_stop);
num_end_tokens);
// Copy result back to CPU
auto has_running_seqs_cpu =
has_running_seqs_gpu.copy_to(has_running_seqs.place(), false);
@@ -258,7 +232,6 @@ PD_BUILD_STATIC_OP(unified_update_model_status)
"seq_lens_decoder",
"has_running_seqs",
"step_input_ids",
"adaptive_step_input_len",
"step_output_ids",
"step_output_len",
"stop_flags",
@@ -270,12 +243,10 @@ PD_BUILD_STATIC_OP(unified_update_model_status)
"step_idx",
"end_tokens",
"max_dec_len"})
.Attrs({"is_naive_mode: bool", "prefill_one_step_stop: bool"})
.Outputs({"seq_lens_encoder_out",
"seq_lens_decoder_out",
"has_running_seqs_out",
"step_input_ids_out",
"adaptive_step_input_len_out",
"step_output_ids_out",
"step_output_len_out",
"stop_flags_out",
@@ -287,7 +258,6 @@ PD_BUILD_STATIC_OP(unified_update_model_status)
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"has_running_seqs", "has_running_seqs_out"},
{"step_input_ids", "step_input_ids_out"},
{"adaptive_step_input_len", "adaptive_step_input_len_out"},
{"step_output_ids", "step_output_ids_out"},
{"step_output_len", "step_output_len_out"},
{"stop_flags", "stop_flags_out"},