[Speculative Decoding] Optimize attn_mask_offset and fix mtp bug (#7005)

* optimize attn_mask_offset and optimize mtp usage

* delete useless branch

* fix kernel format

* fix kernel runner
This commit is contained in:
freeliuzc
2026-03-25 16:52:06 +08:00
committed by GitHub
parent aee293be0f
commit 7a6c28781b
10 changed files with 31 additions and 234 deletions
+1 -4
View File
@@ -880,7 +880,6 @@ void UnifiedUpdateModelStatus(const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& is_paused,
const paddle::Tensor& mask_rollback,
const paddle::Tensor& token_ids_all,
const paddle::Tensor& prompt_lens,
const paddle::Tensor& step_idx,
@@ -1146,10 +1145,8 @@ std::vector<paddle::Tensor> UpdateAttnMaskOffsets(
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& attn_mask_offsets_full,
const paddle::Tensor& attn_mask_offsets_decoder,
const paddle::Tensor& is_block_step,
const paddle::Tensor& decode_states,
const paddle::Tensor& mask_rollback);
const paddle::Tensor& decode_states);
std::vector<paddle::Tensor> FusedNeoxRopeEmbedding(
const paddle::Tensor& qkv,
@@ -41,7 +41,6 @@ template <int BLOCK_SIZE>
__global__ void unified_update_model_status_kernel(int *seq_lens_encoder,
int *seq_lens_decoder,
bool *has_running_seqs,
int *mask_rollback,
int64_t *step_input_ids,
int64_t *step_output_ids,
int *step_output_len,
@@ -100,16 +99,12 @@ __global__ void unified_update_model_status_kernel(int *seq_lens_encoder,
cur_seq_len_encoder = 0;
} else if (cur_seq_len_decoder > 0) {
cur_seq_len_decoder += output_len;
mask_rollback[batch_id] = seq_lens_this_time[batch_id] - output_len;
} else {
mask_rollback[batch_id] = 0;
}
if (cur_stop_flag) {
// It should clear seq_lens_decoder in next step for save_output
stop_flag_int = 1;
stop_flags[batch_id] = true;
mask_rollback[batch_id] = 0;
}
// 4. Update model status
@@ -174,7 +169,6 @@ void UnifiedUpdateModelStatus(const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &is_paused,
const paddle::Tensor &mask_rollback,
const paddle::Tensor &token_ids_all,
const paddle::Tensor &prompt_lens,
const paddle::Tensor &step_idx,
@@ -203,7 +197,6 @@ void UnifiedUpdateModelStatus(const paddle::Tensor &seq_lens_encoder,
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<bool *>(has_running_seqs_gpu.data<bool>()),
const_cast<int *>(mask_rollback.data<int>()),
const_cast<int64_t *>(step_input_ids.data<int64_t>()),
const_cast<int64_t *>(step_output_ids.data<int64_t>()),
const_cast<int *>(step_output_len.data<int>()),
@@ -237,7 +230,6 @@ PD_BUILD_STATIC_OP(unified_update_model_status)
"stop_flags",
"seq_lens_this_time",
"is_paused",
"mask_rollback",
"token_ids_all",
"prompt_lens",
"step_idx",
@@ -251,7 +243,6 @@ PD_BUILD_STATIC_OP(unified_update_model_status)
"step_output_len_out",
"stop_flags_out",
"seq_lens_this_time_out",
"mask_rollback_out",
"token_ids_all_out",
"step_idx_out"})
.SetInplaceMap({{"seq_lens_encoder", "seq_lens_encoder_out"},
@@ -262,7 +253,6 @@ PD_BUILD_STATIC_OP(unified_update_model_status)
{"step_output_len", "step_output_len_out"},
{"stop_flags", "stop_flags_out"},
{"seq_lens_this_time", "seq_lens_this_time_out"},
{"mask_rollback", "mask_rollback_out"},
{"token_ids_all", "token_ids_all_out"},
{"step_idx", "step_idx_out"}})
.SetKernelFn(PD_KERNEL(UnifiedUpdateModelStatus));
@@ -129,52 +129,11 @@ struct VerifyContext {
output_len_now++;
}
// TOPP-only: verify_window bulk-accept fallback.
//
// When draft token is NOT in top-p set but IS the top-2 token,
// check verify_window consecutive positions for top-1 match.
// If all match, bulk-accept from position i through ii.
//
// Returns the new loop position (i) after handling.
// Sets *rejected=true if fallback was not triggered (caller should break).
__device__ __forceinline__ int try_verify_window_fallback(
int i,
bool *rejected,
const int64_t *verify_tokens_now,
int seq_len_this_time,
int max_candidate_len,
int verify_window) {
int ii = i;
if (max_candidate_len >= 2 &&
verify_tokens_now[ii * max_candidate_len + 1] ==
step_input_ids_now[ii + 1]) {
// top-2 matches — scan verify_window consecutive top-1 matches
int j = 0;
ii += 1;
for (; j < verify_window && ii < seq_len_this_time - 1; j++, ii++) {
if (verify_tokens_now[ii * max_candidate_len] !=
step_input_ids_now[ii + 1]) {
break;
}
}
if (j >= verify_window) {
// Bulk accept all tokens from i to ii
for (; i < ii; i++) {
if (emit_token(i, step_input_ids_now[i + 1])) return i;
}
return i; // continue outer loop from position ii
}
}
// Fallback not triggered or insufficient window — reject
*rejected = true;
return i;
}
// ============================================================
// Phase 2 helpers — sample token for rejected/last position
// ============================================================
};
// ============================================================
// Phase 2 helpers — sample token for rejected/last position
// ============================================================
__device__ inline int64_t topp_sampling_kernel(const int64_t *candidate_ids,
const float *candidate_scores,
curandState_t *curand_states,
@@ -193,7 +152,6 @@ __device__ inline int64_t topp_sampling_kernel(const int64_t *candidate_ids,
}
return candidate_ids[0];
}
// ============================================================
// Main verification kernel
// ============================================================
@@ -201,8 +159,8 @@ __device__ inline int64_t topp_sampling_kernel(const int64_t *candidate_ids,
// Input parameter groups by strategy:
// - target_tokens: GREEDY=argmax, TARGET_MATCH=sampled, TOPP=unused
// (None)
// - candidate_ids/scores: TOPP=full candidate set, GREEDY/TARGET_MATCH=unused
// (None)
// - candidate_ids/scores: TOPP=full candidate set,
// GREEDY/TARGET_MATCH=unused (None)
// - candidate_lens: TOPP=actual length per position,
// GREEDY/TARGET_MATCH=unused (None)
//
@@ -250,6 +208,9 @@ __global__ void verify_draft_tokens(
// Initialize step_output_len to 0 for ALL slots
if (bid < max_bsz) {
step_output_len[bid] = 0;
for (int i = 0; i < max_step_tokens; i++) {
step_output_ids[bid * max_step_tokens + i] = -1;
}
} else {
return;
}
@@ -307,17 +268,6 @@ __global__ void verify_draft_tokens(
accepted = verify_one_topp(candidate_ids_now + i * max_candidate_len,
ctx.step_input_ids_now[i + 1],
actual_cand_len);
if (!accepted) {
bool rejected = false;
i = ctx.try_verify_window_fallback(i,
&rejected,
candidate_ids_now,
seq_lens_this_time[bid],
max_candidate_len,
verify_window);
if (ctx.stopped || rejected) goto phase1_done;
continue; // bulk accept succeeded, continue from new i
}
break;
}
case 1: // GREEDY
@@ -333,7 +283,6 @@ __global__ void verify_draft_tokens(
break; // reject
}
}
phase1_done:
// ======== Phase 2: Output token for rejected/last position ========
if (!ctx.stopped) {
+5 -19
View File
@@ -21,10 +21,8 @@ __global__ void update_attn_mask_offsets_kernel(
const int* seq_lens_decoder,
const int* cu_seqlens_q,
const int* attn_mask_offsets_full,
int* attn_mask_offsets_decoder,
const bool* is_block_step,
int* decode_states,
int* mask_rollback,
const int real_bsz,
const int max_model_len,
const int decode_states_len) {
@@ -55,15 +53,10 @@ __global__ void update_attn_mask_offsets_kernel(
}
}
} else if (seq_len_decoder > 0) {
// Status: decoder -- normal or chunk_prefill
// TODO: support speculative decoding.
attn_mask_offsets_decoder[bid] -= mask_rollback[bid];
mask_rollback[bid] = 0;
for (int i = 0; i < seq_len_this_time; i++) {
attn_mask_offsets[(query_start_id + i) * 2 + 1] =
attn_mask_offsets_decoder[bid] + 1 + i;
seq_len_decoder + 1 + i;
}
attn_mask_offsets_decoder[bid] += seq_len_this_time;
// Speculative decoding in text_generation
for (int i = 0; i < decode_states_len; i++) {
@@ -85,10 +78,8 @@ std::vector<paddle::Tensor> UpdateAttnMaskOffsets(
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& attn_mask_offsets_full,
const paddle::Tensor& attn_mask_offsets_decoder,
const paddle::Tensor& is_block_step,
const paddle::Tensor& decode_states,
const paddle::Tensor& mask_rollback) {
const paddle::Tensor& decode_states) {
int max_model_len = attn_mask_offsets_full.shape()[1];
int real_bsz = seq_lens_this_time.shape()[0];
int batch_seq_lens = ids_remove_padding.shape()[0];
@@ -112,10 +103,8 @@ std::vector<paddle::Tensor> UpdateAttnMaskOffsets(
seq_lens_decoder.data<int>(),
cu_seqlens_q.data<int>(),
attn_mask_offsets_full.data<int>(),
const_cast<int*>(attn_mask_offsets_decoder.data<int>()),
is_block_step.data<bool>(),
const_cast<int*>(decode_states.data<int>()),
const_cast<int*>(mask_rollback.data<int>()),
real_bsz,
max_model_len,
decode_states_len);
@@ -130,11 +119,8 @@ PD_BUILD_STATIC_OP(update_attn_mask_offsets)
"seq_lens_decoder",
"cu_seqlens_q",
"attn_mask_offsets_full",
"attn_mask_offsets_decoder",
"is_block_step",
"decode_states",
"mask_rollback"})
.Outputs({"attn_mask_offsets", "decode_states_out", "mask_rollback_out"})
.SetInplaceMap({{"decode_states", "decode_states_out"},
{"mask_rollback", "mask_rollback_out"}})
"decode_states"})
.Outputs({"attn_mask_offsets", "decode_states_out"})
.SetInplaceMap({{"decode_states", "decode_states_out"}})
.SetKernelFn(PD_KERNEL(UpdateAttnMaskOffsets));