mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Speculative Decoding] refactor MTP and optimize spec-decoding postprocess (#6973)
* support new mtp * refactor(speculate_decoding and mtp): optimize mtp sturcture logic. Update spec-branch status-process * fix cuda-graph for spec-decoding * fix xpu mtp and fix some note * fix unittest and optmize note * fix model status update in eos-branch
This commit is contained in:
@@ -875,7 +875,6 @@ void UnifiedUpdateModelStatus(const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& has_running_seqs,
|
||||
const paddle::Tensor& step_input_ids,
|
||||
const paddle::Tensor& adaptive_step_input_len,
|
||||
const paddle::Tensor& step_output_ids,
|
||||
const paddle::Tensor& step_output_len,
|
||||
const paddle::Tensor& stop_flags,
|
||||
@@ -886,9 +885,13 @@ void UnifiedUpdateModelStatus(const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& prompt_lens,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& end_tokens,
|
||||
const paddle::Tensor& max_dec_len,
|
||||
const bool is_naive_mode,
|
||||
const bool prefill_one_step_stop);
|
||||
const paddle::Tensor& max_dec_len);
|
||||
|
||||
void NaiveUpdateModelStatus(const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& next_tokens,
|
||||
const paddle::Tensor& cu_seqlens_q_output);
|
||||
|
||||
void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor& token_ids_all,
|
||||
const paddle::Tensor& prompt_lens,
|
||||
@@ -971,24 +974,17 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& is_block_step,
|
||||
const paddle::Tensor& batch_drop,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& mask_rollback,
|
||||
const paddle::Tensor& recompute_token_num,
|
||||
const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||
const paddle::Tensor& base_model_seq_lens_encoder,
|
||||
const paddle::Tensor& base_model_seq_lens_decoder,
|
||||
const paddle::Tensor& base_model_step_idx,
|
||||
const paddle::Tensor& base_model_stop_flags,
|
||||
const paddle::Tensor& base_model_is_block_step,
|
||||
const paddle::Tensor& base_model_draft_tokens,
|
||||
const int max_draft_token,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill,
|
||||
const bool kvcache_scheduler_v1);
|
||||
const paddle::Tensor& target_model_seq_lens_encoder,
|
||||
const paddle::Tensor& target_model_seq_lens_decoder,
|
||||
const paddle::Tensor& target_model_step_idx,
|
||||
const paddle::Tensor& target_model_stop_flags,
|
||||
const paddle::Tensor& max_dec_len,
|
||||
const paddle::Tensor& target_model_draft_tokens,
|
||||
const int num_model_step,
|
||||
const bool is_splitwise_prefill);
|
||||
|
||||
void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
|
||||
const paddle::Tensor& draft_tokens,
|
||||
@@ -1785,6 +1781,10 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
&UnifiedUpdateModelStatus,
|
||||
"unified_update_model_status function");
|
||||
|
||||
m.def("naive_update_model_status",
|
||||
&NaiveUpdateModelStatus,
|
||||
"naive_update_model_status function");
|
||||
|
||||
m.def("speculate_set_value_by_flags_and_idx",
|
||||
&SpeculateSetValueByFlagsAndIdx,
|
||||
"speculate_set_value_by_flags_and_idx function");
|
||||
|
||||
@@ -15,119 +15,16 @@
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#define DISPATCH_BLOCKSIZE(BLOCK_SIZE, ...) \
|
||||
do { \
|
||||
constexpr int BlockSize = BLOCK_SIZE; \
|
||||
__VA_ARGS__; \
|
||||
} while (0)
|
||||
|
||||
#define DISPATCH_TRUNCATE_FIRST_TOKEN( \
|
||||
truncate_first_token, TRUNCATE_FIRST_TOKEN, ...) \
|
||||
do { \
|
||||
if (truncate_first_token) { \
|
||||
constexpr bool TRUNCATE_FIRST_TOKEN = true; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr bool TRUNCATE_FIRST_TOKEN = false; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define DISPATCH_KVCACHE_SCHEDULER( \
|
||||
kvcache_scheduler_v1, KVCACHE_SCHEDULER_V1, ...) \
|
||||
do { \
|
||||
if (kvcache_scheduler_v1) { \
|
||||
constexpr bool KVCACHE_SCHEDULER_V1 = true; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr bool KVCACHE_SCHEDULER_V1 = false; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define DISPATCH_SPLITWISE_PREFILL(splitwise_prefill, SPLITWISE_PREFILL, ...) \
|
||||
do { \
|
||||
if (splitwise_prefill) { \
|
||||
constexpr bool SPLITWISE_PREFILL = true; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr bool SPLITWISE_PREFILL = false; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
template <int THREADBLOCK_SIZE,
|
||||
bool TRUNCATE_FIRST_TOKEN,
|
||||
bool KVCACHE_SCHEDULER_V1>
|
||||
__global__ void process_splitwise_prefill(
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* is_block_step,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* base_model_seq_lens_decoder,
|
||||
const int64_t* base_model_step_idx,
|
||||
const bool* base_model_stop_flags,
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
const int bsz,
|
||||
const int num_model_step,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len,
|
||||
const int pre_ids_len) {
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
int64_t not_stop_flag = 0;
|
||||
|
||||
int tid = threadIdx.x;
|
||||
if (tid < bsz) {
|
||||
int base_model_step_idx_now = base_model_step_idx[tid];
|
||||
auto* input_ids_now = input_ids + tid * input_ids_len;
|
||||
auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len;
|
||||
if (seq_lens_encoder[tid] > 0) {
|
||||
not_stop_flag = 1;
|
||||
int seq_len_encoder = seq_lens_encoder[tid];
|
||||
stop_flags[tid] = false;
|
||||
int64_t base_model_first_token = accept_tokens_now[0];
|
||||
int position = seq_len_encoder;
|
||||
if (TRUNCATE_FIRST_TOKEN) {
|
||||
input_ids_now[position - 1] = base_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder;
|
||||
} else {
|
||||
input_ids_now[position] = base_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder + 1;
|
||||
}
|
||||
} else {
|
||||
stop_flags[tid] = true;
|
||||
seq_lens_this_time[tid] = 0;
|
||||
seq_lens_decoder[tid] = 0;
|
||||
seq_lens_encoder[tid] = 0;
|
||||
not_stop_flag = 0;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int64_t not_stop_flag_sum = BlockReduce(temp_storage).Sum(not_stop_flag);
|
||||
if (tid == 0) {
|
||||
not_need_stop[0] = not_stop_flag_sum > 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <int THREADBLOCK_SIZE,
|
||||
bool TRUNCATE_FIRST_TOKEN,
|
||||
bool KVCACHE_SCHEDULER_V1>
|
||||
// Main draft preprocess kernel.
|
||||
// MTP state (seq_lens_decoder, step_idx) is "shadow state":
|
||||
// - Initialized from target model state each round
|
||||
// - Used for MTP forward, but not committed until verify
|
||||
// - No rollback needed since it's always re-initialized
|
||||
//
|
||||
// is_splitwise_prefill: set true on P-D disaggregated prefill node.
|
||||
// In that mode, only prefill requests (seq_lens_encoder > 0) run MTP;
|
||||
// decode requests are marked stopped and skipped.
|
||||
template <int THREADBLOCK_SIZE>
|
||||
__global__ void draft_model_preprocess_kernel(
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
@@ -137,27 +34,23 @@ __global__ void draft_model_preprocess_kernel(
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* is_block_step,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
int* mask_rollback,
|
||||
int* recompute_token_num,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* base_model_seq_lens_decoder,
|
||||
const int64_t* base_model_step_idx,
|
||||
const bool* base_model_stop_flags,
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
const int* target_model_seq_lens_encoder,
|
||||
const int* target_model_seq_lens_decoder,
|
||||
const int64_t* target_model_step_idx,
|
||||
const bool* target_model_stop_flags,
|
||||
const int64_t* max_dec_len,
|
||||
int64_t* target_model_draft_tokens,
|
||||
const int bsz,
|
||||
const int num_model_step,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len,
|
||||
const int pre_ids_len) {
|
||||
const int target_model_draft_tokens_len,
|
||||
const int pre_ids_len,
|
||||
const bool is_splitwise_prefill) {
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
int64_t not_stop_flag = 0;
|
||||
@@ -165,96 +58,83 @@ __global__ void draft_model_preprocess_kernel(
|
||||
int tid = threadIdx.x;
|
||||
|
||||
if (tid < bsz) {
|
||||
const int32_t base_model_step_idx_now = base_model_step_idx[tid];
|
||||
auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len;
|
||||
auto* draft_tokens_now = draft_tokens + tid * draft_tokens_len;
|
||||
const int32_t accept_num_now = accept_num[tid];
|
||||
auto* input_ids_now = input_ids + tid * input_ids_len;
|
||||
auto* base_model_draft_tokens_now =
|
||||
base_model_draft_tokens + tid * base_model_draft_tokens_len;
|
||||
auto base_model_seq_len_decoder = base_model_seq_lens_decoder[tid];
|
||||
const int32_t base_model_seq_len_this_time =
|
||||
base_model_seq_lens_this_time[tid];
|
||||
auto* target_model_draft_tokens_now =
|
||||
target_model_draft_tokens + tid * target_model_draft_tokens_len;
|
||||
auto* pre_ids_now = pre_ids + tid * pre_ids_len;
|
||||
const auto target_step = target_model_step_idx[tid];
|
||||
auto seq_len_encoder = seq_lens_encoder[tid];
|
||||
|
||||
// Clear target_model_draft_tokens (keep first token)
|
||||
#pragma unroll
|
||||
for (int i = 1; i < base_model_draft_tokens_len; i++) {
|
||||
base_model_draft_tokens_now[i] = -1;
|
||||
for (int i = 1; i < target_model_draft_tokens_len; i++) {
|
||||
target_model_draft_tokens_now[i] = -1;
|
||||
}
|
||||
|
||||
// 1. process block_step situation
|
||||
// -- In v0 mode, block_step will drop mtp query.
|
||||
// -- In v1 mode, block_step will continue to infer.
|
||||
if constexpr (KVCACHE_SCHEDULER_V1) {
|
||||
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
|
||||
stop_flags[tid] = true;
|
||||
is_block_step[tid] = true;
|
||||
// Need to continue infer
|
||||
}
|
||||
} else {
|
||||
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
|
||||
batch_drop[tid] = true;
|
||||
stop_flags[tid] = true;
|
||||
}
|
||||
// ============================================================
|
||||
// Decision: Should MTP/Draft model run?
|
||||
// ============================================================
|
||||
bool should_skip = false;
|
||||
|
||||
// Target model stopped
|
||||
if (target_model_stop_flags[tid]) {
|
||||
should_skip = true;
|
||||
}
|
||||
|
||||
// 2. process normal query, not in any special case.
|
||||
if (!(base_model_stop_flags[tid] || batch_drop[tid])) {
|
||||
not_stop_flag = 1;
|
||||
// prefill generation
|
||||
if (seq_lens_encoder[tid] > 0) {
|
||||
// Can be extended to first few tokens
|
||||
int seq_len_encoder = seq_lens_encoder[tid];
|
||||
stop_flags[tid] = false;
|
||||
int64_t base_model_first_token = accept_tokens_now[0];
|
||||
pre_ids_now[0] = base_model_first_token;
|
||||
int position = seq_len_encoder;
|
||||
if (TRUNCATE_FIRST_TOKEN) {
|
||||
input_ids_now[position - 1] = base_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder;
|
||||
} else {
|
||||
input_ids_now[position] = base_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder + 1;
|
||||
}
|
||||
} else { // decode generation
|
||||
if constexpr (KVCACHE_SCHEDULER_V1) {
|
||||
// 3. try to recover mtp infer in V1 mode
|
||||
if (!base_model_is_block_step[tid] && is_block_step[tid]) {
|
||||
is_block_step[tid] = false;
|
||||
}
|
||||
}
|
||||
if (stop_flags[tid]) {
|
||||
stop_flags[tid] = false;
|
||||
// TODO: check
|
||||
seq_lens_decoder[tid] =
|
||||
base_model_seq_len_decoder - base_model_seq_len_this_time;
|
||||
step_idx[tid] =
|
||||
base_model_step_idx[tid] - base_model_seq_len_this_time;
|
||||
} else {
|
||||
// 2: Last base model generated token and first MTP token
|
||||
const int recompute_token_num_now = recompute_token_num[tid];
|
||||
seq_lens_decoder[tid] -= recompute_token_num_now;
|
||||
step_idx[tid] -= recompute_token_num_now;
|
||||
mask_rollback[tid] += recompute_token_num_now;
|
||||
// NOTE(liuzichang): Used for PD-split mode and future dynamic
|
||||
// strategies.
|
||||
recompute_token_num[tid] = num_model_step - 1;
|
||||
}
|
||||
for (int i = 0; i < accept_num_now; i++) {
|
||||
draft_tokens_now[i] = accept_tokens_now[i];
|
||||
const int pre_id_pos =
|
||||
base_model_step_idx[tid] - (accept_num_now - i);
|
||||
const int64_t accept_token = accept_tokens_now[i];
|
||||
pre_ids_now[pre_id_pos] = accept_token;
|
||||
}
|
||||
seq_lens_this_time[tid] = accept_num_now;
|
||||
}
|
||||
} else {
|
||||
// Near end of max_dec_len in no prefill node
|
||||
if ((not is_splitwise_prefill &&
|
||||
target_step + num_model_step >= max_dec_len[tid])) {
|
||||
should_skip = true;
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Execute based on decision
|
||||
// ============================================================
|
||||
if (should_skip) {
|
||||
stop_flags[tid] = true;
|
||||
seq_lens_this_time[tid] = 0;
|
||||
seq_lens_decoder[tid] = 0;
|
||||
seq_lens_encoder[tid] = 0;
|
||||
step_idx[tid] = 0;
|
||||
not_stop_flag = 0;
|
||||
} else {
|
||||
not_stop_flag = 1;
|
||||
stop_flags[tid] = false;
|
||||
|
||||
if (seq_len_encoder > 0) {
|
||||
// prefill | chunk_prefill | prompt_cache | recover after preempted
|
||||
int64_t target_model_first_token = accept_tokens_now[0];
|
||||
pre_ids_now[0] = target_model_first_token;
|
||||
|
||||
input_ids_now[seq_len_encoder - 1] = target_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder;
|
||||
|
||||
// Shadow state: initialize from target model (prefill just finished)
|
||||
step_idx[tid] = target_step - 1;
|
||||
} else {
|
||||
// Decode: MTP shadow state from target model
|
||||
// Shadow state is initialized from target model state each round
|
||||
// This is the key simplification: no rollback needed
|
||||
int32_t need_compute_token = accept_num_now;
|
||||
seq_lens_decoder[tid] =
|
||||
target_model_seq_lens_decoder[tid] - need_compute_token;
|
||||
step_idx[tid] = target_model_step_idx[tid] - need_compute_token;
|
||||
|
||||
// Prepare draft input tokens from accepted tokens
|
||||
for (int i = 0; i < accept_num_now; i++) {
|
||||
draft_tokens_now[i] = accept_tokens_now[i];
|
||||
const int pre_id_pos =
|
||||
target_model_step_idx[tid] - (accept_num_now - i);
|
||||
pre_ids_now[pre_id_pos] = accept_tokens_now[i];
|
||||
}
|
||||
seq_lens_this_time[tid] = accept_num_now;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
int64_t not_stop_flag_sum = BlockReduce(temp_storage).Sum(not_stop_flag);
|
||||
if (tid == 0) {
|
||||
@@ -262,114 +142,6 @@ __global__ void draft_model_preprocess_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
void DispatchRunner(const cudaStream_t& stream,
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* is_block_step,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
int* mask_rollback,
|
||||
int* recompute_token_num,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* base_model_seq_lens_decoder,
|
||||
const int64_t* base_model_step_idx,
|
||||
const bool* base_model_stop_flags,
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
const int bsz,
|
||||
const int num_model_step,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len,
|
||||
const int pre_ids_len,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill,
|
||||
const bool kvcache_scheduler_v1) {
|
||||
DISPATCH_BLOCKSIZE(512, {
|
||||
DISPATCH_TRUNCATE_FIRST_TOKEN(truncate_first_token, TRUNCATE_FIRST_TOKEN, {
|
||||
DISPATCH_KVCACHE_SCHEDULER(kvcache_scheduler_v1, KVCACHE_SCHEDULER_V1, {
|
||||
DISPATCH_SPLITWISE_PREFILL(splitwise_prefill, SPLITWISE_PREFILL, {
|
||||
if constexpr (SPLITWISE_PREFILL) {
|
||||
process_splitwise_prefill<BlockSize,
|
||||
TRUNCATE_FIRST_TOKEN,
|
||||
KVCACHE_SCHEDULER_V1>
|
||||
<<<1, BlockSize, 0, stream>>>(draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
is_block_step,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len);
|
||||
} else {
|
||||
draft_model_preprocess_kernel<BlockSize,
|
||||
TRUNCATE_FIRST_TOKEN,
|
||||
KVCACHE_SCHEDULER_V1>
|
||||
<<<1, BlockSize, 0, stream>>>(draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
is_block_step,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
mask_rollback,
|
||||
recompute_token_num,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& stop_flags,
|
||||
@@ -378,69 +150,61 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& is_block_step,
|
||||
const paddle::Tensor& batch_drop,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& mask_rollback,
|
||||
const paddle::Tensor& recompute_token_num,
|
||||
const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||
const paddle::Tensor& base_model_seq_lens_encoder,
|
||||
const paddle::Tensor& base_model_seq_lens_decoder,
|
||||
const paddle::Tensor& base_model_step_idx,
|
||||
const paddle::Tensor& base_model_stop_flags,
|
||||
const paddle::Tensor& base_model_is_block_step,
|
||||
const paddle::Tensor& base_model_draft_tokens,
|
||||
const paddle::Tensor& target_model_seq_lens_encoder,
|
||||
const paddle::Tensor& target_model_seq_lens_decoder,
|
||||
const paddle::Tensor& target_model_step_idx,
|
||||
const paddle::Tensor& target_model_stop_flags,
|
||||
const paddle::Tensor& max_dec_len,
|
||||
const paddle::Tensor& target_model_draft_tokens,
|
||||
const int num_model_step,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill,
|
||||
const bool kvcache_scheduler_v1) {
|
||||
const bool is_splitwise_prefill) {
|
||||
constexpr int kBlockSize = 1024;
|
||||
int real_bsz = seq_lens_this_time.shape()[0];
|
||||
PADDLE_ENFORCE_LE(
|
||||
real_bsz,
|
||||
kBlockSize,
|
||||
phi::errors::InvalidArgument(
|
||||
"draft_model_preprocess: real_bsz (%d) exceeds kBlockSize (%d).",
|
||||
real_bsz,
|
||||
kBlockSize));
|
||||
int accept_tokens_len = accept_tokens.shape()[1];
|
||||
int input_ids_len = input_ids.shape()[1];
|
||||
int draft_tokens_len = draft_tokens.shape()[1];
|
||||
int pre_ids_len = pre_ids.shape()[1];
|
||||
auto cu_stream = seq_lens_this_time.stream();
|
||||
constexpr int BlockSize = 512;
|
||||
int base_model_draft_tokens_len = base_model_draft_tokens.shape()[1];
|
||||
int target_model_draft_tokens_len = target_model_draft_tokens.shape()[1];
|
||||
auto not_need_stop_gpu =
|
||||
not_need_stop.copy_to(seq_lens_this_time.place(), false);
|
||||
|
||||
DispatchRunner(cu_stream,
|
||||
const_cast<int64_t*>(draft_tokens.data<int64_t>()),
|
||||
const_cast<int64_t*>(input_ids.data<int64_t>()),
|
||||
const_cast<bool*>(stop_flags.data<bool>()),
|
||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int*>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int*>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<bool*>(is_block_step.data<bool>()),
|
||||
const_cast<bool*>(batch_drop.data<bool>()),
|
||||
const_cast<int64_t*>(pre_ids.data<int64_t>()),
|
||||
const_cast<int*>(mask_rollback.data<int>()),
|
||||
const_cast<int*>(recompute_token_num.data<int>()),
|
||||
accept_tokens.data<int64_t>(),
|
||||
accept_num.data<int>(),
|
||||
base_model_seq_lens_this_time.data<int>(),
|
||||
base_model_seq_lens_encoder.data<int>(),
|
||||
base_model_seq_lens_decoder.data<int>(),
|
||||
base_model_step_idx.data<int64_t>(),
|
||||
base_model_stop_flags.data<bool>(),
|
||||
base_model_is_block_step.data<bool>(),
|
||||
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
|
||||
real_bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
truncate_first_token,
|
||||
splitwise_prefill,
|
||||
kvcache_scheduler_v1);
|
||||
|
||||
draft_model_preprocess_kernel<kBlockSize><<<1, kBlockSize, 0, cu_stream>>>(
|
||||
const_cast<int64_t*>(draft_tokens.data<int64_t>()),
|
||||
const_cast<int64_t*>(input_ids.data<int64_t>()),
|
||||
const_cast<bool*>(stop_flags.data<bool>()),
|
||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int*>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int*>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<int64_t*>(pre_ids.data<int64_t>()),
|
||||
accept_tokens.data<int64_t>(),
|
||||
accept_num.data<int>(),
|
||||
target_model_seq_lens_encoder.data<int>(),
|
||||
target_model_seq_lens_decoder.data<int>(),
|
||||
target_model_step_idx.data<int64_t>(),
|
||||
target_model_stop_flags.data<bool>(),
|
||||
max_dec_len.data<int64_t>(),
|
||||
const_cast<int64_t*>(target_model_draft_tokens.data<int64_t>()),
|
||||
real_bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
target_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
is_splitwise_prefill);
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
|
||||
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
|
||||
@@ -456,20 +220,15 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
|
||||
"seq_lens_decoder",
|
||||
"step_idx",
|
||||
"not_need_stop",
|
||||
"is_block_step",
|
||||
"batch_drop",
|
||||
"pre_ids",
|
||||
"mask_rollback",
|
||||
"recompute_token_num",
|
||||
"accept_tokens",
|
||||
"accept_num",
|
||||
"base_model_seq_lens_this_time",
|
||||
"base_model_seq_lens_encoder",
|
||||
"base_model_seq_lens_decoder",
|
||||
"base_model_step_idx",
|
||||
"base_model_stop_flags",
|
||||
"base_model_is_block_step",
|
||||
"base_model_draft_tokens"})
|
||||
"target_model_seq_lens_encoder",
|
||||
"target_model_seq_lens_decoder",
|
||||
"target_model_step_idx",
|
||||
"target_model_stop_flags",
|
||||
"max_dec_len",
|
||||
"target_model_draft_tokens"})
|
||||
.Outputs({"draft_tokens_out",
|
||||
"input_ids_out",
|
||||
"stop_flags_out",
|
||||
@@ -478,12 +237,8 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
|
||||
"seq_lens_decoder_out",
|
||||
"step_idx_out",
|
||||
"not_need_stop_out",
|
||||
"batch_drop_out",
|
||||
"pre_ids_out"})
|
||||
.Attrs({"num_model_step: int",
|
||||
"truncate_first_token: bool",
|
||||
"splitwise_prefill: bool",
|
||||
"kvcache_scheduler_v1: bool"})
|
||||
.Attrs({"num_model_step: int", "is_splitwise_prefill: bool"})
|
||||
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
|
||||
{"input_ids", "input_ids_out"},
|
||||
{"stop_flags", "stop_flags_out"},
|
||||
@@ -492,6 +247,5 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"step_idx", "step_idx_out"},
|
||||
{"not_need_stop", "not_need_stop_out"},
|
||||
{"batch_drop", "batch_drop_out"},
|
||||
{"pre_ids", "pre_ids_out"}})
|
||||
.SetKernelFn(PD_KERNEL(DraftModelPreprocess));
|
||||
|
||||
@@ -54,43 +54,31 @@ __global__ void draft_model_update_kernel(const int64_t* inter_next_tokens,
|
||||
auto seq_len_decoder = seq_lens_decoder[tid];
|
||||
|
||||
// 1. update step_idx && seq_lens_dec
|
||||
if (!stop_flags[tid] /* seq_lens_decoder > 0 or seq_lens_encoder > 0 */) {
|
||||
if (!stop_flags[tid]) {
|
||||
int64_t token_this_time = -1;
|
||||
// decoder step
|
||||
if (seq_len_decoder > 0 && seq_len_encoder <= 0) {
|
||||
seq_lens_decoder[tid] += seq_len_this_time;
|
||||
token_this_time = next_tokens_start[seq_len_this_time - 1];
|
||||
draft_token_now[0] = next_tokens_start[seq_len_this_time - 1];
|
||||
base_model_draft_tokens_now[substep + 1] = token_this_time;
|
||||
step_idx[tid] += seq_len_this_time;
|
||||
pre_ids_now[step_idx[tid]] = token_this_time;
|
||||
|
||||
} else {
|
||||
if (seq_len_encoder > 0) {
|
||||
token_this_time = next_tokens_start[0];
|
||||
|
||||
// seq_lens_decoder[tid] = seq_lens_encoder[tid];
|
||||
seq_lens_decoder[tid] = seq_len_encoder + seq_len_decoder;
|
||||
seq_lens_encoder[tid] = 0;
|
||||
pre_ids_now[1] = token_this_time;
|
||||
step_idx[tid] += 1;
|
||||
draft_token_now[0] = token_this_time;
|
||||
base_model_draft_tokens_now[substep + 1] = token_this_time;
|
||||
if (step_idx[tid] < max_dec_len[tid]) {
|
||||
base_model_draft_tokens_now[substep + 1] = token_this_time;
|
||||
}
|
||||
} else if (seq_len_decoder > 0) {
|
||||
if (step_idx[tid] >= max_dec_len[tid] - 1) {
|
||||
// If up to max_dec_len -1. Recompute but not update.
|
||||
base_model_draft_tokens_now[substep + 1] = -1;
|
||||
} else {
|
||||
seq_lens_decoder[tid] += seq_len_this_time;
|
||||
token_this_time = next_tokens_start[seq_len_this_time - 1];
|
||||
draft_token_now[0] = next_tokens_start[seq_len_this_time - 1];
|
||||
base_model_draft_tokens_now[substep + 1] = token_this_time;
|
||||
step_idx[tid] += seq_len_this_time;
|
||||
pre_ids_now[step_idx[tid]] = token_this_time;
|
||||
}
|
||||
}
|
||||
|
||||
// multi_end
|
||||
// TODO(liuzichang): Don't check eos in future
|
||||
if (is_in_end(token_this_time, end_ids, end_ids_len) ||
|
||||
prefill_one_step_stop) {
|
||||
stop_flags[tid] = true;
|
||||
stop_flag_now_int = 1;
|
||||
// max_dec_len
|
||||
} else if (step_idx[tid] >= max_dec_len[tid]) {
|
||||
stop_flags[tid] = true;
|
||||
draft_token_now[seq_len_this_time - 1] = end_ids[0];
|
||||
base_model_draft_tokens_now[substep + 1] = end_ids[0];
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
|
||||
} else {
|
||||
draft_token_now[0] = -1;
|
||||
base_model_draft_tokens_now[substep + 1] = -1;
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
/**
|
||||
* @file naive_update_model_status.cu
|
||||
* @brief Post-sampling update for NAIVE speculative decoding mode.
|
||||
*
|
||||
* Responsibilities (one thread per batch slot, <<<1, 1024>>>):
|
||||
* 1. Scatter sampled token into accept_tokens[i, 0] using cu_seqlens_q_output
|
||||
* to map from the packed next_tokens array to the per-batch output.
|
||||
* 2. Set accept_num[i] = 1 for running slots (seq_lens_this_time > 0), else
|
||||
* 0.
|
||||
* 3. Set seq_lens_this_time[i] = 1 for running, 0 for stopped/paused.
|
||||
*
|
||||
* Running slots are identified by seq_lens_this_time[i] > 0, which is already
|
||||
* zeroed for stopped/paused slots by pre_process before this kernel runs.
|
||||
*
|
||||
* The packed next_tokens layout mirrors cu_seqlens_q_output:
|
||||
* next_tokens[cu_seqlens_q_output[i] .. cu_seqlens_q_output[i+1]-1]
|
||||
* are the output tokens for request i (exactly 1 per running slot in naive
|
||||
* decode; 0 for stopped/encoder-only slots).
|
||||
*/
|
||||
template <int THREADBLOCK_SIZE>
|
||||
__global__ void naive_update_model_status_kernel(int64_t *accept_tokens,
|
||||
int *accept_num,
|
||||
int *seq_lens_this_time,
|
||||
const int64_t *next_tokens,
|
||||
const int *cu_seqlens_q_output,
|
||||
int real_bsz,
|
||||
int max_step_tokens) {
|
||||
int bid = threadIdx.x;
|
||||
|
||||
if (bid >= real_bsz) return;
|
||||
if (seq_lens_this_time[bid] > 0) {
|
||||
// Write the last (and only) sampled token to accept_tokens[bid, 0]
|
||||
accept_tokens[bid * max_step_tokens] =
|
||||
next_tokens[cu_seqlens_q_output[bid + 1] - 1];
|
||||
accept_num[bid] = 1;
|
||||
seq_lens_this_time[bid] = 1;
|
||||
} else {
|
||||
accept_num[bid] = 0;
|
||||
seq_lens_this_time[bid] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
void NaiveUpdateModelStatus(const paddle::Tensor &accept_tokens,
|
||||
const paddle::Tensor &accept_num,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &next_tokens,
|
||||
const paddle::Tensor &cu_seqlens_q_output) {
|
||||
constexpr int kBlockSize = 1024;
|
||||
const int real_bsz = seq_lens_this_time.shape()[0];
|
||||
PADDLE_ENFORCE_LE(
|
||||
real_bsz,
|
||||
kBlockSize,
|
||||
phi::errors::InvalidArgument(
|
||||
"naive_update_model_status: real_bsz (%d) must be <= %d.",
|
||||
real_bsz,
|
||||
kBlockSize));
|
||||
const int max_step_tokens = accept_tokens.shape()[1];
|
||||
auto cu_stream = seq_lens_this_time.stream();
|
||||
|
||||
naive_update_model_status_kernel<kBlockSize><<<1, kBlockSize, 0, cu_stream>>>(
|
||||
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
||||
const_cast<int *>(accept_num.data<int>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
next_tokens.data<int64_t>(),
|
||||
cu_seqlens_q_output.data<int>(),
|
||||
real_bsz,
|
||||
max_step_tokens);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(naive_update_model_status)
|
||||
.Inputs({"accept_tokens",
|
||||
"accept_num",
|
||||
"seq_lens_this_time",
|
||||
"next_tokens",
|
||||
"cu_seqlens_q_output"})
|
||||
.Attrs({})
|
||||
.Outputs({"accept_tokens_out", "accept_num_out", "seq_lens_this_time_out"})
|
||||
.SetInplaceMap({{"accept_tokens", "accept_tokens_out"},
|
||||
{"accept_num", "accept_num_out"},
|
||||
{"seq_lens_this_time", "seq_lens_this_time_out"}})
|
||||
.SetKernelFn(PD_KERNEL(NaiveUpdateModelStatus));
|
||||
@@ -43,7 +43,6 @@ __global__ void unified_update_model_status_kernel(int *seq_lens_encoder,
|
||||
bool *has_running_seqs,
|
||||
int *mask_rollback,
|
||||
int64_t *step_input_ids,
|
||||
int *adaptive_step_input_len,
|
||||
int64_t *step_output_ids,
|
||||
int *step_output_len,
|
||||
bool *stop_flags,
|
||||
@@ -58,36 +57,25 @@ __global__ void unified_update_model_status_kernel(int *seq_lens_encoder,
|
||||
int max_bsz,
|
||||
int max_step_tokens,
|
||||
int max_model_len,
|
||||
int num_end_tokens,
|
||||
bool is_naive_mode,
|
||||
bool prefill_one_step_stop) {
|
||||
int num_end_tokens) {
|
||||
const int batch_id = blockIdx.x * BLOCK_SIZE + threadIdx.x;
|
||||
const bool is_valid_slot = batch_id < max_bsz;
|
||||
int stop_flag_int = 0;
|
||||
|
||||
if (is_valid_slot) {
|
||||
// Read state
|
||||
// 1. Read state
|
||||
int cur_seq_len_encoder = seq_lens_encoder[batch_id];
|
||||
int cur_seq_len_decoder = seq_lens_decoder[batch_id];
|
||||
bool cur_stop_flag = stop_flags[batch_id];
|
||||
int output_len = 0;
|
||||
int output_len = step_output_len[batch_id];
|
||||
int64_t cur_step_idx = step_idx[batch_id];
|
||||
bool cur_is_paused = is_paused[batch_id];
|
||||
int64_t prompt_len = prompt_lens[batch_id];
|
||||
|
||||
bool is_running = !cur_stop_flag && !cur_is_paused;
|
||||
|
||||
// Compute output length
|
||||
if (is_running) {
|
||||
if (is_naive_mode) {
|
||||
output_len = 1;
|
||||
} else {
|
||||
output_len = step_output_len[batch_id];
|
||||
}
|
||||
}
|
||||
|
||||
// EOS detection
|
||||
// 2. EOS detection
|
||||
if (is_running && output_len > 0) {
|
||||
bool hit_stop = false;
|
||||
int64_t *output_ids = &step_output_ids[batch_id * max_step_tokens];
|
||||
|
||||
for (int i = 0; i < output_len; i++) {
|
||||
@@ -100,63 +88,56 @@ __global__ void unified_update_model_status_kernel(int *seq_lens_encoder,
|
||||
if (!is_eos) output_ids[i] = end_tokens[0];
|
||||
output_len = i + 1;
|
||||
cur_stop_flag = true;
|
||||
hit_stop = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!hit_stop && prefill_one_step_stop && cur_seq_len_encoder > 0) {
|
||||
cur_stop_flag = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Update state and write back
|
||||
if (is_running) {
|
||||
if (cur_stop_flag) {
|
||||
stop_flag_int = 1;
|
||||
if (output_len == 0) cur_seq_len_decoder = 0;
|
||||
stop_flags[batch_id] = true;
|
||||
mask_rollback[batch_id] = 0;
|
||||
} else if (cur_seq_len_encoder == 0) {
|
||||
// 3. Update state and write back
|
||||
if (cur_seq_len_encoder > 0) {
|
||||
cur_seq_len_decoder += cur_seq_len_encoder;
|
||||
cur_seq_len_encoder = 0;
|
||||
} else if (cur_seq_len_decoder > 0) {
|
||||
cur_seq_len_decoder += output_len;
|
||||
mask_rollback[batch_id] = seq_lens_this_time[batch_id] - output_len;
|
||||
} else {
|
||||
mask_rollback[batch_id] = 0;
|
||||
}
|
||||
|
||||
if (cur_seq_len_encoder > 0) {
|
||||
cur_seq_len_decoder += cur_seq_len_encoder;
|
||||
cur_seq_len_encoder = 0;
|
||||
if (cur_stop_flag) {
|
||||
// It should clear seq_lens_decoder in next step for save_output
|
||||
stop_flag_int = 1;
|
||||
stop_flags[batch_id] = true;
|
||||
mask_rollback[batch_id] = 0;
|
||||
}
|
||||
|
||||
// 4. Update model status
|
||||
seq_lens_encoder[batch_id] = cur_seq_len_encoder;
|
||||
seq_lens_decoder[batch_id] = cur_seq_len_decoder;
|
||||
step_output_len[batch_id] = output_len;
|
||||
step_idx[batch_id] = cur_step_idx;
|
||||
|
||||
// Write history to token_ids_all
|
||||
if (cur_step_idx > 0 && output_len > 0) {
|
||||
// Bounds check: highest write index is prompt_lens + cur_step_idx
|
||||
if (prompt_lens[batch_id] + cur_step_idx < max_model_len) {
|
||||
// 5. Write history to token_ids_all (forward loop: position base+k =
|
||||
// output_ids[k])
|
||||
if (output_len > 0) {
|
||||
// Bounds check: highest write index is prompt_len + cur_step_idx
|
||||
if (prompt_len + cur_step_idx < max_model_len) {
|
||||
int64_t *token_ids_all_now =
|
||||
&token_ids_all[batch_id * max_model_len + prompt_lens[batch_id]];
|
||||
&token_ids_all[batch_id * max_model_len + prompt_len];
|
||||
int64_t *output_ids = &step_output_ids[batch_id * max_step_tokens];
|
||||
int64_t base = cur_step_idx - output_len + 1;
|
||||
for (int i = 0; i < output_len; i++) {
|
||||
token_ids_all_now[cur_step_idx - i] =
|
||||
output_ids[output_len - 1 - i];
|
||||
token_ids_all_now[base + i] = output_ids[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Setup next input
|
||||
// 6. Prepare next step input[0]
|
||||
if (output_len > 0) {
|
||||
step_input_ids[batch_id * max_step_tokens] =
|
||||
step_output_ids[batch_id * max_step_tokens + output_len - 1];
|
||||
}
|
||||
|
||||
if (is_naive_mode) {
|
||||
seq_lens_this_time[batch_id] = cur_stop_flag ? 0 : 1;
|
||||
}
|
||||
} else if (batch_id >= real_bsz) {
|
||||
// Padding slot: just count as stopped, don't modify state
|
||||
stop_flag_int = 1;
|
||||
@@ -164,6 +145,7 @@ __global__ void unified_update_model_status_kernel(int *seq_lens_encoder,
|
||||
// Stopped or paused slot (batch_id < real_bsz)
|
||||
stop_flag_int = 1;
|
||||
stop_flags[batch_id] = true;
|
||||
seq_lens_encoder[batch_id] = 0;
|
||||
seq_lens_decoder[batch_id] = 0;
|
||||
seq_lens_this_time[batch_id] = 0;
|
||||
step_output_len[batch_id] = 0;
|
||||
@@ -175,11 +157,9 @@ __global__ void unified_update_model_status_kernel(int *seq_lens_encoder,
|
||||
typedef cub::BlockReduce<int64_t, BLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
||||
// printf("stop_flag_now_int %d \n", stop_flag_int);
|
||||
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_int);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
// printf("stop_sum %d \n", stop_sum);
|
||||
has_running_seqs[0] = stop_sum < max_bsz;
|
||||
}
|
||||
}
|
||||
@@ -189,7 +169,6 @@ void UnifiedUpdateModelStatus(const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &has_running_seqs,
|
||||
const paddle::Tensor &step_input_ids,
|
||||
const paddle::Tensor &adaptive_step_input_len,
|
||||
const paddle::Tensor &step_output_ids,
|
||||
const paddle::Tensor &step_output_len,
|
||||
const paddle::Tensor &stop_flags,
|
||||
@@ -200,9 +179,7 @@ void UnifiedUpdateModelStatus(const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &prompt_lens,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &end_tokens,
|
||||
const paddle::Tensor &max_dec_len,
|
||||
const bool is_naive_mode,
|
||||
const bool prefill_one_step_stop) {
|
||||
const paddle::Tensor &max_dec_len) {
|
||||
const int real_bsz = seq_lens_this_time.shape()[0];
|
||||
const int max_bsz = stop_flags.shape()[0];
|
||||
PADDLE_ENFORCE_LE(
|
||||
@@ -228,7 +205,6 @@ void UnifiedUpdateModelStatus(const paddle::Tensor &seq_lens_encoder,
|
||||
const_cast<bool *>(has_running_seqs_gpu.data<bool>()),
|
||||
const_cast<int *>(mask_rollback.data<int>()),
|
||||
const_cast<int64_t *>(step_input_ids.data<int64_t>()),
|
||||
const_cast<int *>(adaptive_step_input_len.data<int>()),
|
||||
const_cast<int64_t *>(step_output_ids.data<int64_t>()),
|
||||
const_cast<int *>(step_output_len.data<int>()),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
@@ -243,9 +219,7 @@ void UnifiedUpdateModelStatus(const paddle::Tensor &seq_lens_encoder,
|
||||
max_bsz,
|
||||
max_step_tokens,
|
||||
max_model_len,
|
||||
num_end_tokens,
|
||||
is_naive_mode,
|
||||
prefill_one_step_stop);
|
||||
num_end_tokens);
|
||||
// Copy result back to CPU
|
||||
auto has_running_seqs_cpu =
|
||||
has_running_seqs_gpu.copy_to(has_running_seqs.place(), false);
|
||||
@@ -258,7 +232,6 @@ PD_BUILD_STATIC_OP(unified_update_model_status)
|
||||
"seq_lens_decoder",
|
||||
"has_running_seqs",
|
||||
"step_input_ids",
|
||||
"adaptive_step_input_len",
|
||||
"step_output_ids",
|
||||
"step_output_len",
|
||||
"stop_flags",
|
||||
@@ -270,12 +243,10 @@ PD_BUILD_STATIC_OP(unified_update_model_status)
|
||||
"step_idx",
|
||||
"end_tokens",
|
||||
"max_dec_len"})
|
||||
.Attrs({"is_naive_mode: bool", "prefill_one_step_stop: bool"})
|
||||
.Outputs({"seq_lens_encoder_out",
|
||||
"seq_lens_decoder_out",
|
||||
"has_running_seqs_out",
|
||||
"step_input_ids_out",
|
||||
"adaptive_step_input_len_out",
|
||||
"step_output_ids_out",
|
||||
"step_output_len_out",
|
||||
"stop_flags_out",
|
||||
@@ -287,7 +258,6 @@ PD_BUILD_STATIC_OP(unified_update_model_status)
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"has_running_seqs", "has_running_seqs_out"},
|
||||
{"step_input_ids", "step_input_ids_out"},
|
||||
{"adaptive_step_input_len", "adaptive_step_input_len_out"},
|
||||
{"step_output_ids", "step_output_ids_out"},
|
||||
{"step_output_len", "step_output_len_out"},
|
||||
{"stop_flags", "stop_flags_out"},
|
||||
|
||||
Reference in New Issue
Block a user