mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[BugFix][Speculative Decoding] Correct index calculation in speculate decoding operators (#7121)
- Fix accept_idx calculation in spec_set_value_by_stop_seqs - Fix condition check from < to <= for token matching - Fix accept_tokens indexing logic - Remove unnecessary -1 in current_step comparison for max_think_len Co-authored-by: guanshihui] <guanshihui@baidu.com>
This commit is contained in:
@@ -98,10 +98,7 @@ __global__ void speculate_limit_thinking_content_length_kernel(
|
||||
if (max_think_len > 0) {
|
||||
// A) 超长触发:到达 max_think_len 时开始注入(从本 token 起输出
|
||||
// inject_token_ids[0])
|
||||
if (status == 0 &&
|
||||
(current_step - 1) ==
|
||||
max_think_len) { // current_step - 1 是因为 speculate_verify 里
|
||||
// step_idx + 1 了
|
||||
if (status == 0 && current_step == max_think_len) {
|
||||
status = (inject_len > 0) ? 1 : done_status;
|
||||
}
|
||||
} else if (max_think_len == 0) {
|
||||
|
||||
@@ -58,7 +58,7 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
|
||||
bool is_end = false;
|
||||
// 遍历起始位置
|
||||
for (; accept_idx <= accept_num - 1 && !is_end; accept_idx++) {
|
||||
if (step_idx_now + accept_idx + 1 < stop_seq_len) {
|
||||
if (step_idx_now - accept_num + accept_idx + 1 < stop_seq_len) {
|
||||
#ifdef DEBUG_SPEC_STOP_SEQS
|
||||
printf("num %d < stop_seq_len %d\n",
|
||||
step_idx_now - accept_num + accept_idx + 1,
|
||||
@@ -71,7 +71,7 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
|
||||
int64_t cur_token_idx = -1;
|
||||
|
||||
// 通过当前值判断 token 是在 pre_ids 还是 accept_token 里
|
||||
if (stop_seq_len - 1 - i < accept_idx) {
|
||||
if (stop_seq_len - 1 - i <= accept_idx) {
|
||||
#ifdef DEBUG_SPEC_STOP_SEQS
|
||||
printf(
|
||||
"AcceptTokens bid:%d. tid:%d, accept_idx:%d, "
|
||||
@@ -83,7 +83,7 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
|
||||
accept_idx - (stop_seq_len - 1 - i) - 1);
|
||||
#endif
|
||||
cur_token_idx =
|
||||
accept_tokens_now[accept_idx - (stop_seq_len - 1 - i) - 1];
|
||||
accept_tokens_now[accept_idx - (stop_seq_len - 1 - i)];
|
||||
} else {
|
||||
#ifdef DEBUG_SPEC_STOP_SEQS
|
||||
printf(
|
||||
@@ -98,7 +98,7 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
|
||||
(stop_seq_len - 1 - i));
|
||||
#endif
|
||||
int pre_ids_idx =
|
||||
step_idx_now + accept_idx - (stop_seq_len - 1 - i);
|
||||
step_idx_now - accept_num + accept_idx - (stop_seq_len - 1 - i);
|
||||
// EC3
|
||||
// 特殊拼接会导致input_ids最后一位无特殊token,即pre_ids[0]可能为23,
|
||||
// 导致异常结束
|
||||
|
||||
Reference in New Issue
Block a user