mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Bugfix] Align thinking_budget behavior with ERNIE reasoning flow (#6934)
* [Bugfix] Align thinking_budget behavior with ERNIE reasoning flow * [Docs] Fix thinking_budget markdown formatting * [Test] Align ernie thinking budget test with process_request_dict
This commit is contained in:
@@ -20,7 +20,19 @@ from fastdeploy.engine import common_engine as common_engine_module
|
||||
from fastdeploy.engine import engine as engine_module
|
||||
from fastdeploy.engine.args_utils import EngineArgs # Import EngineArgs
|
||||
from fastdeploy.engine.sampling_params import SamplingParams
|
||||
from fastdeploy.input.ernie4_5_processor import (
|
||||
Ernie4_5Processor as ErnieTextDataProcessor,
|
||||
)
|
||||
from fastdeploy.input.ernie4_5_vl_processor import (
|
||||
Ernie4_5_VLProcessor as ErnieVLDataProcessor,
|
||||
)
|
||||
from fastdeploy.input.text_processor import DataProcessor as TextDataProcessor
|
||||
from fastdeploy.input.v1.ernie4_5_processor import (
|
||||
Ernie4_5Processor as V1ErnieTextDataProcessor,
|
||||
)
|
||||
from fastdeploy.input.v1.ernie4_5_vl_processor.ernie4_5_vl_processor import (
|
||||
Ernie4_5_VLProcessor as V1ErnieVLDataProcessor,
|
||||
)
|
||||
from fastdeploy.input.v1.text_processor import DataProcessor as V1TextDataProcessor
|
||||
from fastdeploy.model_executor.logits_processor import ThinkingBudgetLogitsProcessor
|
||||
from fastdeploy.scheduler import SchedulerConfig
|
||||
@@ -52,6 +64,7 @@ class MockModelRunner:
|
||||
"req_ids": [f"req_{i}" for i in range(max_num_seqs)],
|
||||
"logits_processors_args": [{} for _ in range(max_num_seqs)],
|
||||
"prompt_ids": paddle.to_tensor(np.zeros((max_num_seqs, 10), dtype=np.int64)), # Max prompt len 10
|
||||
"token_ids_all": paddle.to_tensor(np.full((max_num_seqs, self.max_model_len), -1, dtype=np.int64)),
|
||||
"prompt_lens": paddle.to_tensor(np.zeros((max_num_seqs, 1), dtype=np.int64)),
|
||||
"pre_ids": paddle.to_tensor(np.full((max_num_seqs, self.max_model_len), -1, dtype=np.int64)),
|
||||
"step_idx": paddle.to_tensor(np.zeros((max_num_seqs, 1), dtype=np.int64)),
|
||||
@@ -73,6 +86,12 @@ class MockModelRunner:
|
||||
self.share_inputs["prompt_ids"][slot_id, : len(req.prompt_ids)] = paddle.to_tensor(
|
||||
req.prompt_ids, dtype=paddle.int64
|
||||
)
|
||||
self.share_inputs["token_ids_all"][slot_id, :] = paddle.to_tensor(
|
||||
np.full((self.max_model_len,), -1, dtype=np.int64), dtype=paddle.int64
|
||||
)
|
||||
self.share_inputs["token_ids_all"][slot_id, : len(req.prompt_ids)] = paddle.to_tensor(
|
||||
req.prompt_ids, dtype=paddle.int64
|
||||
)
|
||||
self.share_inputs["prompt_lens"][slot_id] = paddle.to_tensor(len(req.prompt_ids), dtype=paddle.int64)
|
||||
if req.sampling_params.logits_processors_args:
|
||||
self.share_inputs["logits_processors_args"][slot_id] = req.sampling_params.logits_processors_args
|
||||
@@ -152,7 +171,7 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
||||
return logits
|
||||
|
||||
def test_thinking_budget_not_reached(self):
|
||||
# Scenario: Thinking budget is 5, but only 3 tokens are generated in thinking phase
|
||||
# Scenario: Thinking budget is 5, and prompt-side tokens after <think> do not consume budget.
|
||||
req_id = "test_req_1"
|
||||
prompt_ids = [1, 2, THINKING_START_TOKEN_ID, 3, 4, 5]
|
||||
sampling_params = SamplingParams(logits_processors_args={"thinking_budget": 5})
|
||||
@@ -167,9 +186,9 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
||||
processor.update_state(mock_runner.share_inputs)
|
||||
self.assertTrue(processor._states[req_id].started)
|
||||
self.assertFalse(processor._states[req_id].ended)
|
||||
self.assertEqual(processor._states[req_id].tokens_after_start, 3) # (3, 4, 5) after THINKING_START
|
||||
self.assertEqual(processor._states[req_id].tokens_after_start, 0)
|
||||
|
||||
# Step 2: Simulate one generation step (budget 5, generated 3 -> 4)
|
||||
# Step 2: Simulate one generation step (budget 5, generated 0 -> 1)
|
||||
logits = self._get_initial_logits(1)
|
||||
processor.update_state(mock_runner.share_inputs) # Update state before apply
|
||||
processed_logits = processor.apply(logits)
|
||||
@@ -181,9 +200,9 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
||||
mock_runner.update_request_state(0, mock_req, pre_id=0) # Update last token
|
||||
|
||||
processor.update_state(mock_runner.share_inputs) # Update state after generating token
|
||||
self.assertEqual(processor._states[req_id].tokens_after_start, 4) # (3, 4, 5, 0)
|
||||
self.assertEqual(processor._states[req_id].tokens_after_start, 1)
|
||||
|
||||
# Step 3: Simulate another generation step (budget 5, generated 4 -> 5)
|
||||
# Step 3: Simulate another generation step (budget 5, generated 1 -> 2)
|
||||
logits = self._get_initial_logits(1)
|
||||
processor.update_state(mock_runner.share_inputs) # Update state before apply
|
||||
processed_logits = processor.apply(logits)
|
||||
@@ -191,14 +210,14 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
||||
mock_runner.update_request_state(0, mock_req, pre_id=0) # Update last token
|
||||
|
||||
processor.update_state(mock_runner.share_inputs) # Update state after generating token
|
||||
self.assertEqual(processor._states[req_id].tokens_after_start, 5) # (3, 4, 5, 0, 0)
|
||||
self.assertEqual(processor._states[req_id].tokens_after_start, 2)
|
||||
|
||||
# LogitsProcessor should still not restrict as NEW_LINE_TOKEN is not yet last token
|
||||
|
||||
def test_thinking_budget_reached_forces_newline(self):
|
||||
# Scenario: Budget is 3, after 3 tokens, it should force newline, then thinking_end
|
||||
def test_thinking_budget_reached_forces_think_end(self):
|
||||
# Scenario: Budget is 3 and only decode-time tokens count toward the budget.
|
||||
req_id = "test_req_2"
|
||||
prompt_ids = [1, 2, THINKING_START_TOKEN_ID, 3] # Initial 1 token after start
|
||||
prompt_ids = [1, 2, THINKING_START_TOKEN_ID, 3]
|
||||
sampling_params = SamplingParams(logits_processors_args={"thinking_budget": 3})
|
||||
mock_req = MockRequest(req_id, prompt_ids, sampling_params)
|
||||
|
||||
@@ -206,16 +225,26 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
||||
mock_runner.update_request_state(0, mock_req, pre_id=prompt_ids[-1])
|
||||
|
||||
processor = ThinkingBudgetLogitsProcessor(self.fd_config)
|
||||
line_break_id = processor.line_break_token_id
|
||||
think_end_id = processor.think_end_token_id
|
||||
|
||||
# Step 1: Initial state update (1 token after start)
|
||||
# Step 1: Initial state update (prompt-side tokens do not count)
|
||||
processor.update_state(mock_runner.share_inputs)
|
||||
self.assertEqual(processor._states[req_id].tokens_after_start, 1)
|
||||
self.assertEqual(processor._states[req_id].tokens_after_start, 0)
|
||||
self.assertFalse(processor._states[req_id].ended)
|
||||
self.assertEqual(processor._states[req_id].last_token_id, 3)
|
||||
|
||||
# Step 2: Generate 2nd token (budget 3, generated 1 -> 2)
|
||||
# Step 2: Generate 1st decode token (budget 3, generated 0 -> 1)
|
||||
logits = self._get_initial_logits(1)
|
||||
processor.update_state(mock_runner.share_inputs)
|
||||
processed_logits = processor.apply(logits)
|
||||
next_token = mock_runner.generate_next_token(processed_logits)[0]
|
||||
self.assertEqual(next_token, 0) # Normal generation
|
||||
mock_runner.update_request_state(0, mock_req, pre_id=next_token)
|
||||
processor.update_state(mock_runner.share_inputs)
|
||||
self.assertEqual(processor._states[req_id].tokens_after_start, 1)
|
||||
self.assertEqual(processor._states[req_id].last_token_id, 0)
|
||||
|
||||
# Step 3: Generate 2nd decode token (budget 3, generated 1 -> 2)
|
||||
logits = self._get_initial_logits(1)
|
||||
processor.update_state(mock_runner.share_inputs)
|
||||
processed_logits = processor.apply(logits)
|
||||
@@ -226,42 +255,17 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
||||
self.assertEqual(processor._states[req_id].tokens_after_start, 2)
|
||||
self.assertEqual(processor._states[req_id].last_token_id, 0)
|
||||
|
||||
# Step 3: Generate 3rd token (budget 3, generated 2 -> 3). Next step should force NEW_LINE.
|
||||
# Step 4: Generate 3rd decode token (budget 3, generated 2 -> 3).
|
||||
logits = self._get_initial_logits(1)
|
||||
processor.update_state(mock_runner.share_inputs)
|
||||
processed_logits = processor.apply(logits)
|
||||
next_token = mock_runner.generate_next_token(processed_logits)[0]
|
||||
self.assertEqual(next_token, 0) # Normal generation
|
||||
self.assertEqual(next_token, 0)
|
||||
mock_runner.update_request_state(0, mock_req, pre_id=next_token)
|
||||
processor.update_state(mock_runner.share_inputs)
|
||||
self.assertEqual(processor._states[req_id].tokens_after_start, 3) # Budget is now met
|
||||
self.assertEqual(processor._states[req_id].last_token_id, 0)
|
||||
self.assertEqual(processor._states[req_id].tokens_after_start, 3)
|
||||
|
||||
# Step 4: Budget reached, last token not NEW_LINE. Should force NEW_LINE_TOKEN_ID.
|
||||
logits = self._get_initial_logits(1)
|
||||
processor.update_state(mock_runner.share_inputs)
|
||||
processed_logits = processor.apply(logits)
|
||||
|
||||
# Verify all other logits are -inf, only NEW_LINE_TOKEN_ID is 0.0
|
||||
# Use <= for comparison because paddle.full creates -10.0 and comparison can have precision issues
|
||||
other_logits = paddle.concat(
|
||||
[
|
||||
processed_logits[0, :line_break_id],
|
||||
processed_logits[0, line_break_id + 1 : VOCAB_SIZE],
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
self.assertTrue(paddle.all(other_logits <= -10.0).item() or paddle.all(other_logits == -float("inf")).item())
|
||||
self.assertEqual(processed_logits[0, line_break_id].item(), 0.0)
|
||||
|
||||
next_token = mock_runner.generate_next_token(processed_logits)[0]
|
||||
self.assertEqual(next_token, line_break_id) # Forces NEW_LINE
|
||||
mock_runner.update_request_state(0, mock_req, pre_id=next_token)
|
||||
processor.update_state(mock_runner.share_inputs)
|
||||
self.assertEqual(processor._states[req_id].tokens_after_start, 4) # Still increment
|
||||
self.assertEqual(processor._states[req_id].last_token_id, line_break_id)
|
||||
|
||||
# Step 5: Last token is NEW_LINE. Should force THINKING_END_TOKEN_ID.
|
||||
# Step 5: Budget reached, should force THINKING_END_TOKEN_ID directly.
|
||||
logits = self._get_initial_logits(1)
|
||||
processor.update_state(mock_runner.share_inputs)
|
||||
processed_logits = processor.apply(logits)
|
||||
@@ -324,9 +328,9 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
||||
self.assertEqual(next_token, stop_sentence_token_ids[0])
|
||||
|
||||
def test_thinking_budget_no_stop_sentence_defaults(self):
|
||||
# Scenario: No stop sentence, budget reached should force newline.
|
||||
# Scenario: No stop sentence, budget reached should force thinking_end directly.
|
||||
req_id = "test_req_no_stop_sentence"
|
||||
prompt_ids = [THINKING_START_TOKEN_ID, 42] # 1 token after start
|
||||
prompt_ids = [THINKING_START_TOKEN_ID, 42]
|
||||
sampling_params = SamplingParams(logits_processors_args={"thinking_budget": 1})
|
||||
mock_req = MockRequest(req_id, prompt_ids, sampling_params)
|
||||
|
||||
@@ -334,35 +338,42 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
||||
mock_runner.update_request_state(0, mock_req, pre_id=prompt_ids[-1])
|
||||
|
||||
processor = ThinkingBudgetLogitsProcessor(self.fd_config)
|
||||
line_break_id = processor.line_break_token_id
|
||||
|
||||
processor.update_state(mock_runner.share_inputs)
|
||||
logits = self._get_initial_logits(1)
|
||||
processed_logits = processor.apply(logits)
|
||||
next_token = mock_runner.generate_next_token(processed_logits)[0]
|
||||
self.assertEqual(next_token, line_break_id)
|
||||
self.assertEqual(next_token, 0)
|
||||
mock_runner.update_request_state(0, mock_req, pre_id=next_token)
|
||||
processor.update_state(mock_runner.share_inputs)
|
||||
|
||||
logits = self._get_initial_logits(1)
|
||||
processor.update_state(mock_runner.share_inputs)
|
||||
processed_logits = processor.apply(logits)
|
||||
next_token = mock_runner.generate_next_token(processed_logits)[0]
|
||||
self.assertEqual(next_token, THINKING_END_TOKEN_ID)
|
||||
|
||||
def test_thinking_budget_uses_config_token_ids(self):
|
||||
# Scenario: Processor should use token ids from model config.
|
||||
self.fd_config.model_config.think_start_id = 123
|
||||
self.fd_config.model_config.think_end_id = 124
|
||||
self.fd_config.model_config.line_break_id = 125
|
||||
self.fd_config.model_config.line_break_id = -1
|
||||
processor = ThinkingBudgetLogitsProcessor(self.fd_config)
|
||||
self.assertEqual(processor.think_start_token_id, 123)
|
||||
self.assertEqual(processor.think_end_token_id, 124)
|
||||
self.assertEqual(processor.line_break_token_id, 125)
|
||||
self.assertEqual(processor.line_break_token_id, -1)
|
||||
self.assertTrue(processor._enabled)
|
||||
|
||||
def test_thinking_budget_disabled_when_token_ids_missing(self):
|
||||
# Scenario: Processor should be disabled when token ids are not configured.
|
||||
self.fd_config.model_config.think_start_id = -1
|
||||
self.fd_config.model_config.think_end_id = -1
|
||||
self.fd_config.model_config.line_break_id = -1
|
||||
self.fd_config.model_config.line_break_id = NEW_LINE_TOKEN_ID
|
||||
processor = ThinkingBudgetLogitsProcessor(self.fd_config)
|
||||
self.assertFalse(processor._enabled)
|
||||
self.assertEqual(processor.think_start_token_id, -1)
|
||||
self.assertEqual(processor.think_end_token_id, -1)
|
||||
self.assertEqual(processor.line_break_token_id, -1)
|
||||
self.assertEqual(processor.line_break_token_id, NEW_LINE_TOKEN_ID)
|
||||
|
||||
# update_state and apply should be no-op when disabled
|
||||
mock_runner = MockModelRunner(self.fd_config, max_num_seqs=1)
|
||||
@@ -493,6 +504,28 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
||||
self.assertEqual(state.tokens_after_start, 2)
|
||||
self.assertEqual(state.last_token_id, 99)
|
||||
|
||||
def test_thinking_budget_prompt_state_from_token_ids_all_fallback(self):
|
||||
req_id = "req_gpu_fallback"
|
||||
mock_runner = MockModelRunner(self.fd_config, max_num_seqs=1)
|
||||
mock_runner.share_inputs["req_ids"][0] = req_id
|
||||
mock_runner.share_inputs["logits_processors_args"][0] = {"thinking_budget": 3}
|
||||
mock_runner.share_inputs["prompt_ids"] = None
|
||||
mock_runner.share_inputs["token_ids_all"][0, :4] = paddle.to_tensor(
|
||||
[1, THINKING_START_TOKEN_ID, 2, 3], dtype=paddle.int64
|
||||
)
|
||||
mock_runner.share_inputs["prompt_lens"][0, 0] = paddle.to_tensor(4, dtype=paddle.int64)
|
||||
mock_runner.share_inputs["next_tokens"][0, 0] = paddle.to_tensor(-1, dtype=paddle.int64)
|
||||
|
||||
processor = ThinkingBudgetLogitsProcessor(self.fd_config)
|
||||
processor.update_state(mock_runner.share_inputs)
|
||||
|
||||
state = processor._states[req_id]
|
||||
self.assertTrue(state.prompt_checked)
|
||||
self.assertTrue(state.started)
|
||||
self.assertFalse(state.ended)
|
||||
self.assertEqual(state.tokens_after_start, 0)
|
||||
self.assertEqual(state.last_token_id, 3)
|
||||
|
||||
def test_thinking_budget_not_configured(self):
|
||||
# Scenario: Processor is active, but request does not provide thinking_budget
|
||||
req_id = "test_req_3"
|
||||
@@ -532,20 +565,7 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
||||
self.assertEqual(processor._states[req_id].tokens_after_start, 0) # No tokens after start yet
|
||||
self.assertEqual(processor._states[req_id].last_token_id, THINKING_START_TOKEN_ID)
|
||||
|
||||
# Step 1: Budget 0 reached, last token is THINKING_START. Should force NEW_LINE.
|
||||
logits = self._get_initial_logits(1)
|
||||
processor.update_state(mock_runner.share_inputs)
|
||||
processed_logits = processor.apply(logits)
|
||||
|
||||
self.assertEqual(processed_logits[0, NEW_LINE_TOKEN_ID].item(), 0.0)
|
||||
next_token = mock_runner.generate_next_token(processed_logits)[0]
|
||||
self.assertEqual(next_token, NEW_LINE_TOKEN_ID)
|
||||
mock_runner.update_request_state(0, mock_req, pre_id=next_token)
|
||||
processor.update_state(mock_runner.share_inputs)
|
||||
self.assertEqual(processor._states[req_id].tokens_after_start, 1) # Still increments
|
||||
self.assertEqual(processor._states[req_id].last_token_id, NEW_LINE_TOKEN_ID)
|
||||
|
||||
# Step 2: Last token is NEW_LINE. Should force THINKING_END_TOKEN_ID.
|
||||
# Step 1: Budget 0 reached, last token is THINKING_START. Should force THINKING_END.
|
||||
logits = self._get_initial_logits(1)
|
||||
processor.update_state(mock_runner.share_inputs)
|
||||
processed_logits = processor.apply(logits)
|
||||
@@ -572,7 +592,7 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
||||
processor.update_state(mock_runner.share_inputs)
|
||||
self.assertTrue(processor._states[req_id].started)
|
||||
self.assertTrue(processor._states[req_id].ended)
|
||||
self.assertEqual(processor._states[req_id].tokens_after_start, 2) # Tokens 2, THINKING_END after start
|
||||
self.assertEqual(processor._states[req_id].tokens_after_start, 0)
|
||||
|
||||
logits = self._get_initial_logits(1)
|
||||
processor.update_state(mock_runner.share_inputs)
|
||||
@@ -582,7 +602,7 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
||||
def test_multiple_requests(self):
|
||||
# Scenario: Multiple requests with different thinking states
|
||||
req_id_1 = "req_a"
|
||||
prompt_ids_1 = [THINKING_START_TOKEN_ID, 10, 11] # budget 2, last_token=11, tokens_after_start=2
|
||||
prompt_ids_1 = [THINKING_START_TOKEN_ID, 10, 11]
|
||||
sampling_params_1 = SamplingParams(logits_processors_args={"thinking_budget": 2})
|
||||
mock_req_1 = MockRequest(req_id_1, prompt_ids_1, sampling_params_1)
|
||||
|
||||
@@ -592,7 +612,7 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
||||
mock_req_2 = MockRequest(req_id_2, prompt_ids_2, sampling_params_2)
|
||||
|
||||
req_id_3 = "req_c"
|
||||
prompt_ids_3 = [THINKING_START_TOKEN_ID, 30] # budget 1, last_token=30, tokens_after_start=1
|
||||
prompt_ids_3 = [THINKING_START_TOKEN_ID, 30]
|
||||
sampling_params_3 = SamplingParams(logits_processors_args={"thinking_budget": 1})
|
||||
mock_req_3 = MockRequest(req_id_3, prompt_ids_3, sampling_params_3)
|
||||
|
||||
@@ -609,14 +629,14 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
||||
# Verify initial states
|
||||
self.assertTrue(processor._states[req_id_1].started)
|
||||
self.assertFalse(processor._states[req_id_1].ended)
|
||||
self.assertEqual(processor._states[req_id_1].tokens_after_start, 2)
|
||||
self.assertEqual(processor._states[req_id_1].tokens_after_start, 0)
|
||||
self.assertEqual(processor._states[req_id_1].last_token_id, 11)
|
||||
|
||||
self.assertNotIn(req_id_2, processor._states) # No budget specified for req_2
|
||||
|
||||
self.assertTrue(processor._states[req_id_3].started)
|
||||
self.assertFalse(processor._states[req_id_3].ended)
|
||||
self.assertEqual(processor._states[req_id_3].tokens_after_start, 1)
|
||||
self.assertEqual(processor._states[req_id_3].tokens_after_start, 0)
|
||||
self.assertEqual(processor._states[req_id_3].last_token_id, 30)
|
||||
|
||||
# Simulate logits for the batch
|
||||
@@ -624,16 +644,14 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
||||
processor.update_state(mock_runner.share_inputs) # Ensure state is updated before apply
|
||||
processed_batch_logits = processor.apply(batch_logits)
|
||||
|
||||
# Req 1: budget 2, tokens_after_start 2. Should force NEW_LINE (last_token_id is 11, not NEW_LINE)
|
||||
self.assertEqual(processed_batch_logits[0, NEW_LINE_TOKEN_ID].item(), 0.0)
|
||||
self.assertEqual(paddle.argmax(processed_batch_logits[0], axis=-1).item(), NEW_LINE_TOKEN_ID)
|
||||
# Req 1: prompt-side content does not consume budget, so first step is normal generation.
|
||||
self.assertEqual(paddle.argmax(processed_batch_logits[0], axis=-1).item(), 0)
|
||||
|
||||
# Req 2: No thinking budget, normal generation
|
||||
self.assertEqual(paddle.argmax(processed_batch_logits[1], axis=-1).item(), 0)
|
||||
|
||||
# Req 3: budget 1, tokens_after_start 1. Should force NEW_LINE (last_token_id is 30, not NEW_LINE)
|
||||
self.assertEqual(processed_batch_logits[2, NEW_LINE_TOKEN_ID].item(), 0.0)
|
||||
self.assertEqual(paddle.argmax(processed_batch_logits[2], axis=-1).item(), NEW_LINE_TOKEN_ID)
|
||||
# Req 3: prompt-side content does not consume budget, so first step is normal generation.
|
||||
self.assertEqual(paddle.argmax(processed_batch_logits[2], axis=-1).item(), 0)
|
||||
|
||||
# Simulate generating next tokens and updating state
|
||||
next_tokens = mock_runner.generate_next_token(processed_batch_logits)
|
||||
@@ -643,20 +661,25 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
||||
processor.update_state(mock_runner.share_inputs)
|
||||
|
||||
# Verify updated states for next step
|
||||
self.assertEqual(processor._states[req_id_1].last_token_id, NEW_LINE_TOKEN_ID)
|
||||
self.assertEqual(processor._states[req_id_3].last_token_id, NEW_LINE_TOKEN_ID)
|
||||
self.assertEqual(processor._states[req_id_1].last_token_id, 0)
|
||||
self.assertEqual(processor._states[req_id_3].last_token_id, 0)
|
||||
self.assertEqual(processor._states[req_id_1].tokens_after_start, 1)
|
||||
self.assertEqual(processor._states[req_id_3].tokens_after_start, 1)
|
||||
self.assertFalse(processor._states[req_id_1].ended)
|
||||
self.assertFalse(processor._states[req_id_3].ended)
|
||||
|
||||
batch_logits = self._get_initial_logits(3)
|
||||
processor.update_state(mock_runner.share_inputs)
|
||||
processed_batch_logits = processor.apply(batch_logits)
|
||||
|
||||
# Req 1: last token was NEW_LINE. Should force THINKING_END
|
||||
self.assertEqual(paddle.argmax(processed_batch_logits[0], axis=-1).item(), THINKING_END_TOKEN_ID)
|
||||
# Req 1: budget 2, tokens_after_start 1. Still normal generation.
|
||||
self.assertEqual(paddle.argmax(processed_batch_logits[0], axis=-1).item(), 0)
|
||||
|
||||
# Req 2: Still normal generation
|
||||
self.assertEqual(paddle.argmax(processed_batch_logits[1], axis=-1).item(), 0)
|
||||
|
||||
# Req 3: last token was NEW_LINE. Should force THINKING_END
|
||||
# Req 3: budget 1 reached after one generated token, should now force THINKING_END.
|
||||
self.assertEqual(processed_batch_logits[2, THINKING_END_TOKEN_ID].item(), 0.0)
|
||||
self.assertEqual(paddle.argmax(processed_batch_logits[2], axis=-1).item(), THINKING_END_TOKEN_ID)
|
||||
|
||||
|
||||
@@ -724,7 +747,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
|
||||
self.assertTrue(updated["think_prompt_checked"])
|
||||
self.assertTrue(updated["think_prompt_started"])
|
||||
self.assertTrue(updated["think_prompt_ended"])
|
||||
self.assertEqual(updated["think_prompt_tokens_after_start"], 2)
|
||||
self.assertEqual(updated["think_prompt_tokens_after_start"], 0)
|
||||
self.assertEqual(updated["think_prompt_last_token_id"], 3)
|
||||
|
||||
def test_v1_process_request_missing_logits_processors_args(self):
|
||||
@@ -833,6 +856,58 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
|
||||
self.assertEqual(processor.encode_with_cache("iter"), [41, 42])
|
||||
self.assertNotIn(("np", False), processor._tokenize_cache)
|
||||
|
||||
def test_text_encode_with_cache_lazy_init(self):
|
||||
processor = TextDataProcessor.__new__(TextDataProcessor)
|
||||
call_counter = {"count": 0}
|
||||
|
||||
def _text2ids(text, max_model_len=None, add_special_tokens=False):
|
||||
call_counter["count"] += 1
|
||||
return np.array([51, 52], dtype=np.int64)
|
||||
|
||||
processor.text2ids = _text2ids
|
||||
|
||||
self.assertFalse(hasattr(processor, "_tokenize_cache"))
|
||||
self.assertEqual(processor.encode_with_cache("lazy"), [51, 52])
|
||||
self.assertTrue(hasattr(processor, "_tokenize_cache"))
|
||||
self.assertEqual(processor.encode_with_cache("lazy"), [51, 52])
|
||||
self.assertEqual(call_counter["count"], 1)
|
||||
|
||||
def test_v1_encode_with_cache_lazy_init(self):
|
||||
processor = V1TextDataProcessor.__new__(V1TextDataProcessor)
|
||||
call_counter = {"count": 0}
|
||||
|
||||
def _text2ids(text, max_model_len=None, add_special_tokens=False):
|
||||
call_counter["count"] += 1
|
||||
return np.array([61, 62], dtype=np.int64)
|
||||
|
||||
processor.text2ids = _text2ids
|
||||
|
||||
self.assertFalse(hasattr(processor, "_tokenize_cache"))
|
||||
self.assertEqual(processor.encode_with_cache("lazy"), [61, 62])
|
||||
self.assertTrue(hasattr(processor, "_tokenize_cache"))
|
||||
self.assertEqual(processor.encode_with_cache("lazy"), [61, 62])
|
||||
self.assertEqual(call_counter["count"], 1)
|
||||
|
||||
def test_ernie_encode_literal_text_with_cache(self):
|
||||
processor = ErnieTextDataProcessor.__new__(ErnieTextDataProcessor)
|
||||
processor.tokenizer = SimpleNamespace(
|
||||
tokenize=lambda text: ["token_a", "token_b"],
|
||||
convert_tokens_to_ids=lambda tokens: [71, 72],
|
||||
)
|
||||
|
||||
self.assertEqual(processor._encode_literal_text_with_cache("fallback"), [71, 72])
|
||||
self.assertEqual(processor._encode_literal_text_with_cache("fallback"), [71, 72])
|
||||
|
||||
def test_v1_ernie_encode_literal_text_with_cache(self):
|
||||
processor = V1ErnieTextDataProcessor.__new__(V1ErnieTextDataProcessor)
|
||||
processor.tokenizer = SimpleNamespace(
|
||||
tokenize=lambda text: ["token_c", "token_d"],
|
||||
convert_tokens_to_ids=lambda tokens: [81, 82],
|
||||
)
|
||||
|
||||
self.assertEqual(processor._encode_literal_text_with_cache("fallback"), [81, 82])
|
||||
self.assertEqual(processor._encode_literal_text_with_cache("fallback"), [81, 82])
|
||||
|
||||
def test_text_update_thinking_prompt_state_branches(self):
|
||||
processor = TextDataProcessor.__new__(TextDataProcessor)
|
||||
processor._think_token_ids = None
|
||||
@@ -868,7 +943,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
|
||||
)
|
||||
self.assertTrue(with_start_no_end["think_prompt_started"])
|
||||
self.assertFalse(with_start_no_end["think_prompt_ended"])
|
||||
self.assertEqual(with_start_no_end["think_prompt_tokens_after_start"], 2)
|
||||
self.assertEqual(with_start_no_end["think_prompt_tokens_after_start"], 0)
|
||||
self.assertEqual(with_start_no_end["think_prompt_last_token_id"], 3)
|
||||
|
||||
# 命中 _get_think_token_ids 的缓存分支
|
||||
@@ -891,7 +966,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
|
||||
)
|
||||
self.assertTrue(with_start_no_end["think_prompt_started"])
|
||||
self.assertFalse(with_start_no_end["think_prompt_ended"])
|
||||
self.assertEqual(with_start_no_end["think_prompt_tokens_after_start"], 2)
|
||||
self.assertEqual(with_start_no_end["think_prompt_tokens_after_start"], 0)
|
||||
self.assertEqual(with_start_no_end["think_prompt_last_token_id"], 3)
|
||||
|
||||
# 命中 _get_think_token_ids 的缓存分支
|
||||
@@ -903,7 +978,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
|
||||
processor.eos_token_ids = [1]
|
||||
processor.update_stop_seq = lambda *args, **kwargs: None
|
||||
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
|
||||
processor.encode_with_cache = lambda text, *args, **kwargs: [23] if text == "\n" else [201, 202]
|
||||
processor._encode_literal_text_with_cache = lambda text: [201, 202]
|
||||
processor._update_thinking_prompt_state = lambda prompt_token_ids, args: args
|
||||
processor.reasoning_parser = None
|
||||
|
||||
@@ -924,7 +999,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
|
||||
processed = processor.process_request_dict(request, max_model_len=16)
|
||||
self.assertEqual(
|
||||
processed["logits_processors_args"].get("think_stop_sentence_token_ids"),
|
||||
[23, 201, 202],
|
||||
[201, 202],
|
||||
)
|
||||
self.assertNotIn("think_stop_sentence", processed["logits_processors_args"])
|
||||
|
||||
@@ -934,7 +1009,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
|
||||
processor.eos_token_ids = [1]
|
||||
processor.update_stop_seq = lambda *args, **kwargs: None
|
||||
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
|
||||
processor.encode_with_cache = lambda text, *args, **kwargs: [23] if text == "\n" else [301, 302]
|
||||
processor._encode_literal_text_with_cache = lambda text: [301, 302]
|
||||
processor._update_thinking_prompt_state = lambda prompt_token_ids, args: args
|
||||
processor.reasoning_parser = None
|
||||
|
||||
@@ -955,7 +1030,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
|
||||
processed = processor.process_request(request, max_model_len=16)
|
||||
self.assertEqual(
|
||||
processed.logits_processors_args.get("think_stop_sentence_token_ids"),
|
||||
[23, 301, 302],
|
||||
[301, 302],
|
||||
)
|
||||
self.assertNotIn("think_stop_sentence", processed.logits_processors_args)
|
||||
|
||||
@@ -965,7 +1040,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
|
||||
processor.eos_token_ids = [1]
|
||||
processor.update_stop_seq = lambda *args, **kwargs: None
|
||||
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
|
||||
processor.encode_with_cache = lambda text, *args, **kwargs: [23] if text == "\n" else [401, 402]
|
||||
processor._encode_literal_text_with_cache = lambda text: [401, 402]
|
||||
processor._update_thinking_prompt_state = lambda prompt_token_ids, args: args
|
||||
processor.reasoning_parser = None
|
||||
|
||||
@@ -992,10 +1067,173 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
|
||||
processed = processor.process_request_dict(request, max_model_len=16)
|
||||
self.assertEqual(
|
||||
processed.sampling_params.logits_processors_args.get("think_stop_sentence_token_ids"),
|
||||
[23, 401, 402],
|
||||
[401, 402],
|
||||
)
|
||||
self.assertNotIn("think_stop_sentence", processed.sampling_params.logits_processors_args)
|
||||
|
||||
def test_ernie_process_request_dict_prepares_thinking_budget_args(self):
|
||||
processor = ErnieTextDataProcessor.__new__(ErnieTextDataProcessor)
|
||||
processor._apply_default_parameters = lambda request: request
|
||||
processor.eos_token_ids = [1]
|
||||
processor.update_stop_seq = lambda *args, **kwargs: None
|
||||
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
|
||||
processor._encode_literal_text_with_cache = lambda text: [501, 502]
|
||||
processor.tokenizer = DummyTokenizerForTextProcessor()
|
||||
processor.reasoning_parser = None
|
||||
|
||||
request = {
|
||||
"request_id": "req_ernie_text",
|
||||
"eos_token_ids": [1],
|
||||
"prompt_token_ids": [1, THINKING_START_TOKEN_ID, 2],
|
||||
"prompt": None,
|
||||
"messages": None,
|
||||
"logits_processors_args": {"thinking_budget": 20, "think_stop_sentence": "done"},
|
||||
"bad_words": None,
|
||||
"bad_words_token_ids": None,
|
||||
"max_tokens": 1,
|
||||
"temperature": 1.0,
|
||||
"top_p": 0.9,
|
||||
"response_max_tokens": None,
|
||||
"enable_thinking": True,
|
||||
}
|
||||
with patch("fastdeploy.input.ernie4_5_processor.process_stop_token_ids", lambda *args, **kwargs: None):
|
||||
processed = processor.process_request_dict(request, max_model_len=16)
|
||||
|
||||
self.assertEqual(processed["logits_processors_args"]["think_stop_sentence_token_ids"], [501, 502])
|
||||
self.assertTrue(processed["logits_processors_args"]["think_prompt_started"])
|
||||
self.assertFalse(processed["logits_processors_args"]["think_prompt_ended"])
|
||||
self.assertEqual(processed["logits_processors_args"]["think_prompt_tokens_after_start"], 0)
|
||||
|
||||
def test_v1_ernie_process_request_dict_prepares_thinking_budget_args(self):
|
||||
processor = V1ErnieTextDataProcessor.__new__(V1ErnieTextDataProcessor)
|
||||
processor._apply_default_parameters = lambda request: request
|
||||
processor.eos_token_ids = [1]
|
||||
processor.update_stop_seq = lambda *args, **kwargs: None
|
||||
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
|
||||
processor._encode_literal_text_with_cache = lambda text: [601, 602]
|
||||
processor.tokenizer = DummyTokenizerForTextProcessor()
|
||||
processor.reasoning_parser = None
|
||||
|
||||
request = DummyRequestV1(
|
||||
request_id="req_v1_ernie_text",
|
||||
eos_token_ids=[1],
|
||||
prompt_token_ids=[1, THINKING_START_TOKEN_ID, 2],
|
||||
prompt=None,
|
||||
messages=None,
|
||||
chat_template_kwargs=None,
|
||||
enable_thinking=True,
|
||||
sampling_params=SimpleNamespace(
|
||||
bad_words=None,
|
||||
bad_words_token_ids=None,
|
||||
max_tokens=1,
|
||||
temperature=1.0,
|
||||
top_p=0.9,
|
||||
repetition_penalty=1.0,
|
||||
frequency_penalty=0.0,
|
||||
presence_penalty=0.0,
|
||||
response_max_tokens=None,
|
||||
n=1,
|
||||
logits_processors_args={"thinking_budget": 20, "think_stop_sentence": "done"},
|
||||
),
|
||||
)
|
||||
with patch("fastdeploy.input.v1.ernie4_5_processor.process_stop_token_ids", lambda *args, **kwargs: None):
|
||||
processed = processor.process_request_dict(request, max_model_len=16)
|
||||
|
||||
self.assertEqual(processed.sampling_params.logits_processors_args["think_stop_sentence_token_ids"], [601, 602])
|
||||
self.assertTrue(processed.sampling_params.logits_processors_args["think_prompt_started"])
|
||||
self.assertFalse(processed.sampling_params.logits_processors_args["think_prompt_ended"])
|
||||
self.assertEqual(processed.sampling_params.logits_processors_args["think_prompt_tokens_after_start"], 0)
|
||||
|
||||
def test_ernie_vl_process_request_dict_prepares_thinking_budget_args(self):
|
||||
processor = ErnieVLDataProcessor.__new__(ErnieVLDataProcessor)
|
||||
processor._apply_default_parameters = lambda request: request
|
||||
processor.eos_token_ids = [1]
|
||||
processor.update_stop_seq = lambda *args, **kwargs: None
|
||||
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
|
||||
processor._encode_literal_text_with_cache = lambda text: [701, 702]
|
||||
processor.tokenizer = DummyTokenizerForTextProcessor()
|
||||
processor.reasoning_parser = None
|
||||
processor._check_mm_limits = lambda *args, **kwargs: None
|
||||
processor.append_completion_tokens = lambda *args, **kwargs: None
|
||||
processor.pack_outputs = lambda outs: outs
|
||||
processor.ernie4_5_processor = SimpleNamespace(
|
||||
request2ids=lambda request: {"input_ids": np.array([1, THINKING_START_TOKEN_ID, 2], dtype=np.int64)}
|
||||
)
|
||||
|
||||
request = {
|
||||
"request_id": "req_ernie_vl",
|
||||
"eos_token_ids": [1],
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"bad_words": None,
|
||||
"bad_words_token_ids": None,
|
||||
"logits_processors_args": {"thinking_budget": 20, "think_stop_sentence": "done"},
|
||||
"max_tokens": 1,
|
||||
"top_p": 0.9,
|
||||
"response_max_tokens": None,
|
||||
}
|
||||
with patch(
|
||||
"fastdeploy.input.ernie4_5_vl_processor.ernie4_5_vl_processor.process_stop_token_ids",
|
||||
lambda *args, **kwargs: None,
|
||||
):
|
||||
processed = processor.process_request_dict(request, max_model_len=16)
|
||||
|
||||
self.assertEqual(processed["logits_processors_args"]["think_stop_sentence_token_ids"], [701, 702])
|
||||
self.assertTrue(processed["logits_processors_args"]["think_prompt_started"])
|
||||
self.assertFalse(processed["logits_processors_args"]["think_prompt_ended"])
|
||||
self.assertEqual(processed["logits_processors_args"]["think_prompt_tokens_after_start"], 0)
|
||||
|
||||
def test_v1_ernie_vl_process_request_dict_prepares_thinking_budget_args(self):
|
||||
processor = V1ErnieVLDataProcessor.__new__(V1ErnieVLDataProcessor)
|
||||
processor._apply_default_parameters = lambda request: request
|
||||
processor.eos_token_ids = [1]
|
||||
processor.update_stop_seq = lambda *args, **kwargs: None
|
||||
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
|
||||
processor._encode_literal_text_with_cache = lambda text: [801, 802]
|
||||
processor.tokenizer = DummyTokenizerForTextProcessor()
|
||||
processor.reasoning_parser = None
|
||||
processor._check_mm_limits = lambda *args, **kwargs: None
|
||||
processor.append_completion_tokens = lambda *args, **kwargs: None
|
||||
processor.pack_outputs = lambda outs: outs
|
||||
processor.ernie4_5_processor = SimpleNamespace(
|
||||
request2ids=lambda request: {"input_ids": np.array([1, THINKING_START_TOKEN_ID, 2], dtype=np.int64)}
|
||||
)
|
||||
|
||||
request = DummyRequestV1(
|
||||
request_id="req_v1_ernie_vl",
|
||||
eos_token_ids=[1],
|
||||
prompt_token_ids=None,
|
||||
prompt=None,
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
chat_template_kwargs=None,
|
||||
enable_thinking=True,
|
||||
completion_token_ids=None,
|
||||
multimodal_data=None,
|
||||
sampling_params=SimpleNamespace(
|
||||
bad_words=None,
|
||||
bad_words_token_ids=None,
|
||||
max_tokens=1,
|
||||
temperature=1.0,
|
||||
top_p=0.9,
|
||||
repetition_penalty=1.0,
|
||||
frequency_penalty=0.0,
|
||||
presence_penalty=0.0,
|
||||
response_max_tokens=None,
|
||||
reasoning_max_tokens=None,
|
||||
n=1,
|
||||
logits_processors_args={"thinking_budget": 20, "think_stop_sentence": "done"},
|
||||
),
|
||||
)
|
||||
with patch(
|
||||
"fastdeploy.input.v1.ernie4_5_vl_processor.ernie4_5_vl_processor.process_stop_token_ids",
|
||||
lambda *args, **kwargs: None,
|
||||
):
|
||||
processed = processor.process_request_dict(request, max_model_len=16)
|
||||
|
||||
self.assertEqual(processed.sampling_params.logits_processors_args["think_stop_sentence_token_ids"], [801, 802])
|
||||
self.assertTrue(processed.sampling_params.logits_processors_args["think_prompt_started"])
|
||||
self.assertFalse(processed.sampling_params.logits_processors_args["think_prompt_ended"])
|
||||
self.assertEqual(processed.sampling_params.logits_processors_args["think_prompt_tokens_after_start"], 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user