[Bugfix] Align thinking_budget behavior with ERNIE reasoning flow (#6934)

* [Bugfix] Align thinking_budget behavior with ERNIE reasoning flow

* [Docs] Fix thinking_budget markdown formatting

* [Test] Align ernie thinking budget test with process_request_dict
This commit is contained in:
jackyYang6
2026-03-23 14:15:55 +08:00
committed by GitHub
parent 7a78001be2
commit 634d23a38a
10 changed files with 663 additions and 285 deletions
+327 -89
View File
@@ -20,7 +20,19 @@ from fastdeploy.engine import common_engine as common_engine_module
from fastdeploy.engine import engine as engine_module
from fastdeploy.engine.args_utils import EngineArgs # Import EngineArgs
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.input.ernie4_5_processor import (
Ernie4_5Processor as ErnieTextDataProcessor,
)
from fastdeploy.input.ernie4_5_vl_processor import (
Ernie4_5_VLProcessor as ErnieVLDataProcessor,
)
from fastdeploy.input.text_processor import DataProcessor as TextDataProcessor
from fastdeploy.input.v1.ernie4_5_processor import (
Ernie4_5Processor as V1ErnieTextDataProcessor,
)
from fastdeploy.input.v1.ernie4_5_vl_processor.ernie4_5_vl_processor import (
Ernie4_5_VLProcessor as V1ErnieVLDataProcessor,
)
from fastdeploy.input.v1.text_processor import DataProcessor as V1TextDataProcessor
from fastdeploy.model_executor.logits_processor import ThinkingBudgetLogitsProcessor
from fastdeploy.scheduler import SchedulerConfig
@@ -52,6 +64,7 @@ class MockModelRunner:
"req_ids": [f"req_{i}" for i in range(max_num_seqs)],
"logits_processors_args": [{} for _ in range(max_num_seqs)],
"prompt_ids": paddle.to_tensor(np.zeros((max_num_seqs, 10), dtype=np.int64)), # Max prompt len 10
"token_ids_all": paddle.to_tensor(np.full((max_num_seqs, self.max_model_len), -1, dtype=np.int64)),
"prompt_lens": paddle.to_tensor(np.zeros((max_num_seqs, 1), dtype=np.int64)),
"pre_ids": paddle.to_tensor(np.full((max_num_seqs, self.max_model_len), -1, dtype=np.int64)),
"step_idx": paddle.to_tensor(np.zeros((max_num_seqs, 1), dtype=np.int64)),
@@ -73,6 +86,12 @@ class MockModelRunner:
self.share_inputs["prompt_ids"][slot_id, : len(req.prompt_ids)] = paddle.to_tensor(
req.prompt_ids, dtype=paddle.int64
)
self.share_inputs["token_ids_all"][slot_id, :] = paddle.to_tensor(
np.full((self.max_model_len,), -1, dtype=np.int64), dtype=paddle.int64
)
self.share_inputs["token_ids_all"][slot_id, : len(req.prompt_ids)] = paddle.to_tensor(
req.prompt_ids, dtype=paddle.int64
)
self.share_inputs["prompt_lens"][slot_id] = paddle.to_tensor(len(req.prompt_ids), dtype=paddle.int64)
if req.sampling_params.logits_processors_args:
self.share_inputs["logits_processors_args"][slot_id] = req.sampling_params.logits_processors_args
@@ -152,7 +171,7 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
return logits
def test_thinking_budget_not_reached(self):
# Scenario: Thinking budget is 5, but only 3 tokens are generated in thinking phase
# Scenario: Thinking budget is 5, and prompt-side tokens after <think> do not consume budget.
req_id = "test_req_1"
prompt_ids = [1, 2, THINKING_START_TOKEN_ID, 3, 4, 5]
sampling_params = SamplingParams(logits_processors_args={"thinking_budget": 5})
@@ -167,9 +186,9 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
processor.update_state(mock_runner.share_inputs)
self.assertTrue(processor._states[req_id].started)
self.assertFalse(processor._states[req_id].ended)
self.assertEqual(processor._states[req_id].tokens_after_start, 3) # (3, 4, 5) after THINKING_START
self.assertEqual(processor._states[req_id].tokens_after_start, 0)
# Step 2: Simulate one generation step (budget 5, generated 3 -> 4)
# Step 2: Simulate one generation step (budget 5, generated 0 -> 1)
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs) # Update state before apply
processed_logits = processor.apply(logits)
@@ -181,9 +200,9 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
mock_runner.update_request_state(0, mock_req, pre_id=0) # Update last token
processor.update_state(mock_runner.share_inputs) # Update state after generating token
self.assertEqual(processor._states[req_id].tokens_after_start, 4) # (3, 4, 5, 0)
self.assertEqual(processor._states[req_id].tokens_after_start, 1)
# Step 3: Simulate another generation step (budget 5, generated 4 -> 5)
# Step 3: Simulate another generation step (budget 5, generated 1 -> 2)
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs) # Update state before apply
processed_logits = processor.apply(logits)
@@ -191,14 +210,14 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
mock_runner.update_request_state(0, mock_req, pre_id=0) # Update last token
processor.update_state(mock_runner.share_inputs) # Update state after generating token
self.assertEqual(processor._states[req_id].tokens_after_start, 5) # (3, 4, 5, 0, 0)
self.assertEqual(processor._states[req_id].tokens_after_start, 2)
# LogitsProcessor should still not restrict as NEW_LINE_TOKEN is not yet last token
def test_thinking_budget_reached_forces_newline(self):
# Scenario: Budget is 3, after 3 tokens, it should force newline, then thinking_end
def test_thinking_budget_reached_forces_think_end(self):
# Scenario: Budget is 3 and only decode-time tokens count toward the budget.
req_id = "test_req_2"
prompt_ids = [1, 2, THINKING_START_TOKEN_ID, 3] # Initial 1 token after start
prompt_ids = [1, 2, THINKING_START_TOKEN_ID, 3]
sampling_params = SamplingParams(logits_processors_args={"thinking_budget": 3})
mock_req = MockRequest(req_id, prompt_ids, sampling_params)
@@ -206,16 +225,26 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
mock_runner.update_request_state(0, mock_req, pre_id=prompt_ids[-1])
processor = ThinkingBudgetLogitsProcessor(self.fd_config)
line_break_id = processor.line_break_token_id
think_end_id = processor.think_end_token_id
# Step 1: Initial state update (1 token after start)
# Step 1: Initial state update (prompt-side tokens do not count)
processor.update_state(mock_runner.share_inputs)
self.assertEqual(processor._states[req_id].tokens_after_start, 1)
self.assertEqual(processor._states[req_id].tokens_after_start, 0)
self.assertFalse(processor._states[req_id].ended)
self.assertEqual(processor._states[req_id].last_token_id, 3)
# Step 2: Generate 2nd token (budget 3, generated 1 -> 2)
# Step 2: Generate 1st decode token (budget 3, generated 0 -> 1)
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs)
processed_logits = processor.apply(logits)
next_token = mock_runner.generate_next_token(processed_logits)[0]
self.assertEqual(next_token, 0) # Normal generation
mock_runner.update_request_state(0, mock_req, pre_id=next_token)
processor.update_state(mock_runner.share_inputs)
self.assertEqual(processor._states[req_id].tokens_after_start, 1)
self.assertEqual(processor._states[req_id].last_token_id, 0)
# Step 3: Generate 2nd decode token (budget 3, generated 1 -> 2)
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs)
processed_logits = processor.apply(logits)
@@ -226,42 +255,17 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
self.assertEqual(processor._states[req_id].tokens_after_start, 2)
self.assertEqual(processor._states[req_id].last_token_id, 0)
# Step 3: Generate 3rd token (budget 3, generated 2 -> 3). Next step should force NEW_LINE.
# Step 4: Generate 3rd decode token (budget 3, generated 2 -> 3).
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs)
processed_logits = processor.apply(logits)
next_token = mock_runner.generate_next_token(processed_logits)[0]
self.assertEqual(next_token, 0) # Normal generation
self.assertEqual(next_token, 0)
mock_runner.update_request_state(0, mock_req, pre_id=next_token)
processor.update_state(mock_runner.share_inputs)
self.assertEqual(processor._states[req_id].tokens_after_start, 3) # Budget is now met
self.assertEqual(processor._states[req_id].last_token_id, 0)
self.assertEqual(processor._states[req_id].tokens_after_start, 3)
# Step 4: Budget reached, last token not NEW_LINE. Should force NEW_LINE_TOKEN_ID.
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs)
processed_logits = processor.apply(logits)
# Verify all other logits are -inf, only NEW_LINE_TOKEN_ID is 0.0
# Use <= for comparison because paddle.full creates -10.0 and comparison can have precision issues
other_logits = paddle.concat(
[
processed_logits[0, :line_break_id],
processed_logits[0, line_break_id + 1 : VOCAB_SIZE],
],
axis=0,
)
self.assertTrue(paddle.all(other_logits <= -10.0).item() or paddle.all(other_logits == -float("inf")).item())
self.assertEqual(processed_logits[0, line_break_id].item(), 0.0)
next_token = mock_runner.generate_next_token(processed_logits)[0]
self.assertEqual(next_token, line_break_id) # Forces NEW_LINE
mock_runner.update_request_state(0, mock_req, pre_id=next_token)
processor.update_state(mock_runner.share_inputs)
self.assertEqual(processor._states[req_id].tokens_after_start, 4) # Still increment
self.assertEqual(processor._states[req_id].last_token_id, line_break_id)
# Step 5: Last token is NEW_LINE. Should force THINKING_END_TOKEN_ID.
# Step 5: Budget reached, should force THINKING_END_TOKEN_ID directly.
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs)
processed_logits = processor.apply(logits)
@@ -324,9 +328,9 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
self.assertEqual(next_token, stop_sentence_token_ids[0])
def test_thinking_budget_no_stop_sentence_defaults(self):
# Scenario: No stop sentence, budget reached should force newline.
# Scenario: No stop sentence, budget reached should force thinking_end directly.
req_id = "test_req_no_stop_sentence"
prompt_ids = [THINKING_START_TOKEN_ID, 42] # 1 token after start
prompt_ids = [THINKING_START_TOKEN_ID, 42]
sampling_params = SamplingParams(logits_processors_args={"thinking_budget": 1})
mock_req = MockRequest(req_id, prompt_ids, sampling_params)
@@ -334,35 +338,42 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
mock_runner.update_request_state(0, mock_req, pre_id=prompt_ids[-1])
processor = ThinkingBudgetLogitsProcessor(self.fd_config)
line_break_id = processor.line_break_token_id
processor.update_state(mock_runner.share_inputs)
logits = self._get_initial_logits(1)
processed_logits = processor.apply(logits)
next_token = mock_runner.generate_next_token(processed_logits)[0]
self.assertEqual(next_token, line_break_id)
self.assertEqual(next_token, 0)
mock_runner.update_request_state(0, mock_req, pre_id=next_token)
processor.update_state(mock_runner.share_inputs)
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs)
processed_logits = processor.apply(logits)
next_token = mock_runner.generate_next_token(processed_logits)[0]
self.assertEqual(next_token, THINKING_END_TOKEN_ID)
def test_thinking_budget_uses_config_token_ids(self):
# Scenario: Processor should use token ids from model config.
self.fd_config.model_config.think_start_id = 123
self.fd_config.model_config.think_end_id = 124
self.fd_config.model_config.line_break_id = 125
self.fd_config.model_config.line_break_id = -1
processor = ThinkingBudgetLogitsProcessor(self.fd_config)
self.assertEqual(processor.think_start_token_id, 123)
self.assertEqual(processor.think_end_token_id, 124)
self.assertEqual(processor.line_break_token_id, 125)
self.assertEqual(processor.line_break_token_id, -1)
self.assertTrue(processor._enabled)
def test_thinking_budget_disabled_when_token_ids_missing(self):
# Scenario: Processor should be disabled when token ids are not configured.
self.fd_config.model_config.think_start_id = -1
self.fd_config.model_config.think_end_id = -1
self.fd_config.model_config.line_break_id = -1
self.fd_config.model_config.line_break_id = NEW_LINE_TOKEN_ID
processor = ThinkingBudgetLogitsProcessor(self.fd_config)
self.assertFalse(processor._enabled)
self.assertEqual(processor.think_start_token_id, -1)
self.assertEqual(processor.think_end_token_id, -1)
self.assertEqual(processor.line_break_token_id, -1)
self.assertEqual(processor.line_break_token_id, NEW_LINE_TOKEN_ID)
# update_state and apply should be no-op when disabled
mock_runner = MockModelRunner(self.fd_config, max_num_seqs=1)
@@ -493,6 +504,28 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
self.assertEqual(state.tokens_after_start, 2)
self.assertEqual(state.last_token_id, 99)
def test_thinking_budget_prompt_state_from_token_ids_all_fallback(self):
req_id = "req_gpu_fallback"
mock_runner = MockModelRunner(self.fd_config, max_num_seqs=1)
mock_runner.share_inputs["req_ids"][0] = req_id
mock_runner.share_inputs["logits_processors_args"][0] = {"thinking_budget": 3}
mock_runner.share_inputs["prompt_ids"] = None
mock_runner.share_inputs["token_ids_all"][0, :4] = paddle.to_tensor(
[1, THINKING_START_TOKEN_ID, 2, 3], dtype=paddle.int64
)
mock_runner.share_inputs["prompt_lens"][0, 0] = paddle.to_tensor(4, dtype=paddle.int64)
mock_runner.share_inputs["next_tokens"][0, 0] = paddle.to_tensor(-1, dtype=paddle.int64)
processor = ThinkingBudgetLogitsProcessor(self.fd_config)
processor.update_state(mock_runner.share_inputs)
state = processor._states[req_id]
self.assertTrue(state.prompt_checked)
self.assertTrue(state.started)
self.assertFalse(state.ended)
self.assertEqual(state.tokens_after_start, 0)
self.assertEqual(state.last_token_id, 3)
def test_thinking_budget_not_configured(self):
# Scenario: Processor is active, but request does not provide thinking_budget
req_id = "test_req_3"
@@ -532,20 +565,7 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
self.assertEqual(processor._states[req_id].tokens_after_start, 0) # No tokens after start yet
self.assertEqual(processor._states[req_id].last_token_id, THINKING_START_TOKEN_ID)
# Step 1: Budget 0 reached, last token is THINKING_START. Should force NEW_LINE.
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs)
processed_logits = processor.apply(logits)
self.assertEqual(processed_logits[0, NEW_LINE_TOKEN_ID].item(), 0.0)
next_token = mock_runner.generate_next_token(processed_logits)[0]
self.assertEqual(next_token, NEW_LINE_TOKEN_ID)
mock_runner.update_request_state(0, mock_req, pre_id=next_token)
processor.update_state(mock_runner.share_inputs)
self.assertEqual(processor._states[req_id].tokens_after_start, 1) # Still increments
self.assertEqual(processor._states[req_id].last_token_id, NEW_LINE_TOKEN_ID)
# Step 2: Last token is NEW_LINE. Should force THINKING_END_TOKEN_ID.
# Step 1: Budget 0 reached, last token is THINKING_START. Should force THINKING_END.
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs)
processed_logits = processor.apply(logits)
@@ -572,7 +592,7 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
processor.update_state(mock_runner.share_inputs)
self.assertTrue(processor._states[req_id].started)
self.assertTrue(processor._states[req_id].ended)
self.assertEqual(processor._states[req_id].tokens_after_start, 2) # Tokens 2, THINKING_END after start
self.assertEqual(processor._states[req_id].tokens_after_start, 0)
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs)
@@ -582,7 +602,7 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
def test_multiple_requests(self):
# Scenario: Multiple requests with different thinking states
req_id_1 = "req_a"
prompt_ids_1 = [THINKING_START_TOKEN_ID, 10, 11] # budget 2, last_token=11, tokens_after_start=2
prompt_ids_1 = [THINKING_START_TOKEN_ID, 10, 11]
sampling_params_1 = SamplingParams(logits_processors_args={"thinking_budget": 2})
mock_req_1 = MockRequest(req_id_1, prompt_ids_1, sampling_params_1)
@@ -592,7 +612,7 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
mock_req_2 = MockRequest(req_id_2, prompt_ids_2, sampling_params_2)
req_id_3 = "req_c"
prompt_ids_3 = [THINKING_START_TOKEN_ID, 30] # budget 1, last_token=30, tokens_after_start=1
prompt_ids_3 = [THINKING_START_TOKEN_ID, 30]
sampling_params_3 = SamplingParams(logits_processors_args={"thinking_budget": 1})
mock_req_3 = MockRequest(req_id_3, prompt_ids_3, sampling_params_3)
@@ -609,14 +629,14 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
# Verify initial states
self.assertTrue(processor._states[req_id_1].started)
self.assertFalse(processor._states[req_id_1].ended)
self.assertEqual(processor._states[req_id_1].tokens_after_start, 2)
self.assertEqual(processor._states[req_id_1].tokens_after_start, 0)
self.assertEqual(processor._states[req_id_1].last_token_id, 11)
self.assertNotIn(req_id_2, processor._states) # No budget specified for req_2
self.assertTrue(processor._states[req_id_3].started)
self.assertFalse(processor._states[req_id_3].ended)
self.assertEqual(processor._states[req_id_3].tokens_after_start, 1)
self.assertEqual(processor._states[req_id_3].tokens_after_start, 0)
self.assertEqual(processor._states[req_id_3].last_token_id, 30)
# Simulate logits for the batch
@@ -624,16 +644,14 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
processor.update_state(mock_runner.share_inputs) # Ensure state is updated before apply
processed_batch_logits = processor.apply(batch_logits)
# Req 1: budget 2, tokens_after_start 2. Should force NEW_LINE (last_token_id is 11, not NEW_LINE)
self.assertEqual(processed_batch_logits[0, NEW_LINE_TOKEN_ID].item(), 0.0)
self.assertEqual(paddle.argmax(processed_batch_logits[0], axis=-1).item(), NEW_LINE_TOKEN_ID)
# Req 1: prompt-side content does not consume budget, so first step is normal generation.
self.assertEqual(paddle.argmax(processed_batch_logits[0], axis=-1).item(), 0)
# Req 2: No thinking budget, normal generation
self.assertEqual(paddle.argmax(processed_batch_logits[1], axis=-1).item(), 0)
# Req 3: budget 1, tokens_after_start 1. Should force NEW_LINE (last_token_id is 30, not NEW_LINE)
self.assertEqual(processed_batch_logits[2, NEW_LINE_TOKEN_ID].item(), 0.0)
self.assertEqual(paddle.argmax(processed_batch_logits[2], axis=-1).item(), NEW_LINE_TOKEN_ID)
# Req 3: prompt-side content does not consume budget, so first step is normal generation.
self.assertEqual(paddle.argmax(processed_batch_logits[2], axis=-1).item(), 0)
# Simulate generating next tokens and updating state
next_tokens = mock_runner.generate_next_token(processed_batch_logits)
@@ -643,20 +661,25 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
processor.update_state(mock_runner.share_inputs)
# Verify updated states for next step
self.assertEqual(processor._states[req_id_1].last_token_id, NEW_LINE_TOKEN_ID)
self.assertEqual(processor._states[req_id_3].last_token_id, NEW_LINE_TOKEN_ID)
self.assertEqual(processor._states[req_id_1].last_token_id, 0)
self.assertEqual(processor._states[req_id_3].last_token_id, 0)
self.assertEqual(processor._states[req_id_1].tokens_after_start, 1)
self.assertEqual(processor._states[req_id_3].tokens_after_start, 1)
self.assertFalse(processor._states[req_id_1].ended)
self.assertFalse(processor._states[req_id_3].ended)
batch_logits = self._get_initial_logits(3)
processor.update_state(mock_runner.share_inputs)
processed_batch_logits = processor.apply(batch_logits)
# Req 1: last token was NEW_LINE. Should force THINKING_END
self.assertEqual(paddle.argmax(processed_batch_logits[0], axis=-1).item(), THINKING_END_TOKEN_ID)
# Req 1: budget 2, tokens_after_start 1. Still normal generation.
self.assertEqual(paddle.argmax(processed_batch_logits[0], axis=-1).item(), 0)
# Req 2: Still normal generation
self.assertEqual(paddle.argmax(processed_batch_logits[1], axis=-1).item(), 0)
# Req 3: last token was NEW_LINE. Should force THINKING_END
# Req 3: budget 1 reached after one generated token, should now force THINKING_END.
self.assertEqual(processed_batch_logits[2, THINKING_END_TOKEN_ID].item(), 0.0)
self.assertEqual(paddle.argmax(processed_batch_logits[2], axis=-1).item(), THINKING_END_TOKEN_ID)
@@ -724,7 +747,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
self.assertTrue(updated["think_prompt_checked"])
self.assertTrue(updated["think_prompt_started"])
self.assertTrue(updated["think_prompt_ended"])
self.assertEqual(updated["think_prompt_tokens_after_start"], 2)
self.assertEqual(updated["think_prompt_tokens_after_start"], 0)
self.assertEqual(updated["think_prompt_last_token_id"], 3)
def test_v1_process_request_missing_logits_processors_args(self):
@@ -833,6 +856,58 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
self.assertEqual(processor.encode_with_cache("iter"), [41, 42])
self.assertNotIn(("np", False), processor._tokenize_cache)
def test_text_encode_with_cache_lazy_init(self):
processor = TextDataProcessor.__new__(TextDataProcessor)
call_counter = {"count": 0}
def _text2ids(text, max_model_len=None, add_special_tokens=False):
call_counter["count"] += 1
return np.array([51, 52], dtype=np.int64)
processor.text2ids = _text2ids
self.assertFalse(hasattr(processor, "_tokenize_cache"))
self.assertEqual(processor.encode_with_cache("lazy"), [51, 52])
self.assertTrue(hasattr(processor, "_tokenize_cache"))
self.assertEqual(processor.encode_with_cache("lazy"), [51, 52])
self.assertEqual(call_counter["count"], 1)
def test_v1_encode_with_cache_lazy_init(self):
processor = V1TextDataProcessor.__new__(V1TextDataProcessor)
call_counter = {"count": 0}
def _text2ids(text, max_model_len=None, add_special_tokens=False):
call_counter["count"] += 1
return np.array([61, 62], dtype=np.int64)
processor.text2ids = _text2ids
self.assertFalse(hasattr(processor, "_tokenize_cache"))
self.assertEqual(processor.encode_with_cache("lazy"), [61, 62])
self.assertTrue(hasattr(processor, "_tokenize_cache"))
self.assertEqual(processor.encode_with_cache("lazy"), [61, 62])
self.assertEqual(call_counter["count"], 1)
def test_ernie_encode_literal_text_with_cache(self):
processor = ErnieTextDataProcessor.__new__(ErnieTextDataProcessor)
processor.tokenizer = SimpleNamespace(
tokenize=lambda text: ["token_a", "token_b"],
convert_tokens_to_ids=lambda tokens: [71, 72],
)
self.assertEqual(processor._encode_literal_text_with_cache("fallback"), [71, 72])
self.assertEqual(processor._encode_literal_text_with_cache("fallback"), [71, 72])
def test_v1_ernie_encode_literal_text_with_cache(self):
processor = V1ErnieTextDataProcessor.__new__(V1ErnieTextDataProcessor)
processor.tokenizer = SimpleNamespace(
tokenize=lambda text: ["token_c", "token_d"],
convert_tokens_to_ids=lambda tokens: [81, 82],
)
self.assertEqual(processor._encode_literal_text_with_cache("fallback"), [81, 82])
self.assertEqual(processor._encode_literal_text_with_cache("fallback"), [81, 82])
def test_text_update_thinking_prompt_state_branches(self):
processor = TextDataProcessor.__new__(TextDataProcessor)
processor._think_token_ids = None
@@ -868,7 +943,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
)
self.assertTrue(with_start_no_end["think_prompt_started"])
self.assertFalse(with_start_no_end["think_prompt_ended"])
self.assertEqual(with_start_no_end["think_prompt_tokens_after_start"], 2)
self.assertEqual(with_start_no_end["think_prompt_tokens_after_start"], 0)
self.assertEqual(with_start_no_end["think_prompt_last_token_id"], 3)
# 命中 _get_think_token_ids 的缓存分支
@@ -891,7 +966,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
)
self.assertTrue(with_start_no_end["think_prompt_started"])
self.assertFalse(with_start_no_end["think_prompt_ended"])
self.assertEqual(with_start_no_end["think_prompt_tokens_after_start"], 2)
self.assertEqual(with_start_no_end["think_prompt_tokens_after_start"], 0)
self.assertEqual(with_start_no_end["think_prompt_last_token_id"], 3)
# 命中 _get_think_token_ids 的缓存分支
@@ -903,7 +978,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
processor.eos_token_ids = [1]
processor.update_stop_seq = lambda *args, **kwargs: None
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
processor.encode_with_cache = lambda text, *args, **kwargs: [23] if text == "\n" else [201, 202]
processor._encode_literal_text_with_cache = lambda text: [201, 202]
processor._update_thinking_prompt_state = lambda prompt_token_ids, args: args
processor.reasoning_parser = None
@@ -924,7 +999,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
processed = processor.process_request_dict(request, max_model_len=16)
self.assertEqual(
processed["logits_processors_args"].get("think_stop_sentence_token_ids"),
[23, 201, 202],
[201, 202],
)
self.assertNotIn("think_stop_sentence", processed["logits_processors_args"])
@@ -934,7 +1009,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
processor.eos_token_ids = [1]
processor.update_stop_seq = lambda *args, **kwargs: None
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
processor.encode_with_cache = lambda text, *args, **kwargs: [23] if text == "\n" else [301, 302]
processor._encode_literal_text_with_cache = lambda text: [301, 302]
processor._update_thinking_prompt_state = lambda prompt_token_ids, args: args
processor.reasoning_parser = None
@@ -955,7 +1030,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
processed = processor.process_request(request, max_model_len=16)
self.assertEqual(
processed.logits_processors_args.get("think_stop_sentence_token_ids"),
[23, 301, 302],
[301, 302],
)
self.assertNotIn("think_stop_sentence", processed.logits_processors_args)
@@ -965,7 +1040,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
processor.eos_token_ids = [1]
processor.update_stop_seq = lambda *args, **kwargs: None
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
processor.encode_with_cache = lambda text, *args, **kwargs: [23] if text == "\n" else [401, 402]
processor._encode_literal_text_with_cache = lambda text: [401, 402]
processor._update_thinking_prompt_state = lambda prompt_token_ids, args: args
processor.reasoning_parser = None
@@ -992,10 +1067,173 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
processed = processor.process_request_dict(request, max_model_len=16)
self.assertEqual(
processed.sampling_params.logits_processors_args.get("think_stop_sentence_token_ids"),
[23, 401, 402],
[401, 402],
)
self.assertNotIn("think_stop_sentence", processed.sampling_params.logits_processors_args)
def test_ernie_process_request_dict_prepares_thinking_budget_args(self):
processor = ErnieTextDataProcessor.__new__(ErnieTextDataProcessor)
processor._apply_default_parameters = lambda request: request
processor.eos_token_ids = [1]
processor.update_stop_seq = lambda *args, **kwargs: None
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
processor._encode_literal_text_with_cache = lambda text: [501, 502]
processor.tokenizer = DummyTokenizerForTextProcessor()
processor.reasoning_parser = None
request = {
"request_id": "req_ernie_text",
"eos_token_ids": [1],
"prompt_token_ids": [1, THINKING_START_TOKEN_ID, 2],
"prompt": None,
"messages": None,
"logits_processors_args": {"thinking_budget": 20, "think_stop_sentence": "done"},
"bad_words": None,
"bad_words_token_ids": None,
"max_tokens": 1,
"temperature": 1.0,
"top_p": 0.9,
"response_max_tokens": None,
"enable_thinking": True,
}
with patch("fastdeploy.input.ernie4_5_processor.process_stop_token_ids", lambda *args, **kwargs: None):
processed = processor.process_request_dict(request, max_model_len=16)
self.assertEqual(processed["logits_processors_args"]["think_stop_sentence_token_ids"], [501, 502])
self.assertTrue(processed["logits_processors_args"]["think_prompt_started"])
self.assertFalse(processed["logits_processors_args"]["think_prompt_ended"])
self.assertEqual(processed["logits_processors_args"]["think_prompt_tokens_after_start"], 0)
def test_v1_ernie_process_request_dict_prepares_thinking_budget_args(self):
processor = V1ErnieTextDataProcessor.__new__(V1ErnieTextDataProcessor)
processor._apply_default_parameters = lambda request: request
processor.eos_token_ids = [1]
processor.update_stop_seq = lambda *args, **kwargs: None
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
processor._encode_literal_text_with_cache = lambda text: [601, 602]
processor.tokenizer = DummyTokenizerForTextProcessor()
processor.reasoning_parser = None
request = DummyRequestV1(
request_id="req_v1_ernie_text",
eos_token_ids=[1],
prompt_token_ids=[1, THINKING_START_TOKEN_ID, 2],
prompt=None,
messages=None,
chat_template_kwargs=None,
enable_thinking=True,
sampling_params=SimpleNamespace(
bad_words=None,
bad_words_token_ids=None,
max_tokens=1,
temperature=1.0,
top_p=0.9,
repetition_penalty=1.0,
frequency_penalty=0.0,
presence_penalty=0.0,
response_max_tokens=None,
n=1,
logits_processors_args={"thinking_budget": 20, "think_stop_sentence": "done"},
),
)
with patch("fastdeploy.input.v1.ernie4_5_processor.process_stop_token_ids", lambda *args, **kwargs: None):
processed = processor.process_request_dict(request, max_model_len=16)
self.assertEqual(processed.sampling_params.logits_processors_args["think_stop_sentence_token_ids"], [601, 602])
self.assertTrue(processed.sampling_params.logits_processors_args["think_prompt_started"])
self.assertFalse(processed.sampling_params.logits_processors_args["think_prompt_ended"])
self.assertEqual(processed.sampling_params.logits_processors_args["think_prompt_tokens_after_start"], 0)
def test_ernie_vl_process_request_dict_prepares_thinking_budget_args(self):
processor = ErnieVLDataProcessor.__new__(ErnieVLDataProcessor)
processor._apply_default_parameters = lambda request: request
processor.eos_token_ids = [1]
processor.update_stop_seq = lambda *args, **kwargs: None
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
processor._encode_literal_text_with_cache = lambda text: [701, 702]
processor.tokenizer = DummyTokenizerForTextProcessor()
processor.reasoning_parser = None
processor._check_mm_limits = lambda *args, **kwargs: None
processor.append_completion_tokens = lambda *args, **kwargs: None
processor.pack_outputs = lambda outs: outs
processor.ernie4_5_processor = SimpleNamespace(
request2ids=lambda request: {"input_ids": np.array([1, THINKING_START_TOKEN_ID, 2], dtype=np.int64)}
)
request = {
"request_id": "req_ernie_vl",
"eos_token_ids": [1],
"messages": [{"role": "user", "content": "hi"}],
"bad_words": None,
"bad_words_token_ids": None,
"logits_processors_args": {"thinking_budget": 20, "think_stop_sentence": "done"},
"max_tokens": 1,
"top_p": 0.9,
"response_max_tokens": None,
}
with patch(
"fastdeploy.input.ernie4_5_vl_processor.ernie4_5_vl_processor.process_stop_token_ids",
lambda *args, **kwargs: None,
):
processed = processor.process_request_dict(request, max_model_len=16)
self.assertEqual(processed["logits_processors_args"]["think_stop_sentence_token_ids"], [701, 702])
self.assertTrue(processed["logits_processors_args"]["think_prompt_started"])
self.assertFalse(processed["logits_processors_args"]["think_prompt_ended"])
self.assertEqual(processed["logits_processors_args"]["think_prompt_tokens_after_start"], 0)
def test_v1_ernie_vl_process_request_dict_prepares_thinking_budget_args(self):
processor = V1ErnieVLDataProcessor.__new__(V1ErnieVLDataProcessor)
processor._apply_default_parameters = lambda request: request
processor.eos_token_ids = [1]
processor.update_stop_seq = lambda *args, **kwargs: None
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
processor._encode_literal_text_with_cache = lambda text: [801, 802]
processor.tokenizer = DummyTokenizerForTextProcessor()
processor.reasoning_parser = None
processor._check_mm_limits = lambda *args, **kwargs: None
processor.append_completion_tokens = lambda *args, **kwargs: None
processor.pack_outputs = lambda outs: outs
processor.ernie4_5_processor = SimpleNamespace(
request2ids=lambda request: {"input_ids": np.array([1, THINKING_START_TOKEN_ID, 2], dtype=np.int64)}
)
request = DummyRequestV1(
request_id="req_v1_ernie_vl",
eos_token_ids=[1],
prompt_token_ids=None,
prompt=None,
messages=[{"role": "user", "content": "hi"}],
chat_template_kwargs=None,
enable_thinking=True,
completion_token_ids=None,
multimodal_data=None,
sampling_params=SimpleNamespace(
bad_words=None,
bad_words_token_ids=None,
max_tokens=1,
temperature=1.0,
top_p=0.9,
repetition_penalty=1.0,
frequency_penalty=0.0,
presence_penalty=0.0,
response_max_tokens=None,
reasoning_max_tokens=None,
n=1,
logits_processors_args={"thinking_budget": 20, "think_stop_sentence": "done"},
),
)
with patch(
"fastdeploy.input.v1.ernie4_5_vl_processor.ernie4_5_vl_processor.process_stop_token_ids",
lambda *args, **kwargs: None,
):
processed = processor.process_request_dict(request, max_model_len=16)
self.assertEqual(processed.sampling_params.logits_processors_args["think_stop_sentence_token_ids"], [801, 802])
self.assertTrue(processed.sampling_params.logits_processors_args["think_prompt_started"])
self.assertFalse(processed.sampling_params.logits_processors_args["think_prompt_ended"])
self.assertEqual(processed.sampling_params.logits_processors_args["think_prompt_tokens_after_start"], 0)
if __name__ == "__main__":
unittest.main()