[BugFix] fix thinking bug (#4710)

* fix thinking bug

* fix ut

* update

* fix
This commit is contained in:
Yuanle Liu
2025-10-31 22:00:31 +08:00
committed by GitHub
parent 10358bf1a0
commit b301bd6c31
8 changed files with 458 additions and 290 deletions
@@ -33,10 +33,14 @@ class TestLimitThinkingContentLengthV1(unittest.TestCase):
max_think_lens = paddle.to_tensor([10, 15], dtype="int32")
step_idx = paddle.to_tensor([[5], [8]], dtype="int64")
limit_think_status = paddle.to_tensor([0, 0], dtype="int32")
stop_flags = paddle.to_tensor([False, False], dtype="bool")
eos_token_ids = paddle.to_tensor([[2], [2]], dtype="int64")
think_end_id = 999
# Run operator
limit_thinking_content_length_v1(next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id)
limit_thinking_content_length_v1(
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, eos_token_ids, think_end_id
)
# Verify: tokens unchanged, status unchanged
assert next_tokens.numpy()[0, 0] == 100
@@ -50,10 +54,14 @@ class TestLimitThinkingContentLengthV1(unittest.TestCase):
max_think_lens = paddle.to_tensor([5, 8], dtype="int32")
step_idx = paddle.to_tensor([[5], [10]], dtype="int64") # Both exceed or equal limit
limit_think_status = paddle.to_tensor([0, 0], dtype="int32")
stop_flags = paddle.to_tensor([False, False], dtype="bool")
eos_token_ids = paddle.to_tensor([[2], [2]], dtype="int64")
think_end_id = 999
# Run operator
limit_thinking_content_length_v1(next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id)
limit_thinking_content_length_v1(
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, eos_token_ids, think_end_id
)
# Verify: tokens replaced with think_end_id, status changed to 2
assert next_tokens.numpy()[0, 0] == 999 # Replaced
@@ -67,10 +75,14 @@ class TestLimitThinkingContentLengthV1(unittest.TestCase):
max_think_lens = paddle.to_tensor([10], dtype="int32")
step_idx = paddle.to_tensor([[3]], dtype="int64") # Still within limit
limit_think_status = paddle.to_tensor([0], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
eos_token_ids = paddle.to_tensor([[2]], dtype="int64")
think_end_id = 999
# Run operator
limit_thinking_content_length_v1(next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id)
limit_thinking_content_length_v1(
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, eos_token_ids, think_end_id
)
# Verify: token unchanged (already think_end_id), status changed to 2
assert next_tokens.numpy()[0, 0] == 999
@@ -82,10 +94,14 @@ class TestLimitThinkingContentLengthV1(unittest.TestCase):
max_think_lens = paddle.to_tensor([5], dtype="int32")
step_idx = paddle.to_tensor([[6]], dtype="int64")
limit_think_status = paddle.to_tensor([1], dtype="int32") # Status is 1
stop_flags = paddle.to_tensor([False], dtype="bool")
eos_token_ids = paddle.to_tensor([[2]], dtype="int64")
think_end_id = 999
# Run operator
limit_thinking_content_length_v1(next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id)
limit_thinking_content_length_v1(
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, eos_token_ids, think_end_id
)
# Verify: status changed to 2
assert limit_think_status.numpy()[0] == 2
@@ -96,10 +112,14 @@ class TestLimitThinkingContentLengthV1(unittest.TestCase):
max_think_lens = paddle.to_tensor([-1], dtype="int32") # Disabled
step_idx = paddle.to_tensor([[100]], dtype="int64") # Would exceed limit if enabled
limit_think_status = paddle.to_tensor([0], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
eos_token_ids = paddle.to_tensor([[2]], dtype="int64")
think_end_id = 999
# Run operator
limit_thinking_content_length_v1(next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id)
limit_thinking_content_length_v1(
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, eos_token_ids, think_end_id
)
# Verify: nothing changed
assert next_tokens.numpy()[0, 0] == 100
@@ -111,10 +131,14 @@ class TestLimitThinkingContentLengthV1(unittest.TestCase):
max_think_lens = paddle.to_tensor([5], dtype="int32")
step_idx = paddle.to_tensor([[10]], dtype="int64")
limit_think_status = paddle.to_tensor([2], dtype="int32") # Already in response phase
stop_flags = paddle.to_tensor([False], dtype="bool")
eos_token_ids = paddle.to_tensor([[2]], dtype="int64")
think_end_id = 999
# Run operator
limit_thinking_content_length_v1(next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id)
limit_thinking_content_length_v1(
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, eos_token_ids, think_end_id
)
# Verify: nothing changed
assert next_tokens.numpy()[0, 0] == 100
@@ -126,10 +150,14 @@ class TestLimitThinkingContentLengthV1(unittest.TestCase):
max_think_lens = paddle.to_tensor([10, 5, 8, -1], dtype="int32")
step_idx = paddle.to_tensor([[3], [5], [4], [100]], dtype="int64")
limit_think_status = paddle.to_tensor([0, 0, 0, 0], dtype="int32")
stop_flags = paddle.to_tensor([False, False, False, False], dtype="bool")
eos_token_ids = paddle.to_tensor([[2], [2]], dtype="int64")
think_end_id = 999
# Run operator
limit_thinking_content_length_v1(next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id)
limit_thinking_content_length_v1(
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, eos_token_ids, think_end_id
)
# Verify each sequence
# Seq 0: step 3 < max 10, status 0, token unchanged
@@ -158,12 +186,13 @@ class TestLimitThinkingContentLengthV2(unittest.TestCase):
max_think_lens = paddle.to_tensor([10, 15], dtype="int32")
step_idx = paddle.to_tensor([[5], [8]], dtype="int64")
limit_think_status = paddle.to_tensor([0, 0], dtype="int32")
stop_flags = paddle.to_tensor([False, False], dtype="bool")
think_end_id = 999
line_break_id = 888
# Run operator
limit_thinking_content_length_v2(
next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id, line_break_id
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, think_end_id, line_break_id
)
# Verify: tokens unchanged, status unchanged
@@ -179,11 +208,12 @@ class TestLimitThinkingContentLengthV2(unittest.TestCase):
max_think_lens = paddle.to_tensor([5], dtype="int32")
step_idx = paddle.to_tensor([[5]], dtype="int64")
limit_think_status = paddle.to_tensor([0], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
think_end_id = 999
line_break_id = 888
limit_thinking_content_length_v2(
next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id, line_break_id
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, think_end_id, line_break_id
)
assert next_tokens.numpy()[0, 0] == 888 # line_break_id
assert limit_think_status.numpy()[0] == 1
@@ -194,7 +224,7 @@ class TestLimitThinkingContentLengthV2(unittest.TestCase):
limit_think_status = paddle.to_tensor([1], dtype="int32")
limit_thinking_content_length_v2(
next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id, line_break_id
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, think_end_id, line_break_id
)
assert next_tokens.numpy()[0, 0] == 999 # think_end_id
assert limit_think_status.numpy()[0] == 1
@@ -205,7 +235,7 @@ class TestLimitThinkingContentLengthV2(unittest.TestCase):
limit_think_status = paddle.to_tensor([1], dtype="int32")
limit_thinking_content_length_v2(
next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id, line_break_id
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, think_end_id, line_break_id
)
assert next_tokens.numpy()[0, 0] == 888 # line_break_id
assert limit_think_status.numpy()[0] == 1
@@ -216,7 +246,7 @@ class TestLimitThinkingContentLengthV2(unittest.TestCase):
limit_think_status = paddle.to_tensor([1], dtype="int32")
limit_thinking_content_length_v2(
next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id, line_break_id
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, think_end_id, line_break_id
)
assert next_tokens.numpy()[0, 0] == 888 # line_break_id
assert limit_think_status.numpy()[0] == 3 # Move to status 3
@@ -227,12 +257,13 @@ class TestLimitThinkingContentLengthV2(unittest.TestCase):
max_think_lens = paddle.to_tensor([10], dtype="int32")
step_idx = paddle.to_tensor([[3]], dtype="int64")
limit_think_status = paddle.to_tensor([0], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
think_end_id = 999
line_break_id = 888
# Run operator
limit_thinking_content_length_v2(
next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id, line_break_id
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, think_end_id, line_break_id
)
# Verify: status changed to 3 (response phase)
@@ -245,12 +276,13 @@ class TestLimitThinkingContentLengthV2(unittest.TestCase):
max_think_lens = paddle.to_tensor([5], dtype="int32")
step_idx = paddle.to_tensor([[9]], dtype="int64")
limit_think_status = paddle.to_tensor([2], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
think_end_id = 999
line_break_id = 888
# Run operator
limit_thinking_content_length_v2(
next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id, line_break_id
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, think_end_id, line_break_id
)
# Verify: status changed to 3
@@ -262,12 +294,13 @@ class TestLimitThinkingContentLengthV2(unittest.TestCase):
max_think_lens = paddle.to_tensor([-1], dtype="int32")
step_idx = paddle.to_tensor([[100]], dtype="int64")
limit_think_status = paddle.to_tensor([0], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
think_end_id = 999
line_break_id = 888
# Run operator
limit_thinking_content_length_v2(
next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id, line_break_id
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, think_end_id, line_break_id
)
# Verify: nothing changed
@@ -280,12 +313,13 @@ class TestLimitThinkingContentLengthV2(unittest.TestCase):
max_think_lens = paddle.to_tensor([5], dtype="int32")
step_idx = paddle.to_tensor([[10]], dtype="int64")
limit_think_status = paddle.to_tensor([3], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
think_end_id = 999
line_break_id = 888
# Run operator
limit_thinking_content_length_v2(
next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id, line_break_id
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, think_end_id, line_break_id
)
# Verify: nothing changed
@@ -298,12 +332,13 @@ class TestLimitThinkingContentLengthV2(unittest.TestCase):
max_think_lens = paddle.to_tensor([10, 5, 8, -1, 6], dtype="int32")
step_idx = paddle.to_tensor([[3], [5], [4], [100], [9]], dtype="int64")
limit_think_status = paddle.to_tensor([0, 0, 0, 0, 2], dtype="int32")
stop_flags = paddle.to_tensor([False, False, False, False, False], dtype="bool")
think_end_id = 999
line_break_id = 888
# Run operator
limit_thinking_content_length_v2(
next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id, line_break_id
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, think_end_id, line_break_id
)
# Seq 0: step 3 < max 10, status 0, unchanged