mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[BugFix][Speculative Decoding] Correct index calculation in speculate decoding operators (#7121)
- Fix accept_idx calculation in spec_set_value_by_stop_seqs - Fix condition check from < to <= for token matching - Fix accept_tokens indexing logic - Remove unnecessary -1 in current_step comparison for max_think_len Co-authored-by: guanshihui] <guanshihui@baidu.com>
This commit is contained in:
@@ -98,10 +98,7 @@ __global__ void speculate_limit_thinking_content_length_kernel(
|
|||||||
if (max_think_len > 0) {
|
if (max_think_len > 0) {
|
||||||
// A) 超长触发:到达 max_think_len 时开始注入(从本 token 起输出
|
// A) 超长触发:到达 max_think_len 时开始注入(从本 token 起输出
|
||||||
// inject_token_ids[0])
|
// inject_token_ids[0])
|
||||||
if (status == 0 &&
|
if (status == 0 && current_step == max_think_len) {
|
||||||
(current_step - 1) ==
|
|
||||||
max_think_len) { // current_step - 1 是因为 speculate_verify 里
|
|
||||||
// step_idx + 1 了
|
|
||||||
status = (inject_len > 0) ? 1 : done_status;
|
status = (inject_len > 0) ? 1 : done_status;
|
||||||
}
|
}
|
||||||
} else if (max_think_len == 0) {
|
} else if (max_think_len == 0) {
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
|
|||||||
bool is_end = false;
|
bool is_end = false;
|
||||||
// 遍历起始位置
|
// 遍历起始位置
|
||||||
for (; accept_idx <= accept_num - 1 && !is_end; accept_idx++) {
|
for (; accept_idx <= accept_num - 1 && !is_end; accept_idx++) {
|
||||||
if (step_idx_now + accept_idx + 1 < stop_seq_len) {
|
if (step_idx_now - accept_num + accept_idx + 1 < stop_seq_len) {
|
||||||
#ifdef DEBUG_SPEC_STOP_SEQS
|
#ifdef DEBUG_SPEC_STOP_SEQS
|
||||||
printf("num %d < stop_seq_len %d\n",
|
printf("num %d < stop_seq_len %d\n",
|
||||||
step_idx_now - accept_num + accept_idx + 1,
|
step_idx_now - accept_num + accept_idx + 1,
|
||||||
@@ -71,7 +71,7 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
|
|||||||
int64_t cur_token_idx = -1;
|
int64_t cur_token_idx = -1;
|
||||||
|
|
||||||
// 通过当前值判断 token 是在 pre_ids 还是 accept_token 里
|
// 通过当前值判断 token 是在 pre_ids 还是 accept_token 里
|
||||||
if (stop_seq_len - 1 - i < accept_idx) {
|
if (stop_seq_len - 1 - i <= accept_idx) {
|
||||||
#ifdef DEBUG_SPEC_STOP_SEQS
|
#ifdef DEBUG_SPEC_STOP_SEQS
|
||||||
printf(
|
printf(
|
||||||
"AcceptTokens bid:%d. tid:%d, accept_idx:%d, "
|
"AcceptTokens bid:%d. tid:%d, accept_idx:%d, "
|
||||||
@@ -83,7 +83,7 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
|
|||||||
accept_idx - (stop_seq_len - 1 - i) - 1);
|
accept_idx - (stop_seq_len - 1 - i) - 1);
|
||||||
#endif
|
#endif
|
||||||
cur_token_idx =
|
cur_token_idx =
|
||||||
accept_tokens_now[accept_idx - (stop_seq_len - 1 - i) - 1];
|
accept_tokens_now[accept_idx - (stop_seq_len - 1 - i)];
|
||||||
} else {
|
} else {
|
||||||
#ifdef DEBUG_SPEC_STOP_SEQS
|
#ifdef DEBUG_SPEC_STOP_SEQS
|
||||||
printf(
|
printf(
|
||||||
@@ -98,7 +98,7 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
|
|||||||
(stop_seq_len - 1 - i));
|
(stop_seq_len - 1 - i));
|
||||||
#endif
|
#endif
|
||||||
int pre_ids_idx =
|
int pre_ids_idx =
|
||||||
step_idx_now + accept_idx - (stop_seq_len - 1 - i);
|
step_idx_now - accept_num + accept_idx - (stop_seq_len - 1 - i);
|
||||||
// EC3
|
// EC3
|
||||||
// 特殊拼接会导致input_ids最后一位无特殊token,即pre_ids[0]可能为23,
|
// 特殊拼接会导致input_ids最后一位无特殊token,即pre_ids[0]可能为23,
|
||||||
// 导致异常结束
|
// 导致异常结束
|
||||||
|
|||||||
@@ -175,17 +175,19 @@ def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str
|
|||||||
accept_idx = 0
|
accept_idx = 0
|
||||||
is_end = False
|
is_end = False
|
||||||
while accept_idx <= an - 1 and not is_end:
|
while accept_idx <= an - 1 and not is_end:
|
||||||
if step_idx_now + accept_idx + 1 < stop_seq_len:
|
if step_idx_now - an + accept_idx + 1 < stop_seq_len:
|
||||||
accept_idx += 1
|
accept_idx += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Check one stop_seq match
|
# Check one stop_seq match
|
||||||
for i in range(stop_seq_len - 1, -1, -1):
|
for i in range(stop_seq_len - 1, -1, -1):
|
||||||
cur_token_idx = -1
|
cur_token_idx = -1
|
||||||
if stop_seq_len - 1 - i < accept_idx:
|
# 注意:新版本kernel改成了 <=,并且去掉了 -1
|
||||||
cur_token_idx = accept_tokens_now[accept_idx - (stop_seq_len - 1 - i) - 1]
|
if stop_seq_len - 1 - i <= accept_idx:
|
||||||
|
cur_token_idx = accept_tokens_now[accept_idx - (stop_seq_len - 1 - i)]
|
||||||
else:
|
else:
|
||||||
pre_ids_idx = step_idx_now + accept_idx - (stop_seq_len - 1 - i)
|
# 新版本:step_idx已经包含accept_num,所以要减去
|
||||||
|
pre_ids_idx = step_idx_now - an + accept_idx - (stop_seq_len - 1 - i)
|
||||||
if pre_ids_idx <= 0:
|
if pre_ids_idx <= 0:
|
||||||
break
|
break
|
||||||
cur_token_idx = pre_ids_now[pre_ids_idx]
|
cur_token_idx = pre_ids_now[pre_ids_idx]
|
||||||
@@ -290,12 +292,12 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
|||||||
inputs["prompt_lens"][:] = 0
|
inputs["prompt_lens"][:] = 0
|
||||||
inputs["step_idx"][:] = 6
|
inputs["step_idx"][:] = 6
|
||||||
inputs["accept_num"][:] = 3
|
inputs["accept_num"][:] = 3
|
||||||
# Kernel matching at accept_idx=2 (3rd token, 0-indexed):
|
# stop_seq spans pre_ids and accept_tokens
|
||||||
# i=2(last): stop_seq_len-1-i=0 < accept_idx(2) -> accept_tokens[2-0-1]=accept_tokens[1]
|
# For accept_idx=1: step_idx_now - accept_num + 1 + 1 = 6-3+1+1 = 5 >= stop_seq_len=3, so we check
|
||||||
# i=1: stop_seq_len-1-i=1 < accept_idx(2) -> accept_tokens[2-1-1]=accept_tokens[0]
|
# i=2: stop_seq_len-1-i=0 <= accept_idx(1) -> accept_tokens[1-0] = accept_tokens[1] = 22
|
||||||
# i=0: stop_seq_len-1-i=2 >= accept_idx(2) -> pre_ids[step_idx+2-(3-1-0)]=pre_ids[6]
|
# i=1: stop_seq_len-1-i=1 <= accept_idx(1) -> accept_tokens[1-1] = accept_tokens[0] = 11
|
||||||
# So stop_seq should be [pre_ids[6], accept_tokens[0], accept_tokens[1]]
|
# i=0: stop_seq_len-1-i=2 > accept_idx(1) -> pre_ids_idx = 6-3+1-(3-1-0) = 4-2 = 2 -> pre_ids[2] = 99
|
||||||
inputs["token_ids_all"][0, 6] = 99
|
inputs["token_ids_all"][0, 2] = 99
|
||||||
inputs["accept_tokens"][0, :3] = [11, 22, 33]
|
inputs["accept_tokens"][0, :3] = [11, 22, 33]
|
||||||
inputs["stop_seqs"][0, 0, :3] = [99, 11, 22]
|
inputs["stop_seqs"][0, 0, :3] = [99, 11, 22]
|
||||||
inputs["stop_seqs_len"][0, 0] = 3
|
inputs["stop_seqs_len"][0, 0] = 3
|
||||||
@@ -303,9 +305,9 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
|||||||
inputs["min_tokens"][:] = 0
|
inputs["min_tokens"][:] = 0
|
||||||
outputs = self._run_and_get(inputs)
|
outputs = self._run_and_get(inputs)
|
||||||
self._check_all_outputs(inputs, outputs)
|
self._check_all_outputs(inputs, outputs)
|
||||||
# Match at accept_idx=2, loop increments to 3
|
# Match at accept_idx=1, loop increments to 2
|
||||||
self.assertEqual(outputs["accept_num"][0], 3)
|
self.assertEqual(outputs["accept_num"][0], 2)
|
||||||
self.assertEqual(outputs["accept_tokens"][0, 2], -1)
|
self.assertEqual(outputs["accept_tokens"][0, 1], inputs["end_ids"][0])
|
||||||
|
|
||||||
def test_match_in_pre_ids_only(self):
|
def test_match_in_pre_ids_only(self):
|
||||||
"""Stop seq found entirely within token_ids_all (pre_ids), matching at accept_idx=0."""
|
"""Stop seq found entirely within token_ids_all (pre_ids), matching at accept_idx=0."""
|
||||||
@@ -314,29 +316,29 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
|||||||
accept_tokens_len=5,
|
accept_tokens_len=5,
|
||||||
max_model_len=32,
|
max_model_len=32,
|
||||||
stop_seqs_bs=1,
|
stop_seqs_bs=1,
|
||||||
stop_seqs_max_len=3,
|
stop_seqs_max_len=4, # 需要4个元素
|
||||||
seed=30,
|
seed=30,
|
||||||
)
|
)
|
||||||
inputs["prompt_lens"][:] = 0
|
inputs["prompt_lens"][:] = 0
|
||||||
inputs["step_idx"][:] = 8
|
inputs["step_idx"][:] = 8
|
||||||
inputs["accept_num"][:] = 3
|
inputs["accept_num"][:] = 3
|
||||||
# pre_ids at step_idx positions: token_ids_all[0, 6]=50, [0,7]=60, [0,8]=70
|
# stop_seq partially in pre_ids, partially in accept_tokens
|
||||||
# stop_seq = [50, 60, 70], all 3 tokens are in pre_ids
|
# For accept_idx=1: step_idx_now - accept_num + 1 + 1 = 8-3+1+1 = 7 >= stop_seq_len=4, so we check
|
||||||
# For accept_idx=0: step_idx_now + 0 + 1 = 9 >= stop_seq_len=3, so we check
|
# i=3: stop_seq_len-1-i=0 <= accept_idx(1) -> accept_tokens[1-0] = accept_tokens[1] = 22
|
||||||
# i=2: pre_ids_idx = 8+0-(3-1-2) = 8 -> pre_ids_now[8] = 70
|
# i=2: stop_seq_len-1-i=1 <= accept_idx(1) -> accept_tokens[1-1] = accept_tokens[0] = 11
|
||||||
# i=1: pre_ids_idx = 8+0-(3-1-1) = 7 -> pre_ids_now[7] = 60
|
# i=1: stop_seq_len-1-i=2 > accept_idx(1) -> pre_ids_idx = 8-3+1-(4-1-1) = 6-2 = 4 -> pre_ids[4] = 60
|
||||||
# i=0: pre_ids_idx = 8+0-(3-1-0) = 6 -> pre_ids_now[6] = 50
|
# i=0: stop_seq_len-1-i=3 > accept_idx(1) -> pre_ids_idx = 8-3+1-(4-1-0) = 6-3 = 3 -> pre_ids[3] = 50
|
||||||
inputs["token_ids_all"][0, 6] = 50
|
inputs["token_ids_all"][0, 3] = 50
|
||||||
inputs["token_ids_all"][0, 7] = 60
|
inputs["token_ids_all"][0, 4] = 60
|
||||||
inputs["token_ids_all"][0, 8] = 70
|
inputs["accept_tokens"][0, :3] = [11, 22, 3]
|
||||||
inputs["accept_tokens"][0, :3] = [1, 2, 3]
|
inputs["stop_seqs"][0, 0, :4] = [50, 60, 11, 22]
|
||||||
inputs["stop_seqs"][0, 0, :3] = [50, 60, 70]
|
inputs["stop_seqs_len"][0, 0] = 4
|
||||||
inputs["stop_seqs_len"][0, 0] = 3
|
|
||||||
inputs["stop_flags"][:] = False
|
inputs["stop_flags"][:] = False
|
||||||
inputs["min_tokens"][:] = 0
|
inputs["min_tokens"][:] = 0
|
||||||
outputs = self._run_and_get(inputs)
|
outputs = self._run_and_get(inputs)
|
||||||
self._check_all_outputs(inputs, outputs)
|
self._check_all_outputs(inputs, outputs)
|
||||||
self.assertEqual(outputs["accept_num"][0], 1)
|
# Match at accept_idx=1, loop increments to 2
|
||||||
|
self.assertEqual(outputs["accept_num"][0], 2)
|
||||||
|
|
||||||
def test_already_stopped(self):
|
def test_already_stopped(self):
|
||||||
"""Kernel skips sequences with stop_flags=True."""
|
"""Kernel skips sequences with stop_flags=True."""
|
||||||
|
|||||||
Reference in New Issue
Block a user