mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Speculative Decoding] Fix speculate stop_seqs and fix accept_num in eos branch (#6825)
This commit is contained in:
@@ -58,7 +58,7 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
|
||||
bool is_end = false;
|
||||
// 遍历起始位置
|
||||
for (; accept_idx <= accept_num - 1 && !is_end; accept_idx++) {
|
||||
if (step_idx_now - accept_num + accept_idx + 1 < stop_seq_len) {
|
||||
if (step_idx_now + accept_idx + 1 < stop_seq_len) {
|
||||
#ifdef DEBUG_SPEC_STOP_SEQS
|
||||
printf("num %d < stop_seq_len %d\n",
|
||||
step_idx_now - accept_num + accept_idx + 1,
|
||||
@@ -98,7 +98,7 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
|
||||
(stop_seq_len - 1 - i));
|
||||
#endif
|
||||
int pre_ids_idx =
|
||||
step_idx_now - accept_num + accept_idx - (stop_seq_len - 1 - i);
|
||||
step_idx_now + accept_idx - (stop_seq_len - 1 - i);
|
||||
// EC3
|
||||
// 特殊拼接会导致input_ids最后一位无特殊token,即pre_ids[0]可能为23,
|
||||
// 导致异常结束
|
||||
@@ -131,7 +131,7 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
|
||||
|
||||
accept_nums[bid] = accept_idx;
|
||||
accept_tokens_now[accept_idx - 1] = end_ids[0];
|
||||
stop_flags[bid] = true;
|
||||
// stop_flags[bid] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -116,8 +116,8 @@ struct VerifyContext {
|
||||
}
|
||||
|
||||
// Emit the final token at position `pos` in Phase 2.
|
||||
// Same EOS/max_dec_len logic, but does NOT increment output_len_now
|
||||
// (Phase 2's token is already counted in the initial output_len_now=1).
|
||||
// Same EOS/max_dec_len logic. Increments output_len_now since
|
||||
// Phase 2 produces one additional token.
|
||||
__device__ __forceinline__ void emit_final_token(int pos, int64_t token) {
|
||||
cur_step_idx++;
|
||||
bool is_eos = is_in_end(token, end_tokens, end_length);
|
||||
@@ -126,6 +126,7 @@ struct VerifyContext {
|
||||
token = end_tokens[0];
|
||||
}
|
||||
step_output_ids[bid * max_step_tokens + pos] = token;
|
||||
output_len_now++;
|
||||
}
|
||||
|
||||
// TOPP-only: verify_window bulk-accept fallback.
|
||||
@@ -278,7 +279,7 @@ __global__ void verify_draft_tokens(
|
||||
ctx.step_input_ids_now = step_input_ids + bid * max_step_tokens;
|
||||
ctx.step_output_ids = step_output_ids;
|
||||
ctx.cur_step_idx = step_idx[bid];
|
||||
ctx.output_len_now = 1;
|
||||
ctx.output_len_now = 0;
|
||||
ctx.stopped = false;
|
||||
|
||||
// ======== Phase 1: Verify draft tokens ========
|
||||
|
||||
Reference in New Issue
Block a user