[Bugfix] Align thinking_budget behavior with ERNIE reasoning flow (#6934)

* [Bugfix] Align thinking_budget behavior with ERNIE reasoning flow

* [Docs] Fix thinking_budget markdown formatting

* [Test] Align ernie thinking budget test with process_request_dict
This commit is contained in:
jackyYang6
2026-03-23 14:15:55 +08:00
committed by GitHub
parent 7a78001be2
commit 634d23a38a
10 changed files with 663 additions and 285 deletions
+28 -12
View File
@@ -2,7 +2,9 @@
## Overview ## Overview
`ThinkingBudgetLogitsProcessor` limits the number of tokens generated inside the `<think> ... </think>` segment. When the budget is reached, it forces a line break token and then the `</think>` token to terminate the thinking section. `ThinkingBudgetLogitsProcessor` limits the number of tokens generated inside the `<think> ... </think>`
segment. When the budget is reached, it terminates thinking by forcing `</think>`. If
`think_stop_sentence` is configured, it forces the custom sentence first and then `</think>`.
## When to Use ## When to Use
@@ -11,19 +13,22 @@
## How It Works ## How It Works
1. **CPU precompute (DataProcessor)**: when a request includes `thinking_budget`, the prompt token ids are scanned to determine whether thinking has started, whether it already ended, and how many tokens are already inside the thinking section. 1. **Request-side precompute (DataProcessor)**: when a request includes `thinking_budget`, the prompt token ids are scanned to determine whether thinking has started, whether it already ended, and how many tokens are already inside the thinking section.
2. **Per-step update**: during decoding, the processor tracks `last_token_id` and `tokens_after_start`. 2. **Per-step update**: during decoding, the processor tracks `last_token_id` and `tokens_after_start`.
3. **Budget enforcement**: once the budget is reached, it forces a line break and then the thinking end token. 3. **Budget enforcement**: once the budget is reached, it forces `</think>` directly. If `think_stop_sentence`
is configured, it forces that sentence first and then `</think>`.
## Requirements ## Requirements
- The model must provide valid token ids for `think_start_id`, `think_end_id`, and `line_break_id` (via `ModelConfig`). - The model must provide valid token ids for `think_start_id` and `think_end_id` (via `ModelConfig`).
- If any of these ids are invalid, the processor is disabled and `thinking_budget` will not take effect. - If either of these ids is invalid, the processor is disabled and `thinking_budget` will not take effect.
## Request Parameters ## Request Parameters
- `thinking_budget` (int, required to enable): maximum number of tokens after `<think>` before forced termination. - `thinking_budget` (int, required to enable): maximum number of decode-time tokens after `<think>` before forced
- `think_stop_sentence` (string, optional): a stop sentence that will be tokenized on the CPU side and enforced near the budget boundary. termination.
- `think_stop_sentence` (string, optional): a literal custom sentence that will be tokenized on the request side
and enforced near the budget boundary.
## Operator-Level vs LogitsProcessor ## Operator-Level vs LogitsProcessor
@@ -41,16 +46,25 @@ FastDeploy has two ways to limit thinking length:
In short: In short:
- If you only need a hard cap on thinking length, prefer `reasoning_max_tokens`. - If you only need a hard cap on thinking length, prefer `reasoning_max_tokens`.
- If you need custom behavior (for example, injecting custom sentence tokens), use `ThinkingBudgetLogitsProcessor`. - If you need custom behavior (for example, inserting a custom sentence before `</think>`), use
`ThinkingBudgetLogitsProcessor`.
## Practical guidance ## Practical guidance
`reasoning_max_tokens` and `thinking_budget` are not mutually exclusive in current implementation. `reasoning_max_tokens` and `thinking_budget` are not mutually exclusive in current implementation.
If both are configured for the same request, both constraints can take effect, and whichever triggers first will end the thinking phase. If both are configured for the same request, both constraints can take effect, and whichever triggers first will end the thinking phase.
- To use **operator-level-only** behavior: this is request-level config only. Set `enable_thinking=true` and `reasoning_max_tokens` in request, and do not set `thinking_budget`. - To use **operator-level-only** behavior: this is request-level config only. Set
- To use **logits-processor-only** behavior (especially with `think_stop_sentence`): this requires service-level + request-level config. Start service with `--logits-processors ThinkingBudgetLogitsProcessor`, and set `thinking_budget` (and optional `think_stop_sentence`) in `logits_processors_args`; leave `reasoning_max_tokens` unset. `enable_thinking=true` and `reasoning_max_tokens` in request, and do not set `thinking_budget`.
- Avoid enabling both for strict custom sentence insertion requirements, because operator-level termination may cut the custom sentence path earlier. - To use **logits-processor-only** behavior (especially with `think_stop_sentence`): this requires
service-level + request-level config. Start service with `--logits-processors ThinkingBudgetLogitsProcessor`,
and set `thinking_budget` (and optional `think_stop_sentence`) in `logits_processors_args`; leave
`reasoning_max_tokens` unset.
- `thinking_budget` itself does not require `enable_thinking=true`.
- If an ERNIE chat template already appends `<think>` in the prompt, `thinking_budget` should still take effect; it
does not require the model to emit another `<think>` during decoding.
- Avoid enabling both for strict custom sentence insertion requirements, because operator-level
termination may cut the custom sentence path earlier.
## Online Usage ## Online Usage
@@ -120,4 +134,6 @@ print(outputs[0].outputs.text)
## Performance Note ## Performance Note
This processor runs `update_state` and `apply` on every decode step. If you only need a hard thinking-length cap and care most about throughput, consider the operator-level reasoning-length controls instead of per-step logits processing. This processor runs `update_state` and `apply` on every decode step. If you only need a hard
thinking-length cap and care most about throughput, consider the operator-level reasoning-length
controls instead of per-step logits processing.
+20 -11
View File
@@ -2,7 +2,9 @@
## 概述 ## 概述
`ThinkingBudgetLogitsProcessor` 用于限制 `<think> ... </think>` 区间的生成长度。当预算达到阈值时,会强制生成换行符 token,再强制生成 `</think>`,从而结束思考段。 `ThinkingBudgetLogitsProcessor` 用于限制 `<think> ... </think>` 区间的生成长度。当预算达到阈值时,
会直接强制生成 `</think>` 来结束思考段;如果配置了 `think_stop_sentence`,则会先强制输出该自定义
文案,再输出 `</think>`
## 适用场景 ## 适用场景
@@ -11,19 +13,20 @@
## 工作原理 ## 工作原理
1. **CPU 侧预计算(DataProcessor**:当请求中包含 `thinking_budget`,会基于 prompt 的 token ids 计算是否已进入思考段、是否已结束,以及已有的思考长度。 1. **请求侧预计算(DataProcessor**:当请求中包含 `thinking_budget`,会基于 prompt 的 token ids 计算是否已进入思考段、是否已结束,以及已有的思考长度。
2. **每步更新**:解码过程中跟踪 `last_token_id``tokens_after_start` 2. **每步更新**:解码过程中跟踪 `last_token_id``tokens_after_start`
3. **预算约束**:达到预算后,依次强制换行符与思考结束 token 3. **预算约束**:达到预算后,默认直接强制 `</think>`;如果配置了 `think_stop_sentence`,则先逐 token
强制输出该文案,再输出 `</think>`
## 前置要求 ## 前置要求
- 模型需提供有效的 `think_start_id``think_end_id``line_break_id`(来自 `ModelConfig`)。 - 模型需提供有效的 `think_start_id``think_end_id`(来自 `ModelConfig`)。
- 若任意 id 无效,处理器会禁用,`thinking_budget` 不生效。 -其中任意 id 无效,处理器会禁用,`thinking_budget` 不生效。
## 请求参数 ## 请求参数
- `thinking_budget`int,启用所需):`<think>` 之后允许的最大 token 数。 - `thinking_budget`int,启用所需):`<think>` 之后允许的最大 decode 阶段 token 数。
- `think_stop_sentence`string,可选):CPU 侧会将该字符串编码为 token ids,并在预算边界附近强制输出。 - `think_stop_sentence`string,可选):按字面串编码的自定义终止文案,并在预算边界附近强制输出。
## 算子级限制 vs LogitsProcessor ## 算子级限制 vs LogitsProcessor
@@ -35,21 +38,27 @@ FastDeploy 当前有两种思考长度控制方式:
- 适合“只限制思考长度”的简单场景。 - 适合“只限制思考长度”的简单场景。
- **`ThinkingBudgetLogitsProcessor`**`logits_processors_args.thinking_budget`): - **`ThinkingBudgetLogitsProcessor`**`logits_processors_args.thinking_budget`):
- 由每步 Python 侧 logits 处理实现。 - 由每步 Python 侧 logits 处理实现。
- 支持更灵活的行为,例如 `think_stop_sentence`(在结束前插入自定义话术) - 支持更灵活的行为,例如 `think_stop_sentence`
- 相比算子级限制,在高并发下通常有更高开销。 - 相比算子级限制,在高并发下通常有更高开销。
可按以下原则选择: 可按以下原则选择:
- 仅需限制思考长度:优先用 `reasoning_max_tokens` - 仅需限制思考长度:优先用 `reasoning_max_tokens`
- 需要更灵活控制(如插入自定义话术):使用 `ThinkingBudgetLogitsProcessor` - 需要更灵活控制(如`</think>`插入自定义话术):使用 `ThinkingBudgetLogitsProcessor`
## 建议实践 ## 建议实践
当前实现中,`reasoning_max_tokens``thinking_budget` 不是互斥关系。 当前实现中,`reasoning_max_tokens``thinking_budget` 不是互斥关系。
同一请求如果同时配置,两套约束都可能生效,谁先触发就先结束思考段。 同一请求如果同时配置,两套约束都可能生效,谁先触发就先结束思考段。
- **只用算子级限制**:这是请求级配置。仅在请求中设置 `enable_thinking=true` + `reasoning_max_tokens`不要传 `thinking_budget` - **只用算子级限制**:这是请求级配置。仅在请求中设置 `enable_thinking=true` + `reasoning_max_tokens`
- **只用 LogitsProcessor**(尤其要用 `think_stop_sentence`):这是“服务启动 + 请求参数”两级配置。服务启动时必须加 `--logits-processors ThinkingBudgetLogitsProcessor`,并在请求里通过 `logits_processors_args``thinking_budget`(以及可选的 `think_stop_sentence`);同时不要设置 `reasoning_max_tokens` 不要传 `thinking_budget`
- **只用 LogitsProcessor**(尤其要用 `think_stop_sentence`):这是“服务启动 + 请求参数”两级配置。
服务启动时必须加 `--logits-processors ThinkingBudgetLogitsProcessor`,并在请求里通过
`logits_processors_args``thinking_budget`(以及可选的 `think_stop_sentence`);同时不要设置
`reasoning_max_tokens`
- `thinking_budget` 本身不依赖 `enable_thinking=true`
- 如果 ERNIE 的 chat template 已经在 prompt 里拼入 `<think>``thinking_budget` 也应正常生效,不要求模型在 decode 阶段再次输出 `<think>`
- 如果业务要求“必须完整插入自定义话术”,不建议与算子级限制同时开启,否则可能被算子级提前截断。 - 如果业务要求“必须完整插入自定义话术”,不建议与算子级限制同时开启,否则可能被算子级提前截断。
## 在线使用 ## 在线使用
+9
View File
@@ -107,6 +107,11 @@ class Ernie4_5Processor(BaseDataProcessor):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids) bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids request["bad_words_token_ids"] = bad_words_token_ids
logits_processors_args = self._prepare_think_stop_sentence(
request.get("logits_processors_args") or {}, max_model_len
)
request["logits_processors_args"] = logits_processors_args
# processing prompt_token_ids # processing prompt_token_ids
if not request.get("prompt_token_ids"): if not request.get("prompt_token_ids"):
if request.get("prompt"): if request.get("prompt"):
@@ -143,6 +148,10 @@ class Ernie4_5Processor(BaseDataProcessor):
# truncate prompts that exceed the length limit # truncate prompts that exceed the length limit
if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len: if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len:
request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1] request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]
logits_processors_args = self._update_thinking_prompt_state(
request["prompt_token_ids"], request.get("logits_processors_args") or {}
)
request["logits_processors_args"] = logits_processors_args
max_tokens = max_model_len - len(request["prompt_token_ids"]) max_tokens = max_model_len - len(request["prompt_token_ids"])
if request.get("max_tokens") is None: if request.get("max_tokens") is None:
request["max_tokens"] = max(1, max_tokens) request["max_tokens"] = max(1, max_tokens)
@@ -210,6 +210,11 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids) bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids request["bad_words_token_ids"] = bad_words_token_ids
logits_processors_args = self._prepare_think_stop_sentence(
request.get("logits_processors_args") or {}, max_model_len
)
request["logits_processors_args"] = logits_processors_args
if request.get("prompt_token_ids"): if request.get("prompt_token_ids"):
messages = request.get("messages") messages = request.get("messages")
if messages: if messages:
@@ -257,6 +262,10 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
# 截断超过长度限制的prompt # 截断超过长度限制的prompt
if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len: if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len:
request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1] request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]
logits_processors_args = self._update_thinking_prompt_state(
request["prompt_token_ids"], request.get("logits_processors_args") or {}
)
request["logits_processors_args"] = logits_processors_args
max_tokens = max_model_len - len(request["prompt_token_ids"]) max_tokens = max_model_len - len(request["prompt_token_ids"])
if request.get("max_tokens") is None: if request.get("max_tokens") is None:
+97 -67
View File
@@ -92,6 +92,9 @@ class BaseDataProcessor(ABC):
""" """
Encode text into token ids with a small LRU cache. Encode text into token ids with a small LRU cache.
""" """
if not hasattr(self, "_tokenize_cache"):
self._tokenize_cache = OrderedDict()
self._tokenize_cache_capacity = getattr(self, "_tokenize_cache_capacity", 128)
key = (text, bool(add_special_tokens)) key = (text, bool(add_special_tokens))
cached = self._tokenize_cache.get(key) cached = self._tokenize_cache.get(key)
if cached is not None: if cached is not None:
@@ -107,6 +110,25 @@ class BaseDataProcessor(ABC):
self._tokenize_cache.popitem(last=False) self._tokenize_cache.popitem(last=False)
return token_ids return token_ids
def _encode_literal_text_with_cache(self, text):
if not hasattr(self, "_tokenize_cache"):
self._tokenize_cache = OrderedDict()
self._tokenize_cache_capacity = getattr(self, "_tokenize_cache_capacity", 128)
key = ("literal_text", text)
cached = self._tokenize_cache.get(key)
if cached is not None:
self._tokenize_cache.move_to_end(key)
return cached
token_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text))
if hasattr(token_ids, "tolist"):
token_ids = token_ids.tolist()
elif not isinstance(token_ids, list):
token_ids = list(token_ids)
self._tokenize_cache[key] = token_ids
if len(self._tokenize_cache) > self._tokenize_cache_capacity:
self._tokenize_cache.popitem(last=False)
return token_ids
def messages2ids(self, messages): def messages2ids(self, messages):
""" """
Convert multi-turn messages into ID sequences. Convert multi-turn messages into ID sequences.
@@ -119,6 +141,77 @@ class BaseDataProcessor(ABC):
""" """
raise NotImplementedError raise NotImplementedError
def _get_think_token_ids(self):
think_token_ids = getattr(self, "_think_token_ids", None)
if think_token_ids is not None:
return think_token_ids
tokenizer = getattr(self, "tokenizer", None)
vocab = tokenizer.get_vocab() if tokenizer is not None else {}
think_start_id = vocab.get("<think>", -1)
think_end_id = vocab.get("</think>", -1)
self._think_token_ids = (think_start_id, think_end_id)
return self._think_token_ids
def _prepare_think_stop_sentence(self, logits_processors_args, max_model_len=None):
if not isinstance(logits_processors_args, dict):
return logits_processors_args
think_stop_sentence = logits_processors_args.get("think_stop_sentence")
if isinstance(think_stop_sentence, str) and think_stop_sentence:
sentence_token_ids = self._encode_literal_text_with_cache(think_stop_sentence)
logits_processors_args["think_stop_sentence_token_ids"] = sentence_token_ids
logits_processors_args.pop("think_stop_sentence", None)
return logits_processors_args
def _update_thinking_prompt_state(self, prompt_token_ids, logits_processors_args):
if not isinstance(logits_processors_args, dict):
return logits_processors_args
thinking_budget = logits_processors_args.get("thinking_budget")
if thinking_budget is None or not isinstance(thinking_budget, int) or thinking_budget < 0:
return logits_processors_args
if logits_processors_args.get("think_prompt_checked"):
return logits_processors_args
if prompt_token_ids is None:
return logits_processors_args
token_len = getattr(prompt_token_ids, "size", None) or len(prompt_token_ids)
if token_len == 0:
return logits_processors_args
think_start_id, think_end_id = self._get_think_token_ids()
if think_start_id < 0 or think_end_id < 0:
return logits_processors_args
if hasattr(prompt_token_ids, "tolist"):
token_list = prompt_token_ids.tolist()
else:
token_list = list(prompt_token_ids)
started = False
ended = False
tokens_after_start = 0
last_token_id = None
in_thinking = False
for token_id in token_list:
if token_id == think_start_id:
started = True
ended = False
in_thinking = True
elif token_id == think_end_id and in_thinking:
ended = True
in_thinking = False
if started and token_list:
# Align with operator-level reasoning_max_tokens: prompt-side tokens
# inside <think> do not consume thinking budget.
last_token_id = int(token_list[-1])
logits_processors_args["think_prompt_checked"] = True
logits_processors_args["think_prompt_started"] = started
logits_processors_args["think_prompt_ended"] = ended
logits_processors_args["think_prompt_tokens_after_start"] = tokens_after_start
if last_token_id is not None:
logits_processors_args["think_prompt_last_token_id"] = last_token_id
else:
logits_processors_args.pop("think_prompt_last_token_id", None)
return logits_processors_args
def ids2tokens(self, token_id, task_id=None): def ids2tokens(self, token_id, task_id=None):
""" """
token ids to strings token ids to strings
@@ -207,65 +300,6 @@ class DataProcessor(BaseDataProcessor):
self.reasoning_parser = reasoning_parser_obj(self.tokenizer) self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
self.tokenizer.pad_token_id = self.pad_token_id self.tokenizer.pad_token_id = self.pad_token_id
self._think_token_ids = None
def _get_think_token_ids(self):
if self._think_token_ids is not None:
return self._think_token_ids
vocab = self.tokenizer.get_vocab()
think_start_id = vocab.get("<think>", -1)
think_end_id = vocab.get("</think>", -1)
self._think_token_ids = (think_start_id, think_end_id)
return self._think_token_ids
def _update_thinking_prompt_state(self, prompt_token_ids, logits_processors_args):
if not isinstance(logits_processors_args, dict):
return logits_processors_args
thinking_budget = logits_processors_args.get("thinking_budget")
if thinking_budget is None or not isinstance(thinking_budget, int) or thinking_budget < 0:
return logits_processors_args
if logits_processors_args.get("think_prompt_checked"):
return logits_processors_args
if prompt_token_ids is None:
return logits_processors_args
token_len = getattr(prompt_token_ids, "size", None) or len(prompt_token_ids)
if token_len == 0:
return logits_processors_args
think_start_id, think_end_id = self._get_think_token_ids()
if think_start_id < 0 or think_end_id < 0:
return logits_processors_args
if hasattr(prompt_token_ids, "tolist"):
token_list = prompt_token_ids.tolist()
else:
token_list = list(prompt_token_ids)
started = think_start_id in token_list
ended = False
tokens_after_start = 0
last_token_id = None
if started:
start_pos = token_list.index(think_start_id)
tokens_after = token_list[start_pos + 1 :]
if think_end_id in tokens_after:
end_pos = tokens_after.index(think_end_id)
tokens_after_start = end_pos + 1
ended = True
else:
tokens_after_start = len(tokens_after)
if token_list:
last_token_id = int(token_list[-1])
logits_processors_args["think_prompt_checked"] = True
logits_processors_args["think_prompt_started"] = started
logits_processors_args["think_prompt_ended"] = ended
logits_processors_args["think_prompt_tokens_after_start"] = tokens_after_start
if last_token_id is not None:
logits_processors_args["think_prompt_last_token_id"] = last_token_id
else:
logits_processors_args.pop("think_prompt_last_token_id", None)
return logits_processors_args
def process_request_dict(self, request, max_model_len=None, **kwargs): def process_request_dict(self, request, max_model_len=None, **kwargs):
""" """
Preprocess the request Preprocess the request
@@ -292,14 +326,10 @@ class DataProcessor(BaseDataProcessor):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids) bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids request["bad_words_token_ids"] = bad_words_token_ids
logits_processors_args = request.get("logits_processors_args") or {} logits_processors_args = self._prepare_think_stop_sentence(
think_stop_sentence = logits_processors_args.get("think_stop_sentence") request.get("logits_processors_args") or {}, max_model_len
if isinstance(think_stop_sentence, str) and think_stop_sentence: )
newline_token_ids = self.encode_with_cache("\n", max_model_len, add_special_tokens=False) request["logits_processors_args"] = logits_processors_args
sentence_token_ids = self.encode_with_cache(think_stop_sentence, max_model_len, add_special_tokens=False)
logits_processors_args["think_stop_sentence_token_ids"] = newline_token_ids + sentence_token_ids
logits_processors_args.pop("think_stop_sentence", None)
request["logits_processors_args"] = logits_processors_args
# processing prompt_token_ids # processing prompt_token_ids
if not request.get("prompt_token_ids"): if not request.get("prompt_token_ids"):
+20 -1
View File
@@ -107,6 +107,11 @@ class Ernie4_5Processor(BaseDataProcessor):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids) bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids request["bad_words_token_ids"] = bad_words_token_ids
logits_processors_args = self._prepare_think_stop_sentence(
request.get("logits_processors_args") or {}, max_model_len
)
request["logits_processors_args"] = logits_processors_args
# processing prompt_token_ids # processing prompt_token_ids
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0: if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
if request.prompt is not None: if request.prompt is not None:
@@ -123,7 +128,8 @@ class Ernie4_5Processor(BaseDataProcessor):
token_ids = self.tokenizer.convert_tokens_to_ids(tokens) token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
request.prompt_token_ids = token_ids request.prompt_token_ids = token_ids
data_processor_logger.debug( data_processor_logger.debug(
f"request_ids: {request.request_id}, prompt: {prompt}, tokens: {tokens}, token_ids: {token_ids}" f"request_ids: {request.request_id}, prompt: {prompt}, "
f"tokens: {tokens}, token_ids: {token_ids}"
) )
elif request.messages is not None: elif request.messages is not None:
task = request.to_dict() task = request.to_dict()
@@ -145,6 +151,10 @@ class Ernie4_5Processor(BaseDataProcessor):
# truncate prompts that exceed the length limit # truncate prompts that exceed the length limit
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len: if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1] request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
logits_processors_args = self._update_thinking_prompt_state(
request.prompt_token_ids, request.get("logits_processors_args") or {}
)
request["logits_processors_args"] = logits_processors_args
max_tokens = max_model_len - len(request.prompt_token_ids) max_tokens = max_model_len - len(request.prompt_token_ids)
if request.get("max_tokens") is None: if request.get("max_tokens") is None:
request.set("max_tokens", max(1, max_tokens)) request.set("max_tokens", max(1, max_tokens))
@@ -201,6 +211,11 @@ class Ernie4_5Processor(BaseDataProcessor):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids) bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request.sampling_params.bad_words_token_ids = bad_words_token_ids request.sampling_params.bad_words_token_ids = bad_words_token_ids
logits_processors_args = self._prepare_think_stop_sentence(
getattr(request.sampling_params, "logits_processors_args", None) or {}, max_model_len
)
request.sampling_params.logits_processors_args = logits_processors_args
# processing prompt_token_ids # processing prompt_token_ids
if not request.prompt_token_ids: if not request.prompt_token_ids:
if request.prompt: if request.prompt:
@@ -241,6 +256,10 @@ class Ernie4_5Processor(BaseDataProcessor):
# truncate prompts that exceed the length limit # truncate prompts that exceed the length limit
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len: if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1] request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
logits_processors_args = self._update_thinking_prompt_state(
request.prompt_token_ids, getattr(request.sampling_params, "logits_processors_args", None) or {}
)
request.sampling_params.logits_processors_args = logits_processors_args
max_tokens = max_model_len - len(request.prompt_token_ids) max_tokens = max_model_len - len(request.prompt_token_ids)
if getattr(request.sampling_params, "max_tokens", None) is None: if getattr(request.sampling_params, "max_tokens", None) is None:
request.sampling_params.max_tokens = max(1, max_tokens) request.sampling_params.max_tokens = max(1, max_tokens)
@@ -216,6 +216,11 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids) bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request.sampling_params.bad_words_token_ids = bad_words_token_ids request.sampling_params.bad_words_token_ids = bad_words_token_ids
logits_processors_args = self._prepare_think_stop_sentence(
getattr(request.sampling_params, "logits_processors_args", None) or {}, max_model_len
)
request.sampling_params.logits_processors_args = logits_processors_args
if request.prompt_token_ids: if request.prompt_token_ids:
messages = request.messages messages = request.messages
if messages: if messages:
@@ -267,6 +272,10 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
# 截断超过长度限制的prompt # 截断超过长度限制的prompt
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len: if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1] request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
logits_processors_args = self._update_thinking_prompt_state(
request.prompt_token_ids, getattr(request.sampling_params, "logits_processors_args", None) or {}
)
request.sampling_params.logits_processors_args = logits_processors_args
max_tokens = max_model_len - len(request.prompt_token_ids) max_tokens = max_model_len - len(request.prompt_token_ids)
if getattr(request.sampling_params, "max_tokens", None) is None: if getattr(request.sampling_params, "max_tokens", None) is None:
+101 -75
View File
@@ -115,6 +115,9 @@ class BaseDataProcessor(ABC):
""" """
Encode text into token ids with a small LRU cache. Encode text into token ids with a small LRU cache.
""" """
if not hasattr(self, "_tokenize_cache"):
self._tokenize_cache = OrderedDict()
self._tokenize_cache_capacity = getattr(self, "_tokenize_cache_capacity", 128)
key = (text, bool(add_special_tokens)) key = (text, bool(add_special_tokens))
cached = self._tokenize_cache.get(key) cached = self._tokenize_cache.get(key)
if cached is not None: if cached is not None:
@@ -130,6 +133,25 @@ class BaseDataProcessor(ABC):
self._tokenize_cache.popitem(last=False) self._tokenize_cache.popitem(last=False)
return token_ids return token_ids
def _encode_literal_text_with_cache(self, text):
if not hasattr(self, "_tokenize_cache"):
self._tokenize_cache = OrderedDict()
self._tokenize_cache_capacity = getattr(self, "_tokenize_cache_capacity", 128)
key = ("literal_text", text)
cached = self._tokenize_cache.get(key)
if cached is not None:
self._tokenize_cache.move_to_end(key)
return cached
token_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text))
if hasattr(token_ids, "tolist"):
token_ids = token_ids.tolist()
elif not isinstance(token_ids, list):
token_ids = list(token_ids)
self._tokenize_cache[key] = token_ids
if len(self._tokenize_cache) > self._tokenize_cache_capacity:
self._tokenize_cache.popitem(last=False)
return token_ids
def messages2ids(self, messages): def messages2ids(self, messages):
""" """
Convert multi-turn messages into ID sequences. Convert multi-turn messages into ID sequences.
@@ -142,6 +164,77 @@ class BaseDataProcessor(ABC):
""" """
raise NotImplementedError raise NotImplementedError
def _get_think_token_ids(self):
think_token_ids = getattr(self, "_think_token_ids", None)
if think_token_ids is not None:
return think_token_ids
tokenizer = getattr(self, "tokenizer", None)
vocab = tokenizer.get_vocab() if tokenizer is not None else {}
think_start_id = vocab.get("<think>", -1)
think_end_id = vocab.get("</think>", -1)
self._think_token_ids = (think_start_id, think_end_id)
return self._think_token_ids
def _prepare_think_stop_sentence(self, logits_processors_args, max_model_len=None):
if not isinstance(logits_processors_args, dict):
return logits_processors_args
think_stop_sentence = logits_processors_args.get("think_stop_sentence")
if isinstance(think_stop_sentence, str) and think_stop_sentence:
sentence_token_ids = self._encode_literal_text_with_cache(think_stop_sentence)
logits_processors_args["think_stop_sentence_token_ids"] = sentence_token_ids
logits_processors_args.pop("think_stop_sentence", None)
return logits_processors_args
def _update_thinking_prompt_state(self, prompt_token_ids, logits_processors_args):
if not isinstance(logits_processors_args, dict):
return logits_processors_args
thinking_budget = logits_processors_args.get("thinking_budget")
if thinking_budget is None or not isinstance(thinking_budget, int) or thinking_budget < 0:
return logits_processors_args
if logits_processors_args.get("think_prompt_checked"):
return logits_processors_args
if prompt_token_ids is None:
return logits_processors_args
token_len = getattr(prompt_token_ids, "size", None) or len(prompt_token_ids)
if token_len == 0:
return logits_processors_args
think_start_id, think_end_id = self._get_think_token_ids()
if think_start_id < 0 or think_end_id < 0:
return logits_processors_args
if hasattr(prompt_token_ids, "tolist"):
token_list = prompt_token_ids.tolist()
else:
token_list = list(prompt_token_ids)
started = False
ended = False
tokens_after_start = 0
last_token_id = None
in_thinking = False
for token_id in token_list:
if token_id == think_start_id:
started = True
ended = False
in_thinking = True
elif token_id == think_end_id and in_thinking:
ended = True
in_thinking = False
if started and token_list:
# Align with operator-level reasoning_max_tokens: prompt-side tokens
# inside <think> do not consume thinking budget.
last_token_id = int(token_list[-1])
logits_processors_args["think_prompt_checked"] = True
logits_processors_args["think_prompt_started"] = started
logits_processors_args["think_prompt_ended"] = ended
logits_processors_args["think_prompt_tokens_after_start"] = tokens_after_start
if last_token_id is not None:
logits_processors_args["think_prompt_last_token_id"] = last_token_id
else:
logits_processors_args.pop("think_prompt_last_token_id", None)
return logits_processors_args
def ids2tokens(self, token_id, task_id=None): def ids2tokens(self, token_id, task_id=None):
""" """
token ids to strings token ids to strings
@@ -221,65 +314,6 @@ class DataProcessor(BaseDataProcessor):
self.reasoning_parser = reasoning_parser_obj(self.tokenizer) self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
self.tokenizer.pad_token_id = self.pad_token_id self.tokenizer.pad_token_id = self.pad_token_id
self._think_token_ids = None
def _get_think_token_ids(self):
if self._think_token_ids is not None:
return self._think_token_ids
vocab = self.tokenizer.get_vocab()
think_start_id = vocab.get("<think>", -1)
think_end_id = vocab.get("</think>", -1)
self._think_token_ids = (think_start_id, think_end_id)
return self._think_token_ids
def _update_thinking_prompt_state(self, prompt_token_ids, logits_processors_args):
if not isinstance(logits_processors_args, dict):
return logits_processors_args
thinking_budget = logits_processors_args.get("thinking_budget")
if thinking_budget is None or not isinstance(thinking_budget, int) or thinking_budget < 0:
return logits_processors_args
if logits_processors_args.get("think_prompt_checked"):
return logits_processors_args
if prompt_token_ids is None:
return logits_processors_args
token_len = getattr(prompt_token_ids, "size", None) or len(prompt_token_ids)
if token_len == 0:
return logits_processors_args
think_start_id, think_end_id = self._get_think_token_ids()
if think_start_id < 0 or think_end_id < 0:
return logits_processors_args
if hasattr(prompt_token_ids, "tolist"):
token_list = prompt_token_ids.tolist()
else:
token_list = list(prompt_token_ids)
started = think_start_id in token_list
ended = False
tokens_after_start = 0
last_token_id = None
if started:
start_pos = token_list.index(think_start_id)
tokens_after = token_list[start_pos + 1 :]
if think_end_id in tokens_after:
end_pos = tokens_after.index(think_end_id)
tokens_after_start = end_pos + 1
ended = True
else:
tokens_after_start = len(tokens_after)
if token_list:
last_token_id = int(token_list[-1])
logits_processors_args["think_prompt_checked"] = True
logits_processors_args["think_prompt_started"] = started
logits_processors_args["think_prompt_ended"] = ended
logits_processors_args["think_prompt_tokens_after_start"] = tokens_after_start
if last_token_id is not None:
logits_processors_args["think_prompt_last_token_id"] = last_token_id
else:
logits_processors_args.pop("think_prompt_last_token_id", None)
return logits_processors_args
def process_request(self, request, max_model_len=None, **kwargs): def process_request(self, request, max_model_len=None, **kwargs):
""" """
Preprocess the request Preprocess the request
@@ -306,14 +340,10 @@ class DataProcessor(BaseDataProcessor):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids) bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids request["bad_words_token_ids"] = bad_words_token_ids
logits_processors_args = request.get("logits_processors_args") or {} logits_processors_args = self._prepare_think_stop_sentence(
think_stop_sentence = logits_processors_args.get("think_stop_sentence") request.get("logits_processors_args") or {}, max_model_len
if isinstance(think_stop_sentence, str) and think_stop_sentence: )
newline_token_ids = self.encode_with_cache("\n", max_model_len, add_special_tokens=False) request["logits_processors_args"] = logits_processors_args
sentence_token_ids = self.encode_with_cache(think_stop_sentence, max_model_len, add_special_tokens=False)
logits_processors_args["think_stop_sentence_token_ids"] = newline_token_ids + sentence_token_ids
logits_processors_args.pop("think_stop_sentence", None)
request["logits_processors_args"] = logits_processors_args
# processing prompt_token_ids # processing prompt_token_ids
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0: if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
@@ -411,14 +441,10 @@ class DataProcessor(BaseDataProcessor):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids) bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request.sampling_params.bad_words_token_ids = bad_words_token_ids request.sampling_params.bad_words_token_ids = bad_words_token_ids
logits_processors_args = getattr(request.sampling_params, "logits_processors_args", None) or {} logits_processors_args = self._prepare_think_stop_sentence(
think_stop_sentence = logits_processors_args.get("think_stop_sentence") getattr(request.sampling_params, "logits_processors_args", None) or {}, max_model_len
if isinstance(think_stop_sentence, str) and think_stop_sentence: )
newline_token_ids = self.encode_with_cache("\n", max_model_len, add_special_tokens=False) request.sampling_params.logits_processors_args = logits_processors_args
sentence_token_ids = self.encode_with_cache(think_stop_sentence, max_model_len, add_special_tokens=False)
logits_processors_args["think_stop_sentence_token_ids"] = newline_token_ids + sentence_token_ids
logits_processors_args.pop("think_stop_sentence", None)
request.sampling_params.logits_processors_args = logits_processors_args
# processing prompt_token_ids # processing prompt_token_ids
if not request.prompt_token_ids: if not request.prompt_token_ids:
@@ -42,14 +42,15 @@ class _ThinkingState:
class ThinkingBudgetLogitsProcessor(LogitsProcessor): class ThinkingBudgetLogitsProcessor(LogitsProcessor):
"""Limit the number of tokens generated in the thinking phase. """Limit the number of tokens generated in the thinking phase.
The processor tracks per-request thinking state and forces a newline token The processor tracks per-request thinking state and forces the thinking end
when the budget is reached, followed by the thinking end token on the next step. token when the budget is reached. If a stop sentence is configured, the
processor emits the stop sentence first and then the thinking end token.
Request-specific configuration is provided via logits_processors_args: Request-specific configuration is provided via logits_processors_args:
{"thinking_budget": <int>} {"thinking_budget": <int>}
Requires model_config to provide think_start_id, think_end_id, and line_break_id. Requires model_config to provide think_start_id and think_end_id. If any of
If any of these are missing or invalid (-1), the processor will be disabled. these are missing or invalid (-1), the processor will be disabled.
""" """
def __init__(self, fd_config: FDConfig) -> None: def __init__(self, fd_config: FDConfig) -> None:
@@ -61,20 +62,33 @@ class ThinkingBudgetLogitsProcessor(LogitsProcessor):
self.think_start_token_id = think_start_id if isinstance(think_start_id, int) and think_start_id >= 0 else -1 self.think_start_token_id = think_start_id if isinstance(think_start_id, int) and think_start_id >= 0 else -1
self.think_end_token_id = think_end_id if isinstance(think_end_id, int) and think_end_id >= 0 else -1 self.think_end_token_id = think_end_id if isinstance(think_end_id, int) and think_end_id >= 0 else -1
self.line_break_token_id = line_break_id if isinstance(line_break_id, int) and line_break_id >= 0 else -1 self.line_break_token_id = line_break_id if isinstance(line_break_id, int) and line_break_id >= 0 else -1
self._enabled = ( self._enabled = self.think_start_token_id >= 0 and self.think_end_token_id >= 0
self.think_start_token_id >= 0 and self.think_end_token_id >= 0 and self.line_break_token_id >= 0
)
if not self._enabled: if not self._enabled:
logger.warning( logger.warning(
"ThinkingBudgetLogitsProcessor disabled: missing token ids " "ThinkingBudgetLogitsProcessor disabled: missing token ids "
f"(think_start={think_start_id}, think_end={think_end_id}, line_break={line_break_id}). " f"(think_start={think_start_id}, think_end={think_end_id}). "
"Ensure model vocab contains <think>, </think> tokens and line_break_id is configured." "Ensure model vocab contains <think> and </think> tokens."
) )
self._states: Dict[str, _ThinkingState] = {} self._states: Dict[str, _ThinkingState] = {}
self._active_req_ids: list[str] = [] self._active_req_ids: list[str] = []
self._active_budgets: list[int] = [] self._active_budgets: list[int] = []
self._active_slots: list[int] = [] self._active_slots: list[int] = []
def _scan_prompt_state(self, prompt_slice: list[int]) -> tuple[bool, bool, int, Optional[int]]:
started = False
ended = False
in_thinking = False
for token_id in prompt_slice:
if token_id == self.think_start_token_id:
started = True
ended = False
in_thinking = True
elif token_id == self.think_end_token_id and in_thinking:
ended = True
in_thinking = False
last_token_id = int(prompt_slice[-1]) if started and prompt_slice else None
return started, ended, 0, last_token_id
def update_state(self, share_inputs: dict) -> None: def update_state(self, share_inputs: dict) -> None:
if not self._enabled: if not self._enabled:
return return
@@ -82,6 +96,7 @@ class ThinkingBudgetLogitsProcessor(LogitsProcessor):
req_ids = share_inputs["req_ids"] req_ids = share_inputs["req_ids"]
logits_processors_args = share_inputs["logits_processors_args"] logits_processors_args = share_inputs["logits_processors_args"]
prompt_ids = share_inputs.get("prompt_ids") prompt_ids = share_inputs.get("prompt_ids")
token_ids_all = share_inputs.get("token_ids_all")
prompt_lens = share_inputs.get("prompt_lens") prompt_lens = share_inputs.get("prompt_lens")
pre_ids = share_inputs.get("pre_ids") pre_ids = share_inputs.get("pre_ids")
next_tokens = share_inputs.get("next_tokens") next_tokens = share_inputs.get("next_tokens")
@@ -153,6 +168,8 @@ class ThinkingBudgetLogitsProcessor(LogitsProcessor):
for idx, slot_id in enumerate(candidate_slots): for idx, slot_id in enumerate(candidate_slots):
next_token_by_slot[slot_id] = int(next_sel[idx]) next_token_by_slot[slot_id] = int(next_sel[idx])
prompt_source = prompt_ids if prompt_ids is not None else token_ids_all
for idx, slot_id in enumerate(candidate_slots): for idx, slot_id in enumerate(candidate_slots):
req_id = candidate_req_ids[idx] req_id = candidate_req_ids[idx]
logit_proc_args = candidate_args[idx] logit_proc_args = candidate_args[idx]
@@ -190,25 +207,28 @@ class ThinkingBudgetLogitsProcessor(LogitsProcessor):
state.current_step_idx = current_step_idx state.current_step_idx = current_step_idx
if not state.started and not state.prompt_checked: if not state.started and not state.prompt_checked:
if prompt_ids is not None and prompt_lens is not None: if prompt_source is not None and prompt_lens is not None:
if prompt_lens_np is not None: if prompt_lens_np is not None:
prompt_len = int(prompt_lens_np[slot_id]) prompt_len = int(prompt_lens_np[slot_id])
else: else:
prompt_len = int(prompt_lens[slot_id]) prompt_len = int(prompt_lens[slot_id])
prompt_slice = prompt_ids[slot_id, :prompt_len] prompt_slice = prompt_source[slot_id, :prompt_len]
prompt_slice = prompt_slice.numpy().tolist() if hasattr(prompt_slice, "numpy"):
if self.think_start_token_id in prompt_slice: prompt_slice = prompt_slice.numpy().tolist()
elif hasattr(prompt_slice, "tolist"):
prompt_slice = prompt_slice.tolist()
else:
prompt_slice = list(prompt_slice)
if prompt_ids is None:
prompt_slice = [int(token_id) for token_id in prompt_slice if int(token_id) >= 0]
prompt_started, prompt_ended, prompt_tokens_after_start, prompt_last_token_id = (
self._scan_prompt_state(prompt_slice)
)
if prompt_started:
state.started = True state.started = True
start_pos = prompt_slice.index(self.think_start_token_id) state.ended = prompt_ended
tokens_after = prompt_slice[start_pos + 1 :] state.tokens_after_start = prompt_tokens_after_start
if self.think_end_token_id in tokens_after: state.last_token_id = prompt_last_token_id
end_pos = tokens_after.index(self.think_end_token_id)
state.tokens_after_start = end_pos + 1
state.ended = True
else:
state.tokens_after_start = len(tokens_after)
if prompt_slice:
state.last_token_id = int(prompt_slice[-1])
if current_step_idx is not None and state.last_step_idx is None: if current_step_idx is not None and state.last_step_idx is None:
state.last_step_idx = current_step_idx state.last_step_idx = current_step_idx
state.prompt_checked = True state.prompt_checked = True
@@ -304,13 +324,6 @@ class ThinkingBudgetLogitsProcessor(LogitsProcessor):
if state.tokens_after_start < budget: if state.tokens_after_start < budget:
continue continue
if state.last_token_id != self.line_break_token_id:
logits[slot_id, :] = -float("inf")
logits[slot_id, self.line_break_token_id] = 0.0
state.last_token_id = self.line_break_token_id
state.last_step_idx = state.current_step_idx
continue
logits[slot_id, :] = -float("inf") logits[slot_id, :] = -float("inf")
logits[slot_id, self.think_end_token_id] = 0.0 logits[slot_id, self.think_end_token_id] = 0.0
state.last_token_id = self.think_end_token_id state.last_token_id = self.think_end_token_id
+327 -89
View File
@@ -20,7 +20,19 @@ from fastdeploy.engine import common_engine as common_engine_module
from fastdeploy.engine import engine as engine_module from fastdeploy.engine import engine as engine_module
from fastdeploy.engine.args_utils import EngineArgs # Import EngineArgs from fastdeploy.engine.args_utils import EngineArgs # Import EngineArgs
from fastdeploy.engine.sampling_params import SamplingParams from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.input.ernie4_5_processor import (
Ernie4_5Processor as ErnieTextDataProcessor,
)
from fastdeploy.input.ernie4_5_vl_processor import (
Ernie4_5_VLProcessor as ErnieVLDataProcessor,
)
from fastdeploy.input.text_processor import DataProcessor as TextDataProcessor from fastdeploy.input.text_processor import DataProcessor as TextDataProcessor
from fastdeploy.input.v1.ernie4_5_processor import (
Ernie4_5Processor as V1ErnieTextDataProcessor,
)
from fastdeploy.input.v1.ernie4_5_vl_processor.ernie4_5_vl_processor import (
Ernie4_5_VLProcessor as V1ErnieVLDataProcessor,
)
from fastdeploy.input.v1.text_processor import DataProcessor as V1TextDataProcessor from fastdeploy.input.v1.text_processor import DataProcessor as V1TextDataProcessor
from fastdeploy.model_executor.logits_processor import ThinkingBudgetLogitsProcessor from fastdeploy.model_executor.logits_processor import ThinkingBudgetLogitsProcessor
from fastdeploy.scheduler import SchedulerConfig from fastdeploy.scheduler import SchedulerConfig
@@ -52,6 +64,7 @@ class MockModelRunner:
"req_ids": [f"req_{i}" for i in range(max_num_seqs)], "req_ids": [f"req_{i}" for i in range(max_num_seqs)],
"logits_processors_args": [{} for _ in range(max_num_seqs)], "logits_processors_args": [{} for _ in range(max_num_seqs)],
"prompt_ids": paddle.to_tensor(np.zeros((max_num_seqs, 10), dtype=np.int64)), # Max prompt len 10 "prompt_ids": paddle.to_tensor(np.zeros((max_num_seqs, 10), dtype=np.int64)), # Max prompt len 10
"token_ids_all": paddle.to_tensor(np.full((max_num_seqs, self.max_model_len), -1, dtype=np.int64)),
"prompt_lens": paddle.to_tensor(np.zeros((max_num_seqs, 1), dtype=np.int64)), "prompt_lens": paddle.to_tensor(np.zeros((max_num_seqs, 1), dtype=np.int64)),
"pre_ids": paddle.to_tensor(np.full((max_num_seqs, self.max_model_len), -1, dtype=np.int64)), "pre_ids": paddle.to_tensor(np.full((max_num_seqs, self.max_model_len), -1, dtype=np.int64)),
"step_idx": paddle.to_tensor(np.zeros((max_num_seqs, 1), dtype=np.int64)), "step_idx": paddle.to_tensor(np.zeros((max_num_seqs, 1), dtype=np.int64)),
@@ -73,6 +86,12 @@ class MockModelRunner:
self.share_inputs["prompt_ids"][slot_id, : len(req.prompt_ids)] = paddle.to_tensor( self.share_inputs["prompt_ids"][slot_id, : len(req.prompt_ids)] = paddle.to_tensor(
req.prompt_ids, dtype=paddle.int64 req.prompt_ids, dtype=paddle.int64
) )
self.share_inputs["token_ids_all"][slot_id, :] = paddle.to_tensor(
np.full((self.max_model_len,), -1, dtype=np.int64), dtype=paddle.int64
)
self.share_inputs["token_ids_all"][slot_id, : len(req.prompt_ids)] = paddle.to_tensor(
req.prompt_ids, dtype=paddle.int64
)
self.share_inputs["prompt_lens"][slot_id] = paddle.to_tensor(len(req.prompt_ids), dtype=paddle.int64) self.share_inputs["prompt_lens"][slot_id] = paddle.to_tensor(len(req.prompt_ids), dtype=paddle.int64)
if req.sampling_params.logits_processors_args: if req.sampling_params.logits_processors_args:
self.share_inputs["logits_processors_args"][slot_id] = req.sampling_params.logits_processors_args self.share_inputs["logits_processors_args"][slot_id] = req.sampling_params.logits_processors_args
@@ -152,7 +171,7 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
return logits return logits
def test_thinking_budget_not_reached(self): def test_thinking_budget_not_reached(self):
# Scenario: Thinking budget is 5, but only 3 tokens are generated in thinking phase # Scenario: Thinking budget is 5, and prompt-side tokens after <think> do not consume budget.
req_id = "test_req_1" req_id = "test_req_1"
prompt_ids = [1, 2, THINKING_START_TOKEN_ID, 3, 4, 5] prompt_ids = [1, 2, THINKING_START_TOKEN_ID, 3, 4, 5]
sampling_params = SamplingParams(logits_processors_args={"thinking_budget": 5}) sampling_params = SamplingParams(logits_processors_args={"thinking_budget": 5})
@@ -167,9 +186,9 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
processor.update_state(mock_runner.share_inputs) processor.update_state(mock_runner.share_inputs)
self.assertTrue(processor._states[req_id].started) self.assertTrue(processor._states[req_id].started)
self.assertFalse(processor._states[req_id].ended) self.assertFalse(processor._states[req_id].ended)
self.assertEqual(processor._states[req_id].tokens_after_start, 3) # (3, 4, 5) after THINKING_START self.assertEqual(processor._states[req_id].tokens_after_start, 0)
# Step 2: Simulate one generation step (budget 5, generated 3 -> 4) # Step 2: Simulate one generation step (budget 5, generated 0 -> 1)
logits = self._get_initial_logits(1) logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs) # Update state before apply processor.update_state(mock_runner.share_inputs) # Update state before apply
processed_logits = processor.apply(logits) processed_logits = processor.apply(logits)
@@ -181,9 +200,9 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
mock_runner.update_request_state(0, mock_req, pre_id=0) # Update last token mock_runner.update_request_state(0, mock_req, pre_id=0) # Update last token
processor.update_state(mock_runner.share_inputs) # Update state after generating token processor.update_state(mock_runner.share_inputs) # Update state after generating token
self.assertEqual(processor._states[req_id].tokens_after_start, 4) # (3, 4, 5, 0) self.assertEqual(processor._states[req_id].tokens_after_start, 1)
# Step 3: Simulate another generation step (budget 5, generated 4 -> 5) # Step 3: Simulate another generation step (budget 5, generated 1 -> 2)
logits = self._get_initial_logits(1) logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs) # Update state before apply processor.update_state(mock_runner.share_inputs) # Update state before apply
processed_logits = processor.apply(logits) processed_logits = processor.apply(logits)
@@ -191,14 +210,14 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
mock_runner.update_request_state(0, mock_req, pre_id=0) # Update last token mock_runner.update_request_state(0, mock_req, pre_id=0) # Update last token
processor.update_state(mock_runner.share_inputs) # Update state after generating token processor.update_state(mock_runner.share_inputs) # Update state after generating token
self.assertEqual(processor._states[req_id].tokens_after_start, 5) # (3, 4, 5, 0, 0) self.assertEqual(processor._states[req_id].tokens_after_start, 2)
# LogitsProcessor should still not restrict as NEW_LINE_TOKEN is not yet last token # LogitsProcessor should still not restrict as NEW_LINE_TOKEN is not yet last token
def test_thinking_budget_reached_forces_newline(self): def test_thinking_budget_reached_forces_think_end(self):
# Scenario: Budget is 3, after 3 tokens, it should force newline, then thinking_end # Scenario: Budget is 3 and only decode-time tokens count toward the budget.
req_id = "test_req_2" req_id = "test_req_2"
prompt_ids = [1, 2, THINKING_START_TOKEN_ID, 3] # Initial 1 token after start prompt_ids = [1, 2, THINKING_START_TOKEN_ID, 3]
sampling_params = SamplingParams(logits_processors_args={"thinking_budget": 3}) sampling_params = SamplingParams(logits_processors_args={"thinking_budget": 3})
mock_req = MockRequest(req_id, prompt_ids, sampling_params) mock_req = MockRequest(req_id, prompt_ids, sampling_params)
@@ -206,16 +225,26 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
mock_runner.update_request_state(0, mock_req, pre_id=prompt_ids[-1]) mock_runner.update_request_state(0, mock_req, pre_id=prompt_ids[-1])
processor = ThinkingBudgetLogitsProcessor(self.fd_config) processor = ThinkingBudgetLogitsProcessor(self.fd_config)
line_break_id = processor.line_break_token_id
think_end_id = processor.think_end_token_id think_end_id = processor.think_end_token_id
# Step 1: Initial state update (1 token after start) # Step 1: Initial state update (prompt-side tokens do not count)
processor.update_state(mock_runner.share_inputs) processor.update_state(mock_runner.share_inputs)
self.assertEqual(processor._states[req_id].tokens_after_start, 1) self.assertEqual(processor._states[req_id].tokens_after_start, 0)
self.assertFalse(processor._states[req_id].ended) self.assertFalse(processor._states[req_id].ended)
self.assertEqual(processor._states[req_id].last_token_id, 3) self.assertEqual(processor._states[req_id].last_token_id, 3)
# Step 2: Generate 2nd token (budget 3, generated 1 -> 2) # Step 2: Generate 1st decode token (budget 3, generated 0 -> 1)
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs)
processed_logits = processor.apply(logits)
next_token = mock_runner.generate_next_token(processed_logits)[0]
self.assertEqual(next_token, 0) # Normal generation
mock_runner.update_request_state(0, mock_req, pre_id=next_token)
processor.update_state(mock_runner.share_inputs)
self.assertEqual(processor._states[req_id].tokens_after_start, 1)
self.assertEqual(processor._states[req_id].last_token_id, 0)
# Step 3: Generate 2nd decode token (budget 3, generated 1 -> 2)
logits = self._get_initial_logits(1) logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs) processor.update_state(mock_runner.share_inputs)
processed_logits = processor.apply(logits) processed_logits = processor.apply(logits)
@@ -226,42 +255,17 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
self.assertEqual(processor._states[req_id].tokens_after_start, 2) self.assertEqual(processor._states[req_id].tokens_after_start, 2)
self.assertEqual(processor._states[req_id].last_token_id, 0) self.assertEqual(processor._states[req_id].last_token_id, 0)
# Step 3: Generate 3rd token (budget 3, generated 2 -> 3). Next step should force NEW_LINE. # Step 4: Generate 3rd decode token (budget 3, generated 2 -> 3).
logits = self._get_initial_logits(1) logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs) processor.update_state(mock_runner.share_inputs)
processed_logits = processor.apply(logits) processed_logits = processor.apply(logits)
next_token = mock_runner.generate_next_token(processed_logits)[0] next_token = mock_runner.generate_next_token(processed_logits)[0]
self.assertEqual(next_token, 0) # Normal generation self.assertEqual(next_token, 0)
mock_runner.update_request_state(0, mock_req, pre_id=next_token) mock_runner.update_request_state(0, mock_req, pre_id=next_token)
processor.update_state(mock_runner.share_inputs) processor.update_state(mock_runner.share_inputs)
self.assertEqual(processor._states[req_id].tokens_after_start, 3) # Budget is now met self.assertEqual(processor._states[req_id].tokens_after_start, 3)
self.assertEqual(processor._states[req_id].last_token_id, 0)
# Step 4: Budget reached, last token not NEW_LINE. Should force NEW_LINE_TOKEN_ID. # Step 5: Budget reached, should force THINKING_END_TOKEN_ID directly.
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs)
processed_logits = processor.apply(logits)
# Verify all other logits are -inf, only NEW_LINE_TOKEN_ID is 0.0
# Use <= for comparison because paddle.full creates -10.0 and comparison can have precision issues
other_logits = paddle.concat(
[
processed_logits[0, :line_break_id],
processed_logits[0, line_break_id + 1 : VOCAB_SIZE],
],
axis=0,
)
self.assertTrue(paddle.all(other_logits <= -10.0).item() or paddle.all(other_logits == -float("inf")).item())
self.assertEqual(processed_logits[0, line_break_id].item(), 0.0)
next_token = mock_runner.generate_next_token(processed_logits)[0]
self.assertEqual(next_token, line_break_id) # Forces NEW_LINE
mock_runner.update_request_state(0, mock_req, pre_id=next_token)
processor.update_state(mock_runner.share_inputs)
self.assertEqual(processor._states[req_id].tokens_after_start, 4) # Still increment
self.assertEqual(processor._states[req_id].last_token_id, line_break_id)
# Step 5: Last token is NEW_LINE. Should force THINKING_END_TOKEN_ID.
logits = self._get_initial_logits(1) logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs) processor.update_state(mock_runner.share_inputs)
processed_logits = processor.apply(logits) processed_logits = processor.apply(logits)
@@ -324,9 +328,9 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
self.assertEqual(next_token, stop_sentence_token_ids[0]) self.assertEqual(next_token, stop_sentence_token_ids[0])
def test_thinking_budget_no_stop_sentence_defaults(self): def test_thinking_budget_no_stop_sentence_defaults(self):
# Scenario: No stop sentence, budget reached should force newline. # Scenario: No stop sentence, budget reached should force thinking_end directly.
req_id = "test_req_no_stop_sentence" req_id = "test_req_no_stop_sentence"
prompt_ids = [THINKING_START_TOKEN_ID, 42] # 1 token after start prompt_ids = [THINKING_START_TOKEN_ID, 42]
sampling_params = SamplingParams(logits_processors_args={"thinking_budget": 1}) sampling_params = SamplingParams(logits_processors_args={"thinking_budget": 1})
mock_req = MockRequest(req_id, prompt_ids, sampling_params) mock_req = MockRequest(req_id, prompt_ids, sampling_params)
@@ -334,35 +338,42 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
mock_runner.update_request_state(0, mock_req, pre_id=prompt_ids[-1]) mock_runner.update_request_state(0, mock_req, pre_id=prompt_ids[-1])
processor = ThinkingBudgetLogitsProcessor(self.fd_config) processor = ThinkingBudgetLogitsProcessor(self.fd_config)
line_break_id = processor.line_break_token_id
processor.update_state(mock_runner.share_inputs) processor.update_state(mock_runner.share_inputs)
logits = self._get_initial_logits(1) logits = self._get_initial_logits(1)
processed_logits = processor.apply(logits) processed_logits = processor.apply(logits)
next_token = mock_runner.generate_next_token(processed_logits)[0] next_token = mock_runner.generate_next_token(processed_logits)[0]
self.assertEqual(next_token, line_break_id) self.assertEqual(next_token, 0)
mock_runner.update_request_state(0, mock_req, pre_id=next_token)
processor.update_state(mock_runner.share_inputs)
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs)
processed_logits = processor.apply(logits)
next_token = mock_runner.generate_next_token(processed_logits)[0]
self.assertEqual(next_token, THINKING_END_TOKEN_ID)
def test_thinking_budget_uses_config_token_ids(self): def test_thinking_budget_uses_config_token_ids(self):
# Scenario: Processor should use token ids from model config. # Scenario: Processor should use token ids from model config.
self.fd_config.model_config.think_start_id = 123 self.fd_config.model_config.think_start_id = 123
self.fd_config.model_config.think_end_id = 124 self.fd_config.model_config.think_end_id = 124
self.fd_config.model_config.line_break_id = 125 self.fd_config.model_config.line_break_id = -1
processor = ThinkingBudgetLogitsProcessor(self.fd_config) processor = ThinkingBudgetLogitsProcessor(self.fd_config)
self.assertEqual(processor.think_start_token_id, 123) self.assertEqual(processor.think_start_token_id, 123)
self.assertEqual(processor.think_end_token_id, 124) self.assertEqual(processor.think_end_token_id, 124)
self.assertEqual(processor.line_break_token_id, 125) self.assertEqual(processor.line_break_token_id, -1)
self.assertTrue(processor._enabled) self.assertTrue(processor._enabled)
def test_thinking_budget_disabled_when_token_ids_missing(self): def test_thinking_budget_disabled_when_token_ids_missing(self):
# Scenario: Processor should be disabled when token ids are not configured. # Scenario: Processor should be disabled when token ids are not configured.
self.fd_config.model_config.think_start_id = -1 self.fd_config.model_config.think_start_id = -1
self.fd_config.model_config.think_end_id = -1 self.fd_config.model_config.think_end_id = -1
self.fd_config.model_config.line_break_id = -1 self.fd_config.model_config.line_break_id = NEW_LINE_TOKEN_ID
processor = ThinkingBudgetLogitsProcessor(self.fd_config) processor = ThinkingBudgetLogitsProcessor(self.fd_config)
self.assertFalse(processor._enabled) self.assertFalse(processor._enabled)
self.assertEqual(processor.think_start_token_id, -1) self.assertEqual(processor.think_start_token_id, -1)
self.assertEqual(processor.think_end_token_id, -1) self.assertEqual(processor.think_end_token_id, -1)
self.assertEqual(processor.line_break_token_id, -1) self.assertEqual(processor.line_break_token_id, NEW_LINE_TOKEN_ID)
# update_state and apply should be no-op when disabled # update_state and apply should be no-op when disabled
mock_runner = MockModelRunner(self.fd_config, max_num_seqs=1) mock_runner = MockModelRunner(self.fd_config, max_num_seqs=1)
@@ -493,6 +504,28 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
self.assertEqual(state.tokens_after_start, 2) self.assertEqual(state.tokens_after_start, 2)
self.assertEqual(state.last_token_id, 99) self.assertEqual(state.last_token_id, 99)
def test_thinking_budget_prompt_state_from_token_ids_all_fallback(self):
req_id = "req_gpu_fallback"
mock_runner = MockModelRunner(self.fd_config, max_num_seqs=1)
mock_runner.share_inputs["req_ids"][0] = req_id
mock_runner.share_inputs["logits_processors_args"][0] = {"thinking_budget": 3}
mock_runner.share_inputs["prompt_ids"] = None
mock_runner.share_inputs["token_ids_all"][0, :4] = paddle.to_tensor(
[1, THINKING_START_TOKEN_ID, 2, 3], dtype=paddle.int64
)
mock_runner.share_inputs["prompt_lens"][0, 0] = paddle.to_tensor(4, dtype=paddle.int64)
mock_runner.share_inputs["next_tokens"][0, 0] = paddle.to_tensor(-1, dtype=paddle.int64)
processor = ThinkingBudgetLogitsProcessor(self.fd_config)
processor.update_state(mock_runner.share_inputs)
state = processor._states[req_id]
self.assertTrue(state.prompt_checked)
self.assertTrue(state.started)
self.assertFalse(state.ended)
self.assertEqual(state.tokens_after_start, 0)
self.assertEqual(state.last_token_id, 3)
def test_thinking_budget_not_configured(self): def test_thinking_budget_not_configured(self):
# Scenario: Processor is active, but request does not provide thinking_budget # Scenario: Processor is active, but request does not provide thinking_budget
req_id = "test_req_3" req_id = "test_req_3"
@@ -532,20 +565,7 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
self.assertEqual(processor._states[req_id].tokens_after_start, 0) # No tokens after start yet self.assertEqual(processor._states[req_id].tokens_after_start, 0) # No tokens after start yet
self.assertEqual(processor._states[req_id].last_token_id, THINKING_START_TOKEN_ID) self.assertEqual(processor._states[req_id].last_token_id, THINKING_START_TOKEN_ID)
# Step 1: Budget 0 reached, last token is THINKING_START. Should force NEW_LINE. # Step 1: Budget 0 reached, last token is THINKING_START. Should force THINKING_END.
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs)
processed_logits = processor.apply(logits)
self.assertEqual(processed_logits[0, NEW_LINE_TOKEN_ID].item(), 0.0)
next_token = mock_runner.generate_next_token(processed_logits)[0]
self.assertEqual(next_token, NEW_LINE_TOKEN_ID)
mock_runner.update_request_state(0, mock_req, pre_id=next_token)
processor.update_state(mock_runner.share_inputs)
self.assertEqual(processor._states[req_id].tokens_after_start, 1) # Still increments
self.assertEqual(processor._states[req_id].last_token_id, NEW_LINE_TOKEN_ID)
# Step 2: Last token is NEW_LINE. Should force THINKING_END_TOKEN_ID.
logits = self._get_initial_logits(1) logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs) processor.update_state(mock_runner.share_inputs)
processed_logits = processor.apply(logits) processed_logits = processor.apply(logits)
@@ -572,7 +592,7 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
processor.update_state(mock_runner.share_inputs) processor.update_state(mock_runner.share_inputs)
self.assertTrue(processor._states[req_id].started) self.assertTrue(processor._states[req_id].started)
self.assertTrue(processor._states[req_id].ended) self.assertTrue(processor._states[req_id].ended)
self.assertEqual(processor._states[req_id].tokens_after_start, 2) # Tokens 2, THINKING_END after start self.assertEqual(processor._states[req_id].tokens_after_start, 0)
logits = self._get_initial_logits(1) logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs) processor.update_state(mock_runner.share_inputs)
@@ -582,7 +602,7 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
def test_multiple_requests(self): def test_multiple_requests(self):
# Scenario: Multiple requests with different thinking states # Scenario: Multiple requests with different thinking states
req_id_1 = "req_a" req_id_1 = "req_a"
prompt_ids_1 = [THINKING_START_TOKEN_ID, 10, 11] # budget 2, last_token=11, tokens_after_start=2 prompt_ids_1 = [THINKING_START_TOKEN_ID, 10, 11]
sampling_params_1 = SamplingParams(logits_processors_args={"thinking_budget": 2}) sampling_params_1 = SamplingParams(logits_processors_args={"thinking_budget": 2})
mock_req_1 = MockRequest(req_id_1, prompt_ids_1, sampling_params_1) mock_req_1 = MockRequest(req_id_1, prompt_ids_1, sampling_params_1)
@@ -592,7 +612,7 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
mock_req_2 = MockRequest(req_id_2, prompt_ids_2, sampling_params_2) mock_req_2 = MockRequest(req_id_2, prompt_ids_2, sampling_params_2)
req_id_3 = "req_c" req_id_3 = "req_c"
prompt_ids_3 = [THINKING_START_TOKEN_ID, 30] # budget 1, last_token=30, tokens_after_start=1 prompt_ids_3 = [THINKING_START_TOKEN_ID, 30]
sampling_params_3 = SamplingParams(logits_processors_args={"thinking_budget": 1}) sampling_params_3 = SamplingParams(logits_processors_args={"thinking_budget": 1})
mock_req_3 = MockRequest(req_id_3, prompt_ids_3, sampling_params_3) mock_req_3 = MockRequest(req_id_3, prompt_ids_3, sampling_params_3)
@@ -609,14 +629,14 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
# Verify initial states # Verify initial states
self.assertTrue(processor._states[req_id_1].started) self.assertTrue(processor._states[req_id_1].started)
self.assertFalse(processor._states[req_id_1].ended) self.assertFalse(processor._states[req_id_1].ended)
self.assertEqual(processor._states[req_id_1].tokens_after_start, 2) self.assertEqual(processor._states[req_id_1].tokens_after_start, 0)
self.assertEqual(processor._states[req_id_1].last_token_id, 11) self.assertEqual(processor._states[req_id_1].last_token_id, 11)
self.assertNotIn(req_id_2, processor._states) # No budget specified for req_2 self.assertNotIn(req_id_2, processor._states) # No budget specified for req_2
self.assertTrue(processor._states[req_id_3].started) self.assertTrue(processor._states[req_id_3].started)
self.assertFalse(processor._states[req_id_3].ended) self.assertFalse(processor._states[req_id_3].ended)
self.assertEqual(processor._states[req_id_3].tokens_after_start, 1) self.assertEqual(processor._states[req_id_3].tokens_after_start, 0)
self.assertEqual(processor._states[req_id_3].last_token_id, 30) self.assertEqual(processor._states[req_id_3].last_token_id, 30)
# Simulate logits for the batch # Simulate logits for the batch
@@ -624,16 +644,14 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
processor.update_state(mock_runner.share_inputs) # Ensure state is updated before apply processor.update_state(mock_runner.share_inputs) # Ensure state is updated before apply
processed_batch_logits = processor.apply(batch_logits) processed_batch_logits = processor.apply(batch_logits)
# Req 1: budget 2, tokens_after_start 2. Should force NEW_LINE (last_token_id is 11, not NEW_LINE) # Req 1: prompt-side content does not consume budget, so first step is normal generation.
self.assertEqual(processed_batch_logits[0, NEW_LINE_TOKEN_ID].item(), 0.0) self.assertEqual(paddle.argmax(processed_batch_logits[0], axis=-1).item(), 0)
self.assertEqual(paddle.argmax(processed_batch_logits[0], axis=-1).item(), NEW_LINE_TOKEN_ID)
# Req 2: No thinking budget, normal generation # Req 2: No thinking budget, normal generation
self.assertEqual(paddle.argmax(processed_batch_logits[1], axis=-1).item(), 0) self.assertEqual(paddle.argmax(processed_batch_logits[1], axis=-1).item(), 0)
# Req 3: budget 1, tokens_after_start 1. Should force NEW_LINE (last_token_id is 30, not NEW_LINE) # Req 3: prompt-side content does not consume budget, so first step is normal generation.
self.assertEqual(processed_batch_logits[2, NEW_LINE_TOKEN_ID].item(), 0.0) self.assertEqual(paddle.argmax(processed_batch_logits[2], axis=-1).item(), 0)
self.assertEqual(paddle.argmax(processed_batch_logits[2], axis=-1).item(), NEW_LINE_TOKEN_ID)
# Simulate generating next tokens and updating state # Simulate generating next tokens and updating state
next_tokens = mock_runner.generate_next_token(processed_batch_logits) next_tokens = mock_runner.generate_next_token(processed_batch_logits)
@@ -643,20 +661,25 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
processor.update_state(mock_runner.share_inputs) processor.update_state(mock_runner.share_inputs)
# Verify updated states for next step # Verify updated states for next step
self.assertEqual(processor._states[req_id_1].last_token_id, NEW_LINE_TOKEN_ID) self.assertEqual(processor._states[req_id_1].last_token_id, 0)
self.assertEqual(processor._states[req_id_3].last_token_id, NEW_LINE_TOKEN_ID) self.assertEqual(processor._states[req_id_3].last_token_id, 0)
self.assertEqual(processor._states[req_id_1].tokens_after_start, 1)
self.assertEqual(processor._states[req_id_3].tokens_after_start, 1)
self.assertFalse(processor._states[req_id_1].ended)
self.assertFalse(processor._states[req_id_3].ended)
batch_logits = self._get_initial_logits(3) batch_logits = self._get_initial_logits(3)
processor.update_state(mock_runner.share_inputs) processor.update_state(mock_runner.share_inputs)
processed_batch_logits = processor.apply(batch_logits) processed_batch_logits = processor.apply(batch_logits)
# Req 1: last token was NEW_LINE. Should force THINKING_END # Req 1: budget 2, tokens_after_start 1. Still normal generation.
self.assertEqual(paddle.argmax(processed_batch_logits[0], axis=-1).item(), THINKING_END_TOKEN_ID) self.assertEqual(paddle.argmax(processed_batch_logits[0], axis=-1).item(), 0)
# Req 2: Still normal generation # Req 2: Still normal generation
self.assertEqual(paddle.argmax(processed_batch_logits[1], axis=-1).item(), 0) self.assertEqual(paddle.argmax(processed_batch_logits[1], axis=-1).item(), 0)
# Req 3: last token was NEW_LINE. Should force THINKING_END # Req 3: budget 1 reached after one generated token, should now force THINKING_END.
self.assertEqual(processed_batch_logits[2, THINKING_END_TOKEN_ID].item(), 0.0)
self.assertEqual(paddle.argmax(processed_batch_logits[2], axis=-1).item(), THINKING_END_TOKEN_ID) self.assertEqual(paddle.argmax(processed_batch_logits[2], axis=-1).item(), THINKING_END_TOKEN_ID)
@@ -724,7 +747,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
self.assertTrue(updated["think_prompt_checked"]) self.assertTrue(updated["think_prompt_checked"])
self.assertTrue(updated["think_prompt_started"]) self.assertTrue(updated["think_prompt_started"])
self.assertTrue(updated["think_prompt_ended"]) self.assertTrue(updated["think_prompt_ended"])
self.assertEqual(updated["think_prompt_tokens_after_start"], 2) self.assertEqual(updated["think_prompt_tokens_after_start"], 0)
self.assertEqual(updated["think_prompt_last_token_id"], 3) self.assertEqual(updated["think_prompt_last_token_id"], 3)
def test_v1_process_request_missing_logits_processors_args(self): def test_v1_process_request_missing_logits_processors_args(self):
@@ -833,6 +856,58 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
self.assertEqual(processor.encode_with_cache("iter"), [41, 42]) self.assertEqual(processor.encode_with_cache("iter"), [41, 42])
self.assertNotIn(("np", False), processor._tokenize_cache) self.assertNotIn(("np", False), processor._tokenize_cache)
def test_text_encode_with_cache_lazy_init(self):
processor = TextDataProcessor.__new__(TextDataProcessor)
call_counter = {"count": 0}
def _text2ids(text, max_model_len=None, add_special_tokens=False):
call_counter["count"] += 1
return np.array([51, 52], dtype=np.int64)
processor.text2ids = _text2ids
self.assertFalse(hasattr(processor, "_tokenize_cache"))
self.assertEqual(processor.encode_with_cache("lazy"), [51, 52])
self.assertTrue(hasattr(processor, "_tokenize_cache"))
self.assertEqual(processor.encode_with_cache("lazy"), [51, 52])
self.assertEqual(call_counter["count"], 1)
def test_v1_encode_with_cache_lazy_init(self):
processor = V1TextDataProcessor.__new__(V1TextDataProcessor)
call_counter = {"count": 0}
def _text2ids(text, max_model_len=None, add_special_tokens=False):
call_counter["count"] += 1
return np.array([61, 62], dtype=np.int64)
processor.text2ids = _text2ids
self.assertFalse(hasattr(processor, "_tokenize_cache"))
self.assertEqual(processor.encode_with_cache("lazy"), [61, 62])
self.assertTrue(hasattr(processor, "_tokenize_cache"))
self.assertEqual(processor.encode_with_cache("lazy"), [61, 62])
self.assertEqual(call_counter["count"], 1)
def test_ernie_encode_literal_text_with_cache(self):
processor = ErnieTextDataProcessor.__new__(ErnieTextDataProcessor)
processor.tokenizer = SimpleNamespace(
tokenize=lambda text: ["token_a", "token_b"],
convert_tokens_to_ids=lambda tokens: [71, 72],
)
self.assertEqual(processor._encode_literal_text_with_cache("fallback"), [71, 72])
self.assertEqual(processor._encode_literal_text_with_cache("fallback"), [71, 72])
def test_v1_ernie_encode_literal_text_with_cache(self):
processor = V1ErnieTextDataProcessor.__new__(V1ErnieTextDataProcessor)
processor.tokenizer = SimpleNamespace(
tokenize=lambda text: ["token_c", "token_d"],
convert_tokens_to_ids=lambda tokens: [81, 82],
)
self.assertEqual(processor._encode_literal_text_with_cache("fallback"), [81, 82])
self.assertEqual(processor._encode_literal_text_with_cache("fallback"), [81, 82])
def test_text_update_thinking_prompt_state_branches(self): def test_text_update_thinking_prompt_state_branches(self):
processor = TextDataProcessor.__new__(TextDataProcessor) processor = TextDataProcessor.__new__(TextDataProcessor)
processor._think_token_ids = None processor._think_token_ids = None
@@ -868,7 +943,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
) )
self.assertTrue(with_start_no_end["think_prompt_started"]) self.assertTrue(with_start_no_end["think_prompt_started"])
self.assertFalse(with_start_no_end["think_prompt_ended"]) self.assertFalse(with_start_no_end["think_prompt_ended"])
self.assertEqual(with_start_no_end["think_prompt_tokens_after_start"], 2) self.assertEqual(with_start_no_end["think_prompt_tokens_after_start"], 0)
self.assertEqual(with_start_no_end["think_prompt_last_token_id"], 3) self.assertEqual(with_start_no_end["think_prompt_last_token_id"], 3)
# 命中 _get_think_token_ids 的缓存分支 # 命中 _get_think_token_ids 的缓存分支
@@ -891,7 +966,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
) )
self.assertTrue(with_start_no_end["think_prompt_started"]) self.assertTrue(with_start_no_end["think_prompt_started"])
self.assertFalse(with_start_no_end["think_prompt_ended"]) self.assertFalse(with_start_no_end["think_prompt_ended"])
self.assertEqual(with_start_no_end["think_prompt_tokens_after_start"], 2) self.assertEqual(with_start_no_end["think_prompt_tokens_after_start"], 0)
self.assertEqual(with_start_no_end["think_prompt_last_token_id"], 3) self.assertEqual(with_start_no_end["think_prompt_last_token_id"], 3)
# 命中 _get_think_token_ids 的缓存分支 # 命中 _get_think_token_ids 的缓存分支
@@ -903,7 +978,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
processor.eos_token_ids = [1] processor.eos_token_ids = [1]
processor.update_stop_seq = lambda *args, **kwargs: None processor.update_stop_seq = lambda *args, **kwargs: None
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
processor.encode_with_cache = lambda text, *args, **kwargs: [23] if text == "\n" else [201, 202] processor._encode_literal_text_with_cache = lambda text: [201, 202]
processor._update_thinking_prompt_state = lambda prompt_token_ids, args: args processor._update_thinking_prompt_state = lambda prompt_token_ids, args: args
processor.reasoning_parser = None processor.reasoning_parser = None
@@ -924,7 +999,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
processed = processor.process_request_dict(request, max_model_len=16) processed = processor.process_request_dict(request, max_model_len=16)
self.assertEqual( self.assertEqual(
processed["logits_processors_args"].get("think_stop_sentence_token_ids"), processed["logits_processors_args"].get("think_stop_sentence_token_ids"),
[23, 201, 202], [201, 202],
) )
self.assertNotIn("think_stop_sentence", processed["logits_processors_args"]) self.assertNotIn("think_stop_sentence", processed["logits_processors_args"])
@@ -934,7 +1009,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
processor.eos_token_ids = [1] processor.eos_token_ids = [1]
processor.update_stop_seq = lambda *args, **kwargs: None processor.update_stop_seq = lambda *args, **kwargs: None
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
processor.encode_with_cache = lambda text, *args, **kwargs: [23] if text == "\n" else [301, 302] processor._encode_literal_text_with_cache = lambda text: [301, 302]
processor._update_thinking_prompt_state = lambda prompt_token_ids, args: args processor._update_thinking_prompt_state = lambda prompt_token_ids, args: args
processor.reasoning_parser = None processor.reasoning_parser = None
@@ -955,7 +1030,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
processed = processor.process_request(request, max_model_len=16) processed = processor.process_request(request, max_model_len=16)
self.assertEqual( self.assertEqual(
processed.logits_processors_args.get("think_stop_sentence_token_ids"), processed.logits_processors_args.get("think_stop_sentence_token_ids"),
[23, 301, 302], [301, 302],
) )
self.assertNotIn("think_stop_sentence", processed.logits_processors_args) self.assertNotIn("think_stop_sentence", processed.logits_processors_args)
@@ -965,7 +1040,7 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
processor.eos_token_ids = [1] processor.eos_token_ids = [1]
processor.update_stop_seq = lambda *args, **kwargs: None processor.update_stop_seq = lambda *args, **kwargs: None
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
processor.encode_with_cache = lambda text, *args, **kwargs: [23] if text == "\n" else [401, 402] processor._encode_literal_text_with_cache = lambda text: [401, 402]
processor._update_thinking_prompt_state = lambda prompt_token_ids, args: args processor._update_thinking_prompt_state = lambda prompt_token_ids, args: args
processor.reasoning_parser = None processor.reasoning_parser = None
@@ -992,10 +1067,173 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
processed = processor.process_request_dict(request, max_model_len=16) processed = processor.process_request_dict(request, max_model_len=16)
self.assertEqual( self.assertEqual(
processed.sampling_params.logits_processors_args.get("think_stop_sentence_token_ids"), processed.sampling_params.logits_processors_args.get("think_stop_sentence_token_ids"),
[23, 401, 402], [401, 402],
) )
self.assertNotIn("think_stop_sentence", processed.sampling_params.logits_processors_args) self.assertNotIn("think_stop_sentence", processed.sampling_params.logits_processors_args)
def test_ernie_process_request_dict_prepares_thinking_budget_args(self):
processor = ErnieTextDataProcessor.__new__(ErnieTextDataProcessor)
processor._apply_default_parameters = lambda request: request
processor.eos_token_ids = [1]
processor.update_stop_seq = lambda *args, **kwargs: None
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
processor._encode_literal_text_with_cache = lambda text: [501, 502]
processor.tokenizer = DummyTokenizerForTextProcessor()
processor.reasoning_parser = None
request = {
"request_id": "req_ernie_text",
"eos_token_ids": [1],
"prompt_token_ids": [1, THINKING_START_TOKEN_ID, 2],
"prompt": None,
"messages": None,
"logits_processors_args": {"thinking_budget": 20, "think_stop_sentence": "done"},
"bad_words": None,
"bad_words_token_ids": None,
"max_tokens": 1,
"temperature": 1.0,
"top_p": 0.9,
"response_max_tokens": None,
"enable_thinking": True,
}
with patch("fastdeploy.input.ernie4_5_processor.process_stop_token_ids", lambda *args, **kwargs: None):
processed = processor.process_request_dict(request, max_model_len=16)
self.assertEqual(processed["logits_processors_args"]["think_stop_sentence_token_ids"], [501, 502])
self.assertTrue(processed["logits_processors_args"]["think_prompt_started"])
self.assertFalse(processed["logits_processors_args"]["think_prompt_ended"])
self.assertEqual(processed["logits_processors_args"]["think_prompt_tokens_after_start"], 0)
def test_v1_ernie_process_request_dict_prepares_thinking_budget_args(self):
processor = V1ErnieTextDataProcessor.__new__(V1ErnieTextDataProcessor)
processor._apply_default_parameters = lambda request: request
processor.eos_token_ids = [1]
processor.update_stop_seq = lambda *args, **kwargs: None
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
processor._encode_literal_text_with_cache = lambda text: [601, 602]
processor.tokenizer = DummyTokenizerForTextProcessor()
processor.reasoning_parser = None
request = DummyRequestV1(
request_id="req_v1_ernie_text",
eos_token_ids=[1],
prompt_token_ids=[1, THINKING_START_TOKEN_ID, 2],
prompt=None,
messages=None,
chat_template_kwargs=None,
enable_thinking=True,
sampling_params=SimpleNamespace(
bad_words=None,
bad_words_token_ids=None,
max_tokens=1,
temperature=1.0,
top_p=0.9,
repetition_penalty=1.0,
frequency_penalty=0.0,
presence_penalty=0.0,
response_max_tokens=None,
n=1,
logits_processors_args={"thinking_budget": 20, "think_stop_sentence": "done"},
),
)
with patch("fastdeploy.input.v1.ernie4_5_processor.process_stop_token_ids", lambda *args, **kwargs: None):
processed = processor.process_request_dict(request, max_model_len=16)
self.assertEqual(processed.sampling_params.logits_processors_args["think_stop_sentence_token_ids"], [601, 602])
self.assertTrue(processed.sampling_params.logits_processors_args["think_prompt_started"])
self.assertFalse(processed.sampling_params.logits_processors_args["think_prompt_ended"])
self.assertEqual(processed.sampling_params.logits_processors_args["think_prompt_tokens_after_start"], 0)
def test_ernie_vl_process_request_dict_prepares_thinking_budget_args(self):
processor = ErnieVLDataProcessor.__new__(ErnieVLDataProcessor)
processor._apply_default_parameters = lambda request: request
processor.eos_token_ids = [1]
processor.update_stop_seq = lambda *args, **kwargs: None
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
processor._encode_literal_text_with_cache = lambda text: [701, 702]
processor.tokenizer = DummyTokenizerForTextProcessor()
processor.reasoning_parser = None
processor._check_mm_limits = lambda *args, **kwargs: None
processor.append_completion_tokens = lambda *args, **kwargs: None
processor.pack_outputs = lambda outs: outs
processor.ernie4_5_processor = SimpleNamespace(
request2ids=lambda request: {"input_ids": np.array([1, THINKING_START_TOKEN_ID, 2], dtype=np.int64)}
)
request = {
"request_id": "req_ernie_vl",
"eos_token_ids": [1],
"messages": [{"role": "user", "content": "hi"}],
"bad_words": None,
"bad_words_token_ids": None,
"logits_processors_args": {"thinking_budget": 20, "think_stop_sentence": "done"},
"max_tokens": 1,
"top_p": 0.9,
"response_max_tokens": None,
}
with patch(
"fastdeploy.input.ernie4_5_vl_processor.ernie4_5_vl_processor.process_stop_token_ids",
lambda *args, **kwargs: None,
):
processed = processor.process_request_dict(request, max_model_len=16)
self.assertEqual(processed["logits_processors_args"]["think_stop_sentence_token_ids"], [701, 702])
self.assertTrue(processed["logits_processors_args"]["think_prompt_started"])
self.assertFalse(processed["logits_processors_args"]["think_prompt_ended"])
self.assertEqual(processed["logits_processors_args"]["think_prompt_tokens_after_start"], 0)
def test_v1_ernie_vl_process_request_dict_prepares_thinking_budget_args(self):
processor = V1ErnieVLDataProcessor.__new__(V1ErnieVLDataProcessor)
processor._apply_default_parameters = lambda request: request
processor.eos_token_ids = [1]
processor.update_stop_seq = lambda *args, **kwargs: None
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
processor._encode_literal_text_with_cache = lambda text: [801, 802]
processor.tokenizer = DummyTokenizerForTextProcessor()
processor.reasoning_parser = None
processor._check_mm_limits = lambda *args, **kwargs: None
processor.append_completion_tokens = lambda *args, **kwargs: None
processor.pack_outputs = lambda outs: outs
processor.ernie4_5_processor = SimpleNamespace(
request2ids=lambda request: {"input_ids": np.array([1, THINKING_START_TOKEN_ID, 2], dtype=np.int64)}
)
request = DummyRequestV1(
request_id="req_v1_ernie_vl",
eos_token_ids=[1],
prompt_token_ids=None,
prompt=None,
messages=[{"role": "user", "content": "hi"}],
chat_template_kwargs=None,
enable_thinking=True,
completion_token_ids=None,
multimodal_data=None,
sampling_params=SimpleNamespace(
bad_words=None,
bad_words_token_ids=None,
max_tokens=1,
temperature=1.0,
top_p=0.9,
repetition_penalty=1.0,
frequency_penalty=0.0,
presence_penalty=0.0,
response_max_tokens=None,
reasoning_max_tokens=None,
n=1,
logits_processors_args={"thinking_budget": 20, "think_stop_sentence": "done"},
),
)
with patch(
"fastdeploy.input.v1.ernie4_5_vl_processor.ernie4_5_vl_processor.process_stop_token_ids",
lambda *args, **kwargs: None,
):
processed = processor.process_request_dict(request, max_model_len=16)
self.assertEqual(processed.sampling_params.logits_processors_args["think_stop_sentence_token_ids"], [801, 802])
self.assertTrue(processed.sampling_params.logits_processors_args["think_prompt_started"])
self.assertFalse(processed.sampling_params.logits_processors_args["think_prompt_ended"])
self.assertEqual(processed.sampling_params.logits_processors_args["think_prompt_tokens_after_start"], 0)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()