mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Speculative Decoding] Optimize attn_mask_offset and fix mtp bug (#7005)
* optimize attn_mask_offset and optimize mtp usage * delete useless branch * fix kernel format * fix kernel runner
This commit is contained in:
@@ -880,7 +880,6 @@ void UnifiedUpdateModelStatus(const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& is_paused,
|
||||
const paddle::Tensor& mask_rollback,
|
||||
const paddle::Tensor& token_ids_all,
|
||||
const paddle::Tensor& prompt_lens,
|
||||
const paddle::Tensor& step_idx,
|
||||
@@ -1146,10 +1145,8 @@ std::vector<paddle::Tensor> UpdateAttnMaskOffsets(
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& attn_mask_offsets_full,
|
||||
const paddle::Tensor& attn_mask_offsets_decoder,
|
||||
const paddle::Tensor& is_block_step,
|
||||
const paddle::Tensor& decode_states,
|
||||
const paddle::Tensor& mask_rollback);
|
||||
const paddle::Tensor& decode_states);
|
||||
|
||||
std::vector<paddle::Tensor> FusedNeoxRopeEmbedding(
|
||||
const paddle::Tensor& qkv,
|
||||
|
||||
@@ -41,7 +41,6 @@ template <int BLOCK_SIZE>
|
||||
__global__ void unified_update_model_status_kernel(int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
bool *has_running_seqs,
|
||||
int *mask_rollback,
|
||||
int64_t *step_input_ids,
|
||||
int64_t *step_output_ids,
|
||||
int *step_output_len,
|
||||
@@ -100,16 +99,12 @@ __global__ void unified_update_model_status_kernel(int *seq_lens_encoder,
|
||||
cur_seq_len_encoder = 0;
|
||||
} else if (cur_seq_len_decoder > 0) {
|
||||
cur_seq_len_decoder += output_len;
|
||||
mask_rollback[batch_id] = seq_lens_this_time[batch_id] - output_len;
|
||||
} else {
|
||||
mask_rollback[batch_id] = 0;
|
||||
}
|
||||
|
||||
if (cur_stop_flag) {
|
||||
// It should clear seq_lens_decoder in next step for save_output
|
||||
stop_flag_int = 1;
|
||||
stop_flags[batch_id] = true;
|
||||
mask_rollback[batch_id] = 0;
|
||||
}
|
||||
|
||||
// 4. Update model status
|
||||
@@ -174,7 +169,6 @@ void UnifiedUpdateModelStatus(const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &is_paused,
|
||||
const paddle::Tensor &mask_rollback,
|
||||
const paddle::Tensor &token_ids_all,
|
||||
const paddle::Tensor &prompt_lens,
|
||||
const paddle::Tensor &step_idx,
|
||||
@@ -203,7 +197,6 @@ void UnifiedUpdateModelStatus(const paddle::Tensor &seq_lens_encoder,
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<bool *>(has_running_seqs_gpu.data<bool>()),
|
||||
const_cast<int *>(mask_rollback.data<int>()),
|
||||
const_cast<int64_t *>(step_input_ids.data<int64_t>()),
|
||||
const_cast<int64_t *>(step_output_ids.data<int64_t>()),
|
||||
const_cast<int *>(step_output_len.data<int>()),
|
||||
@@ -237,7 +230,6 @@ PD_BUILD_STATIC_OP(unified_update_model_status)
|
||||
"stop_flags",
|
||||
"seq_lens_this_time",
|
||||
"is_paused",
|
||||
"mask_rollback",
|
||||
"token_ids_all",
|
||||
"prompt_lens",
|
||||
"step_idx",
|
||||
@@ -251,7 +243,6 @@ PD_BUILD_STATIC_OP(unified_update_model_status)
|
||||
"step_output_len_out",
|
||||
"stop_flags_out",
|
||||
"seq_lens_this_time_out",
|
||||
"mask_rollback_out",
|
||||
"token_ids_all_out",
|
||||
"step_idx_out"})
|
||||
.SetInplaceMap({{"seq_lens_encoder", "seq_lens_encoder_out"},
|
||||
@@ -262,7 +253,6 @@ PD_BUILD_STATIC_OP(unified_update_model_status)
|
||||
{"step_output_len", "step_output_len_out"},
|
||||
{"stop_flags", "stop_flags_out"},
|
||||
{"seq_lens_this_time", "seq_lens_this_time_out"},
|
||||
{"mask_rollback", "mask_rollback_out"},
|
||||
{"token_ids_all", "token_ids_all_out"},
|
||||
{"step_idx", "step_idx_out"}})
|
||||
.SetKernelFn(PD_KERNEL(UnifiedUpdateModelStatus));
|
||||
|
||||
@@ -129,52 +129,11 @@ struct VerifyContext {
|
||||
output_len_now++;
|
||||
}
|
||||
|
||||
// TOPP-only: verify_window bulk-accept fallback.
|
||||
//
|
||||
// When draft token is NOT in top-p set but IS the top-2 token,
|
||||
// check verify_window consecutive positions for top-1 match.
|
||||
// If all match, bulk-accept from position i through ii.
|
||||
//
|
||||
// Returns the new loop position (i) after handling.
|
||||
// Sets *rejected=true if fallback was not triggered (caller should break).
|
||||
__device__ __forceinline__ int try_verify_window_fallback(
|
||||
int i,
|
||||
bool *rejected,
|
||||
const int64_t *verify_tokens_now,
|
||||
int seq_len_this_time,
|
||||
int max_candidate_len,
|
||||
int verify_window) {
|
||||
int ii = i;
|
||||
if (max_candidate_len >= 2 &&
|
||||
verify_tokens_now[ii * max_candidate_len + 1] ==
|
||||
step_input_ids_now[ii + 1]) {
|
||||
// top-2 matches — scan verify_window consecutive top-1 matches
|
||||
int j = 0;
|
||||
ii += 1;
|
||||
for (; j < verify_window && ii < seq_len_this_time - 1; j++, ii++) {
|
||||
if (verify_tokens_now[ii * max_candidate_len] !=
|
||||
step_input_ids_now[ii + 1]) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (j >= verify_window) {
|
||||
// Bulk accept all tokens from i to ii
|
||||
for (; i < ii; i++) {
|
||||
if (emit_token(i, step_input_ids_now[i + 1])) return i;
|
||||
}
|
||||
return i; // continue outer loop from position ii
|
||||
}
|
||||
}
|
||||
// Fallback not triggered or insufficient window — reject
|
||||
*rejected = true;
|
||||
return i;
|
||||
}
|
||||
// ============================================================
|
||||
// Phase 2 helpers — sample token for rejected/last position
|
||||
// ============================================================
|
||||
};
|
||||
|
||||
// ============================================================
|
||||
// Phase 2 helpers — sample token for rejected/last position
|
||||
// ============================================================
|
||||
|
||||
__device__ inline int64_t topp_sampling_kernel(const int64_t *candidate_ids,
|
||||
const float *candidate_scores,
|
||||
curandState_t *curand_states,
|
||||
@@ -193,7 +152,6 @@ __device__ inline int64_t topp_sampling_kernel(const int64_t *candidate_ids,
|
||||
}
|
||||
return candidate_ids[0];
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Main verification kernel
|
||||
// ============================================================
|
||||
@@ -201,8 +159,8 @@ __device__ inline int64_t topp_sampling_kernel(const int64_t *candidate_ids,
|
||||
// Input parameter groups by strategy:
|
||||
// - target_tokens: GREEDY=argmax, TARGET_MATCH=sampled, TOPP=unused
|
||||
// (None)
|
||||
// - candidate_ids/scores: TOPP=full candidate set, GREEDY/TARGET_MATCH=unused
|
||||
// (None)
|
||||
// - candidate_ids/scores: TOPP=full candidate set,
|
||||
// GREEDY/TARGET_MATCH=unused (None)
|
||||
// - candidate_lens: TOPP=actual length per position,
|
||||
// GREEDY/TARGET_MATCH=unused (None)
|
||||
//
|
||||
@@ -250,6 +208,9 @@ __global__ void verify_draft_tokens(
|
||||
// Initialize step_output_len to 0 for ALL slots
|
||||
if (bid < max_bsz) {
|
||||
step_output_len[bid] = 0;
|
||||
for (int i = 0; i < max_step_tokens; i++) {
|
||||
step_output_ids[bid * max_step_tokens + i] = -1;
|
||||
}
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
@@ -307,17 +268,6 @@ __global__ void verify_draft_tokens(
|
||||
accepted = verify_one_topp(candidate_ids_now + i * max_candidate_len,
|
||||
ctx.step_input_ids_now[i + 1],
|
||||
actual_cand_len);
|
||||
if (!accepted) {
|
||||
bool rejected = false;
|
||||
i = ctx.try_verify_window_fallback(i,
|
||||
&rejected,
|
||||
candidate_ids_now,
|
||||
seq_lens_this_time[bid],
|
||||
max_candidate_len,
|
||||
verify_window);
|
||||
if (ctx.stopped || rejected) goto phase1_done;
|
||||
continue; // bulk accept succeeded, continue from new i
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 1: // GREEDY
|
||||
@@ -333,7 +283,6 @@ __global__ void verify_draft_tokens(
|
||||
break; // reject
|
||||
}
|
||||
}
|
||||
phase1_done:
|
||||
|
||||
// ======== Phase 2: Output token for rejected/last position ========
|
||||
if (!ctx.stopped) {
|
||||
|
||||
@@ -21,10 +21,8 @@ __global__ void update_attn_mask_offsets_kernel(
|
||||
const int* seq_lens_decoder,
|
||||
const int* cu_seqlens_q,
|
||||
const int* attn_mask_offsets_full,
|
||||
int* attn_mask_offsets_decoder,
|
||||
const bool* is_block_step,
|
||||
int* decode_states,
|
||||
int* mask_rollback,
|
||||
const int real_bsz,
|
||||
const int max_model_len,
|
||||
const int decode_states_len) {
|
||||
@@ -55,15 +53,10 @@ __global__ void update_attn_mask_offsets_kernel(
|
||||
}
|
||||
}
|
||||
} else if (seq_len_decoder > 0) {
|
||||
// Status: decoder -- normal or chunk_prefill
|
||||
// TODO: support speculative decoding.
|
||||
attn_mask_offsets_decoder[bid] -= mask_rollback[bid];
|
||||
mask_rollback[bid] = 0;
|
||||
for (int i = 0; i < seq_len_this_time; i++) {
|
||||
attn_mask_offsets[(query_start_id + i) * 2 + 1] =
|
||||
attn_mask_offsets_decoder[bid] + 1 + i;
|
||||
seq_len_decoder + 1 + i;
|
||||
}
|
||||
attn_mask_offsets_decoder[bid] += seq_len_this_time;
|
||||
|
||||
// Speculative decoding in text_generation
|
||||
for (int i = 0; i < decode_states_len; i++) {
|
||||
@@ -85,10 +78,8 @@ std::vector<paddle::Tensor> UpdateAttnMaskOffsets(
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& attn_mask_offsets_full,
|
||||
const paddle::Tensor& attn_mask_offsets_decoder,
|
||||
const paddle::Tensor& is_block_step,
|
||||
const paddle::Tensor& decode_states,
|
||||
const paddle::Tensor& mask_rollback) {
|
||||
const paddle::Tensor& decode_states) {
|
||||
int max_model_len = attn_mask_offsets_full.shape()[1];
|
||||
int real_bsz = seq_lens_this_time.shape()[0];
|
||||
int batch_seq_lens = ids_remove_padding.shape()[0];
|
||||
@@ -112,10 +103,8 @@ std::vector<paddle::Tensor> UpdateAttnMaskOffsets(
|
||||
seq_lens_decoder.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
attn_mask_offsets_full.data<int>(),
|
||||
const_cast<int*>(attn_mask_offsets_decoder.data<int>()),
|
||||
is_block_step.data<bool>(),
|
||||
const_cast<int*>(decode_states.data<int>()),
|
||||
const_cast<int*>(mask_rollback.data<int>()),
|
||||
real_bsz,
|
||||
max_model_len,
|
||||
decode_states_len);
|
||||
@@ -130,11 +119,8 @@ PD_BUILD_STATIC_OP(update_attn_mask_offsets)
|
||||
"seq_lens_decoder",
|
||||
"cu_seqlens_q",
|
||||
"attn_mask_offsets_full",
|
||||
"attn_mask_offsets_decoder",
|
||||
"is_block_step",
|
||||
"decode_states",
|
||||
"mask_rollback"})
|
||||
.Outputs({"attn_mask_offsets", "decode_states_out", "mask_rollback_out"})
|
||||
.SetInplaceMap({{"decode_states", "decode_states_out"},
|
||||
{"mask_rollback", "mask_rollback_out"}})
|
||||
"decode_states"})
|
||||
.Outputs({"attn_mask_offsets", "decode_states_out"})
|
||||
.SetInplaceMap({{"decode_states", "decode_states_out"}})
|
||||
.SetKernelFn(PD_KERNEL(UpdateAttnMaskOffsets));
|
||||
|
||||
Reference in New Issue
Block a user